336 lines
10 KiB
Python
336 lines
10 KiB
Python
|
"""
|
||
|
This script collects results in a folder, calculates performance metrics and writes them to csv.
|
||
|
"""
|
||
|
|
||
|
import argparse
|
||
|
from pathlib import Path
|
||
|
import logging
|
||
|
import pandas as pd
|
||
|
import numpy as np
|
||
|
from sklearn.metrics import (
|
||
|
precision_recall_curve,
|
||
|
precision_recall_fscore_support,
|
||
|
roc_auc_score,
|
||
|
matthews_corrcoef,
|
||
|
)
|
||
|
from tqdm import tqdm
|
||
|
|
||
|
|
||
|
def traverse_path(path, output, cross=False, resampled=False, baer=False):
|
||
|
"""
|
||
|
Traverses the given path and extracts results for each experiment and version
|
||
|
|
||
|
:param path: Root path
|
||
|
:param output: Path to write results csv to
|
||
|
:param cross: If true, expects cross-domain results.
|
||
|
:return: None
|
||
|
"""
|
||
|
path = Path(path)
|
||
|
|
||
|
results = []
|
||
|
|
||
|
exp_dirs = [x for x in path.iterdir() if x.is_dir()]
|
||
|
for exp_dir in tqdm(exp_dirs):
|
||
|
itr = exp_dir.iterdir()
|
||
|
if baer:
|
||
|
itr = [exp_dir] # Missing version directory in the structure
|
||
|
for version_dir in itr:
|
||
|
if not version_dir.is_dir():
|
||
|
pass
|
||
|
|
||
|
results.append(
|
||
|
process_version(
|
||
|
version_dir, cross=cross, resampled=resampled, baer=baer
|
||
|
)
|
||
|
)
|
||
|
|
||
|
results = pd.DataFrame(results)
|
||
|
if cross:
|
||
|
sort_keys = ["data", "model", "target", "lr", "version"]
|
||
|
else:
|
||
|
sort_keys = ["data", "model", "lr", "version"]
|
||
|
results.sort_values(sort_keys, inplace=True)
|
||
|
results.to_csv(output, index=False)
|
||
|
|
||
|
|
||
|
def process_version(version_dir: Path, cross: bool, resampled: bool, baer: bool):
|
||
|
"""
|
||
|
Extracts statistics for the given version of the given experiment.
|
||
|
|
||
|
:param version_dir: Path to the specific version
|
||
|
:param cross: If true, expects cross-domain results.
|
||
|
:return: Results dictionary
|
||
|
"""
|
||
|
stats = parse_exp_name(version_dir, cross=cross, resampled=resampled, baer=baer)
|
||
|
|
||
|
stats.update(eval_task1(version_dir))
|
||
|
stats.update(eval_task23(version_dir))
|
||
|
|
||
|
return stats
|
||
|
|
||
|
|
||
|
def parse_exp_name(version_dir, cross, resampled, baer):
|
||
|
if baer:
|
||
|
exp_name = version_dir.name
|
||
|
version = "0"
|
||
|
else:
|
||
|
exp_name = version_dir.parent.name
|
||
|
version = version_dir.name.split("_")[-1]
|
||
|
|
||
|
parts = exp_name.split("_")
|
||
|
target = None
|
||
|
sampling_rate = None
|
||
|
if cross or baer:
|
||
|
if len(parts) == 4:
|
||
|
data, model, lr, target = parts
|
||
|
else:
|
||
|
data, model, target = parts
|
||
|
lr = "0.001"
|
||
|
elif resampled:
|
||
|
if len(parts) == 5:
|
||
|
data, model, lr, target, sampling_rate = parts
|
||
|
else:
|
||
|
data, model, target, sampling_rate = parts
|
||
|
lr = "0.001"
|
||
|
else:
|
||
|
if len(parts) == 3:
|
||
|
data, model, lr = parts
|
||
|
else:
|
||
|
data, model, *_ = parts
|
||
|
lr = "0.001"
|
||
|
|
||
|
# lr = float(lr)
|
||
|
|
||
|
stats = {
|
||
|
"experiment": exp_name,
|
||
|
"data": data,
|
||
|
"model": model,
|
||
|
"lr": None,
|
||
|
"version": version,
|
||
|
}
|
||
|
|
||
|
if cross or baer:
|
||
|
stats["target"] = target
|
||
|
if resampled:
|
||
|
stats["target"] = target
|
||
|
stats["sampling_rate"] = sampling_rate
|
||
|
|
||
|
return stats
|
||
|
|
||
|
|
||
|
def eval_task1(version_dir: Path):
|
||
|
if not (
|
||
|
(version_dir / "dev_task1.csv").is_file()
|
||
|
and (version_dir / "test_task1.csv").is_file()
|
||
|
):
|
||
|
logging.warning(f"Directory {version_dir} does not contain task 1")
|
||
|
return {}
|
||
|
|
||
|
stats = {}
|
||
|
|
||
|
dev_pred = pd.read_csv(version_dir / "dev_task1.csv")
|
||
|
dev_pred["trace_type_bin"] = dev_pred["trace_type"] == "earthquake"
|
||
|
test_pred = pd.read_csv(version_dir / "test_task1.csv")
|
||
|
test_pred["trace_type_bin"] = test_pred["trace_type"] == "earthquake"
|
||
|
|
||
|
prec, recall, thr = precision_recall_curve(
|
||
|
dev_pred["trace_type_bin"], dev_pred["score_detection"]
|
||
|
)
|
||
|
|
||
|
f1 = 2 * prec * recall / (prec + recall)
|
||
|
auc = roc_auc_score(dev_pred["trace_type_bin"], dev_pred["score_detection"])
|
||
|
|
||
|
opt_index = np.nanargmax(f1) # F1 optimal threshold index
|
||
|
opt_thr = thr[opt_index] # F1 optimal threshold value
|
||
|
|
||
|
dev_stats = {
|
||
|
"dev_det_precision": prec[opt_index],
|
||
|
"dev_det_recall": recall[opt_index],
|
||
|
"dev_det_f1": f1[opt_index],
|
||
|
"dev_det_auc": auc,
|
||
|
"det_threshold": opt_thr,
|
||
|
}
|
||
|
stats.update(dev_stats)
|
||
|
|
||
|
prec, recall, f1, _ = precision_recall_fscore_support(
|
||
|
test_pred["trace_type_bin"],
|
||
|
test_pred["score_detection"] > opt_thr,
|
||
|
average="binary",
|
||
|
)
|
||
|
auc = roc_auc_score(test_pred["trace_type_bin"], test_pred["score_detection"])
|
||
|
test_stats = {
|
||
|
"test_det_precision": prec,
|
||
|
"test_det_recall": recall,
|
||
|
"test_det_f1": f1,
|
||
|
"test_det_auc": auc,
|
||
|
}
|
||
|
stats.update(test_stats)
|
||
|
|
||
|
return stats
|
||
|
|
||
|
|
||
|
def eval_task23(version_dir: Path):
|
||
|
print(version_dir / "dev_task23.csv")
|
||
|
if not (
|
||
|
(version_dir / "dev_task23.csv").is_file()
|
||
|
and (version_dir / "test_task23.csv").is_file()
|
||
|
):
|
||
|
logging.warning(f"Directory {version_dir} does not contain tasks 2 and 3")
|
||
|
return {}
|
||
|
|
||
|
stats = {}
|
||
|
|
||
|
dev_pred = pd.read_csv(version_dir / "dev_task23.csv")
|
||
|
dev_pred["phase_label_bin"] = dev_pred["phase_label"] == "P"
|
||
|
test_pred = pd.read_csv(version_dir / "test_task23.csv")
|
||
|
test_pred["phase_label_bin"] = test_pred["phase_label"] == "P"
|
||
|
|
||
|
def add_aux_columns(pred):
|
||
|
for col in ["s_sample_pred", "score_p_or_s"]:
|
||
|
if col not in pred.columns:
|
||
|
pred[col] = np.nan
|
||
|
|
||
|
add_aux_columns(dev_pred)
|
||
|
add_aux_columns(test_pred)
|
||
|
|
||
|
def nanmask(pred):
|
||
|
"""
|
||
|
Returns all entries that are nan in score_p_or_s, p_sample_pred and s_sample_pred
|
||
|
"""
|
||
|
mask = np.logical_and(
|
||
|
np.isnan(pred["p_sample_pred"]), np.isnan(pred["s_sample_pred"])
|
||
|
)
|
||
|
mask = np.logical_and(mask, np.isnan(pred["score_p_or_s"]))
|
||
|
return mask
|
||
|
|
||
|
if nanmask(dev_pred).all():
|
||
|
logging.warning(f"{version_dir} contains NaN predictions for tasks 2 and 3")
|
||
|
return {}
|
||
|
|
||
|
dev_pred = dev_pred[~nanmask(dev_pred)]
|
||
|
test_pred = test_pred[~nanmask(test_pred)]
|
||
|
|
||
|
skip_task2 = False
|
||
|
if (
|
||
|
np.logical_or(
|
||
|
np.isnan(dev_pred["score_p_or_s"]), np.isinf(dev_pred["score_p_or_s"])
|
||
|
).all()
|
||
|
or np.logical_or(
|
||
|
np.isnan(test_pred["score_p_or_s"]), np.isinf(test_pred["score_p_or_s"])
|
||
|
).all()
|
||
|
):
|
||
|
# For unfortunate combinations of nans and infs, otherwise weird scores can occur
|
||
|
skip_task2 = True
|
||
|
|
||
|
# Clipping removes infinitely likely P waves, usually resulting from models trained without S arrivals
|
||
|
dev_pred["score_p_or_s"] = np.clip(dev_pred["score_p_or_s"].values, -1e100, 1e100)
|
||
|
test_pred["score_p_or_s"] = np.clip(test_pred["score_p_or_s"].values, -1e100, 1e100)
|
||
|
|
||
|
dev_pred_restricted = dev_pred[~np.isnan(dev_pred["score_p_or_s"])]
|
||
|
test_pred_restricted = test_pred[~np.isnan(test_pred["score_p_or_s"])]
|
||
|
if len(dev_pred_restricted) > 0 and not skip_task2:
|
||
|
prec, recall, thr = precision_recall_curve(
|
||
|
dev_pred_restricted["phase_label_bin"], dev_pred_restricted["score_p_or_s"]
|
||
|
)
|
||
|
|
||
|
f1 = 2 * prec * recall / (prec + recall)
|
||
|
|
||
|
opt_index = np.nanargmax(f1) # F1 optimal threshold index
|
||
|
opt_thr = thr[opt_index] # F1 optimal threshold value
|
||
|
|
||
|
# Determine (approximately) optimal MCC threshold using 50 candidates
|
||
|
mcc_thrs = np.sort(dev_pred["score_p_or_s"].values)
|
||
|
mcc_thrs = mcc_thrs[np.linspace(0, len(mcc_thrs) - 1, 50, dtype=int)]
|
||
|
mccs = []
|
||
|
for thr in mcc_thrs:
|
||
|
mccs.append(
|
||
|
matthews_corrcoef(
|
||
|
dev_pred["phase_label_bin"], dev_pred["score_p_or_s"] > thr
|
||
|
)
|
||
|
)
|
||
|
mcc = np.max(mccs)
|
||
|
mcc_thr = mcc_thrs[np.argmax(mccs)]
|
||
|
|
||
|
dev_stats = {
|
||
|
"dev_phase_precision": prec[opt_index],
|
||
|
"dev_phase_recall": recall[opt_index],
|
||
|
"dev_phase_f1": f1[opt_index],
|
||
|
"phase_threshold": opt_thr,
|
||
|
"dev_phase_mcc": mcc,
|
||
|
"phase_threshold_mcc": mcc_thr,
|
||
|
}
|
||
|
stats.update(dev_stats)
|
||
|
|
||
|
prec, recall, f1, _ = precision_recall_fscore_support(
|
||
|
test_pred_restricted["phase_label_bin"],
|
||
|
test_pred_restricted["score_p_or_s"] > opt_thr,
|
||
|
average="binary",
|
||
|
)
|
||
|
mcc = matthews_corrcoef(
|
||
|
test_pred["phase_label_bin"], test_pred["score_p_or_s"] > mcc_thr
|
||
|
)
|
||
|
test_stats = {
|
||
|
"test_phase_precision": prec,
|
||
|
"test_phase_recall": recall,
|
||
|
"test_phase_f1": f1,
|
||
|
"test_phase_mcc": mcc,
|
||
|
}
|
||
|
stats.update(test_stats)
|
||
|
|
||
|
for pred, set_str in [(dev_pred, "dev"), (test_pred, "test")]:
|
||
|
for i, phase in enumerate(["P", "S"]):
|
||
|
pred_phase = pred[pred["phase_label"] == phase]
|
||
|
pred_col = f"{phase.lower()}_sample_pred"
|
||
|
|
||
|
if len(pred_phase) == 0:
|
||
|
continue
|
||
|
|
||
|
diff = (pred_phase[pred_col] - pred_phase["phase_onset"]) / pred_phase[
|
||
|
"sampling_rate"
|
||
|
]
|
||
|
|
||
|
stats[f"{set_str}_{phase}_mean_s"] = np.mean(diff)
|
||
|
stats[f"{set_str}_{phase}_std_s"] = np.sqrt(np.mean(diff**2))
|
||
|
stats[f"{set_str}_{phase}_mae_s"] = np.mean(np.abs(diff))
|
||
|
|
||
|
return stats
|
||
|
|
||
|
|
||
|
if __name__ == "__main__":
|
||
|
parser = argparse.ArgumentParser(
|
||
|
description="Collects results from all experiments in a folder and outputs them in condensed csv format."
|
||
|
)
|
||
|
parser.add_argument(
|
||
|
"path",
|
||
|
type=str,
|
||
|
help="Root path of predictions",
|
||
|
)
|
||
|
parser.add_argument(
|
||
|
"output",
|
||
|
type=str,
|
||
|
help="Path for the output csv",
|
||
|
)
|
||
|
parser.add_argument(
|
||
|
"--cross", action="store_true", help="If true, expects cross-domain results."
|
||
|
)
|
||
|
parser.add_argument(
|
||
|
"--resampled",
|
||
|
action="store_true",
|
||
|
help="If true, expects cross-domain cross-sampling rate results.",
|
||
|
)
|
||
|
parser.add_argument(
|
||
|
"--baer",
|
||
|
action="store_true",
|
||
|
help="If true, expects results from Baer-Kradolfer picker.",
|
||
|
)
|
||
|
|
||
|
args = parser.parse_args()
|
||
|
|
||
|
traverse_path(
|
||
|
args.path,
|
||
|
args.output,
|
||
|
cross=args.cross,
|
||
|
resampled=args.resampled,
|
||
|
baer=args.baer,
|
||
|
)
|