103 lines
3.1 KiB
Python
103 lines
3.1 KiB
Python
import json
|
|
import pathlib
|
|
import sys
|
|
from collections import defaultdict
|
|
|
|
import obspy
|
|
import seisbench.models as sbm
|
|
from obspy.core.event import Catalog, Event, Pick, WaveformStreamID
|
|
from obspy.io.json.default import Default
|
|
|
|
|
|
def list_pretrained_models(model_runner_class):
|
|
m = getattr(sbm, model_runner_class.model_type)
|
|
weights = m.list_pretrained()
|
|
return weights
|
|
|
|
|
|
def exit_error(msg):
|
|
print("ERROR:", msg)
|
|
sys.exit(1)
|
|
|
|
|
|
class ModelRunner:
|
|
model_type = "EMPTY"
|
|
|
|
def __init__(self, weights_name="original", output_dir=pathlib.Path("."), **kwargs):
|
|
# self.model_name = getattr(sbm, __class__.model_type)
|
|
self.model = self.load_model(weights_name)
|
|
self.output_dir = pathlib.Path(output_dir)
|
|
self.process_kwargs(**kwargs)
|
|
|
|
def process_kwargs(self, **kwargs):
|
|
pass
|
|
|
|
def citation(self):
|
|
return self.model.citation
|
|
|
|
def load_model(self, weights_name):
|
|
return self.model_name.from_pretrained(weights_name)
|
|
|
|
def load_stream(self, stream_file_name):
|
|
return obspy.read(stream_file_name)
|
|
|
|
def save_picks(self, classs_picks, stream_path):
|
|
dict_picks = list(map(lambda p: p.__dict__, classs_picks))
|
|
fpath = self.output_dir / f"{stream_path.stem}_picks.json"
|
|
|
|
with open(fpath, "w") as fp:
|
|
json.dump(dict_picks, fp, default=Default())
|
|
|
|
def save_quakeml(self, classs_picks, stream_path):
|
|
e = Event()
|
|
for cpick in classs_picks:
|
|
net, sta, loc = cpick.trace_id.split(".")
|
|
p = Pick(
|
|
time=cpick.peak_time,
|
|
phase_hint=cpick.phase,
|
|
waveform_id=WaveformStreamID(
|
|
network_code=net, station_code=sta, location_code=loc
|
|
),
|
|
)
|
|
e.picks.append(p)
|
|
|
|
cat = Catalog([e])
|
|
fpath = self.output_dir / f"{stream_path.stem}_picks.xml"
|
|
cat.write(fpath, format="QUAKEML")
|
|
|
|
def write_annotations(self, annotations, stream_path):
|
|
ann = annotations.copy()
|
|
for tr in ann:
|
|
tr.stats.channel = f"G_{tr.stats.component}"
|
|
fpath = self.output_dir / f"{stream_path.stem}_annotations.mseed"
|
|
ann.write(fpath, format="MSEED")
|
|
|
|
@staticmethod
|
|
def validate_stream(stream):
|
|
groups = defaultdict(list)
|
|
for trace in stream:
|
|
groups[trace.stats.station].append(trace.stats.channel[-1])
|
|
|
|
number_of_channels = list(map(len, groups.values()))
|
|
|
|
if max(number_of_channels) < 3:
|
|
exit_error("Not enough traces in the stream")
|
|
|
|
def find_picks(self, stream_file_name, save_annotations=True):
|
|
stream_path = pathlib.Path(stream_file_name)
|
|
stream = self.load_stream(stream_path)
|
|
|
|
self.validate_stream(stream)
|
|
|
|
annotations = self.model.annotate(stream, **self.annotate_kwargs)
|
|
|
|
if save_annotations:
|
|
self.write_annotations(annotations, stream_path)
|
|
|
|
classs_picks = self.model.classify_aggregate(annotations, self.classify_kwargs)
|
|
|
|
self.save_picks(classs_picks, stream_path)
|
|
self.save_quakeml(classs_picks, stream_path)
|
|
|
|
return classs_picks
|