initial commit with the pipeline for training and evaluating seisbench models
This commit is contained in:
parent
fd4a67f2ae
commit
915f2a2d69
9
.gitignore
vendored
9
.gitignore
vendored
@ -3,8 +3,11 @@ __pycache__/
|
||||
*/.ipynb_checkpoints/
|
||||
.ipynb_checkpoints/
|
||||
.env
|
||||
models/
|
||||
data/
|
||||
weights/
|
||||
datasets/
|
||||
wip
|
||||
artifacts/
|
||||
wandb/
|
||||
wandb/
|
||||
scripts/pred/
|
||||
scripts/pred_resampled/
|
||||
scripts/lightning_logs/
|
57
README.md
57
README.md
@ -2,20 +2,59 @@
|
||||
|
||||
|
||||
This repo contains notebooks and scripts demonstrating how to:
|
||||
- Prepare IGF data for training model detecting P phase (i.e. transform mseeds into [SeisBench data format](https://seisbench.readthedocs.io/en/stable/pages/data_format.html)), check the [notebook](utils/Transforming%20mseeds%20to%20SeisBench%20dataset.ipynb).
|
||||
The original data can be downloaded from the [drive](https://drive.google.com/drive/folders/1InVI9DLaD7gdzraM2jMzeIrtiBSu-UIK?usp=drive_link)
|
||||
- Prepare IGF data for training a seisbench model detecting P phase (i.e. transform mseeds into [SeisBench data format](https://seisbench.readthedocs.io/en/stable/pages/data_format.html)), check the [notebook](utils/Transforming%20mseeds%20to%20SeisBench%20dataset.ipynb).
|
||||
|
||||
- Explore available data, check the [notebook](notebooks/Explore%20igf%20data.ipynb)
|
||||
- Train cnn model (Seisbench PhaseNet) to detect P phase, check the [script](scripts/train.py)
|
||||
- Use Weights & Biases to search for the best training hyperparams, check the [script](scripts/training_wandb_sweep.py)
|
||||
- Validate model performance, check the [notebook](notebooks/Check%20model%20performance%20depending%20on%20station-random%20window.ipynb)
|
||||
- Use model for detecting P phase, check the [notebook](notebooks/Present%20model%20predictions.ipynb)
|
||||
- Train various cnn models available in seisbench library and compare their performance of detecting P phase, check the [script](scripts/pipeline.py)
|
||||
|
||||
- [to update] Validate model performance, check the [notebook](notebooks/Check%20model%20performance%20depending%20on%20station-random%20window.ipynb)
|
||||
- [to update] Use model for detecting P phase, check the [notebook](notebooks/Present%20model%20predictions.ipynb)
|
||||
|
||||
|
||||
### Acknowledgments
|
||||
This code is based on the [pick-benchmark](https://github.com/seisbench/pick-benchmark), the repository accompanying the paper:
|
||||
[Which picker fits my data? A quantitative evaluation of deep learning based seismic pickers](https://github.com/seisbench/pick-benchmark#:~:text=Which%20picker%20fits%20my%20data%3F%20A%20quantitative%20evaluation%20of%20deep%20learning%20based%20seismic%20pickers)
|
||||
### Usage
|
||||
|
||||
1. Install all dependencies with poetry, run:
|
||||
`poetry install`
|
||||
2. Prepare .env file with content:
|
||||
|
||||
`WANDB_API_KEY="your key"`
|
||||
`poetry install`
|
||||
2. Prepare .env file with content:
|
||||
```
|
||||
WANDB_HOST="https://epos-ai.grid.cyfronet.pl/"
|
||||
WANDB_API_KEY="your key"
|
||||
WANDB_USER="your user"
|
||||
WANDB_PROJECT="training_seisbench_models_on_igf_data"
|
||||
BENCHMARK_DEFAULT_WORKER=2
|
||||
|
||||
3. Transform data into seisbench format.
|
||||
* Download original data from the [drive](https://drive.google.com/drive/folders/1InVI9DLaD7gdzraM2jMzeIrtiBSu-UIK?usp=drive_link)
|
||||
* Run the notebook: `utils/Transforming mseeds to SeisBench dataset.ipynb`
|
||||
|
||||
4. Initialize poetry environment:
|
||||
|
||||
`poetry shell`
|
||||
|
||||
5. Run the pipeline script:
|
||||
|
||||
`python pipeline.py`
|
||||
|
||||
The script performs the following steps:
|
||||
* Generates evaluation targets
|
||||
* Trains multiple versions of GPD, PhaseNet and ... models to find the best hyperparameters, producing the lowest validation loss.
|
||||
This step utilizes the Weights & Biases platform to perform the hyperparameters search (called sweeping) and track the training process and store the results.
|
||||
The results are available at
|
||||
`https://epos-ai.grid.cyfronet.pl/<your user name>/<your project name>`
|
||||
* Uses the best performing model of each type to generate predictions
|
||||
* Evaluates the performance of each model by comparing the predictions with the evaluation targets
|
||||
* Saves the results in the `scripts/pred` directory
|
||||
*
|
||||
The default settings are saved in config.json file. To change the settings, edit the config.json file or pass the new settings as arguments to the script.
|
||||
For example, to change the sweep configuration file for GPD model, run:
|
||||
`python pipeline.py --gpd_config <new config file>`
|
||||
The new config file should be placed in the `experiments` or as specified in the `configs_path` parameter in the config.json file.
|
||||
|
||||
### Troubleshooting
|
||||
|
||||
* `wandb: ERROR Run .. errored: OSError(24, 'Too many open files')`
|
||||
-> https://github.com/wandb/wandb/issues/2825
|
||||
|
15
config.json
Normal file
15
config.json
Normal file
@ -0,0 +1,15 @@
|
||||
{
|
||||
"dataset_name": "igf",
|
||||
"data_path": "datasets/igf/seisbench_format/",
|
||||
"targets_path": "datasets/targets/igf",
|
||||
"models_path": "weights",
|
||||
"configs_path": "experiments",
|
||||
"sampling_rate": 100,
|
||||
"num_workers": 1,
|
||||
"seed": 10,
|
||||
"sweep_files": {
|
||||
"GPD": "sweep_gpd.yaml",
|
||||
"PhaseNet": "sweep_phasenet.yaml"
|
||||
},
|
||||
"experiment_count": 20
|
||||
}
|
@ -1,11 +0,0 @@
|
||||
{
|
||||
"epochs": 10,
|
||||
"batch_size": 256,
|
||||
"dataset": "igf_1",
|
||||
"sampling_rate": 100,
|
||||
"model_names": "EQTransformer,BasicPhaseAE,GPD",
|
||||
"model_name": "PhaseNet",
|
||||
"learning_rate": 0.01,
|
||||
"pretrained": null,
|
||||
"sampling_rate": 100
|
||||
}
|
26
experiments/sweep_gpd.yaml
Normal file
26
experiments/sweep_gpd.yaml
Normal file
@ -0,0 +1,26 @@
|
||||
name: GPD_fixed_highpass:2-10
|
||||
method: bayes
|
||||
metric:
|
||||
goal: minimize
|
||||
name: val_loss
|
||||
parameters:
|
||||
model_name:
|
||||
value:
|
||||
- GPD
|
||||
batch_size:
|
||||
distribution: int_uniform
|
||||
max: 1024
|
||||
min: 256
|
||||
max_epochs:
|
||||
value:
|
||||
- 3
|
||||
learning_rate:
|
||||
distribution: uniform
|
||||
max: 0.02
|
||||
min: 0.005
|
||||
highpass:
|
||||
value:
|
||||
- 2
|
||||
lowpass:
|
||||
value:
|
||||
- 10
|
27
experiments/sweep_gpd_highpass.yaml
Normal file
27
experiments/sweep_gpd_highpass.yaml
Normal file
@ -0,0 +1,27 @@
|
||||
name: GPD_fixed_highpass:2-10
|
||||
method: bayes
|
||||
metric:
|
||||
goal: minimize
|
||||
name: val_loss
|
||||
parameters:
|
||||
model_name:
|
||||
value:
|
||||
- GPD
|
||||
batch_size:
|
||||
distribution: int_uniform
|
||||
max: 1024
|
||||
min: 256
|
||||
max_epochs:
|
||||
value:
|
||||
- 15
|
||||
learning_rate:
|
||||
distribution: uniform
|
||||
max: 0.02
|
||||
min: 0.005
|
||||
highpass:
|
||||
distribution: uniform
|
||||
min: 0.5
|
||||
max: 2.0
|
||||
lowpass:
|
||||
value:
|
||||
- 10
|
20
experiments/sweep_phasenet.yaml
Normal file
20
experiments/sweep_phasenet.yaml
Normal file
@ -0,0 +1,20 @@
|
||||
name: PhaseNet-lr0.005-0.02-bs256-1024
|
||||
method: bayes
|
||||
metric:
|
||||
goal: minimize
|
||||
name: val_loss
|
||||
parameters:
|
||||
model_name:
|
||||
value:
|
||||
- PhaseNet
|
||||
batch_size:
|
||||
distribution: int_uniform
|
||||
max: 1024
|
||||
min: 256
|
||||
max_epochs:
|
||||
value:
|
||||
- 15
|
||||
learning_rate:
|
||||
distribution: uniform
|
||||
max: 0.02
|
||||
min: 0.005
|
1605
poetry.lock
generated
1605
poetry.lock
generated
File diff suppressed because it is too large
Load Diff
@ -14,6 +14,8 @@ pandas = "^2.0.3"
|
||||
obspy = "^1.4.0"
|
||||
wandb = "^0.15.4"
|
||||
torchmetrics = "^0.11.4"
|
||||
ipykernel = "^6.24.0"
|
||||
jupyterlab = "^4.0.2"
|
||||
|
||||
[tool.poetry.dev-dependencies]
|
||||
|
||||
|
133
scripts/augmentations.py
Normal file
133
scripts/augmentations.py
Normal file
@ -0,0 +1,133 @@
|
||||
"""
|
||||
This file contains augmentations required for the models that are too specific to be merged into SeisBench.
|
||||
"""
|
||||
|
||||
import numpy as np
|
||||
import copy
|
||||
|
||||
|
||||
class DuplicateEvent:
|
||||
"""
|
||||
Adds a rescaled version of the event to the empty part of the trace after the event.
|
||||
Event position and empty space are determined from a detection.
|
||||
Detections can be generated for example with :py:class:`~seisbench.generate.labeling.DetectionLabeller`.
|
||||
|
||||
This implementation is modelled after the `implementation for EQTransformer <https://github.com/smousavi05/EQTransformer/blob/98676017f971efbb6f4475f42e415c3868d00c03/EQTransformer/core/EqT_utils.py#L255>`_.
|
||||
|
||||
.. warning::
|
||||
This augmentation does **not** modify the metadata, as representing multiple picks of
|
||||
the same type is currently not supported. Workflows should therefore always first generate
|
||||
labels from metadata and then pass the labels in the key `label_keys`. These keys are automatically
|
||||
adjusted by addition of the labels.
|
||||
|
||||
.. warning::
|
||||
This implementation currently has strict shape requirements:
|
||||
|
||||
- (1, samples) for detection
|
||||
- (channels, samples) for data
|
||||
- (labels, samples) for labels
|
||||
|
||||
:param inv_scale: The scale factor is defined by as 1/u, where u is uniform.
|
||||
`inv_scale` defines the minimum and maximum values for u.
|
||||
Defaults to (1, 10), e.g., scaling by factor 1 to 1/10.
|
||||
:param detection_key: Key to read detection from.
|
||||
If key is a tuple, detection will be read from the first key and written to the second one.
|
||||
:param key: The keys for reading from and writing to the state dict.
|
||||
If key is a single string, the corresponding entry in state dict is modified.
|
||||
Otherwise, a 2-tuple is expected, with the first string indicating the key
|
||||
to read from and the second one the key to write to.
|
||||
:param label_keys: Keys for the label columns.
|
||||
Labels of the original and duplicate events will be added and capped at 1.
|
||||
Note that this will lead to invalid noise traces.
|
||||
Value can either be a single key specification or a list of key specifications.
|
||||
Each key specification is either a string, for identical input and output keys,
|
||||
or as a tuple of two strings, input and output keys.
|
||||
Defaults to None.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self, inv_scale=(1, 10), detection_key="detections", key="X", label_keys=None
|
||||
):
|
||||
if isinstance(detection_key, str):
|
||||
self.detection_key = (detection_key, detection_key)
|
||||
else:
|
||||
self.detection_key = detection_key
|
||||
|
||||
if isinstance(key, str):
|
||||
self.key = (key, key)
|
||||
else:
|
||||
self.key = key
|
||||
|
||||
# Single key
|
||||
if not isinstance(label_keys, list):
|
||||
if label_keys is None:
|
||||
label_keys = []
|
||||
else:
|
||||
label_keys = [label_keys]
|
||||
|
||||
# Resolve identical input and output keys
|
||||
self.label_keys = []
|
||||
for key in label_keys:
|
||||
if isinstance(key, tuple):
|
||||
self.label_keys.append(key)
|
||||
else:
|
||||
self.label_keys.append((key, key))
|
||||
|
||||
self.inv_scale = inv_scale
|
||||
|
||||
def __call__(self, state_dict):
|
||||
x, metadata = state_dict[self.key[0]]
|
||||
detection, _ = state_dict[self.detection_key[0]]
|
||||
detection_mask = detection[0] > 0.5
|
||||
|
||||
if detection.shape[-1] != x.shape[-1]:
|
||||
raise ValueError("Number of samples in trace and detection disagree.")
|
||||
|
||||
if self.key[0] != self.key[1]:
|
||||
# Ensure metadata is not modified inplace unless input and output key are anyhow identical
|
||||
metadata = copy.deepcopy(metadata)
|
||||
|
||||
if detection_mask.any():
|
||||
n_samples = x.shape[-1]
|
||||
event_samples = np.arange(n_samples)[detection_mask]
|
||||
event_start, event_end = np.min(event_samples), np.max(event_samples) + 1
|
||||
|
||||
if event_end + 20 < n_samples:
|
||||
second_start = np.random.randint(event_end + 20, n_samples)
|
||||
scale = 1 / np.random.uniform(*self.inv_scale)
|
||||
|
||||
if self.key[0] != self.key[1]:
|
||||
# Avoid inplace modification if input and output keys differ
|
||||
x = x.copy()
|
||||
|
||||
space = min(event_end - event_start, n_samples - second_start)
|
||||
x[:, second_start : second_start + space] += (
|
||||
x[:, event_start : event_start + space] * scale
|
||||
)
|
||||
|
||||
shift = second_start - event_start
|
||||
|
||||
for label_key in self.label_keys + [self.detection_key]:
|
||||
y, metadata = state_dict[label_key[0]]
|
||||
if y.shape[-1] != n_samples:
|
||||
raise ValueError(
|
||||
f"Number of samples disagree between trace and label key '{label_key[0]}'."
|
||||
)
|
||||
|
||||
if label_key[0] != label_key[1]:
|
||||
metadata = copy.deepcopy(metadata)
|
||||
y = y.copy()
|
||||
|
||||
y[:, shift:] += y[:, :-shift]
|
||||
y = np.minimum(y, 1)
|
||||
state_dict[label_key[1]] = (y, metadata)
|
||||
else:
|
||||
# Copy entries
|
||||
for label_key in self.label_keys + [self.detection_key]:
|
||||
y, metadata = state_dict[label_key[0]]
|
||||
if label_key[0] != label_key[1]:
|
||||
metadata = copy.deepcopy(metadata)
|
||||
y = y.copy()
|
||||
state_dict[label_key[1]] = (y, metadata)
|
||||
|
||||
state_dict[self.key[1]] = (x, metadata)
|
335
scripts/collect_results.py
Normal file
335
scripts/collect_results.py
Normal file
@ -0,0 +1,335 @@
|
||||
"""
|
||||
This script collects results in a folder, calculates performance metrics and writes them to csv.
|
||||
"""
|
||||
|
||||
import argparse
|
||||
from pathlib import Path
|
||||
import logging
|
||||
import pandas as pd
|
||||
import numpy as np
|
||||
from sklearn.metrics import (
|
||||
precision_recall_curve,
|
||||
precision_recall_fscore_support,
|
||||
roc_auc_score,
|
||||
matthews_corrcoef,
|
||||
)
|
||||
from tqdm import tqdm
|
||||
|
||||
|
||||
def traverse_path(path, output, cross=False, resampled=False, baer=False):
|
||||
"""
|
||||
Traverses the given path and extracts results for each experiment and version
|
||||
|
||||
:param path: Root path
|
||||
:param output: Path to write results csv to
|
||||
:param cross: If true, expects cross-domain results.
|
||||
:return: None
|
||||
"""
|
||||
path = Path(path)
|
||||
|
||||
results = []
|
||||
|
||||
exp_dirs = [x for x in path.iterdir() if x.is_dir()]
|
||||
for exp_dir in tqdm(exp_dirs):
|
||||
itr = exp_dir.iterdir()
|
||||
if baer:
|
||||
itr = [exp_dir] # Missing version directory in the structure
|
||||
for version_dir in itr:
|
||||
if not version_dir.is_dir():
|
||||
pass
|
||||
|
||||
results.append(
|
||||
process_version(
|
||||
version_dir, cross=cross, resampled=resampled, baer=baer
|
||||
)
|
||||
)
|
||||
|
||||
results = pd.DataFrame(results)
|
||||
if cross:
|
||||
sort_keys = ["data", "model", "target", "lr", "version"]
|
||||
else:
|
||||
sort_keys = ["data", "model", "lr", "version"]
|
||||
results.sort_values(sort_keys, inplace=True)
|
||||
results.to_csv(output, index=False)
|
||||
|
||||
|
||||
def process_version(version_dir: Path, cross: bool, resampled: bool, baer: bool):
|
||||
"""
|
||||
Extracts statistics for the given version of the given experiment.
|
||||
|
||||
:param version_dir: Path to the specific version
|
||||
:param cross: If true, expects cross-domain results.
|
||||
:return: Results dictionary
|
||||
"""
|
||||
stats = parse_exp_name(version_dir, cross=cross, resampled=resampled, baer=baer)
|
||||
|
||||
stats.update(eval_task1(version_dir))
|
||||
stats.update(eval_task23(version_dir))
|
||||
|
||||
return stats
|
||||
|
||||
|
||||
def parse_exp_name(version_dir, cross, resampled, baer):
|
||||
if baer:
|
||||
exp_name = version_dir.name
|
||||
version = "0"
|
||||
else:
|
||||
exp_name = version_dir.parent.name
|
||||
version = version_dir.name.split("_")[-1]
|
||||
|
||||
parts = exp_name.split("_")
|
||||
target = None
|
||||
sampling_rate = None
|
||||
if cross or baer:
|
||||
if len(parts) == 4:
|
||||
data, model, lr, target = parts
|
||||
else:
|
||||
data, model, target = parts
|
||||
lr = "0.001"
|
||||
elif resampled:
|
||||
if len(parts) == 5:
|
||||
data, model, lr, target, sampling_rate = parts
|
||||
else:
|
||||
data, model, target, sampling_rate = parts
|
||||
lr = "0.001"
|
||||
else:
|
||||
if len(parts) == 3:
|
||||
data, model, lr = parts
|
||||
else:
|
||||
data, model, *_ = parts
|
||||
lr = "0.001"
|
||||
|
||||
# lr = float(lr)
|
||||
|
||||
stats = {
|
||||
"experiment": exp_name,
|
||||
"data": data,
|
||||
"model": model,
|
||||
"lr": None,
|
||||
"version": version,
|
||||
}
|
||||
|
||||
if cross or baer:
|
||||
stats["target"] = target
|
||||
if resampled:
|
||||
stats["target"] = target
|
||||
stats["sampling_rate"] = sampling_rate
|
||||
|
||||
return stats
|
||||
|
||||
|
||||
def eval_task1(version_dir: Path):
|
||||
if not (
|
||||
(version_dir / "dev_task1.csv").is_file()
|
||||
and (version_dir / "test_task1.csv").is_file()
|
||||
):
|
||||
logging.warning(f"Directory {version_dir} does not contain task 1")
|
||||
return {}
|
||||
|
||||
stats = {}
|
||||
|
||||
dev_pred = pd.read_csv(version_dir / "dev_task1.csv")
|
||||
dev_pred["trace_type_bin"] = dev_pred["trace_type"] == "earthquake"
|
||||
test_pred = pd.read_csv(version_dir / "test_task1.csv")
|
||||
test_pred["trace_type_bin"] = test_pred["trace_type"] == "earthquake"
|
||||
|
||||
prec, recall, thr = precision_recall_curve(
|
||||
dev_pred["trace_type_bin"], dev_pred["score_detection"]
|
||||
)
|
||||
|
||||
f1 = 2 * prec * recall / (prec + recall)
|
||||
auc = roc_auc_score(dev_pred["trace_type_bin"], dev_pred["score_detection"])
|
||||
|
||||
opt_index = np.nanargmax(f1) # F1 optimal threshold index
|
||||
opt_thr = thr[opt_index] # F1 optimal threshold value
|
||||
|
||||
dev_stats = {
|
||||
"dev_det_precision": prec[opt_index],
|
||||
"dev_det_recall": recall[opt_index],
|
||||
"dev_det_f1": f1[opt_index],
|
||||
"dev_det_auc": auc,
|
||||
"det_threshold": opt_thr,
|
||||
}
|
||||
stats.update(dev_stats)
|
||||
|
||||
prec, recall, f1, _ = precision_recall_fscore_support(
|
||||
test_pred["trace_type_bin"],
|
||||
test_pred["score_detection"] > opt_thr,
|
||||
average="binary",
|
||||
)
|
||||
auc = roc_auc_score(test_pred["trace_type_bin"], test_pred["score_detection"])
|
||||
test_stats = {
|
||||
"test_det_precision": prec,
|
||||
"test_det_recall": recall,
|
||||
"test_det_f1": f1,
|
||||
"test_det_auc": auc,
|
||||
}
|
||||
stats.update(test_stats)
|
||||
|
||||
return stats
|
||||
|
||||
|
||||
def eval_task23(version_dir: Path):
|
||||
print(version_dir / "dev_task23.csv")
|
||||
if not (
|
||||
(version_dir / "dev_task23.csv").is_file()
|
||||
and (version_dir / "test_task23.csv").is_file()
|
||||
):
|
||||
logging.warning(f"Directory {version_dir} does not contain tasks 2 and 3")
|
||||
return {}
|
||||
|
||||
stats = {}
|
||||
|
||||
dev_pred = pd.read_csv(version_dir / "dev_task23.csv")
|
||||
dev_pred["phase_label_bin"] = dev_pred["phase_label"] == "P"
|
||||
test_pred = pd.read_csv(version_dir / "test_task23.csv")
|
||||
test_pred["phase_label_bin"] = test_pred["phase_label"] == "P"
|
||||
|
||||
def add_aux_columns(pred):
|
||||
for col in ["s_sample_pred", "score_p_or_s"]:
|
||||
if col not in pred.columns:
|
||||
pred[col] = np.nan
|
||||
|
||||
add_aux_columns(dev_pred)
|
||||
add_aux_columns(test_pred)
|
||||
|
||||
def nanmask(pred):
|
||||
"""
|
||||
Returns all entries that are nan in score_p_or_s, p_sample_pred and s_sample_pred
|
||||
"""
|
||||
mask = np.logical_and(
|
||||
np.isnan(pred["p_sample_pred"]), np.isnan(pred["s_sample_pred"])
|
||||
)
|
||||
mask = np.logical_and(mask, np.isnan(pred["score_p_or_s"]))
|
||||
return mask
|
||||
|
||||
if nanmask(dev_pred).all():
|
||||
logging.warning(f"{version_dir} contains NaN predictions for tasks 2 and 3")
|
||||
return {}
|
||||
|
||||
dev_pred = dev_pred[~nanmask(dev_pred)]
|
||||
test_pred = test_pred[~nanmask(test_pred)]
|
||||
|
||||
skip_task2 = False
|
||||
if (
|
||||
np.logical_or(
|
||||
np.isnan(dev_pred["score_p_or_s"]), np.isinf(dev_pred["score_p_or_s"])
|
||||
).all()
|
||||
or np.logical_or(
|
||||
np.isnan(test_pred["score_p_or_s"]), np.isinf(test_pred["score_p_or_s"])
|
||||
).all()
|
||||
):
|
||||
# For unfortunate combinations of nans and infs, otherwise weird scores can occur
|
||||
skip_task2 = True
|
||||
|
||||
# Clipping removes infinitely likely P waves, usually resulting from models trained without S arrivals
|
||||
dev_pred["score_p_or_s"] = np.clip(dev_pred["score_p_or_s"].values, -1e100, 1e100)
|
||||
test_pred["score_p_or_s"] = np.clip(test_pred["score_p_or_s"].values, -1e100, 1e100)
|
||||
|
||||
dev_pred_restricted = dev_pred[~np.isnan(dev_pred["score_p_or_s"])]
|
||||
test_pred_restricted = test_pred[~np.isnan(test_pred["score_p_or_s"])]
|
||||
if len(dev_pred_restricted) > 0 and not skip_task2:
|
||||
prec, recall, thr = precision_recall_curve(
|
||||
dev_pred_restricted["phase_label_bin"], dev_pred_restricted["score_p_or_s"]
|
||||
)
|
||||
|
||||
f1 = 2 * prec * recall / (prec + recall)
|
||||
|
||||
opt_index = np.nanargmax(f1) # F1 optimal threshold index
|
||||
opt_thr = thr[opt_index] # F1 optimal threshold value
|
||||
|
||||
# Determine (approximately) optimal MCC threshold using 50 candidates
|
||||
mcc_thrs = np.sort(dev_pred["score_p_or_s"].values)
|
||||
mcc_thrs = mcc_thrs[np.linspace(0, len(mcc_thrs) - 1, 50, dtype=int)]
|
||||
mccs = []
|
||||
for thr in mcc_thrs:
|
||||
mccs.append(
|
||||
matthews_corrcoef(
|
||||
dev_pred["phase_label_bin"], dev_pred["score_p_or_s"] > thr
|
||||
)
|
||||
)
|
||||
mcc = np.max(mccs)
|
||||
mcc_thr = mcc_thrs[np.argmax(mccs)]
|
||||
|
||||
dev_stats = {
|
||||
"dev_phase_precision": prec[opt_index],
|
||||
"dev_phase_recall": recall[opt_index],
|
||||
"dev_phase_f1": f1[opt_index],
|
||||
"phase_threshold": opt_thr,
|
||||
"dev_phase_mcc": mcc,
|
||||
"phase_threshold_mcc": mcc_thr,
|
||||
}
|
||||
stats.update(dev_stats)
|
||||
|
||||
prec, recall, f1, _ = precision_recall_fscore_support(
|
||||
test_pred_restricted["phase_label_bin"],
|
||||
test_pred_restricted["score_p_or_s"] > opt_thr,
|
||||
average="binary",
|
||||
)
|
||||
mcc = matthews_corrcoef(
|
||||
test_pred["phase_label_bin"], test_pred["score_p_or_s"] > mcc_thr
|
||||
)
|
||||
test_stats = {
|
||||
"test_phase_precision": prec,
|
||||
"test_phase_recall": recall,
|
||||
"test_phase_f1": f1,
|
||||
"test_phase_mcc": mcc,
|
||||
}
|
||||
stats.update(test_stats)
|
||||
|
||||
for pred, set_str in [(dev_pred, "dev"), (test_pred, "test")]:
|
||||
for i, phase in enumerate(["P", "S"]):
|
||||
pred_phase = pred[pred["phase_label"] == phase]
|
||||
pred_col = f"{phase.lower()}_sample_pred"
|
||||
|
||||
if len(pred_phase) == 0:
|
||||
continue
|
||||
|
||||
diff = (pred_phase[pred_col] - pred_phase["phase_onset"]) / pred_phase[
|
||||
"sampling_rate"
|
||||
]
|
||||
|
||||
stats[f"{set_str}_{phase}_mean_s"] = np.mean(diff)
|
||||
stats[f"{set_str}_{phase}_std_s"] = np.sqrt(np.mean(diff**2))
|
||||
stats[f"{set_str}_{phase}_mae_s"] = np.mean(np.abs(diff))
|
||||
|
||||
return stats
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Collects results from all experiments in a folder and outputs them in condensed csv format."
|
||||
)
|
||||
parser.add_argument(
|
||||
"path",
|
||||
type=str,
|
||||
help="Root path of predictions",
|
||||
)
|
||||
parser.add_argument(
|
||||
"output",
|
||||
type=str,
|
||||
help="Path for the output csv",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--cross", action="store_true", help="If true, expects cross-domain results."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--resampled",
|
||||
action="store_true",
|
||||
help="If true, expects cross-domain cross-sampling rate results.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--baer",
|
||||
action="store_true",
|
||||
help="If true, expects results from Baer-Kradolfer picker.",
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
traverse_path(
|
||||
args.path,
|
||||
args.output,
|
||||
cross=args.cross,
|
||||
resampled=args.resampled,
|
||||
baer=args.baer,
|
||||
)
|
26
scripts/config_loader.py
Normal file
26
scripts/config_loader.py
Normal file
@ -0,0 +1,26 @@
|
||||
import json
|
||||
from pathlib import Path
|
||||
|
||||
|
||||
def load_config(file_path):
|
||||
with open(file_path, 'r') as config_file:
|
||||
config = json.load(config_file)
|
||||
return config
|
||||
|
||||
|
||||
project_path = str(Path.cwd().parent)
|
||||
config_path = project_path + "/config.json"
|
||||
|
||||
config = load_config(config_path)
|
||||
|
||||
data_path = f"{project_path}/{config['data_path']}"
|
||||
models_path = f"{project_path}/{config['models_path']}"
|
||||
targets_path = f"{project_path}/{config['targets_path']}"
|
||||
dataset_name = config['dataset_name']
|
||||
configs_path = f"{project_path}/{config['configs_path']}"
|
||||
|
||||
sweep_files = config['sweep_files']
|
||||
sampling_rate = config['sampling_rate']
|
||||
num_workers = config['num_workers']
|
||||
seed = config['seed']
|
||||
experiment_count = config['experiment_count']
|
32
scripts/data.py
Normal file
32
scripts/data.py
Normal file
@ -0,0 +1,32 @@
|
||||
"""
|
||||
This file contains functionality related to data.
|
||||
"""
|
||||
|
||||
import seisbench.data as sbd
|
||||
|
||||
|
||||
def get_dataset_by_name(name):
|
||||
"""
|
||||
Resolve dataset name to class from seisbench.data.
|
||||
|
||||
:param name: Name of dataset as defined in seisbench.data.
|
||||
:return: Dataset class from seisbench.data
|
||||
"""
|
||||
try:
|
||||
return sbd.__getattribute__(name)
|
||||
except AttributeError:
|
||||
raise ValueError(f"Unknown dataset '{name}'.")
|
||||
|
||||
|
||||
def get_custom_dataset(path):
|
||||
"""
|
||||
Return custom dataset in seisbench format
|
||||
:param path:
|
||||
:return: Dataset class
|
||||
"""
|
||||
|
||||
try:
|
||||
return sbd.WaveformDataset(path)
|
||||
except AttributeError:
|
||||
raise ValueError(f"Unknown dataset '{path}'.")
|
||||
|
248
scripts/eval.py
Normal file
248
scripts/eval.py
Normal file
@ -0,0 +1,248 @@
|
||||
"""
|
||||
This script implements functionality for evaluating models.
|
||||
Given a model and a set of targets, it calculates and outputs predictions.
|
||||
"""
|
||||
|
||||
import seisbench.generate as sbg
|
||||
|
||||
import argparse
|
||||
import pandas as pd
|
||||
import yaml
|
||||
from pathlib import Path
|
||||
import pytorch_lightning as pl
|
||||
from torch.utils.data import DataLoader
|
||||
import torch
|
||||
|
||||
import models, data
|
||||
import logging
|
||||
from util import default_workers, load_best_model_data
|
||||
import time
|
||||
import datetime
|
||||
from config_loader import models_path, project_path
|
||||
import os
|
||||
|
||||
data_aliases = {
|
||||
"ethz": "ETHZ",
|
||||
"geofon": "GEOFON",
|
||||
"stead": "STEAD",
|
||||
"neic": "NEIC",
|
||||
"instance": "InstanceCountsCombined",
|
||||
"iquique": "Iquique",
|
||||
"lendb": "LenDB",
|
||||
"scedc": "SCEDC"
|
||||
}
|
||||
|
||||
|
||||
def main(weights, targets, sets, batchsize, num_workers, sampling_rate=None, sweep_id=None, test_run=False):
|
||||
weights = Path(weights)
|
||||
targets = Path(os.path.abspath(targets))
|
||||
print(targets)
|
||||
# print()
|
||||
sets = sets.split(",")
|
||||
|
||||
checkpoint_path, version = load_best_model_data(sweep_id, weights)
|
||||
logging.warning("Starting evaluation of model: \n" + checkpoint_path)
|
||||
|
||||
config_path = f"{models_path}/{weights}/{version}/hparams.yaml"
|
||||
with open(config_path, "r") as f:
|
||||
# config = yaml.safe_load(f)
|
||||
config = yaml.full_load(f)
|
||||
|
||||
model_name = config["model_name"][0] if type(config["model_name"]) == list else config["model_name"]
|
||||
|
||||
model_cls = models.__getattribute__(model_name + "Lit")
|
||||
model = model_cls.load_from_checkpoint(checkpoint_path)
|
||||
|
||||
data_name = data_aliases[targets.name] if targets.name in data_aliases else None
|
||||
|
||||
if data_name != config["dataset_name"] and targets.name in data_aliases:
|
||||
logging.warning("Detected cross-domain evaluation")
|
||||
pred_root = "pred_cross"
|
||||
parts = weights.name.split()
|
||||
weight_path_name = "_".join(parts[:2] + [targets.name] + parts[2:])
|
||||
|
||||
else:
|
||||
pred_root = "pred"
|
||||
weight_path_name = weights.name
|
||||
|
||||
if data_name is not None:
|
||||
dataset = data.get_dataset_by_name(data_name)(
|
||||
sampling_rate=100, component_order="ZNE", dimension_order="NCW", cache="full"
|
||||
)
|
||||
else:
|
||||
data_path = project_path + '/' + config['data_path']
|
||||
print("Loading dataset: ", data_path)
|
||||
dataset = data.get_custom_dataset(data_path)
|
||||
|
||||
if sampling_rate is not None:
|
||||
dataset.sampling_rate = sampling_rate
|
||||
pred_root = pred_root + "_resampled"
|
||||
weight_path_name = weight_path_name + f"_{sampling_rate}"
|
||||
|
||||
for eval_set in sets:
|
||||
split = dataset.get_split(eval_set)
|
||||
if targets.name == "instance":
|
||||
logging.warning(
|
||||
"Overwriting noise trace_names to allow correct identification"
|
||||
)
|
||||
# Replace trace names for noise entries
|
||||
split._metadata["trace_name"].values[
|
||||
-len(split.datasets[-1]) :
|
||||
] = split._metadata["trace_name"][-len(split.datasets[-1]) :].apply(
|
||||
lambda x: "noise_" + x
|
||||
)
|
||||
split._build_trace_name_to_idx_dict()
|
||||
|
||||
logging.warning(f"Starting set {eval_set}")
|
||||
split.preload_waveforms(pbar=True)
|
||||
print("eval set shape", split.metadata.shape)
|
||||
|
||||
for task in ["1", "23"]:
|
||||
task_csv = targets / f"task{task}.csv"
|
||||
|
||||
print(task_csv)
|
||||
|
||||
if not task_csv.is_file():
|
||||
continue
|
||||
|
||||
logging.warning(f"Starting task {task}")
|
||||
|
||||
task_targets = pd.read_csv(task_csv)
|
||||
task_targets = task_targets[task_targets["trace_split"] == eval_set]
|
||||
if task == "1" and targets.name == "instance":
|
||||
border = _identify_instance_dataset_border(task_targets)
|
||||
task_targets["trace_name"].values[border:] = task_targets["trace_name"][
|
||||
border:
|
||||
].apply(lambda x: "noise_" + x)
|
||||
|
||||
if sampling_rate is not None:
|
||||
for key in ["start_sample", "end_sample", "phase_onset"]:
|
||||
if key not in task_targets.columns:
|
||||
continue
|
||||
task_targets[key] = (
|
||||
task_targets[key]
|
||||
* sampling_rate
|
||||
/ task_targets["sampling_rate"]
|
||||
)
|
||||
task_targets[sampling_rate] = sampling_rate
|
||||
|
||||
restrict_to_phase = config.get("restrict_to_phase", None)
|
||||
if restrict_to_phase is not None and "phase_label" in task_targets.columns:
|
||||
mask = task_targets["phase_label"].isin(list(restrict_to_phase))
|
||||
task_targets = task_targets[mask]
|
||||
|
||||
if restrict_to_phase is not None and task == "1":
|
||||
logging.warning("Skipping task 1 as restrict_to_phase is set.")
|
||||
continue
|
||||
|
||||
generator = sbg.SteeredGenerator(split, task_targets)
|
||||
generator.add_augmentations(model.get_eval_augmentations())
|
||||
|
||||
loader = DataLoader(
|
||||
generator, batch_size=batchsize, shuffle=False, num_workers=num_workers
|
||||
)
|
||||
# trainer = pl.Trainer(accelerator="gpu", devices=1)
|
||||
trainer = pl.Trainer()
|
||||
|
||||
predictions = trainer.predict(model, loader)
|
||||
|
||||
# Merge batches
|
||||
merged_predictions = []
|
||||
for i, _ in enumerate(predictions[0]):
|
||||
merged_predictions.append(torch.cat([x[i] for x in predictions]))
|
||||
|
||||
merged_predictions = [x.cpu().numpy() for x in merged_predictions]
|
||||
task_targets["score_detection"] = merged_predictions[0]
|
||||
task_targets["score_p_or_s"] = merged_predictions[1]
|
||||
task_targets["p_sample_pred"] = (
|
||||
merged_predictions[2] + task_targets["start_sample"]
|
||||
)
|
||||
task_targets["s_sample_pred"] = (
|
||||
merged_predictions[3] + task_targets["start_sample"]
|
||||
)
|
||||
|
||||
pred_path = (
|
||||
weights.parent.parent
|
||||
/ pred_root
|
||||
/ weight_path_name
|
||||
/ version
|
||||
/ f"{eval_set}_task{task}.csv"
|
||||
)
|
||||
pred_path.parent.mkdir(exist_ok=True, parents=True)
|
||||
task_targets.to_csv(pred_path, index=False)
|
||||
|
||||
|
||||
def _identify_instance_dataset_border(task_targets):
|
||||
"""
|
||||
Calculates the dataset border between Signal and Noise for instance,
|
||||
assuming it is the only place where the bucket number does not increase
|
||||
"""
|
||||
buckets = task_targets["trace_name"].apply(lambda x: int(x.split("$")[0][6:]))
|
||||
|
||||
last_bucket = 0
|
||||
for i, bucket in enumerate(buckets):
|
||||
if bucket < last_bucket:
|
||||
return i
|
||||
last_bucket = bucket
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
code_start_time = time.perf_counter()
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Evaluate a trained model using a set of targets."
|
||||
)
|
||||
parser.add_argument(
|
||||
"weights",
|
||||
type=str,
|
||||
help="Path to weights. Expected to be in models_path directory."
|
||||
"The script will automatically load the configuration and the model. "
|
||||
"The script always uses the checkpoint with lowest validation loss."
|
||||
"If sweep_id is provided the script considers only the checkpoints generated by that sweep."
|
||||
"Predictions will be written into the weights path as csv."
|
||||
"Note: Due to pytorch lightning internals, there exist two weights folders, "
|
||||
"{weights} and {weight}_{weights}. Please use the former as parameter",
|
||||
)
|
||||
parser.add_argument(
|
||||
"targets",
|
||||
type=str,
|
||||
help="Path to evaluation targets folder. "
|
||||
"The script will detect which tasks are present base on file names.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--sets",
|
||||
type=str,
|
||||
default="dev,test",
|
||||
help="Sets on which to evaluate, separated by commata. Defaults to dev and test.",
|
||||
)
|
||||
parser.add_argument("--batchsize", type=int, default=1024, help="Batch size")
|
||||
parser.add_argument(
|
||||
"--num_workers",
|
||||
default=default_workers,
|
||||
type=int,
|
||||
help="Number of workers for data loader",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--sampling_rate", type=float, help="Overwrites the sampling rate in the data"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--sweep_id", type=str, help="wandb sweep_id", required=False, default=None
|
||||
)
|
||||
parser.add_argument(
|
||||
"--test_run", action="store_true", required=False, default=False
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
main(
|
||||
args.weights,
|
||||
args.targets,
|
||||
args.sets,
|
||||
batchsize=args.batchsize,
|
||||
num_workers=args.num_workers,
|
||||
sampling_rate=args.sampling_rate,
|
||||
sweep_id=args.sweep_id,
|
||||
test_run=args.test_run
|
||||
)
|
||||
running_time = str(
|
||||
datetime.timedelta(seconds=time.perf_counter() - code_start_time)
|
||||
)
|
||||
print(f"Running time: {running_time}")
|
321
scripts/generate_eval_targets.py
Normal file
321
scripts/generate_eval_targets.py
Normal file
@ -0,0 +1,321 @@
|
||||
"""
|
||||
This script generates evaluation targets for the following three tasks:
|
||||
|
||||
- Earthquake detection (Task 1): Given a 30~s window, does the window contain an earthquake signal?
|
||||
- Phase identification (Task 2): Given a 10~s window containing exactly one phase onset, identify which phase type.
|
||||
- Onset time determination (Task 3): Given a 10~s window containing exactly one phase onset, identify the onset time.
|
||||
|
||||
Each target for evaluation will consist of the following information:
|
||||
|
||||
- trace name (as in dataset)
|
||||
- trace index (in dataset)
|
||||
- split (as in dataset)
|
||||
- sampling rate (at which all information is presented)
|
||||
- start_sample
|
||||
- end_sample
|
||||
- trace_type (only task 1: earthquake/noise)
|
||||
- phase_label (only task 2/3: P/S)
|
||||
- full_phase_label (only task 2/3: phase label as in the dataset, might be Pn, Pg, etc.)
|
||||
- phase_onset_sample (only task 2/3: onset sample of the phase relative to full trace)
|
||||
|
||||
It needs to be provided with a dataset and writes a folder with two CSV files, one for task 1 and one for tasks 2 and 3.
|
||||
Each file will describe targets for train, dev and test, derived from the respective splits.
|
||||
|
||||
When using these tasks for evaluation, the models can make use of waveforms from the context, i.e.,
|
||||
before/after the start and end samples. However, make sure this does not add further bias in the evaluation,
|
||||
for example by always centring the windows on the picks using context.
|
||||
|
||||
.. warning::
|
||||
For comparability, it is strongly advised to use published evaluation targets, instead of generating new ones.
|
||||
|
||||
.. warning::
|
||||
This implementation is not optimized and loads the full waveform data for its computations.
|
||||
This will lead to very high memory usage, as the full dataset will be stored in memory.
|
||||
"""
|
||||
import seisbench.data as sbd
|
||||
|
||||
import argparse
|
||||
from pathlib import Path
|
||||
import pandas as pd
|
||||
import numpy as np
|
||||
from tqdm import tqdm
|
||||
|
||||
from models import phase_dict
|
||||
|
||||
|
||||
def main(dataset_name, output, tasks, sampling_rate, noise_before_events):
|
||||
np.random.seed(42)
|
||||
tasks = [str(i) in tasks.split(",") for i in range(1, 4)]
|
||||
|
||||
if not any(tasks):
|
||||
raise ValueError(f"No task defined. Got tasks='{tasks}'.")
|
||||
|
||||
dataset_args = {
|
||||
"sampling_rate": sampling_rate,
|
||||
"dimension_order": "NCW",
|
||||
"cache": "full",
|
||||
}
|
||||
|
||||
try:
|
||||
# Check if dataset is available in SeisBench
|
||||
dataset = sbd.__getattribute__(dataset_name)(**dataset_args)
|
||||
except AttributeError:
|
||||
# Otherwise interpret data_in as path
|
||||
dataset = sbd.WaveformDataset(dataset_name, **dataset_args)
|
||||
|
||||
output = Path(output)
|
||||
output.mkdir(parents=True, exist_ok=False)
|
||||
|
||||
if "split" in dataset.metadata.columns:
|
||||
dataset.filter(dataset["split"].isin(["dev", "test"]), inplace=True)
|
||||
|
||||
dataset.preload_waveforms(pbar=True)
|
||||
|
||||
if tasks[0]:
|
||||
generate_task1(dataset, output, sampling_rate, noise_before_events)
|
||||
if tasks[1] or tasks[2]:
|
||||
generate_task23(dataset, output, sampling_rate)
|
||||
|
||||
|
||||
def generate_task1(dataset, output, sampling_rate, noise_before_events):
|
||||
np.random.seed(42)
|
||||
windowlen = 30 * sampling_rate # 30 s windows
|
||||
labels = []
|
||||
|
||||
for i in tqdm(range(len(dataset)), total=len(dataset)):
|
||||
waveforms, metadata = dataset.get_sample(i)
|
||||
|
||||
if "split" in metadata:
|
||||
trace_split = metadata["split"]
|
||||
else:
|
||||
trace_split = ""
|
||||
|
||||
def checkphase(metadata, phase, phase_label, target_phase, npts):
|
||||
return (
|
||||
phase in metadata
|
||||
and phase_label == target_phase
|
||||
and not np.isnan(metadata[phase])
|
||||
and 0 <= metadata[phase] < npts
|
||||
)
|
||||
|
||||
p_arrivals = [
|
||||
metadata[phase]
|
||||
for phase, phase_label in phase_dict.items()
|
||||
if checkphase(metadata, phase, phase_label, "P", waveforms.shape[-1])
|
||||
]
|
||||
s_arrivals = [
|
||||
metadata[phase]
|
||||
for phase, phase_label in phase_dict.items()
|
||||
if checkphase(metadata, phase, phase_label, "S", waveforms.shape[-1])
|
||||
]
|
||||
|
||||
if len(p_arrivals) == 0 and len(s_arrivals) == 0:
|
||||
start_sample, end_sample = select_window_containing(
|
||||
waveforms.shape[-1], windowlen
|
||||
)
|
||||
sample = {
|
||||
"trace_name": metadata["trace_name"],
|
||||
"trace_idx": i,
|
||||
"trace_split": trace_split,
|
||||
"sampling_rate": sampling_rate,
|
||||
"start_sample": start_sample,
|
||||
"end_sample": end_sample,
|
||||
"trace_type": "noise",
|
||||
}
|
||||
labels += [sample]
|
||||
|
||||
else:
|
||||
first_arrival = min(p_arrivals + s_arrivals)
|
||||
|
||||
start_sample, end_sample = select_window_containing(
|
||||
waveforms.shape[-1], windowlen, containing=first_arrival
|
||||
)
|
||||
if end_sample - start_sample <= windowlen:
|
||||
sample = {
|
||||
"trace_name": metadata["trace_name"],
|
||||
"trace_idx": i,
|
||||
"trace_split": trace_split,
|
||||
"sampling_rate": sampling_rate,
|
||||
"start_sample": start_sample,
|
||||
"end_sample": end_sample,
|
||||
"trace_type": "earthquake",
|
||||
}
|
||||
labels += [sample]
|
||||
|
||||
if noise_before_events and first_arrival > windowlen:
|
||||
start_sample, end_sample = select_window_containing(
|
||||
min(waveforms.shape[-1], first_arrival), windowlen
|
||||
)
|
||||
if end_sample - start_sample <= windowlen:
|
||||
sample = {
|
||||
"trace_name": metadata["trace_name"],
|
||||
"trace_idx": i,
|
||||
"trace_split": trace_split,
|
||||
"sampling_rate": sampling_rate,
|
||||
"start_sample": start_sample,
|
||||
"end_sample": end_sample,
|
||||
"trace_type": "noise",
|
||||
}
|
||||
labels += [sample]
|
||||
|
||||
labels = pd.DataFrame(labels)
|
||||
diff = labels["end_sample"] - labels["start_sample"]
|
||||
labels = labels[diff > 100]
|
||||
labels.to_csv(output / "task1.csv", index=False)
|
||||
|
||||
|
||||
def generate_task23(dataset, output, sampling_rate):
|
||||
np.random.seed(42)
|
||||
windowlen = 10 * sampling_rate # 30 s windows
|
||||
labels = []
|
||||
|
||||
for idx in tqdm(range(len(dataset)), total=len(dataset)):
|
||||
waveforms, metadata = dataset.get_sample(idx)
|
||||
|
||||
if "split" in metadata:
|
||||
trace_split = metadata["split"]
|
||||
else:
|
||||
trace_split = ""
|
||||
|
||||
def checkphase(metadata, phase, npts):
|
||||
return (
|
||||
phase in metadata
|
||||
and not np.isnan(metadata[phase])
|
||||
and 0 <= metadata[phase] < npts
|
||||
)
|
||||
|
||||
# Example entry: (1031, "P", "Pg")
|
||||
arrivals = sorted(
|
||||
[
|
||||
(metadata[phase], phase_label, phase.split("_")[1])
|
||||
for phase, phase_label in phase_dict.items()
|
||||
if checkphase(metadata, phase, waveforms.shape[-1])
|
||||
]
|
||||
)
|
||||
|
||||
if len(arrivals) == 0:
|
||||
# Trace has no arrivals
|
||||
continue
|
||||
|
||||
for i, (onset, phase, full_phase) in enumerate(arrivals):
|
||||
if i == 0:
|
||||
onset_before = 0
|
||||
else:
|
||||
onset_before = int(arrivals[i - 1][0]) + int(
|
||||
0.5 * sampling_rate
|
||||
) # 0.5 s minimum spacing
|
||||
|
||||
if i == len(arrivals) - 1:
|
||||
onset_after = np.inf
|
||||
else:
|
||||
onset_after = int(arrivals[i + 1][0]) - int(
|
||||
0.5 * sampling_rate
|
||||
) # 0.5 s minimum spacing
|
||||
|
||||
if (
|
||||
onset_after - onset_before < windowlen
|
||||
or onset_before > onset
|
||||
or onset_after < onset
|
||||
):
|
||||
# Impossible to isolate pick
|
||||
continue
|
||||
|
||||
else:
|
||||
onset_after = min(onset_after, waveforms.shape[-1])
|
||||
# Shift everything to a "virtual" start at onset_before
|
||||
start_sample, end_sample = select_window_containing(
|
||||
onset_after - onset_before,
|
||||
windowlen=windowlen,
|
||||
containing=onset - onset_before,
|
||||
bounds=(50, 50),
|
||||
)
|
||||
start_sample += onset_before
|
||||
end_sample += onset_before
|
||||
if end_sample - start_sample <= windowlen:
|
||||
sample = {
|
||||
"trace_name": metadata["trace_name"],
|
||||
"trace_idx": idx,
|
||||
"trace_split": trace_split,
|
||||
"sampling_rate": sampling_rate,
|
||||
"start_sample": start_sample,
|
||||
"end_sample": end_sample,
|
||||
"phase_label": phase,
|
||||
"full_phase_label": full_phase,
|
||||
"phase_onset": onset,
|
||||
}
|
||||
|
||||
labels += [sample]
|
||||
|
||||
labels = pd.DataFrame(labels)
|
||||
diff = labels["end_sample"] - labels["start_sample"]
|
||||
labels = labels[diff > 100]
|
||||
labels.to_csv(output / "task23.csv", index=False)
|
||||
|
||||
|
||||
def select_window_containing(npts, windowlen, containing=None, bounds=(100, 100)):
|
||||
"""
|
||||
Selects a window from a larger trace.
|
||||
|
||||
:param npts: Number of points of the full trace
|
||||
:param windowlen: Desired windowlen
|
||||
:param containing: Sample number that should be contained. If None, any window within the trace is valid.
|
||||
:param bounds: The containing sample may not be in the first/last samples indicated here.
|
||||
:return: Start sample, end_sample
|
||||
"""
|
||||
if npts <= windowlen:
|
||||
# If npts is smaller than the window length, always return the full window
|
||||
return 0, npts
|
||||
|
||||
else:
|
||||
if containing is None:
|
||||
start_sample = np.random.randint(0, npts - windowlen + 1)
|
||||
return start_sample, start_sample + windowlen
|
||||
|
||||
else:
|
||||
earliest_start = max(0, containing - windowlen + bounds[1])
|
||||
latest_start = min(npts - windowlen, containing - bounds[0])
|
||||
if latest_start <= earliest_start:
|
||||
# Again, return full window
|
||||
return 0, npts
|
||||
|
||||
else:
|
||||
start_sample = np.random.randint(earliest_start, latest_start + 1)
|
||||
return start_sample, start_sample + windowlen
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Generate evaluation targets. See the docstring for details."
|
||||
)
|
||||
parser.add_argument(
|
||||
"dataset", type=str, help="Path to input dataset or SeisBench dataset name"
|
||||
)
|
||||
parser.add_argument(
|
||||
"output", type=str, help="Path to write target files to. Must not exist."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--tasks",
|
||||
type=str,
|
||||
default="1,2,3",
|
||||
help="Which tasks to generate data for. By default generates data for all tasks.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--sampling_rate",
|
||||
type=float,
|
||||
default=100,
|
||||
help="Sampling rate in Hz to generate targets for.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--no_noise_before_events",
|
||||
action="store_true",
|
||||
help="If set, does not extract noise from windows before the first arrival.",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
main(
|
||||
args.dataset,
|
||||
args.output,
|
||||
args.tasks,
|
||||
args.sampling_rate,
|
||||
not args.no_noise_before_events,
|
||||
)
|
178
scripts/hyperparameter_sweep.py
Normal file
178
scripts/hyperparameter_sweep.py
Normal file
@ -0,0 +1,178 @@
|
||||
# -----------------
|
||||
# Copyright © 2023 ACK Cyfronet AGH, Poland.
|
||||
# This work was partially funded by EPOS Project funded in frame of PL-POIR4.2
|
||||
# -----------------
|
||||
|
||||
import os.path
|
||||
import argparse
|
||||
from pytorch_lightning.loggers import WandbLogger, CSVLogger
|
||||
from pytorch_lightning.callbacks import ModelCheckpoint
|
||||
from pytorch_lightning.callbacks.early_stopping import EarlyStopping
|
||||
import pytorch_lightning as pl
|
||||
import wandb
|
||||
import torch
|
||||
import traceback
|
||||
import logging
|
||||
from dotenv import load_dotenv
|
||||
import models
|
||||
import train
|
||||
import util
|
||||
from config_loader import config as common_config
|
||||
from config_loader import models_path, dataset_name, seed, experiment_count
|
||||
|
||||
|
||||
torch.multiprocessing.set_sharing_strategy('file_system')
|
||||
|
||||
load_dotenv()
|
||||
wandb_api_key = os.environ.get('WANDB_API_KEY')
|
||||
if wandb_api_key is None:
|
||||
raise ValueError("WANDB_API_KEY environment variable is not set.")
|
||||
host = os.environ.get("WANDB_HOST")
|
||||
if host is None:
|
||||
raise ValueError("WANDB_HOST environment variable is not set.")
|
||||
|
||||
|
||||
wandb.login(key=wandb_api_key, host=host)
|
||||
# wandb.login(key=wandb_api_key)
|
||||
|
||||
wandb_project_name = os.environ.get("WANDB_PROJECT")
|
||||
wandb_user_name = os.environ.get("WANDB_USER")
|
||||
|
||||
script_name = os.path.splitext(os.path.basename(__file__))[0]
|
||||
logger = logging.getLogger(script_name)
|
||||
logger.setLevel(logging.WARNING)
|
||||
|
||||
|
||||
def set_random_seed(seed=3):
|
||||
torch.manual_seed(seed)
|
||||
torch.cuda.manual_seed(seed)
|
||||
torch.cuda.manual_seed_all(seed)
|
||||
torch.backends.cudnn.deterministic = True
|
||||
torch.backends.cudnn.benchmark = False
|
||||
|
||||
|
||||
def get_trainer_args(config):
|
||||
trainer_args = {'max_epochs': config.max_epochs[0]}
|
||||
return trainer_args
|
||||
|
||||
|
||||
class HyperparameterSweep:
|
||||
def __init__(self, project_name, sweep_config):
|
||||
self.project_name = project_name
|
||||
self.sweep_config = sweep_config
|
||||
self.sweep_id = None
|
||||
|
||||
def run_sweep(self):
|
||||
|
||||
# Create the sweep
|
||||
self.sweep_id = wandb.sweep(self.sweep_config, project=self.project_name)
|
||||
|
||||
logger.info("Created sweep with ID: " + self.sweep_id)
|
||||
|
||||
# Run the sweep
|
||||
wandb.agent(self.sweep_id, function=self.run_experiment, count=experiment_count)
|
||||
|
||||
def all_runs_finished(self):
|
||||
|
||||
sweep_path = f"{wandb_user_name}/{wandb_project_name}/{self.sweep_id}"
|
||||
logger.debug(f"Sweep path: {sweep_path}")
|
||||
sweep_runs = wandb.Api().sweep(sweep_path).runs
|
||||
all_finished = all(run.state == "finished" for run in sweep_runs)
|
||||
if all_finished:
|
||||
logger.info("All runs finished successfully.")
|
||||
|
||||
all_not_running = all(run.state != "running" for run in sweep_runs)
|
||||
if all_not_running and not all_finished:
|
||||
logger.warning("Some runs are not finished but failed or crashed.")
|
||||
|
||||
return all_not_running
|
||||
|
||||
def run_experiment(self):
|
||||
|
||||
try:
|
||||
|
||||
logger.debug("Starting a new run...")
|
||||
run = wandb.init(
|
||||
project=self.project_name,
|
||||
config=common_config,
|
||||
)
|
||||
|
||||
wandb.run.log_code(
|
||||
".",
|
||||
include_fn=lambda path: path.endswith(os.path.basename(__file__))
|
||||
)
|
||||
|
||||
model_name = wandb.config.model_name[0]
|
||||
model_args = models.get_model_specific_args(wandb.config)
|
||||
logger.debug(f"Initializing {model_name}")
|
||||
|
||||
model = models.__getattribute__(model_name + "Lit")(**model_args)
|
||||
|
||||
train_loader, dev_loader = train.prepare_data(wandb.config, model, test_run=False)
|
||||
|
||||
wandb_logger = WandbLogger(project=self.project_name, log_model="all")
|
||||
wandb_logger.watch(model)
|
||||
|
||||
# CSV logger - also used for saving configuration as yaml
|
||||
experiment_name = f"{dataset_name}_{model_name}"
|
||||
csv_logger = CSVLogger(models_path, experiment_name, version=run.id)
|
||||
csv_logger.log_hyperparams(wandb.config)
|
||||
|
||||
loggers = [wandb_logger, csv_logger]
|
||||
|
||||
experiment_signature = f"{experiment_name}_sweep={self.sweep_id}-run={run.id}"
|
||||
|
||||
logger.debug("Experiment signature: " + experiment_signature)
|
||||
|
||||
checkpoint_callback = ModelCheckpoint(
|
||||
save_top_k=1,
|
||||
filename=experiment_signature + "-{epoch}-{val_loss:.3f}",
|
||||
monitor="val_loss",
|
||||
mode="min",
|
||||
dirpath=f"{models_path}/{experiment_name}/",
|
||||
) # save_top_k=1, monitor="val_loss", mode="min": save the best model in terms of validation loss
|
||||
checkpoint_callback.STARTING_VERSION = 1
|
||||
|
||||
early_stopping_callback = EarlyStopping(
|
||||
monitor="val_loss",
|
||||
patience=3,
|
||||
verbose=True,
|
||||
mode="min")
|
||||
callbacks = [checkpoint_callback, early_stopping_callback]
|
||||
|
||||
trainer = pl.Trainer(
|
||||
default_root_dir=models_path,
|
||||
logger=loggers,
|
||||
callbacks=callbacks,
|
||||
**get_trainer_args(wandb.config)
|
||||
)
|
||||
|
||||
trainer.fit(model, train_loader, dev_loader)
|
||||
|
||||
except Exception as e:
|
||||
logger.error("caught error: ", str(e))
|
||||
traceback_str = traceback.format_exc()
|
||||
logger.error(traceback_str)
|
||||
|
||||
run.finish()
|
||||
|
||||
|
||||
def start_sweep(sweep_config):
|
||||
|
||||
logger.info("Starting sweep with config: " + str(sweep_config))
|
||||
set_random_seed(seed)
|
||||
sweep_runner = HyperparameterSweep(project_name=wandb_project_name, sweep_config=sweep_config)
|
||||
sweep_runner.run_sweep()
|
||||
|
||||
return sweep_runner
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--sweep_config", type=str, required=True)
|
||||
args = parser.parse_args()
|
||||
|
||||
sweep_config = util.load_sweep_config(args.sweep_config)
|
||||
start_sweep(sweep_config)
|
||||
|
1138
scripts/models.py
Normal file
1138
scripts/models.py
Normal file
File diff suppressed because it is too large
Load Diff
92
scripts/pipeline.py
Normal file
92
scripts/pipeline.py
Normal file
@ -0,0 +1,92 @@
|
||||
"""
|
||||
-----------------
|
||||
Copyright © 2023 ACK Cyfronet AGH, Poland.
|
||||
This work was partially funded by EPOS Project funded in frame of PL-POIR4.2
|
||||
-----------------
|
||||
|
||||
This script runs the pipeline for the training and evaluation of the models.
|
||||
"""
|
||||
import logging
|
||||
import time
|
||||
import argparse
|
||||
|
||||
import util
|
||||
import generate_eval_targets
|
||||
import hyperparameter_sweep
|
||||
import eval
|
||||
import collect_results
|
||||
from config_loader import data_path, targets_path, sampling_rate, dataset_name, sweep_files
|
||||
|
||||
logger = logging.getLogger('pipeline')
|
||||
logger.setLevel(logging.INFO)
|
||||
|
||||
|
||||
def load_sweep_config(model_name, args):
|
||||
|
||||
if model_name == "PhaseNet" and args.phasenet_config is not None:
|
||||
sweep_fname = args.phasenet_config
|
||||
elif model_name == "GPD" and args.gpd_config is not None:
|
||||
sweep_fname = args.gpd_config
|
||||
else:
|
||||
# use the default sweep config for the model
|
||||
sweep_fname = sweep_files[model_name]
|
||||
|
||||
logger.info(f"Loading sweep config: {sweep_fname}")
|
||||
|
||||
return util.load_sweep_config(sweep_fname)
|
||||
|
||||
|
||||
def find_the_best_params(model_name, args):
|
||||
|
||||
# find the best hyperparams for the model_name
|
||||
logger.info(f"Starting searching for the best hyperparams for the model: {model_name}")
|
||||
|
||||
sweep_config = load_sweep_config(model_name, args)
|
||||
sweep_runner = hyperparameter_sweep.start_sweep(sweep_config)
|
||||
|
||||
# wait for all runs to finish
|
||||
|
||||
all_finished = sweep_runner.all_runs_finished()
|
||||
while not all_finished:
|
||||
logger.info("Waiting for sweep runs to finish...")
|
||||
# Sleep for a few seconds before checking again
|
||||
time.sleep(30)
|
||||
all_finished = sweep_runner.all_runs_finished()
|
||||
|
||||
logger.info(f"Finished the sweep: {sweep_runner.sweep_id}")
|
||||
return sweep_runner.sweep_id
|
||||
|
||||
|
||||
def generate_predictions(sweep_id, model_name):
|
||||
experiment_name = f"{dataset_name}_{model_name}"
|
||||
eval.main(weights=experiment_name,
|
||||
targets=targets_path,
|
||||
sets='dev,test',
|
||||
batchsize=128,
|
||||
num_workers=4,
|
||||
# sampling_rate=sampling_rate,
|
||||
sweep_id=sweep_id
|
||||
)
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--phasenet_config", type=str, required=False)
|
||||
parser.add_argument("--gpd_config", type=str, required=False)
|
||||
args = parser.parse_args()
|
||||
|
||||
# generate labels
|
||||
generate_eval_targets.main(data_path, targets_path, "2,3", sampling_rate, None)
|
||||
|
||||
# find the best hyperparams for the models
|
||||
for model_name in ["GPD", "PhaseNet"]:
|
||||
sweep_id = find_the_best_params(model_name, args)
|
||||
generate_predictions(sweep_id, model_name)
|
||||
|
||||
# collect results
|
||||
collect_results.traverse_path("pred", "pred/results.csv")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
532
scripts/train.py
532
scripts/train.py
@ -1,339 +1,245 @@
|
||||
import os.path
|
||||
import wandb
|
||||
import seisbench.data as sbd
|
||||
"""
|
||||
This script handles the training of models base on model configuration files.
|
||||
"""
|
||||
|
||||
import seisbench.generate as sbg
|
||||
import seisbench.models as sbm
|
||||
from seisbench.util import worker_seeding
|
||||
import numpy as np
|
||||
import torch
|
||||
from torch.utils.data import DataLoader
|
||||
import torch.nn.functional as f
|
||||
import torch.nn as nn
|
||||
from torchmetrics import Metric
|
||||
from torch import Tensor, tensor
|
||||
|
||||
import pytorch_lightning as pl
|
||||
# from pytorch_lightning.loggers import TensorBoardLogger, CSVLogger
|
||||
from pytorch_lightning.loggers import WandbLogger, CSVLogger
|
||||
|
||||
# https://github.com/Lightning-AI/lightning/pull/12554
|
||||
# https://github.com/Lightning-AI/lightning/issues/11796
|
||||
from pytorch_lightning.callbacks import ModelCheckpoint
|
||||
import argparse
|
||||
import json
|
||||
from dotenv import load_dotenv
|
||||
import numpy as np
|
||||
from torch.utils.data import DataLoader
|
||||
import torch
|
||||
import os
|
||||
import logging
|
||||
from pathlib import Path
|
||||
|
||||
load_dotenv()
|
||||
wandb_api_key = os.environ.get('WANDB_API_KEY')
|
||||
if wandb_api_key is None:
|
||||
raise ValueError("WANDB_API_KEY environment variable is not set.")
|
||||
import models, data, util
|
||||
import time
|
||||
import datetime
|
||||
import wandb
|
||||
#
|
||||
# load_dotenv()
|
||||
# wandb_api_key = os.environ.get('WANDB_API_KEY')
|
||||
# if wandb_api_key is None:
|
||||
# raise ValueError("WANDB_API_KEY environment variable is not set.")
|
||||
#
|
||||
# wandb.login(key=wandb_api_key)
|
||||
|
||||
wandb.login(key=wandb_api_key)
|
||||
def train(config, experiment_name, test_run):
|
||||
"""
|
||||
Runs the model training defined by the config.
|
||||
|
||||
project_path = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
|
||||
Config parameters:
|
||||
|
||||
- model: Model used as in the models.py file, but without the Lit suffix
|
||||
- data: Dataset used, as in seisbench.data
|
||||
- model_args: Arguments passed to the constructor of the model lightning module
|
||||
- trainer_args: Arguments passed to the lightning trainer
|
||||
- batch_size: Batch size for training and validation
|
||||
- num_workers: Number of workers for data loading.
|
||||
If not set, uses environment variable BENCHMARK_DEFAULT_WORKERS
|
||||
- restrict_to_phase: Filters datasets only to examples containing the given phase.
|
||||
By default, uses all phases.
|
||||
- training_fraction: Fraction of training blocks to use as float between 0 and 1. Defaults to 1.
|
||||
|
||||
:param config: Configuration parameters for training
|
||||
:param test_run: If true, makes a test run with less data and less logging. Intended for debug purposes.
|
||||
"""
|
||||
model = models.__getattribute__(config["model"] + "Lit")(
|
||||
**config.get("model_args", {})
|
||||
)
|
||||
|
||||
train_loader, dev_loader = prepare_data(config, model, test_run)
|
||||
|
||||
# CSV logger - also used for saving configuration as yaml
|
||||
csv_logger = CSVLogger("weights", experiment_name)
|
||||
csv_logger.log_hyperparams(config)
|
||||
loggers = [csv_logger]
|
||||
|
||||
default_root_dir = os.path.join(
|
||||
"weights"
|
||||
) # Experiment name is parsed from the loggers
|
||||
if not test_run:
|
||||
# tb_logger = TensorBoardLogger("tb_logs", experiment_name)
|
||||
# tb_logger.log_hyperparams(config)
|
||||
# loggers += [tb_logger]
|
||||
wandb_logger = WandbLogger()
|
||||
wandb_logger.watch(model)
|
||||
|
||||
loggers +=[wandb_logger]
|
||||
|
||||
checkpoint_callback = ModelCheckpoint(
|
||||
save_top_k=1, filename="{epoch}-{step}", monitor="val_loss", mode="min"
|
||||
) # save_top_k=1, monitor="val_loss", mode="min": save the best model in terms of validation loss
|
||||
callbacks = [checkpoint_callback]
|
||||
|
||||
## Uncomment the following 2 lines to enable
|
||||
# device_stats = DeviceStatsMonitor()
|
||||
# callbacks.append(device_stats)
|
||||
|
||||
trainer = pl.Trainer(
|
||||
default_root_dir=default_root_dir,
|
||||
logger=loggers,
|
||||
callbacks=callbacks,
|
||||
**config.get("trainer_args", {}),
|
||||
)
|
||||
|
||||
trainer.fit(model, train_loader, dev_loader)
|
||||
|
||||
|
||||
class PickMAE(Metric):
|
||||
higher_is_better: bool = False
|
||||
mae_error: Tensor
|
||||
def prepare_data(config, model, test_run):
|
||||
"""
|
||||
Returns the training and validation data loaders
|
||||
:param config:
|
||||
:param model:
|
||||
:param test_run:
|
||||
:return:
|
||||
"""
|
||||
batch_size = config.get("batch_size", 1024)
|
||||
if type(batch_size) == list:
|
||||
batch_size = batch_size[0]
|
||||
|
||||
def __init__(self, sampling_rate):
|
||||
super().__init__()
|
||||
self.add_state("mae_error", default=torch.tensor(0), dist_reduce_fx="sum")
|
||||
self.sampling_rate = sampling_rate
|
||||
num_workers = config.get("num_workers", util.default_workers)
|
||||
try:
|
||||
dataset = data.get_dataset_by_name(config["dataset_name"])(
|
||||
sampling_rate=100, component_order="ZNE", dimension_order="NCW", cache="full"
|
||||
)
|
||||
except ValueError:
|
||||
data_path = str(Path.cwd().parent) + '/' + config['data_path']
|
||||
print(data_path)
|
||||
dataset = data.get_custom_dataset(data_path)
|
||||
|
||||
def update(self, preds: torch.Tensor, target: torch.Tensor):
|
||||
restrict_to_phase = config.get("restrict_to_phase", None)
|
||||
if restrict_to_phase is not None:
|
||||
mask = generate_phase_mask(dataset, restrict_to_phase)
|
||||
dataset.filter(mask, inplace=True)
|
||||
|
||||
assert preds.shape == target.shape
|
||||
if "split" not in dataset.metadata.columns:
|
||||
logging.warning("No split defined, adding auxiliary split.")
|
||||
split = np.array(["train"] * len(dataset))
|
||||
split[int(0.6 * len(dataset)) : int(0.7 * len(dataset))] = "dev"
|
||||
split[int(0.7 * len(dataset)) :] = "test"
|
||||
|
||||
pred_pick_idx = torch.argmax(preds[:, 0, :], dim=1).type(torch.FloatTensor)
|
||||
true_pick_idx = torch.argmax(target[:, 0, :], dim=-1).type(torch.FloatTensor)
|
||||
dataset._metadata["split"] = split
|
||||
|
||||
mae = nn.L1Loss()
|
||||
self.mae_error = mae(pred_pick_idx, true_pick_idx) / self.sampling_rate #mae in seconds
|
||||
train_data = dataset.train()
|
||||
dev_data = dataset.dev()
|
||||
|
||||
def compute(self):
|
||||
return self.mae_error.float()
|
||||
if test_run:
|
||||
# Only use a small part of the dataset
|
||||
train_mask = np.zeros(len(train_data), dtype=bool)
|
||||
train_mask[:5000] = True
|
||||
train_data.filter(train_mask, inplace=True)
|
||||
|
||||
dev_mask = np.zeros(len(dev_data), dtype=bool)
|
||||
dev_mask[:5000] = True
|
||||
dev_data.filter(dev_mask, inplace=True)
|
||||
|
||||
training_fraction = config.get("training_fraction", 1.0)
|
||||
apply_training_fraction(training_fraction, train_data)
|
||||
|
||||
train_data.preload_waveforms(pbar=True)
|
||||
dev_data.preload_waveforms(pbar=True)
|
||||
|
||||
train_generator = sbg.GenericGenerator(train_data)
|
||||
dev_generator = sbg.GenericGenerator(dev_data)
|
||||
|
||||
train_generator.add_augmentations(model.get_train_augmentations())
|
||||
dev_generator.add_augmentations(model.get_val_augmentations())
|
||||
|
||||
train_loader = DataLoader(
|
||||
train_generator,
|
||||
batch_size=batch_size,
|
||||
shuffle=True,
|
||||
num_workers=num_workers,
|
||||
worker_init_fn=worker_seeding,
|
||||
drop_last=True, # Avoid crashes from batch norm layers for batch size 1
|
||||
)
|
||||
dev_loader = DataLoader(
|
||||
dev_generator,
|
||||
batch_size=batch_size,
|
||||
num_workers=num_workers,
|
||||
worker_init_fn=worker_seeding,
|
||||
)
|
||||
|
||||
return train_loader, dev_loader
|
||||
|
||||
|
||||
class EarlyStopper:
|
||||
def __init__(self, patience=1, min_delta=0):
|
||||
self.patience = patience
|
||||
self.min_delta = min_delta
|
||||
self.counter = 0
|
||||
self.min_validation_loss = np.inf
|
||||
def apply_training_fraction(training_fraction, train_data):
|
||||
"""
|
||||
Reduces the size of train_data to train_fraction by inplace filtering.
|
||||
Filter blockwise for efficient memory savings.
|
||||
|
||||
def early_stop(self, validation_loss):
|
||||
if validation_loss < self.min_validation_loss:
|
||||
self.min_validation_loss = validation_loss
|
||||
self.counter = 0
|
||||
elif validation_loss > (self.min_validation_loss + self.min_delta):
|
||||
self.counter += 1
|
||||
if self.counter >= self.patience:
|
||||
return True
|
||||
return False
|
||||
:param training_fraction: Training fraction between 0 and 1.
|
||||
:param train_data: Training dataset
|
||||
:return: None
|
||||
"""
|
||||
|
||||
if not 0.0 < training_fraction <= 1.0:
|
||||
raise ValueError("Training fraction needs to be between 0 and 1.")
|
||||
|
||||
if training_fraction < 1:
|
||||
blocks = train_data["trace_name"].apply(lambda x: x.split("$")[0])
|
||||
unique_blocks = blocks.unique()
|
||||
np.random.shuffle(unique_blocks)
|
||||
target_blocks = unique_blocks[: int(training_fraction * len(unique_blocks))]
|
||||
target_blocks = set(target_blocks)
|
||||
mask = blocks.isin(target_blocks)
|
||||
train_data.filter(mask, inplace=True)
|
||||
|
||||
|
||||
def get_data_generator(split, sampling_rate, path, sb_dataset="ethz", station=None, window='random'):
|
||||
def generate_phase_mask(dataset, phases):
|
||||
mask = np.zeros(len(dataset), dtype=bool)
|
||||
|
||||
if path is not None:
|
||||
data = sbd.WaveformDataset(path, sampling_rate=sampling_rate)
|
||||
phase_dict = {
|
||||
"trace_Pg_arrival_sample": "P"
|
||||
}
|
||||
elif sb_dataset == "ethz":
|
||||
data = sbd.ETHZ(sampling_rate=sampling_rate, force=True)
|
||||
|
||||
phase_dict = {
|
||||
"trace_p_arrival_sample": "P",
|
||||
"trace_pP_arrival_sample": "P",
|
||||
"trace_P_arrival_sample": "P",
|
||||
"trace_P1_arrival_sample": "P",
|
||||
"trace_Pg_arrival_sample": "P",
|
||||
"trace_Pn_arrival_sample": "P",
|
||||
"trace_PmP_arrival_sample": "P",
|
||||
"trace_pwP_arrival_sample": "P",
|
||||
"trace_pwPm_arrival_sample": "P",
|
||||
# "trace_s_arrival_sample": "S",
|
||||
# "trace_S_arrival_sample": "S",
|
||||
# "trace_S1_arrival_sample": "S",
|
||||
# "trace_Sg_arrival_sample": "S",
|
||||
# "trace_SmS_arrival_sample": "S",
|
||||
# "trace_Sn_arrival_sample": "S",
|
||||
}
|
||||
|
||||
dataset = data.get_split(split)
|
||||
dataset.filter(dataset.metadata.trace_Pg_arrival_sample.notna())
|
||||
|
||||
print(split, dataset.metadata.shape, sampling_rate)
|
||||
|
||||
if station is not None:
|
||||
dataset.filter(dataset.metadata.station_code==station)
|
||||
|
||||
data_generator = sbg.GenericGenerator(dataset)
|
||||
if window == 'random':
|
||||
print("using random window")
|
||||
window_selector = sbg.RandomWindow(windowlen=3001, strategy="pad")
|
||||
else:
|
||||
window_selector = sbg.FixedWindow(windowlen=3001, p0=0, strategy="pad")
|
||||
|
||||
augmentations = [
|
||||
sbg.WindowAroundSample(list(phase_dict.keys()), samples_before=3000, windowlen=6000, selection="random",
|
||||
strategy="variable"),
|
||||
window_selector,
|
||||
sbg.Normalize(demean_axis=-1, amp_norm_axis=-1, amp_norm_type="peak"),
|
||||
sbg.ChangeDtype(np.float32),
|
||||
sbg.ProbabilisticLabeller(label_columns=phase_dict, sigma=30, dim=0)
|
||||
]
|
||||
|
||||
data_generator.add_augmentations(augmentations)
|
||||
|
||||
return data_generator
|
||||
|
||||
|
||||
def get_data_generators(sampling_rate=100, path=project_path+"/data/igf/seisbench_format", sb_dataset="ethz", station=None,
|
||||
window='random'):
|
||||
|
||||
train_generator = get_data_generator("train", sampling_rate, path, sb_dataset, station, window)
|
||||
dev_generator = get_data_generator("dev", sampling_rate, path, sb_dataset, station, window)
|
||||
test_generator = get_data_generator("test", sampling_rate, path, sb_dataset, station, window)
|
||||
|
||||
return train_generator, dev_generator, test_generator
|
||||
|
||||
|
||||
def get_data_loaders(batch_size=256, sampling_rate=100, path=project_path+"/data/igf/seisbench_format", sb_dataset="ethz",
|
||||
window='random'):
|
||||
|
||||
train_generator, dev_generator, test_generator = get_data_generators(sampling_rate, path, sb_dataset, window=window)
|
||||
num_workers = 0 # The number of threads used for loading data
|
||||
|
||||
train_loader = DataLoader(train_generator, batch_size=batch_size, shuffle=True, num_workers=num_workers,
|
||||
worker_init_fn=worker_seeding)
|
||||
dev_loader = DataLoader(dev_generator, batch_size=batch_size, shuffle=False, num_workers=num_workers,
|
||||
worker_init_fn=worker_seeding)
|
||||
|
||||
test_loader = DataLoader(test_generator, batch_size=batch_size, shuffle=False, num_workers=num_workers,
|
||||
worker_init_fn=worker_seeding)
|
||||
|
||||
return train_loader, dev_loader, test_loader
|
||||
|
||||
|
||||
def load_model(name="PhaseNet", pretrained=None, classes=2, modify_output=True):
|
||||
|
||||
if name == "PhaseNet":
|
||||
|
||||
if pretrained is not None and pretrained:
|
||||
model = sbm.PhaseNet(phases="PN", norm="peak").from_pretrained(pretrained)
|
||||
for key, phase in models.phase_dict.items():
|
||||
if phase not in phases:
|
||||
continue
|
||||
else:
|
||||
model = sbm.PhaseNet(phases="PN", norm="peak")
|
||||
if key in dataset.metadata:
|
||||
mask = np.logical_or(mask, ~np.isnan(dataset.metadata[key]))
|
||||
|
||||
if modify_output:
|
||||
model.out = nn.Conv1d(model.filters_root, classes, 1, padding="same")
|
||||
|
||||
return model
|
||||
|
||||
|
||||
def train_one_epoch(model, dataloader, optimizer, pick_mae):
|
||||
size = len(dataloader.dataset)
|
||||
for batch_id, batch in enumerate(dataloader):
|
||||
|
||||
# Compute prediction and loss
|
||||
|
||||
pred = model(batch["X"].to(model.device))
|
||||
|
||||
loss = loss_fn(pred, batch["y"].to(model.device))
|
||||
|
||||
# Compute cross entropy loss
|
||||
cross_entropy_loss = f.cross_entropy(pred, batch["y"])
|
||||
|
||||
# Compute mae
|
||||
mae = pick_mae(pred, batch['y'])
|
||||
|
||||
wandb.log({"loss": loss})
|
||||
wandb.log({"batch cross entropy loss": cross_entropy_loss})
|
||||
wandb.log({"p_mae": mae})
|
||||
|
||||
|
||||
# Backpropagation
|
||||
optimizer.zero_grad()
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
|
||||
if batch_id % 5 == 0:
|
||||
loss, current = loss.item(), batch_id * batch["X"].shape[0]
|
||||
print(f"loss: {loss:>7f} [{current:>5d}/{size:>5d}]")
|
||||
print(f"mae: {mae:>7f}")
|
||||
|
||||
|
||||
def test_one_epoch(model, dataloader, pick_mae, wandb_log=True):
|
||||
|
||||
num_batches = len(dataloader)
|
||||
test_loss = 0
|
||||
test_mae = 0
|
||||
|
||||
with torch.no_grad():
|
||||
for batch in dataloader:
|
||||
pred = model(batch["X"].to(model.device))
|
||||
|
||||
test_loss += loss_fn(pred, batch["y"].to(model.device)).item()
|
||||
test_mae += pick_mae(pred, batch['y'])
|
||||
test_cross_entropy_loss = f.cross_entropy(pred, batch["y"])
|
||||
if wandb_log:
|
||||
wandb.log({"batch cross entropy test loss": test_cross_entropy_loss})
|
||||
|
||||
test_loss /= num_batches
|
||||
test_mae /= num_batches
|
||||
|
||||
wandb.log({"test_p_mae": test_mae, "test_loss": test_loss})
|
||||
|
||||
print(f"Test avg loss: {test_loss:>8f}")
|
||||
print(f"Test avg mae: {test_mae:>7f}\n")
|
||||
|
||||
return test_loss, test_mae
|
||||
|
||||
|
||||
def train_model(model, path_to_trained_model, train_loader, dev_loader):
|
||||
|
||||
wandb.watch(model, log_freq=10)
|
||||
|
||||
optimizer = torch.optim.Adam(model.parameters(), lr=wandb.config.learning_rate)
|
||||
early_stopper = EarlyStopper(patience=3, min_delta=10)
|
||||
pick_mae = PickMAE(wandb.config.sampling_rate)
|
||||
|
||||
best_loss = np.inf
|
||||
best_metrics = {}
|
||||
|
||||
for t in range(wandb.config.epochs):
|
||||
print(f"Epoch {t + 1}\n-------------------------------")
|
||||
train_one_epoch(model, train_loader, optimizer, pick_mae)
|
||||
test_loss, test_mae = test_one_epoch(model, dev_loader, pick_mae)
|
||||
|
||||
if test_loss < best_loss:
|
||||
best_loss = test_loss
|
||||
best_metrics = {"test_p_mae": test_mae, "test_loss": test_loss}
|
||||
torch.save(model.state_dict(), path_to_trained_model)
|
||||
|
||||
if early_stopper.early_stop(test_loss):
|
||||
break
|
||||
|
||||
print("Best model: ", str(best_metrics))
|
||||
|
||||
|
||||
def loss_fn(y_pred, y_true, eps=1e-5):
|
||||
# vector cross entropy loss
|
||||
h = y_true * torch.log(y_pred + eps)
|
||||
h = h.mean(-1).sum(-1) # Mean along sample dimension and sum along pick dimension
|
||||
h = h.mean() # Mean over batch axis
|
||||
return -h
|
||||
|
||||
|
||||
def train_phasenet_on_sb_data():
|
||||
|
||||
config = {
|
||||
"epochs": 3,
|
||||
"batch_size": 256,
|
||||
"dataset": "ethz",
|
||||
"sampling_rate": 100,
|
||||
"model_name": "PhaseNet"
|
||||
}
|
||||
|
||||
run = wandb.init(
|
||||
# set the wandb project where this run will be logged
|
||||
project="training_seisbench_models_on_igf_data",
|
||||
# track hyperparameters and run metadata
|
||||
config=config
|
||||
)
|
||||
|
||||
wandb.run.log_code(".", include_fn=lambda path: path.endswith("training_wandb_sweep.py"))
|
||||
|
||||
train_loader, dev_loader, test = get_data_loaders(batch_size=wandb.config.batch_size,
|
||||
sampling_rate=wandb.config.sampling_rate,
|
||||
path=None,
|
||||
sb_dataset=wandb.config.dataset)
|
||||
|
||||
model = load_model(name=wandb.config.model_name, pretrained=None, modify_output=True)
|
||||
path_to_trained_model = f"{project_path}/models/{wandb.config.model_name}_trained_on_{wandb.config.data_set}.pt"
|
||||
train_model(model, path_to_trained_model,
|
||||
train_loader, dev_loader)
|
||||
|
||||
artifact = wandb.Artifact('model', type='model')
|
||||
artifact.add_file(path_to_trained_model)
|
||||
run.log_artifact(artifact)
|
||||
|
||||
run.finish()
|
||||
|
||||
|
||||
def load_config(config_path):
|
||||
with open(config_path, 'r') as f:
|
||||
config = json.load(f)
|
||||
return config
|
||||
|
||||
|
||||
def train_sbmodel_on_igf_data():
|
||||
|
||||
config_path = project_path + "/experiments/config.json"
|
||||
config = load_config(config_path)
|
||||
|
||||
run = wandb.init(
|
||||
# set the wandb project where this run will be logged
|
||||
project="training_seisbench_models_on_igf_data",
|
||||
# track hyperparameters and run metadata
|
||||
config=config
|
||||
)
|
||||
wandb.run.log_code(".", include_fn=lambda path: path.endswith("training_wandb_sweep.py"))
|
||||
|
||||
print(wandb.config.batch_size, wandb.config.sampling_rate)
|
||||
train_loader, dev_loader, test_loader = get_data_loaders(batch_size=wandb.config.batch_size,
|
||||
sampling_rate=wandb.config.sampling_rate
|
||||
)
|
||||
|
||||
model_name = wandb.config.model_name
|
||||
pretrained = wandb.config.pretrained
|
||||
|
||||
print(model_name, pretrained)
|
||||
model = load_model(name=model_name, pretrained=pretrained)
|
||||
path_to_trained_model = f"{project_path}/models/{model_name}_pretrained_on_{pretrained}_finetuned_on_{wandb.config.dataset}.pt"
|
||||
train_model(model, path_to_trained_model, train_loader, dev_loader)
|
||||
|
||||
artifact = wandb.Artifact('model', type='model')
|
||||
artifact.add_file(path_to_trained_model)
|
||||
run.log_artifact(artifact)
|
||||
|
||||
run.finish()
|
||||
return mask
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# train_phasenet_on_sb_data()
|
||||
train_sbmodel_on_igf_data()
|
||||
code_start_time = time.perf_counter()
|
||||
|
||||
torch.manual_seed(42)
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--config", type=str, required=True)
|
||||
parser.add_argument("--test_run", action="store_true")
|
||||
parser.add_argument("--lr", default=None, type=float)
|
||||
args = parser.parse_args()
|
||||
|
||||
with open(args.config, "r") as f:
|
||||
config = json.load(f)
|
||||
|
||||
experiment_name = os.path.basename(args.config)[:-5]
|
||||
if args.lr is not None:
|
||||
logging.warning(f"Overwriting learning rate to {args.lr}")
|
||||
experiment_name += f"_{args.lr}"
|
||||
config["model_args"]["lr"] = args.lr
|
||||
|
||||
run = wandb.init(
|
||||
# set the wandb project where this run will be logged
|
||||
project="training_seisbench_models_on_igf_data_with_pick-benchmark",
|
||||
# track hyperparameters and run metadata
|
||||
config=config
|
||||
)
|
||||
|
||||
if args.test_run:
|
||||
experiment_name = experiment_name + "_test"
|
||||
train(config, experiment_name, test_run=args.test_run)
|
||||
|
||||
running_time = str(
|
||||
datetime.timedelta(seconds=time.perf_counter() - code_start_time)
|
||||
)
|
||||
print(f"Running time: {running_time}")
|
||||
|
@ -1,62 +0,0 @@
|
||||
import os.path
|
||||
import wandb
|
||||
import yaml
|
||||
|
||||
from train import get_data_loaders, load_model, train_model
|
||||
|
||||
from dotenv import load_dotenv
|
||||
|
||||
load_dotenv()
|
||||
wandb_api_key = os.environ.get('WANDB_API_KEY')
|
||||
if wandb_api_key is None:
|
||||
raise ValueError("WANDB_API_KEY environment variable is not set.")
|
||||
|
||||
wandb.login(key=wandb_api_key)
|
||||
|
||||
project_path = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
|
||||
sweep_config_path = project_path + "/experiments/sweep4.yaml"
|
||||
|
||||
with open(sweep_config_path) as file:
|
||||
sweep_configuration = yaml.load(file, Loader=yaml.FullLoader)
|
||||
|
||||
sweep_id = wandb.sweep(
|
||||
sweep=sweep_configuration,
|
||||
project='training_seisbench_models_on_igf_data'
|
||||
)
|
||||
sampling_rate = 100
|
||||
|
||||
def tune_training_hyperparams():
|
||||
|
||||
run = wandb.init(
|
||||
# set the wandb project where this run will be logged
|
||||
project="training_seisbench_models_on_igf_data",
|
||||
# track hyperparameters and run metadata
|
||||
config={"sampling_rate":sampling_rate}
|
||||
)
|
||||
|
||||
wandb.run.log_code(".", include_fn=lambda path: path.endswith("training_wandb_sweep.py"))
|
||||
|
||||
train_loader, dev_loader, test_loader = get_data_loaders(batch_size=wandb.config.batch_size,
|
||||
sampling_rate=wandb.config.sampling_rate,
|
||||
sb_dataset=wandb.config.dataset)
|
||||
|
||||
model_name = wandb.config.model_name
|
||||
pretrained = wandb.config.pretrained
|
||||
print(wandb.config)
|
||||
print(model_name, pretrained, type(pretrained), wandb.config.sampling_rate)
|
||||
if not pretrained:
|
||||
pretrained
|
||||
model = load_model(name=model_name, pretrained=pretrained)
|
||||
path_to_trained_model = f"{project_path}/models/{model_name}_pretrained_on_{pretrained}_finetuned_on_{wandb.config.dataset}.pt"
|
||||
train_model(model, path_to_trained_model, train_loader, dev_loader)
|
||||
|
||||
artifact = wandb.Artifact('model', type='model')
|
||||
artifact.add_file(path_to_trained_model)
|
||||
run.log_artifact(artifact)
|
||||
|
||||
run.finish()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
wandb.agent(sweep_id, function=tune_training_hyperparams, count=10)
|
119
scripts/util.py
Normal file
119
scripts/util.py
Normal file
@ -0,0 +1,119 @@
|
||||
"""
|
||||
This script offers general functionality required in multiple places.
|
||||
"""
|
||||
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
import os
|
||||
import logging
|
||||
import glob
|
||||
import wandb
|
||||
from dotenv import load_dotenv
|
||||
import sys
|
||||
from config_loader import models_path, configs_path
|
||||
import yaml
|
||||
load_dotenv()
|
||||
|
||||
|
||||
logging.basicConfig()
|
||||
logging.getLogger().setLevel(logging.DEBUG)
|
||||
|
||||
|
||||
def load_best_model_data(sweep_id, weights):
|
||||
"""
|
||||
Determines the model with the lowest validation loss.
|
||||
If sweep_id is not provided the best model is determined based on the validation loss in checkpoint filenames in weights directory.
|
||||
|
||||
:param sweep_id:
|
||||
:param weights:
|
||||
:return:
|
||||
"""
|
||||
|
||||
if sweep_id is not None:
|
||||
wandb_project_name = os.environ.get("WANDB_PROJECT")
|
||||
wandb_user = os.environ.get("WANDB_USER")
|
||||
api = wandb.Api()
|
||||
sweep = api.sweep(f"{wandb_user}/{wandb_project_name}/{sweep_id}")
|
||||
|
||||
# Get best run parameters
|
||||
best_run = sweep.best_run()
|
||||
run_id = best_run.id
|
||||
matching_models = glob.glob(f"{models_path}/{weights}/*run={run_id}*ckpt")
|
||||
if len(matching_models)!=1:
|
||||
raise ValueError("Unable to determine the best checkpoint for run_id: " + run_id)
|
||||
best_checkpoint_path = matching_models[0]
|
||||
|
||||
else:
|
||||
checkpoints_path = f"{models_path}/{weights}/*ckpt"
|
||||
logging.debug(f"Searching for checkpoints in dir: {checkpoints_path}")
|
||||
|
||||
checkpoints = glob.glob(checkpoints_path)
|
||||
val_losses = []
|
||||
|
||||
for ckpt in checkpoints:
|
||||
i = ckpt.index("val_loss=")
|
||||
val_losses.append(float(ckpt[i + 9:-5]))
|
||||
|
||||
best_checkpoint_path = checkpoints[np.argmin(val_losses)]
|
||||
run_id_st = best_checkpoint_path.index("run=") + 4
|
||||
run_id_end = best_checkpoint_path.index("-epoch=")
|
||||
run_id = best_checkpoint_path[run_id_st:run_id_end]
|
||||
|
||||
return best_checkpoint_path, run_id
|
||||
|
||||
|
||||
def load_best_model(model_cls, weights, version):
|
||||
"""
|
||||
Determines the model with lowest validation loss from the csv logs and loads it
|
||||
|
||||
:param model_cls: Class of the lightning module to load
|
||||
:param weights: Path to weights as in cmd arguments
|
||||
:param version: String of version file
|
||||
:return: Instance of lightning module that was loaded from the best checkpoint
|
||||
"""
|
||||
metrics = pd.read_csv(weights / version / "metrics.csv")
|
||||
|
||||
idx = np.nanargmin(metrics["val_loss"])
|
||||
min_row = metrics.iloc[idx]
|
||||
|
||||
# For default checkpoint filename, see https://github.com/Lightning-AI/lightning/pull/11805
|
||||
# and https://github.com/Lightning-AI/lightning/issues/16636.
|
||||
# For example, 'epoch=0-step=1.ckpt' means the 1st step has finish, but the 1st epoch hasn't
|
||||
checkpoint = f"epoch={min_row['epoch']:.0f}-step={min_row['step']+1:.0f}.ckpt"
|
||||
|
||||
# For default save path of checkpoints, see https://github.com/Lightning-AI/lightning/pull/12372
|
||||
checkpoint_path = weights / version / "checkpoints" / checkpoint
|
||||
|
||||
return model_cls.load_from_checkpoint(checkpoint_path)
|
||||
|
||||
|
||||
default_workers = os.getenv("BENCHMARK_DEFAULT_WORKERS", None)
|
||||
if default_workers is None:
|
||||
logging.warning(
|
||||
"BENCHMARK_DEFAULT_WORKERS not set. "
|
||||
"Will use 12 workers if not specified otherwise in configuration."
|
||||
)
|
||||
default_workers = 12
|
||||
else:
|
||||
default_workers = int(default_workers)
|
||||
|
||||
|
||||
def load_sweep_config(sweep_fname):
|
||||
"""
|
||||
Loads sweep config from yaml file
|
||||
|
||||
:param sweep_fname: sweep yaml file, expected to be in configs_path
|
||||
:return: Dictionary containing sweep config
|
||||
"""
|
||||
|
||||
sweep_config_path = f"{configs_path}/{sweep_fname}"
|
||||
|
||||
try:
|
||||
with open(sweep_config_path, "r") as file:
|
||||
sweep_config = yaml.load(file, Loader=yaml.FullLoader)
|
||||
except FileNotFoundError:
|
||||
logging.error(f"Could not find sweep config file: {sweep_fname}. "
|
||||
f"Please make sure the file exists and is in {configs_path} directory.")
|
||||
sys.exit(1)
|
||||
|
||||
return sweep_config
|
Loading…
Reference in New Issue
Block a user