Import script for model performanace analysis
This commit is contained in:
parent
e86f131cc0
commit
281c73764d
149
scripts/perf_analysis.py
Normal file
149
scripts/perf_analysis.py
Normal file
@ -0,0 +1,149 @@
|
|||||||
|
import json
|
||||||
|
import pathlib
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import obspy
|
||||||
|
import pandas as pd
|
||||||
|
import seisbench.data as sbd
|
||||||
|
import seisbench.models as sbm
|
||||||
|
from seisbench.models.team import itertools
|
||||||
|
from sklearn.metrics import precision_recall_curve, roc_auc_score, roc_curve
|
||||||
|
|
||||||
|
datasets = [
|
||||||
|
# path to datasets in seisbench format
|
||||||
|
]
|
||||||
|
|
||||||
|
models = [
|
||||||
|
# model names
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
def find_keys_phase(meta, phase):
|
||||||
|
phases = []
|
||||||
|
for k in meta.keys():
|
||||||
|
if k.startswith("trace_" + phase) and k.endswith("_arrival_sample"):
|
||||||
|
phases.append(k)
|
||||||
|
|
||||||
|
return phases
|
||||||
|
|
||||||
|
|
||||||
|
def create_stream(meta, raw, start, length=30):
|
||||||
|
|
||||||
|
st = obspy.Stream()
|
||||||
|
|
||||||
|
for i in range(3):
|
||||||
|
tr = obspy.Trace(raw[i, :])
|
||||||
|
tr.stats.starttime = meta["trace_start_time"]
|
||||||
|
tr.stats.sampling_rate = meta["trace_sampling_rate_hz"]
|
||||||
|
tr.stats.network = meta["station_network_code"]
|
||||||
|
tr.stats.station = meta["station_code"]
|
||||||
|
tr.stats.channel = meta["trace_channel"][:2] + meta["trace_component_order"][i]
|
||||||
|
|
||||||
|
stop = start + length
|
||||||
|
tr = tr.slice(start, stop)
|
||||||
|
|
||||||
|
st.append(tr)
|
||||||
|
|
||||||
|
return st
|
||||||
|
|
||||||
|
|
||||||
|
def get_pred(model, stream):
|
||||||
|
ann = model.annotate(stream)
|
||||||
|
noise = ann.select(channel="PhaseNet_N")[0]
|
||||||
|
pred = max(1 - noise.data)
|
||||||
|
return pred
|
||||||
|
|
||||||
|
|
||||||
|
def to_short(stream):
|
||||||
|
short = [tr for tr in stream if tr.data.shape[0] < 3001]
|
||||||
|
return any(short)
|
||||||
|
|
||||||
|
|
||||||
|
for ds, model_name in itertools.product(datasets, models):
|
||||||
|
|
||||||
|
data = sbd.WaveformDataset(ds, sampling_rate=100).test()
|
||||||
|
data_name = pathlib.Path(ds).stem
|
||||||
|
fname = f"roc___{model_name}___{data_name}.csv"
|
||||||
|
|
||||||
|
print(f"{fname:.<50s}.... ", flush=True, end="")
|
||||||
|
|
||||||
|
if pathlib.Path(fname).is_file():
|
||||||
|
print(" ready, skipping", flush=True)
|
||||||
|
continue
|
||||||
|
|
||||||
|
p_labels = find_keys_phase(data.metadata, "P")
|
||||||
|
s_labels = find_keys_phase(data.metadata, "S")
|
||||||
|
|
||||||
|
model = sbm.PhaseNet().from_pretrained(model_name)
|
||||||
|
|
||||||
|
label_true = []
|
||||||
|
label_pred = []
|
||||||
|
|
||||||
|
for i in range(len(data)):
|
||||||
|
|
||||||
|
waveform, metadata = data.get_sample(i)
|
||||||
|
m = pd.Series(metadata)
|
||||||
|
|
||||||
|
has_p_label = m[p_labels].notna()
|
||||||
|
has_s_label = m[s_labels].notna()
|
||||||
|
|
||||||
|
if any(has_p_label):
|
||||||
|
|
||||||
|
trace_start_time = obspy.UTCDateTime(m["trace_start_time"])
|
||||||
|
pick_sample = m[p_labels][has_p_label][0]
|
||||||
|
|
||||||
|
start = trace_start_time + pick_sample / m["trace_sampling_rate_hz"] - 15
|
||||||
|
|
||||||
|
try:
|
||||||
|
st_p = create_stream(m, waveform, start)
|
||||||
|
if not (to_short(st_p)):
|
||||||
|
pred_p = get_pred(model, st_p)
|
||||||
|
label_true.append(1)
|
||||||
|
label_pred.append(pred_p)
|
||||||
|
except IndexError:
|
||||||
|
pass
|
||||||
|
|
||||||
|
try:
|
||||||
|
st_n = create_stream(m, waveform, trace_start_time + 1)
|
||||||
|
if not (to_short(st_n)):
|
||||||
|
pred_n = get_pred(model, st_n)
|
||||||
|
label_true.append(0)
|
||||||
|
label_pred.append(pred_n)
|
||||||
|
except IndexError:
|
||||||
|
pass
|
||||||
|
|
||||||
|
if any(has_s_label):
|
||||||
|
trace_start_time = obspy.UTCDateTime(m["trace_start_time"])
|
||||||
|
pick_sample = m[s_labels][has_s_label][0]
|
||||||
|
start = trace_start_time + pick_sample / m["trace_sampling_rate_hz"] - 15
|
||||||
|
|
||||||
|
try:
|
||||||
|
st_s = create_stream(m, waveform, start)
|
||||||
|
if not (to_short(st_s)):
|
||||||
|
pred_s = get_pred(model, st_s)
|
||||||
|
label_true.append(1)
|
||||||
|
label_pred.append(pred_s)
|
||||||
|
except IndexError:
|
||||||
|
pass
|
||||||
|
|
||||||
|
fpr, tpr, roc_thresholds = roc_curve(label_true, label_pred)
|
||||||
|
df = pd.DataFrame({"fpr": fpr, "tpr": tpr, "thresholds": roc_thresholds})
|
||||||
|
df.to_csv(fname)
|
||||||
|
|
||||||
|
precision, recall, prc_thresholds = precision_recall_curve(label_true, label_pred)
|
||||||
|
prc_thresholds_extra = np.append(prc_thresholds, -999)
|
||||||
|
df = pd.DataFrame(
|
||||||
|
{"pre": precision, "rec": recall, "thresholds": prc_thresholds_extra}
|
||||||
|
)
|
||||||
|
df.to_csv(fname.replace("roc", "pr"))
|
||||||
|
|
||||||
|
stats = {
|
||||||
|
"model": str(model_name),
|
||||||
|
"data": str(data_name),
|
||||||
|
"auc": float(roc_auc_score(label_true, label_pred)),
|
||||||
|
}
|
||||||
|
|
||||||
|
with open(f"stats___{model_name}___{data_name}.json", "w") as fp:
|
||||||
|
json.dump(stats, fp)
|
||||||
|
|
||||||
|
print(" finished", flush=True)
|
Loading…
Reference in New Issue
Block a user