finetuning (#1)

Reviewed-on: #1
Extended Phasenet finetuning options
This commit is contained in:
k.milian 2024-07-29 13:41:05 +02:00
parent 5c3ce04868
commit e86f131cc0
11 changed files with 381 additions and 54 deletions

View File

@ -111,7 +111,7 @@ After adjusting the grant name, the paths to conda env and the paths to data sen
The script performs the following steps:
1. Generates evaluation targets in `datasets/<dataset_name>/targets` directory.
1. Trains multiple versions of GPD, PhaseNet and ... models to find the best hyperparameters, producing the lowest validation loss.
1. Trains multiple versions of GPD, PhaseNet, BasicPhaseAE, and EQTransformer models to find the best hyperparameters, producing the lowest validation loss.
This step utilizes the Weights & Biases platform to perform the hyperparameters search (called sweeping) and track the training process and store the results.
The results are available at
@ -126,12 +126,29 @@ After adjusting the grant name, the paths to conda env and the paths to data sen
The results are saved in the `scripts/pred/results.csv` file. They are additionally logged in Weights & Biases platform as summary metrics of corresponding runs.
<br/>
The default settings are saved in config.json file. To change the settings, edit the config.json file or pass the new settings as arguments to the script. For example, to change the sweep configuration file for GPD model, run:
The default settings for max number of experiments and paths are saved in config.json file. To change the settings, edit the config.json file or pass the new settings as arguments to the script. For example, to change the sweep configuration file for the GPD model, run:
```python pipeline.py --gpd_config <new config file>```
The new config file should be placed in the `experiments` folder or as specified in the `configs_path` parameter in the config.json file.
Sweep configs are used to define the max number of epochs to run and the hyperparameters search space for the following parameters:
* `batch_size`
* `learning_rate`
Phasenet model has additional available parameters:
* `norm` - normalization method, options ('peak', 'std')
* `pretrained` - pretrained seisbench models used for transfer learning
* `finetuning` - the type of layers to finetune first, options ('all', 'top', 'encoder', 'decoder')
* `lr_reduce_factor` - factor to reduce learning rate after unfreezing layers
GPD model has additional parameters for filtering:
* `highpass` - highpass filter frequency
* `lowpass` - lowpass filter frequency
The sweep configs are saved in the `experiments` folder.
If you have multiple datasets, you can run the pipeline for each dataset separately by specifying the dataset name as an argument:
```python pipeline.py --dataset <dataset_name>```

View File

@ -1,6 +1,6 @@
{
"dataset_name": "bogdanka",
"data_path": "datasets/bogdanka/seisbench_format/",
"dataset_name": "bogdanka_2018_2022",
"data_path": "datasets/bogdanka_2018_2022/seisbench_format/",
"targets_path": "datasets/targets",
"models_path": "weights",
"configs_path": "experiments",
@ -13,5 +13,5 @@
"BasicPhaseAE": "sweep_basicphase_ae.yaml",
"EQTransformer": "sweep_eqtransformer.yaml"
},
"experiment_count": 20
"experiment_count": 15
}

View File

@ -1,3 +1,4 @@
name: BasicPhaseAE
method: bayes
metric:
goal: minimize
@ -7,13 +8,9 @@ parameters:
value:
- BasicPhaseAE
batch_size:
distribution: int_uniform
max: 1024
min: 256
values: [64, 128, 256]
max_epochs:
value:
- 20
- 30
learning_rate:
distribution: uniform
max: 0.02
min: 0.001
values: [0.01, 0.005, 0.001]

View File

@ -8,13 +8,9 @@ parameters:
value:
- EQTransformer
batch_size:
distribution: int_uniform
max: 1024
min: 256
values: [64, 128, 256]
max_epochs:
value:
- 30
learning_rate:
distribution: uniform
max: 0.02
min: 0.005
values: [0.01, 0.005, 0.001]

View File

@ -1,4 +1,4 @@
name: GPD_fixed_highpass:2-10
name: GPD
method: bayes
metric:
goal: minimize
@ -8,16 +8,12 @@ parameters:
value:
- GPD
batch_size:
distribution: int_uniform
max: 1024
min: 256
values: [64, 128, 256]
max_epochs:
value:
- 30
learning_rate:
distribution: uniform
max: 0.02
min: 0.005
values: [0.01, 0.005, 0.001]
highpass:
value:
- 1

View File

@ -1,8 +1,13 @@
"""
This file contains functionality related to data.
"""
import os.path
import seisbench.data as sbd
import logging
logging.root.setLevel(logging.INFO)
logger = logging.getLogger('data')
def get_dataset_by_name(name):
@ -30,3 +35,26 @@ def get_custom_dataset(path):
except AttributeError:
raise ValueError(f"Unknown dataset '{path}'.")
def validate_custom_dataset(data_path):
"""
Validate the dataset
:param data_path: path to the dataset
:return:
"""
# check if path exists
if not os.path.isdir((data_path)):
raise ValueError(f"Data path {data_path} does not exist.")
dataset = sbd.WaveformDataset(data_path)
# check if the dataset is split into train, dev and test
if len(dataset.train()) == 0:
raise ValueError(f"Training set is empty.")
if len(dataset.dev()) == 0:
raise ValueError(f"Dev set is empty.")
if len(dataset.test()) == 0:
raise ValueError(f"Test set is empty.")
logger.info("Custom dataset validated successfully.")

View File

@ -7,8 +7,9 @@ import os
import os.path
import argparse
from pytorch_lightning.loggers import WandbLogger, CSVLogger
from pytorch_lightning.callbacks import ModelCheckpoint
from pytorch_lightning.callbacks import ModelCheckpoint, LearningRateMonitor
from pytorch_lightning.callbacks.early_stopping import EarlyStopping
import pytorch_lightning as pl
import wandb
import torch
@ -20,6 +21,8 @@ import train
import util
import config_loader
torch.multiprocessing.set_sharing_strategy('file_system')
os.system("ulimit -n unlimited")
@ -31,14 +34,13 @@ host = os.environ.get("WANDB_HOST")
if host is None:
raise ValueError("WANDB_HOST environment variable is not set.")
wandb.login(key=wandb_api_key, host=host)
wandb_project_name = os.environ.get("WANDB_PROJECT")
wandb_user_name = os.environ.get("WANDB_USER")
script_name = os.path.splitext(os.path.basename(__file__))[0]
logger = logging.getLogger(script_name)
logger.setLevel(logging.WARNING)
logger.setLevel(logging.INFO)
def set_random_seed(seed=3):
@ -54,6 +56,11 @@ def get_trainer_args(config):
return trainer_args
def get_arg(arg):
if type(arg) == list:
return arg[0]
return arg
class HyperparameterSweep:
def __init__(self, project_name, sweep_config):
self.project_name = project_name
@ -84,25 +91,40 @@ class HyperparameterSweep:
return all_not_running
def run_experiment(self):
try:
logger.debug("Starting a new run...")
logger.info("Starting a new run...")
run = wandb.init(
project=self.project_name,
config=config_loader.config,
save_code=True
save_code=True,
entity=wandb_user_name
)
run.log_code(
root=".",
include_fn=lambda path: path.endswith(".py") or path.endswith(".sh"),
exclude_fn=lambda path: path.endswith("template.sh")
) # not working as expected
)
model_name = wandb.config.model_name[0]
model_name = get_arg(wandb.config.model_name)
model_args = models.get_model_specific_args(wandb.config)
logger.debug(f"Initializing {model_name}")
if "pretrained" in wandb.config:
weights = get_arg(wandb.config.pretrained)
if weights != "false":
model_args["pretrained"] = weights
if "norm" in wandb.config:
model_args["norm"] = get_arg(wandb.config.norm)
if "finetuning" in wandb.config:
model_args['finetuning_strategy'] = get_arg(wandb.config.finetuning)
if "lr_reduce_factor" in wandb.config:
model_args['steplr_gamma'] = get_arg(wandb.config.lr_reduce_factor)
logger.debug(f"Initializing {model_name} with args: {model_args}")
model = models.__getattribute__(model_name + "Lit")(**model_args)
train_loader, dev_loader = train.prepare_data(wandb.config, model, test_run=False)
@ -132,10 +154,13 @@ class HyperparameterSweep:
early_stopping_callback = EarlyStopping(
monitor="val_loss",
patience=3,
patience=5,
verbose=True,
mode="min")
callbacks = [checkpoint_callback, early_stopping_callback]
lr_monitor = LearningRateMonitor(logging_interval='epoch')
callbacks = [checkpoint_callback, early_stopping_callback, lr_monitor]
trainer = pl.Trainer(
default_root_dir=config_loader.models_path,
@ -143,7 +168,6 @@ class HyperparameterSweep:
callbacks=callbacks,
**get_trainer_args(wandb.config)
)
trainer.fit(model, train_loader, dev_loader)
except Exception as e:
@ -155,7 +179,6 @@ class HyperparameterSweep:
def start_sweep(sweep_config):
logger.info("Starting sweep with config: " + str(sweep_config))
set_random_seed(config_loader.seed)
sweep_runner = HyperparameterSweep(project_name=wandb_project_name, sweep_config=sweep_config)
@ -165,7 +188,6 @@ def start_sweep(sweep_config):
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--sweep_config", type=str, required=True)
args = parser.parse_args()

146
scripts/input_validate.py Normal file
View File

@ -0,0 +1,146 @@
from pydantic import BaseModel, ConfigDict, field_validator
from typing_extensions import Literal
from typing import Union, List, Optional
import yaml
import logging
logging.root.setLevel(logging.INFO)
logger = logging.getLogger('input_validator')
#todo
# 1. check if a single value is allowed in a sweep
# 2. merge input params
# 3. change names of the classes
# 4. add constraints for PhaseNet, GPD
model_names = Literal["PhaseNet", "GPD", "BasicPhaseAE", "EQTransformer"]
norm_values = Literal["peak", "std"]
finetuning_values = Literal["all", "top", "decoder", "encoder"]
pretrained_values = Literal['diting', 'ethz', 'geofon', 'instance', 'iquique', 'lendb', 'neic',
'original', 'scedc', False]
class Metric(BaseModel):
goal: str
name: str
class NumericValue(BaseModel):
value: Union[int, float, List[Union[int, float]]]
class NumericValues(BaseModel):
values: List[Union[int, float]]
class IntDistribution(BaseModel):
distribution: str = "int_uniform"
min: int
max: int
class FloatDistribution(BaseModel):
distribution: str = "uniform"
min: float
max: float
class Pretrained(BaseModel):
distribution: Optional[str] = "categorical"
values: List[pretrained_values] = None
value: Union[pretrained_values, List[pretrained_values]] = None
class Finetuning(BaseModel):
distribution: Optional[str] = "categorical"
values: List[finetuning_values] = None
value: Union[finetuning_values, List[finetuning_values]] = None
class Norm(BaseModel):
distribution: Optional[str] = "categorical"
values: List[norm_values] = None
value: Union[norm_values, List[norm_values]] = None
class ModelType(BaseModel):
distribution: Optional[str] = "categorical"
value: Union[model_names, List[model_names]] = None
values: List[model_names] = None
class Parameters(BaseModel):
model_config = ConfigDict(extra='forbid', protected_namespaces=())
model_name: ModelType
batch_size: Union[IntDistribution, NumericValue, NumericValues]
learning_rate: Union[FloatDistribution, NumericValue, NumericValues]
max_epochs: Union[IntDistribution, NumericValue, NumericValues]
class PhaseNetParameters(Parameters):
model_config = ConfigDict(extra='forbid')
norm: Norm = None
pretrained: Pretrained = None
finetuning: Finetuning = None
lr_reduce_factor: Optional[Union[FloatDistribution, NumericValue, NumericValues]] = None
highpass: Union[NumericValue, NumericValues, FloatDistribution, IntDistribution] = None
lowpass: Union[NumericValue, NumericValues, FloatDistribution, IntDistribution] = None
@field_validator("model_name")
def validate_model(cls, v):
if "PhaseNet" not in v.value:
raise ValueError("Additional parameters implemented for PhaseNet only")
return v
class FilteringParameters(Parameters):
model_config = ConfigDict(extra='forbid')
highpass: Union[NumericValue, NumericValues, FloatDistribution, IntDistribution] = None
lowpass: Union[NumericValue, NumericValues, FloatDistribution, IntDistribution] = None
@field_validator("model_name")
def validate_model(cls, v):
print(v.value)
if v.value[0] not in ["GPD", "PhaseNet"]:
raise ValueError("Filtering parameters implemented for GPD and PhaseNet only")
class InputParams(BaseModel):
name: str
method: str
metric: Metric
parameters: Union[Parameters, PhaseNetParameters, FilteringParameters]
def validate_sweep_yaml(yaml_filename, model_name=None):
# Load YAML configuration
with open(yaml_filename, 'r') as f:
sweep_config = yaml.safe_load(f)
validate_sweep_config(sweep_config, model_name)
def validate_sweep_config(sweep_config, model_name=None):
# Validate sweep config
input_params = InputParams(**sweep_config)
# Check consistency of input parameters and sweep configuration
sweep_model_name = input_params.parameters.model_name.value
if model_name is not None and model_name not in sweep_model_name:
info = f"Model name {model_name} is inconsistent with the sweep configuration {sweep_model_name}."
logger.info(info)
raise ValueError(info)
logger.info("Input validation successful.")
if __name__ == "__main__":
yaml_filename = "../experiments/sweep_phasenet_bogdanka_lr_bs.yaml"
validate_sweep_yaml(yaml_filename, None)

View File

@ -7,16 +7,26 @@ import seisbench.generate as sbg
import pytorch_lightning as pl
import torch
from torch.optim import lr_scheduler
import torch.nn.functional as F
import numpy as np
from abc import abstractmethod, ABC
# import lightning as L
# Allows to import this file in both jupyter notebook and code
try:
from .augmentations import DuplicateEvent
except ImportError:
from augmentations import DuplicateEvent
import os
import logging
script_name = os.path.splitext(os.path.basename(__file__))[0]
logger = logging.getLogger(script_name)
logger.setLevel(logging.DEBUG)
# Phase dict for labelling. We only study P and S phases without differentiating between them.
phase_dict = {
@ -131,8 +141,32 @@ class PhaseNetLit(SeisBenchModuleLit):
self.sigma = sigma
self.sample_boundaries = sample_boundaries
self.loss = vector_cross_entropy
self.pretrained = kwargs.pop("pretrained", None)
self.norm = kwargs.pop("norm", "peak")
self.highpass = kwargs.pop("highpass", None)
self.lowpass = kwargs.pop("lowpass", None)
if self.pretrained is not None:
self.model = sbm.PhaseNet.from_pretrained(self.pretrained)
# self.norm = self.model.norm
else:
self.model = sbm.PhaseNet(**kwargs)
self.finetuning_strategy = kwargs.pop("finetuning_strategy", None)
self.steplr_gamma = kwargs.pop("steplr_gamma", 0.1)
self.reduce_lr_on_plateau = False
self.initial_epochs = 0
if self.finetuning_strategy is not None:
if self.finetuning_strategy == "top":
self.initial_epochs = 3
elif self.finetuning_strategy in ["decoder", "encoder"]:
self.initial_epochs = 6
self.freeze()
def forward(self, x):
return self.model(x)
@ -154,9 +188,49 @@ class PhaseNetLit(SeisBenchModuleLit):
def configure_optimizers(self):
optimizer = torch.optim.Adam(self.parameters(), lr=self.lr)
return optimizer
if self.finetuning_strategy is not None:
scheduler = lr_scheduler.LambdaLR(optimizer, self.lr_lambda)
self.reduce_lr_on_plateau = False
else:
scheduler = lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.1, patience=3)
self.reduce_lr_on_plateau = True
#
return {
'optimizer': optimizer,
'lr_scheduler': {
'scheduler': scheduler,
'monitor': 'val_loss',
'interval': 'epoch',
'reduce_on_plateau': self.reduce_lr_on_plateau,
},
}
def lr_lambda(self, epoch):
# reduce lr after x initial epochs
if epoch == self.initial_epochs:
self.lr *= self.steplr_gamma
return self.lr
def lr_scheduler_step(self, scheduler, metric):
if self.reduce_lr_on_plateau:
scheduler.step(metric, epoch=self.current_epoch)
else:
scheduler.step(epoch=self.current_epoch)
# def lr_scheduler_step(self, scheduler, optimizer_idx, metric):
# scheduler.step(epoch=self.current_epoch)
def get_augmentations(self):
filter = []
if self.highpass is not None:
filter = [sbg.Filter(1, self.highpass, "highpass", forward_backward=True)]
logger.info(f"Using highpass filer {self.highpass}")
if self.lowpass is not None:
filter += [sbg.Filter(1, self.lowpass, "lowpass", forward_backward=True)]
logger.info(f"Using lowpass filer {self.lowpass}")
logger.info(filter)
return [
# In 2/3 of the cases, select windows around picks, to reduce amount of noise traces in training.
# Uses strategy variable, as padding will be handled by the random window.
@ -180,18 +254,26 @@ class PhaseNetLit(SeisBenchModuleLit):
windowlen=3001,
strategy="pad",
),
sbg.Normalize(demean_axis=-1, amp_norm_axis=-1, amp_norm_type=self.norm),
*filter,
sbg.ChangeDtype(np.float32),
sbg.Normalize(demean_axis=-1, amp_norm_axis=-1, amp_norm_type="peak"),
sbg.ProbabilisticLabeller(
label_columns=phase_dict, sigma=self.sigma, dim=0
),
]
def get_eval_augmentations(self):
filter = []
if self.highpass is not None:
filter = [sbg.Filter(1, self.highpass, "highpass", forward_backward=True)]
if self.lowpass is not None:
filter += [sbg.Filter(1, self.lowpass, "lowpass", forward_backward=True)]
return [
sbg.SteeredWindow(windowlen=3001, strategy="pad"),
sbg.Normalize(demean_axis=-1, amp_norm_axis=-1, amp_norm_type=self.norm),
*filter,
sbg.ChangeDtype(np.float32),
sbg.Normalize(demean_axis=-1, amp_norm_axis=-1, amp_norm_type="peak"),
]
def predict_step(self, batch, batch_idx=None, dataloader_idx=None):
@ -219,6 +301,27 @@ class PhaseNetLit(SeisBenchModuleLit):
return score_detection, score_p_or_s, p_sample, s_sample
def freeze(self):
if self.finetuning_strategy == "decoder": # finetune decoder branch and freeze encoder branch
for p in self.model.down_branch.parameters():
p.requires_grad = False
elif self.finetuning_strategy == "encoder": # finetune encoder branch and freeze decoder branch
for p in self.model.up_branch.parameters():
p.requires_grad = False
elif self.finetuning_strategy == "top":
for p in self.model.out.parameters():
p.requires_grad = False
def unfreeze(self):
logger.info("Unfreezing layers")
for p in self.model.parameters():
p.requires_grad = True
def on_train_epoch_start(self):
# Unfreeze some layers after x initial epochs
if self.current_epoch == self.initial_epochs:
self.unfreeze()
class GPDLit(SeisBenchModuleLit):
"""
@ -846,7 +949,7 @@ class BasicPhaseAELit(SeisBenchModuleLit):
# Create overlapping windows
re = torch.zeros(x.shape[:2] + (7, 600), dtype=x.dtype, device=x.device)
for i, start in enumerate(range(0, 2401, 400)):
re[:, :, i] = x[:, :, start : start + 600]
re[:, :, i] = x[:, :, start: start + 600]
x = re
x = x.permute(0, 2, 1, 3) # --> (batch, windows, channels, samples)
@ -862,9 +965,9 @@ class BasicPhaseAELit(SeisBenchModuleLit):
for i, start in enumerate(range(0, 2401, 400)):
if start == 0:
# Use full window (for start==0, the end will be overwritten)
pred[:, :, start : start + 600] = window_pred[:, i]
pred[:, :, start: start + 600] = window_pred[:, i]
else:
pred[:, :, start + 100 : start + 600] = window_pred[:, i, :, 100:]
pred[:, :, start + 100: start + 600] = window_pred[:, i, :, 100:]
score_detection = torch.zeros(pred.shape[0])
score_p_or_s = torch.zeros(pred.shape[0])
@ -1112,7 +1215,6 @@ class DPPPickerLit(SeisBenchModuleLit):
def get_model_specific_args(config):
model = config.model_name[0]
lr = config.learning_rate
if type(lr) == list:
@ -1125,8 +1227,12 @@ def get_model_specific_args(config):
if 'highpass' in config:
args['highpass'] = config.highpass
if 'lowpass' in config:
args['lowpass'] = config.lowpass[0]
args['lowpass'] = config.lowpass
case 'PhaseNet':
if 'highpass' in config:
args['highpass'] = config.highpass
if 'lowpass' in config:
args['lowpass'] = config.lowpass
if 'sample_boundaries' in config:
args['sample_boundaries'] = config.sample_boundaries
case 'DPPPicker':

View File

@ -60,6 +60,8 @@ def split_events(events, input_path):
logger.info(f"Split: {events_stats['split'].value_counts()}")
logger.info(f"Split: {events_stats['split'].value_counts()}")
return events_stats
@ -201,6 +203,7 @@ def convert_mseed_to_seisbench_format(input_path, catalog_path, output_path):
events_to_convert = events_stats[events_stats['pick_count'] > 0]
logger.debug("Catalog loaded, starting converting {events_to_convert} events ...")
with sbd.WaveformDataWriter(metadata_path, waveforms_path) as writer:

View File

@ -17,6 +17,8 @@ import eval
import collect_results
import importlib
import config_loader
import input_validate
import data
logging.root.setLevel(logging.INFO)
logger = logging.getLogger('pipeline')
@ -40,11 +42,22 @@ def load_sweep_config(model_name, args):
return util.load_sweep_config(sweep_fname)
def find_the_best_params(model_name, args):
def validate_pipeline_input(args):
# validate input parameters
for model_name in args.models:
sweep_config = load_sweep_config(model_name, args)
input_validate.validate_sweep_config(sweep_config, model_name)
# validate dataset
data.validate_custom_dataset(config_loader.data_path)
def find_the_best_params(sweep_config):
# find the best hyperparams for the model_name
model_name = sweep_config['parameters']['model_name']
logger.info(f"Starting searching for the best hyperparams for the model: {model_name}")
sweep_config = load_sweep_config(model_name, args)
sweep_runner = hyperparameter_sweep.start_sweep(sweep_config)
# wait for all runs to finish
@ -91,6 +104,8 @@ def main():
util.set_dataset(args.dataset)
importlib.reload(config_loader)
validate_pipeline_input(args)
logger.info(f"Started pipeline for the {config_loader.dataset_name} dataset.")
# generate labels
@ -101,7 +116,8 @@ def main():
# find the best hyperparams for the models
logger.info("Started training the models.")
for model_name in args.models:
sweep_id = find_the_best_params(model_name, args)
sweep_config = load_sweep_config(model_name, args)
sweep_id = find_the_best_params(sweep_config)
generate_predictions(sweep_id, model_name)
# collect results