Compare commits
No commits in common. "master" and "bugfix/annotate-both-P-and-S-phases-in-one-trace-if-present" have entirely different histories.
master
...
bugfix/ann
23
README.md
23
README.md
@ -22,7 +22,7 @@ Please download and install [Mambaforge](https://github.com/conda-forge/miniforg
|
|||||||
After successful installation and within the Mambaforge environment please clone this repository:
|
After successful installation and within the Mambaforge environment please clone this repository:
|
||||||
|
|
||||||
```
|
```
|
||||||
git clone https://epos-apps.grid.cyfronet.pl/epos-ai/platform-demo-scripts.git
|
git clone ssh://git@git.plgrid.pl:7999/eai/platform-demo-scripts.git
|
||||||
```
|
```
|
||||||
and please run for Linux or Windows platforms:
|
and please run for Linux or Windows platforms:
|
||||||
|
|
||||||
@ -111,7 +111,7 @@ After adjusting the grant name, the paths to conda env and the paths to data sen
|
|||||||
|
|
||||||
The script performs the following steps:
|
The script performs the following steps:
|
||||||
1. Generates evaluation targets in `datasets/<dataset_name>/targets` directory.
|
1. Generates evaluation targets in `datasets/<dataset_name>/targets` directory.
|
||||||
1. Trains multiple versions of GPD, PhaseNet, BasicPhaseAE, and EQTransformer models to find the best hyperparameters, producing the lowest validation loss.
|
1. Trains multiple versions of GPD, PhaseNet and ... models to find the best hyperparameters, producing the lowest validation loss.
|
||||||
|
|
||||||
This step utilizes the Weights & Biases platform to perform the hyperparameters search (called sweeping) and track the training process and store the results.
|
This step utilizes the Weights & Biases platform to perform the hyperparameters search (called sweeping) and track the training process and store the results.
|
||||||
The results are available at
|
The results are available at
|
||||||
@ -126,29 +126,12 @@ After adjusting the grant name, the paths to conda env and the paths to data sen
|
|||||||
The results are saved in the `scripts/pred/results.csv` file. They are additionally logged in Weights & Biases platform as summary metrics of corresponding runs.
|
The results are saved in the `scripts/pred/results.csv` file. They are additionally logged in Weights & Biases platform as summary metrics of corresponding runs.
|
||||||
|
|
||||||
<br/>
|
<br/>
|
||||||
The default settings for max number of experiments and paths are saved in config.json file. To change the settings, edit the config.json file or pass the new settings as arguments to the script. For example, to change the sweep configuration file for the GPD model, run:
|
The default settings are saved in config.json file. To change the settings, edit the config.json file or pass the new settings as arguments to the script. For example, to change the sweep configuration file for GPD model, run:
|
||||||
|
|
||||||
```python pipeline.py --gpd_config <new config file>```
|
```python pipeline.py --gpd_config <new config file>```
|
||||||
|
|
||||||
The new config file should be placed in the `experiments` folder or as specified in the `configs_path` parameter in the config.json file.
|
The new config file should be placed in the `experiments` folder or as specified in the `configs_path` parameter in the config.json file.
|
||||||
|
|
||||||
Sweep configs are used to define the max number of epochs to run and the hyperparameters search space for the following parameters:
|
|
||||||
* `batch_size`
|
|
||||||
* `learning_rate`
|
|
||||||
|
|
||||||
Phasenet model has additional available parameters:
|
|
||||||
* `norm` - normalization method, options ('peak', 'std')
|
|
||||||
* `pretrained` - pretrained seisbench models used for transfer learning
|
|
||||||
* `finetuning` - the type of layers to finetune first, options ('all', 'top', 'encoder', 'decoder')
|
|
||||||
* `lr_reduce_factor` - factor to reduce learning rate after unfreezing layers
|
|
||||||
|
|
||||||
GPD model has additional parameters for filtering:
|
|
||||||
* `highpass` - highpass filter frequency
|
|
||||||
* `lowpass` - lowpass filter frequency
|
|
||||||
|
|
||||||
The sweep configs are saved in the `experiments` folder.
|
|
||||||
|
|
||||||
|
|
||||||
If you have multiple datasets, you can run the pipeline for each dataset separately by specifying the dataset name as an argument:
|
If you have multiple datasets, you can run the pipeline for each dataset separately by specifying the dataset name as an argument:
|
||||||
|
|
||||||
```python pipeline.py --dataset <dataset_name>```
|
```python pipeline.py --dataset <dataset_name>```
|
||||||
|
@ -1,6 +1,6 @@
|
|||||||
{
|
{
|
||||||
"dataset_name": "bogdanka_2018_2022",
|
"dataset_name": "bogdanka",
|
||||||
"data_path": "datasets/bogdanka_2018_2022/seisbench_format/",
|
"data_path": "datasets/bogdanka/seisbench_format/",
|
||||||
"targets_path": "datasets/targets",
|
"targets_path": "datasets/targets",
|
||||||
"models_path": "weights",
|
"models_path": "weights",
|
||||||
"configs_path": "experiments",
|
"configs_path": "experiments",
|
||||||
@ -13,5 +13,5 @@
|
|||||||
"BasicPhaseAE": "sweep_basicphase_ae.yaml",
|
"BasicPhaseAE": "sweep_basicphase_ae.yaml",
|
||||||
"EQTransformer": "sweep_eqtransformer.yaml"
|
"EQTransformer": "sweep_eqtransformer.yaml"
|
||||||
},
|
},
|
||||||
"experiment_count": 15
|
"experiment_count": 20
|
||||||
}
|
}
|
@ -1,4 +1,3 @@
|
|||||||
name: BasicPhaseAE
|
|
||||||
method: bayes
|
method: bayes
|
||||||
metric:
|
metric:
|
||||||
goal: minimize
|
goal: minimize
|
||||||
@ -8,9 +7,13 @@ parameters:
|
|||||||
value:
|
value:
|
||||||
- BasicPhaseAE
|
- BasicPhaseAE
|
||||||
batch_size:
|
batch_size:
|
||||||
values: [64, 128, 256]
|
distribution: int_uniform
|
||||||
|
max: 1024
|
||||||
|
min: 256
|
||||||
max_epochs:
|
max_epochs:
|
||||||
value:
|
value:
|
||||||
- 30
|
- 20
|
||||||
learning_rate:
|
learning_rate:
|
||||||
values: [0.01, 0.005, 0.001]
|
distribution: uniform
|
||||||
|
max: 0.02
|
||||||
|
min: 0.001
|
||||||
|
@ -8,9 +8,13 @@ parameters:
|
|||||||
value:
|
value:
|
||||||
- EQTransformer
|
- EQTransformer
|
||||||
batch_size:
|
batch_size:
|
||||||
values: [64, 128, 256]
|
distribution: int_uniform
|
||||||
|
max: 1024
|
||||||
|
min: 256
|
||||||
max_epochs:
|
max_epochs:
|
||||||
value:
|
value:
|
||||||
- 30
|
- 30
|
||||||
learning_rate:
|
learning_rate:
|
||||||
values: [0.01, 0.005, 0.001]
|
distribution: uniform
|
||||||
|
max: 0.02
|
||||||
|
min: 0.005
|
||||||
|
@ -1,4 +1,4 @@
|
|||||||
name: GPD
|
name: GPD_fixed_highpass:2-10
|
||||||
method: bayes
|
method: bayes
|
||||||
metric:
|
metric:
|
||||||
goal: minimize
|
goal: minimize
|
||||||
@ -8,12 +8,16 @@ parameters:
|
|||||||
value:
|
value:
|
||||||
- GPD
|
- GPD
|
||||||
batch_size:
|
batch_size:
|
||||||
values: [64, 128, 256]
|
distribution: int_uniform
|
||||||
|
max: 1024
|
||||||
|
min: 256
|
||||||
max_epochs:
|
max_epochs:
|
||||||
value:
|
value:
|
||||||
- 30
|
- 30
|
||||||
learning_rate:
|
learning_rate:
|
||||||
values: [0.01, 0.005, 0.001]
|
distribution: uniform
|
||||||
|
max: 0.02
|
||||||
|
min: 0.005
|
||||||
highpass:
|
highpass:
|
||||||
value:
|
value:
|
||||||
- 1
|
- 1
|
||||||
|
@ -1,13 +1,8 @@
|
|||||||
"""
|
"""
|
||||||
This file contains functionality related to data.
|
This file contains functionality related to data.
|
||||||
"""
|
"""
|
||||||
import os.path
|
|
||||||
|
|
||||||
import seisbench.data as sbd
|
import seisbench.data as sbd
|
||||||
import logging
|
|
||||||
|
|
||||||
logging.root.setLevel(logging.INFO)
|
|
||||||
logger = logging.getLogger('data')
|
|
||||||
|
|
||||||
|
|
||||||
def get_dataset_by_name(name):
|
def get_dataset_by_name(name):
|
||||||
@ -35,26 +30,3 @@ def get_custom_dataset(path):
|
|||||||
except AttributeError:
|
except AttributeError:
|
||||||
raise ValueError(f"Unknown dataset '{path}'.")
|
raise ValueError(f"Unknown dataset '{path}'.")
|
||||||
|
|
||||||
|
|
||||||
def validate_custom_dataset(data_path):
|
|
||||||
"""
|
|
||||||
Validate the dataset
|
|
||||||
:param data_path: path to the dataset
|
|
||||||
:return:
|
|
||||||
"""
|
|
||||||
# check if path exists
|
|
||||||
if not os.path.isdir((data_path)):
|
|
||||||
raise ValueError(f"Data path {data_path} does not exist.")
|
|
||||||
|
|
||||||
dataset = sbd.WaveformDataset(data_path)
|
|
||||||
# check if the dataset is split into train, dev and test
|
|
||||||
if len(dataset.train()) == 0:
|
|
||||||
raise ValueError(f"Training set is empty.")
|
|
||||||
if len(dataset.dev()) == 0:
|
|
||||||
raise ValueError(f"Dev set is empty.")
|
|
||||||
if len(dataset.test()) == 0:
|
|
||||||
raise ValueError(f"Test set is empty.")
|
|
||||||
|
|
||||||
logger.info("Custom dataset validated successfully.")
|
|
||||||
|
|
||||||
|
|
||||||
|
@ -7,9 +7,8 @@ import os
|
|||||||
import os.path
|
import os.path
|
||||||
import argparse
|
import argparse
|
||||||
from pytorch_lightning.loggers import WandbLogger, CSVLogger
|
from pytorch_lightning.loggers import WandbLogger, CSVLogger
|
||||||
from pytorch_lightning.callbacks import ModelCheckpoint, LearningRateMonitor
|
from pytorch_lightning.callbacks import ModelCheckpoint
|
||||||
from pytorch_lightning.callbacks.early_stopping import EarlyStopping
|
from pytorch_lightning.callbacks.early_stopping import EarlyStopping
|
||||||
|
|
||||||
import pytorch_lightning as pl
|
import pytorch_lightning as pl
|
||||||
import wandb
|
import wandb
|
||||||
import torch
|
import torch
|
||||||
@ -21,8 +20,6 @@ import train
|
|||||||
import util
|
import util
|
||||||
import config_loader
|
import config_loader
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
torch.multiprocessing.set_sharing_strategy('file_system')
|
torch.multiprocessing.set_sharing_strategy('file_system')
|
||||||
os.system("ulimit -n unlimited")
|
os.system("ulimit -n unlimited")
|
||||||
|
|
||||||
@ -34,13 +31,14 @@ host = os.environ.get("WANDB_HOST")
|
|||||||
if host is None:
|
if host is None:
|
||||||
raise ValueError("WANDB_HOST environment variable is not set.")
|
raise ValueError("WANDB_HOST environment variable is not set.")
|
||||||
|
|
||||||
|
|
||||||
wandb.login(key=wandb_api_key, host=host)
|
wandb.login(key=wandb_api_key, host=host)
|
||||||
wandb_project_name = os.environ.get("WANDB_PROJECT")
|
wandb_project_name = os.environ.get("WANDB_PROJECT")
|
||||||
wandb_user_name = os.environ.get("WANDB_USER")
|
wandb_user_name = os.environ.get("WANDB_USER")
|
||||||
|
|
||||||
script_name = os.path.splitext(os.path.basename(__file__))[0]
|
script_name = os.path.splitext(os.path.basename(__file__))[0]
|
||||||
logger = logging.getLogger(script_name)
|
logger = logging.getLogger(script_name)
|
||||||
logger.setLevel(logging.INFO)
|
logger.setLevel(logging.WARNING)
|
||||||
|
|
||||||
|
|
||||||
def set_random_seed(seed=3):
|
def set_random_seed(seed=3):
|
||||||
@ -56,11 +54,6 @@ def get_trainer_args(config):
|
|||||||
return trainer_args
|
return trainer_args
|
||||||
|
|
||||||
|
|
||||||
def get_arg(arg):
|
|
||||||
if type(arg) == list:
|
|
||||||
return arg[0]
|
|
||||||
return arg
|
|
||||||
|
|
||||||
class HyperparameterSweep:
|
class HyperparameterSweep:
|
||||||
def __init__(self, project_name, sweep_config):
|
def __init__(self, project_name, sweep_config):
|
||||||
self.project_name = project_name
|
self.project_name = project_name
|
||||||
@ -91,40 +84,25 @@ class HyperparameterSweep:
|
|||||||
return all_not_running
|
return all_not_running
|
||||||
|
|
||||||
def run_experiment(self):
|
def run_experiment(self):
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
|
||||||
logger.info("Starting a new run...")
|
logger.debug("Starting a new run...")
|
||||||
|
|
||||||
run = wandb.init(
|
run = wandb.init(
|
||||||
project=self.project_name,
|
project=self.project_name,
|
||||||
config=config_loader.config,
|
config=config_loader.config,
|
||||||
save_code=True,
|
save_code=True
|
||||||
entity=wandb_user_name
|
|
||||||
)
|
)
|
||||||
run.log_code(
|
run.log_code(
|
||||||
root=".",
|
root=".",
|
||||||
include_fn=lambda path: path.endswith(".py") or path.endswith(".sh"),
|
include_fn=lambda path: path.endswith(".py") or path.endswith(".sh"),
|
||||||
exclude_fn=lambda path: path.endswith("template.sh")
|
exclude_fn=lambda path: path.endswith("template.sh")
|
||||||
)
|
) # not working as expected
|
||||||
|
|
||||||
model_name = get_arg(wandb.config.model_name)
|
model_name = wandb.config.model_name[0]
|
||||||
model_args = models.get_model_specific_args(wandb.config)
|
model_args = models.get_model_specific_args(wandb.config)
|
||||||
|
logger.debug(f"Initializing {model_name}")
|
||||||
|
|
||||||
if "pretrained" in wandb.config:
|
|
||||||
weights = get_arg(wandb.config.pretrained)
|
|
||||||
if weights != "false":
|
|
||||||
model_args["pretrained"] = weights
|
|
||||||
|
|
||||||
if "norm" in wandb.config:
|
|
||||||
model_args["norm"] = get_arg(wandb.config.norm)
|
|
||||||
|
|
||||||
if "finetuning" in wandb.config:
|
|
||||||
model_args['finetuning_strategy'] = get_arg(wandb.config.finetuning)
|
|
||||||
|
|
||||||
if "lr_reduce_factor" in wandb.config:
|
|
||||||
model_args['steplr_gamma'] = get_arg(wandb.config.lr_reduce_factor)
|
|
||||||
|
|
||||||
logger.debug(f"Initializing {model_name} with args: {model_args}")
|
|
||||||
model = models.__getattribute__(model_name + "Lit")(**model_args)
|
model = models.__getattribute__(model_name + "Lit")(**model_args)
|
||||||
|
|
||||||
train_loader, dev_loader = train.prepare_data(wandb.config, model, test_run=False)
|
train_loader, dev_loader = train.prepare_data(wandb.config, model, test_run=False)
|
||||||
@ -154,13 +132,10 @@ class HyperparameterSweep:
|
|||||||
|
|
||||||
early_stopping_callback = EarlyStopping(
|
early_stopping_callback = EarlyStopping(
|
||||||
monitor="val_loss",
|
monitor="val_loss",
|
||||||
patience=5,
|
patience=3,
|
||||||
verbose=True,
|
verbose=True,
|
||||||
mode="min")
|
mode="min")
|
||||||
|
callbacks = [checkpoint_callback, early_stopping_callback]
|
||||||
lr_monitor = LearningRateMonitor(logging_interval='epoch')
|
|
||||||
|
|
||||||
callbacks = [checkpoint_callback, early_stopping_callback, lr_monitor]
|
|
||||||
|
|
||||||
trainer = pl.Trainer(
|
trainer = pl.Trainer(
|
||||||
default_root_dir=config_loader.models_path,
|
default_root_dir=config_loader.models_path,
|
||||||
@ -168,6 +143,7 @@ class HyperparameterSweep:
|
|||||||
callbacks=callbacks,
|
callbacks=callbacks,
|
||||||
**get_trainer_args(wandb.config)
|
**get_trainer_args(wandb.config)
|
||||||
)
|
)
|
||||||
|
|
||||||
trainer.fit(model, train_loader, dev_loader)
|
trainer.fit(model, train_loader, dev_loader)
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
@ -179,6 +155,7 @@ class HyperparameterSweep:
|
|||||||
|
|
||||||
|
|
||||||
def start_sweep(sweep_config):
|
def start_sweep(sweep_config):
|
||||||
|
|
||||||
logger.info("Starting sweep with config: " + str(sweep_config))
|
logger.info("Starting sweep with config: " + str(sweep_config))
|
||||||
set_random_seed(config_loader.seed)
|
set_random_seed(config_loader.seed)
|
||||||
sweep_runner = HyperparameterSweep(project_name=wandb_project_name, sweep_config=sweep_config)
|
sweep_runner = HyperparameterSweep(project_name=wandb_project_name, sweep_config=sweep_config)
|
||||||
@ -188,6 +165,7 @@ def start_sweep(sweep_config):
|
|||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
|
||||||
parser = argparse.ArgumentParser()
|
parser = argparse.ArgumentParser()
|
||||||
parser.add_argument("--sweep_config", type=str, required=True)
|
parser.add_argument("--sweep_config", type=str, required=True)
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
@ -1,146 +0,0 @@
|
|||||||
from pydantic import BaseModel, ConfigDict, field_validator
|
|
||||||
from typing_extensions import Literal
|
|
||||||
from typing import Union, List, Optional
|
|
||||||
import yaml
|
|
||||||
import logging
|
|
||||||
|
|
||||||
logging.root.setLevel(logging.INFO)
|
|
||||||
logger = logging.getLogger('input_validator')
|
|
||||||
|
|
||||||
#todo
|
|
||||||
# 1. check if a single value is allowed in a sweep
|
|
||||||
# 2. merge input params
|
|
||||||
# 3. change names of the classes
|
|
||||||
# 4. add constraints for PhaseNet, GPD
|
|
||||||
|
|
||||||
|
|
||||||
model_names = Literal["PhaseNet", "GPD", "BasicPhaseAE", "EQTransformer"]
|
|
||||||
norm_values = Literal["peak", "std"]
|
|
||||||
finetuning_values = Literal["all", "top", "decoder", "encoder"]
|
|
||||||
pretrained_values = Literal['diting', 'ethz', 'geofon', 'instance', 'iquique', 'lendb', 'neic',
|
|
||||||
'original', 'scedc', False]
|
|
||||||
|
|
||||||
|
|
||||||
class Metric(BaseModel):
|
|
||||||
goal: str
|
|
||||||
name: str
|
|
||||||
|
|
||||||
|
|
||||||
class NumericValue(BaseModel):
|
|
||||||
value: Union[int, float, List[Union[int, float]]]
|
|
||||||
|
|
||||||
|
|
||||||
class NumericValues(BaseModel):
|
|
||||||
values: List[Union[int, float]]
|
|
||||||
|
|
||||||
|
|
||||||
class IntDistribution(BaseModel):
|
|
||||||
distribution: str = "int_uniform"
|
|
||||||
min: int
|
|
||||||
max: int
|
|
||||||
|
|
||||||
|
|
||||||
class FloatDistribution(BaseModel):
|
|
||||||
distribution: str = "uniform"
|
|
||||||
min: float
|
|
||||||
max: float
|
|
||||||
|
|
||||||
|
|
||||||
class Pretrained(BaseModel):
|
|
||||||
distribution: Optional[str] = "categorical"
|
|
||||||
values: List[pretrained_values] = None
|
|
||||||
value: Union[pretrained_values, List[pretrained_values]] = None
|
|
||||||
|
|
||||||
|
|
||||||
class Finetuning(BaseModel):
|
|
||||||
distribution: Optional[str] = "categorical"
|
|
||||||
values: List[finetuning_values] = None
|
|
||||||
value: Union[finetuning_values, List[finetuning_values]] = None
|
|
||||||
|
|
||||||
|
|
||||||
class Norm(BaseModel):
|
|
||||||
distribution: Optional[str] = "categorical"
|
|
||||||
values: List[norm_values] = None
|
|
||||||
value: Union[norm_values, List[norm_values]] = None
|
|
||||||
|
|
||||||
|
|
||||||
class ModelType(BaseModel):
|
|
||||||
distribution: Optional[str] = "categorical"
|
|
||||||
value: Union[model_names, List[model_names]] = None
|
|
||||||
values: List[model_names] = None
|
|
||||||
|
|
||||||
|
|
||||||
class Parameters(BaseModel):
|
|
||||||
model_config = ConfigDict(extra='forbid', protected_namespaces=())
|
|
||||||
model_name: ModelType
|
|
||||||
batch_size: Union[IntDistribution, NumericValue, NumericValues]
|
|
||||||
learning_rate: Union[FloatDistribution, NumericValue, NumericValues]
|
|
||||||
max_epochs: Union[IntDistribution, NumericValue, NumericValues]
|
|
||||||
|
|
||||||
|
|
||||||
class PhaseNetParameters(Parameters):
|
|
||||||
model_config = ConfigDict(extra='forbid')
|
|
||||||
norm: Norm = None
|
|
||||||
pretrained: Pretrained = None
|
|
||||||
finetuning: Finetuning = None
|
|
||||||
lr_reduce_factor: Optional[Union[FloatDistribution, NumericValue, NumericValues]] = None
|
|
||||||
|
|
||||||
highpass: Union[NumericValue, NumericValues, FloatDistribution, IntDistribution] = None
|
|
||||||
lowpass: Union[NumericValue, NumericValues, FloatDistribution, IntDistribution] = None
|
|
||||||
|
|
||||||
@field_validator("model_name")
|
|
||||||
def validate_model(cls, v):
|
|
||||||
if "PhaseNet" not in v.value:
|
|
||||||
raise ValueError("Additional parameters implemented for PhaseNet only")
|
|
||||||
return v
|
|
||||||
|
|
||||||
|
|
||||||
class FilteringParameters(Parameters):
|
|
||||||
model_config = ConfigDict(extra='forbid')
|
|
||||||
|
|
||||||
highpass: Union[NumericValue, NumericValues, FloatDistribution, IntDistribution] = None
|
|
||||||
lowpass: Union[NumericValue, NumericValues, FloatDistribution, IntDistribution] = None
|
|
||||||
|
|
||||||
@field_validator("model_name")
|
|
||||||
def validate_model(cls, v):
|
|
||||||
print(v.value)
|
|
||||||
if v.value[0] not in ["GPD", "PhaseNet"]:
|
|
||||||
raise ValueError("Filtering parameters implemented for GPD and PhaseNet only")
|
|
||||||
|
|
||||||
|
|
||||||
class InputParams(BaseModel):
|
|
||||||
name: str
|
|
||||||
method: str
|
|
||||||
metric: Metric
|
|
||||||
parameters: Union[Parameters, PhaseNetParameters, FilteringParameters]
|
|
||||||
|
|
||||||
|
|
||||||
def validate_sweep_yaml(yaml_filename, model_name=None):
|
|
||||||
# Load YAML configuration
|
|
||||||
with open(yaml_filename, 'r') as f:
|
|
||||||
sweep_config = yaml.safe_load(f)
|
|
||||||
|
|
||||||
validate_sweep_config(sweep_config, model_name)
|
|
||||||
|
|
||||||
|
|
||||||
def validate_sweep_config(sweep_config, model_name=None):
|
|
||||||
|
|
||||||
# Validate sweep config
|
|
||||||
|
|
||||||
input_params = InputParams(**sweep_config)
|
|
||||||
|
|
||||||
# Check consistency of input parameters and sweep configuration
|
|
||||||
sweep_model_name = input_params.parameters.model_name.value
|
|
||||||
if model_name is not None and model_name not in sweep_model_name:
|
|
||||||
info = f"Model name {model_name} is inconsistent with the sweep configuration {sweep_model_name}."
|
|
||||||
logger.info(info)
|
|
||||||
raise ValueError(info)
|
|
||||||
logger.info("Input validation successful.")
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
yaml_filename = "../experiments/sweep_phasenet_bogdanka_lr_bs.yaml"
|
|
||||||
validate_sweep_yaml(yaml_filename, None)
|
|
@ -7,26 +7,16 @@ import seisbench.generate as sbg
|
|||||||
|
|
||||||
import pytorch_lightning as pl
|
import pytorch_lightning as pl
|
||||||
import torch
|
import torch
|
||||||
from torch.optim import lr_scheduler
|
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from abc import abstractmethod, ABC
|
from abc import abstractmethod, ABC
|
||||||
|
|
||||||
# import lightning as L
|
|
||||||
|
|
||||||
|
|
||||||
# Allows to import this file in both jupyter notebook and code
|
# Allows to import this file in both jupyter notebook and code
|
||||||
try:
|
try:
|
||||||
from .augmentations import DuplicateEvent
|
from .augmentations import DuplicateEvent
|
||||||
except ImportError:
|
except ImportError:
|
||||||
from augmentations import DuplicateEvent
|
from augmentations import DuplicateEvent
|
||||||
|
|
||||||
import os
|
|
||||||
import logging
|
|
||||||
|
|
||||||
script_name = os.path.splitext(os.path.basename(__file__))[0]
|
|
||||||
logger = logging.getLogger(script_name)
|
|
||||||
logger.setLevel(logging.DEBUG)
|
|
||||||
|
|
||||||
# Phase dict for labelling. We only study P and S phases without differentiating between them.
|
# Phase dict for labelling. We only study P and S phases without differentiating between them.
|
||||||
phase_dict = {
|
phase_dict = {
|
||||||
@ -141,31 +131,7 @@ class PhaseNetLit(SeisBenchModuleLit):
|
|||||||
self.sigma = sigma
|
self.sigma = sigma
|
||||||
self.sample_boundaries = sample_boundaries
|
self.sample_boundaries = sample_boundaries
|
||||||
self.loss = vector_cross_entropy
|
self.loss = vector_cross_entropy
|
||||||
self.pretrained = kwargs.pop("pretrained", None)
|
self.model = sbm.PhaseNet(phases="PN", **kwargs)
|
||||||
self.norm = kwargs.pop("norm", "peak")
|
|
||||||
self.highpass = kwargs.pop("highpass", None)
|
|
||||||
self.lowpass = kwargs.pop("lowpass", None)
|
|
||||||
|
|
||||||
|
|
||||||
if self.pretrained is not None:
|
|
||||||
self.model = sbm.PhaseNet.from_pretrained(self.pretrained)
|
|
||||||
# self.norm = self.model.norm
|
|
||||||
else:
|
|
||||||
self.model = sbm.PhaseNet(**kwargs)
|
|
||||||
|
|
||||||
self.finetuning_strategy = kwargs.pop("finetuning_strategy", None)
|
|
||||||
self.steplr_gamma = kwargs.pop("steplr_gamma", 0.1)
|
|
||||||
self.reduce_lr_on_plateau = False
|
|
||||||
|
|
||||||
self.initial_epochs = 0
|
|
||||||
|
|
||||||
if self.finetuning_strategy is not None:
|
|
||||||
if self.finetuning_strategy == "top":
|
|
||||||
self.initial_epochs = 3
|
|
||||||
elif self.finetuning_strategy in ["decoder", "encoder"]:
|
|
||||||
self.initial_epochs = 6
|
|
||||||
|
|
||||||
self.freeze()
|
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
return self.model(x)
|
return self.model(x)
|
||||||
@ -188,49 +154,9 @@ class PhaseNetLit(SeisBenchModuleLit):
|
|||||||
|
|
||||||
def configure_optimizers(self):
|
def configure_optimizers(self):
|
||||||
optimizer = torch.optim.Adam(self.parameters(), lr=self.lr)
|
optimizer = torch.optim.Adam(self.parameters(), lr=self.lr)
|
||||||
if self.finetuning_strategy is not None:
|
return optimizer
|
||||||
scheduler = lr_scheduler.LambdaLR(optimizer, self.lr_lambda)
|
|
||||||
self.reduce_lr_on_plateau = False
|
|
||||||
else:
|
|
||||||
scheduler = lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.1, patience=3)
|
|
||||||
self.reduce_lr_on_plateau = True
|
|
||||||
#
|
|
||||||
return {
|
|
||||||
'optimizer': optimizer,
|
|
||||||
'lr_scheduler': {
|
|
||||||
'scheduler': scheduler,
|
|
||||||
'monitor': 'val_loss',
|
|
||||||
'interval': 'epoch',
|
|
||||||
'reduce_on_plateau': self.reduce_lr_on_plateau,
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
def lr_lambda(self, epoch):
|
|
||||||
# reduce lr after x initial epochs
|
|
||||||
if epoch == self.initial_epochs:
|
|
||||||
self.lr *= self.steplr_gamma
|
|
||||||
|
|
||||||
return self.lr
|
|
||||||
|
|
||||||
def lr_scheduler_step(self, scheduler, metric):
|
|
||||||
if self.reduce_lr_on_plateau:
|
|
||||||
scheduler.step(metric, epoch=self.current_epoch)
|
|
||||||
else:
|
|
||||||
scheduler.step(epoch=self.current_epoch)
|
|
||||||
|
|
||||||
# def lr_scheduler_step(self, scheduler, optimizer_idx, metric):
|
|
||||||
# scheduler.step(epoch=self.current_epoch)
|
|
||||||
|
|
||||||
def get_augmentations(self):
|
def get_augmentations(self):
|
||||||
filter = []
|
|
||||||
if self.highpass is not None:
|
|
||||||
filter = [sbg.Filter(1, self.highpass, "highpass", forward_backward=True)]
|
|
||||||
logger.info(f"Using highpass filer {self.highpass}")
|
|
||||||
if self.lowpass is not None:
|
|
||||||
filter += [sbg.Filter(1, self.lowpass, "lowpass", forward_backward=True)]
|
|
||||||
logger.info(f"Using lowpass filer {self.lowpass}")
|
|
||||||
logger.info(filter)
|
|
||||||
|
|
||||||
return [
|
return [
|
||||||
# In 2/3 of the cases, select windows around picks, to reduce amount of noise traces in training.
|
# In 2/3 of the cases, select windows around picks, to reduce amount of noise traces in training.
|
||||||
# Uses strategy variable, as padding will be handled by the random window.
|
# Uses strategy variable, as padding will be handled by the random window.
|
||||||
@ -254,26 +180,18 @@ class PhaseNetLit(SeisBenchModuleLit):
|
|||||||
windowlen=3001,
|
windowlen=3001,
|
||||||
strategy="pad",
|
strategy="pad",
|
||||||
),
|
),
|
||||||
sbg.Normalize(demean_axis=-1, amp_norm_axis=-1, amp_norm_type=self.norm),
|
|
||||||
*filter,
|
|
||||||
sbg.ChangeDtype(np.float32),
|
sbg.ChangeDtype(np.float32),
|
||||||
|
sbg.Normalize(demean_axis=-1, amp_norm_axis=-1, amp_norm_type="peak"),
|
||||||
sbg.ProbabilisticLabeller(
|
sbg.ProbabilisticLabeller(
|
||||||
label_columns=phase_dict, sigma=self.sigma, dim=0
|
label_columns=phase_dict, sigma=self.sigma, dim=0
|
||||||
),
|
),
|
||||||
]
|
]
|
||||||
|
|
||||||
def get_eval_augmentations(self):
|
def get_eval_augmentations(self):
|
||||||
filter = []
|
|
||||||
if self.highpass is not None:
|
|
||||||
filter = [sbg.Filter(1, self.highpass, "highpass", forward_backward=True)]
|
|
||||||
if self.lowpass is not None:
|
|
||||||
filter += [sbg.Filter(1, self.lowpass, "lowpass", forward_backward=True)]
|
|
||||||
|
|
||||||
return [
|
return [
|
||||||
sbg.SteeredWindow(windowlen=3001, strategy="pad"),
|
sbg.SteeredWindow(windowlen=3001, strategy="pad"),
|
||||||
sbg.Normalize(demean_axis=-1, amp_norm_axis=-1, amp_norm_type=self.norm),
|
|
||||||
*filter,
|
|
||||||
sbg.ChangeDtype(np.float32),
|
sbg.ChangeDtype(np.float32),
|
||||||
|
sbg.Normalize(demean_axis=-1, amp_norm_axis=-1, amp_norm_type="peak"),
|
||||||
]
|
]
|
||||||
|
|
||||||
def predict_step(self, batch, batch_idx=None, dataloader_idx=None):
|
def predict_step(self, batch, batch_idx=None, dataloader_idx=None):
|
||||||
@ -301,27 +219,6 @@ class PhaseNetLit(SeisBenchModuleLit):
|
|||||||
|
|
||||||
return score_detection, score_p_or_s, p_sample, s_sample
|
return score_detection, score_p_or_s, p_sample, s_sample
|
||||||
|
|
||||||
def freeze(self):
|
|
||||||
if self.finetuning_strategy == "decoder": # finetune decoder branch and freeze encoder branch
|
|
||||||
for p in self.model.down_branch.parameters():
|
|
||||||
p.requires_grad = False
|
|
||||||
elif self.finetuning_strategy == "encoder": # finetune encoder branch and freeze decoder branch
|
|
||||||
for p in self.model.up_branch.parameters():
|
|
||||||
p.requires_grad = False
|
|
||||||
elif self.finetuning_strategy == "top":
|
|
||||||
for p in self.model.out.parameters():
|
|
||||||
p.requires_grad = False
|
|
||||||
|
|
||||||
def unfreeze(self):
|
|
||||||
logger.info("Unfreezing layers")
|
|
||||||
for p in self.model.parameters():
|
|
||||||
p.requires_grad = True
|
|
||||||
|
|
||||||
def on_train_epoch_start(self):
|
|
||||||
# Unfreeze some layers after x initial epochs
|
|
||||||
if self.current_epoch == self.initial_epochs:
|
|
||||||
self.unfreeze()
|
|
||||||
|
|
||||||
|
|
||||||
class GPDLit(SeisBenchModuleLit):
|
class GPDLit(SeisBenchModuleLit):
|
||||||
"""
|
"""
|
||||||
@ -1215,6 +1112,7 @@ class DPPPickerLit(SeisBenchModuleLit):
|
|||||||
|
|
||||||
|
|
||||||
def get_model_specific_args(config):
|
def get_model_specific_args(config):
|
||||||
|
|
||||||
model = config.model_name[0]
|
model = config.model_name[0]
|
||||||
lr = config.learning_rate
|
lr = config.learning_rate
|
||||||
if type(lr) == list:
|
if type(lr) == list:
|
||||||
@ -1227,12 +1125,8 @@ def get_model_specific_args(config):
|
|||||||
if 'highpass' in config:
|
if 'highpass' in config:
|
||||||
args['highpass'] = config.highpass
|
args['highpass'] = config.highpass
|
||||||
if 'lowpass' in config:
|
if 'lowpass' in config:
|
||||||
args['lowpass'] = config.lowpass
|
args['lowpass'] = config.lowpass[0]
|
||||||
case 'PhaseNet':
|
case 'PhaseNet':
|
||||||
if 'highpass' in config:
|
|
||||||
args['highpass'] = config.highpass
|
|
||||||
if 'lowpass' in config:
|
|
||||||
args['lowpass'] = config.lowpass
|
|
||||||
if 'sample_boundaries' in config:
|
if 'sample_boundaries' in config:
|
||||||
args['sample_boundaries'] = config.sample_boundaries
|
args['sample_boundaries'] = config.sample_boundaries
|
||||||
case 'DPPPicker':
|
case 'DPPPicker':
|
||||||
|
@ -23,11 +23,10 @@ logging.basicConfig(filename="output.out",
|
|||||||
datefmt='%H:%M:%S',
|
datefmt='%H:%M:%S',
|
||||||
level=logging.DEBUG)
|
level=logging.DEBUG)
|
||||||
|
|
||||||
logging.root.setLevel(logging.INFO)
|
|
||||||
logger = logging.getLogger('converter')
|
logger = logging.getLogger('converter')
|
||||||
|
|
||||||
|
|
||||||
def split_events(events, input_path):
|
def split_events(events, input_path):
|
||||||
|
|
||||||
logger.info("Splitting available events into train, dev and test sets ...")
|
logger.info("Splitting available events into train, dev and test sets ...")
|
||||||
events_stats = pd.DataFrame()
|
events_stats = pd.DataFrame()
|
||||||
events_stats.index.name = "event"
|
events_stats.index.name = "event"
|
||||||
@ -58,10 +57,6 @@ def split_events(events, input_path):
|
|||||||
else:
|
else:
|
||||||
break
|
break
|
||||||
|
|
||||||
logger.info(f"Split: {events_stats['split'].value_counts()}")
|
|
||||||
|
|
||||||
logger.info(f"Split: {events_stats['split'].value_counts()}")
|
|
||||||
|
|
||||||
return events_stats
|
return events_stats
|
||||||
|
|
||||||
|
|
||||||
@ -96,6 +91,7 @@ def get_event_params(event):
|
|||||||
|
|
||||||
|
|
||||||
def get_trace_params(pick):
|
def get_trace_params(pick):
|
||||||
|
|
||||||
trace_params = {
|
trace_params = {
|
||||||
"station_network_code": pick.waveform_id.network_code,
|
"station_network_code": pick.waveform_id.network_code,
|
||||||
"station_code": pick.waveform_id.station_code,
|
"station_code": pick.waveform_id.station_code,
|
||||||
@ -128,18 +124,16 @@ def get_trace_path(input_path, trace_params):
|
|||||||
path = f"{input_path}/{year}/{net}/{station}/{tr_channel}.D/{net}.{station}..{tr_channel}.D.{year}.{day_of_year}"
|
path = f"{input_path}/{year}/{net}/{station}/{tr_channel}.D/{net}.{station}..{tr_channel}.D.{year}.{day_of_year}"
|
||||||
return path
|
return path
|
||||||
|
|
||||||
|
|
||||||
def get_three_channels_trace_paths(input_path, trace_params):
|
def get_three_channels_trace_paths(input_path, trace_params):
|
||||||
year = trace_params["time"].year
|
year = trace_params["time"].year
|
||||||
day_of_year = pd.Timestamp(str(trace_params["time"])).day_of_year
|
day_of_year = pd.Timestamp(str(trace_params["time"])).day_of_year
|
||||||
net = trace_params["station_network_code"]
|
net = trace_params["station_network_code"]
|
||||||
station = trace_params["station_code"]
|
station = trace_params["station_code"]
|
||||||
channel_base = trace_params["trace_channel"]
|
|
||||||
paths = []
|
paths = []
|
||||||
for ch in ["E", "N", "Z"]:
|
for channel in ["EHE", "EHN", "EHZ"]:
|
||||||
channel = channel_base[:-1] + ch
|
paths.append(f"{input_path}/{year}/{net}/{station}/{channel}.D/{net}.{station}..{channel}.D.{year}.{day_of_year}")
|
||||||
paths.append(
|
|
||||||
f"{input_path}/{year}/{net}/{station}/{channel}.D/{net}.{station}..{channel}.D.{year}.{day_of_year:03}")
|
|
||||||
return paths
|
return paths
|
||||||
|
|
||||||
|
|
||||||
@ -181,7 +175,6 @@ def load_stream(input_path, trace_params, time_before=60, time_after=60):
|
|||||||
print(f"no data in: {trace_path}")
|
print(f"no data in: {trace_path}")
|
||||||
else:
|
else:
|
||||||
sampling_rate = stream.traces[0].stats.sampling_rate
|
sampling_rate = stream.traces[0].stats.sampling_rate
|
||||||
stream.merge()
|
|
||||||
|
|
||||||
return sampling_rate, stream
|
return sampling_rate, stream
|
||||||
|
|
||||||
@ -201,10 +194,7 @@ def convert_mseed_to_seisbench_format(input_path, catalog_path, output_path):
|
|||||||
metadata_path = output_path + "/metadata.csv"
|
metadata_path = output_path + "/metadata.csv"
|
||||||
waveforms_path = output_path + "/waveforms.hdf5"
|
waveforms_path = output_path + "/waveforms.hdf5"
|
||||||
|
|
||||||
events_to_convert = events_stats[events_stats['pick_count'] > 0]
|
logger.debug("Catalog loaded, starting conversion ...")
|
||||||
|
|
||||||
|
|
||||||
logger.debug("Catalog loaded, starting converting {events_to_convert} events ...")
|
|
||||||
|
|
||||||
with sbd.WaveformDataWriter(metadata_path, waveforms_path) as writer:
|
with sbd.WaveformDataWriter(metadata_path, waveforms_path) as writer:
|
||||||
writer.data_format = {
|
writer.data_format = {
|
||||||
@ -229,7 +219,7 @@ def convert_mseed_to_seisbench_format(input_path, catalog_path, output_path):
|
|||||||
|
|
||||||
trace_params = get_trace_params(picks[0])
|
trace_params = get_trace_params(picks[0])
|
||||||
sampling_rate, stream = load_stream(input_path, trace_params)
|
sampling_rate, stream = load_stream(input_path, trace_params)
|
||||||
if stream is None or len(stream.traces) == 0:
|
if stream is None:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
actual_t_start, data, _ = sbu.stream_to_array(
|
actual_t_start, data, _ = sbu.stream_to_array(
|
||||||
@ -249,10 +239,12 @@ def convert_mseed_to_seisbench_format(input_path, catalog_path, output_path):
|
|||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
|
||||||
parser = argparse.ArgumentParser(description='Convert mseed files to seisbench format')
|
parser = argparse.ArgumentParser(description='Convert mseed files to seisbench format')
|
||||||
parser.add_argument('--input_path', type=str, help='Path to mseed files')
|
parser.add_argument('--input_path', type=str, help='Path to mseed files')
|
||||||
parser.add_argument('--catalog_path', type=str, help='Path to events catalog in quakeml format')
|
parser.add_argument('--catalog_path', type=str, help='Path to events catalog in quakeml format')
|
||||||
parser.add_argument('--output_path', type=str, help='Path to output files')
|
parser.add_argument('--output_path', type=str, help='Path to output files')
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
|
||||||
convert_mseed_to_seisbench_format(args.input_path, args.catalog_path, args.output_path)
|
convert_mseed_to_seisbench_format(args.input_path, args.catalog_path, args.output_path)
|
||||||
|
@ -1,149 +0,0 @@
|
|||||||
import json
|
|
||||||
import pathlib
|
|
||||||
|
|
||||||
import numpy as np
|
|
||||||
import obspy
|
|
||||||
import pandas as pd
|
|
||||||
import seisbench.data as sbd
|
|
||||||
import seisbench.models as sbm
|
|
||||||
from seisbench.models.team import itertools
|
|
||||||
from sklearn.metrics import precision_recall_curve, roc_auc_score, roc_curve
|
|
||||||
|
|
||||||
datasets = [
|
|
||||||
# path to datasets in seisbench format
|
|
||||||
]
|
|
||||||
|
|
||||||
models = [
|
|
||||||
# model names
|
|
||||||
]
|
|
||||||
|
|
||||||
|
|
||||||
def find_keys_phase(meta, phase):
|
|
||||||
phases = []
|
|
||||||
for k in meta.keys():
|
|
||||||
if k.startswith("trace_" + phase) and k.endswith("_arrival_sample"):
|
|
||||||
phases.append(k)
|
|
||||||
|
|
||||||
return phases
|
|
||||||
|
|
||||||
|
|
||||||
def create_stream(meta, raw, start, length=30):
|
|
||||||
|
|
||||||
st = obspy.Stream()
|
|
||||||
|
|
||||||
for i in range(3):
|
|
||||||
tr = obspy.Trace(raw[i, :])
|
|
||||||
tr.stats.starttime = meta["trace_start_time"]
|
|
||||||
tr.stats.sampling_rate = meta["trace_sampling_rate_hz"]
|
|
||||||
tr.stats.network = meta["station_network_code"]
|
|
||||||
tr.stats.station = meta["station_code"]
|
|
||||||
tr.stats.channel = meta["trace_channel"][:2] + meta["trace_component_order"][i]
|
|
||||||
|
|
||||||
stop = start + length
|
|
||||||
tr = tr.slice(start, stop)
|
|
||||||
|
|
||||||
st.append(tr)
|
|
||||||
|
|
||||||
return st
|
|
||||||
|
|
||||||
|
|
||||||
def get_pred(model, stream):
|
|
||||||
ann = model.annotate(stream)
|
|
||||||
noise = ann.select(channel="PhaseNet_N")[0]
|
|
||||||
pred = max(1 - noise.data)
|
|
||||||
return pred
|
|
||||||
|
|
||||||
|
|
||||||
def to_short(stream):
|
|
||||||
short = [tr for tr in stream if tr.data.shape[0] < 3001]
|
|
||||||
return any(short)
|
|
||||||
|
|
||||||
|
|
||||||
for ds, model_name in itertools.product(datasets, models):
|
|
||||||
|
|
||||||
data = sbd.WaveformDataset(ds, sampling_rate=100).test()
|
|
||||||
data_name = pathlib.Path(ds).stem
|
|
||||||
fname = f"roc___{model_name}___{data_name}.csv"
|
|
||||||
|
|
||||||
print(f"{fname:.<50s}.... ", flush=True, end="")
|
|
||||||
|
|
||||||
if pathlib.Path(fname).is_file():
|
|
||||||
print(" ready, skipping", flush=True)
|
|
||||||
continue
|
|
||||||
|
|
||||||
p_labels = find_keys_phase(data.metadata, "P")
|
|
||||||
s_labels = find_keys_phase(data.metadata, "S")
|
|
||||||
|
|
||||||
model = sbm.PhaseNet().from_pretrained(model_name)
|
|
||||||
|
|
||||||
label_true = []
|
|
||||||
label_pred = []
|
|
||||||
|
|
||||||
for i in range(len(data)):
|
|
||||||
|
|
||||||
waveform, metadata = data.get_sample(i)
|
|
||||||
m = pd.Series(metadata)
|
|
||||||
|
|
||||||
has_p_label = m[p_labels].notna()
|
|
||||||
has_s_label = m[s_labels].notna()
|
|
||||||
|
|
||||||
if any(has_p_label):
|
|
||||||
|
|
||||||
trace_start_time = obspy.UTCDateTime(m["trace_start_time"])
|
|
||||||
pick_sample = m[p_labels][has_p_label][0]
|
|
||||||
|
|
||||||
start = trace_start_time + pick_sample / m["trace_sampling_rate_hz"] - 15
|
|
||||||
|
|
||||||
try:
|
|
||||||
st_p = create_stream(m, waveform, start)
|
|
||||||
if not (to_short(st_p)):
|
|
||||||
pred_p = get_pred(model, st_p)
|
|
||||||
label_true.append(1)
|
|
||||||
label_pred.append(pred_p)
|
|
||||||
except IndexError:
|
|
||||||
pass
|
|
||||||
|
|
||||||
try:
|
|
||||||
st_n = create_stream(m, waveform, trace_start_time + 1)
|
|
||||||
if not (to_short(st_n)):
|
|
||||||
pred_n = get_pred(model, st_n)
|
|
||||||
label_true.append(0)
|
|
||||||
label_pred.append(pred_n)
|
|
||||||
except IndexError:
|
|
||||||
pass
|
|
||||||
|
|
||||||
if any(has_s_label):
|
|
||||||
trace_start_time = obspy.UTCDateTime(m["trace_start_time"])
|
|
||||||
pick_sample = m[s_labels][has_s_label][0]
|
|
||||||
start = trace_start_time + pick_sample / m["trace_sampling_rate_hz"] - 15
|
|
||||||
|
|
||||||
try:
|
|
||||||
st_s = create_stream(m, waveform, start)
|
|
||||||
if not (to_short(st_s)):
|
|
||||||
pred_s = get_pred(model, st_s)
|
|
||||||
label_true.append(1)
|
|
||||||
label_pred.append(pred_s)
|
|
||||||
except IndexError:
|
|
||||||
pass
|
|
||||||
|
|
||||||
fpr, tpr, roc_thresholds = roc_curve(label_true, label_pred)
|
|
||||||
df = pd.DataFrame({"fpr": fpr, "tpr": tpr, "thresholds": roc_thresholds})
|
|
||||||
df.to_csv(fname)
|
|
||||||
|
|
||||||
precision, recall, prc_thresholds = precision_recall_curve(label_true, label_pred)
|
|
||||||
prc_thresholds_extra = np.append(prc_thresholds, -999)
|
|
||||||
df = pd.DataFrame(
|
|
||||||
{"pre": precision, "rec": recall, "thresholds": prc_thresholds_extra}
|
|
||||||
)
|
|
||||||
df.to_csv(fname.replace("roc", "pr"))
|
|
||||||
|
|
||||||
stats = {
|
|
||||||
"model": str(model_name),
|
|
||||||
"data": str(data_name),
|
|
||||||
"auc": float(roc_auc_score(label_true, label_pred)),
|
|
||||||
}
|
|
||||||
|
|
||||||
with open(f"stats___{model_name}___{data_name}.json", "w") as fp:
|
|
||||||
json.dump(stats, fp)
|
|
||||||
|
|
||||||
print(" finished", flush=True)
|
|
@ -17,8 +17,6 @@ import eval
|
|||||||
import collect_results
|
import collect_results
|
||||||
import importlib
|
import importlib
|
||||||
import config_loader
|
import config_loader
|
||||||
import input_validate
|
|
||||||
import data
|
|
||||||
|
|
||||||
logging.root.setLevel(logging.INFO)
|
logging.root.setLevel(logging.INFO)
|
||||||
logger = logging.getLogger('pipeline')
|
logger = logging.getLogger('pipeline')
|
||||||
@ -42,22 +40,11 @@ def load_sweep_config(model_name, args):
|
|||||||
return util.load_sweep_config(sweep_fname)
|
return util.load_sweep_config(sweep_fname)
|
||||||
|
|
||||||
|
|
||||||
def validate_pipeline_input(args):
|
def find_the_best_params(model_name, args):
|
||||||
|
|
||||||
# validate input parameters
|
|
||||||
for model_name in args.models:
|
|
||||||
sweep_config = load_sweep_config(model_name, args)
|
|
||||||
input_validate.validate_sweep_config(sweep_config, model_name)
|
|
||||||
|
|
||||||
# validate dataset
|
|
||||||
data.validate_custom_dataset(config_loader.data_path)
|
|
||||||
|
|
||||||
|
|
||||||
def find_the_best_params(sweep_config):
|
|
||||||
# find the best hyperparams for the model_name
|
# find the best hyperparams for the model_name
|
||||||
model_name = sweep_config['parameters']['model_name']
|
|
||||||
logger.info(f"Starting searching for the best hyperparams for the model: {model_name}")
|
logger.info(f"Starting searching for the best hyperparams for the model: {model_name}")
|
||||||
|
|
||||||
|
sweep_config = load_sweep_config(model_name, args)
|
||||||
sweep_runner = hyperparameter_sweep.start_sweep(sweep_config)
|
sweep_runner = hyperparameter_sweep.start_sweep(sweep_config)
|
||||||
|
|
||||||
# wait for all runs to finish
|
# wait for all runs to finish
|
||||||
@ -92,20 +79,12 @@ def main():
|
|||||||
parser.add_argument("--basic_phase_ae_config", type=str, required=False)
|
parser.add_argument("--basic_phase_ae_config", type=str, required=False)
|
||||||
parser.add_argument("--eqtransformer_config", type=str, required=False)
|
parser.add_argument("--eqtransformer_config", type=str, required=False)
|
||||||
parser.add_argument("--dataset", type=str, required=False)
|
parser.add_argument("--dataset", type=str, required=False)
|
||||||
available_models = ["GPD", "PhaseNet", "BasicPhaseAE", "EQTransformer"]
|
|
||||||
parser.add_argument("--models", nargs='*', required=False, choices=available_models, default=available_models,
|
|
||||||
help="Models to train and evaluate (default: all)")
|
|
||||||
parser.add_argument("--collect_results", action="store_true", help="Collect and log results without training")
|
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
if not args.collect_results:
|
|
||||||
|
|
||||||
if args.dataset is not None:
|
if args.dataset is not None:
|
||||||
util.set_dataset(args.dataset)
|
util.set_dataset(args.dataset)
|
||||||
importlib.reload(config_loader)
|
importlib.reload(config_loader)
|
||||||
|
|
||||||
validate_pipeline_input(args)
|
|
||||||
|
|
||||||
logger.info(f"Started pipeline for the {config_loader.dataset_name} dataset.")
|
logger.info(f"Started pipeline for the {config_loader.dataset_name} dataset.")
|
||||||
|
|
||||||
# generate labels
|
# generate labels
|
||||||
@ -115,9 +94,10 @@ def main():
|
|||||||
|
|
||||||
# find the best hyperparams for the models
|
# find the best hyperparams for the models
|
||||||
logger.info("Started training the models.")
|
logger.info("Started training the models.")
|
||||||
for model_name in args.models:
|
for model_name in ["GPD", "PhaseNet", "BasicPhaseAE", "EQTransformer"]:
|
||||||
sweep_config = load_sweep_config(model_name, args)
|
if config_loader.dataset_name == "lumineos" and model_name == "EQTransformer":
|
||||||
sweep_id = find_the_best_params(sweep_config)
|
break
|
||||||
|
sweep_id = find_the_best_params(model_name, args)
|
||||||
generate_predictions(sweep_id, model_name)
|
generate_predictions(sweep_id, model_name)
|
||||||
|
|
||||||
# collect results
|
# collect results
|
||||||
|
Loading…
Reference in New Issue
Block a user