platform-demo-scripts/scripts/models.py
k.milian e86f131cc0 finetuning (#1)
Reviewed-on: #1
Extended Phasenet finetuning options
2024-07-29 13:41:05 +02:00

1245 lines
43 KiB
Python

"""
This file contains the model specifications.
"""
import seisbench.models as sbm
import seisbench.generate as sbg
import pytorch_lightning as pl
import torch
from torch.optim import lr_scheduler
import torch.nn.functional as F
import numpy as np
from abc import abstractmethod, ABC
# import lightning as L
# Allows to import this file in both jupyter notebook and code
try:
from .augmentations import DuplicateEvent
except ImportError:
from augmentations import DuplicateEvent
import os
import logging
script_name = os.path.splitext(os.path.basename(__file__))[0]
logger = logging.getLogger(script_name)
logger.setLevel(logging.DEBUG)
# Phase dict for labelling. We only study P and S phases without differentiating between them.
phase_dict = {
"trace_p_arrival_sample": "P",
"trace_pP_arrival_sample": "P",
"trace_P_arrival_sample": "P",
"trace_P1_arrival_sample": "P",
"trace_Pg_arrival_sample": "P",
"trace_Pn_arrival_sample": "P",
"trace_PmP_arrival_sample": "P",
"trace_pwP_arrival_sample": "P",
"trace_pwPm_arrival_sample": "P",
"trace_s_arrival_sample": "S",
"trace_S_arrival_sample": "S",
"trace_S1_arrival_sample": "S",
"trace_Sg_arrival_sample": "S",
"trace_SmS_arrival_sample": "S",
"trace_Sn_arrival_sample": "S",
}
def vector_cross_entropy(y_pred, y_true, eps=1e-5):
"""
Cross entropy loss
:param y_true: True label probabilities
:param y_pred: Predicted label probabilities
:param eps: Epsilon to clip values for stability
:return: Average loss across batch
"""
h = y_true * torch.log(y_pred + eps)
if y_pred.ndim == 3:
h = h.mean(-1).sum(
-1
) # Mean along sample dimension and sum along pick dimension
else:
h = h.sum(-1) # Sum along pick dimension
h = h.mean() # Mean over batch axis
return -h
class SeisBenchModuleLit(pl.LightningModule, ABC):
"""
Abstract interface for SeisBench lightning modules.
Adds generic function, e.g., get_augmentations
"""
@abstractmethod
def get_augmentations(self):
"""
Returns a list of augmentations that can be passed to the seisbench.generate.GenericGenerator
:return: List of augmentations
"""
pass
def get_train_augmentations(self):
"""
Returns the set of training augmentations.
"""
return self.get_augmentations()
def get_val_augmentations(self):
"""
Returns the set of validation augmentations for validations during training.
"""
return self.get_augmentations()
@abstractmethod
def get_eval_augmentations(self):
"""
Returns the set of evaluation augmentations for evaluation after training.
These augmentations will be passed to a SteeredGenerator and should usually contain a steered window.
"""
pass
@abstractmethod
def predict_step(self, batch, batch_idx=None, dataloader_idx=None):
"""
Predict step for the lightning module. Returns results for three tasks:
- earthquake detection (score, higher means more likely detection)
- P to S phase discrimination (score, high means P, low means S)
- phase location in samples (two integers, first for P, second for S wave)
All predictions should only take the window defined by batch["window_borders"] into account.
:param batch:
:return:
"""
score_detection = None
score_p_or_s = None
p_sample = None
s_sample = None
return score_detection, score_p_or_s, p_sample, s_sample
class PhaseNetLit(SeisBenchModuleLit):
"""
LightningModule for PhaseNet
:param lr: Learning rate, defaults to 1e-2
:param sigma: Standard deviation passed to the ProbabilisticPickLabeller
:param sample_boundaries: Low and high boundaries for the RandomWindow selection.
:param kwargs: Kwargs are passed to the SeisBench.models.PhaseNet constructor.
"""
def __init__(self, lr=1e-2, sigma=20, sample_boundaries=(None, None), **kwargs):
super().__init__()
self.save_hyperparameters()
self.lr = lr
self.sigma = sigma
self.sample_boundaries = sample_boundaries
self.loss = vector_cross_entropy
self.pretrained = kwargs.pop("pretrained", None)
self.norm = kwargs.pop("norm", "peak")
self.highpass = kwargs.pop("highpass", None)
self.lowpass = kwargs.pop("lowpass", None)
if self.pretrained is not None:
self.model = sbm.PhaseNet.from_pretrained(self.pretrained)
# self.norm = self.model.norm
else:
self.model = sbm.PhaseNet(**kwargs)
self.finetuning_strategy = kwargs.pop("finetuning_strategy", None)
self.steplr_gamma = kwargs.pop("steplr_gamma", 0.1)
self.reduce_lr_on_plateau = False
self.initial_epochs = 0
if self.finetuning_strategy is not None:
if self.finetuning_strategy == "top":
self.initial_epochs = 3
elif self.finetuning_strategy in ["decoder", "encoder"]:
self.initial_epochs = 6
self.freeze()
def forward(self, x):
return self.model(x)
def shared_step(self, batch):
x = batch["X"]
y_true = batch["y"]
y_pred = self.model(x)
return self.loss(y_pred, y_true)
def training_step(self, batch, batch_idx):
loss = self.shared_step(batch)
self.log("train_loss", loss)
return loss
def validation_step(self, batch, batch_idx):
loss = self.shared_step(batch)
self.log("val_loss", loss)
return loss
def configure_optimizers(self):
optimizer = torch.optim.Adam(self.parameters(), lr=self.lr)
if self.finetuning_strategy is not None:
scheduler = lr_scheduler.LambdaLR(optimizer, self.lr_lambda)
self.reduce_lr_on_plateau = False
else:
scheduler = lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.1, patience=3)
self.reduce_lr_on_plateau = True
#
return {
'optimizer': optimizer,
'lr_scheduler': {
'scheduler': scheduler,
'monitor': 'val_loss',
'interval': 'epoch',
'reduce_on_plateau': self.reduce_lr_on_plateau,
},
}
def lr_lambda(self, epoch):
# reduce lr after x initial epochs
if epoch == self.initial_epochs:
self.lr *= self.steplr_gamma
return self.lr
def lr_scheduler_step(self, scheduler, metric):
if self.reduce_lr_on_plateau:
scheduler.step(metric, epoch=self.current_epoch)
else:
scheduler.step(epoch=self.current_epoch)
# def lr_scheduler_step(self, scheduler, optimizer_idx, metric):
# scheduler.step(epoch=self.current_epoch)
def get_augmentations(self):
filter = []
if self.highpass is not None:
filter = [sbg.Filter(1, self.highpass, "highpass", forward_backward=True)]
logger.info(f"Using highpass filer {self.highpass}")
if self.lowpass is not None:
filter += [sbg.Filter(1, self.lowpass, "lowpass", forward_backward=True)]
logger.info(f"Using lowpass filer {self.lowpass}")
logger.info(filter)
return [
# In 2/3 of the cases, select windows around picks, to reduce amount of noise traces in training.
# Uses strategy variable, as padding will be handled by the random window.
# In 1/3 of the cases, just returns the original trace, to keep diversity high.
sbg.OneOf(
[
sbg.WindowAroundSample(
list(phase_dict.keys()),
samples_before=3000,
windowlen=6000,
selection="random",
strategy="variable",
),
sbg.NullAugmentation(),
],
probabilities=[2, 1],
),
sbg.RandomWindow(
low=self.sample_boundaries[0],
high=self.sample_boundaries[1],
windowlen=3001,
strategy="pad",
),
sbg.Normalize(demean_axis=-1, amp_norm_axis=-1, amp_norm_type=self.norm),
*filter,
sbg.ChangeDtype(np.float32),
sbg.ProbabilisticLabeller(
label_columns=phase_dict, sigma=self.sigma, dim=0
),
]
def get_eval_augmentations(self):
filter = []
if self.highpass is not None:
filter = [sbg.Filter(1, self.highpass, "highpass", forward_backward=True)]
if self.lowpass is not None:
filter += [sbg.Filter(1, self.lowpass, "lowpass", forward_backward=True)]
return [
sbg.SteeredWindow(windowlen=3001, strategy="pad"),
sbg.Normalize(demean_axis=-1, amp_norm_axis=-1, amp_norm_type=self.norm),
*filter,
sbg.ChangeDtype(np.float32),
]
def predict_step(self, batch, batch_idx=None, dataloader_idx=None):
x = batch["X"]
window_borders = batch["window_borders"]
pred = self.model(x)
score_detection = torch.zeros(pred.shape[0])
score_p_or_s = torch.zeros(pred.shape[0])
p_sample = torch.zeros(pred.shape[0], dtype=int)
s_sample = torch.zeros(pred.shape[0], dtype=int)
for i in range(pred.shape[0]):
start_sample, end_sample = window_borders[i]
local_pred = pred[i, :, start_sample:end_sample]
score_detection[i] = torch.max(1 - local_pred[-1]) # 1 - noise
score_p_or_s[i] = torch.max(local_pred[0]) / torch.max(
local_pred[1]
) # most likely P by most likely S
p_sample[i] = torch.argmax(local_pred[0])
s_sample[i] = torch.argmax(local_pred[1])
return score_detection, score_p_or_s, p_sample, s_sample
def freeze(self):
if self.finetuning_strategy == "decoder": # finetune decoder branch and freeze encoder branch
for p in self.model.down_branch.parameters():
p.requires_grad = False
elif self.finetuning_strategy == "encoder": # finetune encoder branch and freeze decoder branch
for p in self.model.up_branch.parameters():
p.requires_grad = False
elif self.finetuning_strategy == "top":
for p in self.model.out.parameters():
p.requires_grad = False
def unfreeze(self):
logger.info("Unfreezing layers")
for p in self.model.parameters():
p.requires_grad = True
def on_train_epoch_start(self):
# Unfreeze some layers after x initial epochs
if self.current_epoch == self.initial_epochs:
self.unfreeze()
class GPDLit(SeisBenchModuleLit):
"""
LightningModule for GPD
:param lr: Learning rate, defaults to 1e-3
:param sigma: Standard deviation passed to the ProbabilisticPickLabeller. If not, uses determinisic labels,
i.e., whether a pick is contained.
:param highpass: If not None, cutoff frequency for highpass filter in Hz.
:param lowpass: If not None, cutoff frequency for lowpass filter in Hz.
:param kwargs: Kwargs are passed to the SeisBench.models.GPD constructor.
"""
def __init__(self, lr=1e-3, highpass=None, lowpass=None, sigma=None, **kwargs):
super().__init__()
self.save_hyperparameters()
self.lr = lr
self.sigma = sigma
self.model = sbm.GPD(**kwargs)
if sigma is None:
self.nllloss = torch.nn.NLLLoss()
self.loss = self.nll_with_probabilities
else:
self.loss = vector_cross_entropy
self.highpass = highpass
self.lowpass = lowpass
self.predict_stride = 5
def nll_with_probabilities(self, y_pred, y_true):
y_pred = torch.log(y_pred)
return self.nllloss(y_pred, y_true)
def forward(self, x):
return self.model(x)
def shared_step(self, batch):
x = batch["X"]
y_true = batch["y"].squeeze()
y_pred = self.model(x)
return self.loss(y_pred, y_true)
def training_step(self, batch, batch_idx):
loss = self.shared_step(batch)
self.log("train_loss", loss)
return loss
def validation_step(self, batch, batch_idx):
loss = self.shared_step(batch)
self.log("val_loss", loss)
return loss
def configure_optimizers(self):
optimizer = torch.optim.Adam(self.parameters(), lr=self.lr)
return optimizer
def get_augmentations(self):
filter = []
if self.highpass is not None:
filter = [sbg.Filter(1, self.highpass, "highpass")]
if self.lowpass is not None:
filter += [sbg.Filter(1, self.lowpass, "lowpass")]
if self.sigma is None:
labeller = sbg.StandardLabeller(
label_columns=phase_dict,
on_overlap="fixed-relevance",
low=100,
high=-100,
)
else:
labeller = sbg.ProbabilisticPointLabeller(
label_columns=phase_dict, position=0.5, sigma=self.sigma
)
return (
[
# In 2/3 of the cases, select windows around picks, to reduce amount of noise traces in training.
# Uses strategy variable, as padding will be handled by the random window.
# In 1/3 of the cases, just returns the original trace, to keep diversity high.
sbg.OneOf(
[
sbg.WindowAroundSample(
list(phase_dict.keys()),
samples_before=400,
windowlen=800,
selection="random",
strategy="variable",
),
sbg.NullAugmentation(),
],
probabilities=[2, 1],
),
sbg.RandomWindow(
windowlen=400,
strategy="pad",
),
sbg.Normalize(detrend_axis=-1, amp_norm_axis=-1, amp_norm_type="peak"),
labeller,
]
+ filter
+ [sbg.ChangeDtype(np.float32)]
)
def get_eval_augmentations(self):
filter = []
if self.highpass is not None:
filter = [sbg.Filter(1, self.highpass, "highpass")]
if self.lowpass is not None:
filter += [sbg.Filter(1, self.lowpass, "lowpass")]
return [
# Larger window length ensures a sliding window covering full trace can be applied
sbg.SteeredWindow(windowlen=3400, strategy="pad"),
sbg.SlidingWindow(timestep=self.predict_stride, windowlen=400),
sbg.Normalize(detrend_axis=-1, amp_norm_axis=-1, amp_norm_type="peak"),
*filter,
sbg.ChangeDtype(np.float32),
]
def predict_step(self, batch, batch_idx=None, dataloader_idx=None):
x = batch["X"]
window_borders = batch["window_borders"]
shape_save = x.shape
x = x.reshape(
(-1,) + shape_save[2:]
) # Merge batch and sliding window dimensions
pred = self.model(x)
pred = pred.reshape(shape_save[:2] + (-1,))
pred = torch.repeat_interleave(
pred, self.predict_stride, dim=1
) # Counteract stride
pred = F.pad(pred, (0, 0, 200, 200))
pred = pred.permute(0, 2, 1)
# Otherwise windows shorter 30 s will automatically produce detections
pred[:, 2, :200] = 1
pred[:, 2, -200:] = 1
score_detection = torch.zeros(pred.shape[0])
score_p_or_s = torch.zeros(pred.shape[0])
p_sample = torch.zeros(pred.shape[0], dtype=int)
s_sample = torch.zeros(pred.shape[0], dtype=int)
for i in range(pred.shape[0]):
start_sample, end_sample = window_borders[i]
local_pred = pred[i, :, start_sample:end_sample]
score_detection[i] = torch.max(1 - local_pred[-1]) # 1 - noise
score_p_or_s[i] = torch.max(local_pred[0]) / torch.max(
local_pred[1]
) # most likely P by most likely S
# Adjust for prediction stride by choosing the sample in the middle of each block
p_sample[i] = torch.argmax(local_pred[0]) + self.predict_stride // 2
s_sample[i] = torch.argmax(local_pred[1]) + self.predict_stride // 2
return score_detection, score_p_or_s, p_sample, s_sample
def pick_mae(self, batch, batch_idx):
x = batch["X"]
window_borders = batch["window_borders"]
shape_save = x.shape
x = x.reshape(
(-1,) + shape_save[2:]
) # Merge batch and sliding window dimensions
pred = self.model(x)
pred = pred.reshape(shape_save[:2] + (-1,))
pred = torch.repeat_interleave(
pred, self.predict_stride, dim=1
) # Counteract stride
pred = F.pad(pred, (0, 0, 200, 200))
pred = pred.permute(0, 2, 1)
# Otherwise windows shorter 30 s will automatically produce detections
pred[:, 2, :200] = 1
pred[:, 2, -200:] = 1
score_detection = torch.zeros(pred.shape[0])
score_p_or_s = torch.zeros(pred.shape[0])
p_sample = torch.zeros(pred.shape[0], dtype=int)
s_sample = torch.zeros(pred.shape[0], dtype=int)
for i in range(pred.shape[0]):
start_sample, end_sample = window_borders[i]
local_pred = pred[i, :, start_sample:end_sample]
score_detection[i] = torch.max(1 - local_pred[-1]) # 1 - noise
score_p_or_s[i] = torch.max(local_pred[0]) / torch.max(
local_pred[1]
) # most likely P by most likely S
# Adjust for prediction stride by choosing the sample in the middle of each block
p_sample[i] = torch.argmax(local_pred[0]) + self.predict_stride // 2
s_sample[i] = torch.argmax(local_pred[1]) + self.predict_stride // 2
return score_detection, score_p_or_s, p_sample, s_sample
class EQTransformerLit(SeisBenchModuleLit):
"""
LightningModule for EQTransformer
:param lr: Learning rate, defaults to 1e-2
:param sigma: Standard deviation passed to the ProbabilisticPickLabeller
:param sample_boundaries: Low and high boundaries for the RandomWindow selection.
:param loss_weights: Loss weights for detection, P and S phase.
:param rotate_array: If true, rotate array along sample axis.
:param detection_fixed_window: Passed as parameter fixed_window to detection
:param kwargs: Kwargs are passed to the SeisBench.models.EQTransformer constructor.
"""
def __init__(
self,
lr=1e-2,
sigma=20,
sample_boundaries=(None, None),
loss_weights=(0.05, 0.40, 0.55),
rotate_array=False,
detection_fixed_window=None,
**kwargs,
):
super().__init__()
self.save_hyperparameters()
self.lr = lr
self.sigma = sigma
self.sample_boundaries = sample_boundaries
self.loss = torch.nn.BCELoss()
self.loss_weights = loss_weights
self.rotate_array = rotate_array
self.detection_fixed_window = detection_fixed_window
self.model = sbm.EQTransformer(**kwargs)
def forward(self, x):
return self.model(x)
def shared_step(self, batch):
x = batch["X"]
p_true = batch["y"][:, 0]
s_true = batch["y"][:, 1]
det_true = batch["detections"][:, 0]
det_pred, p_pred, s_pred = self.model(x)
return (
self.loss_weights[0] * self.loss(det_pred, det_true)
+ self.loss_weights[1] * self.loss(p_pred, p_true)
+ self.loss_weights[2] * self.loss(s_pred, s_true)
)
def training_step(self, batch, batch_idx):
loss = self.shared_step(batch)
self.log("train_loss", loss)
return loss
def validation_step(self, batch, batch_idx):
loss = self.shared_step(batch)
self.log("val_loss", loss)
return loss
def configure_optimizers(self):
optimizer = torch.optim.Adam(self.parameters(), lr=self.lr)
return optimizer
def get_joint_augmentations(self):
p_phases = [key for key, val in phase_dict.items() if val == "P"]
s_phases = [key for key, val in phase_dict.items() if val == "S"]
if self.detection_fixed_window is not None:
detection_labeller = sbg.DetectionLabeller(
p_phases,
fixed_window=self.detection_fixed_window,
key=("X", "detections"),
)
else:
detection_labeller = sbg.DetectionLabeller(
p_phases, s_phases=s_phases, key=("X", "detections")
)
block1 = [
# In 2/3 of the cases, select windows around picks, to reduce amount of noise traces in training.
# Uses strategy variable, as padding will be handled by the random window.
# In 1/3 of the cases, just returns the original trace, to keep diversity high.
sbg.OneOf(
[
sbg.WindowAroundSample(
list(phase_dict.keys()),
samples_before=6000,
windowlen=12000,
selection="random",
strategy="variable",
),
sbg.NullAugmentation(),
],
probabilities=[2, 1],
),
sbg.RandomWindow(
low=self.sample_boundaries[0],
high=self.sample_boundaries[1],
windowlen=6000,
strategy="pad",
),
sbg.ProbabilisticLabeller(
label_columns=phase_dict, sigma=self.sigma, dim=0
),
detection_labeller,
# Normalize to ensure correct augmentation behavior
sbg.Normalize(detrend_axis=-1, amp_norm_axis=-1, amp_norm_type="peak"),
]
block2 = [
sbg.ChangeDtype(np.float32, "X"),
sbg.ChangeDtype(np.float32, "y"),
sbg.ChangeDtype(np.float32, "detections"),
]
return block1, block2
def get_train_augmentations(self):
if self.rotate_array:
rotation_block = [
sbg.OneOf(
[
sbg.RandomArrayRotation(["X", "y", "detections"]),
sbg.NullAugmentation(),
],
[0.99, 0.01],
)
]
else:
rotation_block = []
augmentation_block = [
# Add secondary event
sbg.OneOf(
[DuplicateEvent(label_keys="y"), sbg.NullAugmentation()],
probabilities=[0.3, 0.7],
),
# Gaussian noise
sbg.OneOf([sbg.GaussianNoise(), sbg.NullAugmentation()], [0.5, 0.5]),
# Array rotation
*rotation_block,
# Gaps
sbg.OneOf([sbg.AddGap(), sbg.NullAugmentation()], [0.2, 0.8]),
# Channel dropout
sbg.OneOf([sbg.ChannelDropout(), sbg.NullAugmentation()], [0.3, 0.7]),
# Augmentations make second normalize necessary
sbg.Normalize(detrend_axis=-1, amp_norm_axis=-1, amp_norm_type="peak"),
]
block1, block2 = self.get_joint_augmentations()
return block1 + augmentation_block + block2
def get_val_augmentations(self):
block1, block2 = self.get_joint_augmentations()
return block1 + block2
def get_augmentations(self):
raise NotImplementedError("Use get_train/val_augmentations instead.")
def get_eval_augmentations(self):
return [
sbg.SteeredWindow(windowlen=6000, strategy="pad"),
sbg.ChangeDtype(np.float32),
sbg.Normalize(detrend_axis=-1, amp_norm_axis=-1, amp_norm_type="peak"),
]
def predict_step(self, batch, batch_idx=None, dataloader_idx=None):
x = batch["X"]
window_borders = batch["window_borders"]
det_pred, p_pred, s_pred = self.model(x)
score_detection = torch.zeros(det_pred.shape[0])
score_p_or_s = torch.zeros(det_pred.shape[0])
p_sample = torch.zeros(det_pred.shape[0], dtype=int)
s_sample = torch.zeros(det_pred.shape[0], dtype=int)
for i in range(det_pred.shape[0]):
start_sample, end_sample = window_borders[i]
local_det_pred = det_pred[i, start_sample:end_sample]
local_p_pred = p_pred[i, start_sample:end_sample]
local_s_pred = s_pred[i, start_sample:end_sample]
score_detection[i] = torch.max(local_det_pred)
score_p_or_s[i] = torch.max(local_p_pred) / torch.max(
local_s_pred
) # most likely P by most likely S
p_sample[i] = torch.argmax(local_p_pred)
s_sample[i] = torch.argmax(local_s_pred)
return score_detection, score_p_or_s, p_sample, s_sample
class CREDLit(SeisBenchModuleLit):
"""
LightningModule for CRED
:param lr: Learning rate, defaults to 1e-2
:param sample_boundaries: Low and high boundaries for the RandomWindow selection.
:param kwargs: Kwargs are passed to the SeisBench.models.CRED constructor.
"""
def __init__(
self,
lr=1e-2,
sample_boundaries=(None, None),
detection_fixed_window=None,
**kwargs,
):
super().__init__()
self.save_hyperparameters()
self.lr = lr
self.sample_boundaries = sample_boundaries
self.detection_fixed_window = detection_fixed_window
self.loss = torch.nn.BCELoss()
self.model = sbm.CRED(**kwargs)
def forward(self, x):
return self.model(x)
def shared_step(self, batch):
x = batch["spec"]
y_true = batch["y"][:, 0]
y_pred = self.model(x)[:, :, 0]
return self.loss(y_pred, y_true)
def training_step(self, batch, batch_idx):
loss = self.shared_step(batch)
self.log("train_loss", loss)
return loss
def validation_step(self, batch, batch_idx):
loss = self.shared_step(batch)
self.log("val_loss", loss)
return loss
def configure_optimizers(self):
optimizer = torch.optim.Adam(self.parameters(), lr=self.lr)
return optimizer
def get_augmentations(self):
p_phases = [key for key, val in phase_dict.items() if val == "P"]
s_phases = [key for key, val in phase_dict.items() if val == "S"]
def spectrogram(state_dict):
x, metadata = state_dict["X"]
spec = self.model.waveforms_to_spectrogram(x)
state_dict["spec"] = (spec, metadata)
def resample_detections(state_dict):
# Resample detections to 19 samples as in the output of CRED
# Each sample represents the average over 158 original samples
y, metadata = state_dict["y"]
y = np.pad(y, [(0, 0), (0, 2)], mode="constant", constant_values=0)
y = np.reshape(y, (1, 19, 158))
y = np.mean(y, axis=-1)
state_dict["y"] = (y, metadata)
if self.detection_fixed_window is not None:
detection_labeller = sbg.DetectionLabeller(
p_phases, fixed_window=self.detection_fixed_window
)
else:
detection_labeller = sbg.DetectionLabeller(p_phases, s_phases=s_phases)
augmentations = [
# In 2/3 of the cases, select windows around picks, to reduce amount of noise traces in training.
# Uses strategy variable, as padding will be handled by the random window.
# In 1/3 of the cases, just returns the original trace, to keep diversity high.
sbg.OneOf(
[
sbg.WindowAroundSample(
list(phase_dict.keys()),
samples_before=3000,
windowlen=6000,
selection="random",
strategy="variable",
),
sbg.NullAugmentation(),
],
probabilities=[2, 1],
),
sbg.RandomWindow(
low=self.sample_boundaries[0],
high=self.sample_boundaries[1],
windowlen=3000,
strategy="pad",
),
detection_labeller,
# Normalize to ensure correct augmentation behavior
sbg.Normalize(detrend_axis=-1, amp_norm_axis=-1, amp_norm_type="peak"),
spectrogram,
resample_detections,
sbg.ChangeDtype(np.float32, "y"),
sbg.ChangeDtype(np.float32, "spec"),
]
return augmentations
def get_eval_augmentations(self):
def spectrogram(state_dict):
x, metadata = state_dict["X"]
spec = self.model.waveforms_to_spectrogram(x)
state_dict["spec"] = (spec, metadata)
return [
sbg.SteeredWindow(windowlen=3000, strategy="pad"),
sbg.Normalize(detrend_axis=-1, amp_norm_axis=-1, amp_norm_type="peak"),
spectrogram,
sbg.ChangeDtype(np.float32, "spec"),
]
def predict_step(self, batch, batch_idx=None, dataloader_idx=None):
x = batch["spec"]
window_borders = batch["window_borders"]
pred = self.model(x)
score_detection = torch.zeros(pred.shape[0])
score_p_or_s = torch.zeros(pred.shape[0]) * np.nan
p_sample = torch.zeros(pred.shape[0], dtype=int) * np.nan
s_sample = torch.zeros(pred.shape[0], dtype=int) * np.nan
for i in range(pred.shape[0]):
start_sample, end_sample = window_borders[i]
# We go for a slightly broader window, i.e., all output prediction points encompassing the target window
start_resampled = start_sample // 158
end_resampled = end_sample // 158 + 1
local_pred = pred[i, start_resampled:end_resampled, 0]
score_detection[i] = torch.max(local_pred) # 1 - noise
return score_detection, score_p_or_s, p_sample, s_sample
class BasicPhaseAELit(SeisBenchModuleLit):
"""
LightningModule for BasicPhaseAE
:param lr: Learning rate, defaults to 1e-2
:param sigma: Standard deviation passed to the ProbabilisticPickLabeller
:param sample_boundaries: Low and high boundaries for the RandomWindow selection.
:param kwargs: Kwargs are passed to the SeisBench.models.BasicPhaseAE constructor.
"""
def __init__(self, lr=1e-2, sigma=20, sample_boundaries=(None, None), **kwargs):
super().__init__()
self.save_hyperparameters()
self.lr = lr
self.sigma = sigma
self.sample_boundaries = sample_boundaries
self.loss = vector_cross_entropy
self.model = sbm.BasicPhaseAE(**kwargs)
def forward(self, x):
return self.model(x)
def shared_step(self, batch):
x = batch["X"]
y_true = batch["y"]
y_pred = self.model(x)
return self.loss(y_pred, y_true)
def training_step(self, batch, batch_idx):
loss = self.shared_step(batch)
self.log("train_loss", loss)
return loss
def validation_step(self, batch, batch_idx):
loss = self.shared_step(batch)
self.log("val_loss", loss)
return loss
def configure_optimizers(self):
optimizer = torch.optim.Adam(self.parameters(), lr=self.lr)
return optimizer
def get_augmentations(self):
return [
# In 2/3 of the cases, select windows around picks, to reduce amount of noise traces in training.
# Uses strategy variable, as padding will be handled by the random window.
# In 1/3 of the cases, just returns the original trace, to keep diversity high.
sbg.OneOf(
[
sbg.WindowAroundSample(
list(phase_dict.keys()),
samples_before=700,
windowlen=700,
selection="random",
strategy="variable",
),
sbg.NullAugmentation(),
],
probabilities=[2, 1],
),
sbg.RandomWindow(
low=self.sample_boundaries[0],
high=self.sample_boundaries[1],
windowlen=600,
strategy="pad",
),
sbg.ChangeDtype(np.float32),
sbg.Normalize(demean_axis=-1, amp_norm_axis=-1, amp_norm_type="peak"),
sbg.ProbabilisticLabeller(
label_columns=phase_dict, sigma=self.sigma, dim=0
),
]
def get_eval_augmentations(self):
return [
sbg.SteeredWindow(windowlen=3000, strategy="pad"),
sbg.ChangeDtype(np.float32),
sbg.Normalize(demean_axis=-1, amp_norm_axis=-1, amp_norm_type="peak"),
]
def predict_step(self, batch, batch_idx=None, dataloader_idx=None):
x = batch["X"]
window_borders = batch["window_borders"]
# Create overlapping windows
re = torch.zeros(x.shape[:2] + (7, 600), dtype=x.dtype, device=x.device)
for i, start in enumerate(range(0, 2401, 400)):
re[:, :, i] = x[:, :, start: start + 600]
x = re
x = x.permute(0, 2, 1, 3) # --> (batch, windows, channels, samples)
shape_save = x.shape
x = x.reshape(-1, 3, 600) # --> (batch * windows, channels, samples)
window_pred = self.model(x)
window_pred = window_pred.reshape(
shape_save[:2] + (3, 600)
) # --> (batch, window, channels, samples)
# Connect predictions again, ignoring first and last second of each prediction
pred = torch.zeros((window_pred.shape[0], window_pred.shape[2], 3000))
for i, start in enumerate(range(0, 2401, 400)):
if start == 0:
# Use full window (for start==0, the end will be overwritten)
pred[:, :, start: start + 600] = window_pred[:, i]
else:
pred[:, :, start + 100: start + 600] = window_pred[:, i, :, 100:]
score_detection = torch.zeros(pred.shape[0])
score_p_or_s = torch.zeros(pred.shape[0])
p_sample = torch.zeros(pred.shape[0], dtype=int)
s_sample = torch.zeros(pred.shape[0], dtype=int)
for i in range(pred.shape[0]):
start_sample, end_sample = window_borders[i]
local_pred = pred[i, :, start_sample:end_sample]
score_detection[i] = torch.max(1 - local_pred[-1]) # 1 - noise
score_p_or_s[i] = torch.max(local_pred[0]) / torch.max(
local_pred[1]
) # most likely P by most likely S
p_sample[i] = torch.argmax(local_pred[0])
s_sample[i] = torch.argmax(local_pred[1])
return score_detection, score_p_or_s, p_sample, s_sample
class DPPDetectorLit(SeisBenchModuleLit):
"""
LightningModule for DPPDetector
:param lr: Learning rate, defaults to 1e-2
:param kwargs: Kwargs are passed to the SeisBench.models.PhaseNet constructor.
"""
def __init__(self, lr=1e-3, **kwargs):
super().__init__()
self.save_hyperparameters()
self.lr = lr
self.nllloss = torch.nn.NLLLoss()
self.loss = self.nll_with_probabilities
self.model = sbm.DPPDetector(**kwargs)
def forward(self, x):
return self.model(x)
def nll_with_probabilities(self, y_pred, y_true):
y_pred = torch.log(y_pred + 1e-5)
return self.nllloss(y_pred, y_true)
def shared_step(self, batch):
x = batch["X"]
y_true = batch["y"].squeeze()
y_pred = self.model(x)
return self.loss(y_pred, y_true)
def training_step(self, batch, batch_idx):
loss = self.shared_step(batch)
self.log("train_loss", loss)
return loss
def validation_step(self, batch, batch_idx):
loss = self.shared_step(batch)
self.log("val_loss", loss)
return loss
def configure_optimizers(self):
optimizer = torch.optim.Adam(self.parameters(), lr=self.lr)
return optimizer
def get_augmentations(self):
return [
# In 2/3 of the cases, select windows around picks, to reduce amount of noise traces in training.
# Uses strategy variable, as padding will be handled by the random window.
# In 1/3 of the cases, just returns the original trace, to keep diversity high.
sbg.OneOf(
[
sbg.WindowAroundSample(
list(phase_dict.keys()),
samples_before=650,
windowlen=1300,
selection="random",
strategy="variable",
),
sbg.NullAugmentation(),
],
probabilities=[2, 1],
),
sbg.RandomWindow(
windowlen=500,
strategy="pad",
),
sbg.ChangeDtype(np.float32),
sbg.Normalize(demean_axis=-1, amp_norm_axis=-1, amp_norm_type="peak"),
sbg.StandardLabeller(
label_columns=phase_dict, on_overlap="fixed-relevance"
),
]
def get_eval_augmentations(self):
return [
sbg.SteeredWindow(windowlen=3000, strategy="pad"),
sbg.ChangeDtype(np.float32),
sbg.Normalize(demean_axis=-1, amp_norm_axis=-1, amp_norm_type="peak"),
]
def predict_step(self, batch, batch_idx=None, dataloader_idx=None):
x = batch["X"]
window_borders = batch["window_borders"]
# Create windows
x = x.reshape(x.shape[:-1] + (6, 500)) # Split into 6 windows of length 500
x = x.permute(0, 2, 1, 3) # --> (batch, windows, channels, samples)
shape_save = x.shape
x = x.reshape(-1, 3, 500) # --> (batch * windows, channels, samples)
pred = self.model(x)
pred = pred.reshape(shape_save[:2] + (-1,)) # --> (batch, windows, label)
score_detection = torch.zeros(pred.shape[0])
score_p_or_s = torch.zeros(pred.shape[0])
p_sample = torch.zeros(pred.shape[0], dtype=int) * np.nan
s_sample = torch.zeros(pred.shape[0], dtype=int) * np.nan
for i in range(pred.shape[0]):
start_sample, end_sample = window_borders[i]
start_resampled = start_sample.cpu() // 500
end_resampled = int(np.ceil(end_sample.cpu() / 500))
local_pred = pred[i, start_resampled:end_resampled, :]
score_p_or_s[i] = torch.max(local_pred[:, 0]) / torch.max(local_pred[:, 1])
score_detection[i] = torch.max(1 - local_pred[:, -1])
return score_detection, score_p_or_s, p_sample, s_sample
class DPPPickerLit(SeisBenchModuleLit):
"""
LightningModule for DPPPicker
:param lr: Learning rate, defaults to 1e-2
:param kwargs: Kwargs are passed to the SeisBench.models.PhaseNet constructor.
"""
def __init__(self, mode, lr=1e-3, **kwargs):
super().__init__()
self.save_hyperparameters()
self.mode = mode
self.lr = lr
self.loss = torch.nn.BCELoss()
print(mode, lr, **kwargs)
self.model = sbm.DPPPicker(mode=mode, **kwargs)
def forward(self, x):
if self.mode == "P":
x = x[:, 0:1] # Select vertical component
elif self.mode == "S":
x = x[:, 1:3] # Select horizontal components
else:
raise ValueError(f"Unknown mode {self.mode}")
return self.model(x)
def shared_step(self, batch):
x = batch["X"]
y_true = batch["y"]
if self.mode == "P":
y_true = y_true[:, 0] # P wave
x = x[:, 0:1] # Select vertical component
elif self.mode == "S":
y_true = y_true[:, 1] # S wave
x = x[:, 1:3] # Select horizontal components
else:
raise ValueError(f"Unknown mode {self.mode}")
y_pred = self.model(x)
loss = self.loss(y_pred, y_true)
return loss
def training_step(self, batch, batch_idx):
loss = self.shared_step(batch)
self.log("train_loss", loss)
return loss
def validation_step(self, batch, batch_idx):
loss = self.shared_step(batch)
self.log("val_loss", loss)
return loss
def configure_optimizers(self):
optimizer = torch.optim.Adam(self.parameters(), lr=self.lr)
return optimizer
def get_augmentations(self):
return [
sbg.WindowAroundSample(
[key for key, val in phase_dict.items() if val == self.mode],
samples_before=1000,
windowlen=2000,
selection="random",
strategy="variable",
),
sbg.RandomWindow(
windowlen=1000,
strategy="pad",
),
sbg.Normalize(demean_axis=-1, amp_norm_axis=-1, amp_norm_type="peak"),
sbg.StepLabeller(label_columns=phase_dict),
sbg.ChangeDtype(np.float32),
sbg.ChangeDtype(np.float32, "y"),
]
def get_eval_augmentations(self):
return [
sbg.SteeredWindow(windowlen=1000, strategy="pad"),
sbg.ChangeDtype(np.float32),
sbg.Normalize(demean_axis=-1, amp_norm_axis=-1, amp_norm_type="peak"),
]
def predict_step(self, batch, batch_idx=None, dataloader_idx=None):
x = batch["X"]
window_borders = batch["window_borders"]
pred = self(x)
score_detection = torch.zeros(pred.shape[0]) * np.nan
score_p_or_s = torch.zeros(pred.shape[0]) * np.nan
p_sample = torch.zeros(pred.shape[0], dtype=int) * np.nan
s_sample = torch.zeros(pred.shape[0], dtype=int) * np.nan
for i in range(pred.shape[0]):
start_sample, end_sample = window_borders[i]
local_pred = pred[i, start_sample:end_sample]
if (local_pred > 0.5).any():
sample = torch.argmax(
(local_pred > 0.5).float()
) # First sample exceeding 0.5
else:
sample = 500 # Simply guess the middle
if self.mode == "P":
p_sample[i] = sample
elif self.mode == "S":
s_sample[i] = sample
else:
raise ValueError(f"Unknown mode {self.mode}")
return score_detection, score_p_or_s, p_sample, s_sample
def get_model_specific_args(config):
model = config.model_name[0]
lr = config.learning_rate
if type(lr) == list:
lr = lr[0]
args = {'lr': lr}
match model:
case 'GPD':
if 'highpass' in config:
args['highpass'] = config.highpass
if 'lowpass' in config:
args['lowpass'] = config.lowpass
case 'PhaseNet':
if 'highpass' in config:
args['highpass'] = config.highpass
if 'lowpass' in config:
args['lowpass'] = config.lowpass
if 'sample_boundaries' in config:
args['sample_boundaries'] = config.sample_boundaries
case 'DPPPicker':
if 'mode' in config:
args['mode'] = config.mode[0]
print("args", args, model)
return args