249 lines
8.1 KiB
Python
249 lines
8.1 KiB
Python
"""
|
|
This script handles the training of models base on model configuration files.
|
|
"""
|
|
|
|
import seisbench.generate as sbg
|
|
from seisbench.util import worker_seeding
|
|
|
|
import pytorch_lightning as pl
|
|
# from pytorch_lightning.loggers import TensorBoardLogger, CSVLogger
|
|
from pytorch_lightning.loggers import WandbLogger, CSVLogger
|
|
|
|
# https://github.com/Lightning-AI/lightning/pull/12554
|
|
# https://github.com/Lightning-AI/lightning/issues/11796
|
|
from pytorch_lightning.callbacks import ModelCheckpoint
|
|
import argparse
|
|
import json
|
|
import numpy as np
|
|
from torch.utils.data import DataLoader
|
|
import torch
|
|
import os
|
|
import logging
|
|
from pathlib import Path
|
|
from dotenv import load_dotenv
|
|
|
|
import models, data, util
|
|
import time
|
|
import datetime
|
|
import wandb
|
|
|
|
|
|
def train(config, experiment_name, test_run):
|
|
"""
|
|
Runs the model training defined by the config.
|
|
|
|
Config parameters:
|
|
|
|
- model: Model used as in the models.py file, but without the Lit suffix
|
|
- data: Dataset used, as in seisbench.data
|
|
- model_args: Arguments passed to the constructor of the model lightning module
|
|
- trainer_args: Arguments passed to the lightning trainer
|
|
- batch_size: Batch size for training and validation
|
|
- num_workers: Number of workers for data loading.
|
|
If not set, uses environment variable BENCHMARK_DEFAULT_WORKERS
|
|
- restrict_to_phase: Filters datasets only to examples containing the given phase.
|
|
By default, uses all phases.
|
|
- training_fraction: Fraction of training blocks to use as float between 0 and 1. Defaults to 1.
|
|
|
|
:param config: Configuration parameters for training
|
|
:param test_run: If true, makes a test run with less data and less logging. Intended for debug purposes.
|
|
"""
|
|
model = models.__getattribute__(config["model"] + "Lit")(
|
|
**config.get("model_args", {})
|
|
)
|
|
|
|
train_loader, dev_loader = prepare_data(config, model, test_run)
|
|
|
|
# CSV logger - also used for saving configuration as yaml
|
|
csv_logger = CSVLogger("weights", experiment_name)
|
|
csv_logger.log_hyperparams(config)
|
|
loggers = [csv_logger]
|
|
|
|
default_root_dir = os.path.join(
|
|
"weights"
|
|
) # Experiment name is parsed from the loggers
|
|
if not test_run:
|
|
# tb_logger = TensorBoardLogger("tb_logs", experiment_name)
|
|
# tb_logger.log_hyperparams(config)
|
|
# loggers += [tb_logger]
|
|
wandb_logger = WandbLogger()
|
|
wandb_logger.watch(model)
|
|
|
|
loggers +=[wandb_logger]
|
|
|
|
checkpoint_callback = ModelCheckpoint(
|
|
save_top_k=1, filename="{epoch}-{step}", monitor="val_loss", mode="min"
|
|
) # save_top_k=1, monitor="val_loss", mode="min": save the best model in terms of validation loss
|
|
callbacks = [checkpoint_callback]
|
|
|
|
## Uncomment the following 2 lines to enable
|
|
# device_stats = DeviceStatsMonitor()
|
|
# callbacks.append(device_stats)
|
|
|
|
trainer = pl.Trainer(
|
|
default_root_dir=default_root_dir,
|
|
logger=loggers,
|
|
callbacks=callbacks,
|
|
**config.get("trainer_args", {}),
|
|
)
|
|
|
|
trainer.fit(model, train_loader, dev_loader)
|
|
|
|
|
|
def prepare_data(config, model, test_run):
|
|
"""
|
|
Returns the training and validation data loaders
|
|
:param config:
|
|
:param model:
|
|
:param test_run:
|
|
:return:
|
|
"""
|
|
batch_size = config.get("batch_size", 1024)
|
|
if type(batch_size) == list:
|
|
batch_size = batch_size[0]
|
|
|
|
num_workers = config.get("num_workers", util.default_workers)
|
|
try:
|
|
dataset = data.get_dataset_by_name(config["dataset_name"])(
|
|
sampling_rate=100, component_order="ZNE", dimension_order="NCW", cache="full"
|
|
)
|
|
except ValueError:
|
|
data_path = str(Path.cwd().parent) + '/' + config['data_path']
|
|
print(data_path)
|
|
dataset = data.get_custom_dataset(data_path)
|
|
|
|
restrict_to_phase = config.get("restrict_to_phase", None)
|
|
if restrict_to_phase is not None:
|
|
mask = generate_phase_mask(dataset, restrict_to_phase)
|
|
dataset.filter(mask, inplace=True)
|
|
|
|
if "split" not in dataset.metadata.columns:
|
|
logging.warning("No split defined, adding auxiliary split.")
|
|
split = np.array(["train"] * len(dataset))
|
|
split[int(0.6 * len(dataset)) : int(0.7 * len(dataset))] = "dev"
|
|
split[int(0.7 * len(dataset)) :] = "test"
|
|
|
|
dataset._metadata["split"] = split
|
|
|
|
train_data = dataset.train()
|
|
dev_data = dataset.dev()
|
|
|
|
if test_run:
|
|
# Only use a small part of the dataset
|
|
train_mask = np.zeros(len(train_data), dtype=bool)
|
|
train_mask[:5000] = True
|
|
train_data.filter(train_mask, inplace=True)
|
|
|
|
dev_mask = np.zeros(len(dev_data), dtype=bool)
|
|
dev_mask[:5000] = True
|
|
dev_data.filter(dev_mask, inplace=True)
|
|
|
|
training_fraction = config.get("training_fraction", 1.0)
|
|
apply_training_fraction(training_fraction, train_data)
|
|
|
|
train_data.preload_waveforms(pbar=True)
|
|
dev_data.preload_waveforms(pbar=True)
|
|
|
|
train_generator = sbg.GenericGenerator(train_data)
|
|
dev_generator = sbg.GenericGenerator(dev_data)
|
|
|
|
train_generator.add_augmentations(model.get_train_augmentations())
|
|
dev_generator.add_augmentations(model.get_val_augmentations())
|
|
|
|
train_loader = DataLoader(
|
|
train_generator,
|
|
batch_size=batch_size,
|
|
shuffle=True,
|
|
num_workers=num_workers,
|
|
worker_init_fn=worker_seeding,
|
|
drop_last=True, # Avoid crashes from batch norm layers for batch size 1
|
|
)
|
|
dev_loader = DataLoader(
|
|
dev_generator,
|
|
batch_size=batch_size,
|
|
num_workers=num_workers,
|
|
worker_init_fn=worker_seeding,
|
|
)
|
|
|
|
return train_loader, dev_loader
|
|
|
|
|
|
def apply_training_fraction(training_fraction, train_data):
|
|
"""
|
|
Reduces the size of train_data to train_fraction by inplace filtering.
|
|
Filter blockwise for efficient memory savings.
|
|
|
|
:param training_fraction: Training fraction between 0 and 1.
|
|
:param train_data: Training dataset
|
|
:return: None
|
|
"""
|
|
|
|
if not 0.0 < training_fraction <= 1.0:
|
|
raise ValueError("Training fraction needs to be between 0 and 1.")
|
|
|
|
if training_fraction < 1:
|
|
blocks = train_data["trace_name"].apply(lambda x: x.split("$")[0])
|
|
unique_blocks = blocks.unique()
|
|
np.random.shuffle(unique_blocks)
|
|
target_blocks = unique_blocks[: int(training_fraction * len(unique_blocks))]
|
|
target_blocks = set(target_blocks)
|
|
mask = blocks.isin(target_blocks)
|
|
train_data.filter(mask, inplace=True)
|
|
|
|
|
|
def generate_phase_mask(dataset, phases):
|
|
mask = np.zeros(len(dataset), dtype=bool)
|
|
|
|
for key, phase in models.phase_dict.items():
|
|
if phase not in phases:
|
|
continue
|
|
else:
|
|
if key in dataset.metadata:
|
|
mask = np.logical_or(mask, ~np.isnan(dataset.metadata[key]))
|
|
|
|
return mask
|
|
|
|
|
|
if __name__ == "__main__":
|
|
|
|
load_dotenv()
|
|
wandb_api_key = os.environ.get('WANDB_API_KEY')
|
|
if wandb_api_key is None:
|
|
raise ValueError("WANDB_API_KEY environment variable is not set.")
|
|
|
|
wandb.login(key=wandb_api_key)
|
|
|
|
code_start_time = time.perf_counter()
|
|
|
|
torch.manual_seed(42)
|
|
parser = argparse.ArgumentParser()
|
|
parser.add_argument("--config", type=str, required=True)
|
|
parser.add_argument("--test_run", action="store_true")
|
|
parser.add_argument("--lr", default=None, type=float)
|
|
args = parser.parse_args()
|
|
|
|
with open(args.config, "r") as f:
|
|
config = json.load(f)
|
|
|
|
experiment_name = os.path.basename(args.config)[:-5]
|
|
if args.lr is not None:
|
|
logging.warning(f"Overwriting learning rate to {args.lr}")
|
|
experiment_name += f"_{args.lr}"
|
|
config["model_args"]["lr"] = args.lr
|
|
|
|
run = wandb.init(
|
|
# set the wandb project where this run will be logged
|
|
project="training_seisbench_models_on_igf_data_with_pick-benchmark",
|
|
# track hyperparameters and run metadata
|
|
config=config
|
|
)
|
|
|
|
if args.test_run:
|
|
experiment_name = experiment_name + "_test"
|
|
train(config, experiment_name, test_run=args.test_run)
|
|
|
|
running_time = str(
|
|
datetime.timedelta(seconds=time.perf_counter() - code_start_time)
|
|
)
|
|
print(f"Running time: {running_time}")
|