finetuning (#1)

Reviewed-on: #1
Extended Phasenet finetuning options
This commit is contained in:
2024-07-29 13:41:05 +02:00
parent 5c3ce04868
commit e86f131cc0
11 changed files with 381 additions and 54 deletions

View File

@@ -1,8 +1,13 @@
"""
This file contains functionality related to data.
"""
import os.path
import seisbench.data as sbd
import logging
logging.root.setLevel(logging.INFO)
logger = logging.getLogger('data')
def get_dataset_by_name(name):
@@ -30,3 +35,26 @@ def get_custom_dataset(path):
except AttributeError:
raise ValueError(f"Unknown dataset '{path}'.")
def validate_custom_dataset(data_path):
"""
Validate the dataset
:param data_path: path to the dataset
:return:
"""
# check if path exists
if not os.path.isdir((data_path)):
raise ValueError(f"Data path {data_path} does not exist.")
dataset = sbd.WaveformDataset(data_path)
# check if the dataset is split into train, dev and test
if len(dataset.train()) == 0:
raise ValueError(f"Training set is empty.")
if len(dataset.dev()) == 0:
raise ValueError(f"Dev set is empty.")
if len(dataset.test()) == 0:
raise ValueError(f"Test set is empty.")
logger.info("Custom dataset validated successfully.")

View File

@@ -7,8 +7,9 @@ import os
import os.path
import argparse
from pytorch_lightning.loggers import WandbLogger, CSVLogger
from pytorch_lightning.callbacks import ModelCheckpoint
from pytorch_lightning.callbacks import ModelCheckpoint, LearningRateMonitor
from pytorch_lightning.callbacks.early_stopping import EarlyStopping
import pytorch_lightning as pl
import wandb
import torch
@@ -20,6 +21,8 @@ import train
import util
import config_loader
torch.multiprocessing.set_sharing_strategy('file_system')
os.system("ulimit -n unlimited")
@@ -31,14 +34,13 @@ host = os.environ.get("WANDB_HOST")
if host is None:
raise ValueError("WANDB_HOST environment variable is not set.")
wandb.login(key=wandb_api_key, host=host)
wandb_project_name = os.environ.get("WANDB_PROJECT")
wandb_user_name = os.environ.get("WANDB_USER")
script_name = os.path.splitext(os.path.basename(__file__))[0]
logger = logging.getLogger(script_name)
logger.setLevel(logging.WARNING)
logger.setLevel(logging.INFO)
def set_random_seed(seed=3):
@@ -54,6 +56,11 @@ def get_trainer_args(config):
return trainer_args
def get_arg(arg):
if type(arg) == list:
return arg[0]
return arg
class HyperparameterSweep:
def __init__(self, project_name, sweep_config):
self.project_name = project_name
@@ -84,25 +91,40 @@ class HyperparameterSweep:
return all_not_running
def run_experiment(self):
try:
logger.debug("Starting a new run...")
logger.info("Starting a new run...")
run = wandb.init(
project=self.project_name,
config=config_loader.config,
save_code=True
save_code=True,
entity=wandb_user_name
)
run.log_code(
root=".",
include_fn=lambda path: path.endswith(".py") or path.endswith(".sh"),
exclude_fn=lambda path: path.endswith("template.sh")
) # not working as expected
)
model_name = wandb.config.model_name[0]
model_name = get_arg(wandb.config.model_name)
model_args = models.get_model_specific_args(wandb.config)
logger.debug(f"Initializing {model_name}")
if "pretrained" in wandb.config:
weights = get_arg(wandb.config.pretrained)
if weights != "false":
model_args["pretrained"] = weights
if "norm" in wandb.config:
model_args["norm"] = get_arg(wandb.config.norm)
if "finetuning" in wandb.config:
model_args['finetuning_strategy'] = get_arg(wandb.config.finetuning)
if "lr_reduce_factor" in wandb.config:
model_args['steplr_gamma'] = get_arg(wandb.config.lr_reduce_factor)
logger.debug(f"Initializing {model_name} with args: {model_args}")
model = models.__getattribute__(model_name + "Lit")(**model_args)
train_loader, dev_loader = train.prepare_data(wandb.config, model, test_run=False)
@@ -132,10 +154,13 @@ class HyperparameterSweep:
early_stopping_callback = EarlyStopping(
monitor="val_loss",
patience=3,
patience=5,
verbose=True,
mode="min")
callbacks = [checkpoint_callback, early_stopping_callback]
lr_monitor = LearningRateMonitor(logging_interval='epoch')
callbacks = [checkpoint_callback, early_stopping_callback, lr_monitor]
trainer = pl.Trainer(
default_root_dir=config_loader.models_path,
@@ -143,7 +168,6 @@ class HyperparameterSweep:
callbacks=callbacks,
**get_trainer_args(wandb.config)
)
trainer.fit(model, train_loader, dev_loader)
except Exception as e:
@@ -155,7 +179,6 @@ class HyperparameterSweep:
def start_sweep(sweep_config):
logger.info("Starting sweep with config: " + str(sweep_config))
set_random_seed(config_loader.seed)
sweep_runner = HyperparameterSweep(project_name=wandb_project_name, sweep_config=sweep_config)
@@ -165,7 +188,6 @@ def start_sweep(sweep_config):
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--sweep_config", type=str, required=True)
args = parser.parse_args()

146
scripts/input_validate.py Normal file
View File

@@ -0,0 +1,146 @@
from pydantic import BaseModel, ConfigDict, field_validator
from typing_extensions import Literal
from typing import Union, List, Optional
import yaml
import logging
logging.root.setLevel(logging.INFO)
logger = logging.getLogger('input_validator')
#todo
# 1. check if a single value is allowed in a sweep
# 2. merge input params
# 3. change names of the classes
# 4. add constraints for PhaseNet, GPD
model_names = Literal["PhaseNet", "GPD", "BasicPhaseAE", "EQTransformer"]
norm_values = Literal["peak", "std"]
finetuning_values = Literal["all", "top", "decoder", "encoder"]
pretrained_values = Literal['diting', 'ethz', 'geofon', 'instance', 'iquique', 'lendb', 'neic',
'original', 'scedc', False]
class Metric(BaseModel):
goal: str
name: str
class NumericValue(BaseModel):
value: Union[int, float, List[Union[int, float]]]
class NumericValues(BaseModel):
values: List[Union[int, float]]
class IntDistribution(BaseModel):
distribution: str = "int_uniform"
min: int
max: int
class FloatDistribution(BaseModel):
distribution: str = "uniform"
min: float
max: float
class Pretrained(BaseModel):
distribution: Optional[str] = "categorical"
values: List[pretrained_values] = None
value: Union[pretrained_values, List[pretrained_values]] = None
class Finetuning(BaseModel):
distribution: Optional[str] = "categorical"
values: List[finetuning_values] = None
value: Union[finetuning_values, List[finetuning_values]] = None
class Norm(BaseModel):
distribution: Optional[str] = "categorical"
values: List[norm_values] = None
value: Union[norm_values, List[norm_values]] = None
class ModelType(BaseModel):
distribution: Optional[str] = "categorical"
value: Union[model_names, List[model_names]] = None
values: List[model_names] = None
class Parameters(BaseModel):
model_config = ConfigDict(extra='forbid', protected_namespaces=())
model_name: ModelType
batch_size: Union[IntDistribution, NumericValue, NumericValues]
learning_rate: Union[FloatDistribution, NumericValue, NumericValues]
max_epochs: Union[IntDistribution, NumericValue, NumericValues]
class PhaseNetParameters(Parameters):
model_config = ConfigDict(extra='forbid')
norm: Norm = None
pretrained: Pretrained = None
finetuning: Finetuning = None
lr_reduce_factor: Optional[Union[FloatDistribution, NumericValue, NumericValues]] = None
highpass: Union[NumericValue, NumericValues, FloatDistribution, IntDistribution] = None
lowpass: Union[NumericValue, NumericValues, FloatDistribution, IntDistribution] = None
@field_validator("model_name")
def validate_model(cls, v):
if "PhaseNet" not in v.value:
raise ValueError("Additional parameters implemented for PhaseNet only")
return v
class FilteringParameters(Parameters):
model_config = ConfigDict(extra='forbid')
highpass: Union[NumericValue, NumericValues, FloatDistribution, IntDistribution] = None
lowpass: Union[NumericValue, NumericValues, FloatDistribution, IntDistribution] = None
@field_validator("model_name")
def validate_model(cls, v):
print(v.value)
if v.value[0] not in ["GPD", "PhaseNet"]:
raise ValueError("Filtering parameters implemented for GPD and PhaseNet only")
class InputParams(BaseModel):
name: str
method: str
metric: Metric
parameters: Union[Parameters, PhaseNetParameters, FilteringParameters]
def validate_sweep_yaml(yaml_filename, model_name=None):
# Load YAML configuration
with open(yaml_filename, 'r') as f:
sweep_config = yaml.safe_load(f)
validate_sweep_config(sweep_config, model_name)
def validate_sweep_config(sweep_config, model_name=None):
# Validate sweep config
input_params = InputParams(**sweep_config)
# Check consistency of input parameters and sweep configuration
sweep_model_name = input_params.parameters.model_name.value
if model_name is not None and model_name not in sweep_model_name:
info = f"Model name {model_name} is inconsistent with the sweep configuration {sweep_model_name}."
logger.info(info)
raise ValueError(info)
logger.info("Input validation successful.")
if __name__ == "__main__":
yaml_filename = "../experiments/sweep_phasenet_bogdanka_lr_bs.yaml"
validate_sweep_yaml(yaml_filename, None)

View File

@@ -7,16 +7,26 @@ import seisbench.generate as sbg
import pytorch_lightning as pl
import torch
from torch.optim import lr_scheduler
import torch.nn.functional as F
import numpy as np
from abc import abstractmethod, ABC
# import lightning as L
# Allows to import this file in both jupyter notebook and code
try:
from .augmentations import DuplicateEvent
except ImportError:
from augmentations import DuplicateEvent
import os
import logging
script_name = os.path.splitext(os.path.basename(__file__))[0]
logger = logging.getLogger(script_name)
logger.setLevel(logging.DEBUG)
# Phase dict for labelling. We only study P and S phases without differentiating between them.
phase_dict = {
@@ -131,7 +141,31 @@ class PhaseNetLit(SeisBenchModuleLit):
self.sigma = sigma
self.sample_boundaries = sample_boundaries
self.loss = vector_cross_entropy
self.model = sbm.PhaseNet(**kwargs)
self.pretrained = kwargs.pop("pretrained", None)
self.norm = kwargs.pop("norm", "peak")
self.highpass = kwargs.pop("highpass", None)
self.lowpass = kwargs.pop("lowpass", None)
if self.pretrained is not None:
self.model = sbm.PhaseNet.from_pretrained(self.pretrained)
# self.norm = self.model.norm
else:
self.model = sbm.PhaseNet(**kwargs)
self.finetuning_strategy = kwargs.pop("finetuning_strategy", None)
self.steplr_gamma = kwargs.pop("steplr_gamma", 0.1)
self.reduce_lr_on_plateau = False
self.initial_epochs = 0
if self.finetuning_strategy is not None:
if self.finetuning_strategy == "top":
self.initial_epochs = 3
elif self.finetuning_strategy in ["decoder", "encoder"]:
self.initial_epochs = 6
self.freeze()
def forward(self, x):
return self.model(x)
@@ -154,9 +188,49 @@ class PhaseNetLit(SeisBenchModuleLit):
def configure_optimizers(self):
optimizer = torch.optim.Adam(self.parameters(), lr=self.lr)
return optimizer
if self.finetuning_strategy is not None:
scheduler = lr_scheduler.LambdaLR(optimizer, self.lr_lambda)
self.reduce_lr_on_plateau = False
else:
scheduler = lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.1, patience=3)
self.reduce_lr_on_plateau = True
#
return {
'optimizer': optimizer,
'lr_scheduler': {
'scheduler': scheduler,
'monitor': 'val_loss',
'interval': 'epoch',
'reduce_on_plateau': self.reduce_lr_on_plateau,
},
}
def lr_lambda(self, epoch):
# reduce lr after x initial epochs
if epoch == self.initial_epochs:
self.lr *= self.steplr_gamma
return self.lr
def lr_scheduler_step(self, scheduler, metric):
if self.reduce_lr_on_plateau:
scheduler.step(metric, epoch=self.current_epoch)
else:
scheduler.step(epoch=self.current_epoch)
# def lr_scheduler_step(self, scheduler, optimizer_idx, metric):
# scheduler.step(epoch=self.current_epoch)
def get_augmentations(self):
filter = []
if self.highpass is not None:
filter = [sbg.Filter(1, self.highpass, "highpass", forward_backward=True)]
logger.info(f"Using highpass filer {self.highpass}")
if self.lowpass is not None:
filter += [sbg.Filter(1, self.lowpass, "lowpass", forward_backward=True)]
logger.info(f"Using lowpass filer {self.lowpass}")
logger.info(filter)
return [
# In 2/3 of the cases, select windows around picks, to reduce amount of noise traces in training.
# Uses strategy variable, as padding will be handled by the random window.
@@ -180,18 +254,26 @@ class PhaseNetLit(SeisBenchModuleLit):
windowlen=3001,
strategy="pad",
),
sbg.Normalize(demean_axis=-1, amp_norm_axis=-1, amp_norm_type=self.norm),
*filter,
sbg.ChangeDtype(np.float32),
sbg.Normalize(demean_axis=-1, amp_norm_axis=-1, amp_norm_type="peak"),
sbg.ProbabilisticLabeller(
label_columns=phase_dict, sigma=self.sigma, dim=0
),
]
def get_eval_augmentations(self):
filter = []
if self.highpass is not None:
filter = [sbg.Filter(1, self.highpass, "highpass", forward_backward=True)]
if self.lowpass is not None:
filter += [sbg.Filter(1, self.lowpass, "lowpass", forward_backward=True)]
return [
sbg.SteeredWindow(windowlen=3001, strategy="pad"),
sbg.Normalize(demean_axis=-1, amp_norm_axis=-1, amp_norm_type=self.norm),
*filter,
sbg.ChangeDtype(np.float32),
sbg.Normalize(demean_axis=-1, amp_norm_axis=-1, amp_norm_type="peak"),
]
def predict_step(self, batch, batch_idx=None, dataloader_idx=None):
@@ -219,6 +301,27 @@ class PhaseNetLit(SeisBenchModuleLit):
return score_detection, score_p_or_s, p_sample, s_sample
def freeze(self):
if self.finetuning_strategy == "decoder": # finetune decoder branch and freeze encoder branch
for p in self.model.down_branch.parameters():
p.requires_grad = False
elif self.finetuning_strategy == "encoder": # finetune encoder branch and freeze decoder branch
for p in self.model.up_branch.parameters():
p.requires_grad = False
elif self.finetuning_strategy == "top":
for p in self.model.out.parameters():
p.requires_grad = False
def unfreeze(self):
logger.info("Unfreezing layers")
for p in self.model.parameters():
p.requires_grad = True
def on_train_epoch_start(self):
# Unfreeze some layers after x initial epochs
if self.current_epoch == self.initial_epochs:
self.unfreeze()
class GPDLit(SeisBenchModuleLit):
"""
@@ -846,7 +949,7 @@ class BasicPhaseAELit(SeisBenchModuleLit):
# Create overlapping windows
re = torch.zeros(x.shape[:2] + (7, 600), dtype=x.dtype, device=x.device)
for i, start in enumerate(range(0, 2401, 400)):
re[:, :, i] = x[:, :, start : start + 600]
re[:, :, i] = x[:, :, start: start + 600]
x = re
x = x.permute(0, 2, 1, 3) # --> (batch, windows, channels, samples)
@@ -862,9 +965,9 @@ class BasicPhaseAELit(SeisBenchModuleLit):
for i, start in enumerate(range(0, 2401, 400)):
if start == 0:
# Use full window (for start==0, the end will be overwritten)
pred[:, :, start : start + 600] = window_pred[:, i]
pred[:, :, start: start + 600] = window_pred[:, i]
else:
pred[:, :, start + 100 : start + 600] = window_pred[:, i, :, 100:]
pred[:, :, start + 100: start + 600] = window_pred[:, i, :, 100:]
score_detection = torch.zeros(pred.shape[0])
score_p_or_s = torch.zeros(pred.shape[0])
@@ -1112,7 +1215,6 @@ class DPPPickerLit(SeisBenchModuleLit):
def get_model_specific_args(config):
model = config.model_name[0]
lr = config.learning_rate
if type(lr) == list:
@@ -1125,8 +1227,12 @@ def get_model_specific_args(config):
if 'highpass' in config:
args['highpass'] = config.highpass
if 'lowpass' in config:
args['lowpass'] = config.lowpass[0]
args['lowpass'] = config.lowpass
case 'PhaseNet':
if 'highpass' in config:
args['highpass'] = config.highpass
if 'lowpass' in config:
args['lowpass'] = config.lowpass
if 'sample_boundaries' in config:
args['sample_boundaries'] = config.sample_boundaries
case 'DPPPicker':

View File

@@ -57,6 +57,8 @@ def split_events(events, input_path):
events_stats.loc[i, 'split'] = 'dev'
else:
break
logger.info(f"Split: {events_stats['split'].value_counts()}")
logger.info(f"Split: {events_stats['split'].value_counts()}")
@@ -198,9 +200,10 @@ def convert_mseed_to_seisbench_format(input_path, catalog_path, output_path):
metadata_path = output_path + "/metadata.csv"
waveforms_path = output_path + "/waveforms.hdf5"
events_to_convert = events_stats[events_stats['pick_count'] > 0]
logger.debug("Catalog loaded, starting converting {events_to_convert} events ...")
with sbd.WaveformDataWriter(metadata_path, waveforms_path) as writer:

View File

@@ -17,6 +17,8 @@ import eval
import collect_results
import importlib
import config_loader
import input_validate
import data
logging.root.setLevel(logging.INFO)
logger = logging.getLogger('pipeline')
@@ -40,11 +42,22 @@ def load_sweep_config(model_name, args):
return util.load_sweep_config(sweep_fname)
def find_the_best_params(model_name, args):
def validate_pipeline_input(args):
# validate input parameters
for model_name in args.models:
sweep_config = load_sweep_config(model_name, args)
input_validate.validate_sweep_config(sweep_config, model_name)
# validate dataset
data.validate_custom_dataset(config_loader.data_path)
def find_the_best_params(sweep_config):
# find the best hyperparams for the model_name
model_name = sweep_config['parameters']['model_name']
logger.info(f"Starting searching for the best hyperparams for the model: {model_name}")
sweep_config = load_sweep_config(model_name, args)
sweep_runner = hyperparameter_sweep.start_sweep(sweep_config)
# wait for all runs to finish
@@ -91,6 +104,8 @@ def main():
util.set_dataset(args.dataset)
importlib.reload(config_loader)
validate_pipeline_input(args)
logger.info(f"Started pipeline for the {config_loader.dataset_name} dataset.")
# generate labels
@@ -101,7 +116,8 @@ def main():
# find the best hyperparams for the models
logger.info("Started training the models.")
for model_name in args.models:
sweep_id = find_the_best_params(model_name, args)
sweep_config = load_sweep_config(model_name, args)
sweep_id = find_the_best_params(sweep_config)
generate_predictions(sweep_id, model_name)
# collect results