import json import pathlib import sys from collections import defaultdict import obspy import seisbench.models as sbm from obspy.core.event import Catalog, Event, Pick, WaveformStreamID from obspy.io.json.default import Default def list_pretrained_models(model_runner_class): m = getattr(sbm, model_runner_class.model_type) weights = m.list_pretrained() return weights def exit_error(msg): print("ERROR:", msg) sys.exit(1) class ModelRunner: model_type = "EMPTY" def __init__(self, weights_name="original", output_dir=pathlib.Path("."), **kwargs): # self.model_name = getattr(sbm, __class__.model_type) self.model = self.load_model(weights_name) self.output_dir = pathlib.Path(output_dir) self.process_kwargs(**kwargs) def process_kwargs(self, **kwargs): pass def citation(self): return self.model.citation def load_model(self, weights_name): return self.model_name.from_pretrained(weights_name) def load_stream(self, stream_file_name): return obspy.read(stream_file_name) def save_picks(self, classs_picks, stream_path): dict_picks = list(map(lambda p: p.__dict__, classs_picks)) fpath = self.output_dir / f"{stream_path.stem}_picks.json" with open(fpath, "w") as fp: json.dump(dict_picks, fp, default=Default()) def save_quakeml(self, classs_picks, stream_path): e = Event() for cpick in classs_picks: net, sta, loc = cpick.trace_id.split(".") p = Pick( time=cpick.peak_time, phase_hint=cpick.phase, waveform_id=WaveformStreamID( network_code=net, station_code=sta, location_code=loc ), ) e.picks.append(p) cat = Catalog([e]) fpath = self.output_dir / f"{stream_path.stem}_picks.xml" cat.write(fpath, format="QUAKEML") def write_annotations(self, annotations, stream_path): ann = annotations.copy() for tr in ann: tr.stats.channel = f"G_{tr.stats.component}" fpath = self.output_dir / f"{stream_path.stem}_annotations.mseed" ann.write(fpath, format="MSEED") @staticmethod def validate_stream(stream): groups = defaultdict(list) for trace in stream: groups[trace.stats.station].append(trace.stats.channel[-1]) number_of_channels = list(map(len, groups.values())) if max(number_of_channels) < 3: exit_error("Not enough traces in the stream") def find_picks(self, stream_file_name, save_annotations=True): stream_path = pathlib.Path(stream_file_name) stream = self.load_stream(stream_path) self.validate_stream(stream) annotations = self.model.annotate(stream, **self.annotate_kwargs) if save_annotations: self.write_annotations(annotations, stream_path) classs_picks = self.model.classify_aggregate(annotations, self.classify_kwargs) self.save_picks(classs_picks, stream_path) self.save_quakeml(classs_picks, stream_path) return classs_picks