import json import pathlib import numpy as np import obspy import pandas as pd import seisbench.data as sbd import seisbench.models as sbm from seisbench.models.team import itertools from sklearn.metrics import precision_recall_curve, roc_auc_score, roc_curve datasets = [ # path to datasets in seisbench format ] models = [ # model names ] def find_keys_phase(meta, phase): phases = [] for k in meta.keys(): if k.startswith("trace_" + phase) and k.endswith("_arrival_sample"): phases.append(k) return phases def create_stream(meta, raw, start, length=30): st = obspy.Stream() for i in range(3): tr = obspy.Trace(raw[i, :]) tr.stats.starttime = meta["trace_start_time"] tr.stats.sampling_rate = meta["trace_sampling_rate_hz"] tr.stats.network = meta["station_network_code"] tr.stats.station = meta["station_code"] tr.stats.channel = meta["trace_channel"][:2] + meta["trace_component_order"][i] stop = start + length tr = tr.slice(start, stop) st.append(tr) return st def get_pred(model, stream): ann = model.annotate(stream) noise = ann.select(channel="PhaseNet_N")[0] pred = max(1 - noise.data) return pred def to_short(stream): short = [tr for tr in stream if tr.data.shape[0] < 3001] return any(short) for ds, model_name in itertools.product(datasets, models): data = sbd.WaveformDataset(ds, sampling_rate=100).test() data_name = pathlib.Path(ds).stem fname = f"roc___{model_name}___{data_name}.csv" print(f"{fname:.<50s}.... ", flush=True, end="") if pathlib.Path(fname).is_file(): print(" ready, skipping", flush=True) continue p_labels = find_keys_phase(data.metadata, "P") s_labels = find_keys_phase(data.metadata, "S") model = sbm.PhaseNet().from_pretrained(model_name) label_true = [] label_pred = [] for i in range(len(data)): waveform, metadata = data.get_sample(i) m = pd.Series(metadata) has_p_label = m[p_labels].notna() has_s_label = m[s_labels].notna() if any(has_p_label): trace_start_time = obspy.UTCDateTime(m["trace_start_time"]) pick_sample = m[p_labels][has_p_label][0] start = trace_start_time + pick_sample / m["trace_sampling_rate_hz"] - 15 try: st_p = create_stream(m, waveform, start) if not (to_short(st_p)): pred_p = get_pred(model, st_p) label_true.append(1) label_pred.append(pred_p) except IndexError: pass try: st_n = create_stream(m, waveform, trace_start_time + 1) if not (to_short(st_n)): pred_n = get_pred(model, st_n) label_true.append(0) label_pred.append(pred_n) except IndexError: pass if any(has_s_label): trace_start_time = obspy.UTCDateTime(m["trace_start_time"]) pick_sample = m[s_labels][has_s_label][0] start = trace_start_time + pick_sample / m["trace_sampling_rate_hz"] - 15 try: st_s = create_stream(m, waveform, start) if not (to_short(st_s)): pred_s = get_pred(model, st_s) label_true.append(1) label_pred.append(pred_s) except IndexError: pass fpr, tpr, roc_thresholds = roc_curve(label_true, label_pred) df = pd.DataFrame({"fpr": fpr, "tpr": tpr, "thresholds": roc_thresholds}) df.to_csv(fname) precision, recall, prc_thresholds = precision_recall_curve(label_true, label_pred) prc_thresholds_extra = np.append(prc_thresholds, -999) df = pd.DataFrame( {"pre": precision, "rec": recall, "thresholds": prc_thresholds_extra} ) df.to_csv(fname.replace("roc", "pr")) stats = { "model": str(model_name), "data": str(data_name), "auc": float(roc_auc_score(label_true, label_pred)), } with open(f"stats___{model_name}___{data_name}.json", "w") as fp: json.dump(stats, fp) print(" finished", flush=True)