platform-demo-scripts/scripts/hyperparameter_sweep.py
k.milian e86f131cc0 finetuning (#1)
Reviewed-on: #1
Extended Phasenet finetuning options
2024-07-29 13:41:05 +02:00

198 lines
6.6 KiB
Python

# -----------------
# Copyright © 2023 ACK Cyfronet AGH, Poland.
# This work was partially funded by EPOS Project funded in frame of PL-POIR4.2
# -----------------
import os
import os.path
import argparse
from pytorch_lightning.loggers import WandbLogger, CSVLogger
from pytorch_lightning.callbacks import ModelCheckpoint, LearningRateMonitor
from pytorch_lightning.callbacks.early_stopping import EarlyStopping
import pytorch_lightning as pl
import wandb
import torch
import traceback
import logging
from dotenv import load_dotenv
import models
import train
import util
import config_loader
torch.multiprocessing.set_sharing_strategy('file_system')
os.system("ulimit -n unlimited")
load_dotenv()
wandb_api_key = os.environ.get('WANDB_API_KEY')
if wandb_api_key is None:
raise ValueError("WANDB_API_KEY environment variable is not set.")
host = os.environ.get("WANDB_HOST")
if host is None:
raise ValueError("WANDB_HOST environment variable is not set.")
wandb.login(key=wandb_api_key, host=host)
wandb_project_name = os.environ.get("WANDB_PROJECT")
wandb_user_name = os.environ.get("WANDB_USER")
script_name = os.path.splitext(os.path.basename(__file__))[0]
logger = logging.getLogger(script_name)
logger.setLevel(logging.INFO)
def set_random_seed(seed=3):
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
def get_trainer_args(config):
trainer_args = {'max_epochs': config.max_epochs[0]}
return trainer_args
def get_arg(arg):
if type(arg) == list:
return arg[0]
return arg
class HyperparameterSweep:
def __init__(self, project_name, sweep_config):
self.project_name = project_name
self.sweep_config = sweep_config
self.sweep_id = None
def run_sweep(self):
# Create the sweep
self.sweep_id = wandb.sweep(self.sweep_config, project=self.project_name)
logger.info("Created sweep with ID: " + self.sweep_id)
# Run the sweep
wandb.agent(self.sweep_id, function=self.run_experiment, count=config_loader.experiment_count)
def all_runs_finished(self):
sweep_path = f"{wandb_user_name}/{wandb_project_name}/{self.sweep_id}"
logger.debug(f"Sweep path: {sweep_path}")
sweep_runs = wandb.Api().sweep(sweep_path).runs
all_finished = all(run.state == "finished" for run in sweep_runs)
if all_finished:
logger.info("All runs finished successfully.")
all_not_running = all(run.state != "running" for run in sweep_runs)
if all_not_running and not all_finished:
logger.warning("Some runs are not finished but failed or crashed.")
return all_not_running
def run_experiment(self):
try:
logger.info("Starting a new run...")
run = wandb.init(
project=self.project_name,
config=config_loader.config,
save_code=True,
entity=wandb_user_name
)
run.log_code(
root=".",
include_fn=lambda path: path.endswith(".py") or path.endswith(".sh"),
exclude_fn=lambda path: path.endswith("template.sh")
)
model_name = get_arg(wandb.config.model_name)
model_args = models.get_model_specific_args(wandb.config)
if "pretrained" in wandb.config:
weights = get_arg(wandb.config.pretrained)
if weights != "false":
model_args["pretrained"] = weights
if "norm" in wandb.config:
model_args["norm"] = get_arg(wandb.config.norm)
if "finetuning" in wandb.config:
model_args['finetuning_strategy'] = get_arg(wandb.config.finetuning)
if "lr_reduce_factor" in wandb.config:
model_args['steplr_gamma'] = get_arg(wandb.config.lr_reduce_factor)
logger.debug(f"Initializing {model_name} with args: {model_args}")
model = models.__getattribute__(model_name + "Lit")(**model_args)
train_loader, dev_loader = train.prepare_data(wandb.config, model, test_run=False)
wandb_logger = WandbLogger(project=self.project_name, log_model="all")
wandb_logger.watch(model)
# CSV logger - also used for saving configuration as yaml
experiment_name = f"{config_loader.dataset_name}_{model_name}"
csv_logger = CSVLogger(config_loader.models_path, experiment_name, version=run.id)
csv_logger.log_hyperparams(wandb.config)
loggers = [wandb_logger, csv_logger]
experiment_signature = f"{experiment_name}_sweep={self.sweep_id}-run={run.id}"
logger.debug("Experiment signature: " + experiment_signature)
checkpoint_callback = ModelCheckpoint(
save_top_k=1,
filename=experiment_signature + "-{epoch}-{val_loss:.3f}",
monitor="val_loss",
mode="min",
dirpath=f"{config_loader.models_path}/{experiment_name}/",
) # save_top_k=1, monitor="val_loss", mode="min": save the best model in terms of validation loss
checkpoint_callback.STARTING_VERSION = 1
early_stopping_callback = EarlyStopping(
monitor="val_loss",
patience=5,
verbose=True,
mode="min")
lr_monitor = LearningRateMonitor(logging_interval='epoch')
callbacks = [checkpoint_callback, early_stopping_callback, lr_monitor]
trainer = pl.Trainer(
default_root_dir=config_loader.models_path,
logger=loggers,
callbacks=callbacks,
**get_trainer_args(wandb.config)
)
trainer.fit(model, train_loader, dev_loader)
except Exception as e:
logger.error("caught error: ", str(e))
traceback_str = traceback.format_exc()
logger.error(traceback_str)
run.finish()
def start_sweep(sweep_config):
logger.info("Starting sweep with config: " + str(sweep_config))
set_random_seed(config_loader.seed)
sweep_runner = HyperparameterSweep(project_name=wandb_project_name, sweep_config=sweep_config)
sweep_runner.run_sweep()
return sweep_runner
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--sweep_config", type=str, required=True)
args = parser.parse_args()
sweep_config = util.load_sweep_config(args.sweep_config)
start_sweep(sweep_config)