platform-demo-scripts/scripts/train.py

340 lines
11 KiB
Python
Raw Normal View History

2023-07-05 09:58:06 +02:00
import os.path
import wandb
import seisbench.data as sbd
import seisbench.generate as sbg
import seisbench.models as sbm
from seisbench.util import worker_seeding
import numpy as np
import torch
from torch.utils.data import DataLoader
import torch.nn.functional as f
import torch.nn as nn
from torchmetrics import Metric
from torch import Tensor, tensor
import json
from dotenv import load_dotenv
load_dotenv()
wandb_api_key = os.environ.get('WANDB_API_KEY')
if wandb_api_key is None:
raise ValueError("WANDB_API_KEY environment variable is not set.")
wandb.login(key=wandb_api_key)
project_path = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
class PickMAE(Metric):
higher_is_better: bool = False
mae_error: Tensor
def __init__(self, sampling_rate):
super().__init__()
self.add_state("mae_error", default=torch.tensor(0), dist_reduce_fx="sum")
self.sampling_rate = sampling_rate
def update(self, preds: torch.Tensor, target: torch.Tensor):
assert preds.shape == target.shape
pred_pick_idx = torch.argmax(preds[:, 0, :], dim=1).type(torch.FloatTensor)
true_pick_idx = torch.argmax(target[:, 0, :], dim=-1).type(torch.FloatTensor)
mae = nn.L1Loss()
self.mae_error = mae(pred_pick_idx, true_pick_idx) / self.sampling_rate #mae in seconds
def compute(self):
return self.mae_error.float()
class EarlyStopper:
def __init__(self, patience=1, min_delta=0):
self.patience = patience
self.min_delta = min_delta
self.counter = 0
self.min_validation_loss = np.inf
def early_stop(self, validation_loss):
if validation_loss < self.min_validation_loss:
self.min_validation_loss = validation_loss
self.counter = 0
elif validation_loss > (self.min_validation_loss + self.min_delta):
self.counter += 1
if self.counter >= self.patience:
return True
return False
def get_data_generator(split, sampling_rate, path, sb_dataset="ethz", station=None, window='random'):
if path is not None:
data = sbd.WaveformDataset(path, sampling_rate=sampling_rate)
phase_dict = {
"trace_Pg_arrival_sample": "P"
}
elif sb_dataset == "ethz":
data = sbd.ETHZ(sampling_rate=sampling_rate, force=True)
phase_dict = {
"trace_p_arrival_sample": "P",
"trace_pP_arrival_sample": "P",
"trace_P_arrival_sample": "P",
"trace_P1_arrival_sample": "P",
"trace_Pg_arrival_sample": "P",
"trace_Pn_arrival_sample": "P",
"trace_PmP_arrival_sample": "P",
"trace_pwP_arrival_sample": "P",
"trace_pwPm_arrival_sample": "P",
# "trace_s_arrival_sample": "S",
# "trace_S_arrival_sample": "S",
# "trace_S1_arrival_sample": "S",
# "trace_Sg_arrival_sample": "S",
# "trace_SmS_arrival_sample": "S",
# "trace_Sn_arrival_sample": "S",
}
dataset = data.get_split(split)
dataset.filter(dataset.metadata.trace_Pg_arrival_sample.notna())
print(split, dataset.metadata.shape, sampling_rate)
if station is not None:
dataset.filter(dataset.metadata.station_code==station)
data_generator = sbg.GenericGenerator(dataset)
if window == 'random':
print("using random window")
window_selector = sbg.RandomWindow(windowlen=3001, strategy="pad")
else:
window_selector = sbg.FixedWindow(windowlen=3001, p0=0, strategy="pad")
augmentations = [
sbg.WindowAroundSample(list(phase_dict.keys()), samples_before=3000, windowlen=6000, selection="random",
strategy="variable"),
window_selector,
sbg.Normalize(demean_axis=-1, amp_norm_axis=-1, amp_norm_type="peak"),
sbg.ChangeDtype(np.float32),
sbg.ProbabilisticLabeller(label_columns=phase_dict, sigma=30, dim=0)
]
data_generator.add_augmentations(augmentations)
return data_generator
def get_data_generators(sampling_rate=100, path=project_path+"/data/igf/seisbench_format", sb_dataset="ethz", station=None,
window='random'):
train_generator = get_data_generator("train", sampling_rate, path, sb_dataset, station, window)
dev_generator = get_data_generator("dev", sampling_rate, path, sb_dataset, station, window)
test_generator = get_data_generator("test", sampling_rate, path, sb_dataset, station, window)
return train_generator, dev_generator, test_generator
def get_data_loaders(batch_size=256, sampling_rate=100, path=project_path+"/data/igf/seisbench_format", sb_dataset="ethz",
window='random'):
train_generator, dev_generator, test_generator = get_data_generators(sampling_rate, path, sb_dataset, window=window)
num_workers = 0 # The number of threads used for loading data
train_loader = DataLoader(train_generator, batch_size=batch_size, shuffle=True, num_workers=num_workers,
worker_init_fn=worker_seeding)
dev_loader = DataLoader(dev_generator, batch_size=batch_size, shuffle=False, num_workers=num_workers,
worker_init_fn=worker_seeding)
test_loader = DataLoader(test_generator, batch_size=batch_size, shuffle=False, num_workers=num_workers,
worker_init_fn=worker_seeding)
return train_loader, dev_loader, test_loader
def load_model(name="PhaseNet", pretrained=None, classes=2, modify_output=True):
if name == "PhaseNet":
if pretrained is not None and pretrained:
model = sbm.PhaseNet(phases="PN", norm="peak").from_pretrained(pretrained)
else:
model = sbm.PhaseNet(phases="PN", norm="peak")
if modify_output:
model.out = nn.Conv1d(model.filters_root, classes, 1, padding="same")
return model
def train_one_epoch(model, dataloader, optimizer, pick_mae):
size = len(dataloader.dataset)
for batch_id, batch in enumerate(dataloader):
# Compute prediction and loss
pred = model(batch["X"].to(model.device))
loss = loss_fn(pred, batch["y"].to(model.device))
# Compute cross entropy loss
cross_entropy_loss = f.cross_entropy(pred, batch["y"])
# Compute mae
mae = pick_mae(pred, batch['y'])
wandb.log({"loss": loss})
wandb.log({"batch cross entropy loss": cross_entropy_loss})
wandb.log({"p_mae": mae})
# Backpropagation
optimizer.zero_grad()
loss.backward()
optimizer.step()
if batch_id % 5 == 0:
loss, current = loss.item(), batch_id * batch["X"].shape[0]
print(f"loss: {loss:>7f} [{current:>5d}/{size:>5d}]")
print(f"mae: {mae:>7f}")
def test_one_epoch(model, dataloader, pick_mae, wandb_log=True):
num_batches = len(dataloader)
test_loss = 0
test_mae = 0
with torch.no_grad():
for batch in dataloader:
pred = model(batch["X"].to(model.device))
test_loss += loss_fn(pred, batch["y"].to(model.device)).item()
test_mae += pick_mae(pred, batch['y'])
test_cross_entropy_loss = f.cross_entropy(pred, batch["y"])
if wandb_log:
wandb.log({"batch cross entropy test loss": test_cross_entropy_loss})
test_loss /= num_batches
test_mae /= num_batches
wandb.log({"test_p_mae": test_mae, "test_loss": test_loss})
print(f"Test avg loss: {test_loss:>8f}")
print(f"Test avg mae: {test_mae:>7f}\n")
return test_loss, test_mae
def train_model(model, path_to_trained_model, train_loader, dev_loader):
wandb.watch(model, log_freq=10)
optimizer = torch.optim.Adam(model.parameters(), lr=wandb.config.learning_rate)
early_stopper = EarlyStopper(patience=3, min_delta=10)
pick_mae = PickMAE(wandb.config.sampling_rate)
best_loss = np.inf
best_metrics = {}
for t in range(wandb.config.epochs):
print(f"Epoch {t + 1}\n-------------------------------")
train_one_epoch(model, train_loader, optimizer, pick_mae)
test_loss, test_mae = test_one_epoch(model, dev_loader, pick_mae)
if test_loss < best_loss:
best_loss = test_loss
best_metrics = {"test_p_mae": test_mae, "test_loss": test_loss}
torch.save(model.state_dict(), path_to_trained_model)
if early_stopper.early_stop(test_loss):
break
print("Best model: ", str(best_metrics))
def loss_fn(y_pred, y_true, eps=1e-5):
# vector cross entropy loss
h = y_true * torch.log(y_pred + eps)
h = h.mean(-1).sum(-1) # Mean along sample dimension and sum along pick dimension
h = h.mean() # Mean over batch axis
return -h
def train_phasenet_on_sb_data():
config = {
"epochs": 3,
"batch_size": 256,
"dataset": "ethz",
"sampling_rate": 100,
"model_name": "PhaseNet"
}
run = wandb.init(
# set the wandb project where this run will be logged
project="training_seisbench_models_on_igf_data",
# track hyperparameters and run metadata
config=config
)
wandb.run.log_code(".", include_fn=lambda path: path.endswith("training_wandb_sweep.py"))
train_loader, dev_loader, test = get_data_loaders(batch_size=wandb.config.batch_size,
sampling_rate=wandb.config.sampling_rate,
path=None,
sb_dataset=wandb.config.dataset)
model = load_model(name=wandb.config.model_name, pretrained=None, modify_output=True)
path_to_trained_model = f"{project_path}/models/{wandb.config.model_name}_trained_on_{wandb.config.data_set}.pt"
train_model(model, path_to_trained_model,
train_loader, dev_loader)
artifact = wandb.Artifact('model', type='model')
artifact.add_file(path_to_trained_model)
run.log_artifact(artifact)
run.finish()
def load_config(config_path):
with open(config_path, 'r') as f:
config = json.load(f)
return config
def train_sbmodel_on_igf_data():
config_path = project_path + "/experiments/config.json"
config = load_config(config_path)
run = wandb.init(
# set the wandb project where this run will be logged
project="training_seisbench_models_on_igf_data",
# track hyperparameters and run metadata
config=config
)
wandb.run.log_code(".", include_fn=lambda path: path.endswith("training_wandb_sweep.py"))
print(wandb.config.batch_size, wandb.config.sampling_rate)
train_loader, dev_loader, test_loader = get_data_loaders(batch_size=wandb.config.batch_size,
sampling_rate=wandb.config.sampling_rate
)
model_name = wandb.config.model_name
pretrained = wandb.config.pretrained
print(model_name, pretrained)
model = load_model(name=model_name, pretrained=pretrained)
path_to_trained_model = f"{project_path}/models/{model_name}_pretrained_on_{pretrained}_finetuned_on_{wandb.config.dataset}.pt"
train_model(model, path_to_trained_model, train_loader, dev_loader)
artifact = wandb.Artifact('model', type='model')
artifact.add_file(path_to_trained_model)
run.log_artifact(artifact)
run.finish()
if __name__ == "__main__":
# train_phasenet_on_sb_data()
train_sbmodel_on_igf_data()