# ----------------- # Copyright © 2023 ACK Cyfronet AGH, Poland. # This work was partially funded by EPOS Project funded in frame of PL-POIR4.2 # ----------------- import os import os.path import argparse from pytorch_lightning.loggers import WandbLogger, CSVLogger from pytorch_lightning.callbacks import ModelCheckpoint, LearningRateMonitor from pytorch_lightning.callbacks.early_stopping import EarlyStopping import pytorch_lightning as pl import wandb import torch import traceback import logging from dotenv import load_dotenv import models import train import util import config_loader torch.multiprocessing.set_sharing_strategy('file_system') os.system("ulimit -n unlimited") load_dotenv() wandb_api_key = os.environ.get('WANDB_API_KEY') if wandb_api_key is None: raise ValueError("WANDB_API_KEY environment variable is not set.") host = os.environ.get("WANDB_HOST") if host is None: raise ValueError("WANDB_HOST environment variable is not set.") wandb.login(key=wandb_api_key, host=host) wandb_project_name = os.environ.get("WANDB_PROJECT") wandb_user_name = os.environ.get("WANDB_USER") script_name = os.path.splitext(os.path.basename(__file__))[0] logger = logging.getLogger(script_name) logger.setLevel(logging.INFO) def set_random_seed(seed=3): torch.manual_seed(seed) torch.cuda.manual_seed(seed) torch.cuda.manual_seed_all(seed) torch.backends.cudnn.deterministic = True torch.backends.cudnn.benchmark = False def get_trainer_args(config): trainer_args = {'max_epochs': config.max_epochs[0]} return trainer_args def get_arg(arg): if type(arg) == list: return arg[0] return arg class HyperparameterSweep: def __init__(self, project_name, sweep_config): self.project_name = project_name self.sweep_config = sweep_config self.sweep_id = None def run_sweep(self): # Create the sweep self.sweep_id = wandb.sweep(self.sweep_config, project=self.project_name) logger.info("Created sweep with ID: " + self.sweep_id) # Run the sweep wandb.agent(self.sweep_id, function=self.run_experiment, count=config_loader.experiment_count) def all_runs_finished(self): sweep_path = f"{wandb_user_name}/{wandb_project_name}/{self.sweep_id}" logger.debug(f"Sweep path: {sweep_path}") sweep_runs = wandb.Api().sweep(sweep_path).runs all_finished = all(run.state == "finished" for run in sweep_runs) if all_finished: logger.info("All runs finished successfully.") all_not_running = all(run.state != "running" for run in sweep_runs) if all_not_running and not all_finished: logger.warning("Some runs are not finished but failed or crashed.") return all_not_running def run_experiment(self): try: logger.info("Starting a new run...") run = wandb.init( project=self.project_name, config=config_loader.config, save_code=True, entity=wandb_user_name ) run.log_code( root=".", include_fn=lambda path: path.endswith(".py") or path.endswith(".sh"), exclude_fn=lambda path: path.endswith("template.sh") ) model_name = get_arg(wandb.config.model_name) model_args = models.get_model_specific_args(wandb.config) if "pretrained" in wandb.config: weights = get_arg(wandb.config.pretrained) if weights != "false": model_args["pretrained"] = weights if "norm" in wandb.config: model_args["norm"] = get_arg(wandb.config.norm) if "finetuning" in wandb.config: model_args['finetuning_strategy'] = get_arg(wandb.config.finetuning) if "lr_reduce_factor" in wandb.config: model_args['steplr_gamma'] = get_arg(wandb.config.lr_reduce_factor) logger.debug(f"Initializing {model_name} with args: {model_args}") model = models.__getattribute__(model_name + "Lit")(**model_args) train_loader, dev_loader = train.prepare_data(wandb.config, model, test_run=False) wandb_logger = WandbLogger(project=self.project_name, log_model="all") wandb_logger.watch(model) # CSV logger - also used for saving configuration as yaml experiment_name = f"{config_loader.dataset_name}_{model_name}" csv_logger = CSVLogger(config_loader.models_path, experiment_name, version=run.id) csv_logger.log_hyperparams(wandb.config) loggers = [wandb_logger, csv_logger] experiment_signature = f"{experiment_name}_sweep={self.sweep_id}-run={run.id}" logger.debug("Experiment signature: " + experiment_signature) checkpoint_callback = ModelCheckpoint( save_top_k=1, filename=experiment_signature + "-{epoch}-{val_loss:.3f}", monitor="val_loss", mode="min", dirpath=f"{config_loader.models_path}/{experiment_name}/", ) # save_top_k=1, monitor="val_loss", mode="min": save the best model in terms of validation loss checkpoint_callback.STARTING_VERSION = 1 early_stopping_callback = EarlyStopping( monitor="val_loss", patience=5, verbose=True, mode="min") lr_monitor = LearningRateMonitor(logging_interval='epoch') callbacks = [checkpoint_callback, early_stopping_callback, lr_monitor] trainer = pl.Trainer( default_root_dir=config_loader.models_path, logger=loggers, callbacks=callbacks, **get_trainer_args(wandb.config) ) trainer.fit(model, train_loader, dev_loader) except Exception as e: logger.error("caught error: ", str(e)) traceback_str = traceback.format_exc() logger.error(traceback_str) run.finish() def start_sweep(sweep_config): logger.info("Starting sweep with config: " + str(sweep_config)) set_random_seed(config_loader.seed) sweep_runner = HyperparameterSweep(project_name=wandb_project_name, sweep_config=sweep_config) sweep_runner.run_sweep() return sweep_runner if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument("--sweep_config", type=str, required=True) args = parser.parse_args() sweep_config = util.load_sweep_config(args.sweep_config) start_sweep(sweep_config)