Compare commits

...

10 Commits

22 changed files with 886 additions and 459 deletions

3
.gitignore vendored
View File

@ -3,6 +3,7 @@ __pycache__/
*/.ipynb_checkpoints/
.ipynb_checkpoints/
.env
*.out
weights/
datasets/
wip
@ -10,4 +11,4 @@ artifacts/
wandb/
scripts/pred/
scripts/pred_resampled/
scripts/lightning_logs/
scripts/lightning_logs/

View File

@ -2,12 +2,13 @@
This repo contains notebooks and scripts demonstrating how to:
- Prepare data for training a seisbench model detecting P and S waves (i.e. transform mseeds into [SeisBench data format](https://seisbench.readthedocs.io/en/stable/pages/data_format.html)), check the [notebook](utils/Transforming%20mseeds%20from%20Bogdanka%20to%20Seisbench%20format.ipynb) and the [script](utils/mseeds_to_seisbench.py)
- [to update] Explore available data, check the [notebook](notebooks/Explore%20igf%20data.ipynb)
- Prepare data for training a seisbench model detecting P and S waves (i.e. transform mseeds into [SeisBench data format](https://seisbench.readthedocs.io/en/stable/pages/data_format.html)), check the [notebook](notebooks/Transforming%20mseeds%20from%20Bogdanka%20to%20Seisbench%20format.ipynb) and the [script](scripts/mseeds_to_seisbench.py)
[//]: # (- [to update] Explore available data, check the [notebook](notebooks/Explore%20igf%20data.ipynb))
- Train various cnn models available in seisbench library and compare their performance of detecting P and S waves, check the [script](scripts/pipeline.py)
- [to update] Validate model performance, check the [notebook](notebooks/Check%20model%20performance%20depending%20on%20station-random%20window.ipynb)
- [to update] Use model for detecting P phase, check the [notebook](notebooks/Present%20model%20predictions.ipynb)
[//]: # (- [to update] Validate model performance, check the [notebook](notebooks/Check%20model%20performance%20depending%20on%20station-random%20window.ipynb))
[//]: # (- [to update] Use model for detecting P phase, check the [notebook](notebooks/Present%20model%20predictions.ipynb))
### Acknowledgments
@ -21,7 +22,7 @@ Please download and install [Mambaforge](https://github.com/conda-forge/miniforg
After successful installation and within the Mambaforge environment please clone this repository:
```
git clone ssh://git@git.plgrid.pl:7999/eai/platform-demo-scripts.git
git clone https://epos-apps.grid.cyfronet.pl/epos-ai/platform-demo-scripts.git
```
and please run for Linux or Windows platforms:
@ -69,10 +70,13 @@ poetry shell
WANDB_USER="your user"
WANDB_PROJECT="training_seisbench_models"
BENCHMARK_DEFAULT_WORKER=2
```
2. Transform data into seisbench format.
To utilize functionality of Seisbench library, data need to be transformed to [SeisBench data format](https://seisbench.readthedocs.io/en/stable/pages/data_format.html)). If your data is in the MSEED format, you can use the prepared script `mseeds_to_seisbench.py` to perform the transformation. Please make sure that your data has the same structure as the data used in this project.
To utilize functionality of Seisbench library, data need to be transformed to [SeisBench data format](https://seisbench.readthedocs.io/en/stable/pages/data_format.html)).
If your data is stored in the MSEED format and catalog in the QuakeML format, you can use the prepared script `mseeds_to_seisbench.py` to perform the transformation. Please make sure that your data has the same structure as the data used in this project.
The script assumes that:
* the data is stored in the following directory structure:
`input_path/year/station_network_code/station_code/trace_channel.D` e.g.
@ -80,24 +84,20 @@ poetry shell
* the file names follow the pattern:
`station_network_code.station_code..trace_channel.D.year.day_of_year`
e.g. `PL.ALBE..EHE.D.2018.282`
* events catalog is stored in quakeML format
Run the script `mseeds_to_seisbench` located in the `utils` directory
Run the `mseeds_to_seisbench.py` script with the following arguments:
```
cd utils
python mseeds_to_seisbench.py --input_path $input_path --catalog_path $catalog_path --output_path $output_path
```
If you want to run the script on a cluster, you can use the script `convert_data.sh` as a template (adjust the grant name, computing name and paths) and send the job to queue using sbatch command on login node of e.g. Ares:
```
cd utils
sbatch convert_data.sh
If you want to run the script on a cluster, you can use the template script `convert_data_template.sh`.
After adjusting the grant name, the paths to conda env and the paths to data send the job to queue using sbatch command on a login node of e.g. Ares:
```
sbatch convert_data_template.sh
```
If your data has a different structure or format, use the notebooks to gain an understanding of the Seisbench format and what needs to be done to transform your data:
If your data has a different structure or format, check the notebooks to gain an understanding of the Seisbench format and what needs to be done to transform your data:
* [Seisbench example](https://colab.research.google.com/github/seisbench/seisbench/blob/main/examples/01a_dataset_basics.ipynb) or
* [Transforming mseeds from Bogdanka to Seisbench format](utils/Transforming mseeds from Bogdanka to Seisbench format.ipynb) notebook
* [Transforming mseeds from Bogdanka to Seisbench format](notebooks/Transforming mseeds from Bogdanka to Seisbench format.ipynb) notebook
3. Adjust the `config.json` and specify:
@ -110,34 +110,65 @@ poetry shell
`python pipeline.py`
The script performs the following steps:
* Generates evaluation targets in `datasets/<dataset_name>/targets` directory.
* Trains multiple versions of GPD, PhaseNet and ... models to find the best hyperparameters, producing the lowest validation loss.
1. Generates evaluation targets in `datasets/<dataset_name>/targets` directory.
1. Trains multiple versions of GPD, PhaseNet, BasicPhaseAE, and EQTransformer models to find the best hyperparameters, producing the lowest validation loss.
This step utilizes the Weights & Biases platform to perform the hyperparameters search (called sweeping) and track the training process and store the results.
The results are available at
`https://epos-ai.grid.cyfronet.pl/<WANDB_USER>/<WANDB_PROJECT>`
Weights and training logs can be downloaded from the platform.
Additionally, the most important data are saved locally in `weights/<dataset_name>_<model_name>/ ` directory:
* Weights of the best checkpoint of each model are saved as `<dataset_name>_<model_name>_sweep=<sweep_id>-run=<run_id>-epoch=<epoch_number>-val_loss=<val_loss>.ckpt`
* Metrics and hyperparams are saved in <run_id> folders
This step utilizes the Weights & Biases platform to perform the hyperparameters search (called sweeping) and track the training process and store the results.
The results are available at
`https://epos-ai.grid.cyfronet.pl/<WANDB_USER>/<WANDB_PROJECT>`
Weights and training logs can be downloaded from the platform.
Additionally, the most important data are saved locally in `weights/<dataset_name>_<model_name>/ ` directory:
* Weights of the best checkpoint of each model are saved as `<dataset_name>_<model_name>_sweep=<sweep_id>-run=<run_id>-epoch=<epoch_number>-val_loss=<val_loss>.ckpt`
* Metrics and hyperparams are saved in <run_id> folders
* Uses the best performing model of each type to generate predictions. The predictons are saved in the `scripts/pred/<dataset_name>_<model_name>/<run_id>` directory.
* Evaluates the performance of each model by comparing the predictions with the evaluation targets.
The results are saved in the `scripts/pred/results.csv` file.
1. Uses the best performing model of each type to generate predictions. The predictons are saved in the `scripts/pred/<dataset_name>_<model_name>/<run_id>` directory.
1. Evaluates the performance of each model by comparing the predictions with the evaluation targets and calculating MAE metrics.
The results are saved in the `scripts/pred/results.csv` file. They are additionally logged in Weights & Biases platform as summary metrics of corresponding runs.
<br/>
The default settings for max number of experiments and paths are saved in config.json file. To change the settings, edit the config.json file or pass the new settings as arguments to the script. For example, to change the sweep configuration file for the GPD model, run:
The default settings are saved in config.json file. To change the settings, edit the config.json file or pass the new settings as arguments to the script.
For example, to change the sweep configuration file for GPD model, run:
`python pipeline.py --gpd_config <new config file>`
The new config file should be placed in the `experiments` folder or as specified in the `configs_path` parameter in the config.json file.
```python pipeline.py --gpd_config <new config file>```
The new config file should be placed in the `experiments` folder or as specified in the `configs_path` parameter in the config.json file.
Sweep configs are used to define the max number of epochs to run and the hyperparameters search space for the following parameters:
* `batch_size`
* `learning_rate`
Phasenet model has additional available parameters:
* `norm` - normalization method, options ('peak', 'std')
* `pretrained` - pretrained seisbench models used for transfer learning
* `finetuning` - the type of layers to finetune first, options ('all', 'top', 'encoder', 'decoder')
* `lr_reduce_factor` - factor to reduce learning rate after unfreezing layers
GPD model has additional parameters for filtering:
* `highpass` - highpass filter frequency
* `lowpass` - lowpass filter frequency
The sweep configs are saved in the `experiments` folder.
If you have multiple datasets, you can run the pipeline for each dataset separately by specifying the dataset name as an argument:
```python pipeline.py --dataset <dataset_name>```
### Troubleshooting
* Problem with reading the catalog file: please make sure that your quakeML xml file has the following opening and closing tags:
```
<?xml version="1.0"?>
<q:quakeml xmlns="http://quakeml.org/xmlns/bed/1.2" xmlns:q="http://quakeml.org/xmlns/quakeml/1.2">
....
</q:quakeml>
```
* `wandb: ERROR Run .. errored: OSError(24, 'Too many open files')`
-> https://github.com/wandb/wandb/issues/2825
### Licence
TODO
The code is licenced under the GNU General Public License v3.0. See the [LICENSE](LICENSE.txt) file for details.
### Copyright

View File

@ -1,15 +1,17 @@
{
"dataset_name": "bogdanka",
"data_path": "datasets/bogdanka/seisbench_format/",
"targets_path": "datasets/targets",
"models_path": "weights",
"configs_path": "experiments",
"sampling_rate": 100,
"num_workers": 1,
"seed": 10,
"sweep_files": {
"GPD": "sweep_gpd.yaml",
"PhaseNet": "sweep_phasenet.yaml"
},
"experiment_count": 20
"dataset_name": "bogdanka_2018_2022",
"data_path": "datasets/bogdanka_2018_2022/seisbench_format/",
"targets_path": "datasets/targets",
"models_path": "weights",
"configs_path": "experiments",
"sampling_rate": 100,
"num_workers": 1,
"seed": 10,
"sweep_files": {
"GPD": "sweep_gpd.yaml",
"PhaseNet": "sweep_phasenet.yaml",
"BasicPhaseAE": "sweep_basicphase_ae.yaml",
"EQTransformer": "sweep_eqtransformer.yaml"
},
"experiment_count": 15
}

View File

@ -0,0 +1,16 @@
name: BasicPhaseAE
method: bayes
metric:
goal: minimize
name: val_loss
parameters:
model_name:
value:
- BasicPhaseAE
batch_size:
values: [64, 128, 256]
max_epochs:
value:
- 30
learning_rate:
values: [0.01, 0.005, 0.001]

View File

@ -0,0 +1,16 @@
name: EQTransformer
method: bayes
metric:
goal: minimize
name: val_loss
parameters:
model_name:
value:
- EQTransformer
batch_size:
values: [64, 128, 256]
max_epochs:
value:
- 30
learning_rate:
values: [0.01, 0.005, 0.001]

View File

@ -1,4 +1,4 @@
name: GPD_fixed_highpass:2-10
name: GPD
method: bayes
metric:
goal: minimize
@ -8,19 +8,15 @@ parameters:
value:
- GPD
batch_size:
distribution: int_uniform
max: 1024
min: 256
values: [64, 128, 256]
max_epochs:
value:
- 3
- 30
learning_rate:
distribution: uniform
max: 0.02
min: 0.005
values: [0.01, 0.005, 0.001]
highpass:
value:
- 2
- 1
lowpass:
value:
- 10
- 10

View File

@ -13,7 +13,7 @@ parameters:
min: 256
max_epochs:
value:
- 15
- 30
learning_rate:
distribution: uniform
max: 0.02

View File

@ -18,9 +18,7 @@
"import seisbench.data as sbd\n",
"import seisbench.util as sbu\n",
"import numpy as np\n",
"\n",
"\n",
"import utils\n"
"\n"
]
},
{
@ -1126,7 +1124,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.11.5"
"version": "3.10.6"
}
},
"nbformat": 4,

View File

@ -36,6 +36,16 @@
"\n"
]
},
{
"cell_type": "markdown",
"id": "70c64dc6-e4dd-4c01-939d-a28914866f5d",
"metadata": {},
"source": [
"##### The catalog has a custom format with the following properties: \n",
"###### 'Datetime', 'X', 'Y', 'Depth', 'Mw', 'Phases', 'mseed_name'\n",
"###### Phases is a string with detected phases seperated by comma: <Phase> <Station> <Datetime> e.g. \"Pg BRDW 2020-01-01 10:09:44.400, Sg BRDW 2020-01-01 10:09:45.696\""
]
},
{
"cell_type": "code",
"execution_count": 2,
@ -106,6 +116,27 @@
"catalog.head(1)"
]
},
{
"cell_type": "code",
"execution_count": 4,
"id": "03257d45-299d-4ed1-bc64-03303d2a9873",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"'Pg BRDW 2020-01-01 10:09:44.400, Sg BRDW 2020-01-01 10:09:45.696, Pg GROD 2020-01-01 10:09:45.206, Sg GROD 2020-01-01 10:09:46.655, Pg GUZI 2020-01-01 10:09:45.116, Sg GUZI 2020-01-01 10:09:46.561, Pg JEDR 2020-01-01 10:09:44.920, Sg JEDR 2020-01-01 10:09:46.285, Pg MOSK2 2020-01-01 10:09:45.417, Sg MOSK2 2020-01-01 10:09:46.921, Pg NWLU 2020-01-01 10:09:45.686, Sg NWLU 2020-01-01 10:09:47.175, Pg PCHB 2020-01-01 10:09:45.213, Sg PCHB 2020-01-01 10:09:46.565, Pg PPOL 2020-01-01 10:09:44.755, Sg PPOL 2020-01-01 10:09:46.069, Pg RUDN 2020-01-01 10:09:44.502, Sg RUDN 2020-01-01 10:09:45.756, Pg RYNR 2020-01-01 10:09:43.442, Sg RYNR 2020-01-01 10:09:44.394, Pg RZEC 2020-01-01 10:09:46.075, Sg RZEC 2020-01-01 10:09:47.587, Pg SGOR 2020-01-01 10:09:45.817, Sg SGOR 2020-01-01 10:09:47.284, Pg TRBC2 2020-01-01 10:09:44.833, Sg TRBC2 2020-01-01 10:09:46.095, Pg TRN2 2020-01-01 10:09:44.488, Sg TRN2 2020-01-01 10:09:45.698, Pg TRZS 2020-01-01 10:09:46.232, Sg TRZS 2020-01-01 10:09:47.727, Pg ZMST 2020-01-01 10:09:43.592, Sg ZMST 2020-01-01 10:09:44.553, Pg LUBW 2020-01-01 10:09:43.119, Sg LUBW 2020-01-01 10:09:43.929'"
]
},
"execution_count": 4,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"catalog.Phases[0]"
]
},
{
"cell_type": "markdown",
"id": "fe0627b1-6fa0-4b5a-8a60-d80626b5c9be",

View File

@ -0,0 +1,19 @@
#!/bin/bash
#SBATCH --job-name=mseeds_to_seisbench
#SBATCH --time=1:00:00
#SBATCH --account= ### to fill
#SBATCH --partition plgrid
#SBATCH --cpus-per-task=1
#SBATCH --ntasks-per-node=1
#SBATCH --mem=24gb
## activate conda environment
source /path/to/mambaforge/bin/activate ### to adjust
conda activate epos-ai-train
input_path="/path/to/folder/with/mseed/files"
catalog_path="/path/to/catolog.xml"
output_path="/path/to/output/in/seisbench_format"
python mseeds_to_seisbench.py --input_path $input_path --catalog_path $catalog_path --output_path $output_path

View File

@ -1,8 +1,13 @@
"""
This file contains functionality related to data.
"""
import os.path
import seisbench.data as sbd
import logging
logging.root.setLevel(logging.INFO)
logger = logging.getLogger('data')
def get_dataset_by_name(name):
@ -30,3 +35,26 @@ def get_custom_dataset(path):
except AttributeError:
raise ValueError(f"Unknown dataset '{path}'.")
def validate_custom_dataset(data_path):
"""
Validate the dataset
:param data_path: path to the dataset
:return:
"""
# check if path exists
if not os.path.isdir((data_path)):
raise ValueError(f"Data path {data_path} does not exist.")
dataset = sbd.WaveformDataset(data_path)
# check if the dataset is split into train, dev and test
if len(dataset.train()) == 0:
raise ValueError(f"Training set is empty.")
if len(dataset.dev()) == 0:
raise ValueError(f"Dev set is empty.")
if len(dataset.test()) == 0:
raise ValueError(f"Test set is empty.")
logger.info("Custom dataset validated successfully.")

View File

@ -39,10 +39,15 @@ from pathlib import Path
import pandas as pd
import numpy as np
from tqdm import tqdm
import logging
from models import phase_dict
logging.root.setLevel(logging.INFO)
logger = logging.getLogger('targets generator')
def main(dataset_name, output, tasks, sampling_rate, noise_before_events):
np.random.seed(42)
tasks = [str(i) in tasks.split(",") for i in range(1, 4)]
@ -64,17 +69,24 @@ def main(dataset_name, output, tasks, sampling_rate, noise_before_events):
dataset = sbd.WaveformDataset(dataset_name, **dataset_args)
output = Path(output)
output.mkdir(parents=True, exist_ok=False)
output.mkdir(parents=True, exist_ok=True)
if "split" in dataset.metadata.columns:
dataset.filter(dataset["split"].isin(["dev", "test"]), inplace=True)
dataset.preload_waveforms(pbar=True)
if tasks[0]:
generate_task1(dataset, output, sampling_rate, noise_before_events)
if not Path.exists(output / "task1.csv"):
generate_task1(dataset, output, sampling_rate, noise_before_events)
else:
logger.info(f"{output}/task1.csv already exists. Skipping generation of targets.")
if tasks[1] or tasks[2]:
generate_task23(dataset, output, sampling_rate)
if not Path.exists(output / "task23.csv"):
generate_task23(dataset, output, sampling_rate)
else:
logger.info(f"{output}/task23.csv already exists. Skipping generation of targets.")
def generate_task1(dataset, output, sampling_rate, noise_before_events):

View File

@ -7,8 +7,9 @@ import os
import os.path
import argparse
from pytorch_lightning.loggers import WandbLogger, CSVLogger
from pytorch_lightning.callbacks import ModelCheckpoint
from pytorch_lightning.callbacks import ModelCheckpoint, LearningRateMonitor
from pytorch_lightning.callbacks.early_stopping import EarlyStopping
import pytorch_lightning as pl
import wandb
import torch
@ -18,8 +19,8 @@ from dotenv import load_dotenv
import models
import train
import util
from config_loader import config as common_config
from config_loader import models_path, dataset_name, seed, experiment_count
import config_loader
torch.multiprocessing.set_sharing_strategy('file_system')
@ -33,16 +34,13 @@ host = os.environ.get("WANDB_HOST")
if host is None:
raise ValueError("WANDB_HOST environment variable is not set.")
wandb.login(key=wandb_api_key, host=host)
# wandb.login(key=wandb_api_key)
wandb_project_name = os.environ.get("WANDB_PROJECT")
wandb_user_name = os.environ.get("WANDB_USER")
script_name = os.path.splitext(os.path.basename(__file__))[0]
logger = logging.getLogger(script_name)
logger.setLevel(logging.WARNING)
logger.setLevel(logging.INFO)
def set_random_seed(seed=3):
@ -58,6 +56,11 @@ def get_trainer_args(config):
return trainer_args
def get_arg(arg):
if type(arg) == list:
return arg[0]
return arg
class HyperparameterSweep:
def __init__(self, project_name, sweep_config):
self.project_name = project_name
@ -68,11 +71,9 @@ class HyperparameterSweep:
# Create the sweep
self.sweep_id = wandb.sweep(self.sweep_config, project=self.project_name)
logger.info("Created sweep with ID: " + self.sweep_id)
# Run the sweep
wandb.agent(self.sweep_id, function=self.run_experiment, count=experiment_count)
wandb.agent(self.sweep_id, function=self.run_experiment, count=config_loader.experiment_count)
def all_runs_finished(self):
@ -90,24 +91,40 @@ class HyperparameterSweep:
return all_not_running
def run_experiment(self):
try:
logger.debug("Starting a new run...")
logger.info("Starting a new run...")
run = wandb.init(
project=self.project_name,
config=common_config,
config=config_loader.config,
save_code=True,
entity=wandb_user_name
)
run.log_code(
root=".",
include_fn=lambda path: path.endswith(".py") or path.endswith(".sh"),
exclude_fn=lambda path: path.endswith("template.sh")
)
wandb.run.log_code(
".",
include_fn=lambda path: path.endswith(os.path.basename(__file__))
)
model_name = wandb.config.model_name[0]
model_name = get_arg(wandb.config.model_name)
model_args = models.get_model_specific_args(wandb.config)
logger.debug(f"Initializing {model_name}")
if "pretrained" in wandb.config:
weights = get_arg(wandb.config.pretrained)
if weights != "false":
model_args["pretrained"] = weights
if "norm" in wandb.config:
model_args["norm"] = get_arg(wandb.config.norm)
if "finetuning" in wandb.config:
model_args['finetuning_strategy'] = get_arg(wandb.config.finetuning)
if "lr_reduce_factor" in wandb.config:
model_args['steplr_gamma'] = get_arg(wandb.config.lr_reduce_factor)
logger.debug(f"Initializing {model_name} with args: {model_args}")
model = models.__getattribute__(model_name + "Lit")(**model_args)
train_loader, dev_loader = train.prepare_data(wandb.config, model, test_run=False)
@ -116,8 +133,8 @@ class HyperparameterSweep:
wandb_logger.watch(model)
# CSV logger - also used for saving configuration as yaml
experiment_name = f"{dataset_name}_{model_name}"
csv_logger = CSVLogger(models_path, experiment_name, version=run.id)
experiment_name = f"{config_loader.dataset_name}_{model_name}"
csv_logger = CSVLogger(config_loader.models_path, experiment_name, version=run.id)
csv_logger.log_hyperparams(wandb.config)
loggers = [wandb_logger, csv_logger]
@ -131,24 +148,26 @@ class HyperparameterSweep:
filename=experiment_signature + "-{epoch}-{val_loss:.3f}",
monitor="val_loss",
mode="min",
dirpath=f"{models_path}/{experiment_name}/",
dirpath=f"{config_loader.models_path}/{experiment_name}/",
) # save_top_k=1, monitor="val_loss", mode="min": save the best model in terms of validation loss
checkpoint_callback.STARTING_VERSION = 1
early_stopping_callback = EarlyStopping(
monitor="val_loss",
patience=3,
patience=5,
verbose=True,
mode="min")
callbacks = [checkpoint_callback, early_stopping_callback]
lr_monitor = LearningRateMonitor(logging_interval='epoch')
callbacks = [checkpoint_callback, early_stopping_callback, lr_monitor]
trainer = pl.Trainer(
default_root_dir=models_path,
default_root_dir=config_loader.models_path,
logger=loggers,
callbacks=callbacks,
**get_trainer_args(wandb.config)
)
trainer.fit(model, train_loader, dev_loader)
except Exception as e:
@ -160,9 +179,8 @@ class HyperparameterSweep:
def start_sweep(sweep_config):
logger.info("Starting sweep with config: " + str(sweep_config))
set_random_seed(seed)
set_random_seed(config_loader.seed)
sweep_runner = HyperparameterSweep(project_name=wandb_project_name, sweep_config=sweep_config)
sweep_runner.run_sweep()
@ -170,7 +188,6 @@ def start_sweep(sweep_config):
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--sweep_config", type=str, required=True)
args = parser.parse_args()

146
scripts/input_validate.py Normal file
View File

@ -0,0 +1,146 @@
from pydantic import BaseModel, ConfigDict, field_validator
from typing_extensions import Literal
from typing import Union, List, Optional
import yaml
import logging
logging.root.setLevel(logging.INFO)
logger = logging.getLogger('input_validator')
#todo
# 1. check if a single value is allowed in a sweep
# 2. merge input params
# 3. change names of the classes
# 4. add constraints for PhaseNet, GPD
model_names = Literal["PhaseNet", "GPD", "BasicPhaseAE", "EQTransformer"]
norm_values = Literal["peak", "std"]
finetuning_values = Literal["all", "top", "decoder", "encoder"]
pretrained_values = Literal['diting', 'ethz', 'geofon', 'instance', 'iquique', 'lendb', 'neic',
'original', 'scedc', False]
class Metric(BaseModel):
goal: str
name: str
class NumericValue(BaseModel):
value: Union[int, float, List[Union[int, float]]]
class NumericValues(BaseModel):
values: List[Union[int, float]]
class IntDistribution(BaseModel):
distribution: str = "int_uniform"
min: int
max: int
class FloatDistribution(BaseModel):
distribution: str = "uniform"
min: float
max: float
class Pretrained(BaseModel):
distribution: Optional[str] = "categorical"
values: List[pretrained_values] = None
value: Union[pretrained_values, List[pretrained_values]] = None
class Finetuning(BaseModel):
distribution: Optional[str] = "categorical"
values: List[finetuning_values] = None
value: Union[finetuning_values, List[finetuning_values]] = None
class Norm(BaseModel):
distribution: Optional[str] = "categorical"
values: List[norm_values] = None
value: Union[norm_values, List[norm_values]] = None
class ModelType(BaseModel):
distribution: Optional[str] = "categorical"
value: Union[model_names, List[model_names]] = None
values: List[model_names] = None
class Parameters(BaseModel):
model_config = ConfigDict(extra='forbid', protected_namespaces=())
model_name: ModelType
batch_size: Union[IntDistribution, NumericValue, NumericValues]
learning_rate: Union[FloatDistribution, NumericValue, NumericValues]
max_epochs: Union[IntDistribution, NumericValue, NumericValues]
class PhaseNetParameters(Parameters):
model_config = ConfigDict(extra='forbid')
norm: Norm = None
pretrained: Pretrained = None
finetuning: Finetuning = None
lr_reduce_factor: Optional[Union[FloatDistribution, NumericValue, NumericValues]] = None
highpass: Union[NumericValue, NumericValues, FloatDistribution, IntDistribution] = None
lowpass: Union[NumericValue, NumericValues, FloatDistribution, IntDistribution] = None
@field_validator("model_name")
def validate_model(cls, v):
if "PhaseNet" not in v.value:
raise ValueError("Additional parameters implemented for PhaseNet only")
return v
class FilteringParameters(Parameters):
model_config = ConfigDict(extra='forbid')
highpass: Union[NumericValue, NumericValues, FloatDistribution, IntDistribution] = None
lowpass: Union[NumericValue, NumericValues, FloatDistribution, IntDistribution] = None
@field_validator("model_name")
def validate_model(cls, v):
print(v.value)
if v.value[0] not in ["GPD", "PhaseNet"]:
raise ValueError("Filtering parameters implemented for GPD and PhaseNet only")
class InputParams(BaseModel):
name: str
method: str
metric: Metric
parameters: Union[Parameters, PhaseNetParameters, FilteringParameters]
def validate_sweep_yaml(yaml_filename, model_name=None):
# Load YAML configuration
with open(yaml_filename, 'r') as f:
sweep_config = yaml.safe_load(f)
validate_sweep_config(sweep_config, model_name)
def validate_sweep_config(sweep_config, model_name=None):
# Validate sweep config
input_params = InputParams(**sweep_config)
# Check consistency of input parameters and sweep configuration
sweep_model_name = input_params.parameters.model_name.value
if model_name is not None and model_name not in sweep_model_name:
info = f"Model name {model_name} is inconsistent with the sweep configuration {sweep_model_name}."
logger.info(info)
raise ValueError(info)
logger.info("Input validation successful.")
if __name__ == "__main__":
yaml_filename = "../experiments/sweep_phasenet_bogdanka_lr_bs.yaml"
validate_sweep_yaml(yaml_filename, None)

View File

@ -7,16 +7,26 @@ import seisbench.generate as sbg
import pytorch_lightning as pl
import torch
from torch.optim import lr_scheduler
import torch.nn.functional as F
import numpy as np
from abc import abstractmethod, ABC
# import lightning as L
# Allows to import this file in both jupyter notebook and code
try:
from .augmentations import DuplicateEvent
except ImportError:
from augmentations import DuplicateEvent
import os
import logging
script_name = os.path.splitext(os.path.basename(__file__))[0]
logger = logging.getLogger(script_name)
logger.setLevel(logging.DEBUG)
# Phase dict for labelling. We only study P and S phases without differentiating between them.
phase_dict = {
@ -131,7 +141,31 @@ class PhaseNetLit(SeisBenchModuleLit):
self.sigma = sigma
self.sample_boundaries = sample_boundaries
self.loss = vector_cross_entropy
self.model = sbm.PhaseNet(phases="PN", **kwargs)
self.pretrained = kwargs.pop("pretrained", None)
self.norm = kwargs.pop("norm", "peak")
self.highpass = kwargs.pop("highpass", None)
self.lowpass = kwargs.pop("lowpass", None)
if self.pretrained is not None:
self.model = sbm.PhaseNet.from_pretrained(self.pretrained)
# self.norm = self.model.norm
else:
self.model = sbm.PhaseNet(**kwargs)
self.finetuning_strategy = kwargs.pop("finetuning_strategy", None)
self.steplr_gamma = kwargs.pop("steplr_gamma", 0.1)
self.reduce_lr_on_plateau = False
self.initial_epochs = 0
if self.finetuning_strategy is not None:
if self.finetuning_strategy == "top":
self.initial_epochs = 3
elif self.finetuning_strategy in ["decoder", "encoder"]:
self.initial_epochs = 6
self.freeze()
def forward(self, x):
return self.model(x)
@ -154,9 +188,49 @@ class PhaseNetLit(SeisBenchModuleLit):
def configure_optimizers(self):
optimizer = torch.optim.Adam(self.parameters(), lr=self.lr)
return optimizer
if self.finetuning_strategy is not None:
scheduler = lr_scheduler.LambdaLR(optimizer, self.lr_lambda)
self.reduce_lr_on_plateau = False
else:
scheduler = lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.1, patience=3)
self.reduce_lr_on_plateau = True
#
return {
'optimizer': optimizer,
'lr_scheduler': {
'scheduler': scheduler,
'monitor': 'val_loss',
'interval': 'epoch',
'reduce_on_plateau': self.reduce_lr_on_plateau,
},
}
def lr_lambda(self, epoch):
# reduce lr after x initial epochs
if epoch == self.initial_epochs:
self.lr *= self.steplr_gamma
return self.lr
def lr_scheduler_step(self, scheduler, metric):
if self.reduce_lr_on_plateau:
scheduler.step(metric, epoch=self.current_epoch)
else:
scheduler.step(epoch=self.current_epoch)
# def lr_scheduler_step(self, scheduler, optimizer_idx, metric):
# scheduler.step(epoch=self.current_epoch)
def get_augmentations(self):
filter = []
if self.highpass is not None:
filter = [sbg.Filter(1, self.highpass, "highpass", forward_backward=True)]
logger.info(f"Using highpass filer {self.highpass}")
if self.lowpass is not None:
filter += [sbg.Filter(1, self.lowpass, "lowpass", forward_backward=True)]
logger.info(f"Using lowpass filer {self.lowpass}")
logger.info(filter)
return [
# In 2/3 of the cases, select windows around picks, to reduce amount of noise traces in training.
# Uses strategy variable, as padding will be handled by the random window.
@ -180,18 +254,26 @@ class PhaseNetLit(SeisBenchModuleLit):
windowlen=3001,
strategy="pad",
),
sbg.Normalize(demean_axis=-1, amp_norm_axis=-1, amp_norm_type=self.norm),
*filter,
sbg.ChangeDtype(np.float32),
sbg.Normalize(demean_axis=-1, amp_norm_axis=-1, amp_norm_type="peak"),
sbg.ProbabilisticLabeller(
label_columns=phase_dict, sigma=self.sigma, dim=0
),
]
def get_eval_augmentations(self):
filter = []
if self.highpass is not None:
filter = [sbg.Filter(1, self.highpass, "highpass", forward_backward=True)]
if self.lowpass is not None:
filter += [sbg.Filter(1, self.lowpass, "lowpass", forward_backward=True)]
return [
sbg.SteeredWindow(windowlen=3001, strategy="pad"),
sbg.Normalize(demean_axis=-1, amp_norm_axis=-1, amp_norm_type=self.norm),
*filter,
sbg.ChangeDtype(np.float32),
sbg.Normalize(demean_axis=-1, amp_norm_axis=-1, amp_norm_type="peak"),
]
def predict_step(self, batch, batch_idx=None, dataloader_idx=None):
@ -219,6 +301,27 @@ class PhaseNetLit(SeisBenchModuleLit):
return score_detection, score_p_or_s, p_sample, s_sample
def freeze(self):
if self.finetuning_strategy == "decoder": # finetune decoder branch and freeze encoder branch
for p in self.model.down_branch.parameters():
p.requires_grad = False
elif self.finetuning_strategy == "encoder": # finetune encoder branch and freeze decoder branch
for p in self.model.up_branch.parameters():
p.requires_grad = False
elif self.finetuning_strategy == "top":
for p in self.model.out.parameters():
p.requires_grad = False
def unfreeze(self):
logger.info("Unfreezing layers")
for p in self.model.parameters():
p.requires_grad = True
def on_train_epoch_start(self):
# Unfreeze some layers after x initial epochs
if self.current_epoch == self.initial_epochs:
self.unfreeze()
class GPDLit(SeisBenchModuleLit):
"""
@ -846,7 +949,7 @@ class BasicPhaseAELit(SeisBenchModuleLit):
# Create overlapping windows
re = torch.zeros(x.shape[:2] + (7, 600), dtype=x.dtype, device=x.device)
for i, start in enumerate(range(0, 2401, 400)):
re[:, :, i] = x[:, :, start : start + 600]
re[:, :, i] = x[:, :, start: start + 600]
x = re
x = x.permute(0, 2, 1, 3) # --> (batch, windows, channels, samples)
@ -862,9 +965,9 @@ class BasicPhaseAELit(SeisBenchModuleLit):
for i, start in enumerate(range(0, 2401, 400)):
if start == 0:
# Use full window (for start==0, the end will be overwritten)
pred[:, :, start : start + 600] = window_pred[:, i]
pred[:, :, start: start + 600] = window_pred[:, i]
else:
pred[:, :, start + 100 : start + 600] = window_pred[:, i, :, 100:]
pred[:, :, start + 100: start + 600] = window_pred[:, i, :, 100:]
score_detection = torch.zeros(pred.shape[0])
score_p_or_s = torch.zeros(pred.shape[0])
@ -1112,7 +1215,6 @@ class DPPPickerLit(SeisBenchModuleLit):
def get_model_specific_args(config):
model = config.model_name[0]
lr = config.learning_rate
if type(lr) == list:
@ -1125,8 +1227,12 @@ def get_model_specific_args(config):
if 'highpass' in config:
args['highpass'] = config.highpass
if 'lowpass' in config:
args['lowpass'] = config.lowpass[0]
args['lowpass'] = config.lowpass
case 'PhaseNet':
if 'highpass' in config:
args['highpass'] = config.highpass
if 'lowpass' in config:
args['lowpass'] = config.lowpass
if 'sample_boundaries' in config:
args['sample_boundaries'] = config.sample_boundaries
case 'DPPPicker':

View File

@ -1,66 +1,44 @@
"""
-----------------
Copyright © 2023 ACK Cyfronet AGH, Poland.
This work was partially funded by EPOS Project funded in frame of PL-POIR4.2
-----------------
"""
import os
import pandas as pd
import glob
from pathlib import Path
import obspy
from obspy.core.event import read_events
import seisbench
import seisbench.data as sbd
import seisbench.util as sbu
import sys
import logging
import argparse
logging.basicConfig(filename="output.out",
filemode='a',
format='%(asctime)s,%(msecs)d %(name)s %(levelname)s %(message)s',
datefmt='%H:%M:%S',
level=logging.DEBUG)
logging.root.setLevel(logging.INFO)
logger = logging.getLogger('converter')
def create_traces_catalog(directory, years):
for year in years:
directory = f"{directory}/{year}"
files = glob.glob(directory)
traces = []
for i, f in enumerate(files):
st = obspy.read(f)
for tr in st.traces:
# trace_id = tr.id
# start = tr.meta.starttime
# end = tr.meta.endtime
trs = pd.Series({
'trace_id': tr.id,
'trace_st': tr.meta.starttime,
'trace_end': tr.meta.endtime,
'stream_fname': f
})
traces.append(trs)
traces_catalog = pd.DataFrame(pd.concat(traces)).transpose()
traces_catalog.to_csv("data/bogdanka/traces_catalog.csv", append=True, index=False)
def split_events(events, input_path):
logger.info("Splitting available events into train, dev and test sets ...")
events_stats = pd.DataFrame()
events_stats.index.name = "event"
for i, event in enumerate(events):
#check if mseed exists
# check if mseed exists
actual_picks = 0
for pick in event.picks:
trace_params = get_trace_params(pick)
trace_path = get_trace_path(input_path, trace_params)
if os.path.isfile(trace_path):
actual_picks += 1
@ -79,6 +57,10 @@ def split_events(events, input_path):
events_stats.loc[i, 'split'] = 'dev'
else:
break
logger.info(f"Split: {events_stats['split'].value_counts()}")
logger.info(f"Split: {events_stats['split'].value_counts()}")
return events_stats
@ -87,7 +69,6 @@ def get_event_params(event):
origin = event.preferred_origin()
if origin is None:
return {}
# print(origin)
mag = event.preferred_magnitude()
@ -115,12 +96,9 @@ def get_event_params(event):
def get_trace_params(pick):
net = pick.waveform_id.network_code
sta = pick.waveform_id.station_code
trace_params = {
"station_network_code": net,
"station_code": sta,
"station_network_code": pick.waveform_id.network_code,
"station_code": pick.waveform_id.station_code,
"trace_channel": pick.waveform_id.channel_code,
"station_location_code": pick.waveform_id.location_code,
"time": pick.time
@ -134,7 +112,6 @@ def find_trace(pick_time, traces):
if pick_time > tr.stats.endtime:
continue
if pick_time >= tr.stats.starttime:
# print(pick_time, " - selected trace: ", tr)
return tr
logger.warning(f"no matching trace for peak: {pick_time}")
@ -152,12 +129,26 @@ def get_trace_path(input_path, trace_params):
return path
def get_three_channels_trace_paths(input_path, trace_params):
year = trace_params["time"].year
day_of_year = pd.Timestamp(str(trace_params["time"])).day_of_year
net = trace_params["station_network_code"]
station = trace_params["station_code"]
channel_base = trace_params["trace_channel"]
paths = []
for ch in ["E", "N", "Z"]:
channel = channel_base[:-1] + ch
paths.append(
f"{input_path}/{year}/{net}/{station}/{channel}.D/{net}.{station}..{channel}.D.{year}.{day_of_year:03}")
return paths
def load_trace(input_path, trace_params):
trace_path = get_trace_path(input_path, trace_params)
trace = None
if not os.path.isfile(trace_path):
logger.w(trace_path + " not found")
logger.warning(trace_path + " not found")
else:
stream = obspy.read(trace_path)
if len(stream.traces) > 1:
@ -171,19 +162,26 @@ def load_trace(input_path, trace_params):
def load_stream(input_path, trace_params, time_before=60, time_after=60):
trace_path = get_trace_path(input_path, trace_params)
sampling_rate, stream = None, None
pick_time = trace_params["time"]
if not os.path.isfile(trace_path):
print(trace_path + " not found")
else:
stream = obspy.read(trace_path)
trace_paths = get_three_channels_trace_paths(input_path, trace_params)
for trace_path in trace_paths:
if not os.path.isfile(trace_path):
logger.warning(trace_path + " not found")
else:
if stream is None:
stream = obspy.read(trace_path)
else:
stream += obspy.read(trace_path)
if stream is not None:
stream = stream.slice(pick_time - time_before, pick_time + time_after)
if len(stream.traces) == 0:
print(f"no data in: {trace_path}")
else:
sampling_rate = stream.traces[0].stats.sampling_rate
stream.merge()
return sampling_rate, stream
@ -202,23 +200,36 @@ def convert_mseed_to_seisbench_format(input_path, catalog_path, output_path):
metadata_path = output_path + "/metadata.csv"
waveforms_path = output_path + "/waveforms.hdf5"
events_to_convert = events_stats[events_stats['pick_count'] > 0]
logger.debug("Catalog loaded, starting conversion ...")
logger.debug("Catalog loaded, starting converting {events_to_convert} events ...")
with sbd.WaveformDataWriter(metadata_path, waveforms_path) as writer:
writer.data_format = {
"dimension_order": "CW",
"component_order": "ZNE",
}
for i, event in enumerate(events):
logger.debug(f"Converting {i} event")
event_params = get_event_params(event)
event_params["split"] = events_stats.loc[i, "split"]
picks_per_station = {}
for pick in event.picks:
trace_params = get_trace_params(pick)
station = pick.waveform_id.station_code
if station in picks_per_station:
picks_per_station[station].append(pick)
else:
picks_per_station[station] = [pick]
for picks in picks_per_station.values():
trace_params = get_trace_params(picks[0])
sampling_rate, stream = load_stream(input_path, trace_params)
if stream is None:
if stream is None or len(stream.traces) == 0:
continue
actual_t_start, data, _ = sbu.stream_to_array(
@ -229,22 +240,19 @@ def convert_mseed_to_seisbench_format(input_path, catalog_path, output_path):
trace_params["trace_sampling_rate_hz"] = sampling_rate
trace_params["trace_start_time"] = str(actual_t_start)
pick_time = obspy.core.utcdatetime.UTCDateTime(trace_params["time"])
pick_idx = (pick_time - actual_t_start) * sampling_rate
trace_params[f"trace_{pick.phase_hint}_arrival_sample"] = int(pick_idx)
for pick in picks:
pick_time = obspy.core.utcdatetime.UTCDateTime(pick.time)
pick_idx = (pick_time - actual_t_start) * sampling_rate
trace_params[f"trace_{pick.phase_hint}_arrival_sample"] = int(pick_idx)
writer.add_trace({**event_params, **trace_params}, data)
if __name__ == "__main__":
parser = argparse.ArgumentParser(description='Convert mseed files to seisbench format')
parser.add_argument('--input_path', type=str, help='Path to mseed files')
parser.add_argument('--catalog_path', type=str, help='Path to events catalog in quakeml format')
parser.add_argument('--output_path', type=str, help='Path to output files')
args = parser.parse_args()
convert_mseed_to_seisbench_format(args.input_path, args.catalog_path, args.output_path)

149
scripts/perf_analysis.py Normal file
View File

@ -0,0 +1,149 @@
import json
import pathlib
import numpy as np
import obspy
import pandas as pd
import seisbench.data as sbd
import seisbench.models as sbm
from seisbench.models.team import itertools
from sklearn.metrics import precision_recall_curve, roc_auc_score, roc_curve
datasets = [
# path to datasets in seisbench format
]
models = [
# model names
]
def find_keys_phase(meta, phase):
phases = []
for k in meta.keys():
if k.startswith("trace_" + phase) and k.endswith("_arrival_sample"):
phases.append(k)
return phases
def create_stream(meta, raw, start, length=30):
st = obspy.Stream()
for i in range(3):
tr = obspy.Trace(raw[i, :])
tr.stats.starttime = meta["trace_start_time"]
tr.stats.sampling_rate = meta["trace_sampling_rate_hz"]
tr.stats.network = meta["station_network_code"]
tr.stats.station = meta["station_code"]
tr.stats.channel = meta["trace_channel"][:2] + meta["trace_component_order"][i]
stop = start + length
tr = tr.slice(start, stop)
st.append(tr)
return st
def get_pred(model, stream):
ann = model.annotate(stream)
noise = ann.select(channel="PhaseNet_N")[0]
pred = max(1 - noise.data)
return pred
def to_short(stream):
short = [tr for tr in stream if tr.data.shape[0] < 3001]
return any(short)
for ds, model_name in itertools.product(datasets, models):
data = sbd.WaveformDataset(ds, sampling_rate=100).test()
data_name = pathlib.Path(ds).stem
fname = f"roc___{model_name}___{data_name}.csv"
print(f"{fname:.<50s}.... ", flush=True, end="")
if pathlib.Path(fname).is_file():
print(" ready, skipping", flush=True)
continue
p_labels = find_keys_phase(data.metadata, "P")
s_labels = find_keys_phase(data.metadata, "S")
model = sbm.PhaseNet().from_pretrained(model_name)
label_true = []
label_pred = []
for i in range(len(data)):
waveform, metadata = data.get_sample(i)
m = pd.Series(metadata)
has_p_label = m[p_labels].notna()
has_s_label = m[s_labels].notna()
if any(has_p_label):
trace_start_time = obspy.UTCDateTime(m["trace_start_time"])
pick_sample = m[p_labels][has_p_label][0]
start = trace_start_time + pick_sample / m["trace_sampling_rate_hz"] - 15
try:
st_p = create_stream(m, waveform, start)
if not (to_short(st_p)):
pred_p = get_pred(model, st_p)
label_true.append(1)
label_pred.append(pred_p)
except IndexError:
pass
try:
st_n = create_stream(m, waveform, trace_start_time + 1)
if not (to_short(st_n)):
pred_n = get_pred(model, st_n)
label_true.append(0)
label_pred.append(pred_n)
except IndexError:
pass
if any(has_s_label):
trace_start_time = obspy.UTCDateTime(m["trace_start_time"])
pick_sample = m[s_labels][has_s_label][0]
start = trace_start_time + pick_sample / m["trace_sampling_rate_hz"] - 15
try:
st_s = create_stream(m, waveform, start)
if not (to_short(st_s)):
pred_s = get_pred(model, st_s)
label_true.append(1)
label_pred.append(pred_s)
except IndexError:
pass
fpr, tpr, roc_thresholds = roc_curve(label_true, label_pred)
df = pd.DataFrame({"fpr": fpr, "tpr": tpr, "thresholds": roc_thresholds})
df.to_csv(fname)
precision, recall, prc_thresholds = precision_recall_curve(label_true, label_pred)
prc_thresholds_extra = np.append(prc_thresholds, -999)
df = pd.DataFrame(
{"pre": precision, "rec": recall, "thresholds": prc_thresholds_extra}
)
df.to_csv(fname.replace("roc", "pr"))
stats = {
"model": str(model_name),
"data": str(data_name),
"auc": float(roc_auc_score(label_true, label_pred)),
}
with open(f"stats___{model_name}___{data_name}.json", "w") as fp:
json.dump(stats, fp)
print(" finished", flush=True)

View File

@ -15,33 +15,49 @@ import generate_eval_targets
import hyperparameter_sweep
import eval
import collect_results
from config_loader import data_path, targets_path, sampling_rate, dataset_name, sweep_files
import importlib
import config_loader
import input_validate
import data
logging.root.setLevel(logging.INFO)
logger = logging.getLogger('pipeline')
def load_sweep_config(model_name, args):
if model_name == "PhaseNet" and args.phasenet_config is not None:
sweep_fname = args.phasenet_config
elif model_name == "GPD" and args.gpd_config is not None:
sweep_fname = args.gpd_config
elif model_name == "BasicPhaseAE" and args.basic_phase_ae_config is not None:
sweep_fname = args.basic_phase_ae_config
elif model_name == "EQTransformer" and args.eqtransformer_config is not None:
sweep_fname = args.eqtransformer_config
else:
# use the default sweep config for the model
sweep_fname = sweep_files[model_name]
sweep_fname = config_loader.sweep_files[model_name]
logger.info(f"Loading sweep config: {sweep_fname}")
return util.load_sweep_config(sweep_fname)
def find_the_best_params(model_name, args):
def validate_pipeline_input(args):
# validate input parameters
for model_name in args.models:
sweep_config = load_sweep_config(model_name, args)
input_validate.validate_sweep_config(sweep_config, model_name)
# validate dataset
data.validate_custom_dataset(config_loader.data_path)
def find_the_best_params(sweep_config):
# find the best hyperparams for the model_name
model_name = sweep_config['parameters']['model_name']
logger.info(f"Starting searching for the best hyperparams for the model: {model_name}")
sweep_config = load_sweep_config(model_name, args)
sweep_runner = hyperparameter_sweep.start_sweep(sweep_config)
# wait for all runs to finish
@ -58,9 +74,9 @@ def find_the_best_params(model_name, args):
def generate_predictions(sweep_id, model_name):
experiment_name = f"{dataset_name}_{model_name}"
experiment_name = f"{config_loader.dataset_name}_{model_name}"
eval.main(weights=experiment_name,
targets=targets_path,
targets=config_loader.targets_path,
sets='dev,test',
batchsize=128,
num_workers=4,
@ -73,22 +89,49 @@ def main():
parser = argparse.ArgumentParser()
parser.add_argument("--phasenet_config", type=str, required=False)
parser.add_argument("--gpd_config", type=str, required=False)
parser.add_argument("--basic_phase_ae_config", type=str, required=False)
parser.add_argument("--eqtransformer_config", type=str, required=False)
parser.add_argument("--dataset", type=str, required=False)
available_models = ["GPD", "PhaseNet", "BasicPhaseAE", "EQTransformer"]
parser.add_argument("--models", nargs='*', required=False, choices=available_models, default=available_models,
help="Models to train and evaluate (default: all)")
parser.add_argument("--collect_results", action="store_true", help="Collect and log results without training")
args = parser.parse_args()
# generate labels
logger.info("Started generating labels for the dataset.")
generate_eval_targets.main(data_path, targets_path, "2,3", sampling_rate, None)
if not args.collect_results:
# find the best hyperparams for the models
logger.info("Started training the models.")
for model_name in ["GPD", "PhaseNet"]:
sweep_id = find_the_best_params(model_name, args)
generate_predictions(sweep_id, model_name)
if args.dataset is not None:
util.set_dataset(args.dataset)
importlib.reload(config_loader)
validate_pipeline_input(args)
logger.info(f"Started pipeline for the {config_loader.dataset_name} dataset.")
# generate labels
logger.info("Started generating labels for the dataset.")
generate_eval_targets.main(config_loader.data_path, config_loader.targets_path, "2,3", config_loader.sampling_rate,
None)
# find the best hyperparams for the models
logger.info("Started training the models.")
for model_name in args.models:
sweep_config = load_sweep_config(model_name, args)
sweep_id = find_the_best_params(sweep_config)
generate_predictions(sweep_id, model_name)
# collect results
logger.info("Collecting results.")
collect_results.traverse_path("pred", "pred/results.csv")
logger.info("Results saved in pred/results.csv")
results_path = "pred/results.csv"
collect_results.traverse_path("pred", results_path)
logger.info(f"Results saved in {results_path}")
# log calculated metrics (MAE) on w&b
logger.info("Logging MAE metrics on w&b.")
util.log_metrics(results_path)
logger.info("Pipeline finished")
if __name__ == "__main__":
main()

View File

@ -0,0 +1,19 @@
#!/bin/bash
#SBATCH --job-name=job_name
#SBATCH --time=10:00:00
#SBATCH --account= ### to fill
#SBATCH --partition=plgrid-gpu-v100
#SBATCH --cpus-per-task=1
#SBATCH --ntasks-per-node=1
#SBATCH --gres=gpu:1
source path/to/mambaforge/bin/activate ### to change
conda activate epos-ai-train
python -c "import torch; print('CUDA available:', torch.cuda.is_available())"
python -c "import torch; print('Number of CUDA devices:', torch.cuda.device_count())"
python -c "import torch; print('Name of GPU:', torch.cuda.get_device_name(torch.cuda.current_device()))"
python pipeline.py --dataset "bogdanka"

View File

@ -1,5 +1,10 @@
"""
This script offers general functionality required in multiple places.
-----------------
Copyright © 2023 ACK Cyfronet AGH, Poland.
This work was partially funded by EPOS Project funded in frame of PL-POIR4.2
-----------------
This script runs the pipeline for the training and evaluation of the models.
"""
import numpy as np
@ -7,13 +12,15 @@ import pandas as pd
import os
import logging
import glob
import json
import wandb
from dotenv import load_dotenv
import sys
from config_loader import models_path, configs_path
from config_loader import models_path, configs_path, config_path
import yaml
load_dotenv()
load_dotenv()
logging.basicConfig()
logging.getLogger().setLevel(logging.INFO)
@ -38,8 +45,16 @@ def load_best_model_data(sweep_id, weights):
# Get best run parameters
best_run = sweep.best_run()
run_id = best_run.id
matching_models = glob.glob(f"{models_path}/{weights}/*run={run_id}*ckpt")
if len(matching_models)!=1:
run = api.run(f"{wandb_user}/{wandb_project_name}/runs/{run_id}")
dataset = run.config["dataset_name"]
model = run.config["model_name"][0]
experiment = f"{dataset}_{model}"
checkpoints_path = f"{models_path}/{experiment}/*run={run_id}*ckpt"
logging.debug(f"Searching for checkpoints in dir: {checkpoints_path}")
matching_models = glob.glob(checkpoints_path)
if len(matching_models) != 1:
raise ValueError("Unable to determine the best checkpoint for run_id: " + run_id)
best_checkpoint_path = matching_models[0]
@ -62,31 +77,6 @@ def load_best_model_data(sweep_id, weights):
return best_checkpoint_path, run_id
def load_best_model(model_cls, weights, version):
"""
Determines the model with lowest validation loss from the csv logs and loads it
:param model_cls: Class of the lightning module to load
:param weights: Path to weights as in cmd arguments
:param version: String of version file
:return: Instance of lightning module that was loaded from the best checkpoint
"""
metrics = pd.read_csv(weights / version / "metrics.csv")
idx = np.nanargmin(metrics["val_loss"])
min_row = metrics.iloc[idx]
# For default checkpoint filename, see https://github.com/Lightning-AI/lightning/pull/11805
# and https://github.com/Lightning-AI/lightning/issues/16636.
# For example, 'epoch=0-step=1.ckpt' means the 1st step has finish, but the 1st epoch hasn't
checkpoint = f"epoch={min_row['epoch']:.0f}-step={min_row['step']+1:.0f}.ckpt"
# For default save path of checkpoints, see https://github.com/Lightning-AI/lightning/pull/12372
checkpoint_path = weights / version / "checkpoints" / checkpoint
return model_cls.load_from_checkpoint(checkpoint_path)
default_workers = os.getenv("BENCHMARK_DEFAULT_WORKERS", None)
if default_workers is None:
logging.warning(
@ -117,3 +107,51 @@ def load_sweep_config(sweep_fname):
sys.exit(1)
return sweep_config
def log_metrics(results_file):
"""
:param results_file: csv file with calculated metrics
:return:
"""
api = wandb.Api()
wandb_project_name = os.environ.get("WANDB_PROJECT")
wandb_user = os.environ.get("WANDB_USER")
results = pd.read_csv(results_file)
for run_id in results["version"].unique():
try:
run = api.run(f"{wandb_user}/{wandb_project_name}/{run_id}")
metrics_to_log = {}
run_results = results[results["version"] == run_id]
for col in run_results.columns:
if 'mae' in col:
metrics_to_log[col] = run_results[col].values[0]
run.summary[col] = run_results[col].values[0]
run.summary.update()
logging.info(f"Logged metrics for run: {run_id}, {metrics_to_log}")
except Exception as e:
print(f"An error occurred: {e}, {type(e).__name__}, {e.args}")
def set_dataset(dataset_name):
"""
Sets the dataset name in the config file
:param dataset_name:
:return:
"""
with open(config_path, "r+") as f:
config = json.load(f)
config["dataset_name"] = dataset_name
config["data_path"] = f"datasets/{dataset_name}/seisbench_format/"
f.seek(0) # rewind
json.dump(config, f, indent=4)
f.truncate()

View File

@ -1,19 +0,0 @@
#!/bin/bash
#SBATCH --job-name=mseeds_to_seisbench
#SBATCH --time=1:00:00
#SBATCH --account=plgeposai22gpu-gpu
#SBATCH --partition plgrid
#SBATCH --cpus-per-task=1
#SBATCH --ntasks-per-node=1
#SBATCH --mem=24gb
## activate conda environment
source /net/pr2/projects/plgrid/plggeposai/kmilian/mambaforge/bin/activate
conda activate epos-ai-train
input_path="/net/pr2/projects/plgrid/plggeposai/datasets/bogdanka"
catalog_path="/net/pr2/projects/plgrid/plggeposai/datasets/bogdanka/BOIS_all.xml"
output_path="/net/pr2/projects/plgrid/plggeposai/kmilian/platform-demo-scripts/datasets/bogdanka/seisbench_format"
python mseeds_to_seisbench.py --input_path $input_path --catalog_path $catalog_path --output_path $output_path

View File

@ -1,230 +0,0 @@
import os
import pandas as pd
import glob
from pathlib import Path
import obspy
from obspy.core.event import read_events
import seisbench.data as sbd
import seisbench.util as sbu
import sys
import logging
logging.basicConfig(filename="output.out",
filemode='a',
format='%(asctime)s,%(msecs)d %(name)s %(levelname)s %(message)s',
datefmt='%H:%M:%S',
level=logging.DEBUG)
logger = logging.getLogger('converter')
def create_traces_catalog(directory, years):
for year in years:
directory = f"{directory}/{year}"
files = glob.glob(directory)
traces = []
for i, f in enumerate(files):
st = obspy.read(f)
for tr in st.traces:
# trace_id = tr.id
# start = tr.meta.starttime
# end = tr.meta.endtime
trs = pd.Series({
'trace_id': tr.id,
'trace_st': tr.meta.starttime,
'trace_end': tr.meta.endtime,
'stream_fname': f
})
traces.append(trs)
traces_catalog = pd.DataFrame(pd.concat(traces)).transpose()
traces_catalog.to_csv("data/bogdanka/traces_catalog.csv", append=True, index=False)
def split_events(events, input_path):
logger.info("Splitting available events into train, dev and test sets ...")
events_stats = pd.DataFrame()
events_stats.index.name = "event"
for i, event in enumerate(events):
#check if mseed exists
actual_picks = 0
for pick in event.picks:
trace_params = get_trace_params(pick)
trace_path = get_trace_path(input_path, trace_params)
if os.path.isfile(trace_path):
actual_picks += 1
events_stats.loc[i, "pick_count"] = actual_picks
events_stats['pick_count_cumsum'] = events_stats.pick_count.cumsum()
train_th = 0.7 * events_stats.pick_count_cumsum.values[-1]
dev_th = 0.85 * events_stats.pick_count_cumsum.values[-1]
events_stats['split'] = 'test'
for i, event in events_stats.iterrows():
if event['pick_count_cumsum'] < train_th:
events_stats.loc[i, 'split'] = 'train'
elif event['pick_count_cumsum'] < dev_th:
events_stats.loc[i, 'split'] = 'dev'
else:
break
return events_stats
def get_event_params(event):
origin = event.preferred_origin()
if origin is None:
return {}
# print(origin)
mag = event.preferred_magnitude()
source_id = str(event.resource_id)
event_params = {
"source_id": source_id,
"source_origin_uncertainty_sec": origin.time_errors["uncertainty"],
"source_latitude_deg": origin.latitude,
"source_latitude_uncertainty_km": origin.latitude_errors["uncertainty"],
"source_longitude_deg": origin.longitude,
"source_longitude_uncertainty_km": origin.longitude_errors["uncertainty"],
"source_depth_km": origin.depth / 1e3,
"source_depth_uncertainty_km": origin.depth_errors["uncertainty"] / 1e3 if origin.depth_errors[
"uncertainty"] is not None else None,
}
if mag is not None:
event_params["source_magnitude"] = mag.mag
event_params["source_magnitude_uncertainty"] = mag.mag_errors["uncertainty"]
event_params["source_magnitude_type"] = mag.magnitude_type
event_params["source_magnitude_author"] = mag.creation_info.agency_id if mag.creation_info is not None else None
return event_params
def get_trace_params(pick):
net = pick.waveform_id.network_code
sta = pick.waveform_id.station_code
trace_params = {
"station_network_code": net,
"station_code": sta,
"trace_channel": pick.waveform_id.channel_code,
"station_location_code": pick.waveform_id.location_code,
"time": pick.time
}
return trace_params
def find_trace(pick_time, traces):
for tr in traces:
if pick_time > tr.stats.endtime:
continue
if pick_time >= tr.stats.starttime:
# print(pick_time, " - selected trace: ", tr)
return tr
logger.warning(f"no matching trace for peak: {pick_time}")
return None
def get_trace_path(input_path, trace_params):
year = trace_params["time"].year
day_of_year = pd.Timestamp(str(trace_params["time"])).day_of_year
net = trace_params["station_network_code"]
station = trace_params["station_code"]
tr_channel = trace_params["trace_channel"]
path = f"{input_path}/{year}/{net}/{station}/{tr_channel}.D/{net}.{station}..{tr_channel}.D.{year}.{day_of_year}"
return path
def load_trace(input_path, trace_params):
trace_path = get_trace_path(input_path, trace_params)
trace = None
if not os.path.isfile(trace_path):
logger.w(trace_path + " not found")
else:
stream = obspy.read(trace_path)
if len(stream.traces) > 1:
trace = find_trace(trace_params["time"], stream.traces)
elif len(stream.traces) == 0:
logger.warning(f"no data in: {trace_path}")
else:
trace = stream.traces[0]
return trace
def load_stream(input_path, trace_params, time_before=60, time_after=60):
trace_path = get_trace_path(input_path, trace_params)
sampling_rate, stream = None, None
pick_time = trace_params["time"]
if not os.path.isfile(trace_path):
print(trace_path + " not found")
else:
stream = obspy.read(trace_path)
stream = stream.slice(pick_time - time_before, pick_time + time_after)
if len(stream.traces) == 0:
print(f"no data in: {trace_path}")
else:
sampling_rate = stream.traces[0].stats.sampling_rate
return sampling_rate, stream
def convert_mseed_to_seisbench_format():
input_path = "/net/pr2/projects/plgrid/plggeposai"
logger.info("Loading events catalog ...")
events = read_events(input_path + "/BOIS_all.xml")
events_stats = split_events(events)
output_path = input_path + "/seisbench_format"
metadata_path = output_path + "/metadata.csv"
waveforms_path = output_path + "/waveforms.hdf5"
with sbd.WaveformDataWriter(metadata_path, waveforms_path) as writer:
writer.data_format = {
"dimension_order": "CW",
"component_order": "ZNE",
}
for i, event in enumerate(events):
logger.debug(f"Converting {i} event")
event_params = get_event_params(event)
event_params["split"] = events_stats.loc[i, "split"]
# b = False
for pick in event.picks:
trace_params = get_trace_params(pick)
sampling_rate, stream = load_stream(input_path, trace_params)
if stream is None:
continue
actual_t_start, data, _ = sbu.stream_to_array(
stream,
component_order=writer.data_format["component_order"],
)
trace_params["trace_sampling_rate_hz"] = sampling_rate
trace_params["trace_start_time"] = str(actual_t_start)
pick_time = obspy.core.utcdatetime.UTCDateTime(trace_params["time"])
pick_idx = (pick_time - actual_t_start) * sampling_rate
trace_params[f"trace_{pick.phase_hint}_arrival_sample"] = int(pick_idx)
writer.add_trace({**event_params, **trace_params}, data)
if __name__ == "__main__":
convert_mseed_to_seisbench_format()
# create_traces_catalog("/net/pr2/projects/plgrid/plggeposai/", ["2018", "2019"])