finetuning #1
21
README.md
21
README.md
@ -111,7 +111,7 @@ After adjusting the grant name, the paths to conda env and the paths to data sen
|
|||||||
|
|
||||||
The script performs the following steps:
|
The script performs the following steps:
|
||||||
1. Generates evaluation targets in `datasets/<dataset_name>/targets` directory.
|
1. Generates evaluation targets in `datasets/<dataset_name>/targets` directory.
|
||||||
1. Trains multiple versions of GPD, PhaseNet and ... models to find the best hyperparameters, producing the lowest validation loss.
|
1. Trains multiple versions of GPD, PhaseNet, BasicPhaseAE, and EQTransformer models to find the best hyperparameters, producing the lowest validation loss.
|
||||||
|
|
||||||
This step utilizes the Weights & Biases platform to perform the hyperparameters search (called sweeping) and track the training process and store the results.
|
This step utilizes the Weights & Biases platform to perform the hyperparameters search (called sweeping) and track the training process and store the results.
|
||||||
The results are available at
|
The results are available at
|
||||||
@ -126,12 +126,29 @@ After adjusting the grant name, the paths to conda env and the paths to data sen
|
|||||||
The results are saved in the `scripts/pred/results.csv` file. They are additionally logged in Weights & Biases platform as summary metrics of corresponding runs.
|
The results are saved in the `scripts/pred/results.csv` file. They are additionally logged in Weights & Biases platform as summary metrics of corresponding runs.
|
||||||
|
|
||||||
<br/>
|
<br/>
|
||||||
The default settings are saved in config.json file. To change the settings, edit the config.json file or pass the new settings as arguments to the script. For example, to change the sweep configuration file for GPD model, run:
|
The default settings for max number of experiments and paths are saved in config.json file. To change the settings, edit the config.json file or pass the new settings as arguments to the script. For example, to change the sweep configuration file for the GPD model, run:
|
||||||
|
|
||||||
```python pipeline.py --gpd_config <new config file>```
|
```python pipeline.py --gpd_config <new config file>```
|
||||||
|
|
||||||
The new config file should be placed in the `experiments` folder or as specified in the `configs_path` parameter in the config.json file.
|
The new config file should be placed in the `experiments` folder or as specified in the `configs_path` parameter in the config.json file.
|
||||||
|
|
||||||
|
Sweep configs are used to define the max number of epochs to run and the hyperparameters search space for the following parameters:
|
||||||
|
* `batch_size`
|
||||||
|
* `learning_rate`
|
||||||
|
|
||||||
|
Phasenet model has additional available parameters:
|
||||||
|
* `norm` - normalization method, options ('peak', 'std')
|
||||||
|
* `pretrained` - pretrained seisbench models used for transfer learning
|
||||||
|
* `finetuning` - the type of layers to finetune first, options ('all', 'top', 'encoder', 'decoder')
|
||||||
|
* `lr_reduce_factor` - factor to reduce learning rate after unfreezing layers
|
||||||
|
|
||||||
|
GPD model has additional parameters for filtering:
|
||||||
|
* `highpass` - highpass filter frequency
|
||||||
|
* `lowpass` - lowpass filter frequency
|
||||||
|
|
||||||
|
The sweep configs are saved in the `experiments` folder.
|
||||||
|
|
||||||
|
|
||||||
If you have multiple datasets, you can run the pipeline for each dataset separately by specifying the dataset name as an argument:
|
If you have multiple datasets, you can run the pipeline for each dataset separately by specifying the dataset name as an argument:
|
||||||
|
|
||||||
```python pipeline.py --dataset <dataset_name>```
|
```python pipeline.py --dataset <dataset_name>```
|
||||||
|
@ -1,6 +1,6 @@
|
|||||||
{
|
{
|
||||||
"dataset_name": "bogdanka",
|
"dataset_name": "bogdanka_2018_2022",
|
||||||
"data_path": "datasets/bogdanka/seisbench_format/",
|
"data_path": "datasets/bogdanka_2018_2022/seisbench_format/",
|
||||||
"targets_path": "datasets/targets",
|
"targets_path": "datasets/targets",
|
||||||
"models_path": "weights",
|
"models_path": "weights",
|
||||||
"configs_path": "experiments",
|
"configs_path": "experiments",
|
||||||
@ -13,5 +13,5 @@
|
|||||||
"BasicPhaseAE": "sweep_basicphase_ae.yaml",
|
"BasicPhaseAE": "sweep_basicphase_ae.yaml",
|
||||||
"EQTransformer": "sweep_eqtransformer.yaml"
|
"EQTransformer": "sweep_eqtransformer.yaml"
|
||||||
},
|
},
|
||||||
"experiment_count": 20
|
"experiment_count": 15
|
||||||
}
|
}
|
@ -1,3 +1,4 @@
|
|||||||
|
name: BasicPhaseAE
|
||||||
method: bayes
|
method: bayes
|
||||||
metric:
|
metric:
|
||||||
goal: minimize
|
goal: minimize
|
||||||
@ -7,13 +8,9 @@ parameters:
|
|||||||
value:
|
value:
|
||||||
- BasicPhaseAE
|
- BasicPhaseAE
|
||||||
batch_size:
|
batch_size:
|
||||||
distribution: int_uniform
|
values: [64, 128, 256]
|
||||||
max: 1024
|
|
||||||
min: 256
|
|
||||||
max_epochs:
|
max_epochs:
|
||||||
value:
|
value:
|
||||||
- 20
|
- 30
|
||||||
learning_rate:
|
learning_rate:
|
||||||
distribution: uniform
|
values: [0.01, 0.005, 0.001]
|
||||||
max: 0.02
|
|
||||||
min: 0.001
|
|
||||||
|
@ -8,13 +8,9 @@ parameters:
|
|||||||
value:
|
value:
|
||||||
- EQTransformer
|
- EQTransformer
|
||||||
batch_size:
|
batch_size:
|
||||||
distribution: int_uniform
|
values: [64, 128, 256]
|
||||||
max: 1024
|
|
||||||
min: 256
|
|
||||||
max_epochs:
|
max_epochs:
|
||||||
value:
|
value:
|
||||||
- 30
|
- 30
|
||||||
learning_rate:
|
learning_rate:
|
||||||
distribution: uniform
|
values: [0.01, 0.005, 0.001]
|
||||||
max: 0.02
|
|
||||||
min: 0.005
|
|
@ -1,4 +1,4 @@
|
|||||||
name: GPD_fixed_highpass:2-10
|
name: GPD
|
||||||
method: bayes
|
method: bayes
|
||||||
metric:
|
metric:
|
||||||
goal: minimize
|
goal: minimize
|
||||||
@ -8,16 +8,12 @@ parameters:
|
|||||||
value:
|
value:
|
||||||
- GPD
|
- GPD
|
||||||
batch_size:
|
batch_size:
|
||||||
distribution: int_uniform
|
values: [64, 128, 256]
|
||||||
max: 1024
|
|
||||||
min: 256
|
|
||||||
max_epochs:
|
max_epochs:
|
||||||
value:
|
value:
|
||||||
- 30
|
- 30
|
||||||
learning_rate:
|
learning_rate:
|
||||||
distribution: uniform
|
values: [0.01, 0.005, 0.001]
|
||||||
max: 0.02
|
|
||||||
min: 0.005
|
|
||||||
highpass:
|
highpass:
|
||||||
value:
|
value:
|
||||||
- 1
|
- 1
|
||||||
|
@ -1,8 +1,13 @@
|
|||||||
"""
|
"""
|
||||||
This file contains functionality related to data.
|
This file contains functionality related to data.
|
||||||
"""
|
"""
|
||||||
|
import os.path
|
||||||
|
|
||||||
import seisbench.data as sbd
|
import seisbench.data as sbd
|
||||||
|
import logging
|
||||||
|
|
||||||
|
logging.root.setLevel(logging.INFO)
|
||||||
|
logger = logging.getLogger('data')
|
||||||
|
|
||||||
|
|
||||||
def get_dataset_by_name(name):
|
def get_dataset_by_name(name):
|
||||||
@ -30,3 +35,26 @@ def get_custom_dataset(path):
|
|||||||
except AttributeError:
|
except AttributeError:
|
||||||
raise ValueError(f"Unknown dataset '{path}'.")
|
raise ValueError(f"Unknown dataset '{path}'.")
|
||||||
|
|
||||||
|
|
||||||
|
def validate_custom_dataset(data_path):
|
||||||
|
"""
|
||||||
|
Validate the dataset
|
||||||
|
:param data_path: path to the dataset
|
||||||
|
:return:
|
||||||
|
"""
|
||||||
|
# check if path exists
|
||||||
|
if not os.path.isdir((data_path)):
|
||||||
|
raise ValueError(f"Data path {data_path} does not exist.")
|
||||||
|
|
||||||
|
dataset = sbd.WaveformDataset(data_path)
|
||||||
|
# check if the dataset is split into train, dev and test
|
||||||
|
if len(dataset.train()) == 0:
|
||||||
|
raise ValueError(f"Training set is empty.")
|
||||||
|
if len(dataset.dev()) == 0:
|
||||||
|
raise ValueError(f"Dev set is empty.")
|
||||||
|
if len(dataset.test()) == 0:
|
||||||
|
raise ValueError(f"Test set is empty.")
|
||||||
|
|
||||||
|
logger.info("Custom dataset validated successfully.")
|
||||||
|
|
||||||
|
|
||||||
|
@ -7,8 +7,9 @@ import os
|
|||||||
import os.path
|
import os.path
|
||||||
import argparse
|
import argparse
|
||||||
from pytorch_lightning.loggers import WandbLogger, CSVLogger
|
from pytorch_lightning.loggers import WandbLogger, CSVLogger
|
||||||
from pytorch_lightning.callbacks import ModelCheckpoint
|
from pytorch_lightning.callbacks import ModelCheckpoint, LearningRateMonitor
|
||||||
from pytorch_lightning.callbacks.early_stopping import EarlyStopping
|
from pytorch_lightning.callbacks.early_stopping import EarlyStopping
|
||||||
|
|
||||||
import pytorch_lightning as pl
|
import pytorch_lightning as pl
|
||||||
import wandb
|
import wandb
|
||||||
import torch
|
import torch
|
||||||
@ -20,6 +21,8 @@ import train
|
|||||||
import util
|
import util
|
||||||
import config_loader
|
import config_loader
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
torch.multiprocessing.set_sharing_strategy('file_system')
|
torch.multiprocessing.set_sharing_strategy('file_system')
|
||||||
os.system("ulimit -n unlimited")
|
os.system("ulimit -n unlimited")
|
||||||
|
|
||||||
@ -31,14 +34,13 @@ host = os.environ.get("WANDB_HOST")
|
|||||||
if host is None:
|
if host is None:
|
||||||
raise ValueError("WANDB_HOST environment variable is not set.")
|
raise ValueError("WANDB_HOST environment variable is not set.")
|
||||||
|
|
||||||
|
|
||||||
wandb.login(key=wandb_api_key, host=host)
|
wandb.login(key=wandb_api_key, host=host)
|
||||||
wandb_project_name = os.environ.get("WANDB_PROJECT")
|
wandb_project_name = os.environ.get("WANDB_PROJECT")
|
||||||
wandb_user_name = os.environ.get("WANDB_USER")
|
wandb_user_name = os.environ.get("WANDB_USER")
|
||||||
|
|
||||||
script_name = os.path.splitext(os.path.basename(__file__))[0]
|
script_name = os.path.splitext(os.path.basename(__file__))[0]
|
||||||
logger = logging.getLogger(script_name)
|
logger = logging.getLogger(script_name)
|
||||||
logger.setLevel(logging.WARNING)
|
logger.setLevel(logging.INFO)
|
||||||
|
|
||||||
|
|
||||||
def set_random_seed(seed=3):
|
def set_random_seed(seed=3):
|
||||||
@ -54,6 +56,11 @@ def get_trainer_args(config):
|
|||||||
return trainer_args
|
return trainer_args
|
||||||
|
|
||||||
|
|
||||||
|
def get_arg(arg):
|
||||||
|
if type(arg) == list:
|
||||||
|
return arg[0]
|
||||||
|
return arg
|
||||||
|
|
||||||
class HyperparameterSweep:
|
class HyperparameterSweep:
|
||||||
def __init__(self, project_name, sweep_config):
|
def __init__(self, project_name, sweep_config):
|
||||||
self.project_name = project_name
|
self.project_name = project_name
|
||||||
@ -84,25 +91,40 @@ class HyperparameterSweep:
|
|||||||
return all_not_running
|
return all_not_running
|
||||||
|
|
||||||
def run_experiment(self):
|
def run_experiment(self):
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
|
||||||
logger.debug("Starting a new run...")
|
logger.info("Starting a new run...")
|
||||||
|
|
||||||
run = wandb.init(
|
run = wandb.init(
|
||||||
project=self.project_name,
|
project=self.project_name,
|
||||||
config=config_loader.config,
|
config=config_loader.config,
|
||||||
save_code=True
|
save_code=True,
|
||||||
|
entity=wandb_user_name
|
||||||
)
|
)
|
||||||
run.log_code(
|
run.log_code(
|
||||||
root=".",
|
root=".",
|
||||||
include_fn=lambda path: path.endswith(".py") or path.endswith(".sh"),
|
include_fn=lambda path: path.endswith(".py") or path.endswith(".sh"),
|
||||||
exclude_fn=lambda path: path.endswith("template.sh")
|
exclude_fn=lambda path: path.endswith("template.sh")
|
||||||
) # not working as expected
|
)
|
||||||
|
|
||||||
model_name = wandb.config.model_name[0]
|
model_name = get_arg(wandb.config.model_name)
|
||||||
model_args = models.get_model_specific_args(wandb.config)
|
model_args = models.get_model_specific_args(wandb.config)
|
||||||
logger.debug(f"Initializing {model_name}")
|
|
||||||
|
|
||||||
|
if "pretrained" in wandb.config:
|
||||||
|
weights = get_arg(wandb.config.pretrained)
|
||||||
|
if weights != "false":
|
||||||
|
model_args["pretrained"] = weights
|
||||||
|
|
||||||
|
if "norm" in wandb.config:
|
||||||
|
model_args["norm"] = get_arg(wandb.config.norm)
|
||||||
|
|
||||||
|
if "finetuning" in wandb.config:
|
||||||
|
model_args['finetuning_strategy'] = get_arg(wandb.config.finetuning)
|
||||||
|
|
||||||
|
if "lr_reduce_factor" in wandb.config:
|
||||||
|
model_args['steplr_gamma'] = get_arg(wandb.config.lr_reduce_factor)
|
||||||
|
|
||||||
|
logger.debug(f"Initializing {model_name} with args: {model_args}")
|
||||||
model = models.__getattribute__(model_name + "Lit")(**model_args)
|
model = models.__getattribute__(model_name + "Lit")(**model_args)
|
||||||
|
|
||||||
train_loader, dev_loader = train.prepare_data(wandb.config, model, test_run=False)
|
train_loader, dev_loader = train.prepare_data(wandb.config, model, test_run=False)
|
||||||
@ -132,10 +154,13 @@ class HyperparameterSweep:
|
|||||||
|
|
||||||
early_stopping_callback = EarlyStopping(
|
early_stopping_callback = EarlyStopping(
|
||||||
monitor="val_loss",
|
monitor="val_loss",
|
||||||
patience=3,
|
patience=5,
|
||||||
verbose=True,
|
verbose=True,
|
||||||
mode="min")
|
mode="min")
|
||||||
callbacks = [checkpoint_callback, early_stopping_callback]
|
|
||||||
|
lr_monitor = LearningRateMonitor(logging_interval='epoch')
|
||||||
|
|
||||||
|
callbacks = [checkpoint_callback, early_stopping_callback, lr_monitor]
|
||||||
|
|
||||||
trainer = pl.Trainer(
|
trainer = pl.Trainer(
|
||||||
default_root_dir=config_loader.models_path,
|
default_root_dir=config_loader.models_path,
|
||||||
@ -143,7 +168,6 @@ class HyperparameterSweep:
|
|||||||
callbacks=callbacks,
|
callbacks=callbacks,
|
||||||
**get_trainer_args(wandb.config)
|
**get_trainer_args(wandb.config)
|
||||||
)
|
)
|
||||||
|
|
||||||
trainer.fit(model, train_loader, dev_loader)
|
trainer.fit(model, train_loader, dev_loader)
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
@ -155,7 +179,6 @@ class HyperparameterSweep:
|
|||||||
|
|
||||||
|
|
||||||
def start_sweep(sweep_config):
|
def start_sweep(sweep_config):
|
||||||
|
|
||||||
logger.info("Starting sweep with config: " + str(sweep_config))
|
logger.info("Starting sweep with config: " + str(sweep_config))
|
||||||
set_random_seed(config_loader.seed)
|
set_random_seed(config_loader.seed)
|
||||||
sweep_runner = HyperparameterSweep(project_name=wandb_project_name, sweep_config=sweep_config)
|
sweep_runner = HyperparameterSweep(project_name=wandb_project_name, sweep_config=sweep_config)
|
||||||
@ -165,7 +188,6 @@ def start_sweep(sweep_config):
|
|||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
|
||||||
parser = argparse.ArgumentParser()
|
parser = argparse.ArgumentParser()
|
||||||
parser.add_argument("--sweep_config", type=str, required=True)
|
parser.add_argument("--sweep_config", type=str, required=True)
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
146
scripts/input_validate.py
Normal file
146
scripts/input_validate.py
Normal file
@ -0,0 +1,146 @@
|
|||||||
|
from pydantic import BaseModel, ConfigDict, field_validator
|
||||||
|
from typing_extensions import Literal
|
||||||
|
from typing import Union, List, Optional
|
||||||
|
import yaml
|
||||||
|
import logging
|
||||||
|
|
||||||
|
logging.root.setLevel(logging.INFO)
|
||||||
|
logger = logging.getLogger('input_validator')
|
||||||
|
|
||||||
|
#todo
|
||||||
|
# 1. check if a single value is allowed in a sweep
|
||||||
|
# 2. merge input params
|
||||||
|
# 3. change names of the classes
|
||||||
|
# 4. add constraints for PhaseNet, GPD
|
||||||
|
|
||||||
|
|
||||||
|
model_names = Literal["PhaseNet", "GPD", "BasicPhaseAE", "EQTransformer"]
|
||||||
|
norm_values = Literal["peak", "std"]
|
||||||
|
finetuning_values = Literal["all", "top", "decoder", "encoder"]
|
||||||
|
pretrained_values = Literal['diting', 'ethz', 'geofon', 'instance', 'iquique', 'lendb', 'neic',
|
||||||
|
'original', 'scedc', False]
|
||||||
|
|
||||||
|
|
||||||
|
class Metric(BaseModel):
|
||||||
|
goal: str
|
||||||
|
name: str
|
||||||
|
|
||||||
|
|
||||||
|
class NumericValue(BaseModel):
|
||||||
|
value: Union[int, float, List[Union[int, float]]]
|
||||||
|
|
||||||
|
|
||||||
|
class NumericValues(BaseModel):
|
||||||
|
values: List[Union[int, float]]
|
||||||
|
|
||||||
|
|
||||||
|
class IntDistribution(BaseModel):
|
||||||
|
distribution: str = "int_uniform"
|
||||||
|
min: int
|
||||||
|
max: int
|
||||||
|
|
||||||
|
|
||||||
|
class FloatDistribution(BaseModel):
|
||||||
|
distribution: str = "uniform"
|
||||||
|
min: float
|
||||||
|
max: float
|
||||||
|
|
||||||
|
|
||||||
|
class Pretrained(BaseModel):
|
||||||
|
distribution: Optional[str] = "categorical"
|
||||||
|
values: List[pretrained_values] = None
|
||||||
|
value: Union[pretrained_values, List[pretrained_values]] = None
|
||||||
|
|
||||||
|
|
||||||
|
class Finetuning(BaseModel):
|
||||||
|
distribution: Optional[str] = "categorical"
|
||||||
|
values: List[finetuning_values] = None
|
||||||
|
value: Union[finetuning_values, List[finetuning_values]] = None
|
||||||
|
|
||||||
|
|
||||||
|
class Norm(BaseModel):
|
||||||
|
distribution: Optional[str] = "categorical"
|
||||||
|
values: List[norm_values] = None
|
||||||
|
value: Union[norm_values, List[norm_values]] = None
|
||||||
|
|
||||||
|
|
||||||
|
class ModelType(BaseModel):
|
||||||
|
distribution: Optional[str] = "categorical"
|
||||||
|
value: Union[model_names, List[model_names]] = None
|
||||||
|
values: List[model_names] = None
|
||||||
|
|
||||||
|
|
||||||
|
class Parameters(BaseModel):
|
||||||
|
model_config = ConfigDict(extra='forbid', protected_namespaces=())
|
||||||
|
model_name: ModelType
|
||||||
|
batch_size: Union[IntDistribution, NumericValue, NumericValues]
|
||||||
|
learning_rate: Union[FloatDistribution, NumericValue, NumericValues]
|
||||||
|
max_epochs: Union[IntDistribution, NumericValue, NumericValues]
|
||||||
|
|
||||||
|
|
||||||
|
class PhaseNetParameters(Parameters):
|
||||||
|
model_config = ConfigDict(extra='forbid')
|
||||||
|
norm: Norm = None
|
||||||
|
pretrained: Pretrained = None
|
||||||
|
finetuning: Finetuning = None
|
||||||
|
lr_reduce_factor: Optional[Union[FloatDistribution, NumericValue, NumericValues]] = None
|
||||||
|
|
||||||
|
highpass: Union[NumericValue, NumericValues, FloatDistribution, IntDistribution] = None
|
||||||
|
lowpass: Union[NumericValue, NumericValues, FloatDistribution, IntDistribution] = None
|
||||||
|
|
||||||
|
@field_validator("model_name")
|
||||||
|
def validate_model(cls, v):
|
||||||
|
if "PhaseNet" not in v.value:
|
||||||
|
raise ValueError("Additional parameters implemented for PhaseNet only")
|
||||||
|
return v
|
||||||
|
|
||||||
|
|
||||||
|
class FilteringParameters(Parameters):
|
||||||
|
model_config = ConfigDict(extra='forbid')
|
||||||
|
|
||||||
|
highpass: Union[NumericValue, NumericValues, FloatDistribution, IntDistribution] = None
|
||||||
|
lowpass: Union[NumericValue, NumericValues, FloatDistribution, IntDistribution] = None
|
||||||
|
|
||||||
|
@field_validator("model_name")
|
||||||
|
def validate_model(cls, v):
|
||||||
|
print(v.value)
|
||||||
|
if v.value[0] not in ["GPD", "PhaseNet"]:
|
||||||
|
raise ValueError("Filtering parameters implemented for GPD and PhaseNet only")
|
||||||
|
|
||||||
|
|
||||||
|
class InputParams(BaseModel):
|
||||||
|
name: str
|
||||||
|
method: str
|
||||||
|
metric: Metric
|
||||||
|
parameters: Union[Parameters, PhaseNetParameters, FilteringParameters]
|
||||||
|
|
||||||
|
|
||||||
|
def validate_sweep_yaml(yaml_filename, model_name=None):
|
||||||
|
# Load YAML configuration
|
||||||
|
with open(yaml_filename, 'r') as f:
|
||||||
|
sweep_config = yaml.safe_load(f)
|
||||||
|
|
||||||
|
validate_sweep_config(sweep_config, model_name)
|
||||||
|
|
||||||
|
|
||||||
|
def validate_sweep_config(sweep_config, model_name=None):
|
||||||
|
|
||||||
|
# Validate sweep config
|
||||||
|
|
||||||
|
input_params = InputParams(**sweep_config)
|
||||||
|
|
||||||
|
# Check consistency of input parameters and sweep configuration
|
||||||
|
sweep_model_name = input_params.parameters.model_name.value
|
||||||
|
if model_name is not None and model_name not in sweep_model_name:
|
||||||
|
info = f"Model name {model_name} is inconsistent with the sweep configuration {sweep_model_name}."
|
||||||
|
logger.info(info)
|
||||||
|
raise ValueError(info)
|
||||||
|
logger.info("Input validation successful.")
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
yaml_filename = "../experiments/sweep_phasenet_bogdanka_lr_bs.yaml"
|
||||||
|
validate_sweep_yaml(yaml_filename, None)
|
@ -7,16 +7,26 @@ import seisbench.generate as sbg
|
|||||||
|
|
||||||
import pytorch_lightning as pl
|
import pytorch_lightning as pl
|
||||||
import torch
|
import torch
|
||||||
|
from torch.optim import lr_scheduler
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from abc import abstractmethod, ABC
|
from abc import abstractmethod, ABC
|
||||||
|
|
||||||
|
# import lightning as L
|
||||||
|
|
||||||
|
|
||||||
# Allows to import this file in both jupyter notebook and code
|
# Allows to import this file in both jupyter notebook and code
|
||||||
try:
|
try:
|
||||||
from .augmentations import DuplicateEvent
|
from .augmentations import DuplicateEvent
|
||||||
except ImportError:
|
except ImportError:
|
||||||
from augmentations import DuplicateEvent
|
from augmentations import DuplicateEvent
|
||||||
|
|
||||||
|
import os
|
||||||
|
import logging
|
||||||
|
|
||||||
|
script_name = os.path.splitext(os.path.basename(__file__))[0]
|
||||||
|
logger = logging.getLogger(script_name)
|
||||||
|
logger.setLevel(logging.DEBUG)
|
||||||
|
|
||||||
# Phase dict for labelling. We only study P and S phases without differentiating between them.
|
# Phase dict for labelling. We only study P and S phases without differentiating between them.
|
||||||
phase_dict = {
|
phase_dict = {
|
||||||
@ -131,8 +141,32 @@ class PhaseNetLit(SeisBenchModuleLit):
|
|||||||
self.sigma = sigma
|
self.sigma = sigma
|
||||||
self.sample_boundaries = sample_boundaries
|
self.sample_boundaries = sample_boundaries
|
||||||
self.loss = vector_cross_entropy
|
self.loss = vector_cross_entropy
|
||||||
|
self.pretrained = kwargs.pop("pretrained", None)
|
||||||
|
self.norm = kwargs.pop("norm", "peak")
|
||||||
|
self.highpass = kwargs.pop("highpass", None)
|
||||||
|
self.lowpass = kwargs.pop("lowpass", None)
|
||||||
|
|
||||||
|
|
||||||
|
if self.pretrained is not None:
|
||||||
|
self.model = sbm.PhaseNet.from_pretrained(self.pretrained)
|
||||||
|
# self.norm = self.model.norm
|
||||||
|
else:
|
||||||
self.model = sbm.PhaseNet(**kwargs)
|
self.model = sbm.PhaseNet(**kwargs)
|
||||||
|
|
||||||
|
self.finetuning_strategy = kwargs.pop("finetuning_strategy", None)
|
||||||
|
self.steplr_gamma = kwargs.pop("steplr_gamma", 0.1)
|
||||||
|
self.reduce_lr_on_plateau = False
|
||||||
|
|
||||||
|
self.initial_epochs = 0
|
||||||
|
|
||||||
|
if self.finetuning_strategy is not None:
|
||||||
|
if self.finetuning_strategy == "top":
|
||||||
|
self.initial_epochs = 3
|
||||||
|
elif self.finetuning_strategy in ["decoder", "encoder"]:
|
||||||
|
self.initial_epochs = 6
|
||||||
|
|
||||||
|
self.freeze()
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
return self.model(x)
|
return self.model(x)
|
||||||
|
|
||||||
@ -154,9 +188,49 @@ class PhaseNetLit(SeisBenchModuleLit):
|
|||||||
|
|
||||||
def configure_optimizers(self):
|
def configure_optimizers(self):
|
||||||
optimizer = torch.optim.Adam(self.parameters(), lr=self.lr)
|
optimizer = torch.optim.Adam(self.parameters(), lr=self.lr)
|
||||||
return optimizer
|
if self.finetuning_strategy is not None:
|
||||||
|
scheduler = lr_scheduler.LambdaLR(optimizer, self.lr_lambda)
|
||||||
|
self.reduce_lr_on_plateau = False
|
||||||
|
else:
|
||||||
|
scheduler = lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.1, patience=3)
|
||||||
|
self.reduce_lr_on_plateau = True
|
||||||
|
#
|
||||||
|
return {
|
||||||
|
'optimizer': optimizer,
|
||||||
|
'lr_scheduler': {
|
||||||
|
'scheduler': scheduler,
|
||||||
|
'monitor': 'val_loss',
|
||||||
|
'interval': 'epoch',
|
||||||
|
'reduce_on_plateau': self.reduce_lr_on_plateau,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
def lr_lambda(self, epoch):
|
||||||
|
# reduce lr after x initial epochs
|
||||||
|
if epoch == self.initial_epochs:
|
||||||
|
self.lr *= self.steplr_gamma
|
||||||
|
|
||||||
|
return self.lr
|
||||||
|
|
||||||
|
def lr_scheduler_step(self, scheduler, metric):
|
||||||
|
if self.reduce_lr_on_plateau:
|
||||||
|
scheduler.step(metric, epoch=self.current_epoch)
|
||||||
|
else:
|
||||||
|
scheduler.step(epoch=self.current_epoch)
|
||||||
|
|
||||||
|
# def lr_scheduler_step(self, scheduler, optimizer_idx, metric):
|
||||||
|
# scheduler.step(epoch=self.current_epoch)
|
||||||
|
|
||||||
def get_augmentations(self):
|
def get_augmentations(self):
|
||||||
|
filter = []
|
||||||
|
if self.highpass is not None:
|
||||||
|
filter = [sbg.Filter(1, self.highpass, "highpass", forward_backward=True)]
|
||||||
|
logger.info(f"Using highpass filer {self.highpass}")
|
||||||
|
if self.lowpass is not None:
|
||||||
|
filter += [sbg.Filter(1, self.lowpass, "lowpass", forward_backward=True)]
|
||||||
|
logger.info(f"Using lowpass filer {self.lowpass}")
|
||||||
|
logger.info(filter)
|
||||||
|
|
||||||
return [
|
return [
|
||||||
# In 2/3 of the cases, select windows around picks, to reduce amount of noise traces in training.
|
# In 2/3 of the cases, select windows around picks, to reduce amount of noise traces in training.
|
||||||
# Uses strategy variable, as padding will be handled by the random window.
|
# Uses strategy variable, as padding will be handled by the random window.
|
||||||
@ -180,18 +254,26 @@ class PhaseNetLit(SeisBenchModuleLit):
|
|||||||
windowlen=3001,
|
windowlen=3001,
|
||||||
strategy="pad",
|
strategy="pad",
|
||||||
),
|
),
|
||||||
|
sbg.Normalize(demean_axis=-1, amp_norm_axis=-1, amp_norm_type=self.norm),
|
||||||
|
*filter,
|
||||||
sbg.ChangeDtype(np.float32),
|
sbg.ChangeDtype(np.float32),
|
||||||
sbg.Normalize(demean_axis=-1, amp_norm_axis=-1, amp_norm_type="peak"),
|
|
||||||
sbg.ProbabilisticLabeller(
|
sbg.ProbabilisticLabeller(
|
||||||
label_columns=phase_dict, sigma=self.sigma, dim=0
|
label_columns=phase_dict, sigma=self.sigma, dim=0
|
||||||
),
|
),
|
||||||
]
|
]
|
||||||
|
|
||||||
def get_eval_augmentations(self):
|
def get_eval_augmentations(self):
|
||||||
|
filter = []
|
||||||
|
if self.highpass is not None:
|
||||||
|
filter = [sbg.Filter(1, self.highpass, "highpass", forward_backward=True)]
|
||||||
|
if self.lowpass is not None:
|
||||||
|
filter += [sbg.Filter(1, self.lowpass, "lowpass", forward_backward=True)]
|
||||||
|
|
||||||
return [
|
return [
|
||||||
sbg.SteeredWindow(windowlen=3001, strategy="pad"),
|
sbg.SteeredWindow(windowlen=3001, strategy="pad"),
|
||||||
|
sbg.Normalize(demean_axis=-1, amp_norm_axis=-1, amp_norm_type=self.norm),
|
||||||
|
*filter,
|
||||||
sbg.ChangeDtype(np.float32),
|
sbg.ChangeDtype(np.float32),
|
||||||
sbg.Normalize(demean_axis=-1, amp_norm_axis=-1, amp_norm_type="peak"),
|
|
||||||
]
|
]
|
||||||
|
|
||||||
def predict_step(self, batch, batch_idx=None, dataloader_idx=None):
|
def predict_step(self, batch, batch_idx=None, dataloader_idx=None):
|
||||||
@ -219,6 +301,27 @@ class PhaseNetLit(SeisBenchModuleLit):
|
|||||||
|
|
||||||
return score_detection, score_p_or_s, p_sample, s_sample
|
return score_detection, score_p_or_s, p_sample, s_sample
|
||||||
|
|
||||||
|
def freeze(self):
|
||||||
|
if self.finetuning_strategy == "decoder": # finetune decoder branch and freeze encoder branch
|
||||||
|
for p in self.model.down_branch.parameters():
|
||||||
|
p.requires_grad = False
|
||||||
|
elif self.finetuning_strategy == "encoder": # finetune encoder branch and freeze decoder branch
|
||||||
|
for p in self.model.up_branch.parameters():
|
||||||
|
p.requires_grad = False
|
||||||
|
elif self.finetuning_strategy == "top":
|
||||||
|
for p in self.model.out.parameters():
|
||||||
|
p.requires_grad = False
|
||||||
|
|
||||||
|
def unfreeze(self):
|
||||||
|
logger.info("Unfreezing layers")
|
||||||
|
for p in self.model.parameters():
|
||||||
|
p.requires_grad = True
|
||||||
|
|
||||||
|
def on_train_epoch_start(self):
|
||||||
|
# Unfreeze some layers after x initial epochs
|
||||||
|
if self.current_epoch == self.initial_epochs:
|
||||||
|
self.unfreeze()
|
||||||
|
|
||||||
|
|
||||||
class GPDLit(SeisBenchModuleLit):
|
class GPDLit(SeisBenchModuleLit):
|
||||||
"""
|
"""
|
||||||
@ -846,7 +949,7 @@ class BasicPhaseAELit(SeisBenchModuleLit):
|
|||||||
# Create overlapping windows
|
# Create overlapping windows
|
||||||
re = torch.zeros(x.shape[:2] + (7, 600), dtype=x.dtype, device=x.device)
|
re = torch.zeros(x.shape[:2] + (7, 600), dtype=x.dtype, device=x.device)
|
||||||
for i, start in enumerate(range(0, 2401, 400)):
|
for i, start in enumerate(range(0, 2401, 400)):
|
||||||
re[:, :, i] = x[:, :, start : start + 600]
|
re[:, :, i] = x[:, :, start: start + 600]
|
||||||
x = re
|
x = re
|
||||||
|
|
||||||
x = x.permute(0, 2, 1, 3) # --> (batch, windows, channels, samples)
|
x = x.permute(0, 2, 1, 3) # --> (batch, windows, channels, samples)
|
||||||
@ -862,9 +965,9 @@ class BasicPhaseAELit(SeisBenchModuleLit):
|
|||||||
for i, start in enumerate(range(0, 2401, 400)):
|
for i, start in enumerate(range(0, 2401, 400)):
|
||||||
if start == 0:
|
if start == 0:
|
||||||
# Use full window (for start==0, the end will be overwritten)
|
# Use full window (for start==0, the end will be overwritten)
|
||||||
pred[:, :, start : start + 600] = window_pred[:, i]
|
pred[:, :, start: start + 600] = window_pred[:, i]
|
||||||
else:
|
else:
|
||||||
pred[:, :, start + 100 : start + 600] = window_pred[:, i, :, 100:]
|
pred[:, :, start + 100: start + 600] = window_pred[:, i, :, 100:]
|
||||||
|
|
||||||
score_detection = torch.zeros(pred.shape[0])
|
score_detection = torch.zeros(pred.shape[0])
|
||||||
score_p_or_s = torch.zeros(pred.shape[0])
|
score_p_or_s = torch.zeros(pred.shape[0])
|
||||||
@ -1112,7 +1215,6 @@ class DPPPickerLit(SeisBenchModuleLit):
|
|||||||
|
|
||||||
|
|
||||||
def get_model_specific_args(config):
|
def get_model_specific_args(config):
|
||||||
|
|
||||||
model = config.model_name[0]
|
model = config.model_name[0]
|
||||||
lr = config.learning_rate
|
lr = config.learning_rate
|
||||||
if type(lr) == list:
|
if type(lr) == list:
|
||||||
@ -1125,8 +1227,12 @@ def get_model_specific_args(config):
|
|||||||
if 'highpass' in config:
|
if 'highpass' in config:
|
||||||
args['highpass'] = config.highpass
|
args['highpass'] = config.highpass
|
||||||
if 'lowpass' in config:
|
if 'lowpass' in config:
|
||||||
args['lowpass'] = config.lowpass[0]
|
args['lowpass'] = config.lowpass
|
||||||
case 'PhaseNet':
|
case 'PhaseNet':
|
||||||
|
if 'highpass' in config:
|
||||||
|
args['highpass'] = config.highpass
|
||||||
|
if 'lowpass' in config:
|
||||||
|
args['lowpass'] = config.lowpass
|
||||||
if 'sample_boundaries' in config:
|
if 'sample_boundaries' in config:
|
||||||
args['sample_boundaries'] = config.sample_boundaries
|
args['sample_boundaries'] = config.sample_boundaries
|
||||||
case 'DPPPicker':
|
case 'DPPPicker':
|
||||||
|
@ -60,6 +60,8 @@ def split_events(events, input_path):
|
|||||||
|
|
||||||
logger.info(f"Split: {events_stats['split'].value_counts()}")
|
logger.info(f"Split: {events_stats['split'].value_counts()}")
|
||||||
|
|
||||||
|
logger.info(f"Split: {events_stats['split'].value_counts()}")
|
||||||
|
|
||||||
return events_stats
|
return events_stats
|
||||||
|
|
||||||
|
|
||||||
@ -201,6 +203,7 @@ def convert_mseed_to_seisbench_format(input_path, catalog_path, output_path):
|
|||||||
|
|
||||||
events_to_convert = events_stats[events_stats['pick_count'] > 0]
|
events_to_convert = events_stats[events_stats['pick_count'] > 0]
|
||||||
|
|
||||||
|
|
||||||
logger.debug("Catalog loaded, starting converting {events_to_convert} events ...")
|
logger.debug("Catalog loaded, starting converting {events_to_convert} events ...")
|
||||||
|
|
||||||
with sbd.WaveformDataWriter(metadata_path, waveforms_path) as writer:
|
with sbd.WaveformDataWriter(metadata_path, waveforms_path) as writer:
|
||||||
|
@ -17,6 +17,8 @@ import eval
|
|||||||
import collect_results
|
import collect_results
|
||||||
import importlib
|
import importlib
|
||||||
import config_loader
|
import config_loader
|
||||||
|
import input_validate
|
||||||
|
import data
|
||||||
|
|
||||||
logging.root.setLevel(logging.INFO)
|
logging.root.setLevel(logging.INFO)
|
||||||
logger = logging.getLogger('pipeline')
|
logger = logging.getLogger('pipeline')
|
||||||
@ -40,11 +42,22 @@ def load_sweep_config(model_name, args):
|
|||||||
return util.load_sweep_config(sweep_fname)
|
return util.load_sweep_config(sweep_fname)
|
||||||
|
|
||||||
|
|
||||||
def find_the_best_params(model_name, args):
|
def validate_pipeline_input(args):
|
||||||
|
|
||||||
|
# validate input parameters
|
||||||
|
for model_name in args.models:
|
||||||
|
sweep_config = load_sweep_config(model_name, args)
|
||||||
|
input_validate.validate_sweep_config(sweep_config, model_name)
|
||||||
|
|
||||||
|
# validate dataset
|
||||||
|
data.validate_custom_dataset(config_loader.data_path)
|
||||||
|
|
||||||
|
|
||||||
|
def find_the_best_params(sweep_config):
|
||||||
# find the best hyperparams for the model_name
|
# find the best hyperparams for the model_name
|
||||||
|
model_name = sweep_config['parameters']['model_name']
|
||||||
logger.info(f"Starting searching for the best hyperparams for the model: {model_name}")
|
logger.info(f"Starting searching for the best hyperparams for the model: {model_name}")
|
||||||
|
|
||||||
sweep_config = load_sweep_config(model_name, args)
|
|
||||||
sweep_runner = hyperparameter_sweep.start_sweep(sweep_config)
|
sweep_runner = hyperparameter_sweep.start_sweep(sweep_config)
|
||||||
|
|
||||||
# wait for all runs to finish
|
# wait for all runs to finish
|
||||||
@ -91,6 +104,8 @@ def main():
|
|||||||
util.set_dataset(args.dataset)
|
util.set_dataset(args.dataset)
|
||||||
importlib.reload(config_loader)
|
importlib.reload(config_loader)
|
||||||
|
|
||||||
|
validate_pipeline_input(args)
|
||||||
|
|
||||||
logger.info(f"Started pipeline for the {config_loader.dataset_name} dataset.")
|
logger.info(f"Started pipeline for the {config_loader.dataset_name} dataset.")
|
||||||
|
|
||||||
# generate labels
|
# generate labels
|
||||||
@ -101,7 +116,8 @@ def main():
|
|||||||
# find the best hyperparams for the models
|
# find the best hyperparams for the models
|
||||||
logger.info("Started training the models.")
|
logger.info("Started training the models.")
|
||||||
for model_name in args.models:
|
for model_name in args.models:
|
||||||
sweep_id = find_the_best_params(model_name, args)
|
sweep_config = load_sweep_config(model_name, args)
|
||||||
|
sweep_id = find_the_best_params(sweep_config)
|
||||||
generate_predictions(sweep_id, model_name)
|
generate_predictions(sweep_id, model_name)
|
||||||
|
|
||||||
# collect results
|
# collect results
|
||||||
|
Loading…
Reference in New Issue
Block a user