""" This script handles the training of models base on model configuration files. """ import seisbench.generate as sbg from seisbench.util import worker_seeding import pytorch_lightning as pl # from pytorch_lightning.loggers import TensorBoardLogger, CSVLogger from pytorch_lightning.loggers import WandbLogger, CSVLogger # https://github.com/Lightning-AI/lightning/pull/12554 # https://github.com/Lightning-AI/lightning/issues/11796 from pytorch_lightning.callbacks import ModelCheckpoint import argparse import json import numpy as np from torch.utils.data import DataLoader import torch import os import logging from pathlib import Path import models, data, util import time import datetime import wandb # # load_dotenv() # wandb_api_key = os.environ.get('WANDB_API_KEY') # if wandb_api_key is None: # raise ValueError("WANDB_API_KEY environment variable is not set.") # # wandb.login(key=wandb_api_key) def train(config, experiment_name, test_run): """ Runs the model training defined by the config. Config parameters: - model: Model used as in the models.py file, but without the Lit suffix - data: Dataset used, as in seisbench.data - model_args: Arguments passed to the constructor of the model lightning module - trainer_args: Arguments passed to the lightning trainer - batch_size: Batch size for training and validation - num_workers: Number of workers for data loading. If not set, uses environment variable BENCHMARK_DEFAULT_WORKERS - restrict_to_phase: Filters datasets only to examples containing the given phase. By default, uses all phases. - training_fraction: Fraction of training blocks to use as float between 0 and 1. Defaults to 1. :param config: Configuration parameters for training :param test_run: If true, makes a test run with less data and less logging. Intended for debug purposes. """ model = models.__getattribute__(config["model"] + "Lit")( **config.get("model_args", {}) ) train_loader, dev_loader = prepare_data(config, model, test_run) # CSV logger - also used for saving configuration as yaml csv_logger = CSVLogger("weights", experiment_name) csv_logger.log_hyperparams(config) loggers = [csv_logger] default_root_dir = os.path.join( "weights" ) # Experiment name is parsed from the loggers if not test_run: # tb_logger = TensorBoardLogger("tb_logs", experiment_name) # tb_logger.log_hyperparams(config) # loggers += [tb_logger] wandb_logger = WandbLogger() wandb_logger.watch(model) loggers +=[wandb_logger] checkpoint_callback = ModelCheckpoint( save_top_k=1, filename="{epoch}-{step}", monitor="val_loss", mode="min" ) # save_top_k=1, monitor="val_loss", mode="min": save the best model in terms of validation loss callbacks = [checkpoint_callback] ## Uncomment the following 2 lines to enable # device_stats = DeviceStatsMonitor() # callbacks.append(device_stats) trainer = pl.Trainer( default_root_dir=default_root_dir, logger=loggers, callbacks=callbacks, **config.get("trainer_args", {}), ) trainer.fit(model, train_loader, dev_loader) def prepare_data(config, model, test_run): """ Returns the training and validation data loaders :param config: :param model: :param test_run: :return: """ batch_size = config.get("batch_size", 1024) if type(batch_size) == list: batch_size = batch_size[0] num_workers = config.get("num_workers", util.default_workers) try: dataset = data.get_dataset_by_name(config["dataset_name"])( sampling_rate=100, component_order="ZNE", dimension_order="NCW", cache="full" ) except ValueError: data_path = str(Path.cwd().parent) + '/' + config['data_path'] print(data_path) dataset = data.get_custom_dataset(data_path) restrict_to_phase = config.get("restrict_to_phase", None) if restrict_to_phase is not None: mask = generate_phase_mask(dataset, restrict_to_phase) dataset.filter(mask, inplace=True) if "split" not in dataset.metadata.columns: logging.warning("No split defined, adding auxiliary split.") split = np.array(["train"] * len(dataset)) split[int(0.6 * len(dataset)) : int(0.7 * len(dataset))] = "dev" split[int(0.7 * len(dataset)) :] = "test" dataset._metadata["split"] = split train_data = dataset.train() dev_data = dataset.dev() if test_run: # Only use a small part of the dataset train_mask = np.zeros(len(train_data), dtype=bool) train_mask[:5000] = True train_data.filter(train_mask, inplace=True) dev_mask = np.zeros(len(dev_data), dtype=bool) dev_mask[:5000] = True dev_data.filter(dev_mask, inplace=True) training_fraction = config.get("training_fraction", 1.0) apply_training_fraction(training_fraction, train_data) train_data.preload_waveforms(pbar=True) dev_data.preload_waveforms(pbar=True) train_generator = sbg.GenericGenerator(train_data) dev_generator = sbg.GenericGenerator(dev_data) train_generator.add_augmentations(model.get_train_augmentations()) dev_generator.add_augmentations(model.get_val_augmentations()) train_loader = DataLoader( train_generator, batch_size=batch_size, shuffle=True, num_workers=num_workers, worker_init_fn=worker_seeding, drop_last=True, # Avoid crashes from batch norm layers for batch size 1 ) dev_loader = DataLoader( dev_generator, batch_size=batch_size, num_workers=num_workers, worker_init_fn=worker_seeding, ) return train_loader, dev_loader def apply_training_fraction(training_fraction, train_data): """ Reduces the size of train_data to train_fraction by inplace filtering. Filter blockwise for efficient memory savings. :param training_fraction: Training fraction between 0 and 1. :param train_data: Training dataset :return: None """ if not 0.0 < training_fraction <= 1.0: raise ValueError("Training fraction needs to be between 0 and 1.") if training_fraction < 1: blocks = train_data["trace_name"].apply(lambda x: x.split("$")[0]) unique_blocks = blocks.unique() np.random.shuffle(unique_blocks) target_blocks = unique_blocks[: int(training_fraction * len(unique_blocks))] target_blocks = set(target_blocks) mask = blocks.isin(target_blocks) train_data.filter(mask, inplace=True) def generate_phase_mask(dataset, phases): mask = np.zeros(len(dataset), dtype=bool) for key, phase in models.phase_dict.items(): if phase not in phases: continue else: if key in dataset.metadata: mask = np.logical_or(mask, ~np.isnan(dataset.metadata[key])) return mask if __name__ == "__main__": code_start_time = time.perf_counter() torch.manual_seed(42) parser = argparse.ArgumentParser() parser.add_argument("--config", type=str, required=True) parser.add_argument("--test_run", action="store_true") parser.add_argument("--lr", default=None, type=float) args = parser.parse_args() with open(args.config, "r") as f: config = json.load(f) experiment_name = os.path.basename(args.config)[:-5] if args.lr is not None: logging.warning(f"Overwriting learning rate to {args.lr}") experiment_name += f"_{args.lr}" config["model_args"]["lr"] = args.lr run = wandb.init( # set the wandb project where this run will be logged project="training_seisbench_models_on_igf_data_with_pick-benchmark", # track hyperparameters and run metadata config=config ) if args.test_run: experiment_name = experiment_name + "_test" train(config, experiment_name, test_run=args.test_run) running_time = str( datetime.timedelta(seconds=time.perf_counter() - code_start_time) ) print(f"Running time: {running_time}")