platform-demo-scripts/scripts/perf_analysis.py

150 lines
4.2 KiB
Python
Raw Permalink Normal View History

import json
import pathlib
import numpy as np
import obspy
import pandas as pd
import seisbench.data as sbd
import seisbench.models as sbm
from seisbench.models.team import itertools
from sklearn.metrics import precision_recall_curve, roc_auc_score, roc_curve
datasets = [
# path to datasets in seisbench format
]
models = [
# model names
]
def find_keys_phase(meta, phase):
phases = []
for k in meta.keys():
if k.startswith("trace_" + phase) and k.endswith("_arrival_sample"):
phases.append(k)
return phases
def create_stream(meta, raw, start, length=30):
st = obspy.Stream()
for i in range(3):
tr = obspy.Trace(raw[i, :])
tr.stats.starttime = meta["trace_start_time"]
tr.stats.sampling_rate = meta["trace_sampling_rate_hz"]
tr.stats.network = meta["station_network_code"]
tr.stats.station = meta["station_code"]
tr.stats.channel = meta["trace_channel"][:2] + meta["trace_component_order"][i]
stop = start + length
tr = tr.slice(start, stop)
st.append(tr)
return st
def get_pred(model, stream):
ann = model.annotate(stream)
noise = ann.select(channel="PhaseNet_N")[0]
pred = max(1 - noise.data)
return pred
def to_short(stream):
short = [tr for tr in stream if tr.data.shape[0] < 3001]
return any(short)
for ds, model_name in itertools.product(datasets, models):
data = sbd.WaveformDataset(ds, sampling_rate=100).test()
data_name = pathlib.Path(ds).stem
fname = f"roc___{model_name}___{data_name}.csv"
print(f"{fname:.<50s}.... ", flush=True, end="")
if pathlib.Path(fname).is_file():
print(" ready, skipping", flush=True)
continue
p_labels = find_keys_phase(data.metadata, "P")
s_labels = find_keys_phase(data.metadata, "S")
model = sbm.PhaseNet().from_pretrained(model_name)
label_true = []
label_pred = []
for i in range(len(data)):
waveform, metadata = data.get_sample(i)
m = pd.Series(metadata)
has_p_label = m[p_labels].notna()
has_s_label = m[s_labels].notna()
if any(has_p_label):
trace_start_time = obspy.UTCDateTime(m["trace_start_time"])
pick_sample = m[p_labels][has_p_label][0]
start = trace_start_time + pick_sample / m["trace_sampling_rate_hz"] - 15
try:
st_p = create_stream(m, waveform, start)
if not (to_short(st_p)):
pred_p = get_pred(model, st_p)
label_true.append(1)
label_pred.append(pred_p)
except IndexError:
pass
try:
st_n = create_stream(m, waveform, trace_start_time + 1)
if not (to_short(st_n)):
pred_n = get_pred(model, st_n)
label_true.append(0)
label_pred.append(pred_n)
except IndexError:
pass
if any(has_s_label):
trace_start_time = obspy.UTCDateTime(m["trace_start_time"])
pick_sample = m[s_labels][has_s_label][0]
start = trace_start_time + pick_sample / m["trace_sampling_rate_hz"] - 15
try:
st_s = create_stream(m, waveform, start)
if not (to_short(st_s)):
pred_s = get_pred(model, st_s)
label_true.append(1)
label_pred.append(pred_s)
except IndexError:
pass
fpr, tpr, roc_thresholds = roc_curve(label_true, label_pred)
df = pd.DataFrame({"fpr": fpr, "tpr": tpr, "thresholds": roc_thresholds})
df.to_csv(fname)
precision, recall, prc_thresholds = precision_recall_curve(label_true, label_pred)
prc_thresholds_extra = np.append(prc_thresholds, -999)
df = pd.DataFrame(
{"pre": precision, "rec": recall, "thresholds": prc_thresholds_extra}
)
df.to_csv(fname.replace("roc", "pr"))
stats = {
"model": str(model_name),
"data": str(data_name),
"auc": float(roc_auc_score(label_true, label_pred)),
}
with open(f"stats___{model_name}___{data_name}.json", "w") as fp:
json.dump(stats, fp)
print(" finished", flush=True)