import os.path import wandb import yaml from train import get_data_loaders, load_model, train_model from dotenv import load_dotenv load_dotenv() wandb_api_key = os.environ.get('WANDB_API_KEY') if wandb_api_key is None: raise ValueError("WANDB_API_KEY environment variable is not set.") wandb.login(key=wandb_api_key) project_path = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) sweep_config_path = project_path + "/experiments/sweep4.yaml" with open(sweep_config_path) as file: sweep_configuration = yaml.load(file, Loader=yaml.FullLoader) sweep_id = wandb.sweep( sweep=sweep_configuration, project='training_seisbench_models_on_igf_data' ) sampling_rate = 100 def tune_training_hyperparams(): run = wandb.init( # set the wandb project where this run will be logged project="training_seisbench_models_on_igf_data", # track hyperparameters and run metadata config={"sampling_rate":sampling_rate} ) wandb.run.log_code(".", include_fn=lambda path: path.endswith("training_wandb_sweep.py")) train_loader, dev_loader, test_loader = get_data_loaders(batch_size=wandb.config.batch_size, sampling_rate=wandb.config.sampling_rate, sb_dataset=wandb.config.dataset) model_name = wandb.config.model_name pretrained = wandb.config.pretrained print(wandb.config) print(model_name, pretrained, type(pretrained), wandb.config.sampling_rate) if not pretrained: pretrained model = load_model(name=model_name, pretrained=pretrained) path_to_trained_model = f"{project_path}/models/{model_name}_pretrained_on_{pretrained}_finetuned_on_{wandb.config.dataset}.pt" train_model(model, path_to_trained_model, train_loader, dev_loader) artifact = wandb.Artifact('model', type='model') artifact.add_file(path_to_trained_model) run.log_artifact(artifact) run.finish() if __name__ == "__main__": wandb.agent(sweep_id, function=tune_training_hyperparams, count=10)