platform-demo-scripts/scripts/eval.py

244 lines
8.5 KiB
Python
Raw Permalink Normal View History

"""
This script implements functionality for evaluating models.
Given a model and a set of targets, it calculates and outputs predictions.
"""
import seisbench.generate as sbg
import argparse
import pandas as pd
import yaml
from pathlib import Path
import pytorch_lightning as pl
from torch.utils.data import DataLoader
import torch
import models, data
import logging
from util import default_workers, load_best_model_data
import time
import datetime
from config_loader import models_path, project_path
import os
data_aliases = {
"ethz": "ETHZ",
"geofon": "GEOFON",
"stead": "STEAD",
"neic": "NEIC",
"instance": "InstanceCountsCombined",
"iquique": "Iquique",
"lendb": "LenDB",
"scedc": "SCEDC",
}
def main(weights, targets, sets, batchsize, num_workers, sampling_rate=None, sweep_id=None):
weights = Path(weights)
targets = Path(os.path.abspath(targets))
print(targets)
# print()
sets = sets.split(",")
checkpoint_path, version = load_best_model_data(sweep_id, weights)
logging.warning("Starting evaluation of model: \n" + checkpoint_path)
config_path = f"{models_path}/{weights}/{version}/hparams.yaml"
with open(config_path, "r") as f:
# config = yaml.safe_load(f)
config = yaml.full_load(f)
model_name = config["model_name"][0] if type(config["model_name"]) == list else config["model_name"]
model_cls = models.__getattribute__(model_name + "Lit")
model = model_cls.load_from_checkpoint(checkpoint_path)
data_name = data_aliases[targets.name] if targets.name in data_aliases else None
if data_name != config["dataset_name"] and targets.name in data_aliases:
logging.warning("Detected cross-domain evaluation")
pred_root = "pred_cross"
parts = weights.name.split()
weight_path_name = "_".join(parts[:2] + [targets.name] + parts[2:])
else:
pred_root = "pred"
weight_path_name = weights.name
if data_name is not None:
dataset = data.get_dataset_by_name(data_name)(
sampling_rate=100, component_order="ZNE", dimension_order="NCW", cache="full"
)
else:
data_path = project_path + '/' + config['data_path']
print("Loading dataset: ", data_path)
dataset = data.get_custom_dataset(data_path)
if sampling_rate is not None:
dataset.sampling_rate = sampling_rate
pred_root = pred_root + "_resampled"
weight_path_name = weight_path_name + f"_{sampling_rate}"
for eval_set in sets:
split = dataset.get_split(eval_set)
if targets.name == "instance":
logging.warning(
"Overwriting noise trace_names to allow correct identification"
)
# Replace trace names for noise entries
split._metadata["trace_name"].values[
-len(split.datasets[-1]) :
] = split._metadata["trace_name"][-len(split.datasets[-1]) :].apply(
lambda x: "noise_" + x
)
split._build_trace_name_to_idx_dict()
logging.warning(f"Starting set {eval_set}")
split.preload_waveforms(pbar=True)
print("eval set shape", split.metadata.shape)
for task in ["1", "23"]:
task_csv = targets / f"task{task}.csv"
if not task_csv.is_file():
continue
logging.warning(f"Starting task {task}")
task_targets = pd.read_csv(task_csv)
task_targets = task_targets[task_targets["trace_split"] == eval_set]
if task == "1" and targets.name == "instance":
border = _identify_instance_dataset_border(task_targets)
task_targets["trace_name"].values[border:] = task_targets["trace_name"][
border:
].apply(lambda x: "noise_" + x)
if sampling_rate is not None:
for key in ["start_sample", "end_sample", "phase_onset"]:
if key not in task_targets.columns:
continue
task_targets[key] = (
task_targets[key]
* sampling_rate
/ task_targets["sampling_rate"]
)
task_targets[sampling_rate] = sampling_rate
restrict_to_phase = config.get("restrict_to_phase", None)
if restrict_to_phase is not None and "phase_label" in task_targets.columns:
mask = task_targets["phase_label"].isin(list(restrict_to_phase))
task_targets = task_targets[mask]
if restrict_to_phase is not None and task == "1":
logging.warning("Skipping task 1 as restrict_to_phase is set.")
continue
generator = sbg.SteeredGenerator(split, task_targets)
generator.add_augmentations(model.get_eval_augmentations())
loader = DataLoader(
generator, batch_size=batchsize, shuffle=False, num_workers=num_workers
)
# trainer = pl.Trainer(accelerator="gpu", devices=1)
trainer = pl.Trainer()
predictions = trainer.predict(model, loader)
# Merge batches
merged_predictions = []
for i, _ in enumerate(predictions[0]):
merged_predictions.append(torch.cat([x[i] for x in predictions]))
merged_predictions = [x.cpu().numpy() for x in merged_predictions]
task_targets["score_detection"] = merged_predictions[0]
task_targets["score_p_or_s"] = merged_predictions[1]
task_targets["p_sample_pred"] = (
merged_predictions[2] + task_targets["start_sample"]
)
task_targets["s_sample_pred"] = (
merged_predictions[3] + task_targets["start_sample"]
)
pred_path = (
weights.parent.parent
/ pred_root
/ weight_path_name
/ version
/ f"{eval_set}_task{task}.csv"
)
pred_path.parent.mkdir(exist_ok=True, parents=True)
task_targets.to_csv(pred_path, index=False)
def _identify_instance_dataset_border(task_targets):
"""
Calculates the dataset border between Signal and Noise for instance,
assuming it is the only place where the bucket number does not increase
"""
buckets = task_targets["trace_name"].apply(lambda x: int(x.split("$")[0][6:]))
last_bucket = 0
for i, bucket in enumerate(buckets):
if bucket < last_bucket:
return i
last_bucket = bucket
if __name__ == "__main__":
code_start_time = time.perf_counter()
parser = argparse.ArgumentParser(
description="Evaluate a trained model using a set of targets."
)
parser.add_argument(
"weights",
type=str,
help="Path to weights. Expected to be in models_path directory."
"The script will automatically load the configuration and the model. "
"The script always uses the checkpoint with lowest validation loss."
"If sweep_id is provided the script considers only the checkpoints generated by that sweep."
"Predictions will be written into the weights path as csv."
"Note: Due to pytorch lightning internals, there exist two weights folders, "
"{weights} and {weight}_{weights}. Please use the former as parameter",
)
parser.add_argument(
"targets",
type=str,
help="Path to evaluation targets folder. "
"The script will detect which tasks are present base on file names.",
)
parser.add_argument(
"--sets",
type=str,
default="dev,test",
help="Sets on which to evaluate, separated by commata. Defaults to dev and test.",
)
parser.add_argument("--batchsize", type=int, default=1024, help="Batch size")
parser.add_argument(
"--num_workers",
default=default_workers,
type=int,
help="Number of workers for data loader",
)
parser.add_argument(
"--sampling_rate", type=float, help="Overwrites the sampling rate in the data"
)
parser.add_argument(
"--sweep_id", type=str, help="wandb sweep_id", required=False, default=None
)
args = parser.parse_args()
main(
args.weights,
args.targets,
args.sets,
batchsize=args.batchsize,
num_workers=args.num_workers,
sampling_rate=args.sampling_rate,
sweep_id=args.sweep_id
)
running_time = str(
datetime.timedelta(seconds=time.perf_counter() - code_start_time)
)
print(f"Running time: {running_time}")