extended pipeline arguments
This commit is contained in:
		@@ -79,8 +79,14 @@ def main():
 | 
			
		||||
    parser.add_argument("--basic_phase_ae_config", type=str, required=False)
 | 
			
		||||
    parser.add_argument("--eqtransformer_config", type=str, required=False)
 | 
			
		||||
    parser.add_argument("--dataset", type=str, required=False)
 | 
			
		||||
    available_models = ["GPD", "PhaseNet", "BasicPhaseAE", "EQTransformer"]
 | 
			
		||||
    parser.add_argument("--models", nargs='*', required=False, choices=available_models, default=available_models,
 | 
			
		||||
                        help="Models to train and evaluate (default: all)")
 | 
			
		||||
    parser.add_argument("--collect_results", action="store_true", help="Collect and log results without training")
 | 
			
		||||
    args = parser.parse_args()
 | 
			
		||||
 | 
			
		||||
    if not args.collect_results:
 | 
			
		||||
 | 
			
		||||
        if args.dataset is not None:
 | 
			
		||||
            util.set_dataset(args.dataset)
 | 
			
		||||
            importlib.reload(config_loader)
 | 
			
		||||
@@ -94,9 +100,7 @@ def main():
 | 
			
		||||
 | 
			
		||||
        # find the best hyperparams for the models
 | 
			
		||||
        logger.info("Started training the models.")
 | 
			
		||||
    for model_name in ["GPD", "PhaseNet", "BasicPhaseAE", "EQTransformer"]:
 | 
			
		||||
        if config_loader.dataset_name == "lumineos" and model_name == "EQTransformer":
 | 
			
		||||
            break
 | 
			
		||||
        for model_name in args.models:
 | 
			
		||||
            sweep_id = find_the_best_params(model_name, args)
 | 
			
		||||
            generate_predictions(sweep_id, model_name)
 | 
			
		||||
 | 
			
		||||
 
 | 
			
		||||
		Reference in New Issue
	
	Block a user