244 lines
8.5 KiB
Python
244 lines
8.5 KiB
Python
"""
|
|
This script implements functionality for evaluating models.
|
|
Given a model and a set of targets, it calculates and outputs predictions.
|
|
"""
|
|
|
|
import seisbench.generate as sbg
|
|
|
|
import argparse
|
|
import pandas as pd
|
|
import yaml
|
|
from pathlib import Path
|
|
import pytorch_lightning as pl
|
|
from torch.utils.data import DataLoader
|
|
import torch
|
|
|
|
import models, data
|
|
import logging
|
|
from util import default_workers, load_best_model_data
|
|
import time
|
|
import datetime
|
|
from config_loader import models_path, project_path
|
|
import os
|
|
|
|
data_aliases = {
|
|
"ethz": "ETHZ",
|
|
"geofon": "GEOFON",
|
|
"stead": "STEAD",
|
|
"neic": "NEIC",
|
|
"instance": "InstanceCountsCombined",
|
|
"iquique": "Iquique",
|
|
"lendb": "LenDB",
|
|
"scedc": "SCEDC",
|
|
}
|
|
|
|
|
|
def main(weights, targets, sets, batchsize, num_workers, sampling_rate=None, sweep_id=None):
|
|
weights = Path(weights)
|
|
targets = Path(os.path.abspath(targets))
|
|
print(targets)
|
|
# print()
|
|
sets = sets.split(",")
|
|
|
|
checkpoint_path, version = load_best_model_data(sweep_id, weights)
|
|
logging.warning("Starting evaluation of model: \n" + checkpoint_path)
|
|
|
|
config_path = f"{models_path}/{weights}/{version}/hparams.yaml"
|
|
with open(config_path, "r") as f:
|
|
# config = yaml.safe_load(f)
|
|
config = yaml.full_load(f)
|
|
|
|
model_name = config["model_name"][0] if type(config["model_name"]) == list else config["model_name"]
|
|
|
|
model_cls = models.__getattribute__(model_name + "Lit")
|
|
model = model_cls.load_from_checkpoint(checkpoint_path)
|
|
|
|
data_name = data_aliases[targets.name] if targets.name in data_aliases else None
|
|
|
|
if data_name != config["dataset_name"] and targets.name in data_aliases:
|
|
logging.warning("Detected cross-domain evaluation")
|
|
pred_root = "pred_cross"
|
|
parts = weights.name.split()
|
|
weight_path_name = "_".join(parts[:2] + [targets.name] + parts[2:])
|
|
|
|
else:
|
|
pred_root = "pred"
|
|
weight_path_name = weights.name
|
|
|
|
if data_name is not None:
|
|
dataset = data.get_dataset_by_name(data_name)(
|
|
sampling_rate=100, component_order="ZNE", dimension_order="NCW", cache="full"
|
|
)
|
|
else:
|
|
data_path = project_path + '/' + config['data_path']
|
|
print("Loading dataset: ", data_path)
|
|
dataset = data.get_custom_dataset(data_path)
|
|
|
|
if sampling_rate is not None:
|
|
dataset.sampling_rate = sampling_rate
|
|
pred_root = pred_root + "_resampled"
|
|
weight_path_name = weight_path_name + f"_{sampling_rate}"
|
|
|
|
for eval_set in sets:
|
|
split = dataset.get_split(eval_set)
|
|
if targets.name == "instance":
|
|
logging.warning(
|
|
"Overwriting noise trace_names to allow correct identification"
|
|
)
|
|
# Replace trace names for noise entries
|
|
split._metadata["trace_name"].values[
|
|
-len(split.datasets[-1]) :
|
|
] = split._metadata["trace_name"][-len(split.datasets[-1]) :].apply(
|
|
lambda x: "noise_" + x
|
|
)
|
|
split._build_trace_name_to_idx_dict()
|
|
|
|
logging.warning(f"Starting set {eval_set}")
|
|
split.preload_waveforms(pbar=True)
|
|
print("eval set shape", split.metadata.shape)
|
|
|
|
for task in ["1", "23"]:
|
|
task_csv = targets / f"task{task}.csv"
|
|
|
|
if not task_csv.is_file():
|
|
continue
|
|
|
|
logging.warning(f"Starting task {task}")
|
|
|
|
task_targets = pd.read_csv(task_csv)
|
|
task_targets = task_targets[task_targets["trace_split"] == eval_set]
|
|
if task == "1" and targets.name == "instance":
|
|
border = _identify_instance_dataset_border(task_targets)
|
|
task_targets["trace_name"].values[border:] = task_targets["trace_name"][
|
|
border:
|
|
].apply(lambda x: "noise_" + x)
|
|
|
|
if sampling_rate is not None:
|
|
for key in ["start_sample", "end_sample", "phase_onset"]:
|
|
if key not in task_targets.columns:
|
|
continue
|
|
task_targets[key] = (
|
|
task_targets[key]
|
|
* sampling_rate
|
|
/ task_targets["sampling_rate"]
|
|
)
|
|
task_targets[sampling_rate] = sampling_rate
|
|
|
|
restrict_to_phase = config.get("restrict_to_phase", None)
|
|
if restrict_to_phase is not None and "phase_label" in task_targets.columns:
|
|
mask = task_targets["phase_label"].isin(list(restrict_to_phase))
|
|
task_targets = task_targets[mask]
|
|
|
|
if restrict_to_phase is not None and task == "1":
|
|
logging.warning("Skipping task 1 as restrict_to_phase is set.")
|
|
continue
|
|
|
|
generator = sbg.SteeredGenerator(split, task_targets)
|
|
generator.add_augmentations(model.get_eval_augmentations())
|
|
|
|
loader = DataLoader(
|
|
generator, batch_size=batchsize, shuffle=False, num_workers=num_workers
|
|
)
|
|
# trainer = pl.Trainer(accelerator="gpu", devices=1)
|
|
trainer = pl.Trainer()
|
|
|
|
predictions = trainer.predict(model, loader)
|
|
|
|
# Merge batches
|
|
merged_predictions = []
|
|
for i, _ in enumerate(predictions[0]):
|
|
merged_predictions.append(torch.cat([x[i] for x in predictions]))
|
|
|
|
merged_predictions = [x.cpu().numpy() for x in merged_predictions]
|
|
task_targets["score_detection"] = merged_predictions[0]
|
|
task_targets["score_p_or_s"] = merged_predictions[1]
|
|
task_targets["p_sample_pred"] = (
|
|
merged_predictions[2] + task_targets["start_sample"]
|
|
)
|
|
task_targets["s_sample_pred"] = (
|
|
merged_predictions[3] + task_targets["start_sample"]
|
|
)
|
|
|
|
pred_path = (
|
|
weights.parent.parent
|
|
/ pred_root
|
|
/ weight_path_name
|
|
/ version
|
|
/ f"{eval_set}_task{task}.csv"
|
|
)
|
|
pred_path.parent.mkdir(exist_ok=True, parents=True)
|
|
task_targets.to_csv(pred_path, index=False)
|
|
|
|
|
|
def _identify_instance_dataset_border(task_targets):
|
|
"""
|
|
Calculates the dataset border between Signal and Noise for instance,
|
|
assuming it is the only place where the bucket number does not increase
|
|
"""
|
|
buckets = task_targets["trace_name"].apply(lambda x: int(x.split("$")[0][6:]))
|
|
|
|
last_bucket = 0
|
|
for i, bucket in enumerate(buckets):
|
|
if bucket < last_bucket:
|
|
return i
|
|
last_bucket = bucket
|
|
|
|
|
|
if __name__ == "__main__":
|
|
code_start_time = time.perf_counter()
|
|
parser = argparse.ArgumentParser(
|
|
description="Evaluate a trained model using a set of targets."
|
|
)
|
|
parser.add_argument(
|
|
"weights",
|
|
type=str,
|
|
help="Path to weights. Expected to be in models_path directory."
|
|
"The script will automatically load the configuration and the model. "
|
|
"The script always uses the checkpoint with lowest validation loss."
|
|
"If sweep_id is provided the script considers only the checkpoints generated by that sweep."
|
|
"Predictions will be written into the weights path as csv."
|
|
"Note: Due to pytorch lightning internals, there exist two weights folders, "
|
|
"{weights} and {weight}_{weights}. Please use the former as parameter",
|
|
)
|
|
parser.add_argument(
|
|
"targets",
|
|
type=str,
|
|
help="Path to evaluation targets folder. "
|
|
"The script will detect which tasks are present base on file names.",
|
|
)
|
|
parser.add_argument(
|
|
"--sets",
|
|
type=str,
|
|
default="dev,test",
|
|
help="Sets on which to evaluate, separated by commata. Defaults to dev and test.",
|
|
)
|
|
parser.add_argument("--batchsize", type=int, default=1024, help="Batch size")
|
|
parser.add_argument(
|
|
"--num_workers",
|
|
default=default_workers,
|
|
type=int,
|
|
help="Number of workers for data loader",
|
|
)
|
|
parser.add_argument(
|
|
"--sampling_rate", type=float, help="Overwrites the sampling rate in the data"
|
|
)
|
|
parser.add_argument(
|
|
"--sweep_id", type=str, help="wandb sweep_id", required=False, default=None
|
|
)
|
|
|
|
args = parser.parse_args()
|
|
|
|
main(
|
|
args.weights,
|
|
args.targets,
|
|
args.sets,
|
|
batchsize=args.batchsize,
|
|
num_workers=args.num_workers,
|
|
sampling_rate=args.sampling_rate,
|
|
sweep_id=args.sweep_id
|
|
)
|
|
running_time = str(
|
|
datetime.timedelta(seconds=time.perf_counter() - code_start_time)
|
|
)
|
|
print(f"Running time: {running_time}")
|