From ebc759c21503d9a4d51e3a9740da87b5f858a6f2 Mon Sep 17 00:00:00 2001 From: Krystyna Milian Date: Mon, 13 May 2024 13:44:04 +0200 Subject: [PATCH 1/2] Added Phasenet finetunig option and input validation --- README.md | 25 +++++-- scripts/data.py | 28 ++++++++ scripts/hyperparameter_sweep.py | 46 +++++++++---- scripts/input_validate.py | 142 ++++++++++++++++++++++++++++++++++++++++ scripts/models.py | 102 ++++++++++++++++++++++++++--- scripts/mseeds_to_seisbench.py | 5 +- scripts/pipeline.py | 22 ++++++- 7 files changed, 343 insertions(+), 27 deletions(-) create mode 100644 scripts/input_validate.py diff --git a/README.md b/README.md index ed88444..7803cbf 100644 --- a/README.md +++ b/README.md @@ -111,7 +111,7 @@ After adjusting the grant name, the paths to conda env and the paths to data sen The script performs the following steps: 1. Generates evaluation targets in `datasets//targets` directory. - 1. Trains multiple versions of GPD, PhaseNet and ... models to find the best hyperparameters, producing the lowest validation loss. + 1. Trains multiple versions of GPD, PhaseNet, BasicPhaseAE, and EQTransformer models to find the best hyperparameters, producing the lowest validation loss. This step utilizes the Weights & Biases platform to perform the hyperparameters search (called sweeping) and track the training process and store the results. The results are available at @@ -126,14 +126,31 @@ After adjusting the grant name, the paths to conda env and the paths to data sen The results are saved in the `scripts/pred/results.csv` file. They are additionally logged in Weights & Biases platform as summary metrics of corresponding runs.
- The default settings are saved in config.json file. To change the settings, edit the config.json file or pass the new settings as arguments to the script. For example, to change the sweep configuration file for GPD model, run: + The default settings for max number of experiments and paths are saved in config.json file. To change the settings, edit the config.json file or pass the new settings as arguments to the script. For example, to change the sweep configuration file for the GPD model, run: ```python pipeline.py --gpd_config ``` The new config file should be placed in the `experiments` folder or as specified in the `configs_path` parameter in the config.json file. - + + Sweep configs are used to define the max number of epochs to run and the hyperparameters search space for the following parameters: + * `batch_size` + * `learning_rate` + + Phasenet model has additional available parameters: + * `norm` - normalization method, options ('peak', 'std') + * `pretrained` - pretrained seisbench models used for transfer learning + * `finetuning` - the type of layers to finetune first, options ('all', 'top', 'encoder', 'decoder') + * `lr_reduce_factor` - factor to reduce learning rate after unfreezing layers + + GPD model has additional parameters for filtering: + * `highpass` - highpass filter frequency + * `lowpass` - lowpass filter frequency + + The sweep configs are saved in the `experiments` folder. + + If you have multiple datasets, you can run the pipeline for each dataset separately by specifying the dataset name as an argument: - + ```python pipeline.py --dataset ``` ### Troubleshooting diff --git a/scripts/data.py b/scripts/data.py index 1fe54dc..aeeaac4 100644 --- a/scripts/data.py +++ b/scripts/data.py @@ -1,8 +1,13 @@ """ This file contains functionality related to data. """ +import os.path import seisbench.data as sbd +import logging + +logging.root.setLevel(logging.INFO) +logger = logging.getLogger('data') def get_dataset_by_name(name): @@ -30,3 +35,26 @@ def get_custom_dataset(path): except AttributeError: raise ValueError(f"Unknown dataset '{path}'.") + +def validate_custom_dataset(data_path): + """ + Validate the dataset + :param data_path: path to the dataset + :return: + """ + # check if path exists + if not os.path.isdir((data_path)): + raise ValueError(f"Data path {data_path} does not exist.") + + dataset = sbd.WaveformDataset(data_path) + # check if the dataset is split into train, dev and test + if len(dataset.train()) == 0: + raise ValueError(f"Training set is empty.") + if len(dataset.dev()) == 0: + raise ValueError(f"Dev set is empty.") + if len(dataset.test()) == 0: + raise ValueError(f"Test set is empty.") + + logger.info("Custom dataset validated successfully.") + + diff --git a/scripts/hyperparameter_sweep.py b/scripts/hyperparameter_sweep.py index a0a0e22..981c38d 100644 --- a/scripts/hyperparameter_sweep.py +++ b/scripts/hyperparameter_sweep.py @@ -7,8 +7,9 @@ import os import os.path import argparse from pytorch_lightning.loggers import WandbLogger, CSVLogger -from pytorch_lightning.callbacks import ModelCheckpoint +from pytorch_lightning.callbacks import ModelCheckpoint, LearningRateMonitor from pytorch_lightning.callbacks.early_stopping import EarlyStopping + import pytorch_lightning as pl import wandb import torch @@ -20,6 +21,8 @@ import train import util import config_loader + + torch.multiprocessing.set_sharing_strategy('file_system') os.system("ulimit -n unlimited") @@ -31,14 +34,13 @@ host = os.environ.get("WANDB_HOST") if host is None: raise ValueError("WANDB_HOST environment variable is not set.") - wandb.login(key=wandb_api_key, host=host) wandb_project_name = os.environ.get("WANDB_PROJECT") wandb_user_name = os.environ.get("WANDB_USER") script_name = os.path.splitext(os.path.basename(__file__))[0] logger = logging.getLogger(script_name) -logger.setLevel(logging.WARNING) +logger.setLevel(logging.INFO) def set_random_seed(seed=3): @@ -54,6 +56,7 @@ def get_trainer_args(config): return trainer_args + class HyperparameterSweep: def __init__(self, project_name, sweep_config): self.project_name = project_name @@ -87,21 +90,42 @@ class HyperparameterSweep: try: - logger.debug("Starting a new run...") + logger.info("Starting a new run...") run = wandb.init( project=self.project_name, config=config_loader.config, - save_code=True + save_code=True, + entity=wandb_user_name ) run.log_code( root=".", include_fn=lambda path: path.endswith(".py") or path.endswith(".sh"), exclude_fn=lambda path: path.endswith("template.sh") - ) # not working as expected + ) model_name = wandb.config.model_name[0] model_args = models.get_model_specific_args(wandb.config) + if "pretrained" in wandb.config: + weights = wandb.config.get("pretrained") + if type(weights) == list: + weights = weights[0] + if weights != "false": + model_args["pretrained"] = weights + if "norm" in wandb.config: + model_args["norm"] = wandb.config.norm + logger.debug(f"Initializing {model_name}") + if "finetuning" in wandb.config: + # train for a few epochs with some frozen params, then unfreeze and continue training + if type(wandb.config.finetuning) == list: + finetuning_strategy = wandb.config.finetuning[0] + else: + finetuning_strategy = wandb.config.finetuning + model_args['finetuning_strategy'] = finetuning_strategy + + if "lr_reduce_factor" in wandb.config: + model_args['steplr_gamma'] = wandb.config.lr_reduce_factor + model = models.__getattribute__(model_name + "Lit")(**model_args) @@ -132,10 +156,13 @@ class HyperparameterSweep: early_stopping_callback = EarlyStopping( monitor="val_loss", - patience=3, + patience=5, verbose=True, mode="min") - callbacks = [checkpoint_callback, early_stopping_callback] + + lr_monitor = LearningRateMonitor(logging_interval='epoch') + + callbacks = [checkpoint_callback, early_stopping_callback, lr_monitor] trainer = pl.Trainer( default_root_dir=config_loader.models_path, @@ -143,7 +170,6 @@ class HyperparameterSweep: callbacks=callbacks, **get_trainer_args(wandb.config) ) - trainer.fit(model, train_loader, dev_loader) except Exception as e: @@ -155,7 +181,6 @@ class HyperparameterSweep: def start_sweep(sweep_config): - logger.info("Starting sweep with config: " + str(sweep_config)) set_random_seed(config_loader.seed) sweep_runner = HyperparameterSweep(project_name=wandb_project_name, sweep_config=sweep_config) @@ -165,7 +190,6 @@ def start_sweep(sweep_config): if __name__ == "__main__": - parser = argparse.ArgumentParser() parser.add_argument("--sweep_config", type=str, required=True) args = parser.parse_args() diff --git a/scripts/input_validate.py b/scripts/input_validate.py new file mode 100644 index 0000000..fa6a8ef --- /dev/null +++ b/scripts/input_validate.py @@ -0,0 +1,142 @@ +from pydantic import BaseModel, ConfigDict, field_validator +from typing_extensions import Literal +from typing import Union, List, Optional +import yaml +import logging + +logging.root.setLevel(logging.INFO) +logger = logging.getLogger('input_validator') + +#todo +# 1. check if a single value is allowed in a sweep +# 2. merge input params +# 3. change names of the classes +# 4. add constraints for PhaseNet, GPD + + +model_names = Literal["PhaseNet", "GPD", "BasicPhaseAE", "EQTransformer"] +norm_values = Literal["peak", "std"] +finetuning_values = Literal["all", "top", "decoder", "encoder"] +pretrained_values = Literal['diting', 'ethz', 'geofon', 'instance', 'iquique', 'lendb', 'neic', + 'original', 'scedc', False] + + +class Metric(BaseModel): + goal: str + name: str + + +class NumericValue(BaseModel): + value: Union[int, float, List[Union[int, float]]] + + +class NumericValues(BaseModel): + values: List[Union[int, float]] + + +class IntDistribution(BaseModel): + distribution: str = "int_uniform" + min: int + max: int + + +class FloatDistribution(BaseModel): + distribution: str = "uniform" + min: float + max: float + + +class Pretrained(BaseModel): + distribution: Optional[str] = "categorical" + values: List[pretrained_values] = None + value: Union[pretrained_values, List[pretrained_values]] = None + + +class Finetuning(BaseModel): + distribution: Optional[str] = "categorical" + values: List[finetuning_values] = None + value: Union[finetuning_values, List[finetuning_values]] = None + + +class Norm(BaseModel): + distribution: Optional[str] = "categorical" + values: List[norm_values] = None + value: Union[norm_values, List[norm_values]] = None + + +class ModelType(BaseModel): + distribution: Optional[str] = "categorical" + value: Union[model_names, List[model_names]] = None + values: List[model_names] = None + + +class Parameters(BaseModel): + model_config = ConfigDict(extra='forbid', protected_namespaces=()) + model_name: ModelType + batch_size: Union[IntDistribution, NumericValue, NumericValues] + learning_rate: Union[FloatDistribution, NumericValue, NumericValues] + max_epochs: Union[IntDistribution, NumericValue, NumericValues] + + +class PhaseNetParameters(Parameters): + model_config = ConfigDict(extra='forbid') + norm: Norm = None + pretrained: Pretrained = None + finetuning: Finetuning = None + lr_reduce_factor: Optional[Union[FloatDistribution, NumericValue, NumericValues]] = None + + @field_validator("model_name") + def validate_model(cls, v): + if "PhaseNet" not in v.value: + raise ValueError("Additional parameters implemented for PhaseNet only") + return v + + +class GPDParameters(Parameters): + model_config = ConfigDict(extra='forbid') + + highpass: Union[NumericValue, NumericValues, FloatDistribution] = None + lowpass: Union[NumericValue, NumericValues, FloatDistribution] = None + + @field_validator("model_name") + def validate_model(cls, v): + if "GPD" not in v.value: + raise ValueError("Additional parameters implemented for GPD only") + + +class InputParams(BaseModel): + name: str + method: str + metric: Metric + parameters: Union[Parameters, PhaseNetParameters, GPDParameters] + + +def validate_sweep_yaml(yaml_filename, model_name=None): + # Load YAML configuration + with open(yaml_filename, 'r') as f: + sweep_config = yaml.safe_load(f) + + validate_sweep_config(sweep_config, model_name) + + +def validate_sweep_config(sweep_config, model_name=None): + + # Validate sweep config + + input_params = InputParams(**sweep_config) + + # Check consistency of input parameters and sweep configuration + sweep_model_name = input_params.parameters.model_name.value + if model_name is not None and model_name not in sweep_model_name: + info = f"Model name {model_name} is inconsistent with the sweep configuration {sweep_model_name}." + logger.info(info) + raise ValueError(info) + logger.info("Input validation successful.") + + + + + +if __name__ == "__main__": + yaml_filename = "../experiments/sweep_phasenet_lumineos_lr_bs.yaml" + validate_sweep_yaml(yaml_filename, None) diff --git a/scripts/models.py b/scripts/models.py index 644696f..f35b0d9 100644 --- a/scripts/models.py +++ b/scripts/models.py @@ -7,16 +7,26 @@ import seisbench.generate as sbg import pytorch_lightning as pl import torch +from torch.optim import lr_scheduler import torch.nn.functional as F import numpy as np from abc import abstractmethod, ABC +# import lightning as L + + # Allows to import this file in both jupyter notebook and code try: from .augmentations import DuplicateEvent except ImportError: from augmentations import DuplicateEvent +import os +import logging + +script_name = os.path.splitext(os.path.basename(__file__))[0] +logger = logging.getLogger(script_name) +logger.setLevel(logging.DEBUG) # Phase dict for labelling. We only study P and S phases without differentiating between them. phase_dict = { @@ -131,30 +141,84 @@ class PhaseNetLit(SeisBenchModuleLit): self.sigma = sigma self.sample_boundaries = sample_boundaries self.loss = vector_cross_entropy - self.model = sbm.PhaseNet(**kwargs) + self.pretrained = kwargs.pop("pretrained", None) + self.norm = kwargs.pop("norm", "peak") + + if self.pretrained is not None: + self.model = sbm.PhaseNet.from_pretrained(self.pretrained) + # self.norm = self.model.norm + else: + self.model = sbm.PhaseNet(**kwargs) + + self.finetuning_strategy = kwargs.pop("finetuning_strategy", None) + self.steplr_gamma = kwargs.pop("steplr_gamma", 0.1) + + self.initial_epochs = 0 + + if self.finetuning_strategy is not None: + if self.finetuning_strategy == "top": + self.initial_epochs = 3 + elif self.finetuning_strategy in ["decoder", "encoder"]: + self.initial_epochs = 6 + + self.freeze() + def forward(self, x): return self.model(x) + def shared_step(self, batch): x = batch["X"] y_true = batch["y"] y_pred = self.model(x) return self.loss(y_pred, y_true) + def training_step(self, batch, batch_idx): loss = self.shared_step(batch) self.log("train_loss", loss) return loss + def validation_step(self, batch, batch_idx): loss = self.shared_step(batch) self.log("val_loss", loss) return loss + def configure_optimizers(self): optimizer = torch.optim.Adam(self.parameters(), lr=self.lr) - return optimizer + if self.finetuning_strategy is not None: + scheduler = lr_scheduler.LambdaLR(optimizer, self.lr_lambda) + else: + scheduler = lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.1, patience=3) + # + return { + 'optimizer': optimizer, + 'lr_scheduler': { + 'scheduler': scheduler, + 'monitor': 'val_loss', + 'interval': 'epoch', + 'reduce_on_plateau': False, + }, + } + + + def lr_lambda(self, epoch): + # reduce lr after x initial epochs + if epoch == self.initial_epochs: + self.lr *= self.steplr_gamma + + return self.lr + + + def lr_scheduler_step(self, scheduler, metric): + scheduler.step(epoch=self.current_epoch) + + + # def lr_scheduler_step(self, scheduler, optimizer_idx, metric): + # scheduler.step(epoch=self.current_epoch) def get_augmentations(self): return [ @@ -181,19 +245,21 @@ class PhaseNetLit(SeisBenchModuleLit): strategy="pad", ), sbg.ChangeDtype(np.float32), - sbg.Normalize(demean_axis=-1, amp_norm_axis=-1, amp_norm_type="peak"), + sbg.Normalize(demean_axis=-1, amp_norm_axis=-1, amp_norm_type=self.norm), sbg.ProbabilisticLabeller( label_columns=phase_dict, sigma=self.sigma, dim=0 ), ] + def get_eval_augmentations(self): return [ sbg.SteeredWindow(windowlen=3001, strategy="pad"), sbg.ChangeDtype(np.float32), - sbg.Normalize(demean_axis=-1, amp_norm_axis=-1, amp_norm_type="peak"), + sbg.Normalize(demean_axis=-1, amp_norm_axis=-1, amp_norm_type=self.norm), ] + def predict_step(self, batch, batch_idx=None, dataloader_idx=None): x = batch["X"] window_borders = batch["window_borders"] @@ -219,6 +285,27 @@ class PhaseNetLit(SeisBenchModuleLit): return score_detection, score_p_or_s, p_sample, s_sample + def freeze(self): + if self.finetuning_strategy == "decoder": # finetune decoder branch and freeze encoder branch + for p in self.model.down_branch.parameters(): + p.requires_grad = False + elif self.finetuning_strategy == "encoder": # finetune encoder branch and freeze decoder branch + for p in self.model.up_branch.parameters(): + p.requires_grad = False + elif self.finetuning_strategy == "top": + for p in self.model.out.parameters(): + p.requires_grad = False + + def unfreeze(self): + logger.info("Unfreezing layers") + for p in self.model.parameters(): + p.requires_grad = True + + def on_train_epoch_start(self): + # Unfreeze some layers after x initial epochs + if self.current_epoch == self.initial_epochs: + self.unfreeze() + class GPDLit(SeisBenchModuleLit): """ @@ -846,7 +933,7 @@ class BasicPhaseAELit(SeisBenchModuleLit): # Create overlapping windows re = torch.zeros(x.shape[:2] + (7, 600), dtype=x.dtype, device=x.device) for i, start in enumerate(range(0, 2401, 400)): - re[:, :, i] = x[:, :, start : start + 600] + re[:, :, i] = x[:, :, start: start + 600] x = re x = x.permute(0, 2, 1, 3) # --> (batch, windows, channels, samples) @@ -862,9 +949,9 @@ class BasicPhaseAELit(SeisBenchModuleLit): for i, start in enumerate(range(0, 2401, 400)): if start == 0: # Use full window (for start==0, the end will be overwritten) - pred[:, :, start : start + 600] = window_pred[:, i] + pred[:, :, start: start + 600] = window_pred[:, i] else: - pred[:, :, start + 100 : start + 600] = window_pred[:, i, :, 100:] + pred[:, :, start + 100: start + 600] = window_pred[:, i, :, 100:] score_detection = torch.zeros(pred.shape[0]) score_p_or_s = torch.zeros(pred.shape[0]) @@ -1112,7 +1199,6 @@ class DPPPickerLit(SeisBenchModuleLit): def get_model_specific_args(config): - model = config.model_name[0] lr = config.learning_rate if type(lr) == list: diff --git a/scripts/mseeds_to_seisbench.py b/scripts/mseeds_to_seisbench.py index e36d2a3..fae08b8 100644 --- a/scripts/mseeds_to_seisbench.py +++ b/scripts/mseeds_to_seisbench.py @@ -57,6 +57,8 @@ def split_events(events, input_path): events_stats.loc[i, 'split'] = 'dev' else: break + + logger.info(f"Split: {events_stats['split'].value_counts()}") logger.info(f"Split: {events_stats['split'].value_counts()}") @@ -198,9 +200,10 @@ def convert_mseed_to_seisbench_format(input_path, catalog_path, output_path): metadata_path = output_path + "/metadata.csv" waveforms_path = output_path + "/waveforms.hdf5" - + events_to_convert = events_stats[events_stats['pick_count'] > 0] + logger.debug("Catalog loaded, starting converting {events_to_convert} events ...") with sbd.WaveformDataWriter(metadata_path, waveforms_path) as writer: diff --git a/scripts/pipeline.py b/scripts/pipeline.py index 34e7492..8c82c25 100644 --- a/scripts/pipeline.py +++ b/scripts/pipeline.py @@ -17,6 +17,8 @@ import eval import collect_results import importlib import config_loader +import input_validate +import data logging.root.setLevel(logging.INFO) logger = logging.getLogger('pipeline') @@ -40,11 +42,22 @@ def load_sweep_config(model_name, args): return util.load_sweep_config(sweep_fname) -def find_the_best_params(model_name, args): +def validate_pipeline_input(args): + + # validate input parameters + for model_name in args.models: + sweep_config = load_sweep_config(model_name, args) + input_validate.validate_sweep_config(sweep_config, model_name) + + # validate dataset + data.validate_custom_dataset(config_loader.data_path) + + +def find_the_best_params(sweep_config): # find the best hyperparams for the model_name + model_name = sweep_config['parameters']['model_name'] logger.info(f"Starting searching for the best hyperparams for the model: {model_name}") - sweep_config = load_sweep_config(model_name, args) sweep_runner = hyperparameter_sweep.start_sweep(sweep_config) # wait for all runs to finish @@ -91,6 +104,8 @@ def main(): util.set_dataset(args.dataset) importlib.reload(config_loader) + validate_pipeline_input(args) + logger.info(f"Started pipeline for the {config_loader.dataset_name} dataset.") # generate labels @@ -101,7 +116,8 @@ def main(): # find the best hyperparams for the models logger.info("Started training the models.") for model_name in args.models: - sweep_id = find_the_best_params(model_name, args) + sweep_config = load_sweep_config(model_name, args) + sweep_id = find_the_best_params(sweep_config) generate_predictions(sweep_id, model_name) # collect results -- 2.16.5 From f40ac35cc87f01cdc1c6d41f8d3499fc7c945412 Mon Sep 17 00:00:00 2001 From: Krystyna Milian Date: Sun, 14 Jul 2024 10:32:37 +0200 Subject: [PATCH 2/2] fixed PhaseNet normalization, added reducing lr on plateau --- config.json | 6 ++--- experiments/sweep_basicphase_ae.yaml | 11 +++----- experiments/sweep_eqtransformer.yaml | 8 ++---- experiments/sweep_gpd.yaml | 10 +++----- scripts/hyperparameter_sweep.py | 30 ++++++++++------------ scripts/input_validate.py | 18 ++++++++----- scripts/models.py | 50 +++++++++++++++++++++++++----------- 7 files changed, 72 insertions(+), 61 deletions(-) diff --git a/config.json b/config.json index acd11b9..b320b67 100644 --- a/config.json +++ b/config.json @@ -1,6 +1,6 @@ { - "dataset_name": "bogdanka", - "data_path": "datasets/bogdanka/seisbench_format/", + "dataset_name": "bogdanka_2018_2022", + "data_path": "datasets/bogdanka_2018_2022/seisbench_format/", "targets_path": "datasets/targets", "models_path": "weights", "configs_path": "experiments", @@ -13,5 +13,5 @@ "BasicPhaseAE": "sweep_basicphase_ae.yaml", "EQTransformer": "sweep_eqtransformer.yaml" }, - "experiment_count": 20 + "experiment_count": 15 } \ No newline at end of file diff --git a/experiments/sweep_basicphase_ae.yaml b/experiments/sweep_basicphase_ae.yaml index ba45c47..7e3bda1 100644 --- a/experiments/sweep_basicphase_ae.yaml +++ b/experiments/sweep_basicphase_ae.yaml @@ -1,3 +1,4 @@ +name: BasicPhaseAE method: bayes metric: goal: minimize @@ -7,13 +8,9 @@ parameters: value: - BasicPhaseAE batch_size: - distribution: int_uniform - max: 1024 - min: 256 + values: [64, 128, 256] max_epochs: value: - - 20 + - 30 learning_rate: - distribution: uniform - max: 0.02 - min: 0.001 + values: [0.01, 0.005, 0.001] diff --git a/experiments/sweep_eqtransformer.yaml b/experiments/sweep_eqtransformer.yaml index 21735f4..03741a2 100644 --- a/experiments/sweep_eqtransformer.yaml +++ b/experiments/sweep_eqtransformer.yaml @@ -8,13 +8,9 @@ parameters: value: - EQTransformer batch_size: - distribution: int_uniform - max: 1024 - min: 256 + values: [64, 128, 256] max_epochs: value: - 30 learning_rate: - distribution: uniform - max: 0.02 - min: 0.005 + values: [0.01, 0.005, 0.001] \ No newline at end of file diff --git a/experiments/sweep_gpd.yaml b/experiments/sweep_gpd.yaml index 58f3674..29db5bf 100644 --- a/experiments/sweep_gpd.yaml +++ b/experiments/sweep_gpd.yaml @@ -1,4 +1,4 @@ -name: GPD_fixed_highpass:2-10 +name: GPD method: bayes metric: goal: minimize @@ -8,16 +8,12 @@ parameters: value: - GPD batch_size: - distribution: int_uniform - max: 1024 - min: 256 + values: [64, 128, 256] max_epochs: value: - 30 learning_rate: - distribution: uniform - max: 0.02 - min: 0.005 + values: [0.01, 0.005, 0.001] highpass: value: - 1 diff --git a/scripts/hyperparameter_sweep.py b/scripts/hyperparameter_sweep.py index 981c38d..27ef805 100644 --- a/scripts/hyperparameter_sweep.py +++ b/scripts/hyperparameter_sweep.py @@ -56,6 +56,10 @@ def get_trainer_args(config): return trainer_args +def get_arg(arg): + if type(arg) == list: + return arg[0] + return arg class HyperparameterSweep: def __init__(self, project_name, sweep_config): @@ -87,10 +91,10 @@ class HyperparameterSweep: return all_not_running def run_experiment(self): - try: logger.info("Starting a new run...") + run = wandb.init( project=self.project_name, config=config_loader.config, @@ -103,30 +107,24 @@ class HyperparameterSweep: exclude_fn=lambda path: path.endswith("template.sh") ) - model_name = wandb.config.model_name[0] + model_name = get_arg(wandb.config.model_name) model_args = models.get_model_specific_args(wandb.config) + if "pretrained" in wandb.config: - weights = wandb.config.get("pretrained") - if type(weights) == list: - weights = weights[0] + weights = get_arg(wandb.config.pretrained) if weights != "false": model_args["pretrained"] = weights - if "norm" in wandb.config: - model_args["norm"] = wandb.config.norm - logger.debug(f"Initializing {model_name}") + if "norm" in wandb.config: + model_args["norm"] = get_arg(wandb.config.norm) + if "finetuning" in wandb.config: - # train for a few epochs with some frozen params, then unfreeze and continue training - if type(wandb.config.finetuning) == list: - finetuning_strategy = wandb.config.finetuning[0] - else: - finetuning_strategy = wandb.config.finetuning - model_args['finetuning_strategy'] = finetuning_strategy + model_args['finetuning_strategy'] = get_arg(wandb.config.finetuning) if "lr_reduce_factor" in wandb.config: - model_args['steplr_gamma'] = wandb.config.lr_reduce_factor - + model_args['steplr_gamma'] = get_arg(wandb.config.lr_reduce_factor) + logger.debug(f"Initializing {model_name} with args: {model_args}") model = models.__getattribute__(model_name + "Lit")(**model_args) train_loader, dev_loader = train.prepare_data(wandb.config, model, test_run=False) diff --git a/scripts/input_validate.py b/scripts/input_validate.py index fa6a8ef..c6bdfa4 100644 --- a/scripts/input_validate.py +++ b/scripts/input_validate.py @@ -85,6 +85,9 @@ class PhaseNetParameters(Parameters): finetuning: Finetuning = None lr_reduce_factor: Optional[Union[FloatDistribution, NumericValue, NumericValues]] = None + highpass: Union[NumericValue, NumericValues, FloatDistribution, IntDistribution] = None + lowpass: Union[NumericValue, NumericValues, FloatDistribution, IntDistribution] = None + @field_validator("model_name") def validate_model(cls, v): if "PhaseNet" not in v.value: @@ -92,23 +95,24 @@ class PhaseNetParameters(Parameters): return v -class GPDParameters(Parameters): +class FilteringParameters(Parameters): model_config = ConfigDict(extra='forbid') - highpass: Union[NumericValue, NumericValues, FloatDistribution] = None - lowpass: Union[NumericValue, NumericValues, FloatDistribution] = None + highpass: Union[NumericValue, NumericValues, FloatDistribution, IntDistribution] = None + lowpass: Union[NumericValue, NumericValues, FloatDistribution, IntDistribution] = None @field_validator("model_name") def validate_model(cls, v): - if "GPD" not in v.value: - raise ValueError("Additional parameters implemented for GPD only") + print(v.value) + if v.value[0] not in ["GPD", "PhaseNet"]: + raise ValueError("Filtering parameters implemented for GPD and PhaseNet only") class InputParams(BaseModel): name: str method: str metric: Metric - parameters: Union[Parameters, PhaseNetParameters, GPDParameters] + parameters: Union[Parameters, PhaseNetParameters, FilteringParameters] def validate_sweep_yaml(yaml_filename, model_name=None): @@ -138,5 +142,5 @@ def validate_sweep_config(sweep_config, model_name=None): if __name__ == "__main__": - yaml_filename = "../experiments/sweep_phasenet_lumineos_lr_bs.yaml" + yaml_filename = "../experiments/sweep_phasenet_bogdanka_lr_bs.yaml" validate_sweep_yaml(yaml_filename, None) diff --git a/scripts/models.py b/scripts/models.py index f35b0d9..7956c44 100644 --- a/scripts/models.py +++ b/scripts/models.py @@ -143,6 +143,9 @@ class PhaseNetLit(SeisBenchModuleLit): self.loss = vector_cross_entropy self.pretrained = kwargs.pop("pretrained", None) self.norm = kwargs.pop("norm", "peak") + self.highpass = kwargs.pop("highpass", None) + self.lowpass = kwargs.pop("lowpass", None) + if self.pretrained is not None: self.model = sbm.PhaseNet.from_pretrained(self.pretrained) @@ -152,6 +155,7 @@ class PhaseNetLit(SeisBenchModuleLit): self.finetuning_strategy = kwargs.pop("finetuning_strategy", None) self.steplr_gamma = kwargs.pop("steplr_gamma", 0.1) + self.reduce_lr_on_plateau = False self.initial_epochs = 0 @@ -163,36 +167,33 @@ class PhaseNetLit(SeisBenchModuleLit): self.freeze() - def forward(self, x): return self.model(x) - def shared_step(self, batch): x = batch["X"] y_true = batch["y"] y_pred = self.model(x) return self.loss(y_pred, y_true) - def training_step(self, batch, batch_idx): loss = self.shared_step(batch) self.log("train_loss", loss) return loss - def validation_step(self, batch, batch_idx): loss = self.shared_step(batch) self.log("val_loss", loss) return loss - def configure_optimizers(self): optimizer = torch.optim.Adam(self.parameters(), lr=self.lr) if self.finetuning_strategy is not None: scheduler = lr_scheduler.LambdaLR(optimizer, self.lr_lambda) + self.reduce_lr_on_plateau = False else: scheduler = lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.1, patience=3) + self.reduce_lr_on_plateau = True # return { 'optimizer': optimizer, @@ -200,11 +201,10 @@ class PhaseNetLit(SeisBenchModuleLit): 'scheduler': scheduler, 'monitor': 'val_loss', 'interval': 'epoch', - 'reduce_on_plateau': False, + 'reduce_on_plateau': self.reduce_lr_on_plateau, }, } - def lr_lambda(self, epoch): # reduce lr after x initial epochs if epoch == self.initial_epochs: @@ -212,15 +212,25 @@ class PhaseNetLit(SeisBenchModuleLit): return self.lr - def lr_scheduler_step(self, scheduler, metric): - scheduler.step(epoch=self.current_epoch) - + if self.reduce_lr_on_plateau: + scheduler.step(metric, epoch=self.current_epoch) + else: + scheduler.step(epoch=self.current_epoch) # def lr_scheduler_step(self, scheduler, optimizer_idx, metric): # scheduler.step(epoch=self.current_epoch) def get_augmentations(self): + filter = [] + if self.highpass is not None: + filter = [sbg.Filter(1, self.highpass, "highpass", forward_backward=True)] + logger.info(f"Using highpass filer {self.highpass}") + if self.lowpass is not None: + filter += [sbg.Filter(1, self.lowpass, "lowpass", forward_backward=True)] + logger.info(f"Using lowpass filer {self.lowpass}") + logger.info(filter) + return [ # In 2/3 of the cases, select windows around picks, to reduce amount of noise traces in training. # Uses strategy variable, as padding will be handled by the random window. @@ -244,22 +254,28 @@ class PhaseNetLit(SeisBenchModuleLit): windowlen=3001, strategy="pad", ), - sbg.ChangeDtype(np.float32), sbg.Normalize(demean_axis=-1, amp_norm_axis=-1, amp_norm_type=self.norm), + *filter, + sbg.ChangeDtype(np.float32), sbg.ProbabilisticLabeller( label_columns=phase_dict, sigma=self.sigma, dim=0 ), ] - def get_eval_augmentations(self): + filter = [] + if self.highpass is not None: + filter = [sbg.Filter(1, self.highpass, "highpass", forward_backward=True)] + if self.lowpass is not None: + filter += [sbg.Filter(1, self.lowpass, "lowpass", forward_backward=True)] + return [ sbg.SteeredWindow(windowlen=3001, strategy="pad"), - sbg.ChangeDtype(np.float32), sbg.Normalize(demean_axis=-1, amp_norm_axis=-1, amp_norm_type=self.norm), + *filter, + sbg.ChangeDtype(np.float32), ] - def predict_step(self, batch, batch_idx=None, dataloader_idx=None): x = batch["X"] window_borders = batch["window_borders"] @@ -1211,8 +1227,12 @@ def get_model_specific_args(config): if 'highpass' in config: args['highpass'] = config.highpass if 'lowpass' in config: - args['lowpass'] = config.lowpass[0] + args['lowpass'] = config.lowpass case 'PhaseNet': + if 'highpass' in config: + args['highpass'] = config.highpass + if 'lowpass' in config: + args['lowpass'] = config.lowpass if 'sample_boundaries' in config: args['sample_boundaries'] = config.sample_boundaries case 'DPPPicker': -- 2.16.5