Added scripts converting mseeds from Bogdanka to seisbench format, extended readme, modidified logging

This commit is contained in:
2023-09-26 10:50:46 +02:00
parent 78ac51478c
commit aa39980573
15 changed files with 1788 additions and 66 deletions

View File

@@ -15,8 +15,8 @@ config = load_config(config_path)
data_path = f"{project_path}/{config['data_path']}"
models_path = f"{project_path}/{config['models_path']}"
targets_path = f"{project_path}/{config['targets_path']}"
dataset_name = config['dataset_name']
targets_path = f"{project_path}/{config['targets_path']}/{dataset_name}"
configs_path = f"{project_path}/{config['configs_path']}"
sweep_files = config['sweep_files']

View File

@@ -29,11 +29,11 @@ data_aliases = {
"instance": "InstanceCountsCombined",
"iquique": "Iquique",
"lendb": "LenDB",
"scedc": "SCEDC"
"scedc": "SCEDC",
}
def main(weights, targets, sets, batchsize, num_workers, sampling_rate=None, sweep_id=None, test_run=False):
def main(weights, targets, sets, batchsize, num_workers, sampling_rate=None, sweep_id=None):
weights = Path(weights)
targets = Path(os.path.abspath(targets))
print(targets)
@@ -100,8 +100,6 @@ def main(weights, targets, sets, batchsize, num_workers, sampling_rate=None, swe
for task in ["1", "23"]:
task_csv = targets / f"task{task}.csv"
print(task_csv)
if not task_csv.is_file():
continue
@@ -227,9 +225,7 @@ if __name__ == "__main__":
parser.add_argument(
"--sweep_id", type=str, help="wandb sweep_id", required=False, default=None
)
parser.add_argument(
"--test_run", action="store_true", required=False, default=False
)
args = parser.parse_args()
main(
@@ -239,8 +235,7 @@ if __name__ == "__main__":
batchsize=args.batchsize,
num_workers=args.num_workers,
sampling_rate=args.sampling_rate,
sweep_id=args.sweep_id,
test_run=args.test_run
sweep_id=args.sweep_id
)
running_time = str(
datetime.timedelta(seconds=time.perf_counter() - code_start_time)

View File

@@ -3,6 +3,7 @@
# This work was partially funded by EPOS Project funded in frame of PL-POIR4.2
# -----------------
import os
import os.path
import argparse
from pytorch_lightning.loggers import WandbLogger, CSVLogger
@@ -22,6 +23,7 @@ from config_loader import models_path, dataset_name, seed, experiment_count
torch.multiprocessing.set_sharing_strategy('file_system')
os.system("ulimit -n unlimited")
load_dotenv()
wandb_api_key = os.environ.get('WANDB_API_KEY')

View File

@@ -17,8 +17,8 @@ import eval
import collect_results
from config_loader import data_path, targets_path, sampling_rate, dataset_name, sweep_files
logging.root.setLevel(logging.INFO)
logger = logging.getLogger('pipeline')
logger.setLevel(logging.INFO)
def load_sweep_config(model_name, args):
@@ -76,16 +76,19 @@ def main():
args = parser.parse_args()
# generate labels
logger.info("Started generating labels for the dataset.")
generate_eval_targets.main(data_path, targets_path, "2,3", sampling_rate, None)
# find the best hyperparams for the models
logger.info("Started training the models.")
for model_name in ["GPD", "PhaseNet"]:
sweep_id = find_the_best_params(model_name, args)
generate_predictions(sweep_id, model_name)
# collect results
logger.info("Collecting results.")
collect_results.traverse_path("pred", "pred/results.csv")
logger.info("Results saved in pred/results.csv")
if __name__ == "__main__":
main()

View File

@@ -20,18 +20,13 @@ import torch
import os
import logging
from pathlib import Path
from dotenv import load_dotenv
import models, data, util
import time
import datetime
import wandb
#
# load_dotenv()
# wandb_api_key = os.environ.get('WANDB_API_KEY')
# if wandb_api_key is None:
# raise ValueError("WANDB_API_KEY environment variable is not set.")
#
# wandb.login(key=wandb_api_key)
def train(config, experiment_name, test_run):
"""
@@ -210,6 +205,14 @@ def generate_phase_mask(dataset, phases):
if __name__ == "__main__":
load_dotenv()
wandb_api_key = os.environ.get('WANDB_API_KEY')
if wandb_api_key is None:
raise ValueError("WANDB_API_KEY environment variable is not set.")
wandb.login(key=wandb_api_key)
code_start_time = time.perf_counter()
torch.manual_seed(42)

View File

@@ -16,7 +16,7 @@ load_dotenv()
logging.basicConfig()
logging.getLogger().setLevel(logging.DEBUG)
logging.getLogger().setLevel(logging.INFO)
def load_best_model_data(sweep_id, weights):