""" This script handles the training of models base on model configuration files. """ import seisbench.generate as sbg from seisbench.util import worker_seeding import pytorch_lightning as pl # from pytorch_lightning.loggers import TensorBoardLogger, CSVLogger from pytorch_lightning.loggers import WandbLogger, CSVLogger # https://github.com/Lightning-AI/lightning/pull/12554 # https://github.com/Lightning-AI/lightning/issues/11796 from pytorch_lightning.callbacks import ModelCheckpoint import argparse import json import numpy as np from torch.utils.data import DataLoader import torch import os import logging from pathlib import Path from dotenv import load_dotenv import models, data, util import time import datetime import wandb def train(config, experiment_name, test_run): """ Runs the model training defined by the config. Config parameters: - model: Model used as in the models.py file, but without the Lit suffix - data: Dataset used, as in seisbench.data - model_args: Arguments passed to the constructor of the model lightning module - trainer_args: Arguments passed to the lightning trainer - batch_size: Batch size for training and validation - num_workers: Number of workers for data loading. If not set, uses environment variable BENCHMARK_DEFAULT_WORKERS - restrict_to_phase: Filters datasets only to examples containing the given phase. By default, uses all phases. - training_fraction: Fraction of training blocks to use as float between 0 and 1. Defaults to 1. :param config: Configuration parameters for training :param test_run: If true, makes a test run with less data and less logging. Intended for debug purposes. """ model = models.__getattribute__(config["model"] + "Lit")( **config.get("model_args", {}) ) train_loader, dev_loader = prepare_data(config, model, test_run) # CSV logger - also used for saving configuration as yaml csv_logger = CSVLogger("weights", experiment_name) csv_logger.log_hyperparams(config) loggers = [csv_logger] default_root_dir = os.path.join( "weights" ) # Experiment name is parsed from the loggers if not test_run: # tb_logger = TensorBoardLogger("tb_logs", experiment_name) # tb_logger.log_hyperparams(config) # loggers += [tb_logger] wandb_logger = WandbLogger() wandb_logger.watch(model) loggers +=[wandb_logger] checkpoint_callback = ModelCheckpoint( save_top_k=1, filename="{epoch}-{step}", monitor="val_loss", mode="min" ) # save_top_k=1, monitor="val_loss", mode="min": save the best model in terms of validation loss callbacks = [checkpoint_callback] ## Uncomment the following 2 lines to enable # device_stats = DeviceStatsMonitor() # callbacks.append(device_stats) trainer = pl.Trainer( default_root_dir=default_root_dir, logger=loggers, callbacks=callbacks, **config.get("trainer_args", {}), ) trainer.fit(model, train_loader, dev_loader) def prepare_data(config, model, test_run): """ Returns the training and validation data loaders :param config: :param model: :param test_run: :return: """ batch_size = config.get("batch_size", 1024) if type(batch_size) == list: batch_size = batch_size[0] num_workers = config.get("num_workers", util.default_workers) try: dataset = data.get_dataset_by_name(config["dataset_name"])( sampling_rate=100, component_order="ZNE", dimension_order="NCW", cache="full" ) except ValueError: data_path = str(Path.cwd().parent) + '/' + config['data_path'] print(data_path) dataset = data.get_custom_dataset(data_path) restrict_to_phase = config.get("restrict_to_phase", None) if restrict_to_phase is not None: mask = generate_phase_mask(dataset, restrict_to_phase) dataset.filter(mask, inplace=True) if "split" not in dataset.metadata.columns: logging.warning("No split defined, adding auxiliary split.") split = np.array(["train"] * len(dataset)) split[int(0.6 * len(dataset)) : int(0.7 * len(dataset))] = "dev" split[int(0.7 * len(dataset)) :] = "test" dataset._metadata["split"] = split train_data = dataset.train() dev_data = dataset.dev() if test_run: # Only use a small part of the dataset train_mask = np.zeros(len(train_data), dtype=bool) train_mask[:5000] = True train_data.filter(train_mask, inplace=True) dev_mask = np.zeros(len(dev_data), dtype=bool) dev_mask[:5000] = True dev_data.filter(dev_mask, inplace=True) training_fraction = config.get("training_fraction", 1.0) apply_training_fraction(training_fraction, train_data) train_data.preload_waveforms(pbar=True) dev_data.preload_waveforms(pbar=True) train_generator = sbg.GenericGenerator(train_data) dev_generator = sbg.GenericGenerator(dev_data) train_generator.add_augmentations(model.get_train_augmentations()) dev_generator.add_augmentations(model.get_val_augmentations()) train_loader = DataLoader( train_generator, batch_size=batch_size, shuffle=True, num_workers=num_workers, worker_init_fn=worker_seeding, drop_last=True, # Avoid crashes from batch norm layers for batch size 1 ) dev_loader = DataLoader( dev_generator, batch_size=batch_size, num_workers=num_workers, worker_init_fn=worker_seeding, ) return train_loader, dev_loader def apply_training_fraction(training_fraction, train_data): """ Reduces the size of train_data to train_fraction by inplace filtering. Filter blockwise for efficient memory savings. :param training_fraction: Training fraction between 0 and 1. :param train_data: Training dataset :return: None """ if not 0.0 < training_fraction <= 1.0: raise ValueError("Training fraction needs to be between 0 and 1.") if training_fraction < 1: blocks = train_data["trace_name"].apply(lambda x: x.split("$")[0]) unique_blocks = blocks.unique() np.random.shuffle(unique_blocks) target_blocks = unique_blocks[: int(training_fraction * len(unique_blocks))] target_blocks = set(target_blocks) mask = blocks.isin(target_blocks) train_data.filter(mask, inplace=True) def generate_phase_mask(dataset, phases): mask = np.zeros(len(dataset), dtype=bool) for key, phase in models.phase_dict.items(): if phase not in phases: continue else: if key in dataset.metadata: mask = np.logical_or(mask, ~np.isnan(dataset.metadata[key])) return mask if __name__ == "__main__": load_dotenv() wandb_api_key = os.environ.get('WANDB_API_KEY') if wandb_api_key is None: raise ValueError("WANDB_API_KEY environment variable is not set.") wandb.login(key=wandb_api_key) code_start_time = time.perf_counter() torch.manual_seed(42) parser = argparse.ArgumentParser() parser.add_argument("--config", type=str, required=True) parser.add_argument("--test_run", action="store_true") parser.add_argument("--lr", default=None, type=float) args = parser.parse_args() with open(args.config, "r") as f: config = json.load(f) experiment_name = os.path.basename(args.config)[:-5] if args.lr is not None: logging.warning(f"Overwriting learning rate to {args.lr}") experiment_name += f"_{args.lr}" config["model_args"]["lr"] = args.lr run = wandb.init( # set the wandb project where this run will be logged project="training_seisbench_models_on_igf_data_with_pick-benchmark", # track hyperparameters and run metadata config=config ) if args.test_run: experiment_name = experiment_name + "_test" train(config, experiment_name, test_run=args.test_run) running_time = str( datetime.timedelta(seconds=time.perf_counter() - code_start_time) ) print(f"Running time: {running_time}")