2023-08-29 09:59:31 +02:00
|
|
|
# -----------------
|
|
|
|
# Copyright © 2023 ACK Cyfronet AGH, Poland.
|
|
|
|
# This work was partially funded by EPOS Project funded in frame of PL-POIR4.2
|
|
|
|
# -----------------
|
|
|
|
|
2023-09-26 10:50:46 +02:00
|
|
|
import os
|
2023-08-29 09:59:31 +02:00
|
|
|
import os.path
|
|
|
|
import argparse
|
|
|
|
from pytorch_lightning.loggers import WandbLogger, CSVLogger
|
2024-07-29 13:41:05 +02:00
|
|
|
from pytorch_lightning.callbacks import ModelCheckpoint, LearningRateMonitor
|
2023-08-29 09:59:31 +02:00
|
|
|
from pytorch_lightning.callbacks.early_stopping import EarlyStopping
|
2024-07-29 13:41:05 +02:00
|
|
|
|
2023-08-29 09:59:31 +02:00
|
|
|
import pytorch_lightning as pl
|
|
|
|
import wandb
|
|
|
|
import torch
|
|
|
|
import traceback
|
|
|
|
import logging
|
|
|
|
from dotenv import load_dotenv
|
|
|
|
import models
|
|
|
|
import train
|
|
|
|
import util
|
2023-10-12 14:27:53 +02:00
|
|
|
import config_loader
|
2023-08-29 09:59:31 +02:00
|
|
|
|
2024-07-29 13:41:05 +02:00
|
|
|
|
|
|
|
|
2023-08-29 09:59:31 +02:00
|
|
|
torch.multiprocessing.set_sharing_strategy('file_system')
|
2023-09-26 10:50:46 +02:00
|
|
|
os.system("ulimit -n unlimited")
|
2023-08-29 09:59:31 +02:00
|
|
|
|
|
|
|
load_dotenv()
|
|
|
|
wandb_api_key = os.environ.get('WANDB_API_KEY')
|
|
|
|
if wandb_api_key is None:
|
|
|
|
raise ValueError("WANDB_API_KEY environment variable is not set.")
|
|
|
|
host = os.environ.get("WANDB_HOST")
|
|
|
|
if host is None:
|
|
|
|
raise ValueError("WANDB_HOST environment variable is not set.")
|
|
|
|
|
|
|
|
wandb.login(key=wandb_api_key, host=host)
|
|
|
|
wandb_project_name = os.environ.get("WANDB_PROJECT")
|
|
|
|
wandb_user_name = os.environ.get("WANDB_USER")
|
|
|
|
|
|
|
|
script_name = os.path.splitext(os.path.basename(__file__))[0]
|
|
|
|
logger = logging.getLogger(script_name)
|
2024-07-29 13:41:05 +02:00
|
|
|
logger.setLevel(logging.INFO)
|
2023-08-29 09:59:31 +02:00
|
|
|
|
|
|
|
|
|
|
|
def set_random_seed(seed=3):
|
|
|
|
torch.manual_seed(seed)
|
|
|
|
torch.cuda.manual_seed(seed)
|
|
|
|
torch.cuda.manual_seed_all(seed)
|
|
|
|
torch.backends.cudnn.deterministic = True
|
|
|
|
torch.backends.cudnn.benchmark = False
|
|
|
|
|
|
|
|
|
|
|
|
def get_trainer_args(config):
|
|
|
|
trainer_args = {'max_epochs': config.max_epochs[0]}
|
|
|
|
return trainer_args
|
|
|
|
|
|
|
|
|
2024-07-29 13:41:05 +02:00
|
|
|
def get_arg(arg):
|
|
|
|
if type(arg) == list:
|
|
|
|
return arg[0]
|
|
|
|
return arg
|
|
|
|
|
2023-08-29 09:59:31 +02:00
|
|
|
class HyperparameterSweep:
|
|
|
|
def __init__(self, project_name, sweep_config):
|
|
|
|
self.project_name = project_name
|
|
|
|
self.sweep_config = sweep_config
|
|
|
|
self.sweep_id = None
|
|
|
|
|
|
|
|
def run_sweep(self):
|
|
|
|
|
|
|
|
# Create the sweep
|
|
|
|
self.sweep_id = wandb.sweep(self.sweep_config, project=self.project_name)
|
|
|
|
logger.info("Created sweep with ID: " + self.sweep_id)
|
|
|
|
# Run the sweep
|
2023-10-12 14:27:53 +02:00
|
|
|
wandb.agent(self.sweep_id, function=self.run_experiment, count=config_loader.experiment_count)
|
2023-08-29 09:59:31 +02:00
|
|
|
|
|
|
|
def all_runs_finished(self):
|
|
|
|
|
|
|
|
sweep_path = f"{wandb_user_name}/{wandb_project_name}/{self.sweep_id}"
|
|
|
|
logger.debug(f"Sweep path: {sweep_path}")
|
|
|
|
sweep_runs = wandb.Api().sweep(sweep_path).runs
|
|
|
|
all_finished = all(run.state == "finished" for run in sweep_runs)
|
|
|
|
if all_finished:
|
|
|
|
logger.info("All runs finished successfully.")
|
|
|
|
|
|
|
|
all_not_running = all(run.state != "running" for run in sweep_runs)
|
|
|
|
if all_not_running and not all_finished:
|
|
|
|
logger.warning("Some runs are not finished but failed or crashed.")
|
|
|
|
|
|
|
|
return all_not_running
|
|
|
|
|
|
|
|
def run_experiment(self):
|
|
|
|
try:
|
|
|
|
|
2024-07-29 13:41:05 +02:00
|
|
|
logger.info("Starting a new run...")
|
|
|
|
|
2023-08-29 09:59:31 +02:00
|
|
|
run = wandb.init(
|
|
|
|
project=self.project_name,
|
2023-10-12 14:27:53 +02:00
|
|
|
config=config_loader.config,
|
2024-07-29 13:41:05 +02:00
|
|
|
save_code=True,
|
|
|
|
entity=wandb_user_name
|
2023-08-29 09:59:31 +02:00
|
|
|
)
|
2023-10-12 14:27:53 +02:00
|
|
|
run.log_code(
|
|
|
|
root=".",
|
|
|
|
include_fn=lambda path: path.endswith(".py") or path.endswith(".sh"),
|
|
|
|
exclude_fn=lambda path: path.endswith("template.sh")
|
2024-07-29 13:41:05 +02:00
|
|
|
)
|
2023-08-29 09:59:31 +02:00
|
|
|
|
2024-07-29 13:41:05 +02:00
|
|
|
model_name = get_arg(wandb.config.model_name)
|
2023-08-29 09:59:31 +02:00
|
|
|
model_args = models.get_model_specific_args(wandb.config)
|
|
|
|
|
2024-07-29 13:41:05 +02:00
|
|
|
if "pretrained" in wandb.config:
|
|
|
|
weights = get_arg(wandb.config.pretrained)
|
|
|
|
if weights != "false":
|
|
|
|
model_args["pretrained"] = weights
|
|
|
|
|
|
|
|
if "norm" in wandb.config:
|
|
|
|
model_args["norm"] = get_arg(wandb.config.norm)
|
|
|
|
|
|
|
|
if "finetuning" in wandb.config:
|
|
|
|
model_args['finetuning_strategy'] = get_arg(wandb.config.finetuning)
|
|
|
|
|
|
|
|
if "lr_reduce_factor" in wandb.config:
|
|
|
|
model_args['steplr_gamma'] = get_arg(wandb.config.lr_reduce_factor)
|
|
|
|
|
|
|
|
logger.debug(f"Initializing {model_name} with args: {model_args}")
|
2023-08-29 09:59:31 +02:00
|
|
|
model = models.__getattribute__(model_name + "Lit")(**model_args)
|
|
|
|
|
|
|
|
train_loader, dev_loader = train.prepare_data(wandb.config, model, test_run=False)
|
|
|
|
|
|
|
|
wandb_logger = WandbLogger(project=self.project_name, log_model="all")
|
|
|
|
wandb_logger.watch(model)
|
|
|
|
|
|
|
|
# CSV logger - also used for saving configuration as yaml
|
2023-10-12 14:27:53 +02:00
|
|
|
experiment_name = f"{config_loader.dataset_name}_{model_name}"
|
|
|
|
csv_logger = CSVLogger(config_loader.models_path, experiment_name, version=run.id)
|
2023-08-29 09:59:31 +02:00
|
|
|
csv_logger.log_hyperparams(wandb.config)
|
|
|
|
|
|
|
|
loggers = [wandb_logger, csv_logger]
|
|
|
|
|
|
|
|
experiment_signature = f"{experiment_name}_sweep={self.sweep_id}-run={run.id}"
|
|
|
|
|
|
|
|
logger.debug("Experiment signature: " + experiment_signature)
|
|
|
|
|
|
|
|
checkpoint_callback = ModelCheckpoint(
|
|
|
|
save_top_k=1,
|
|
|
|
filename=experiment_signature + "-{epoch}-{val_loss:.3f}",
|
|
|
|
monitor="val_loss",
|
|
|
|
mode="min",
|
2023-10-12 14:27:53 +02:00
|
|
|
dirpath=f"{config_loader.models_path}/{experiment_name}/",
|
2023-08-29 09:59:31 +02:00
|
|
|
) # save_top_k=1, monitor="val_loss", mode="min": save the best model in terms of validation loss
|
|
|
|
checkpoint_callback.STARTING_VERSION = 1
|
|
|
|
|
|
|
|
early_stopping_callback = EarlyStopping(
|
|
|
|
monitor="val_loss",
|
2024-07-29 13:41:05 +02:00
|
|
|
patience=5,
|
2023-08-29 09:59:31 +02:00
|
|
|
verbose=True,
|
|
|
|
mode="min")
|
2024-07-29 13:41:05 +02:00
|
|
|
|
|
|
|
lr_monitor = LearningRateMonitor(logging_interval='epoch')
|
|
|
|
|
|
|
|
callbacks = [checkpoint_callback, early_stopping_callback, lr_monitor]
|
2023-08-29 09:59:31 +02:00
|
|
|
|
|
|
|
trainer = pl.Trainer(
|
2023-10-12 14:27:53 +02:00
|
|
|
default_root_dir=config_loader.models_path,
|
2023-08-29 09:59:31 +02:00
|
|
|
logger=loggers,
|
|
|
|
callbacks=callbacks,
|
|
|
|
**get_trainer_args(wandb.config)
|
|
|
|
)
|
|
|
|
trainer.fit(model, train_loader, dev_loader)
|
|
|
|
|
|
|
|
except Exception as e:
|
|
|
|
logger.error("caught error: ", str(e))
|
|
|
|
traceback_str = traceback.format_exc()
|
|
|
|
logger.error(traceback_str)
|
|
|
|
|
|
|
|
run.finish()
|
|
|
|
|
|
|
|
|
|
|
|
def start_sweep(sweep_config):
|
|
|
|
logger.info("Starting sweep with config: " + str(sweep_config))
|
2023-10-12 14:27:53 +02:00
|
|
|
set_random_seed(config_loader.seed)
|
2023-08-29 09:59:31 +02:00
|
|
|
sweep_runner = HyperparameterSweep(project_name=wandb_project_name, sweep_config=sweep_config)
|
|
|
|
sweep_runner.run_sweep()
|
|
|
|
|
|
|
|
return sweep_runner
|
|
|
|
|
|
|
|
|
|
|
|
if __name__ == "__main__":
|
|
|
|
parser = argparse.ArgumentParser()
|
|
|
|
parser.add_argument("--sweep_config", type=str, required=True)
|
|
|
|
args = parser.parse_args()
|
|
|
|
|
|
|
|
sweep_config = util.load_sweep_config(args.sweep_config)
|
|
|
|
start_sweep(sweep_config)
|
|
|
|
|