Added logging MAE for the best runs, option to run a pipeline on a specific dataset, template bash scripts, GPLv3 license. Modified behavior of generating eval targets, it is skipped if targets already exist.
This commit is contained in:
19
scripts/convert_data_template.sh
Normal file
19
scripts/convert_data_template.sh
Normal file
@@ -0,0 +1,19 @@
|
||||
#!/bin/bash
|
||||
#SBATCH --job-name=mseeds_to_seisbench
|
||||
#SBATCH --time=1:00:00
|
||||
#SBATCH --account= ### to fill
|
||||
#SBATCH --partition plgrid
|
||||
#SBATCH --cpus-per-task=1
|
||||
#SBATCH --ntasks-per-node=1
|
||||
#SBATCH --mem=24gb
|
||||
|
||||
|
||||
## activate conda environment
|
||||
source /path/to/mambaforge/bin/activate ### to adjust
|
||||
conda activate epos-ai-train
|
||||
|
||||
input_path="/path/to/folder/with/mseed/files"
|
||||
catalog_path="/path/to/catolog.xml"
|
||||
output_path="/path/to/output/in/seisbench_format"
|
||||
|
||||
python mseeds_to_seisbench.py --input_path $input_path --catalog_path $catalog_path --output_path $output_path
|
@@ -39,10 +39,15 @@ from pathlib import Path
|
||||
import pandas as pd
|
||||
import numpy as np
|
||||
from tqdm import tqdm
|
||||
|
||||
import logging
|
||||
from models import phase_dict
|
||||
|
||||
|
||||
|
||||
logging.root.setLevel(logging.INFO)
|
||||
logger = logging.getLogger('targets generator')
|
||||
|
||||
|
||||
def main(dataset_name, output, tasks, sampling_rate, noise_before_events):
|
||||
np.random.seed(42)
|
||||
tasks = [str(i) in tasks.split(",") for i in range(1, 4)]
|
||||
@@ -64,17 +69,24 @@ def main(dataset_name, output, tasks, sampling_rate, noise_before_events):
|
||||
dataset = sbd.WaveformDataset(dataset_name, **dataset_args)
|
||||
|
||||
output = Path(output)
|
||||
output.mkdir(parents=True, exist_ok=False)
|
||||
output.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
if "split" in dataset.metadata.columns:
|
||||
dataset.filter(dataset["split"].isin(["dev", "test"]), inplace=True)
|
||||
|
||||
dataset.preload_waveforms(pbar=True)
|
||||
|
||||
|
||||
if tasks[0]:
|
||||
generate_task1(dataset, output, sampling_rate, noise_before_events)
|
||||
if not Path.exists(output / "task1.csv"):
|
||||
generate_task1(dataset, output, sampling_rate, noise_before_events)
|
||||
else:
|
||||
logger.info(f"{output}/task1.csv already exists. Skipping generation of targets.")
|
||||
if tasks[1] or tasks[2]:
|
||||
generate_task23(dataset, output, sampling_rate)
|
||||
if not Path.exists(output / "task23.csv"):
|
||||
generate_task23(dataset, output, sampling_rate)
|
||||
else:
|
||||
logger.info(f"{output}/task23.csv already exists. Skipping generation of targets.")
|
||||
|
||||
|
||||
|
||||
def generate_task1(dataset, output, sampling_rate, noise_before_events):
|
||||
|
@@ -18,9 +18,7 @@ from dotenv import load_dotenv
|
||||
import models
|
||||
import train
|
||||
import util
|
||||
from config_loader import config as common_config
|
||||
from config_loader import models_path, dataset_name, seed, experiment_count
|
||||
|
||||
import config_loader
|
||||
|
||||
torch.multiprocessing.set_sharing_strategy('file_system')
|
||||
os.system("ulimit -n unlimited")
|
||||
@@ -35,8 +33,6 @@ if host is None:
|
||||
|
||||
|
||||
wandb.login(key=wandb_api_key, host=host)
|
||||
# wandb.login(key=wandb_api_key)
|
||||
|
||||
wandb_project_name = os.environ.get("WANDB_PROJECT")
|
||||
wandb_user_name = os.environ.get("WANDB_USER")
|
||||
|
||||
@@ -68,11 +64,9 @@ class HyperparameterSweep:
|
||||
|
||||
# Create the sweep
|
||||
self.sweep_id = wandb.sweep(self.sweep_config, project=self.project_name)
|
||||
|
||||
logger.info("Created sweep with ID: " + self.sweep_id)
|
||||
|
||||
# Run the sweep
|
||||
wandb.agent(self.sweep_id, function=self.run_experiment, count=experiment_count)
|
||||
wandb.agent(self.sweep_id, function=self.run_experiment, count=config_loader.experiment_count)
|
||||
|
||||
def all_runs_finished(self):
|
||||
|
||||
@@ -96,13 +90,14 @@ class HyperparameterSweep:
|
||||
logger.debug("Starting a new run...")
|
||||
run = wandb.init(
|
||||
project=self.project_name,
|
||||
config=common_config,
|
||||
)
|
||||
|
||||
wandb.run.log_code(
|
||||
".",
|
||||
include_fn=lambda path: path.endswith(os.path.basename(__file__))
|
||||
config=config_loader.config,
|
||||
save_code=True
|
||||
)
|
||||
run.log_code(
|
||||
root=".",
|
||||
include_fn=lambda path: path.endswith(".py") or path.endswith(".sh"),
|
||||
exclude_fn=lambda path: path.endswith("template.sh")
|
||||
) # not working as expected
|
||||
|
||||
model_name = wandb.config.model_name[0]
|
||||
model_args = models.get_model_specific_args(wandb.config)
|
||||
@@ -116,8 +111,8 @@ class HyperparameterSweep:
|
||||
wandb_logger.watch(model)
|
||||
|
||||
# CSV logger - also used for saving configuration as yaml
|
||||
experiment_name = f"{dataset_name}_{model_name}"
|
||||
csv_logger = CSVLogger(models_path, experiment_name, version=run.id)
|
||||
experiment_name = f"{config_loader.dataset_name}_{model_name}"
|
||||
csv_logger = CSVLogger(config_loader.models_path, experiment_name, version=run.id)
|
||||
csv_logger.log_hyperparams(wandb.config)
|
||||
|
||||
loggers = [wandb_logger, csv_logger]
|
||||
@@ -131,7 +126,7 @@ class HyperparameterSweep:
|
||||
filename=experiment_signature + "-{epoch}-{val_loss:.3f}",
|
||||
monitor="val_loss",
|
||||
mode="min",
|
||||
dirpath=f"{models_path}/{experiment_name}/",
|
||||
dirpath=f"{config_loader.models_path}/{experiment_name}/",
|
||||
) # save_top_k=1, monitor="val_loss", mode="min": save the best model in terms of validation loss
|
||||
checkpoint_callback.STARTING_VERSION = 1
|
||||
|
||||
@@ -143,7 +138,7 @@ class HyperparameterSweep:
|
||||
callbacks = [checkpoint_callback, early_stopping_callback]
|
||||
|
||||
trainer = pl.Trainer(
|
||||
default_root_dir=models_path,
|
||||
default_root_dir=config_loader.models_path,
|
||||
logger=loggers,
|
||||
callbacks=callbacks,
|
||||
**get_trainer_args(wandb.config)
|
||||
@@ -162,7 +157,7 @@ class HyperparameterSweep:
|
||||
def start_sweep(sweep_config):
|
||||
|
||||
logger.info("Starting sweep with config: " + str(sweep_config))
|
||||
set_random_seed(seed)
|
||||
set_random_seed(config_loader.seed)
|
||||
sweep_runner = HyperparameterSweep(project_name=wandb_project_name, sweep_config=sweep_config)
|
||||
sweep_runner.run_sweep()
|
||||
|
||||
|
250
scripts/mseeds_to_seisbench.py
Normal file
250
scripts/mseeds_to_seisbench.py
Normal file
@@ -0,0 +1,250 @@
|
||||
import os
|
||||
import pandas as pd
|
||||
import glob
|
||||
from pathlib import Path
|
||||
|
||||
import obspy
|
||||
from obspy.core.event import read_events
|
||||
|
||||
import seisbench
|
||||
import seisbench.data as sbd
|
||||
import seisbench.util as sbu
|
||||
import sys
|
||||
import logging
|
||||
import argparse
|
||||
|
||||
|
||||
logging.basicConfig(filename="output.out",
|
||||
filemode='a',
|
||||
format='%(asctime)s,%(msecs)d %(name)s %(levelname)s %(message)s',
|
||||
datefmt='%H:%M:%S',
|
||||
level=logging.DEBUG)
|
||||
|
||||
|
||||
|
||||
logger = logging.getLogger('converter')
|
||||
|
||||
def create_traces_catalog(directory, years):
|
||||
for year in years:
|
||||
directory = f"{directory}/{year}"
|
||||
files = glob.glob(directory)
|
||||
traces = []
|
||||
for i, f in enumerate(files):
|
||||
st = obspy.read(f)
|
||||
|
||||
for tr in st.traces:
|
||||
# trace_id = tr.id
|
||||
# start = tr.meta.starttime
|
||||
# end = tr.meta.endtime
|
||||
|
||||
trs = pd.Series({
|
||||
'trace_id': tr.id,
|
||||
'trace_st': tr.meta.starttime,
|
||||
'trace_end': tr.meta.endtime,
|
||||
'stream_fname': f
|
||||
})
|
||||
traces.append(trs)
|
||||
|
||||
traces_catalog = pd.DataFrame(pd.concat(traces)).transpose()
|
||||
traces_catalog.to_csv("data/bogdanka/traces_catalog.csv", append=True, index=False)
|
||||
|
||||
|
||||
def split_events(events, input_path):
|
||||
|
||||
logger.info("Splitting available events into train, dev and test sets ...")
|
||||
events_stats = pd.DataFrame()
|
||||
events_stats.index.name = "event"
|
||||
|
||||
for i, event in enumerate(events):
|
||||
#check if mseed exists
|
||||
actual_picks = 0
|
||||
for pick in event.picks:
|
||||
trace_params = get_trace_params(pick)
|
||||
trace_path = get_trace_path(input_path, trace_params)
|
||||
if os.path.isfile(trace_path):
|
||||
actual_picks += 1
|
||||
|
||||
events_stats.loc[i, "pick_count"] = actual_picks
|
||||
|
||||
events_stats['pick_count_cumsum'] = events_stats.pick_count.cumsum()
|
||||
|
||||
train_th = 0.7 * events_stats.pick_count_cumsum.values[-1]
|
||||
dev_th = 0.85 * events_stats.pick_count_cumsum.values[-1]
|
||||
|
||||
events_stats['split'] = 'test'
|
||||
for i, event in events_stats.iterrows():
|
||||
if event['pick_count_cumsum'] < train_th:
|
||||
events_stats.loc[i, 'split'] = 'train'
|
||||
elif event['pick_count_cumsum'] < dev_th:
|
||||
events_stats.loc[i, 'split'] = 'dev'
|
||||
else:
|
||||
break
|
||||
|
||||
return events_stats
|
||||
|
||||
|
||||
def get_event_params(event):
|
||||
origin = event.preferred_origin()
|
||||
if origin is None:
|
||||
return {}
|
||||
# print(origin)
|
||||
|
||||
mag = event.preferred_magnitude()
|
||||
|
||||
source_id = str(event.resource_id)
|
||||
|
||||
event_params = {
|
||||
"source_id": source_id,
|
||||
"source_origin_uncertainty_sec": origin.time_errors["uncertainty"],
|
||||
"source_latitude_deg": origin.latitude,
|
||||
"source_latitude_uncertainty_km": origin.latitude_errors["uncertainty"],
|
||||
"source_longitude_deg": origin.longitude,
|
||||
"source_longitude_uncertainty_km": origin.longitude_errors["uncertainty"],
|
||||
"source_depth_km": origin.depth / 1e3,
|
||||
"source_depth_uncertainty_km": origin.depth_errors["uncertainty"] / 1e3 if origin.depth_errors[
|
||||
"uncertainty"] is not None else None,
|
||||
}
|
||||
|
||||
if mag is not None:
|
||||
event_params["source_magnitude"] = mag.mag
|
||||
event_params["source_magnitude_uncertainty"] = mag.mag_errors["uncertainty"]
|
||||
event_params["source_magnitude_type"] = mag.magnitude_type
|
||||
event_params["source_magnitude_author"] = mag.creation_info.agency_id if mag.creation_info is not None else None
|
||||
|
||||
return event_params
|
||||
|
||||
|
||||
def get_trace_params(pick):
|
||||
net = pick.waveform_id.network_code
|
||||
sta = pick.waveform_id.station_code
|
||||
|
||||
trace_params = {
|
||||
"station_network_code": net,
|
||||
"station_code": sta,
|
||||
"trace_channel": pick.waveform_id.channel_code,
|
||||
"station_location_code": pick.waveform_id.location_code,
|
||||
"time": pick.time
|
||||
}
|
||||
|
||||
return trace_params
|
||||
|
||||
|
||||
def find_trace(pick_time, traces):
|
||||
for tr in traces:
|
||||
if pick_time > tr.stats.endtime:
|
||||
continue
|
||||
if pick_time >= tr.stats.starttime:
|
||||
# print(pick_time, " - selected trace: ", tr)
|
||||
return tr
|
||||
|
||||
logger.warning(f"no matching trace for peak: {pick_time}")
|
||||
return None
|
||||
|
||||
|
||||
def get_trace_path(input_path, trace_params):
|
||||
year = trace_params["time"].year
|
||||
day_of_year = pd.Timestamp(str(trace_params["time"])).day_of_year
|
||||
net = trace_params["station_network_code"]
|
||||
station = trace_params["station_code"]
|
||||
tr_channel = trace_params["trace_channel"]
|
||||
|
||||
path = f"{input_path}/{year}/{net}/{station}/{tr_channel}.D/{net}.{station}..{tr_channel}.D.{year}.{day_of_year}"
|
||||
return path
|
||||
|
||||
|
||||
def load_trace(input_path, trace_params):
|
||||
trace_path = get_trace_path(input_path, trace_params)
|
||||
trace = None
|
||||
|
||||
if not os.path.isfile(trace_path):
|
||||
logger.w(trace_path + " not found")
|
||||
else:
|
||||
stream = obspy.read(trace_path)
|
||||
if len(stream.traces) > 1:
|
||||
trace = find_trace(trace_params["time"], stream.traces)
|
||||
elif len(stream.traces) == 0:
|
||||
logger.warning(f"no data in: {trace_path}")
|
||||
else:
|
||||
trace = stream.traces[0]
|
||||
|
||||
return trace
|
||||
|
||||
|
||||
def load_stream(input_path, trace_params, time_before=60, time_after=60):
|
||||
trace_path = get_trace_path(input_path, trace_params)
|
||||
sampling_rate, stream = None, None
|
||||
pick_time = trace_params["time"]
|
||||
|
||||
if not os.path.isfile(trace_path):
|
||||
print(trace_path + " not found")
|
||||
else:
|
||||
stream = obspy.read(trace_path)
|
||||
stream = stream.slice(pick_time - time_before, pick_time + time_after)
|
||||
if len(stream.traces) == 0:
|
||||
print(f"no data in: {trace_path}")
|
||||
else:
|
||||
sampling_rate = stream.traces[0].stats.sampling_rate
|
||||
|
||||
return sampling_rate, stream
|
||||
|
||||
|
||||
def convert_mseed_to_seisbench_format(input_path, catalog_path, output_path):
|
||||
"""
|
||||
Convert mseed files to seisbench dataset format
|
||||
:param input_path: folder with mseed files
|
||||
:param catalog_path: path to events catalog in quakeml format
|
||||
:param output_path: folder to save seisbench dataset
|
||||
:return:
|
||||
"""
|
||||
logger.info("Loading events catalog ...")
|
||||
events = read_events(catalog_path)
|
||||
events_stats = split_events(events, input_path)
|
||||
|
||||
metadata_path = output_path + "/metadata.csv"
|
||||
waveforms_path = output_path + "/waveforms.hdf5"
|
||||
|
||||
logger.debug("Catalog loaded, starting conversion ...")
|
||||
|
||||
with sbd.WaveformDataWriter(metadata_path, waveforms_path) as writer:
|
||||
writer.data_format = {
|
||||
"dimension_order": "CW",
|
||||
"component_order": "ZNE",
|
||||
}
|
||||
for i, event in enumerate(events):
|
||||
logger.debug(f"Converting {i} event")
|
||||
event_params = get_event_params(event)
|
||||
event_params["split"] = events_stats.loc[i, "split"]
|
||||
|
||||
for pick in event.picks:
|
||||
trace_params = get_trace_params(pick)
|
||||
sampling_rate, stream = load_stream(input_path, trace_params)
|
||||
if stream is None:
|
||||
continue
|
||||
|
||||
actual_t_start, data, _ = sbu.stream_to_array(
|
||||
stream,
|
||||
component_order=writer.data_format["component_order"],
|
||||
)
|
||||
|
||||
trace_params["trace_sampling_rate_hz"] = sampling_rate
|
||||
trace_params["trace_start_time"] = str(actual_t_start)
|
||||
|
||||
pick_time = obspy.core.utcdatetime.UTCDateTime(trace_params["time"])
|
||||
pick_idx = (pick_time - actual_t_start) * sampling_rate
|
||||
|
||||
trace_params[f"trace_{pick.phase_hint}_arrival_sample"] = int(pick_idx)
|
||||
|
||||
writer.add_trace({**event_params, **trace_params}, data)
|
||||
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
parser = argparse.ArgumentParser(description='Convert mseed files to seisbench format')
|
||||
parser.add_argument('--input_path', type=str, help='Path to mseed files')
|
||||
parser.add_argument('--catalog_path', type=str, help='Path to events catalog in quakeml format')
|
||||
parser.add_argument('--output_path', type=str, help='Path to output files')
|
||||
args = parser.parse_args()
|
||||
|
||||
|
||||
convert_mseed_to_seisbench_format(args.input_path, args.catalog_path, args.output_path)
|
@@ -15,21 +15,25 @@ import generate_eval_targets
|
||||
import hyperparameter_sweep
|
||||
import eval
|
||||
import collect_results
|
||||
from config_loader import data_path, targets_path, sampling_rate, dataset_name, sweep_files
|
||||
import importlib
|
||||
import config_loader
|
||||
|
||||
logging.root.setLevel(logging.INFO)
|
||||
logger = logging.getLogger('pipeline')
|
||||
|
||||
|
||||
def load_sweep_config(model_name, args):
|
||||
|
||||
if model_name == "PhaseNet" and args.phasenet_config is not None:
|
||||
sweep_fname = args.phasenet_config
|
||||
elif model_name == "GPD" and args.gpd_config is not None:
|
||||
sweep_fname = args.gpd_config
|
||||
elif model_name == "BasicPhaseAE" and args.basic_phase_ae_config is not None:
|
||||
sweep_fname = args.basic_phase_ae_config
|
||||
elif model_name == "EQTransformer" and args.eqtransformer_config is not None:
|
||||
sweep_fname = args.eqtransformer_config
|
||||
else:
|
||||
# use the default sweep config for the model
|
||||
sweep_fname = sweep_files[model_name]
|
||||
sweep_fname = config_loader.sweep_files[model_name]
|
||||
|
||||
logger.info(f"Loading sweep config: {sweep_fname}")
|
||||
|
||||
@@ -37,7 +41,6 @@ def load_sweep_config(model_name, args):
|
||||
|
||||
|
||||
def find_the_best_params(model_name, args):
|
||||
|
||||
# find the best hyperparams for the model_name
|
||||
logger.info(f"Starting searching for the best hyperparams for the model: {model_name}")
|
||||
|
||||
@@ -58,9 +61,9 @@ def find_the_best_params(model_name, args):
|
||||
|
||||
|
||||
def generate_predictions(sweep_id, model_name):
|
||||
experiment_name = f"{dataset_name}_{model_name}"
|
||||
experiment_name = f"{config_loader.dataset_name}_{model_name}"
|
||||
eval.main(weights=experiment_name,
|
||||
targets=targets_path,
|
||||
targets=config_loader.targets_path,
|
||||
sets='dev,test',
|
||||
batchsize=128,
|
||||
num_workers=4,
|
||||
@@ -73,22 +76,42 @@ def main():
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--phasenet_config", type=str, required=False)
|
||||
parser.add_argument("--gpd_config", type=str, required=False)
|
||||
parser.add_argument("--basic_phase_ae_config", type=str, required=False)
|
||||
parser.add_argument("--eqtransformer_config", type=str, required=False)
|
||||
parser.add_argument("--dataset", type=str, required=False)
|
||||
args = parser.parse_args()
|
||||
|
||||
if args.dataset is not None:
|
||||
util.set_dataset(args.dataset)
|
||||
importlib.reload(config_loader)
|
||||
|
||||
logger.info(f"Started pipeline for the {config_loader.dataset_name} dataset.")
|
||||
|
||||
# generate labels
|
||||
logger.info("Started generating labels for the dataset.")
|
||||
generate_eval_targets.main(data_path, targets_path, "2,3", sampling_rate, None)
|
||||
generate_eval_targets.main(config_loader.data_path, config_loader.targets_path, "2,3", config_loader.sampling_rate,
|
||||
None)
|
||||
|
||||
# find the best hyperparams for the models
|
||||
logger.info("Started training the models.")
|
||||
for model_name in ["GPD", "PhaseNet"]:
|
||||
for model_name in ["GPD", "PhaseNet", "BasicPhaseAE", "EQTransformer"]:
|
||||
if config_loader.dataset_name == "lumineos" and model_name == "EQTransformer":
|
||||
break
|
||||
sweep_id = find_the_best_params(model_name, args)
|
||||
generate_predictions(sweep_id, model_name)
|
||||
|
||||
# collect results
|
||||
logger.info("Collecting results.")
|
||||
collect_results.traverse_path("pred", "pred/results.csv")
|
||||
logger.info("Results saved in pred/results.csv")
|
||||
results_path = "pred/results.csv"
|
||||
collect_results.traverse_path("pred", results_path)
|
||||
logger.info(f"Results saved in {results_path}")
|
||||
|
||||
# log calculated metrics (MAE) on w&b
|
||||
logger.info("Logging MAE metrics on w&b.")
|
||||
util.log_metrics(results_path)
|
||||
|
||||
logger.info("Pipeline finished")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
19
scripts/run_pipeline_template.sh
Normal file
19
scripts/run_pipeline_template.sh
Normal file
@@ -0,0 +1,19 @@
|
||||
#!/bin/bash
|
||||
#SBATCH --job-name=job_name
|
||||
#SBATCH --time=10:00:00
|
||||
#SBATCH --account= ### to fill
|
||||
#SBATCH --partition=plgrid-gpu-v100
|
||||
#SBATCH --cpus-per-task=1
|
||||
#SBATCH --ntasks-per-node=1
|
||||
#SBATCH --gres=gpu:1
|
||||
|
||||
source path/to/mambaforge/bin/activate ### to change
|
||||
conda activate epos-ai-train
|
||||
|
||||
|
||||
python -c "import torch; print('CUDA available:', torch.cuda.is_available())"
|
||||
python -c "import torch; print('Number of CUDA devices:', torch.cuda.device_count())"
|
||||
python -c "import torch; print('Name of GPU:', torch.cuda.get_device_name(torch.cuda.current_device()))"
|
||||
|
||||
|
||||
python pipeline.py --dataset "bogdanka"
|
@@ -1,5 +1,10 @@
|
||||
"""
|
||||
This script offers general functionality required in multiple places.
|
||||
-----------------
|
||||
Copyright © 2023 ACK Cyfronet AGH, Poland.
|
||||
This work was partially funded by EPOS Project funded in frame of PL-POIR4.2
|
||||
-----------------
|
||||
|
||||
This script runs the pipeline for the training and evaluation of the models.
|
||||
"""
|
||||
|
||||
import numpy as np
|
||||
@@ -7,13 +12,15 @@ import pandas as pd
|
||||
import os
|
||||
import logging
|
||||
import glob
|
||||
import json
|
||||
import wandb
|
||||
|
||||
from dotenv import load_dotenv
|
||||
import sys
|
||||
from config_loader import models_path, configs_path
|
||||
from config_loader import models_path, configs_path, config_path
|
||||
import yaml
|
||||
load_dotenv()
|
||||
|
||||
load_dotenv()
|
||||
|
||||
logging.basicConfig()
|
||||
logging.getLogger().setLevel(logging.INFO)
|
||||
@@ -38,8 +45,16 @@ def load_best_model_data(sweep_id, weights):
|
||||
# Get best run parameters
|
||||
best_run = sweep.best_run()
|
||||
run_id = best_run.id
|
||||
matching_models = glob.glob(f"{models_path}/{weights}/*run={run_id}*ckpt")
|
||||
if len(matching_models)!=1:
|
||||
|
||||
run = api.run(f"{wandb_user}/{wandb_project_name}/runs/{run_id}")
|
||||
dataset = run.config["dataset_name"]
|
||||
model = run.config["model_name"][0]
|
||||
experiment = f"{dataset}_{model}"
|
||||
|
||||
checkpoints_path = f"{models_path}/{experiment}/*run={run_id}*ckpt"
|
||||
logging.debug(f"Searching for checkpoints in dir: {checkpoints_path}")
|
||||
matching_models = glob.glob(checkpoints_path)
|
||||
if len(matching_models) != 1:
|
||||
raise ValueError("Unable to determine the best checkpoint for run_id: " + run_id)
|
||||
best_checkpoint_path = matching_models[0]
|
||||
|
||||
@@ -62,31 +77,6 @@ def load_best_model_data(sweep_id, weights):
|
||||
return best_checkpoint_path, run_id
|
||||
|
||||
|
||||
def load_best_model(model_cls, weights, version):
|
||||
"""
|
||||
Determines the model with lowest validation loss from the csv logs and loads it
|
||||
|
||||
:param model_cls: Class of the lightning module to load
|
||||
:param weights: Path to weights as in cmd arguments
|
||||
:param version: String of version file
|
||||
:return: Instance of lightning module that was loaded from the best checkpoint
|
||||
"""
|
||||
metrics = pd.read_csv(weights / version / "metrics.csv")
|
||||
|
||||
idx = np.nanargmin(metrics["val_loss"])
|
||||
min_row = metrics.iloc[idx]
|
||||
|
||||
# For default checkpoint filename, see https://github.com/Lightning-AI/lightning/pull/11805
|
||||
# and https://github.com/Lightning-AI/lightning/issues/16636.
|
||||
# For example, 'epoch=0-step=1.ckpt' means the 1st step has finish, but the 1st epoch hasn't
|
||||
checkpoint = f"epoch={min_row['epoch']:.0f}-step={min_row['step']+1:.0f}.ckpt"
|
||||
|
||||
# For default save path of checkpoints, see https://github.com/Lightning-AI/lightning/pull/12372
|
||||
checkpoint_path = weights / version / "checkpoints" / checkpoint
|
||||
|
||||
return model_cls.load_from_checkpoint(checkpoint_path)
|
||||
|
||||
|
||||
default_workers = os.getenv("BENCHMARK_DEFAULT_WORKERS", None)
|
||||
if default_workers is None:
|
||||
logging.warning(
|
||||
@@ -117,3 +107,51 @@ def load_sweep_config(sweep_fname):
|
||||
sys.exit(1)
|
||||
|
||||
return sweep_config
|
||||
|
||||
|
||||
def log_metrics(results_file):
|
||||
"""
|
||||
|
||||
:param results_file: csv file with calculated metrics
|
||||
:return:
|
||||
"""
|
||||
|
||||
api = wandb.Api()
|
||||
wandb_project_name = os.environ.get("WANDB_PROJECT")
|
||||
wandb_user = os.environ.get("WANDB_USER")
|
||||
|
||||
results = pd.read_csv(results_file)
|
||||
for run_id in results["version"].unique():
|
||||
try:
|
||||
run = api.run(f"{wandb_user}/{wandb_project_name}/{run_id}")
|
||||
metrics_to_log = {}
|
||||
run_results = results[results["version"] == run_id]
|
||||
for col in run_results.columns:
|
||||
if 'mae' in col:
|
||||
metrics_to_log[col] = run_results[col].values[0]
|
||||
run.summary[col] = run_results[col].values[0]
|
||||
|
||||
run.summary.update()
|
||||
logging.info(f"Logged metrics for run: {run_id}, {metrics_to_log}")
|
||||
|
||||
except Exception as e:
|
||||
print(f"An error occurred: {e}, {type(e).__name__}, {e.args}")
|
||||
|
||||
|
||||
def set_dataset(dataset_name):
|
||||
"""
|
||||
Sets the dataset name in the config file
|
||||
:param dataset_name:
|
||||
:return:
|
||||
"""
|
||||
|
||||
with open(config_path, "r+") as f:
|
||||
config = json.load(f)
|
||||
config["dataset_name"] = dataset_name
|
||||
config["data_path"] = f"datasets/{dataset_name}/seisbench_format/"
|
||||
|
||||
f.seek(0) # rewind
|
||||
json.dump(config, f, indent=4)
|
||||
f.truncate()
|
||||
|
||||
|
||||
|
Reference in New Issue
Block a user