platform-demo-scripts/scripts/training_wandb_sweep.py
2023-07-05 09:58:06 +02:00

63 lines
2.0 KiB
Python

import os.path
import wandb
import yaml
from train import get_data_loaders, load_model, train_model
from dotenv import load_dotenv
load_dotenv()
wandb_api_key = os.environ.get('WANDB_API_KEY')
if wandb_api_key is None:
raise ValueError("WANDB_API_KEY environment variable is not set.")
wandb.login(key=wandb_api_key)
project_path = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
sweep_config_path = project_path + "/experiments/sweep4.yaml"
with open(sweep_config_path) as file:
sweep_configuration = yaml.load(file, Loader=yaml.FullLoader)
sweep_id = wandb.sweep(
sweep=sweep_configuration,
project='training_seisbench_models_on_igf_data'
)
sampling_rate = 100
def tune_training_hyperparams():
run = wandb.init(
# set the wandb project where this run will be logged
project="training_seisbench_models_on_igf_data",
# track hyperparameters and run metadata
config={"sampling_rate":sampling_rate}
)
wandb.run.log_code(".", include_fn=lambda path: path.endswith("training_wandb_sweep.py"))
train_loader, dev_loader, test_loader = get_data_loaders(batch_size=wandb.config.batch_size,
sampling_rate=wandb.config.sampling_rate,
sb_dataset=wandb.config.dataset)
model_name = wandb.config.model_name
pretrained = wandb.config.pretrained
print(wandb.config)
print(model_name, pretrained, type(pretrained), wandb.config.sampling_rate)
if not pretrained:
pretrained
model = load_model(name=model_name, pretrained=pretrained)
path_to_trained_model = f"{project_path}/models/{model_name}_pretrained_on_{pretrained}_finetuned_on_{wandb.config.dataset}.pt"
train_model(model, path_to_trained_model, train_loader, dev_loader)
artifact = wandb.Artifact('model', type='model')
artifact.add_file(path_to_trained_model)
run.log_artifact(artifact)
run.finish()
if __name__ == "__main__":
wandb.agent(sweep_id, function=tune_training_hyperparams, count=10)