initial commit
This commit is contained in:
+10
@@ -0,0 +1,10 @@
|
||||
__pycache__/
|
||||
.idea/
|
||||
*/.ipynb_checkpoints/
|
||||
.ipynb_checkpoints/
|
||||
.env
|
||||
models/
|
||||
data/
|
||||
wip
|
||||
artifacts/
|
||||
wandb/
|
||||
@@ -0,0 +1,18 @@
|
||||
# Demo notebooks and scripts for EPOS AI Platform
|
||||
|
||||
|
||||
This repo contains notebooks and scripts demonstrating how to:
|
||||
- Prepare IGF data for training model detecting P phase (i.e. transform mseeds into [SeisBench data format](https://seisbench.readthedocs.io/en/stable/pages/data_format.html)), check the [notebook](utils/Transforming%20mseeds%20to%20SeisBench%20dataset.ipynb).
|
||||
The original data can be downloaded from the [drive](https://drive.google.com/drive/folders/1InVI9DLaD7gdzraM2jMzeIrtiBSu-UIK?usp=drive_link)
|
||||
|
||||
- Explore available data, check the [notebook](notebooks/Explore%20igf%20data.ipynb)
|
||||
- Train cnn model (Seisbench PhaseNet)to detect P phase, check the [script](scripts/train.py)
|
||||
- Search for the best training hyperparams, check the [script](scripts/training_wandb_sweep.py)
|
||||
- Validate model performance, check the [notebook](notebooks/Check%20model%20performance%20depending%20on%20station-random%20window.ipynb)
|
||||
- Use model for detecting P phase, check the [notebook](notebooks/Present%20model%20predictions.ipynb)
|
||||
|
||||
|
||||
### Usage
|
||||
|
||||
To install all dependencies run:
|
||||
`poetry install`
|
||||
@@ -0,0 +1,11 @@
|
||||
{
|
||||
"epochs": 10,
|
||||
"batch_size": 256,
|
||||
"dataset": "igf_1",
|
||||
"sampling_rate": 100,
|
||||
"model_names": "EQTransformer,BasicPhaseAE,GPD",
|
||||
"model_name": "PhaseNet",
|
||||
"learning_rate": 0.01,
|
||||
"pretrained": null,
|
||||
"sampling_rate": 100
|
||||
}
|
||||
@@ -0,0 +1,38 @@
|
||||
method: bayes
|
||||
metric:
|
||||
goal: minimize
|
||||
name: test_loss
|
||||
parameters:
|
||||
batch_size:
|
||||
distribution: int_uniform
|
||||
max: 512
|
||||
min: 128
|
||||
dataset:
|
||||
distribution: categorical
|
||||
values:
|
||||
- igf_1
|
||||
epochs:
|
||||
distribution: int_uniform
|
||||
max: 60
|
||||
min: 15
|
||||
learning_rate:
|
||||
distribution: uniform
|
||||
max: 0.02
|
||||
min: 0.005
|
||||
model_name:
|
||||
distribution: categorical
|
||||
values:
|
||||
- PhaseNet
|
||||
pretrained:
|
||||
distribution: categorical
|
||||
values:
|
||||
- diting
|
||||
- ethz
|
||||
- geofon
|
||||
- instance
|
||||
- iquique
|
||||
- lendb
|
||||
- neic
|
||||
- original
|
||||
- scedc
|
||||
program: training_wandb_sweep.py
|
||||
@@ -0,0 +1,32 @@
|
||||
method: bayes
|
||||
metric:
|
||||
goal: minimize
|
||||
name: test_loss
|
||||
parameters:
|
||||
batch_size:
|
||||
distribution: int_uniform
|
||||
max: 512
|
||||
min: 256
|
||||
dataset:
|
||||
distribution: categorical
|
||||
values:
|
||||
- igf_1
|
||||
epochs:
|
||||
distribution: categorical
|
||||
values:
|
||||
- 10
|
||||
learning_rate:
|
||||
distribution: uniform
|
||||
max: 0.02
|
||||
min: 0.01
|
||||
model_name:
|
||||
distribution: categorical
|
||||
values:
|
||||
# - EQTransformer
|
||||
- PhaseNet
|
||||
pretrained:
|
||||
distribution: categorical
|
||||
values:
|
||||
- instance
|
||||
- iquique
|
||||
- false
|
||||
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
Generated
+1745
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,22 @@
|
||||
[tool.poetry]
|
||||
name = "ai_platform_demo_scripts"
|
||||
version = "0.1.0"
|
||||
description = ""
|
||||
authors = ["Krystyna Milian <krystyna.milian@gmail.com>"]
|
||||
|
||||
[tool.poetry.dependencies]
|
||||
python = "^3.10"
|
||||
seisbench = "^0.4.1"
|
||||
torch = "^2.0.1"
|
||||
PyYAML = "^6.0"
|
||||
python-dotenv = "^1.0.0"
|
||||
pandas = "^2.0.3"
|
||||
obspy = "^1.4.0"
|
||||
wandb = "^0.15.4"
|
||||
torchmetrics = "^0.11.4"
|
||||
|
||||
[tool.poetry.dev-dependencies]
|
||||
|
||||
[build-system]
|
||||
requires = ["poetry-core>=1.0.0"]
|
||||
build-backend = "poetry.core.masonry.api"
|
||||
@@ -0,0 +1,339 @@
|
||||
import os.path
|
||||
import wandb
|
||||
import seisbench.data as sbd
|
||||
import seisbench.generate as sbg
|
||||
import seisbench.models as sbm
|
||||
from seisbench.util import worker_seeding
|
||||
import numpy as np
|
||||
import torch
|
||||
from torch.utils.data import DataLoader
|
||||
import torch.nn.functional as f
|
||||
import torch.nn as nn
|
||||
from torchmetrics import Metric
|
||||
from torch import Tensor, tensor
|
||||
import json
|
||||
from dotenv import load_dotenv
|
||||
|
||||
load_dotenv()
|
||||
wandb_api_key = os.environ.get('WANDB_API_KEY')
|
||||
if wandb_api_key is None:
|
||||
raise ValueError("WANDB_API_KEY environment variable is not set.")
|
||||
|
||||
wandb.login(key=wandb_api_key)
|
||||
|
||||
project_path = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
|
||||
|
||||
|
||||
class PickMAE(Metric):
|
||||
higher_is_better: bool = False
|
||||
mae_error: Tensor
|
||||
|
||||
def __init__(self, sampling_rate):
|
||||
super().__init__()
|
||||
self.add_state("mae_error", default=torch.tensor(0), dist_reduce_fx="sum")
|
||||
self.sampling_rate = sampling_rate
|
||||
|
||||
def update(self, preds: torch.Tensor, target: torch.Tensor):
|
||||
|
||||
assert preds.shape == target.shape
|
||||
|
||||
pred_pick_idx = torch.argmax(preds[:, 0, :], dim=1).type(torch.FloatTensor)
|
||||
true_pick_idx = torch.argmax(target[:, 0, :], dim=-1).type(torch.FloatTensor)
|
||||
|
||||
mae = nn.L1Loss()
|
||||
self.mae_error = mae(pred_pick_idx, true_pick_idx) / self.sampling_rate #mae in seconds
|
||||
|
||||
def compute(self):
|
||||
return self.mae_error.float()
|
||||
|
||||
|
||||
class EarlyStopper:
|
||||
def __init__(self, patience=1, min_delta=0):
|
||||
self.patience = patience
|
||||
self.min_delta = min_delta
|
||||
self.counter = 0
|
||||
self.min_validation_loss = np.inf
|
||||
|
||||
def early_stop(self, validation_loss):
|
||||
if validation_loss < self.min_validation_loss:
|
||||
self.min_validation_loss = validation_loss
|
||||
self.counter = 0
|
||||
elif validation_loss > (self.min_validation_loss + self.min_delta):
|
||||
self.counter += 1
|
||||
if self.counter >= self.patience:
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def get_data_generator(split, sampling_rate, path, sb_dataset="ethz", station=None, window='random'):
|
||||
|
||||
if path is not None:
|
||||
data = sbd.WaveformDataset(path, sampling_rate=sampling_rate)
|
||||
phase_dict = {
|
||||
"trace_Pg_arrival_sample": "P"
|
||||
}
|
||||
elif sb_dataset == "ethz":
|
||||
data = sbd.ETHZ(sampling_rate=sampling_rate, force=True)
|
||||
|
||||
phase_dict = {
|
||||
"trace_p_arrival_sample": "P",
|
||||
"trace_pP_arrival_sample": "P",
|
||||
"trace_P_arrival_sample": "P",
|
||||
"trace_P1_arrival_sample": "P",
|
||||
"trace_Pg_arrival_sample": "P",
|
||||
"trace_Pn_arrival_sample": "P",
|
||||
"trace_PmP_arrival_sample": "P",
|
||||
"trace_pwP_arrival_sample": "P",
|
||||
"trace_pwPm_arrival_sample": "P",
|
||||
# "trace_s_arrival_sample": "S",
|
||||
# "trace_S_arrival_sample": "S",
|
||||
# "trace_S1_arrival_sample": "S",
|
||||
# "trace_Sg_arrival_sample": "S",
|
||||
# "trace_SmS_arrival_sample": "S",
|
||||
# "trace_Sn_arrival_sample": "S",
|
||||
}
|
||||
|
||||
dataset = data.get_split(split)
|
||||
dataset.filter(dataset.metadata.trace_Pg_arrival_sample.notna())
|
||||
|
||||
print(split, dataset.metadata.shape, sampling_rate)
|
||||
|
||||
if station is not None:
|
||||
dataset.filter(dataset.metadata.station_code==station)
|
||||
|
||||
data_generator = sbg.GenericGenerator(dataset)
|
||||
if window == 'random':
|
||||
print("using random window")
|
||||
window_selector = sbg.RandomWindow(windowlen=3001, strategy="pad")
|
||||
else:
|
||||
window_selector = sbg.FixedWindow(windowlen=3001, p0=0, strategy="pad")
|
||||
|
||||
augmentations = [
|
||||
sbg.WindowAroundSample(list(phase_dict.keys()), samples_before=3000, windowlen=6000, selection="random",
|
||||
strategy="variable"),
|
||||
window_selector,
|
||||
sbg.Normalize(demean_axis=-1, amp_norm_axis=-1, amp_norm_type="peak"),
|
||||
sbg.ChangeDtype(np.float32),
|
||||
sbg.ProbabilisticLabeller(label_columns=phase_dict, sigma=30, dim=0)
|
||||
]
|
||||
|
||||
data_generator.add_augmentations(augmentations)
|
||||
|
||||
return data_generator
|
||||
|
||||
|
||||
def get_data_generators(sampling_rate=100, path=project_path+"/data/igf/seisbench_format", sb_dataset="ethz", station=None,
|
||||
window='random'):
|
||||
|
||||
train_generator = get_data_generator("train", sampling_rate, path, sb_dataset, station, window)
|
||||
dev_generator = get_data_generator("dev", sampling_rate, path, sb_dataset, station, window)
|
||||
test_generator = get_data_generator("test", sampling_rate, path, sb_dataset, station, window)
|
||||
|
||||
return train_generator, dev_generator, test_generator
|
||||
|
||||
|
||||
def get_data_loaders(batch_size=256, sampling_rate=100, path=project_path+"/data/igf/seisbench_format", sb_dataset="ethz",
|
||||
window='random'):
|
||||
|
||||
train_generator, dev_generator, test_generator = get_data_generators(sampling_rate, path, sb_dataset, window=window)
|
||||
num_workers = 0 # The number of threads used for loading data
|
||||
|
||||
train_loader = DataLoader(train_generator, batch_size=batch_size, shuffle=True, num_workers=num_workers,
|
||||
worker_init_fn=worker_seeding)
|
||||
dev_loader = DataLoader(dev_generator, batch_size=batch_size, shuffle=False, num_workers=num_workers,
|
||||
worker_init_fn=worker_seeding)
|
||||
|
||||
test_loader = DataLoader(test_generator, batch_size=batch_size, shuffle=False, num_workers=num_workers,
|
||||
worker_init_fn=worker_seeding)
|
||||
|
||||
return train_loader, dev_loader, test_loader
|
||||
|
||||
|
||||
def load_model(name="PhaseNet", pretrained=None, classes=2, modify_output=True):
|
||||
|
||||
if name == "PhaseNet":
|
||||
|
||||
if pretrained is not None and pretrained:
|
||||
model = sbm.PhaseNet(phases="PN", norm="peak").from_pretrained(pretrained)
|
||||
else:
|
||||
model = sbm.PhaseNet(phases="PN", norm="peak")
|
||||
|
||||
if modify_output:
|
||||
model.out = nn.Conv1d(model.filters_root, classes, 1, padding="same")
|
||||
|
||||
return model
|
||||
|
||||
|
||||
def train_one_epoch(model, dataloader, optimizer, pick_mae):
|
||||
size = len(dataloader.dataset)
|
||||
for batch_id, batch in enumerate(dataloader):
|
||||
|
||||
# Compute prediction and loss
|
||||
|
||||
pred = model(batch["X"].to(model.device))
|
||||
|
||||
loss = loss_fn(pred, batch["y"].to(model.device))
|
||||
|
||||
# Compute cross entropy loss
|
||||
cross_entropy_loss = f.cross_entropy(pred, batch["y"])
|
||||
|
||||
# Compute mae
|
||||
mae = pick_mae(pred, batch['y'])
|
||||
|
||||
wandb.log({"loss": loss})
|
||||
wandb.log({"batch cross entropy loss": cross_entropy_loss})
|
||||
wandb.log({"p_mae": mae})
|
||||
|
||||
|
||||
# Backpropagation
|
||||
optimizer.zero_grad()
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
|
||||
if batch_id % 5 == 0:
|
||||
loss, current = loss.item(), batch_id * batch["X"].shape[0]
|
||||
print(f"loss: {loss:>7f} [{current:>5d}/{size:>5d}]")
|
||||
print(f"mae: {mae:>7f}")
|
||||
|
||||
|
||||
def test_one_epoch(model, dataloader, pick_mae, wandb_log=True):
|
||||
|
||||
num_batches = len(dataloader)
|
||||
test_loss = 0
|
||||
test_mae = 0
|
||||
|
||||
with torch.no_grad():
|
||||
for batch in dataloader:
|
||||
pred = model(batch["X"].to(model.device))
|
||||
|
||||
test_loss += loss_fn(pred, batch["y"].to(model.device)).item()
|
||||
test_mae += pick_mae(pred, batch['y'])
|
||||
test_cross_entropy_loss = f.cross_entropy(pred, batch["y"])
|
||||
if wandb_log:
|
||||
wandb.log({"batch cross entropy test loss": test_cross_entropy_loss})
|
||||
|
||||
test_loss /= num_batches
|
||||
test_mae /= num_batches
|
||||
|
||||
wandb.log({"test_p_mae": test_mae, "test_loss": test_loss})
|
||||
|
||||
print(f"Test avg loss: {test_loss:>8f}")
|
||||
print(f"Test avg mae: {test_mae:>7f}\n")
|
||||
|
||||
return test_loss, test_mae
|
||||
|
||||
|
||||
def train_model(model, path_to_trained_model, train_loader, dev_loader):
|
||||
|
||||
wandb.watch(model, log_freq=10)
|
||||
|
||||
optimizer = torch.optim.Adam(model.parameters(), lr=wandb.config.learning_rate)
|
||||
early_stopper = EarlyStopper(patience=3, min_delta=10)
|
||||
pick_mae = PickMAE(wandb.config.sampling_rate)
|
||||
|
||||
best_loss = np.inf
|
||||
best_metrics = {}
|
||||
|
||||
for t in range(wandb.config.epochs):
|
||||
print(f"Epoch {t + 1}\n-------------------------------")
|
||||
train_one_epoch(model, train_loader, optimizer, pick_mae)
|
||||
test_loss, test_mae = test_one_epoch(model, dev_loader, pick_mae)
|
||||
|
||||
if test_loss < best_loss:
|
||||
best_loss = test_loss
|
||||
best_metrics = {"test_p_mae": test_mae, "test_loss": test_loss}
|
||||
torch.save(model.state_dict(), path_to_trained_model)
|
||||
|
||||
if early_stopper.early_stop(test_loss):
|
||||
break
|
||||
|
||||
print("Best model: ", str(best_metrics))
|
||||
|
||||
|
||||
def loss_fn(y_pred, y_true, eps=1e-5):
|
||||
# vector cross entropy loss
|
||||
h = y_true * torch.log(y_pred + eps)
|
||||
h = h.mean(-1).sum(-1) # Mean along sample dimension and sum along pick dimension
|
||||
h = h.mean() # Mean over batch axis
|
||||
return -h
|
||||
|
||||
|
||||
def train_phasenet_on_sb_data():
|
||||
|
||||
config = {
|
||||
"epochs": 3,
|
||||
"batch_size": 256,
|
||||
"dataset": "ethz",
|
||||
"sampling_rate": 100,
|
||||
"model_name": "PhaseNet"
|
||||
}
|
||||
|
||||
run = wandb.init(
|
||||
# set the wandb project where this run will be logged
|
||||
project="training_seisbench_models_on_igf_data",
|
||||
# track hyperparameters and run metadata
|
||||
config=config
|
||||
)
|
||||
|
||||
wandb.run.log_code(".", include_fn=lambda path: path.endswith("training_wandb_sweep.py"))
|
||||
|
||||
train_loader, dev_loader, test = get_data_loaders(batch_size=wandb.config.batch_size,
|
||||
sampling_rate=wandb.config.sampling_rate,
|
||||
path=None,
|
||||
sb_dataset=wandb.config.dataset)
|
||||
|
||||
model = load_model(name=wandb.config.model_name, pretrained=None, modify_output=True)
|
||||
path_to_trained_model = f"{project_path}/models/{wandb.config.model_name}_trained_on_{wandb.config.data_set}.pt"
|
||||
train_model(model, path_to_trained_model,
|
||||
train_loader, dev_loader)
|
||||
|
||||
artifact = wandb.Artifact('model', type='model')
|
||||
artifact.add_file(path_to_trained_model)
|
||||
run.log_artifact(artifact)
|
||||
|
||||
run.finish()
|
||||
|
||||
|
||||
def load_config(config_path):
|
||||
with open(config_path, 'r') as f:
|
||||
config = json.load(f)
|
||||
return config
|
||||
|
||||
|
||||
def train_sbmodel_on_igf_data():
|
||||
|
||||
config_path = project_path + "/experiments/config.json"
|
||||
config = load_config(config_path)
|
||||
|
||||
run = wandb.init(
|
||||
# set the wandb project where this run will be logged
|
||||
project="training_seisbench_models_on_igf_data",
|
||||
# track hyperparameters and run metadata
|
||||
config=config
|
||||
)
|
||||
wandb.run.log_code(".", include_fn=lambda path: path.endswith("training_wandb_sweep.py"))
|
||||
|
||||
print(wandb.config.batch_size, wandb.config.sampling_rate)
|
||||
train_loader, dev_loader, test_loader = get_data_loaders(batch_size=wandb.config.batch_size,
|
||||
sampling_rate=wandb.config.sampling_rate
|
||||
)
|
||||
|
||||
model_name = wandb.config.model_name
|
||||
pretrained = wandb.config.pretrained
|
||||
|
||||
print(model_name, pretrained)
|
||||
model = load_model(name=model_name, pretrained=pretrained)
|
||||
path_to_trained_model = f"{project_path}/models/{model_name}_pretrained_on_{pretrained}_finetuned_on_{wandb.config.dataset}.pt"
|
||||
train_model(model, path_to_trained_model, train_loader, dev_loader)
|
||||
|
||||
artifact = wandb.Artifact('model', type='model')
|
||||
artifact.add_file(path_to_trained_model)
|
||||
run.log_artifact(artifact)
|
||||
|
||||
run.finish()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# train_phasenet_on_sb_data()
|
||||
train_sbmodel_on_igf_data()
|
||||
|
||||
@@ -0,0 +1,62 @@
|
||||
import os.path
|
||||
import wandb
|
||||
import yaml
|
||||
|
||||
from train import get_data_loaders, load_model, train_model
|
||||
|
||||
from dotenv import load_dotenv
|
||||
|
||||
load_dotenv()
|
||||
wandb_api_key = os.environ.get('WANDB_API_KEY')
|
||||
if wandb_api_key is None:
|
||||
raise ValueError("WANDB_API_KEY environment variable is not set.")
|
||||
|
||||
wandb.login(key=wandb_api_key)
|
||||
|
||||
project_path = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
|
||||
sweep_config_path = project_path + "/experiments/sweep4.yaml"
|
||||
|
||||
with open(sweep_config_path) as file:
|
||||
sweep_configuration = yaml.load(file, Loader=yaml.FullLoader)
|
||||
|
||||
sweep_id = wandb.sweep(
|
||||
sweep=sweep_configuration,
|
||||
project='training_seisbench_models_on_igf_data'
|
||||
)
|
||||
sampling_rate = 100
|
||||
|
||||
def tune_training_hyperparams():
|
||||
|
||||
run = wandb.init(
|
||||
# set the wandb project where this run will be logged
|
||||
project="training_seisbench_models_on_igf_data",
|
||||
# track hyperparameters and run metadata
|
||||
config={"sampling_rate":sampling_rate}
|
||||
)
|
||||
|
||||
wandb.run.log_code(".", include_fn=lambda path: path.endswith("training_wandb_sweep.py"))
|
||||
|
||||
train_loader, dev_loader, test_loader = get_data_loaders(batch_size=wandb.config.batch_size,
|
||||
sampling_rate=wandb.config.sampling_rate,
|
||||
sb_dataset=wandb.config.dataset)
|
||||
|
||||
model_name = wandb.config.model_name
|
||||
pretrained = wandb.config.pretrained
|
||||
print(wandb.config)
|
||||
print(model_name, pretrained, type(pretrained), wandb.config.sampling_rate)
|
||||
if not pretrained:
|
||||
pretrained
|
||||
model = load_model(name=model_name, pretrained=pretrained)
|
||||
path_to_trained_model = f"{project_path}/models/{model_name}_pretrained_on_{pretrained}_finetuned_on_{wandb.config.dataset}.pt"
|
||||
train_model(model, path_to_trained_model, train_loader, dev_loader)
|
||||
|
||||
artifact = wandb.Artifact('model', type='model')
|
||||
artifact.add_file(path_to_trained_model)
|
||||
run.log_artifact(artifact)
|
||||
|
||||
run.finish()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
wandb.agent(sweep_id, function=tune_training_hyperparams, count=10)
|
||||
File diff suppressed because one or more lines are too long
Reference in New Issue
Block a user