platform-demo-scripts/scripts/train.py

249 lines
8.1 KiB
Python

"""
This script handles the training of models base on model configuration files.
"""
import seisbench.generate as sbg
from seisbench.util import worker_seeding
import pytorch_lightning as pl
# from pytorch_lightning.loggers import TensorBoardLogger, CSVLogger
from pytorch_lightning.loggers import WandbLogger, CSVLogger
# https://github.com/Lightning-AI/lightning/pull/12554
# https://github.com/Lightning-AI/lightning/issues/11796
from pytorch_lightning.callbacks import ModelCheckpoint
import argparse
import json
import numpy as np
from torch.utils.data import DataLoader
import torch
import os
import logging
from pathlib import Path
from dotenv import load_dotenv
import models, data, util
import time
import datetime
import wandb
def train(config, experiment_name, test_run):
"""
Runs the model training defined by the config.
Config parameters:
- model: Model used as in the models.py file, but without the Lit suffix
- data: Dataset used, as in seisbench.data
- model_args: Arguments passed to the constructor of the model lightning module
- trainer_args: Arguments passed to the lightning trainer
- batch_size: Batch size for training and validation
- num_workers: Number of workers for data loading.
If not set, uses environment variable BENCHMARK_DEFAULT_WORKERS
- restrict_to_phase: Filters datasets only to examples containing the given phase.
By default, uses all phases.
- training_fraction: Fraction of training blocks to use as float between 0 and 1. Defaults to 1.
:param config: Configuration parameters for training
:param test_run: If true, makes a test run with less data and less logging. Intended for debug purposes.
"""
model = models.__getattribute__(config["model"] + "Lit")(
**config.get("model_args", {})
)
train_loader, dev_loader = prepare_data(config, model, test_run)
# CSV logger - also used for saving configuration as yaml
csv_logger = CSVLogger("weights", experiment_name)
csv_logger.log_hyperparams(config)
loggers = [csv_logger]
default_root_dir = os.path.join(
"weights"
) # Experiment name is parsed from the loggers
if not test_run:
# tb_logger = TensorBoardLogger("tb_logs", experiment_name)
# tb_logger.log_hyperparams(config)
# loggers += [tb_logger]
wandb_logger = WandbLogger()
wandb_logger.watch(model)
loggers +=[wandb_logger]
checkpoint_callback = ModelCheckpoint(
save_top_k=1, filename="{epoch}-{step}", monitor="val_loss", mode="min"
) # save_top_k=1, monitor="val_loss", mode="min": save the best model in terms of validation loss
callbacks = [checkpoint_callback]
## Uncomment the following 2 lines to enable
# device_stats = DeviceStatsMonitor()
# callbacks.append(device_stats)
trainer = pl.Trainer(
default_root_dir=default_root_dir,
logger=loggers,
callbacks=callbacks,
**config.get("trainer_args", {}),
)
trainer.fit(model, train_loader, dev_loader)
def prepare_data(config, model, test_run):
"""
Returns the training and validation data loaders
:param config:
:param model:
:param test_run:
:return:
"""
batch_size = config.get("batch_size", 1024)
if type(batch_size) == list:
batch_size = batch_size[0]
num_workers = config.get("num_workers", util.default_workers)
try:
dataset = data.get_dataset_by_name(config["dataset_name"])(
sampling_rate=100, component_order="ZNE", dimension_order="NCW", cache="full"
)
except ValueError:
data_path = str(Path.cwd().parent) + '/' + config['data_path']
print(data_path)
dataset = data.get_custom_dataset(data_path)
restrict_to_phase = config.get("restrict_to_phase", None)
if restrict_to_phase is not None:
mask = generate_phase_mask(dataset, restrict_to_phase)
dataset.filter(mask, inplace=True)
if "split" not in dataset.metadata.columns:
logging.warning("No split defined, adding auxiliary split.")
split = np.array(["train"] * len(dataset))
split[int(0.6 * len(dataset)) : int(0.7 * len(dataset))] = "dev"
split[int(0.7 * len(dataset)) :] = "test"
dataset._metadata["split"] = split
train_data = dataset.train()
dev_data = dataset.dev()
if test_run:
# Only use a small part of the dataset
train_mask = np.zeros(len(train_data), dtype=bool)
train_mask[:5000] = True
train_data.filter(train_mask, inplace=True)
dev_mask = np.zeros(len(dev_data), dtype=bool)
dev_mask[:5000] = True
dev_data.filter(dev_mask, inplace=True)
training_fraction = config.get("training_fraction", 1.0)
apply_training_fraction(training_fraction, train_data)
train_data.preload_waveforms(pbar=True)
dev_data.preload_waveforms(pbar=True)
train_generator = sbg.GenericGenerator(train_data)
dev_generator = sbg.GenericGenerator(dev_data)
train_generator.add_augmentations(model.get_train_augmentations())
dev_generator.add_augmentations(model.get_val_augmentations())
train_loader = DataLoader(
train_generator,
batch_size=batch_size,
shuffle=True,
num_workers=num_workers,
worker_init_fn=worker_seeding,
drop_last=True, # Avoid crashes from batch norm layers for batch size 1
)
dev_loader = DataLoader(
dev_generator,
batch_size=batch_size,
num_workers=num_workers,
worker_init_fn=worker_seeding,
)
return train_loader, dev_loader
def apply_training_fraction(training_fraction, train_data):
"""
Reduces the size of train_data to train_fraction by inplace filtering.
Filter blockwise for efficient memory savings.
:param training_fraction: Training fraction between 0 and 1.
:param train_data: Training dataset
:return: None
"""
if not 0.0 < training_fraction <= 1.0:
raise ValueError("Training fraction needs to be between 0 and 1.")
if training_fraction < 1:
blocks = train_data["trace_name"].apply(lambda x: x.split("$")[0])
unique_blocks = blocks.unique()
np.random.shuffle(unique_blocks)
target_blocks = unique_blocks[: int(training_fraction * len(unique_blocks))]
target_blocks = set(target_blocks)
mask = blocks.isin(target_blocks)
train_data.filter(mask, inplace=True)
def generate_phase_mask(dataset, phases):
mask = np.zeros(len(dataset), dtype=bool)
for key, phase in models.phase_dict.items():
if phase not in phases:
continue
else:
if key in dataset.metadata:
mask = np.logical_or(mask, ~np.isnan(dataset.metadata[key]))
return mask
if __name__ == "__main__":
load_dotenv()
wandb_api_key = os.environ.get('WANDB_API_KEY')
if wandb_api_key is None:
raise ValueError("WANDB_API_KEY environment variable is not set.")
wandb.login(key=wandb_api_key)
code_start_time = time.perf_counter()
torch.manual_seed(42)
parser = argparse.ArgumentParser()
parser.add_argument("--config", type=str, required=True)
parser.add_argument("--test_run", action="store_true")
parser.add_argument("--lr", default=None, type=float)
args = parser.parse_args()
with open(args.config, "r") as f:
config = json.load(f)
experiment_name = os.path.basename(args.config)[:-5]
if args.lr is not None:
logging.warning(f"Overwriting learning rate to {args.lr}")
experiment_name += f"_{args.lr}"
config["model_args"]["lr"] = args.lr
run = wandb.init(
# set the wandb project where this run will be logged
project="training_seisbench_models_on_igf_data_with_pick-benchmark",
# track hyperparameters and run metadata
config=config
)
if args.test_run:
experiment_name = experiment_name + "_test"
train(config, experiment_name, test_run=args.test_run)
running_time = str(
datetime.timedelta(seconds=time.perf_counter() - code_start_time)
)
print(f"Running time: {running_time}")