platform-demo-scripts/scripts/augmentations.py

134 lines
5.7 KiB
Python
Raw Permalink Normal View History

"""
This file contains augmentations required for the models that are too specific to be merged into SeisBench.
"""
import numpy as np
import copy
class DuplicateEvent:
"""
Adds a rescaled version of the event to the empty part of the trace after the event.
Event position and empty space are determined from a detection.
Detections can be generated for example with :py:class:`~seisbench.generate.labeling.DetectionLabeller`.
This implementation is modelled after the `implementation for EQTransformer <https://github.com/smousavi05/EQTransformer/blob/98676017f971efbb6f4475f42e415c3868d00c03/EQTransformer/core/EqT_utils.py#L255>`_.
.. warning::
This augmentation does **not** modify the metadata, as representing multiple picks of
the same type is currently not supported. Workflows should therefore always first generate
labels from metadata and then pass the labels in the key `label_keys`. These keys are automatically
adjusted by addition of the labels.
.. warning::
This implementation currently has strict shape requirements:
- (1, samples) for detection
- (channels, samples) for data
- (labels, samples) for labels
:param inv_scale: The scale factor is defined by as 1/u, where u is uniform.
`inv_scale` defines the minimum and maximum values for u.
Defaults to (1, 10), e.g., scaling by factor 1 to 1/10.
:param detection_key: Key to read detection from.
If key is a tuple, detection will be read from the first key and written to the second one.
:param key: The keys for reading from and writing to the state dict.
If key is a single string, the corresponding entry in state dict is modified.
Otherwise, a 2-tuple is expected, with the first string indicating the key
to read from and the second one the key to write to.
:param label_keys: Keys for the label columns.
Labels of the original and duplicate events will be added and capped at 1.
Note that this will lead to invalid noise traces.
Value can either be a single key specification or a list of key specifications.
Each key specification is either a string, for identical input and output keys,
or as a tuple of two strings, input and output keys.
Defaults to None.
"""
def __init__(
self, inv_scale=(1, 10), detection_key="detections", key="X", label_keys=None
):
if isinstance(detection_key, str):
self.detection_key = (detection_key, detection_key)
else:
self.detection_key = detection_key
if isinstance(key, str):
self.key = (key, key)
else:
self.key = key
# Single key
if not isinstance(label_keys, list):
if label_keys is None:
label_keys = []
else:
label_keys = [label_keys]
# Resolve identical input and output keys
self.label_keys = []
for key in label_keys:
if isinstance(key, tuple):
self.label_keys.append(key)
else:
self.label_keys.append((key, key))
self.inv_scale = inv_scale
def __call__(self, state_dict):
x, metadata = state_dict[self.key[0]]
detection, _ = state_dict[self.detection_key[0]]
detection_mask = detection[0] > 0.5
if detection.shape[-1] != x.shape[-1]:
raise ValueError("Number of samples in trace and detection disagree.")
if self.key[0] != self.key[1]:
# Ensure metadata is not modified inplace unless input and output key are anyhow identical
metadata = copy.deepcopy(metadata)
if detection_mask.any():
n_samples = x.shape[-1]
event_samples = np.arange(n_samples)[detection_mask]
event_start, event_end = np.min(event_samples), np.max(event_samples) + 1
if event_end + 20 < n_samples:
second_start = np.random.randint(event_end + 20, n_samples)
scale = 1 / np.random.uniform(*self.inv_scale)
if self.key[0] != self.key[1]:
# Avoid inplace modification if input and output keys differ
x = x.copy()
space = min(event_end - event_start, n_samples - second_start)
x[:, second_start : second_start + space] += (
x[:, event_start : event_start + space] * scale
)
shift = second_start - event_start
for label_key in self.label_keys + [self.detection_key]:
y, metadata = state_dict[label_key[0]]
if y.shape[-1] != n_samples:
raise ValueError(
f"Number of samples disagree between trace and label key '{label_key[0]}'."
)
if label_key[0] != label_key[1]:
metadata = copy.deepcopy(metadata)
y = y.copy()
y[:, shift:] += y[:, :-shift]
y = np.minimum(y, 1)
state_dict[label_key[1]] = (y, metadata)
else:
# Copy entries
for label_key in self.label_keys + [self.detection_key]:
y, metadata = state_dict[label_key[0]]
if label_key[0] != label_key[1]:
metadata = copy.deepcopy(metadata)
y = y.copy()
state_dict[label_key[1]] = (y, metadata)
state_dict[self.key[1]] = (x, metadata)