initial commit
This commit is contained in:
0
scripts/__init__.py
Normal file
0
scripts/__init__.py
Normal file
339
scripts/train.py
Normal file
339
scripts/train.py
Normal file
@@ -0,0 +1,339 @@
|
||||
import os.path
|
||||
import wandb
|
||||
import seisbench.data as sbd
|
||||
import seisbench.generate as sbg
|
||||
import seisbench.models as sbm
|
||||
from seisbench.util import worker_seeding
|
||||
import numpy as np
|
||||
import torch
|
||||
from torch.utils.data import DataLoader
|
||||
import torch.nn.functional as f
|
||||
import torch.nn as nn
|
||||
from torchmetrics import Metric
|
||||
from torch import Tensor, tensor
|
||||
import json
|
||||
from dotenv import load_dotenv
|
||||
|
||||
load_dotenv()
|
||||
wandb_api_key = os.environ.get('WANDB_API_KEY')
|
||||
if wandb_api_key is None:
|
||||
raise ValueError("WANDB_API_KEY environment variable is not set.")
|
||||
|
||||
wandb.login(key=wandb_api_key)
|
||||
|
||||
project_path = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
|
||||
|
||||
|
||||
class PickMAE(Metric):
|
||||
higher_is_better: bool = False
|
||||
mae_error: Tensor
|
||||
|
||||
def __init__(self, sampling_rate):
|
||||
super().__init__()
|
||||
self.add_state("mae_error", default=torch.tensor(0), dist_reduce_fx="sum")
|
||||
self.sampling_rate = sampling_rate
|
||||
|
||||
def update(self, preds: torch.Tensor, target: torch.Tensor):
|
||||
|
||||
assert preds.shape == target.shape
|
||||
|
||||
pred_pick_idx = torch.argmax(preds[:, 0, :], dim=1).type(torch.FloatTensor)
|
||||
true_pick_idx = torch.argmax(target[:, 0, :], dim=-1).type(torch.FloatTensor)
|
||||
|
||||
mae = nn.L1Loss()
|
||||
self.mae_error = mae(pred_pick_idx, true_pick_idx) / self.sampling_rate #mae in seconds
|
||||
|
||||
def compute(self):
|
||||
return self.mae_error.float()
|
||||
|
||||
|
||||
class EarlyStopper:
|
||||
def __init__(self, patience=1, min_delta=0):
|
||||
self.patience = patience
|
||||
self.min_delta = min_delta
|
||||
self.counter = 0
|
||||
self.min_validation_loss = np.inf
|
||||
|
||||
def early_stop(self, validation_loss):
|
||||
if validation_loss < self.min_validation_loss:
|
||||
self.min_validation_loss = validation_loss
|
||||
self.counter = 0
|
||||
elif validation_loss > (self.min_validation_loss + self.min_delta):
|
||||
self.counter += 1
|
||||
if self.counter >= self.patience:
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def get_data_generator(split, sampling_rate, path, sb_dataset="ethz", station=None, window='random'):
|
||||
|
||||
if path is not None:
|
||||
data = sbd.WaveformDataset(path, sampling_rate=sampling_rate)
|
||||
phase_dict = {
|
||||
"trace_Pg_arrival_sample": "P"
|
||||
}
|
||||
elif sb_dataset == "ethz":
|
||||
data = sbd.ETHZ(sampling_rate=sampling_rate, force=True)
|
||||
|
||||
phase_dict = {
|
||||
"trace_p_arrival_sample": "P",
|
||||
"trace_pP_arrival_sample": "P",
|
||||
"trace_P_arrival_sample": "P",
|
||||
"trace_P1_arrival_sample": "P",
|
||||
"trace_Pg_arrival_sample": "P",
|
||||
"trace_Pn_arrival_sample": "P",
|
||||
"trace_PmP_arrival_sample": "P",
|
||||
"trace_pwP_arrival_sample": "P",
|
||||
"trace_pwPm_arrival_sample": "P",
|
||||
# "trace_s_arrival_sample": "S",
|
||||
# "trace_S_arrival_sample": "S",
|
||||
# "trace_S1_arrival_sample": "S",
|
||||
# "trace_Sg_arrival_sample": "S",
|
||||
# "trace_SmS_arrival_sample": "S",
|
||||
# "trace_Sn_arrival_sample": "S",
|
||||
}
|
||||
|
||||
dataset = data.get_split(split)
|
||||
dataset.filter(dataset.metadata.trace_Pg_arrival_sample.notna())
|
||||
|
||||
print(split, dataset.metadata.shape, sampling_rate)
|
||||
|
||||
if station is not None:
|
||||
dataset.filter(dataset.metadata.station_code==station)
|
||||
|
||||
data_generator = sbg.GenericGenerator(dataset)
|
||||
if window == 'random':
|
||||
print("using random window")
|
||||
window_selector = sbg.RandomWindow(windowlen=3001, strategy="pad")
|
||||
else:
|
||||
window_selector = sbg.FixedWindow(windowlen=3001, p0=0, strategy="pad")
|
||||
|
||||
augmentations = [
|
||||
sbg.WindowAroundSample(list(phase_dict.keys()), samples_before=3000, windowlen=6000, selection="random",
|
||||
strategy="variable"),
|
||||
window_selector,
|
||||
sbg.Normalize(demean_axis=-1, amp_norm_axis=-1, amp_norm_type="peak"),
|
||||
sbg.ChangeDtype(np.float32),
|
||||
sbg.ProbabilisticLabeller(label_columns=phase_dict, sigma=30, dim=0)
|
||||
]
|
||||
|
||||
data_generator.add_augmentations(augmentations)
|
||||
|
||||
return data_generator
|
||||
|
||||
|
||||
def get_data_generators(sampling_rate=100, path=project_path+"/data/igf/seisbench_format", sb_dataset="ethz", station=None,
|
||||
window='random'):
|
||||
|
||||
train_generator = get_data_generator("train", sampling_rate, path, sb_dataset, station, window)
|
||||
dev_generator = get_data_generator("dev", sampling_rate, path, sb_dataset, station, window)
|
||||
test_generator = get_data_generator("test", sampling_rate, path, sb_dataset, station, window)
|
||||
|
||||
return train_generator, dev_generator, test_generator
|
||||
|
||||
|
||||
def get_data_loaders(batch_size=256, sampling_rate=100, path=project_path+"/data/igf/seisbench_format", sb_dataset="ethz",
|
||||
window='random'):
|
||||
|
||||
train_generator, dev_generator, test_generator = get_data_generators(sampling_rate, path, sb_dataset, window=window)
|
||||
num_workers = 0 # The number of threads used for loading data
|
||||
|
||||
train_loader = DataLoader(train_generator, batch_size=batch_size, shuffle=True, num_workers=num_workers,
|
||||
worker_init_fn=worker_seeding)
|
||||
dev_loader = DataLoader(dev_generator, batch_size=batch_size, shuffle=False, num_workers=num_workers,
|
||||
worker_init_fn=worker_seeding)
|
||||
|
||||
test_loader = DataLoader(test_generator, batch_size=batch_size, shuffle=False, num_workers=num_workers,
|
||||
worker_init_fn=worker_seeding)
|
||||
|
||||
return train_loader, dev_loader, test_loader
|
||||
|
||||
|
||||
def load_model(name="PhaseNet", pretrained=None, classes=2, modify_output=True):
|
||||
|
||||
if name == "PhaseNet":
|
||||
|
||||
if pretrained is not None and pretrained:
|
||||
model = sbm.PhaseNet(phases="PN", norm="peak").from_pretrained(pretrained)
|
||||
else:
|
||||
model = sbm.PhaseNet(phases="PN", norm="peak")
|
||||
|
||||
if modify_output:
|
||||
model.out = nn.Conv1d(model.filters_root, classes, 1, padding="same")
|
||||
|
||||
return model
|
||||
|
||||
|
||||
def train_one_epoch(model, dataloader, optimizer, pick_mae):
|
||||
size = len(dataloader.dataset)
|
||||
for batch_id, batch in enumerate(dataloader):
|
||||
|
||||
# Compute prediction and loss
|
||||
|
||||
pred = model(batch["X"].to(model.device))
|
||||
|
||||
loss = loss_fn(pred, batch["y"].to(model.device))
|
||||
|
||||
# Compute cross entropy loss
|
||||
cross_entropy_loss = f.cross_entropy(pred, batch["y"])
|
||||
|
||||
# Compute mae
|
||||
mae = pick_mae(pred, batch['y'])
|
||||
|
||||
wandb.log({"loss": loss})
|
||||
wandb.log({"batch cross entropy loss": cross_entropy_loss})
|
||||
wandb.log({"p_mae": mae})
|
||||
|
||||
|
||||
# Backpropagation
|
||||
optimizer.zero_grad()
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
|
||||
if batch_id % 5 == 0:
|
||||
loss, current = loss.item(), batch_id * batch["X"].shape[0]
|
||||
print(f"loss: {loss:>7f} [{current:>5d}/{size:>5d}]")
|
||||
print(f"mae: {mae:>7f}")
|
||||
|
||||
|
||||
def test_one_epoch(model, dataloader, pick_mae, wandb_log=True):
|
||||
|
||||
num_batches = len(dataloader)
|
||||
test_loss = 0
|
||||
test_mae = 0
|
||||
|
||||
with torch.no_grad():
|
||||
for batch in dataloader:
|
||||
pred = model(batch["X"].to(model.device))
|
||||
|
||||
test_loss += loss_fn(pred, batch["y"].to(model.device)).item()
|
||||
test_mae += pick_mae(pred, batch['y'])
|
||||
test_cross_entropy_loss = f.cross_entropy(pred, batch["y"])
|
||||
if wandb_log:
|
||||
wandb.log({"batch cross entropy test loss": test_cross_entropy_loss})
|
||||
|
||||
test_loss /= num_batches
|
||||
test_mae /= num_batches
|
||||
|
||||
wandb.log({"test_p_mae": test_mae, "test_loss": test_loss})
|
||||
|
||||
print(f"Test avg loss: {test_loss:>8f}")
|
||||
print(f"Test avg mae: {test_mae:>7f}\n")
|
||||
|
||||
return test_loss, test_mae
|
||||
|
||||
|
||||
def train_model(model, path_to_trained_model, train_loader, dev_loader):
|
||||
|
||||
wandb.watch(model, log_freq=10)
|
||||
|
||||
optimizer = torch.optim.Adam(model.parameters(), lr=wandb.config.learning_rate)
|
||||
early_stopper = EarlyStopper(patience=3, min_delta=10)
|
||||
pick_mae = PickMAE(wandb.config.sampling_rate)
|
||||
|
||||
best_loss = np.inf
|
||||
best_metrics = {}
|
||||
|
||||
for t in range(wandb.config.epochs):
|
||||
print(f"Epoch {t + 1}\n-------------------------------")
|
||||
train_one_epoch(model, train_loader, optimizer, pick_mae)
|
||||
test_loss, test_mae = test_one_epoch(model, dev_loader, pick_mae)
|
||||
|
||||
if test_loss < best_loss:
|
||||
best_loss = test_loss
|
||||
best_metrics = {"test_p_mae": test_mae, "test_loss": test_loss}
|
||||
torch.save(model.state_dict(), path_to_trained_model)
|
||||
|
||||
if early_stopper.early_stop(test_loss):
|
||||
break
|
||||
|
||||
print("Best model: ", str(best_metrics))
|
||||
|
||||
|
||||
def loss_fn(y_pred, y_true, eps=1e-5):
|
||||
# vector cross entropy loss
|
||||
h = y_true * torch.log(y_pred + eps)
|
||||
h = h.mean(-1).sum(-1) # Mean along sample dimension and sum along pick dimension
|
||||
h = h.mean() # Mean over batch axis
|
||||
return -h
|
||||
|
||||
|
||||
def train_phasenet_on_sb_data():
|
||||
|
||||
config = {
|
||||
"epochs": 3,
|
||||
"batch_size": 256,
|
||||
"dataset": "ethz",
|
||||
"sampling_rate": 100,
|
||||
"model_name": "PhaseNet"
|
||||
}
|
||||
|
||||
run = wandb.init(
|
||||
# set the wandb project where this run will be logged
|
||||
project="training_seisbench_models_on_igf_data",
|
||||
# track hyperparameters and run metadata
|
||||
config=config
|
||||
)
|
||||
|
||||
wandb.run.log_code(".", include_fn=lambda path: path.endswith("training_wandb_sweep.py"))
|
||||
|
||||
train_loader, dev_loader, test = get_data_loaders(batch_size=wandb.config.batch_size,
|
||||
sampling_rate=wandb.config.sampling_rate,
|
||||
path=None,
|
||||
sb_dataset=wandb.config.dataset)
|
||||
|
||||
model = load_model(name=wandb.config.model_name, pretrained=None, modify_output=True)
|
||||
path_to_trained_model = f"{project_path}/models/{wandb.config.model_name}_trained_on_{wandb.config.data_set}.pt"
|
||||
train_model(model, path_to_trained_model,
|
||||
train_loader, dev_loader)
|
||||
|
||||
artifact = wandb.Artifact('model', type='model')
|
||||
artifact.add_file(path_to_trained_model)
|
||||
run.log_artifact(artifact)
|
||||
|
||||
run.finish()
|
||||
|
||||
|
||||
def load_config(config_path):
|
||||
with open(config_path, 'r') as f:
|
||||
config = json.load(f)
|
||||
return config
|
||||
|
||||
|
||||
def train_sbmodel_on_igf_data():
|
||||
|
||||
config_path = project_path + "/experiments/config.json"
|
||||
config = load_config(config_path)
|
||||
|
||||
run = wandb.init(
|
||||
# set the wandb project where this run will be logged
|
||||
project="training_seisbench_models_on_igf_data",
|
||||
# track hyperparameters and run metadata
|
||||
config=config
|
||||
)
|
||||
wandb.run.log_code(".", include_fn=lambda path: path.endswith("training_wandb_sweep.py"))
|
||||
|
||||
print(wandb.config.batch_size, wandb.config.sampling_rate)
|
||||
train_loader, dev_loader, test_loader = get_data_loaders(batch_size=wandb.config.batch_size,
|
||||
sampling_rate=wandb.config.sampling_rate
|
||||
)
|
||||
|
||||
model_name = wandb.config.model_name
|
||||
pretrained = wandb.config.pretrained
|
||||
|
||||
print(model_name, pretrained)
|
||||
model = load_model(name=model_name, pretrained=pretrained)
|
||||
path_to_trained_model = f"{project_path}/models/{model_name}_pretrained_on_{pretrained}_finetuned_on_{wandb.config.dataset}.pt"
|
||||
train_model(model, path_to_trained_model, train_loader, dev_loader)
|
||||
|
||||
artifact = wandb.Artifact('model', type='model')
|
||||
artifact.add_file(path_to_trained_model)
|
||||
run.log_artifact(artifact)
|
||||
|
||||
run.finish()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# train_phasenet_on_sb_data()
|
||||
train_sbmodel_on_igf_data()
|
||||
|
62
scripts/training_wandb_sweep.py
Normal file
62
scripts/training_wandb_sweep.py
Normal file
@@ -0,0 +1,62 @@
|
||||
import os.path
|
||||
import wandb
|
||||
import yaml
|
||||
|
||||
from train import get_data_loaders, load_model, train_model
|
||||
|
||||
from dotenv import load_dotenv
|
||||
|
||||
load_dotenv()
|
||||
wandb_api_key = os.environ.get('WANDB_API_KEY')
|
||||
if wandb_api_key is None:
|
||||
raise ValueError("WANDB_API_KEY environment variable is not set.")
|
||||
|
||||
wandb.login(key=wandb_api_key)
|
||||
|
||||
project_path = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
|
||||
sweep_config_path = project_path + "/experiments/sweep4.yaml"
|
||||
|
||||
with open(sweep_config_path) as file:
|
||||
sweep_configuration = yaml.load(file, Loader=yaml.FullLoader)
|
||||
|
||||
sweep_id = wandb.sweep(
|
||||
sweep=sweep_configuration,
|
||||
project='training_seisbench_models_on_igf_data'
|
||||
)
|
||||
sampling_rate = 100
|
||||
|
||||
def tune_training_hyperparams():
|
||||
|
||||
run = wandb.init(
|
||||
# set the wandb project where this run will be logged
|
||||
project="training_seisbench_models_on_igf_data",
|
||||
# track hyperparameters and run metadata
|
||||
config={"sampling_rate":sampling_rate}
|
||||
)
|
||||
|
||||
wandb.run.log_code(".", include_fn=lambda path: path.endswith("training_wandb_sweep.py"))
|
||||
|
||||
train_loader, dev_loader, test_loader = get_data_loaders(batch_size=wandb.config.batch_size,
|
||||
sampling_rate=wandb.config.sampling_rate,
|
||||
sb_dataset=wandb.config.dataset)
|
||||
|
||||
model_name = wandb.config.model_name
|
||||
pretrained = wandb.config.pretrained
|
||||
print(wandb.config)
|
||||
print(model_name, pretrained, type(pretrained), wandb.config.sampling_rate)
|
||||
if not pretrained:
|
||||
pretrained
|
||||
model = load_model(name=model_name, pretrained=pretrained)
|
||||
path_to_trained_model = f"{project_path}/models/{model_name}_pretrained_on_{pretrained}_finetuned_on_{wandb.config.dataset}.pt"
|
||||
train_model(model, path_to_trained_model, train_loader, dev_loader)
|
||||
|
||||
artifact = wandb.Artifact('model', type='model')
|
||||
artifact.add_file(path_to_trained_model)
|
||||
run.log_artifact(artifact)
|
||||
|
||||
run.finish()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
wandb.agent(sweep_id, function=tune_training_hyperparams, count=10)
|
Reference in New Issue
Block a user