63 lines
2.0 KiB
Python
63 lines
2.0 KiB
Python
import os.path
|
|
import wandb
|
|
import yaml
|
|
|
|
from train import get_data_loaders, load_model, train_model
|
|
|
|
from dotenv import load_dotenv
|
|
|
|
load_dotenv()
|
|
wandb_api_key = os.environ.get('WANDB_API_KEY')
|
|
if wandb_api_key is None:
|
|
raise ValueError("WANDB_API_KEY environment variable is not set.")
|
|
|
|
wandb.login(key=wandb_api_key)
|
|
|
|
project_path = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
|
|
sweep_config_path = project_path + "/experiments/sweep4.yaml"
|
|
|
|
with open(sweep_config_path) as file:
|
|
sweep_configuration = yaml.load(file, Loader=yaml.FullLoader)
|
|
|
|
sweep_id = wandb.sweep(
|
|
sweep=sweep_configuration,
|
|
project='training_seisbench_models_on_igf_data'
|
|
)
|
|
sampling_rate = 100
|
|
|
|
def tune_training_hyperparams():
|
|
|
|
run = wandb.init(
|
|
# set the wandb project where this run will be logged
|
|
project="training_seisbench_models_on_igf_data",
|
|
# track hyperparameters and run metadata
|
|
config={"sampling_rate":sampling_rate}
|
|
)
|
|
|
|
wandb.run.log_code(".", include_fn=lambda path: path.endswith("training_wandb_sweep.py"))
|
|
|
|
train_loader, dev_loader, test_loader = get_data_loaders(batch_size=wandb.config.batch_size,
|
|
sampling_rate=wandb.config.sampling_rate,
|
|
sb_dataset=wandb.config.dataset)
|
|
|
|
model_name = wandb.config.model_name
|
|
pretrained = wandb.config.pretrained
|
|
print(wandb.config)
|
|
print(model_name, pretrained, type(pretrained), wandb.config.sampling_rate)
|
|
if not pretrained:
|
|
pretrained
|
|
model = load_model(name=model_name, pretrained=pretrained)
|
|
path_to_trained_model = f"{project_path}/models/{model_name}_pretrained_on_{pretrained}_finetuned_on_{wandb.config.dataset}.pt"
|
|
train_model(model, path_to_trained_model, train_loader, dev_loader)
|
|
|
|
artifact = wandb.Artifact('model', type='model')
|
|
artifact.add_file(path_to_trained_model)
|
|
run.log_artifact(artifact)
|
|
|
|
run.finish()
|
|
|
|
|
|
if __name__ == "__main__":
|
|
|
|
wandb.agent(sweep_id, function=tune_training_hyperparams, count=10)
|