initial commit with the pipeline for training and evaluating seisbench models

This commit is contained in:
Krystyna Milian 2023-08-29 09:59:31 +02:00
parent fd4a67f2ae
commit 915f2a2d69
21 changed files with 4589 additions and 399 deletions

7
.gitignore vendored
View File

@ -3,8 +3,11 @@ __pycache__/
*/.ipynb_checkpoints/
.ipynb_checkpoints/
.env
models/
data/
weights/
datasets/
wip
artifacts/
wandb/
scripts/pred/
scripts/pred_resampled/
scripts/lightning_logs/

View File

@ -2,20 +2,59 @@
This repo contains notebooks and scripts demonstrating how to:
- Prepare IGF data for training model detecting P phase (i.e. transform mseeds into [SeisBench data format](https://seisbench.readthedocs.io/en/stable/pages/data_format.html)), check the [notebook](utils/Transforming%20mseeds%20to%20SeisBench%20dataset.ipynb).
The original data can be downloaded from the [drive](https://drive.google.com/drive/folders/1InVI9DLaD7gdzraM2jMzeIrtiBSu-UIK?usp=drive_link)
- Prepare IGF data for training a seisbench model detecting P phase (i.e. transform mseeds into [SeisBench data format](https://seisbench.readthedocs.io/en/stable/pages/data_format.html)), check the [notebook](utils/Transforming%20mseeds%20to%20SeisBench%20dataset.ipynb).
- Explore available data, check the [notebook](notebooks/Explore%20igf%20data.ipynb)
- Train cnn model (Seisbench PhaseNet) to detect P phase, check the [script](scripts/train.py)
- Use Weights & Biases to search for the best training hyperparams, check the [script](scripts/training_wandb_sweep.py)
- Validate model performance, check the [notebook](notebooks/Check%20model%20performance%20depending%20on%20station-random%20window.ipynb)
- Use model for detecting P phase, check the [notebook](notebooks/Present%20model%20predictions.ipynb)
- Train various cnn models available in seisbench library and compare their performance of detecting P phase, check the [script](scripts/pipeline.py)
- [to update] Validate model performance, check the [notebook](notebooks/Check%20model%20performance%20depending%20on%20station-random%20window.ipynb)
- [to update] Use model for detecting P phase, check the [notebook](notebooks/Present%20model%20predictions.ipynb)
### Acknowledgments
This code is based on the [pick-benchmark](https://github.com/seisbench/pick-benchmark), the repository accompanying the paper:
[Which picker fits my data? A quantitative evaluation of deep learning based seismic pickers](https://github.com/seisbench/pick-benchmark#:~:text=Which%20picker%20fits%20my%20data%3F%20A%20quantitative%20evaluation%20of%20deep%20learning%20based%20seismic%20pickers)
### Usage
1. Install all dependencies with poetry, run:
`poetry install`
2. Prepare .env file with content:
`WANDB_API_KEY="your key"`
`poetry install`
2. Prepare .env file with content:
```
WANDB_HOST="https://epos-ai.grid.cyfronet.pl/"
WANDB_API_KEY="your key"
WANDB_USER="your user"
WANDB_PROJECT="training_seisbench_models_on_igf_data"
BENCHMARK_DEFAULT_WORKER=2
3. Transform data into seisbench format.
* Download original data from the [drive](https://drive.google.com/drive/folders/1InVI9DLaD7gdzraM2jMzeIrtiBSu-UIK?usp=drive_link)
* Run the notebook: `utils/Transforming mseeds to SeisBench dataset.ipynb`
4. Initialize poetry environment:
`poetry shell`
5. Run the pipeline script:
`python pipeline.py`
The script performs the following steps:
* Generates evaluation targets
* Trains multiple versions of GPD, PhaseNet and ... models to find the best hyperparameters, producing the lowest validation loss.
This step utilizes the Weights & Biases platform to perform the hyperparameters search (called sweeping) and track the training process and store the results.
The results are available at
`https://epos-ai.grid.cyfronet.pl/<your user name>/<your project name>`
* Uses the best performing model of each type to generate predictions
* Evaluates the performance of each model by comparing the predictions with the evaluation targets
* Saves the results in the `scripts/pred` directory
*
The default settings are saved in config.json file. To change the settings, edit the config.json file or pass the new settings as arguments to the script.
For example, to change the sweep configuration file for GPD model, run:
`python pipeline.py --gpd_config <new config file>`
The new config file should be placed in the `experiments` or as specified in the `configs_path` parameter in the config.json file.
### Troubleshooting
* `wandb: ERROR Run .. errored: OSError(24, 'Too many open files')`
-> https://github.com/wandb/wandb/issues/2825

15
config.json Normal file
View File

@ -0,0 +1,15 @@
{
"dataset_name": "igf",
"data_path": "datasets/igf/seisbench_format/",
"targets_path": "datasets/targets/igf",
"models_path": "weights",
"configs_path": "experiments",
"sampling_rate": 100,
"num_workers": 1,
"seed": 10,
"sweep_files": {
"GPD": "sweep_gpd.yaml",
"PhaseNet": "sweep_phasenet.yaml"
},
"experiment_count": 20
}

View File

@ -1,11 +0,0 @@
{
"epochs": 10,
"batch_size": 256,
"dataset": "igf_1",
"sampling_rate": 100,
"model_names": "EQTransformer,BasicPhaseAE,GPD",
"model_name": "PhaseNet",
"learning_rate": 0.01,
"pretrained": null,
"sampling_rate": 100
}

View File

@ -0,0 +1,26 @@
name: GPD_fixed_highpass:2-10
method: bayes
metric:
goal: minimize
name: val_loss
parameters:
model_name:
value:
- GPD
batch_size:
distribution: int_uniform
max: 1024
min: 256
max_epochs:
value:
- 3
learning_rate:
distribution: uniform
max: 0.02
min: 0.005
highpass:
value:
- 2
lowpass:
value:
- 10

View File

@ -0,0 +1,27 @@
name: GPD_fixed_highpass:2-10
method: bayes
metric:
goal: minimize
name: val_loss
parameters:
model_name:
value:
- GPD
batch_size:
distribution: int_uniform
max: 1024
min: 256
max_epochs:
value:
- 15
learning_rate:
distribution: uniform
max: 0.02
min: 0.005
highpass:
distribution: uniform
min: 0.5
max: 2.0
lowpass:
value:
- 10

View File

@ -0,0 +1,20 @@
name: PhaseNet-lr0.005-0.02-bs256-1024
method: bayes
metric:
goal: minimize
name: val_loss
parameters:
model_name:
value:
- PhaseNet
batch_size:
distribution: int_uniform
max: 1024
min: 256
max_epochs:
value:
- 15
learning_rate:
distribution: uniform
max: 0.02
min: 0.005

1605
poetry.lock generated

File diff suppressed because it is too large Load Diff

View File

@ -14,6 +14,8 @@ pandas = "^2.0.3"
obspy = "^1.4.0"
wandb = "^0.15.4"
torchmetrics = "^0.11.4"
ipykernel = "^6.24.0"
jupyterlab = "^4.0.2"
[tool.poetry.dev-dependencies]

133
scripts/augmentations.py Normal file
View File

@ -0,0 +1,133 @@
"""
This file contains augmentations required for the models that are too specific to be merged into SeisBench.
"""
import numpy as np
import copy
class DuplicateEvent:
"""
Adds a rescaled version of the event to the empty part of the trace after the event.
Event position and empty space are determined from a detection.
Detections can be generated for example with :py:class:`~seisbench.generate.labeling.DetectionLabeller`.
This implementation is modelled after the `implementation for EQTransformer <https://github.com/smousavi05/EQTransformer/blob/98676017f971efbb6f4475f42e415c3868d00c03/EQTransformer/core/EqT_utils.py#L255>`_.
.. warning::
This augmentation does **not** modify the metadata, as representing multiple picks of
the same type is currently not supported. Workflows should therefore always first generate
labels from metadata and then pass the labels in the key `label_keys`. These keys are automatically
adjusted by addition of the labels.
.. warning::
This implementation currently has strict shape requirements:
- (1, samples) for detection
- (channels, samples) for data
- (labels, samples) for labels
:param inv_scale: The scale factor is defined by as 1/u, where u is uniform.
`inv_scale` defines the minimum and maximum values for u.
Defaults to (1, 10), e.g., scaling by factor 1 to 1/10.
:param detection_key: Key to read detection from.
If key is a tuple, detection will be read from the first key and written to the second one.
:param key: The keys for reading from and writing to the state dict.
If key is a single string, the corresponding entry in state dict is modified.
Otherwise, a 2-tuple is expected, with the first string indicating the key
to read from and the second one the key to write to.
:param label_keys: Keys for the label columns.
Labels of the original and duplicate events will be added and capped at 1.
Note that this will lead to invalid noise traces.
Value can either be a single key specification or a list of key specifications.
Each key specification is either a string, for identical input and output keys,
or as a tuple of two strings, input and output keys.
Defaults to None.
"""
def __init__(
self, inv_scale=(1, 10), detection_key="detections", key="X", label_keys=None
):
if isinstance(detection_key, str):
self.detection_key = (detection_key, detection_key)
else:
self.detection_key = detection_key
if isinstance(key, str):
self.key = (key, key)
else:
self.key = key
# Single key
if not isinstance(label_keys, list):
if label_keys is None:
label_keys = []
else:
label_keys = [label_keys]
# Resolve identical input and output keys
self.label_keys = []
for key in label_keys:
if isinstance(key, tuple):
self.label_keys.append(key)
else:
self.label_keys.append((key, key))
self.inv_scale = inv_scale
def __call__(self, state_dict):
x, metadata = state_dict[self.key[0]]
detection, _ = state_dict[self.detection_key[0]]
detection_mask = detection[0] > 0.5
if detection.shape[-1] != x.shape[-1]:
raise ValueError("Number of samples in trace and detection disagree.")
if self.key[0] != self.key[1]:
# Ensure metadata is not modified inplace unless input and output key are anyhow identical
metadata = copy.deepcopy(metadata)
if detection_mask.any():
n_samples = x.shape[-1]
event_samples = np.arange(n_samples)[detection_mask]
event_start, event_end = np.min(event_samples), np.max(event_samples) + 1
if event_end + 20 < n_samples:
second_start = np.random.randint(event_end + 20, n_samples)
scale = 1 / np.random.uniform(*self.inv_scale)
if self.key[0] != self.key[1]:
# Avoid inplace modification if input and output keys differ
x = x.copy()
space = min(event_end - event_start, n_samples - second_start)
x[:, second_start : second_start + space] += (
x[:, event_start : event_start + space] * scale
)
shift = second_start - event_start
for label_key in self.label_keys + [self.detection_key]:
y, metadata = state_dict[label_key[0]]
if y.shape[-1] != n_samples:
raise ValueError(
f"Number of samples disagree between trace and label key '{label_key[0]}'."
)
if label_key[0] != label_key[1]:
metadata = copy.deepcopy(metadata)
y = y.copy()
y[:, shift:] += y[:, :-shift]
y = np.minimum(y, 1)
state_dict[label_key[1]] = (y, metadata)
else:
# Copy entries
for label_key in self.label_keys + [self.detection_key]:
y, metadata = state_dict[label_key[0]]
if label_key[0] != label_key[1]:
metadata = copy.deepcopy(metadata)
y = y.copy()
state_dict[label_key[1]] = (y, metadata)
state_dict[self.key[1]] = (x, metadata)

335
scripts/collect_results.py Normal file
View File

@ -0,0 +1,335 @@
"""
This script collects results in a folder, calculates performance metrics and writes them to csv.
"""
import argparse
from pathlib import Path
import logging
import pandas as pd
import numpy as np
from sklearn.metrics import (
precision_recall_curve,
precision_recall_fscore_support,
roc_auc_score,
matthews_corrcoef,
)
from tqdm import tqdm
def traverse_path(path, output, cross=False, resampled=False, baer=False):
"""
Traverses the given path and extracts results for each experiment and version
:param path: Root path
:param output: Path to write results csv to
:param cross: If true, expects cross-domain results.
:return: None
"""
path = Path(path)
results = []
exp_dirs = [x for x in path.iterdir() if x.is_dir()]
for exp_dir in tqdm(exp_dirs):
itr = exp_dir.iterdir()
if baer:
itr = [exp_dir] # Missing version directory in the structure
for version_dir in itr:
if not version_dir.is_dir():
pass
results.append(
process_version(
version_dir, cross=cross, resampled=resampled, baer=baer
)
)
results = pd.DataFrame(results)
if cross:
sort_keys = ["data", "model", "target", "lr", "version"]
else:
sort_keys = ["data", "model", "lr", "version"]
results.sort_values(sort_keys, inplace=True)
results.to_csv(output, index=False)
def process_version(version_dir: Path, cross: bool, resampled: bool, baer: bool):
"""
Extracts statistics for the given version of the given experiment.
:param version_dir: Path to the specific version
:param cross: If true, expects cross-domain results.
:return: Results dictionary
"""
stats = parse_exp_name(version_dir, cross=cross, resampled=resampled, baer=baer)
stats.update(eval_task1(version_dir))
stats.update(eval_task23(version_dir))
return stats
def parse_exp_name(version_dir, cross, resampled, baer):
if baer:
exp_name = version_dir.name
version = "0"
else:
exp_name = version_dir.parent.name
version = version_dir.name.split("_")[-1]
parts = exp_name.split("_")
target = None
sampling_rate = None
if cross or baer:
if len(parts) == 4:
data, model, lr, target = parts
else:
data, model, target = parts
lr = "0.001"
elif resampled:
if len(parts) == 5:
data, model, lr, target, sampling_rate = parts
else:
data, model, target, sampling_rate = parts
lr = "0.001"
else:
if len(parts) == 3:
data, model, lr = parts
else:
data, model, *_ = parts
lr = "0.001"
# lr = float(lr)
stats = {
"experiment": exp_name,
"data": data,
"model": model,
"lr": None,
"version": version,
}
if cross or baer:
stats["target"] = target
if resampled:
stats["target"] = target
stats["sampling_rate"] = sampling_rate
return stats
def eval_task1(version_dir: Path):
if not (
(version_dir / "dev_task1.csv").is_file()
and (version_dir / "test_task1.csv").is_file()
):
logging.warning(f"Directory {version_dir} does not contain task 1")
return {}
stats = {}
dev_pred = pd.read_csv(version_dir / "dev_task1.csv")
dev_pred["trace_type_bin"] = dev_pred["trace_type"] == "earthquake"
test_pred = pd.read_csv(version_dir / "test_task1.csv")
test_pred["trace_type_bin"] = test_pred["trace_type"] == "earthquake"
prec, recall, thr = precision_recall_curve(
dev_pred["trace_type_bin"], dev_pred["score_detection"]
)
f1 = 2 * prec * recall / (prec + recall)
auc = roc_auc_score(dev_pred["trace_type_bin"], dev_pred["score_detection"])
opt_index = np.nanargmax(f1) # F1 optimal threshold index
opt_thr = thr[opt_index] # F1 optimal threshold value
dev_stats = {
"dev_det_precision": prec[opt_index],
"dev_det_recall": recall[opt_index],
"dev_det_f1": f1[opt_index],
"dev_det_auc": auc,
"det_threshold": opt_thr,
}
stats.update(dev_stats)
prec, recall, f1, _ = precision_recall_fscore_support(
test_pred["trace_type_bin"],
test_pred["score_detection"] > opt_thr,
average="binary",
)
auc = roc_auc_score(test_pred["trace_type_bin"], test_pred["score_detection"])
test_stats = {
"test_det_precision": prec,
"test_det_recall": recall,
"test_det_f1": f1,
"test_det_auc": auc,
}
stats.update(test_stats)
return stats
def eval_task23(version_dir: Path):
print(version_dir / "dev_task23.csv")
if not (
(version_dir / "dev_task23.csv").is_file()
and (version_dir / "test_task23.csv").is_file()
):
logging.warning(f"Directory {version_dir} does not contain tasks 2 and 3")
return {}
stats = {}
dev_pred = pd.read_csv(version_dir / "dev_task23.csv")
dev_pred["phase_label_bin"] = dev_pred["phase_label"] == "P"
test_pred = pd.read_csv(version_dir / "test_task23.csv")
test_pred["phase_label_bin"] = test_pred["phase_label"] == "P"
def add_aux_columns(pred):
for col in ["s_sample_pred", "score_p_or_s"]:
if col not in pred.columns:
pred[col] = np.nan
add_aux_columns(dev_pred)
add_aux_columns(test_pred)
def nanmask(pred):
"""
Returns all entries that are nan in score_p_or_s, p_sample_pred and s_sample_pred
"""
mask = np.logical_and(
np.isnan(pred["p_sample_pred"]), np.isnan(pred["s_sample_pred"])
)
mask = np.logical_and(mask, np.isnan(pred["score_p_or_s"]))
return mask
if nanmask(dev_pred).all():
logging.warning(f"{version_dir} contains NaN predictions for tasks 2 and 3")
return {}
dev_pred = dev_pred[~nanmask(dev_pred)]
test_pred = test_pred[~nanmask(test_pred)]
skip_task2 = False
if (
np.logical_or(
np.isnan(dev_pred["score_p_or_s"]), np.isinf(dev_pred["score_p_or_s"])
).all()
or np.logical_or(
np.isnan(test_pred["score_p_or_s"]), np.isinf(test_pred["score_p_or_s"])
).all()
):
# For unfortunate combinations of nans and infs, otherwise weird scores can occur
skip_task2 = True
# Clipping removes infinitely likely P waves, usually resulting from models trained without S arrivals
dev_pred["score_p_or_s"] = np.clip(dev_pred["score_p_or_s"].values, -1e100, 1e100)
test_pred["score_p_or_s"] = np.clip(test_pred["score_p_or_s"].values, -1e100, 1e100)
dev_pred_restricted = dev_pred[~np.isnan(dev_pred["score_p_or_s"])]
test_pred_restricted = test_pred[~np.isnan(test_pred["score_p_or_s"])]
if len(dev_pred_restricted) > 0 and not skip_task2:
prec, recall, thr = precision_recall_curve(
dev_pred_restricted["phase_label_bin"], dev_pred_restricted["score_p_or_s"]
)
f1 = 2 * prec * recall / (prec + recall)
opt_index = np.nanargmax(f1) # F1 optimal threshold index
opt_thr = thr[opt_index] # F1 optimal threshold value
# Determine (approximately) optimal MCC threshold using 50 candidates
mcc_thrs = np.sort(dev_pred["score_p_or_s"].values)
mcc_thrs = mcc_thrs[np.linspace(0, len(mcc_thrs) - 1, 50, dtype=int)]
mccs = []
for thr in mcc_thrs:
mccs.append(
matthews_corrcoef(
dev_pred["phase_label_bin"], dev_pred["score_p_or_s"] > thr
)
)
mcc = np.max(mccs)
mcc_thr = mcc_thrs[np.argmax(mccs)]
dev_stats = {
"dev_phase_precision": prec[opt_index],
"dev_phase_recall": recall[opt_index],
"dev_phase_f1": f1[opt_index],
"phase_threshold": opt_thr,
"dev_phase_mcc": mcc,
"phase_threshold_mcc": mcc_thr,
}
stats.update(dev_stats)
prec, recall, f1, _ = precision_recall_fscore_support(
test_pred_restricted["phase_label_bin"],
test_pred_restricted["score_p_or_s"] > opt_thr,
average="binary",
)
mcc = matthews_corrcoef(
test_pred["phase_label_bin"], test_pred["score_p_or_s"] > mcc_thr
)
test_stats = {
"test_phase_precision": prec,
"test_phase_recall": recall,
"test_phase_f1": f1,
"test_phase_mcc": mcc,
}
stats.update(test_stats)
for pred, set_str in [(dev_pred, "dev"), (test_pred, "test")]:
for i, phase in enumerate(["P", "S"]):
pred_phase = pred[pred["phase_label"] == phase]
pred_col = f"{phase.lower()}_sample_pred"
if len(pred_phase) == 0:
continue
diff = (pred_phase[pred_col] - pred_phase["phase_onset"]) / pred_phase[
"sampling_rate"
]
stats[f"{set_str}_{phase}_mean_s"] = np.mean(diff)
stats[f"{set_str}_{phase}_std_s"] = np.sqrt(np.mean(diff**2))
stats[f"{set_str}_{phase}_mae_s"] = np.mean(np.abs(diff))
return stats
if __name__ == "__main__":
parser = argparse.ArgumentParser(
description="Collects results from all experiments in a folder and outputs them in condensed csv format."
)
parser.add_argument(
"path",
type=str,
help="Root path of predictions",
)
parser.add_argument(
"output",
type=str,
help="Path for the output csv",
)
parser.add_argument(
"--cross", action="store_true", help="If true, expects cross-domain results."
)
parser.add_argument(
"--resampled",
action="store_true",
help="If true, expects cross-domain cross-sampling rate results.",
)
parser.add_argument(
"--baer",
action="store_true",
help="If true, expects results from Baer-Kradolfer picker.",
)
args = parser.parse_args()
traverse_path(
args.path,
args.output,
cross=args.cross,
resampled=args.resampled,
baer=args.baer,
)

26
scripts/config_loader.py Normal file
View File

@ -0,0 +1,26 @@
import json
from pathlib import Path
def load_config(file_path):
with open(file_path, 'r') as config_file:
config = json.load(config_file)
return config
project_path = str(Path.cwd().parent)
config_path = project_path + "/config.json"
config = load_config(config_path)
data_path = f"{project_path}/{config['data_path']}"
models_path = f"{project_path}/{config['models_path']}"
targets_path = f"{project_path}/{config['targets_path']}"
dataset_name = config['dataset_name']
configs_path = f"{project_path}/{config['configs_path']}"
sweep_files = config['sweep_files']
sampling_rate = config['sampling_rate']
num_workers = config['num_workers']
seed = config['seed']
experiment_count = config['experiment_count']

32
scripts/data.py Normal file
View File

@ -0,0 +1,32 @@
"""
This file contains functionality related to data.
"""
import seisbench.data as sbd
def get_dataset_by_name(name):
"""
Resolve dataset name to class from seisbench.data.
:param name: Name of dataset as defined in seisbench.data.
:return: Dataset class from seisbench.data
"""
try:
return sbd.__getattribute__(name)
except AttributeError:
raise ValueError(f"Unknown dataset '{name}'.")
def get_custom_dataset(path):
"""
Return custom dataset in seisbench format
:param path:
:return: Dataset class
"""
try:
return sbd.WaveformDataset(path)
except AttributeError:
raise ValueError(f"Unknown dataset '{path}'.")

248
scripts/eval.py Normal file
View File

@ -0,0 +1,248 @@
"""
This script implements functionality for evaluating models.
Given a model and a set of targets, it calculates and outputs predictions.
"""
import seisbench.generate as sbg
import argparse
import pandas as pd
import yaml
from pathlib import Path
import pytorch_lightning as pl
from torch.utils.data import DataLoader
import torch
import models, data
import logging
from util import default_workers, load_best_model_data
import time
import datetime
from config_loader import models_path, project_path
import os
data_aliases = {
"ethz": "ETHZ",
"geofon": "GEOFON",
"stead": "STEAD",
"neic": "NEIC",
"instance": "InstanceCountsCombined",
"iquique": "Iquique",
"lendb": "LenDB",
"scedc": "SCEDC"
}
def main(weights, targets, sets, batchsize, num_workers, sampling_rate=None, sweep_id=None, test_run=False):
weights = Path(weights)
targets = Path(os.path.abspath(targets))
print(targets)
# print()
sets = sets.split(",")
checkpoint_path, version = load_best_model_data(sweep_id, weights)
logging.warning("Starting evaluation of model: \n" + checkpoint_path)
config_path = f"{models_path}/{weights}/{version}/hparams.yaml"
with open(config_path, "r") as f:
# config = yaml.safe_load(f)
config = yaml.full_load(f)
model_name = config["model_name"][0] if type(config["model_name"]) == list else config["model_name"]
model_cls = models.__getattribute__(model_name + "Lit")
model = model_cls.load_from_checkpoint(checkpoint_path)
data_name = data_aliases[targets.name] if targets.name in data_aliases else None
if data_name != config["dataset_name"] and targets.name in data_aliases:
logging.warning("Detected cross-domain evaluation")
pred_root = "pred_cross"
parts = weights.name.split()
weight_path_name = "_".join(parts[:2] + [targets.name] + parts[2:])
else:
pred_root = "pred"
weight_path_name = weights.name
if data_name is not None:
dataset = data.get_dataset_by_name(data_name)(
sampling_rate=100, component_order="ZNE", dimension_order="NCW", cache="full"
)
else:
data_path = project_path + '/' + config['data_path']
print("Loading dataset: ", data_path)
dataset = data.get_custom_dataset(data_path)
if sampling_rate is not None:
dataset.sampling_rate = sampling_rate
pred_root = pred_root + "_resampled"
weight_path_name = weight_path_name + f"_{sampling_rate}"
for eval_set in sets:
split = dataset.get_split(eval_set)
if targets.name == "instance":
logging.warning(
"Overwriting noise trace_names to allow correct identification"
)
# Replace trace names for noise entries
split._metadata["trace_name"].values[
-len(split.datasets[-1]) :
] = split._metadata["trace_name"][-len(split.datasets[-1]) :].apply(
lambda x: "noise_" + x
)
split._build_trace_name_to_idx_dict()
logging.warning(f"Starting set {eval_set}")
split.preload_waveforms(pbar=True)
print("eval set shape", split.metadata.shape)
for task in ["1", "23"]:
task_csv = targets / f"task{task}.csv"
print(task_csv)
if not task_csv.is_file():
continue
logging.warning(f"Starting task {task}")
task_targets = pd.read_csv(task_csv)
task_targets = task_targets[task_targets["trace_split"] == eval_set]
if task == "1" and targets.name == "instance":
border = _identify_instance_dataset_border(task_targets)
task_targets["trace_name"].values[border:] = task_targets["trace_name"][
border:
].apply(lambda x: "noise_" + x)
if sampling_rate is not None:
for key in ["start_sample", "end_sample", "phase_onset"]:
if key not in task_targets.columns:
continue
task_targets[key] = (
task_targets[key]
* sampling_rate
/ task_targets["sampling_rate"]
)
task_targets[sampling_rate] = sampling_rate
restrict_to_phase = config.get("restrict_to_phase", None)
if restrict_to_phase is not None and "phase_label" in task_targets.columns:
mask = task_targets["phase_label"].isin(list(restrict_to_phase))
task_targets = task_targets[mask]
if restrict_to_phase is not None and task == "1":
logging.warning("Skipping task 1 as restrict_to_phase is set.")
continue
generator = sbg.SteeredGenerator(split, task_targets)
generator.add_augmentations(model.get_eval_augmentations())
loader = DataLoader(
generator, batch_size=batchsize, shuffle=False, num_workers=num_workers
)
# trainer = pl.Trainer(accelerator="gpu", devices=1)
trainer = pl.Trainer()
predictions = trainer.predict(model, loader)
# Merge batches
merged_predictions = []
for i, _ in enumerate(predictions[0]):
merged_predictions.append(torch.cat([x[i] for x in predictions]))
merged_predictions = [x.cpu().numpy() for x in merged_predictions]
task_targets["score_detection"] = merged_predictions[0]
task_targets["score_p_or_s"] = merged_predictions[1]
task_targets["p_sample_pred"] = (
merged_predictions[2] + task_targets["start_sample"]
)
task_targets["s_sample_pred"] = (
merged_predictions[3] + task_targets["start_sample"]
)
pred_path = (
weights.parent.parent
/ pred_root
/ weight_path_name
/ version
/ f"{eval_set}_task{task}.csv"
)
pred_path.parent.mkdir(exist_ok=True, parents=True)
task_targets.to_csv(pred_path, index=False)
def _identify_instance_dataset_border(task_targets):
"""
Calculates the dataset border between Signal and Noise for instance,
assuming it is the only place where the bucket number does not increase
"""
buckets = task_targets["trace_name"].apply(lambda x: int(x.split("$")[0][6:]))
last_bucket = 0
for i, bucket in enumerate(buckets):
if bucket < last_bucket:
return i
last_bucket = bucket
if __name__ == "__main__":
code_start_time = time.perf_counter()
parser = argparse.ArgumentParser(
description="Evaluate a trained model using a set of targets."
)
parser.add_argument(
"weights",
type=str,
help="Path to weights. Expected to be in models_path directory."
"The script will automatically load the configuration and the model. "
"The script always uses the checkpoint with lowest validation loss."
"If sweep_id is provided the script considers only the checkpoints generated by that sweep."
"Predictions will be written into the weights path as csv."
"Note: Due to pytorch lightning internals, there exist two weights folders, "
"{weights} and {weight}_{weights}. Please use the former as parameter",
)
parser.add_argument(
"targets",
type=str,
help="Path to evaluation targets folder. "
"The script will detect which tasks are present base on file names.",
)
parser.add_argument(
"--sets",
type=str,
default="dev,test",
help="Sets on which to evaluate, separated by commata. Defaults to dev and test.",
)
parser.add_argument("--batchsize", type=int, default=1024, help="Batch size")
parser.add_argument(
"--num_workers",
default=default_workers,
type=int,
help="Number of workers for data loader",
)
parser.add_argument(
"--sampling_rate", type=float, help="Overwrites the sampling rate in the data"
)
parser.add_argument(
"--sweep_id", type=str, help="wandb sweep_id", required=False, default=None
)
parser.add_argument(
"--test_run", action="store_true", required=False, default=False
)
args = parser.parse_args()
main(
args.weights,
args.targets,
args.sets,
batchsize=args.batchsize,
num_workers=args.num_workers,
sampling_rate=args.sampling_rate,
sweep_id=args.sweep_id,
test_run=args.test_run
)
running_time = str(
datetime.timedelta(seconds=time.perf_counter() - code_start_time)
)
print(f"Running time: {running_time}")

View File

@ -0,0 +1,321 @@
"""
This script generates evaluation targets for the following three tasks:
- Earthquake detection (Task 1): Given a 30~s window, does the window contain an earthquake signal?
- Phase identification (Task 2): Given a 10~s window containing exactly one phase onset, identify which phase type.
- Onset time determination (Task 3): Given a 10~s window containing exactly one phase onset, identify the onset time.
Each target for evaluation will consist of the following information:
- trace name (as in dataset)
- trace index (in dataset)
- split (as in dataset)
- sampling rate (at which all information is presented)
- start_sample
- end_sample
- trace_type (only task 1: earthquake/noise)
- phase_label (only task 2/3: P/S)
- full_phase_label (only task 2/3: phase label as in the dataset, might be Pn, Pg, etc.)
- phase_onset_sample (only task 2/3: onset sample of the phase relative to full trace)
It needs to be provided with a dataset and writes a folder with two CSV files, one for task 1 and one for tasks 2 and 3.
Each file will describe targets for train, dev and test, derived from the respective splits.
When using these tasks for evaluation, the models can make use of waveforms from the context, i.e.,
before/after the start and end samples. However, make sure this does not add further bias in the evaluation,
for example by always centring the windows on the picks using context.
.. warning::
For comparability, it is strongly advised to use published evaluation targets, instead of generating new ones.
.. warning::
This implementation is not optimized and loads the full waveform data for its computations.
This will lead to very high memory usage, as the full dataset will be stored in memory.
"""
import seisbench.data as sbd
import argparse
from pathlib import Path
import pandas as pd
import numpy as np
from tqdm import tqdm
from models import phase_dict
def main(dataset_name, output, tasks, sampling_rate, noise_before_events):
np.random.seed(42)
tasks = [str(i) in tasks.split(",") for i in range(1, 4)]
if not any(tasks):
raise ValueError(f"No task defined. Got tasks='{tasks}'.")
dataset_args = {
"sampling_rate": sampling_rate,
"dimension_order": "NCW",
"cache": "full",
}
try:
# Check if dataset is available in SeisBench
dataset = sbd.__getattribute__(dataset_name)(**dataset_args)
except AttributeError:
# Otherwise interpret data_in as path
dataset = sbd.WaveformDataset(dataset_name, **dataset_args)
output = Path(output)
output.mkdir(parents=True, exist_ok=False)
if "split" in dataset.metadata.columns:
dataset.filter(dataset["split"].isin(["dev", "test"]), inplace=True)
dataset.preload_waveforms(pbar=True)
if tasks[0]:
generate_task1(dataset, output, sampling_rate, noise_before_events)
if tasks[1] or tasks[2]:
generate_task23(dataset, output, sampling_rate)
def generate_task1(dataset, output, sampling_rate, noise_before_events):
np.random.seed(42)
windowlen = 30 * sampling_rate # 30 s windows
labels = []
for i in tqdm(range(len(dataset)), total=len(dataset)):
waveforms, metadata = dataset.get_sample(i)
if "split" in metadata:
trace_split = metadata["split"]
else:
trace_split = ""
def checkphase(metadata, phase, phase_label, target_phase, npts):
return (
phase in metadata
and phase_label == target_phase
and not np.isnan(metadata[phase])
and 0 <= metadata[phase] < npts
)
p_arrivals = [
metadata[phase]
for phase, phase_label in phase_dict.items()
if checkphase(metadata, phase, phase_label, "P", waveforms.shape[-1])
]
s_arrivals = [
metadata[phase]
for phase, phase_label in phase_dict.items()
if checkphase(metadata, phase, phase_label, "S", waveforms.shape[-1])
]
if len(p_arrivals) == 0 and len(s_arrivals) == 0:
start_sample, end_sample = select_window_containing(
waveforms.shape[-1], windowlen
)
sample = {
"trace_name": metadata["trace_name"],
"trace_idx": i,
"trace_split": trace_split,
"sampling_rate": sampling_rate,
"start_sample": start_sample,
"end_sample": end_sample,
"trace_type": "noise",
}
labels += [sample]
else:
first_arrival = min(p_arrivals + s_arrivals)
start_sample, end_sample = select_window_containing(
waveforms.shape[-1], windowlen, containing=first_arrival
)
if end_sample - start_sample <= windowlen:
sample = {
"trace_name": metadata["trace_name"],
"trace_idx": i,
"trace_split": trace_split,
"sampling_rate": sampling_rate,
"start_sample": start_sample,
"end_sample": end_sample,
"trace_type": "earthquake",
}
labels += [sample]
if noise_before_events and first_arrival > windowlen:
start_sample, end_sample = select_window_containing(
min(waveforms.shape[-1], first_arrival), windowlen
)
if end_sample - start_sample <= windowlen:
sample = {
"trace_name": metadata["trace_name"],
"trace_idx": i,
"trace_split": trace_split,
"sampling_rate": sampling_rate,
"start_sample": start_sample,
"end_sample": end_sample,
"trace_type": "noise",
}
labels += [sample]
labels = pd.DataFrame(labels)
diff = labels["end_sample"] - labels["start_sample"]
labels = labels[diff > 100]
labels.to_csv(output / "task1.csv", index=False)
def generate_task23(dataset, output, sampling_rate):
np.random.seed(42)
windowlen = 10 * sampling_rate # 30 s windows
labels = []
for idx in tqdm(range(len(dataset)), total=len(dataset)):
waveforms, metadata = dataset.get_sample(idx)
if "split" in metadata:
trace_split = metadata["split"]
else:
trace_split = ""
def checkphase(metadata, phase, npts):
return (
phase in metadata
and not np.isnan(metadata[phase])
and 0 <= metadata[phase] < npts
)
# Example entry: (1031, "P", "Pg")
arrivals = sorted(
[
(metadata[phase], phase_label, phase.split("_")[1])
for phase, phase_label in phase_dict.items()
if checkphase(metadata, phase, waveforms.shape[-1])
]
)
if len(arrivals) == 0:
# Trace has no arrivals
continue
for i, (onset, phase, full_phase) in enumerate(arrivals):
if i == 0:
onset_before = 0
else:
onset_before = int(arrivals[i - 1][0]) + int(
0.5 * sampling_rate
) # 0.5 s minimum spacing
if i == len(arrivals) - 1:
onset_after = np.inf
else:
onset_after = int(arrivals[i + 1][0]) - int(
0.5 * sampling_rate
) # 0.5 s minimum spacing
if (
onset_after - onset_before < windowlen
or onset_before > onset
or onset_after < onset
):
# Impossible to isolate pick
continue
else:
onset_after = min(onset_after, waveforms.shape[-1])
# Shift everything to a "virtual" start at onset_before
start_sample, end_sample = select_window_containing(
onset_after - onset_before,
windowlen=windowlen,
containing=onset - onset_before,
bounds=(50, 50),
)
start_sample += onset_before
end_sample += onset_before
if end_sample - start_sample <= windowlen:
sample = {
"trace_name": metadata["trace_name"],
"trace_idx": idx,
"trace_split": trace_split,
"sampling_rate": sampling_rate,
"start_sample": start_sample,
"end_sample": end_sample,
"phase_label": phase,
"full_phase_label": full_phase,
"phase_onset": onset,
}
labels += [sample]
labels = pd.DataFrame(labels)
diff = labels["end_sample"] - labels["start_sample"]
labels = labels[diff > 100]
labels.to_csv(output / "task23.csv", index=False)
def select_window_containing(npts, windowlen, containing=None, bounds=(100, 100)):
"""
Selects a window from a larger trace.
:param npts: Number of points of the full trace
:param windowlen: Desired windowlen
:param containing: Sample number that should be contained. If None, any window within the trace is valid.
:param bounds: The containing sample may not be in the first/last samples indicated here.
:return: Start sample, end_sample
"""
if npts <= windowlen:
# If npts is smaller than the window length, always return the full window
return 0, npts
else:
if containing is None:
start_sample = np.random.randint(0, npts - windowlen + 1)
return start_sample, start_sample + windowlen
else:
earliest_start = max(0, containing - windowlen + bounds[1])
latest_start = min(npts - windowlen, containing - bounds[0])
if latest_start <= earliest_start:
# Again, return full window
return 0, npts
else:
start_sample = np.random.randint(earliest_start, latest_start + 1)
return start_sample, start_sample + windowlen
if __name__ == "__main__":
parser = argparse.ArgumentParser(
description="Generate evaluation targets. See the docstring for details."
)
parser.add_argument(
"dataset", type=str, help="Path to input dataset or SeisBench dataset name"
)
parser.add_argument(
"output", type=str, help="Path to write target files to. Must not exist."
)
parser.add_argument(
"--tasks",
type=str,
default="1,2,3",
help="Which tasks to generate data for. By default generates data for all tasks.",
)
parser.add_argument(
"--sampling_rate",
type=float,
default=100,
help="Sampling rate in Hz to generate targets for.",
)
parser.add_argument(
"--no_noise_before_events",
action="store_true",
help="If set, does not extract noise from windows before the first arrival.",
)
args = parser.parse_args()
main(
args.dataset,
args.output,
args.tasks,
args.sampling_rate,
not args.no_noise_before_events,
)

View File

@ -0,0 +1,178 @@
# -----------------
# Copyright © 2023 ACK Cyfronet AGH, Poland.
# This work was partially funded by EPOS Project funded in frame of PL-POIR4.2
# -----------------
import os.path
import argparse
from pytorch_lightning.loggers import WandbLogger, CSVLogger
from pytorch_lightning.callbacks import ModelCheckpoint
from pytorch_lightning.callbacks.early_stopping import EarlyStopping
import pytorch_lightning as pl
import wandb
import torch
import traceback
import logging
from dotenv import load_dotenv
import models
import train
import util
from config_loader import config as common_config
from config_loader import models_path, dataset_name, seed, experiment_count
torch.multiprocessing.set_sharing_strategy('file_system')
load_dotenv()
wandb_api_key = os.environ.get('WANDB_API_KEY')
if wandb_api_key is None:
raise ValueError("WANDB_API_KEY environment variable is not set.")
host = os.environ.get("WANDB_HOST")
if host is None:
raise ValueError("WANDB_HOST environment variable is not set.")
wandb.login(key=wandb_api_key, host=host)
# wandb.login(key=wandb_api_key)
wandb_project_name = os.environ.get("WANDB_PROJECT")
wandb_user_name = os.environ.get("WANDB_USER")
script_name = os.path.splitext(os.path.basename(__file__))[0]
logger = logging.getLogger(script_name)
logger.setLevel(logging.WARNING)
def set_random_seed(seed=3):
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
def get_trainer_args(config):
trainer_args = {'max_epochs': config.max_epochs[0]}
return trainer_args
class HyperparameterSweep:
def __init__(self, project_name, sweep_config):
self.project_name = project_name
self.sweep_config = sweep_config
self.sweep_id = None
def run_sweep(self):
# Create the sweep
self.sweep_id = wandb.sweep(self.sweep_config, project=self.project_name)
logger.info("Created sweep with ID: " + self.sweep_id)
# Run the sweep
wandb.agent(self.sweep_id, function=self.run_experiment, count=experiment_count)
def all_runs_finished(self):
sweep_path = f"{wandb_user_name}/{wandb_project_name}/{self.sweep_id}"
logger.debug(f"Sweep path: {sweep_path}")
sweep_runs = wandb.Api().sweep(sweep_path).runs
all_finished = all(run.state == "finished" for run in sweep_runs)
if all_finished:
logger.info("All runs finished successfully.")
all_not_running = all(run.state != "running" for run in sweep_runs)
if all_not_running and not all_finished:
logger.warning("Some runs are not finished but failed or crashed.")
return all_not_running
def run_experiment(self):
try:
logger.debug("Starting a new run...")
run = wandb.init(
project=self.project_name,
config=common_config,
)
wandb.run.log_code(
".",
include_fn=lambda path: path.endswith(os.path.basename(__file__))
)
model_name = wandb.config.model_name[0]
model_args = models.get_model_specific_args(wandb.config)
logger.debug(f"Initializing {model_name}")
model = models.__getattribute__(model_name + "Lit")(**model_args)
train_loader, dev_loader = train.prepare_data(wandb.config, model, test_run=False)
wandb_logger = WandbLogger(project=self.project_name, log_model="all")
wandb_logger.watch(model)
# CSV logger - also used for saving configuration as yaml
experiment_name = f"{dataset_name}_{model_name}"
csv_logger = CSVLogger(models_path, experiment_name, version=run.id)
csv_logger.log_hyperparams(wandb.config)
loggers = [wandb_logger, csv_logger]
experiment_signature = f"{experiment_name}_sweep={self.sweep_id}-run={run.id}"
logger.debug("Experiment signature: " + experiment_signature)
checkpoint_callback = ModelCheckpoint(
save_top_k=1,
filename=experiment_signature + "-{epoch}-{val_loss:.3f}",
monitor="val_loss",
mode="min",
dirpath=f"{models_path}/{experiment_name}/",
) # save_top_k=1, monitor="val_loss", mode="min": save the best model in terms of validation loss
checkpoint_callback.STARTING_VERSION = 1
early_stopping_callback = EarlyStopping(
monitor="val_loss",
patience=3,
verbose=True,
mode="min")
callbacks = [checkpoint_callback, early_stopping_callback]
trainer = pl.Trainer(
default_root_dir=models_path,
logger=loggers,
callbacks=callbacks,
**get_trainer_args(wandb.config)
)
trainer.fit(model, train_loader, dev_loader)
except Exception as e:
logger.error("caught error: ", str(e))
traceback_str = traceback.format_exc()
logger.error(traceback_str)
run.finish()
def start_sweep(sweep_config):
logger.info("Starting sweep with config: " + str(sweep_config))
set_random_seed(seed)
sweep_runner = HyperparameterSweep(project_name=wandb_project_name, sweep_config=sweep_config)
sweep_runner.run_sweep()
return sweep_runner
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--sweep_config", type=str, required=True)
args = parser.parse_args()
sweep_config = util.load_sweep_config(args.sweep_config)
start_sweep(sweep_config)

1138
scripts/models.py Normal file

File diff suppressed because it is too large Load Diff

92
scripts/pipeline.py Normal file
View File

@ -0,0 +1,92 @@
"""
-----------------
Copyright © 2023 ACK Cyfronet AGH, Poland.
This work was partially funded by EPOS Project funded in frame of PL-POIR4.2
-----------------
This script runs the pipeline for the training and evaluation of the models.
"""
import logging
import time
import argparse
import util
import generate_eval_targets
import hyperparameter_sweep
import eval
import collect_results
from config_loader import data_path, targets_path, sampling_rate, dataset_name, sweep_files
logger = logging.getLogger('pipeline')
logger.setLevel(logging.INFO)
def load_sweep_config(model_name, args):
if model_name == "PhaseNet" and args.phasenet_config is not None:
sweep_fname = args.phasenet_config
elif model_name == "GPD" and args.gpd_config is not None:
sweep_fname = args.gpd_config
else:
# use the default sweep config for the model
sweep_fname = sweep_files[model_name]
logger.info(f"Loading sweep config: {sweep_fname}")
return util.load_sweep_config(sweep_fname)
def find_the_best_params(model_name, args):
# find the best hyperparams for the model_name
logger.info(f"Starting searching for the best hyperparams for the model: {model_name}")
sweep_config = load_sweep_config(model_name, args)
sweep_runner = hyperparameter_sweep.start_sweep(sweep_config)
# wait for all runs to finish
all_finished = sweep_runner.all_runs_finished()
while not all_finished:
logger.info("Waiting for sweep runs to finish...")
# Sleep for a few seconds before checking again
time.sleep(30)
all_finished = sweep_runner.all_runs_finished()
logger.info(f"Finished the sweep: {sweep_runner.sweep_id}")
return sweep_runner.sweep_id
def generate_predictions(sweep_id, model_name):
experiment_name = f"{dataset_name}_{model_name}"
eval.main(weights=experiment_name,
targets=targets_path,
sets='dev,test',
batchsize=128,
num_workers=4,
# sampling_rate=sampling_rate,
sweep_id=sweep_id
)
def main():
parser = argparse.ArgumentParser()
parser.add_argument("--phasenet_config", type=str, required=False)
parser.add_argument("--gpd_config", type=str, required=False)
args = parser.parse_args()
# generate labels
generate_eval_targets.main(data_path, targets_path, "2,3", sampling_rate, None)
# find the best hyperparams for the models
for model_name in ["GPD", "PhaseNet"]:
sweep_id = find_the_best_params(model_name, args)
generate_predictions(sweep_id, model_name)
# collect results
collect_results.traverse_path("pred", "pred/results.csv")
if __name__ == "__main__":
main()

View File

@ -1,339 +1,245 @@
import os.path
import wandb
import seisbench.data as sbd
"""
This script handles the training of models base on model configuration files.
"""
import seisbench.generate as sbg
import seisbench.models as sbm
from seisbench.util import worker_seeding
import numpy as np
import torch
from torch.utils.data import DataLoader
import torch.nn.functional as f
import torch.nn as nn
from torchmetrics import Metric
from torch import Tensor, tensor
import pytorch_lightning as pl
# from pytorch_lightning.loggers import TensorBoardLogger, CSVLogger
from pytorch_lightning.loggers import WandbLogger, CSVLogger
# https://github.com/Lightning-AI/lightning/pull/12554
# https://github.com/Lightning-AI/lightning/issues/11796
from pytorch_lightning.callbacks import ModelCheckpoint
import argparse
import json
from dotenv import load_dotenv
import numpy as np
from torch.utils.data import DataLoader
import torch
import os
import logging
from pathlib import Path
load_dotenv()
wandb_api_key = os.environ.get('WANDB_API_KEY')
if wandb_api_key is None:
raise ValueError("WANDB_API_KEY environment variable is not set.")
import models, data, util
import time
import datetime
import wandb
#
# load_dotenv()
# wandb_api_key = os.environ.get('WANDB_API_KEY')
# if wandb_api_key is None:
# raise ValueError("WANDB_API_KEY environment variable is not set.")
#
# wandb.login(key=wandb_api_key)
wandb.login(key=wandb_api_key)
def train(config, experiment_name, test_run):
"""
Runs the model training defined by the config.
project_path = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
Config parameters:
- model: Model used as in the models.py file, but without the Lit suffix
- data: Dataset used, as in seisbench.data
- model_args: Arguments passed to the constructor of the model lightning module
- trainer_args: Arguments passed to the lightning trainer
- batch_size: Batch size for training and validation
- num_workers: Number of workers for data loading.
If not set, uses environment variable BENCHMARK_DEFAULT_WORKERS
- restrict_to_phase: Filters datasets only to examples containing the given phase.
By default, uses all phases.
- training_fraction: Fraction of training blocks to use as float between 0 and 1. Defaults to 1.
:param config: Configuration parameters for training
:param test_run: If true, makes a test run with less data and less logging. Intended for debug purposes.
"""
model = models.__getattribute__(config["model"] + "Lit")(
**config.get("model_args", {})
)
train_loader, dev_loader = prepare_data(config, model, test_run)
# CSV logger - also used for saving configuration as yaml
csv_logger = CSVLogger("weights", experiment_name)
csv_logger.log_hyperparams(config)
loggers = [csv_logger]
default_root_dir = os.path.join(
"weights"
) # Experiment name is parsed from the loggers
if not test_run:
# tb_logger = TensorBoardLogger("tb_logs", experiment_name)
# tb_logger.log_hyperparams(config)
# loggers += [tb_logger]
wandb_logger = WandbLogger()
wandb_logger.watch(model)
loggers +=[wandb_logger]
checkpoint_callback = ModelCheckpoint(
save_top_k=1, filename="{epoch}-{step}", monitor="val_loss", mode="min"
) # save_top_k=1, monitor="val_loss", mode="min": save the best model in terms of validation loss
callbacks = [checkpoint_callback]
## Uncomment the following 2 lines to enable
# device_stats = DeviceStatsMonitor()
# callbacks.append(device_stats)
trainer = pl.Trainer(
default_root_dir=default_root_dir,
logger=loggers,
callbacks=callbacks,
**config.get("trainer_args", {}),
)
trainer.fit(model, train_loader, dev_loader)
class PickMAE(Metric):
higher_is_better: bool = False
mae_error: Tensor
def prepare_data(config, model, test_run):
"""
Returns the training and validation data loaders
:param config:
:param model:
:param test_run:
:return:
"""
batch_size = config.get("batch_size", 1024)
if type(batch_size) == list:
batch_size = batch_size[0]
def __init__(self, sampling_rate):
super().__init__()
self.add_state("mae_error", default=torch.tensor(0), dist_reduce_fx="sum")
self.sampling_rate = sampling_rate
num_workers = config.get("num_workers", util.default_workers)
try:
dataset = data.get_dataset_by_name(config["dataset_name"])(
sampling_rate=100, component_order="ZNE", dimension_order="NCW", cache="full"
)
except ValueError:
data_path = str(Path.cwd().parent) + '/' + config['data_path']
print(data_path)
dataset = data.get_custom_dataset(data_path)
def update(self, preds: torch.Tensor, target: torch.Tensor):
restrict_to_phase = config.get("restrict_to_phase", None)
if restrict_to_phase is not None:
mask = generate_phase_mask(dataset, restrict_to_phase)
dataset.filter(mask, inplace=True)
assert preds.shape == target.shape
if "split" not in dataset.metadata.columns:
logging.warning("No split defined, adding auxiliary split.")
split = np.array(["train"] * len(dataset))
split[int(0.6 * len(dataset)) : int(0.7 * len(dataset))] = "dev"
split[int(0.7 * len(dataset)) :] = "test"
pred_pick_idx = torch.argmax(preds[:, 0, :], dim=1).type(torch.FloatTensor)
true_pick_idx = torch.argmax(target[:, 0, :], dim=-1).type(torch.FloatTensor)
dataset._metadata["split"] = split
mae = nn.L1Loss()
self.mae_error = mae(pred_pick_idx, true_pick_idx) / self.sampling_rate #mae in seconds
train_data = dataset.train()
dev_data = dataset.dev()
def compute(self):
return self.mae_error.float()
if test_run:
# Only use a small part of the dataset
train_mask = np.zeros(len(train_data), dtype=bool)
train_mask[:5000] = True
train_data.filter(train_mask, inplace=True)
dev_mask = np.zeros(len(dev_data), dtype=bool)
dev_mask[:5000] = True
dev_data.filter(dev_mask, inplace=True)
training_fraction = config.get("training_fraction", 1.0)
apply_training_fraction(training_fraction, train_data)
train_data.preload_waveforms(pbar=True)
dev_data.preload_waveforms(pbar=True)
train_generator = sbg.GenericGenerator(train_data)
dev_generator = sbg.GenericGenerator(dev_data)
train_generator.add_augmentations(model.get_train_augmentations())
dev_generator.add_augmentations(model.get_val_augmentations())
train_loader = DataLoader(
train_generator,
batch_size=batch_size,
shuffle=True,
num_workers=num_workers,
worker_init_fn=worker_seeding,
drop_last=True, # Avoid crashes from batch norm layers for batch size 1
)
dev_loader = DataLoader(
dev_generator,
batch_size=batch_size,
num_workers=num_workers,
worker_init_fn=worker_seeding,
)
return train_loader, dev_loader
class EarlyStopper:
def __init__(self, patience=1, min_delta=0):
self.patience = patience
self.min_delta = min_delta
self.counter = 0
self.min_validation_loss = np.inf
def apply_training_fraction(training_fraction, train_data):
"""
Reduces the size of train_data to train_fraction by inplace filtering.
Filter blockwise for efficient memory savings.
def early_stop(self, validation_loss):
if validation_loss < self.min_validation_loss:
self.min_validation_loss = validation_loss
self.counter = 0
elif validation_loss > (self.min_validation_loss + self.min_delta):
self.counter += 1
if self.counter >= self.patience:
return True
return False
:param training_fraction: Training fraction between 0 and 1.
:param train_data: Training dataset
:return: None
"""
if not 0.0 < training_fraction <= 1.0:
raise ValueError("Training fraction needs to be between 0 and 1.")
if training_fraction < 1:
blocks = train_data["trace_name"].apply(lambda x: x.split("$")[0])
unique_blocks = blocks.unique()
np.random.shuffle(unique_blocks)
target_blocks = unique_blocks[: int(training_fraction * len(unique_blocks))]
target_blocks = set(target_blocks)
mask = blocks.isin(target_blocks)
train_data.filter(mask, inplace=True)
def get_data_generator(split, sampling_rate, path, sb_dataset="ethz", station=None, window='random'):
def generate_phase_mask(dataset, phases):
mask = np.zeros(len(dataset), dtype=bool)
if path is not None:
data = sbd.WaveformDataset(path, sampling_rate=sampling_rate)
phase_dict = {
"trace_Pg_arrival_sample": "P"
}
elif sb_dataset == "ethz":
data = sbd.ETHZ(sampling_rate=sampling_rate, force=True)
phase_dict = {
"trace_p_arrival_sample": "P",
"trace_pP_arrival_sample": "P",
"trace_P_arrival_sample": "P",
"trace_P1_arrival_sample": "P",
"trace_Pg_arrival_sample": "P",
"trace_Pn_arrival_sample": "P",
"trace_PmP_arrival_sample": "P",
"trace_pwP_arrival_sample": "P",
"trace_pwPm_arrival_sample": "P",
# "trace_s_arrival_sample": "S",
# "trace_S_arrival_sample": "S",
# "trace_S1_arrival_sample": "S",
# "trace_Sg_arrival_sample": "S",
# "trace_SmS_arrival_sample": "S",
# "trace_Sn_arrival_sample": "S",
}
dataset = data.get_split(split)
dataset.filter(dataset.metadata.trace_Pg_arrival_sample.notna())
print(split, dataset.metadata.shape, sampling_rate)
if station is not None:
dataset.filter(dataset.metadata.station_code==station)
data_generator = sbg.GenericGenerator(dataset)
if window == 'random':
print("using random window")
window_selector = sbg.RandomWindow(windowlen=3001, strategy="pad")
for key, phase in models.phase_dict.items():
if phase not in phases:
continue
else:
window_selector = sbg.FixedWindow(windowlen=3001, p0=0, strategy="pad")
if key in dataset.metadata:
mask = np.logical_or(mask, ~np.isnan(dataset.metadata[key]))
augmentations = [
sbg.WindowAroundSample(list(phase_dict.keys()), samples_before=3000, windowlen=6000, selection="random",
strategy="variable"),
window_selector,
sbg.Normalize(demean_axis=-1, amp_norm_axis=-1, amp_norm_type="peak"),
sbg.ChangeDtype(np.float32),
sbg.ProbabilisticLabeller(label_columns=phase_dict, sigma=30, dim=0)
]
data_generator.add_augmentations(augmentations)
return data_generator
def get_data_generators(sampling_rate=100, path=project_path+"/data/igf/seisbench_format", sb_dataset="ethz", station=None,
window='random'):
train_generator = get_data_generator("train", sampling_rate, path, sb_dataset, station, window)
dev_generator = get_data_generator("dev", sampling_rate, path, sb_dataset, station, window)
test_generator = get_data_generator("test", sampling_rate, path, sb_dataset, station, window)
return train_generator, dev_generator, test_generator
def get_data_loaders(batch_size=256, sampling_rate=100, path=project_path+"/data/igf/seisbench_format", sb_dataset="ethz",
window='random'):
train_generator, dev_generator, test_generator = get_data_generators(sampling_rate, path, sb_dataset, window=window)
num_workers = 0 # The number of threads used for loading data
train_loader = DataLoader(train_generator, batch_size=batch_size, shuffle=True, num_workers=num_workers,
worker_init_fn=worker_seeding)
dev_loader = DataLoader(dev_generator, batch_size=batch_size, shuffle=False, num_workers=num_workers,
worker_init_fn=worker_seeding)
test_loader = DataLoader(test_generator, batch_size=batch_size, shuffle=False, num_workers=num_workers,
worker_init_fn=worker_seeding)
return train_loader, dev_loader, test_loader
def load_model(name="PhaseNet", pretrained=None, classes=2, modify_output=True):
if name == "PhaseNet":
if pretrained is not None and pretrained:
model = sbm.PhaseNet(phases="PN", norm="peak").from_pretrained(pretrained)
else:
model = sbm.PhaseNet(phases="PN", norm="peak")
if modify_output:
model.out = nn.Conv1d(model.filters_root, classes, 1, padding="same")
return model
def train_one_epoch(model, dataloader, optimizer, pick_mae):
size = len(dataloader.dataset)
for batch_id, batch in enumerate(dataloader):
# Compute prediction and loss
pred = model(batch["X"].to(model.device))
loss = loss_fn(pred, batch["y"].to(model.device))
# Compute cross entropy loss
cross_entropy_loss = f.cross_entropy(pred, batch["y"])
# Compute mae
mae = pick_mae(pred, batch['y'])
wandb.log({"loss": loss})
wandb.log({"batch cross entropy loss": cross_entropy_loss})
wandb.log({"p_mae": mae})
# Backpropagation
optimizer.zero_grad()
loss.backward()
optimizer.step()
if batch_id % 5 == 0:
loss, current = loss.item(), batch_id * batch["X"].shape[0]
print(f"loss: {loss:>7f} [{current:>5d}/{size:>5d}]")
print(f"mae: {mae:>7f}")
def test_one_epoch(model, dataloader, pick_mae, wandb_log=True):
num_batches = len(dataloader)
test_loss = 0
test_mae = 0
with torch.no_grad():
for batch in dataloader:
pred = model(batch["X"].to(model.device))
test_loss += loss_fn(pred, batch["y"].to(model.device)).item()
test_mae += pick_mae(pred, batch['y'])
test_cross_entropy_loss = f.cross_entropy(pred, batch["y"])
if wandb_log:
wandb.log({"batch cross entropy test loss": test_cross_entropy_loss})
test_loss /= num_batches
test_mae /= num_batches
wandb.log({"test_p_mae": test_mae, "test_loss": test_loss})
print(f"Test avg loss: {test_loss:>8f}")
print(f"Test avg mae: {test_mae:>7f}\n")
return test_loss, test_mae
def train_model(model, path_to_trained_model, train_loader, dev_loader):
wandb.watch(model, log_freq=10)
optimizer = torch.optim.Adam(model.parameters(), lr=wandb.config.learning_rate)
early_stopper = EarlyStopper(patience=3, min_delta=10)
pick_mae = PickMAE(wandb.config.sampling_rate)
best_loss = np.inf
best_metrics = {}
for t in range(wandb.config.epochs):
print(f"Epoch {t + 1}\n-------------------------------")
train_one_epoch(model, train_loader, optimizer, pick_mae)
test_loss, test_mae = test_one_epoch(model, dev_loader, pick_mae)
if test_loss < best_loss:
best_loss = test_loss
best_metrics = {"test_p_mae": test_mae, "test_loss": test_loss}
torch.save(model.state_dict(), path_to_trained_model)
if early_stopper.early_stop(test_loss):
break
print("Best model: ", str(best_metrics))
def loss_fn(y_pred, y_true, eps=1e-5):
# vector cross entropy loss
h = y_true * torch.log(y_pred + eps)
h = h.mean(-1).sum(-1) # Mean along sample dimension and sum along pick dimension
h = h.mean() # Mean over batch axis
return -h
def train_phasenet_on_sb_data():
config = {
"epochs": 3,
"batch_size": 256,
"dataset": "ethz",
"sampling_rate": 100,
"model_name": "PhaseNet"
}
run = wandb.init(
# set the wandb project where this run will be logged
project="training_seisbench_models_on_igf_data",
# track hyperparameters and run metadata
config=config
)
wandb.run.log_code(".", include_fn=lambda path: path.endswith("training_wandb_sweep.py"))
train_loader, dev_loader, test = get_data_loaders(batch_size=wandb.config.batch_size,
sampling_rate=wandb.config.sampling_rate,
path=None,
sb_dataset=wandb.config.dataset)
model = load_model(name=wandb.config.model_name, pretrained=None, modify_output=True)
path_to_trained_model = f"{project_path}/models/{wandb.config.model_name}_trained_on_{wandb.config.data_set}.pt"
train_model(model, path_to_trained_model,
train_loader, dev_loader)
artifact = wandb.Artifact('model', type='model')
artifact.add_file(path_to_trained_model)
run.log_artifact(artifact)
run.finish()
def load_config(config_path):
with open(config_path, 'r') as f:
config = json.load(f)
return config
def train_sbmodel_on_igf_data():
config_path = project_path + "/experiments/config.json"
config = load_config(config_path)
run = wandb.init(
# set the wandb project where this run will be logged
project="training_seisbench_models_on_igf_data",
# track hyperparameters and run metadata
config=config
)
wandb.run.log_code(".", include_fn=lambda path: path.endswith("training_wandb_sweep.py"))
print(wandb.config.batch_size, wandb.config.sampling_rate)
train_loader, dev_loader, test_loader = get_data_loaders(batch_size=wandb.config.batch_size,
sampling_rate=wandb.config.sampling_rate
)
model_name = wandb.config.model_name
pretrained = wandb.config.pretrained
print(model_name, pretrained)
model = load_model(name=model_name, pretrained=pretrained)
path_to_trained_model = f"{project_path}/models/{model_name}_pretrained_on_{pretrained}_finetuned_on_{wandb.config.dataset}.pt"
train_model(model, path_to_trained_model, train_loader, dev_loader)
artifact = wandb.Artifact('model', type='model')
artifact.add_file(path_to_trained_model)
run.log_artifact(artifact)
run.finish()
return mask
if __name__ == "__main__":
# train_phasenet_on_sb_data()
train_sbmodel_on_igf_data()
code_start_time = time.perf_counter()
torch.manual_seed(42)
parser = argparse.ArgumentParser()
parser.add_argument("--config", type=str, required=True)
parser.add_argument("--test_run", action="store_true")
parser.add_argument("--lr", default=None, type=float)
args = parser.parse_args()
with open(args.config, "r") as f:
config = json.load(f)
experiment_name = os.path.basename(args.config)[:-5]
if args.lr is not None:
logging.warning(f"Overwriting learning rate to {args.lr}")
experiment_name += f"_{args.lr}"
config["model_args"]["lr"] = args.lr
run = wandb.init(
# set the wandb project where this run will be logged
project="training_seisbench_models_on_igf_data_with_pick-benchmark",
# track hyperparameters and run metadata
config=config
)
if args.test_run:
experiment_name = experiment_name + "_test"
train(config, experiment_name, test_run=args.test_run)
running_time = str(
datetime.timedelta(seconds=time.perf_counter() - code_start_time)
)
print(f"Running time: {running_time}")

View File

@ -1,62 +0,0 @@
import os.path
import wandb
import yaml
from train import get_data_loaders, load_model, train_model
from dotenv import load_dotenv
load_dotenv()
wandb_api_key = os.environ.get('WANDB_API_KEY')
if wandb_api_key is None:
raise ValueError("WANDB_API_KEY environment variable is not set.")
wandb.login(key=wandb_api_key)
project_path = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
sweep_config_path = project_path + "/experiments/sweep4.yaml"
with open(sweep_config_path) as file:
sweep_configuration = yaml.load(file, Loader=yaml.FullLoader)
sweep_id = wandb.sweep(
sweep=sweep_configuration,
project='training_seisbench_models_on_igf_data'
)
sampling_rate = 100
def tune_training_hyperparams():
run = wandb.init(
# set the wandb project where this run will be logged
project="training_seisbench_models_on_igf_data",
# track hyperparameters and run metadata
config={"sampling_rate":sampling_rate}
)
wandb.run.log_code(".", include_fn=lambda path: path.endswith("training_wandb_sweep.py"))
train_loader, dev_loader, test_loader = get_data_loaders(batch_size=wandb.config.batch_size,
sampling_rate=wandb.config.sampling_rate,
sb_dataset=wandb.config.dataset)
model_name = wandb.config.model_name
pretrained = wandb.config.pretrained
print(wandb.config)
print(model_name, pretrained, type(pretrained), wandb.config.sampling_rate)
if not pretrained:
pretrained
model = load_model(name=model_name, pretrained=pretrained)
path_to_trained_model = f"{project_path}/models/{model_name}_pretrained_on_{pretrained}_finetuned_on_{wandb.config.dataset}.pt"
train_model(model, path_to_trained_model, train_loader, dev_loader)
artifact = wandb.Artifact('model', type='model')
artifact.add_file(path_to_trained_model)
run.log_artifact(artifact)
run.finish()
if __name__ == "__main__":
wandb.agent(sweep_id, function=tune_training_hyperparams, count=10)

119
scripts/util.py Normal file
View File

@ -0,0 +1,119 @@
"""
This script offers general functionality required in multiple places.
"""
import numpy as np
import pandas as pd
import os
import logging
import glob
import wandb
from dotenv import load_dotenv
import sys
from config_loader import models_path, configs_path
import yaml
load_dotenv()
logging.basicConfig()
logging.getLogger().setLevel(logging.DEBUG)
def load_best_model_data(sweep_id, weights):
"""
Determines the model with the lowest validation loss.
If sweep_id is not provided the best model is determined based on the validation loss in checkpoint filenames in weights directory.
:param sweep_id:
:param weights:
:return:
"""
if sweep_id is not None:
wandb_project_name = os.environ.get("WANDB_PROJECT")
wandb_user = os.environ.get("WANDB_USER")
api = wandb.Api()
sweep = api.sweep(f"{wandb_user}/{wandb_project_name}/{sweep_id}")
# Get best run parameters
best_run = sweep.best_run()
run_id = best_run.id
matching_models = glob.glob(f"{models_path}/{weights}/*run={run_id}*ckpt")
if len(matching_models)!=1:
raise ValueError("Unable to determine the best checkpoint for run_id: " + run_id)
best_checkpoint_path = matching_models[0]
else:
checkpoints_path = f"{models_path}/{weights}/*ckpt"
logging.debug(f"Searching for checkpoints in dir: {checkpoints_path}")
checkpoints = glob.glob(checkpoints_path)
val_losses = []
for ckpt in checkpoints:
i = ckpt.index("val_loss=")
val_losses.append(float(ckpt[i + 9:-5]))
best_checkpoint_path = checkpoints[np.argmin(val_losses)]
run_id_st = best_checkpoint_path.index("run=") + 4
run_id_end = best_checkpoint_path.index("-epoch=")
run_id = best_checkpoint_path[run_id_st:run_id_end]
return best_checkpoint_path, run_id
def load_best_model(model_cls, weights, version):
"""
Determines the model with lowest validation loss from the csv logs and loads it
:param model_cls: Class of the lightning module to load
:param weights: Path to weights as in cmd arguments
:param version: String of version file
:return: Instance of lightning module that was loaded from the best checkpoint
"""
metrics = pd.read_csv(weights / version / "metrics.csv")
idx = np.nanargmin(metrics["val_loss"])
min_row = metrics.iloc[idx]
# For default checkpoint filename, see https://github.com/Lightning-AI/lightning/pull/11805
# and https://github.com/Lightning-AI/lightning/issues/16636.
# For example, 'epoch=0-step=1.ckpt' means the 1st step has finish, but the 1st epoch hasn't
checkpoint = f"epoch={min_row['epoch']:.0f}-step={min_row['step']+1:.0f}.ckpt"
# For default save path of checkpoints, see https://github.com/Lightning-AI/lightning/pull/12372
checkpoint_path = weights / version / "checkpoints" / checkpoint
return model_cls.load_from_checkpoint(checkpoint_path)
default_workers = os.getenv("BENCHMARK_DEFAULT_WORKERS", None)
if default_workers is None:
logging.warning(
"BENCHMARK_DEFAULT_WORKERS not set. "
"Will use 12 workers if not specified otherwise in configuration."
)
default_workers = 12
else:
default_workers = int(default_workers)
def load_sweep_config(sweep_fname):
"""
Loads sweep config from yaml file
:param sweep_fname: sweep yaml file, expected to be in configs_path
:return: Dictionary containing sweep config
"""
sweep_config_path = f"{configs_path}/{sweep_fname}"
try:
with open(sweep_config_path, "r") as file:
sweep_config = yaml.load(file, Loader=yaml.FullLoader)
except FileNotFoundError:
logging.error(f"Could not find sweep config file: {sweep_fname}. "
f"Please make sure the file exists and is in {configs_path} directory.")
sys.exit(1)
return sweep_config