import os.path import wandb import seisbench.data as sbd import seisbench.generate as sbg import seisbench.models as sbm from seisbench.util import worker_seeding import numpy as np import torch from torch.utils.data import DataLoader import torch.nn.functional as f import torch.nn as nn from torchmetrics import Metric from torch import Tensor, tensor import json from dotenv import load_dotenv load_dotenv() wandb_api_key = os.environ.get('WANDB_API_KEY') if wandb_api_key is None: raise ValueError("WANDB_API_KEY environment variable is not set.") wandb.login(key=wandb_api_key) project_path = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) class PickMAE(Metric): higher_is_better: bool = False mae_error: Tensor def __init__(self, sampling_rate): super().__init__() self.add_state("mae_error", default=torch.tensor(0), dist_reduce_fx="sum") self.sampling_rate = sampling_rate def update(self, preds: torch.Tensor, target: torch.Tensor): assert preds.shape == target.shape pred_pick_idx = torch.argmax(preds[:, 0, :], dim=1).type(torch.FloatTensor) true_pick_idx = torch.argmax(target[:, 0, :], dim=-1).type(torch.FloatTensor) mae = nn.L1Loss() self.mae_error = mae(pred_pick_idx, true_pick_idx) / self.sampling_rate #mae in seconds def compute(self): return self.mae_error.float() class EarlyStopper: def __init__(self, patience=1, min_delta=0): self.patience = patience self.min_delta = min_delta self.counter = 0 self.min_validation_loss = np.inf def early_stop(self, validation_loss): if validation_loss < self.min_validation_loss: self.min_validation_loss = validation_loss self.counter = 0 elif validation_loss > (self.min_validation_loss + self.min_delta): self.counter += 1 if self.counter >= self.patience: return True return False def get_data_generator(split, sampling_rate, path, sb_dataset="ethz", station=None, window='random'): if path is not None: data = sbd.WaveformDataset(path, sampling_rate=sampling_rate) phase_dict = { "trace_Pg_arrival_sample": "P" } elif sb_dataset == "ethz": data = sbd.ETHZ(sampling_rate=sampling_rate, force=True) phase_dict = { "trace_p_arrival_sample": "P", "trace_pP_arrival_sample": "P", "trace_P_arrival_sample": "P", "trace_P1_arrival_sample": "P", "trace_Pg_arrival_sample": "P", "trace_Pn_arrival_sample": "P", "trace_PmP_arrival_sample": "P", "trace_pwP_arrival_sample": "P", "trace_pwPm_arrival_sample": "P", # "trace_s_arrival_sample": "S", # "trace_S_arrival_sample": "S", # "trace_S1_arrival_sample": "S", # "trace_Sg_arrival_sample": "S", # "trace_SmS_arrival_sample": "S", # "trace_Sn_arrival_sample": "S", } dataset = data.get_split(split) dataset.filter(dataset.metadata.trace_Pg_arrival_sample.notna()) print(split, dataset.metadata.shape, sampling_rate) if station is not None: dataset.filter(dataset.metadata.station_code==station) data_generator = sbg.GenericGenerator(dataset) if window == 'random': print("using random window") window_selector = sbg.RandomWindow(windowlen=3001, strategy="pad") else: window_selector = sbg.FixedWindow(windowlen=3001, p0=0, strategy="pad") augmentations = [ sbg.WindowAroundSample(list(phase_dict.keys()), samples_before=3000, windowlen=6000, selection="random", strategy="variable"), window_selector, sbg.Normalize(demean_axis=-1, amp_norm_axis=-1, amp_norm_type="peak"), sbg.ChangeDtype(np.float32), sbg.ProbabilisticLabeller(label_columns=phase_dict, sigma=30, dim=0) ] data_generator.add_augmentations(augmentations) return data_generator def get_data_generators(sampling_rate=100, path=project_path+"/data/igf/seisbench_format", sb_dataset="ethz", station=None, window='random'): train_generator = get_data_generator("train", sampling_rate, path, sb_dataset, station, window) dev_generator = get_data_generator("dev", sampling_rate, path, sb_dataset, station, window) test_generator = get_data_generator("test", sampling_rate, path, sb_dataset, station, window) return train_generator, dev_generator, test_generator def get_data_loaders(batch_size=256, sampling_rate=100, path=project_path+"/data/igf/seisbench_format", sb_dataset="ethz", window='random'): train_generator, dev_generator, test_generator = get_data_generators(sampling_rate, path, sb_dataset, window=window) num_workers = 0 # The number of threads used for loading data train_loader = DataLoader(train_generator, batch_size=batch_size, shuffle=True, num_workers=num_workers, worker_init_fn=worker_seeding) dev_loader = DataLoader(dev_generator, batch_size=batch_size, shuffle=False, num_workers=num_workers, worker_init_fn=worker_seeding) test_loader = DataLoader(test_generator, batch_size=batch_size, shuffle=False, num_workers=num_workers, worker_init_fn=worker_seeding) return train_loader, dev_loader, test_loader def load_model(name="PhaseNet", pretrained=None, classes=2, modify_output=True): if name == "PhaseNet": if pretrained is not None and pretrained: model = sbm.PhaseNet(phases="PN", norm="peak").from_pretrained(pretrained) else: model = sbm.PhaseNet(phases="PN", norm="peak") if modify_output: model.out = nn.Conv1d(model.filters_root, classes, 1, padding="same") return model def train_one_epoch(model, dataloader, optimizer, pick_mae): size = len(dataloader.dataset) for batch_id, batch in enumerate(dataloader): # Compute prediction and loss pred = model(batch["X"].to(model.device)) loss = loss_fn(pred, batch["y"].to(model.device)) # Compute cross entropy loss cross_entropy_loss = f.cross_entropy(pred, batch["y"]) # Compute mae mae = pick_mae(pred, batch['y']) wandb.log({"loss": loss}) wandb.log({"batch cross entropy loss": cross_entropy_loss}) wandb.log({"p_mae": mae}) # Backpropagation optimizer.zero_grad() loss.backward() optimizer.step() if batch_id % 5 == 0: loss, current = loss.item(), batch_id * batch["X"].shape[0] print(f"loss: {loss:>7f} [{current:>5d}/{size:>5d}]") print(f"mae: {mae:>7f}") def test_one_epoch(model, dataloader, pick_mae, wandb_log=True): num_batches = len(dataloader) test_loss = 0 test_mae = 0 with torch.no_grad(): for batch in dataloader: pred = model(batch["X"].to(model.device)) test_loss += loss_fn(pred, batch["y"].to(model.device)).item() test_mae += pick_mae(pred, batch['y']) test_cross_entropy_loss = f.cross_entropy(pred, batch["y"]) if wandb_log: wandb.log({"batch cross entropy test loss": test_cross_entropy_loss}) test_loss /= num_batches test_mae /= num_batches wandb.log({"test_p_mae": test_mae, "test_loss": test_loss}) print(f"Test avg loss: {test_loss:>8f}") print(f"Test avg mae: {test_mae:>7f}\n") return test_loss, test_mae def train_model(model, path_to_trained_model, train_loader, dev_loader): wandb.watch(model, log_freq=10) optimizer = torch.optim.Adam(model.parameters(), lr=wandb.config.learning_rate) early_stopper = EarlyStopper(patience=3, min_delta=10) pick_mae = PickMAE(wandb.config.sampling_rate) best_loss = np.inf best_metrics = {} for t in range(wandb.config.epochs): print(f"Epoch {t + 1}\n-------------------------------") train_one_epoch(model, train_loader, optimizer, pick_mae) test_loss, test_mae = test_one_epoch(model, dev_loader, pick_mae) if test_loss < best_loss: best_loss = test_loss best_metrics = {"test_p_mae": test_mae, "test_loss": test_loss} torch.save(model.state_dict(), path_to_trained_model) if early_stopper.early_stop(test_loss): break print("Best model: ", str(best_metrics)) def loss_fn(y_pred, y_true, eps=1e-5): # vector cross entropy loss h = y_true * torch.log(y_pred + eps) h = h.mean(-1).sum(-1) # Mean along sample dimension and sum along pick dimension h = h.mean() # Mean over batch axis return -h def train_phasenet_on_sb_data(): config = { "epochs": 3, "batch_size": 256, "dataset": "ethz", "sampling_rate": 100, "model_name": "PhaseNet" } run = wandb.init( # set the wandb project where this run will be logged project="training_seisbench_models_on_igf_data", # track hyperparameters and run metadata config=config ) wandb.run.log_code(".", include_fn=lambda path: path.endswith("training_wandb_sweep.py")) train_loader, dev_loader, test = get_data_loaders(batch_size=wandb.config.batch_size, sampling_rate=wandb.config.sampling_rate, path=None, sb_dataset=wandb.config.dataset) model = load_model(name=wandb.config.model_name, pretrained=None, modify_output=True) path_to_trained_model = f"{project_path}/models/{wandb.config.model_name}_trained_on_{wandb.config.data_set}.pt" train_model(model, path_to_trained_model, train_loader, dev_loader) artifact = wandb.Artifact('model', type='model') artifact.add_file(path_to_trained_model) run.log_artifact(artifact) run.finish() def load_config(config_path): with open(config_path, 'r') as f: config = json.load(f) return config def train_sbmodel_on_igf_data(): config_path = project_path + "/experiments/config.json" config = load_config(config_path) run = wandb.init( # set the wandb project where this run will be logged project="training_seisbench_models_on_igf_data", # track hyperparameters and run metadata config=config ) wandb.run.log_code(".", include_fn=lambda path: path.endswith("training_wandb_sweep.py")) print(wandb.config.batch_size, wandb.config.sampling_rate) train_loader, dev_loader, test_loader = get_data_loaders(batch_size=wandb.config.batch_size, sampling_rate=wandb.config.sampling_rate ) model_name = wandb.config.model_name pretrained = wandb.config.pretrained print(model_name, pretrained) model = load_model(name=model_name, pretrained=pretrained) path_to_trained_model = f"{project_path}/models/{model_name}_pretrained_on_{pretrained}_finetuned_on_{wandb.config.dataset}.pt" train_model(model, path_to_trained_model, train_loader, dev_loader) artifact = wandb.Artifact('model', type='model') artifact.add_file(path_to_trained_model) run.log_artifact(artifact) run.finish() if __name__ == "__main__": # train_phasenet_on_sb_data() train_sbmodel_on_igf_data()