diff --git a/scripts/perf_analysis.py b/scripts/perf_analysis.py new file mode 100644 index 0000000..7d8f82a --- /dev/null +++ b/scripts/perf_analysis.py @@ -0,0 +1,149 @@ +import json +import pathlib + +import numpy as np +import obspy +import pandas as pd +import seisbench.data as sbd +import seisbench.models as sbm +from seisbench.models.team import itertools +from sklearn.metrics import precision_recall_curve, roc_auc_score, roc_curve + +datasets = [ + # path to datasets in seisbench format +] + +models = [ + # model names +] + + +def find_keys_phase(meta, phase): + phases = [] + for k in meta.keys(): + if k.startswith("trace_" + phase) and k.endswith("_arrival_sample"): + phases.append(k) + + return phases + + +def create_stream(meta, raw, start, length=30): + + st = obspy.Stream() + + for i in range(3): + tr = obspy.Trace(raw[i, :]) + tr.stats.starttime = meta["trace_start_time"] + tr.stats.sampling_rate = meta["trace_sampling_rate_hz"] + tr.stats.network = meta["station_network_code"] + tr.stats.station = meta["station_code"] + tr.stats.channel = meta["trace_channel"][:2] + meta["trace_component_order"][i] + + stop = start + length + tr = tr.slice(start, stop) + + st.append(tr) + + return st + + +def get_pred(model, stream): + ann = model.annotate(stream) + noise = ann.select(channel="PhaseNet_N")[0] + pred = max(1 - noise.data) + return pred + + +def to_short(stream): + short = [tr for tr in stream if tr.data.shape[0] < 3001] + return any(short) + + +for ds, model_name in itertools.product(datasets, models): + + data = sbd.WaveformDataset(ds, sampling_rate=100).test() + data_name = pathlib.Path(ds).stem + fname = f"roc___{model_name}___{data_name}.csv" + + print(f"{fname:.<50s}.... ", flush=True, end="") + + if pathlib.Path(fname).is_file(): + print(" ready, skipping", flush=True) + continue + + p_labels = find_keys_phase(data.metadata, "P") + s_labels = find_keys_phase(data.metadata, "S") + + model = sbm.PhaseNet().from_pretrained(model_name) + + label_true = [] + label_pred = [] + + for i in range(len(data)): + + waveform, metadata = data.get_sample(i) + m = pd.Series(metadata) + + has_p_label = m[p_labels].notna() + has_s_label = m[s_labels].notna() + + if any(has_p_label): + + trace_start_time = obspy.UTCDateTime(m["trace_start_time"]) + pick_sample = m[p_labels][has_p_label][0] + + start = trace_start_time + pick_sample / m["trace_sampling_rate_hz"] - 15 + + try: + st_p = create_stream(m, waveform, start) + if not (to_short(st_p)): + pred_p = get_pred(model, st_p) + label_true.append(1) + label_pred.append(pred_p) + except IndexError: + pass + + try: + st_n = create_stream(m, waveform, trace_start_time + 1) + if not (to_short(st_n)): + pred_n = get_pred(model, st_n) + label_true.append(0) + label_pred.append(pred_n) + except IndexError: + pass + + if any(has_s_label): + trace_start_time = obspy.UTCDateTime(m["trace_start_time"]) + pick_sample = m[s_labels][has_s_label][0] + start = trace_start_time + pick_sample / m["trace_sampling_rate_hz"] - 15 + + try: + st_s = create_stream(m, waveform, start) + if not (to_short(st_s)): + pred_s = get_pred(model, st_s) + label_true.append(1) + label_pred.append(pred_s) + except IndexError: + pass + + fpr, tpr, roc_thresholds = roc_curve(label_true, label_pred) + df = pd.DataFrame({"fpr": fpr, "tpr": tpr, "thresholds": roc_thresholds}) + df.to_csv(fname) + + precision, recall, prc_thresholds = precision_recall_curve(label_true, label_pred) + prc_thresholds_extra = np.append(prc_thresholds, -999) + df = pd.DataFrame( + {"pre": precision, "rec": recall, "thresholds": prc_thresholds_extra} + ) + df.to_csv(fname.replace("roc", "pr")) + + stats = { + "model": str(model_name), + "data": str(data_name), + "auc": float(roc_auc_score(label_true, label_pred)), + } + + with open(f"stats___{model_name}___{data_name}.json", "w") as fp: + json.dump(stats, fp) + + print(" finished", flush=True)