platform-demo-scripts/scripts/collect_results.py

336 lines
10 KiB
Python
Raw Permalink Normal View History

"""
This script collects results in a folder, calculates performance metrics and writes them to csv.
"""
import argparse
from pathlib import Path
import logging
import pandas as pd
import numpy as np
from sklearn.metrics import (
precision_recall_curve,
precision_recall_fscore_support,
roc_auc_score,
matthews_corrcoef,
)
from tqdm import tqdm
def traverse_path(path, output, cross=False, resampled=False, baer=False):
"""
Traverses the given path and extracts results for each experiment and version
:param path: Root path
:param output: Path to write results csv to
:param cross: If true, expects cross-domain results.
:return: None
"""
path = Path(path)
results = []
exp_dirs = [x for x in path.iterdir() if x.is_dir()]
for exp_dir in tqdm(exp_dirs):
itr = exp_dir.iterdir()
if baer:
itr = [exp_dir] # Missing version directory in the structure
for version_dir in itr:
if not version_dir.is_dir():
pass
results.append(
process_version(
version_dir, cross=cross, resampled=resampled, baer=baer
)
)
results = pd.DataFrame(results)
if cross:
sort_keys = ["data", "model", "target", "lr", "version"]
else:
sort_keys = ["data", "model", "lr", "version"]
results.sort_values(sort_keys, inplace=True)
results.to_csv(output, index=False)
def process_version(version_dir: Path, cross: bool, resampled: bool, baer: bool):
"""
Extracts statistics for the given version of the given experiment.
:param version_dir: Path to the specific version
:param cross: If true, expects cross-domain results.
:return: Results dictionary
"""
stats = parse_exp_name(version_dir, cross=cross, resampled=resampled, baer=baer)
stats.update(eval_task1(version_dir))
stats.update(eval_task23(version_dir))
return stats
def parse_exp_name(version_dir, cross, resampled, baer):
if baer:
exp_name = version_dir.name
version = "0"
else:
exp_name = version_dir.parent.name
version = version_dir.name.split("_")[-1]
parts = exp_name.split("_")
target = None
sampling_rate = None
if cross or baer:
if len(parts) == 4:
data, model, lr, target = parts
else:
data, model, target = parts
lr = "0.001"
elif resampled:
if len(parts) == 5:
data, model, lr, target, sampling_rate = parts
else:
data, model, target, sampling_rate = parts
lr = "0.001"
else:
if len(parts) == 3:
data, model, lr = parts
else:
data, model, *_ = parts
lr = "0.001"
# lr = float(lr)
stats = {
"experiment": exp_name,
"data": data,
"model": model,
"lr": None,
"version": version,
}
if cross or baer:
stats["target"] = target
if resampled:
stats["target"] = target
stats["sampling_rate"] = sampling_rate
return stats
def eval_task1(version_dir: Path):
if not (
(version_dir / "dev_task1.csv").is_file()
and (version_dir / "test_task1.csv").is_file()
):
logging.warning(f"Directory {version_dir} does not contain task 1")
return {}
stats = {}
dev_pred = pd.read_csv(version_dir / "dev_task1.csv")
dev_pred["trace_type_bin"] = dev_pred["trace_type"] == "earthquake"
test_pred = pd.read_csv(version_dir / "test_task1.csv")
test_pred["trace_type_bin"] = test_pred["trace_type"] == "earthquake"
prec, recall, thr = precision_recall_curve(
dev_pred["trace_type_bin"], dev_pred["score_detection"]
)
f1 = 2 * prec * recall / (prec + recall)
auc = roc_auc_score(dev_pred["trace_type_bin"], dev_pred["score_detection"])
opt_index = np.nanargmax(f1) # F1 optimal threshold index
opt_thr = thr[opt_index] # F1 optimal threshold value
dev_stats = {
"dev_det_precision": prec[opt_index],
"dev_det_recall": recall[opt_index],
"dev_det_f1": f1[opt_index],
"dev_det_auc": auc,
"det_threshold": opt_thr,
}
stats.update(dev_stats)
prec, recall, f1, _ = precision_recall_fscore_support(
test_pred["trace_type_bin"],
test_pred["score_detection"] > opt_thr,
average="binary",
)
auc = roc_auc_score(test_pred["trace_type_bin"], test_pred["score_detection"])
test_stats = {
"test_det_precision": prec,
"test_det_recall": recall,
"test_det_f1": f1,
"test_det_auc": auc,
}
stats.update(test_stats)
return stats
def eval_task23(version_dir: Path):
print(version_dir / "dev_task23.csv")
if not (
(version_dir / "dev_task23.csv").is_file()
and (version_dir / "test_task23.csv").is_file()
):
logging.warning(f"Directory {version_dir} does not contain tasks 2 and 3")
return {}
stats = {}
dev_pred = pd.read_csv(version_dir / "dev_task23.csv")
dev_pred["phase_label_bin"] = dev_pred["phase_label"] == "P"
test_pred = pd.read_csv(version_dir / "test_task23.csv")
test_pred["phase_label_bin"] = test_pred["phase_label"] == "P"
def add_aux_columns(pred):
for col in ["s_sample_pred", "score_p_or_s"]:
if col not in pred.columns:
pred[col] = np.nan
add_aux_columns(dev_pred)
add_aux_columns(test_pred)
def nanmask(pred):
"""
Returns all entries that are nan in score_p_or_s, p_sample_pred and s_sample_pred
"""
mask = np.logical_and(
np.isnan(pred["p_sample_pred"]), np.isnan(pred["s_sample_pred"])
)
mask = np.logical_and(mask, np.isnan(pred["score_p_or_s"]))
return mask
if nanmask(dev_pred).all():
logging.warning(f"{version_dir} contains NaN predictions for tasks 2 and 3")
return {}
dev_pred = dev_pred[~nanmask(dev_pred)]
test_pred = test_pred[~nanmask(test_pred)]
skip_task2 = False
if (
np.logical_or(
np.isnan(dev_pred["score_p_or_s"]), np.isinf(dev_pred["score_p_or_s"])
).all()
or np.logical_or(
np.isnan(test_pred["score_p_or_s"]), np.isinf(test_pred["score_p_or_s"])
).all()
):
# For unfortunate combinations of nans and infs, otherwise weird scores can occur
skip_task2 = True
# Clipping removes infinitely likely P waves, usually resulting from models trained without S arrivals
dev_pred["score_p_or_s"] = np.clip(dev_pred["score_p_or_s"].values, -1e100, 1e100)
test_pred["score_p_or_s"] = np.clip(test_pred["score_p_or_s"].values, -1e100, 1e100)
dev_pred_restricted = dev_pred[~np.isnan(dev_pred["score_p_or_s"])]
test_pred_restricted = test_pred[~np.isnan(test_pred["score_p_or_s"])]
if len(dev_pred_restricted) > 0 and not skip_task2:
prec, recall, thr = precision_recall_curve(
dev_pred_restricted["phase_label_bin"], dev_pred_restricted["score_p_or_s"]
)
f1 = 2 * prec * recall / (prec + recall)
opt_index = np.nanargmax(f1) # F1 optimal threshold index
opt_thr = thr[opt_index] # F1 optimal threshold value
# Determine (approximately) optimal MCC threshold using 50 candidates
mcc_thrs = np.sort(dev_pred["score_p_or_s"].values)
mcc_thrs = mcc_thrs[np.linspace(0, len(mcc_thrs) - 1, 50, dtype=int)]
mccs = []
for thr in mcc_thrs:
mccs.append(
matthews_corrcoef(
dev_pred["phase_label_bin"], dev_pred["score_p_or_s"] > thr
)
)
mcc = np.max(mccs)
mcc_thr = mcc_thrs[np.argmax(mccs)]
dev_stats = {
"dev_phase_precision": prec[opt_index],
"dev_phase_recall": recall[opt_index],
"dev_phase_f1": f1[opt_index],
"phase_threshold": opt_thr,
"dev_phase_mcc": mcc,
"phase_threshold_mcc": mcc_thr,
}
stats.update(dev_stats)
prec, recall, f1, _ = precision_recall_fscore_support(
test_pred_restricted["phase_label_bin"],
test_pred_restricted["score_p_or_s"] > opt_thr,
average="binary",
)
mcc = matthews_corrcoef(
test_pred["phase_label_bin"], test_pred["score_p_or_s"] > mcc_thr
)
test_stats = {
"test_phase_precision": prec,
"test_phase_recall": recall,
"test_phase_f1": f1,
"test_phase_mcc": mcc,
}
stats.update(test_stats)
for pred, set_str in [(dev_pred, "dev"), (test_pred, "test")]:
for i, phase in enumerate(["P", "S"]):
pred_phase = pred[pred["phase_label"] == phase]
pred_col = f"{phase.lower()}_sample_pred"
if len(pred_phase) == 0:
continue
diff = (pred_phase[pred_col] - pred_phase["phase_onset"]) / pred_phase[
"sampling_rate"
]
stats[f"{set_str}_{phase}_mean_s"] = np.mean(diff)
stats[f"{set_str}_{phase}_std_s"] = np.sqrt(np.mean(diff**2))
stats[f"{set_str}_{phase}_mae_s"] = np.mean(np.abs(diff))
return stats
if __name__ == "__main__":
parser = argparse.ArgumentParser(
description="Collects results from all experiments in a folder and outputs them in condensed csv format."
)
parser.add_argument(
"path",
type=str,
help="Root path of predictions",
)
parser.add_argument(
"output",
type=str,
help="Path for the output csv",
)
parser.add_argument(
"--cross", action="store_true", help="If true, expects cross-domain results."
)
parser.add_argument(
"--resampled",
action="store_true",
help="If true, expects cross-domain cross-sampling rate results.",
)
parser.add_argument(
"--baer",
action="store_true",
help="If true, expects results from Baer-Kradolfer picker.",
)
args = parser.parse_args()
traverse_path(
args.path,
args.output,
cross=args.cross,
resampled=args.resampled,
baer=args.baer,
)