PandSWavesDetectionTool/src/epos_ai_picking_tools/.ipynb_checkpoints/model_runner-checkpoint.py
2023-09-20 09:44:18 +00:00

103 lines
3.1 KiB
Python

import json
import pathlib
import sys
from collections import defaultdict
import obspy
import seisbench.models as sbm
from obspy.core.event import Catalog, Event, Pick, WaveformStreamID
from obspy.io.json.default import Default
def list_pretrained_models(model_runner_class):
m = getattr(sbm, model_runner_class.model_type)
weights = m.list_pretrained()
return weights
def exit_error(msg):
print("ERROR:", msg)
sys.exit(1)
class ModelRunner:
model_type = "EMPTY"
def __init__(self, weights_name="original", output_dir=pathlib.Path("."), **kwargs):
# self.model_name = getattr(sbm, __class__.model_type)
self.model = self.load_model(weights_name)
self.output_dir = pathlib.Path(output_dir)
self.process_kwargs(**kwargs)
def process_kwargs(self, **kwargs):
pass
def citation(self):
return self.model.citation
def load_model(self, weights_name):
return self.model_name.from_pretrained(weights_name)
def load_stream(self, stream_file_name):
return obspy.read(stream_file_name)
def save_picks(self, classs_picks, stream_path):
dict_picks = list(map(lambda p: p.__dict__, classs_picks))
fpath = self.output_dir / f"{stream_path.stem}_picks.json"
with open(fpath, "w") as fp:
json.dump(dict_picks, fp, default=Default())
def save_quakeml(self, classs_picks, stream_path):
e = Event()
for cpick in classs_picks:
net, sta, loc = cpick.trace_id.split(".")
p = Pick(
time=cpick.peak_time,
phase_hint=cpick.phase,
waveform_id=WaveformStreamID(
network_code=net, station_code=sta, location_code=loc
),
)
e.picks.append(p)
cat = Catalog([e])
fpath = self.output_dir / f"{stream_path.stem}_picks.xml"
cat.write(fpath, format="QUAKEML")
def write_annotations(self, annotations, stream_path):
ann = annotations.copy()
for tr in ann:
tr.stats.channel = f"G_{tr.stats.component}"
fpath = self.output_dir / f"{stream_path.stem}_annotations.mseed"
ann.write(fpath, format="MSEED")
@staticmethod
def validate_stream(stream):
groups = defaultdict(list)
for trace in stream:
groups[trace.stats.station].append(trace.stats.channel[-1])
number_of_channels = list(map(len, groups.values()))
if max(number_of_channels) < 3:
exit_error("Not enough traces in the stream")
def find_picks(self, stream_file_name, save_annotations=True):
stream_path = pathlib.Path(stream_file_name)
stream = self.load_stream(stream_path)
self.validate_stream(stream)
annotations = self.model.annotate(stream, **self.annotate_kwargs)
if save_annotations:
self.write_annotations(annotations, stream_path)
classs_picks = self.model.classify_aggregate(annotations, self.classify_kwargs)
self.save_picks(classs_picks, stream_path)
self.save_quakeml(classs_picks, stream_path)
return classs_picks