340 lines
11 KiB
Python
340 lines
11 KiB
Python
|
import os.path
|
||
|
import wandb
|
||
|
import seisbench.data as sbd
|
||
|
import seisbench.generate as sbg
|
||
|
import seisbench.models as sbm
|
||
|
from seisbench.util import worker_seeding
|
||
|
import numpy as np
|
||
|
import torch
|
||
|
from torch.utils.data import DataLoader
|
||
|
import torch.nn.functional as f
|
||
|
import torch.nn as nn
|
||
|
from torchmetrics import Metric
|
||
|
from torch import Tensor, tensor
|
||
|
import json
|
||
|
from dotenv import load_dotenv
|
||
|
|
||
|
load_dotenv()
|
||
|
wandb_api_key = os.environ.get('WANDB_API_KEY')
|
||
|
if wandb_api_key is None:
|
||
|
raise ValueError("WANDB_API_KEY environment variable is not set.")
|
||
|
|
||
|
wandb.login(key=wandb_api_key)
|
||
|
|
||
|
project_path = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
|
||
|
|
||
|
|
||
|
class PickMAE(Metric):
|
||
|
higher_is_better: bool = False
|
||
|
mae_error: Tensor
|
||
|
|
||
|
def __init__(self, sampling_rate):
|
||
|
super().__init__()
|
||
|
self.add_state("mae_error", default=torch.tensor(0), dist_reduce_fx="sum")
|
||
|
self.sampling_rate = sampling_rate
|
||
|
|
||
|
def update(self, preds: torch.Tensor, target: torch.Tensor):
|
||
|
|
||
|
assert preds.shape == target.shape
|
||
|
|
||
|
pred_pick_idx = torch.argmax(preds[:, 0, :], dim=1).type(torch.FloatTensor)
|
||
|
true_pick_idx = torch.argmax(target[:, 0, :], dim=-1).type(torch.FloatTensor)
|
||
|
|
||
|
mae = nn.L1Loss()
|
||
|
self.mae_error = mae(pred_pick_idx, true_pick_idx) / self.sampling_rate #mae in seconds
|
||
|
|
||
|
def compute(self):
|
||
|
return self.mae_error.float()
|
||
|
|
||
|
|
||
|
class EarlyStopper:
|
||
|
def __init__(self, patience=1, min_delta=0):
|
||
|
self.patience = patience
|
||
|
self.min_delta = min_delta
|
||
|
self.counter = 0
|
||
|
self.min_validation_loss = np.inf
|
||
|
|
||
|
def early_stop(self, validation_loss):
|
||
|
if validation_loss < self.min_validation_loss:
|
||
|
self.min_validation_loss = validation_loss
|
||
|
self.counter = 0
|
||
|
elif validation_loss > (self.min_validation_loss + self.min_delta):
|
||
|
self.counter += 1
|
||
|
if self.counter >= self.patience:
|
||
|
return True
|
||
|
return False
|
||
|
|
||
|
|
||
|
def get_data_generator(split, sampling_rate, path, sb_dataset="ethz", station=None, window='random'):
|
||
|
|
||
|
if path is not None:
|
||
|
data = sbd.WaveformDataset(path, sampling_rate=sampling_rate)
|
||
|
phase_dict = {
|
||
|
"trace_Pg_arrival_sample": "P"
|
||
|
}
|
||
|
elif sb_dataset == "ethz":
|
||
|
data = sbd.ETHZ(sampling_rate=sampling_rate, force=True)
|
||
|
|
||
|
phase_dict = {
|
||
|
"trace_p_arrival_sample": "P",
|
||
|
"trace_pP_arrival_sample": "P",
|
||
|
"trace_P_arrival_sample": "P",
|
||
|
"trace_P1_arrival_sample": "P",
|
||
|
"trace_Pg_arrival_sample": "P",
|
||
|
"trace_Pn_arrival_sample": "P",
|
||
|
"trace_PmP_arrival_sample": "P",
|
||
|
"trace_pwP_arrival_sample": "P",
|
||
|
"trace_pwPm_arrival_sample": "P",
|
||
|
# "trace_s_arrival_sample": "S",
|
||
|
# "trace_S_arrival_sample": "S",
|
||
|
# "trace_S1_arrival_sample": "S",
|
||
|
# "trace_Sg_arrival_sample": "S",
|
||
|
# "trace_SmS_arrival_sample": "S",
|
||
|
# "trace_Sn_arrival_sample": "S",
|
||
|
}
|
||
|
|
||
|
dataset = data.get_split(split)
|
||
|
dataset.filter(dataset.metadata.trace_Pg_arrival_sample.notna())
|
||
|
|
||
|
print(split, dataset.metadata.shape, sampling_rate)
|
||
|
|
||
|
if station is not None:
|
||
|
dataset.filter(dataset.metadata.station_code==station)
|
||
|
|
||
|
data_generator = sbg.GenericGenerator(dataset)
|
||
|
if window == 'random':
|
||
|
print("using random window")
|
||
|
window_selector = sbg.RandomWindow(windowlen=3001, strategy="pad")
|
||
|
else:
|
||
|
window_selector = sbg.FixedWindow(windowlen=3001, p0=0, strategy="pad")
|
||
|
|
||
|
augmentations = [
|
||
|
sbg.WindowAroundSample(list(phase_dict.keys()), samples_before=3000, windowlen=6000, selection="random",
|
||
|
strategy="variable"),
|
||
|
window_selector,
|
||
|
sbg.Normalize(demean_axis=-1, amp_norm_axis=-1, amp_norm_type="peak"),
|
||
|
sbg.ChangeDtype(np.float32),
|
||
|
sbg.ProbabilisticLabeller(label_columns=phase_dict, sigma=30, dim=0)
|
||
|
]
|
||
|
|
||
|
data_generator.add_augmentations(augmentations)
|
||
|
|
||
|
return data_generator
|
||
|
|
||
|
|
||
|
def get_data_generators(sampling_rate=100, path=project_path+"/data/igf/seisbench_format", sb_dataset="ethz", station=None,
|
||
|
window='random'):
|
||
|
|
||
|
train_generator = get_data_generator("train", sampling_rate, path, sb_dataset, station, window)
|
||
|
dev_generator = get_data_generator("dev", sampling_rate, path, sb_dataset, station, window)
|
||
|
test_generator = get_data_generator("test", sampling_rate, path, sb_dataset, station, window)
|
||
|
|
||
|
return train_generator, dev_generator, test_generator
|
||
|
|
||
|
|
||
|
def get_data_loaders(batch_size=256, sampling_rate=100, path=project_path+"/data/igf/seisbench_format", sb_dataset="ethz",
|
||
|
window='random'):
|
||
|
|
||
|
train_generator, dev_generator, test_generator = get_data_generators(sampling_rate, path, sb_dataset, window=window)
|
||
|
num_workers = 0 # The number of threads used for loading data
|
||
|
|
||
|
train_loader = DataLoader(train_generator, batch_size=batch_size, shuffle=True, num_workers=num_workers,
|
||
|
worker_init_fn=worker_seeding)
|
||
|
dev_loader = DataLoader(dev_generator, batch_size=batch_size, shuffle=False, num_workers=num_workers,
|
||
|
worker_init_fn=worker_seeding)
|
||
|
|
||
|
test_loader = DataLoader(test_generator, batch_size=batch_size, shuffle=False, num_workers=num_workers,
|
||
|
worker_init_fn=worker_seeding)
|
||
|
|
||
|
return train_loader, dev_loader, test_loader
|
||
|
|
||
|
|
||
|
def load_model(name="PhaseNet", pretrained=None, classes=2, modify_output=True):
|
||
|
|
||
|
if name == "PhaseNet":
|
||
|
|
||
|
if pretrained is not None and pretrained:
|
||
|
model = sbm.PhaseNet(phases="PN", norm="peak").from_pretrained(pretrained)
|
||
|
else:
|
||
|
model = sbm.PhaseNet(phases="PN", norm="peak")
|
||
|
|
||
|
if modify_output:
|
||
|
model.out = nn.Conv1d(model.filters_root, classes, 1, padding="same")
|
||
|
|
||
|
return model
|
||
|
|
||
|
|
||
|
def train_one_epoch(model, dataloader, optimizer, pick_mae):
|
||
|
size = len(dataloader.dataset)
|
||
|
for batch_id, batch in enumerate(dataloader):
|
||
|
|
||
|
# Compute prediction and loss
|
||
|
|
||
|
pred = model(batch["X"].to(model.device))
|
||
|
|
||
|
loss = loss_fn(pred, batch["y"].to(model.device))
|
||
|
|
||
|
# Compute cross entropy loss
|
||
|
cross_entropy_loss = f.cross_entropy(pred, batch["y"])
|
||
|
|
||
|
# Compute mae
|
||
|
mae = pick_mae(pred, batch['y'])
|
||
|
|
||
|
wandb.log({"loss": loss})
|
||
|
wandb.log({"batch cross entropy loss": cross_entropy_loss})
|
||
|
wandb.log({"p_mae": mae})
|
||
|
|
||
|
|
||
|
# Backpropagation
|
||
|
optimizer.zero_grad()
|
||
|
loss.backward()
|
||
|
optimizer.step()
|
||
|
|
||
|
if batch_id % 5 == 0:
|
||
|
loss, current = loss.item(), batch_id * batch["X"].shape[0]
|
||
|
print(f"loss: {loss:>7f} [{current:>5d}/{size:>5d}]")
|
||
|
print(f"mae: {mae:>7f}")
|
||
|
|
||
|
|
||
|
def test_one_epoch(model, dataloader, pick_mae, wandb_log=True):
|
||
|
|
||
|
num_batches = len(dataloader)
|
||
|
test_loss = 0
|
||
|
test_mae = 0
|
||
|
|
||
|
with torch.no_grad():
|
||
|
for batch in dataloader:
|
||
|
pred = model(batch["X"].to(model.device))
|
||
|
|
||
|
test_loss += loss_fn(pred, batch["y"].to(model.device)).item()
|
||
|
test_mae += pick_mae(pred, batch['y'])
|
||
|
test_cross_entropy_loss = f.cross_entropy(pred, batch["y"])
|
||
|
if wandb_log:
|
||
|
wandb.log({"batch cross entropy test loss": test_cross_entropy_loss})
|
||
|
|
||
|
test_loss /= num_batches
|
||
|
test_mae /= num_batches
|
||
|
|
||
|
wandb.log({"test_p_mae": test_mae, "test_loss": test_loss})
|
||
|
|
||
|
print(f"Test avg loss: {test_loss:>8f}")
|
||
|
print(f"Test avg mae: {test_mae:>7f}\n")
|
||
|
|
||
|
return test_loss, test_mae
|
||
|
|
||
|
|
||
|
def train_model(model, path_to_trained_model, train_loader, dev_loader):
|
||
|
|
||
|
wandb.watch(model, log_freq=10)
|
||
|
|
||
|
optimizer = torch.optim.Adam(model.parameters(), lr=wandb.config.learning_rate)
|
||
|
early_stopper = EarlyStopper(patience=3, min_delta=10)
|
||
|
pick_mae = PickMAE(wandb.config.sampling_rate)
|
||
|
|
||
|
best_loss = np.inf
|
||
|
best_metrics = {}
|
||
|
|
||
|
for t in range(wandb.config.epochs):
|
||
|
print(f"Epoch {t + 1}\n-------------------------------")
|
||
|
train_one_epoch(model, train_loader, optimizer, pick_mae)
|
||
|
test_loss, test_mae = test_one_epoch(model, dev_loader, pick_mae)
|
||
|
|
||
|
if test_loss < best_loss:
|
||
|
best_loss = test_loss
|
||
|
best_metrics = {"test_p_mae": test_mae, "test_loss": test_loss}
|
||
|
torch.save(model.state_dict(), path_to_trained_model)
|
||
|
|
||
|
if early_stopper.early_stop(test_loss):
|
||
|
break
|
||
|
|
||
|
print("Best model: ", str(best_metrics))
|
||
|
|
||
|
|
||
|
def loss_fn(y_pred, y_true, eps=1e-5):
|
||
|
# vector cross entropy loss
|
||
|
h = y_true * torch.log(y_pred + eps)
|
||
|
h = h.mean(-1).sum(-1) # Mean along sample dimension and sum along pick dimension
|
||
|
h = h.mean() # Mean over batch axis
|
||
|
return -h
|
||
|
|
||
|
|
||
|
def train_phasenet_on_sb_data():
|
||
|
|
||
|
config = {
|
||
|
"epochs": 3,
|
||
|
"batch_size": 256,
|
||
|
"dataset": "ethz",
|
||
|
"sampling_rate": 100,
|
||
|
"model_name": "PhaseNet"
|
||
|
}
|
||
|
|
||
|
run = wandb.init(
|
||
|
# set the wandb project where this run will be logged
|
||
|
project="training_seisbench_models_on_igf_data",
|
||
|
# track hyperparameters and run metadata
|
||
|
config=config
|
||
|
)
|
||
|
|
||
|
wandb.run.log_code(".", include_fn=lambda path: path.endswith("training_wandb_sweep.py"))
|
||
|
|
||
|
train_loader, dev_loader, test = get_data_loaders(batch_size=wandb.config.batch_size,
|
||
|
sampling_rate=wandb.config.sampling_rate,
|
||
|
path=None,
|
||
|
sb_dataset=wandb.config.dataset)
|
||
|
|
||
|
model = load_model(name=wandb.config.model_name, pretrained=None, modify_output=True)
|
||
|
path_to_trained_model = f"{project_path}/models/{wandb.config.model_name}_trained_on_{wandb.config.data_set}.pt"
|
||
|
train_model(model, path_to_trained_model,
|
||
|
train_loader, dev_loader)
|
||
|
|
||
|
artifact = wandb.Artifact('model', type='model')
|
||
|
artifact.add_file(path_to_trained_model)
|
||
|
run.log_artifact(artifact)
|
||
|
|
||
|
run.finish()
|
||
|
|
||
|
|
||
|
def load_config(config_path):
|
||
|
with open(config_path, 'r') as f:
|
||
|
config = json.load(f)
|
||
|
return config
|
||
|
|
||
|
|
||
|
def train_sbmodel_on_igf_data():
|
||
|
|
||
|
config_path = project_path + "/experiments/config.json"
|
||
|
config = load_config(config_path)
|
||
|
|
||
|
run = wandb.init(
|
||
|
# set the wandb project where this run will be logged
|
||
|
project="training_seisbench_models_on_igf_data",
|
||
|
# track hyperparameters and run metadata
|
||
|
config=config
|
||
|
)
|
||
|
wandb.run.log_code(".", include_fn=lambda path: path.endswith("training_wandb_sweep.py"))
|
||
|
|
||
|
print(wandb.config.batch_size, wandb.config.sampling_rate)
|
||
|
train_loader, dev_loader, test_loader = get_data_loaders(batch_size=wandb.config.batch_size,
|
||
|
sampling_rate=wandb.config.sampling_rate
|
||
|
)
|
||
|
|
||
|
model_name = wandb.config.model_name
|
||
|
pretrained = wandb.config.pretrained
|
||
|
|
||
|
print(model_name, pretrained)
|
||
|
model = load_model(name=model_name, pretrained=pretrained)
|
||
|
path_to_trained_model = f"{project_path}/models/{model_name}_pretrained_on_{pretrained}_finetuned_on_{wandb.config.dataset}.pt"
|
||
|
train_model(model, path_to_trained_model, train_loader, dev_loader)
|
||
|
|
||
|
artifact = wandb.Artifact('model', type='model')
|
||
|
artifact.add_file(path_to_trained_model)
|
||
|
run.log_artifact(artifact)
|
||
|
|
||
|
run.finish()
|
||
|
|
||
|
|
||
|
if __name__ == "__main__":
|
||
|
# train_phasenet_on_sb_data()
|
||
|
train_sbmodel_on_igf_data()
|
||
|
|