@@ -1,8 +1,13 @@
|
||||
"""
|
||||
This file contains functionality related to data.
|
||||
"""
|
||||
import os.path
|
||||
|
||||
import seisbench.data as sbd
|
||||
import logging
|
||||
|
||||
logging.root.setLevel(logging.INFO)
|
||||
logger = logging.getLogger('data')
|
||||
|
||||
|
||||
def get_dataset_by_name(name):
|
||||
@@ -30,3 +35,26 @@ def get_custom_dataset(path):
|
||||
except AttributeError:
|
||||
raise ValueError(f"Unknown dataset '{path}'.")
|
||||
|
||||
|
||||
def validate_custom_dataset(data_path):
|
||||
"""
|
||||
Validate the dataset
|
||||
:param data_path: path to the dataset
|
||||
:return:
|
||||
"""
|
||||
# check if path exists
|
||||
if not os.path.isdir((data_path)):
|
||||
raise ValueError(f"Data path {data_path} does not exist.")
|
||||
|
||||
dataset = sbd.WaveformDataset(data_path)
|
||||
# check if the dataset is split into train, dev and test
|
||||
if len(dataset.train()) == 0:
|
||||
raise ValueError(f"Training set is empty.")
|
||||
if len(dataset.dev()) == 0:
|
||||
raise ValueError(f"Dev set is empty.")
|
||||
if len(dataset.test()) == 0:
|
||||
raise ValueError(f"Test set is empty.")
|
||||
|
||||
logger.info("Custom dataset validated successfully.")
|
||||
|
||||
|
||||
|
@@ -7,8 +7,9 @@ import os
|
||||
import os.path
|
||||
import argparse
|
||||
from pytorch_lightning.loggers import WandbLogger, CSVLogger
|
||||
from pytorch_lightning.callbacks import ModelCheckpoint
|
||||
from pytorch_lightning.callbacks import ModelCheckpoint, LearningRateMonitor
|
||||
from pytorch_lightning.callbacks.early_stopping import EarlyStopping
|
||||
|
||||
import pytorch_lightning as pl
|
||||
import wandb
|
||||
import torch
|
||||
@@ -20,6 +21,8 @@ import train
|
||||
import util
|
||||
import config_loader
|
||||
|
||||
|
||||
|
||||
torch.multiprocessing.set_sharing_strategy('file_system')
|
||||
os.system("ulimit -n unlimited")
|
||||
|
||||
@@ -31,14 +34,13 @@ host = os.environ.get("WANDB_HOST")
|
||||
if host is None:
|
||||
raise ValueError("WANDB_HOST environment variable is not set.")
|
||||
|
||||
|
||||
wandb.login(key=wandb_api_key, host=host)
|
||||
wandb_project_name = os.environ.get("WANDB_PROJECT")
|
||||
wandb_user_name = os.environ.get("WANDB_USER")
|
||||
|
||||
script_name = os.path.splitext(os.path.basename(__file__))[0]
|
||||
logger = logging.getLogger(script_name)
|
||||
logger.setLevel(logging.WARNING)
|
||||
logger.setLevel(logging.INFO)
|
||||
|
||||
|
||||
def set_random_seed(seed=3):
|
||||
@@ -54,6 +56,11 @@ def get_trainer_args(config):
|
||||
return trainer_args
|
||||
|
||||
|
||||
def get_arg(arg):
|
||||
if type(arg) == list:
|
||||
return arg[0]
|
||||
return arg
|
||||
|
||||
class HyperparameterSweep:
|
||||
def __init__(self, project_name, sweep_config):
|
||||
self.project_name = project_name
|
||||
@@ -84,25 +91,40 @@ class HyperparameterSweep:
|
||||
return all_not_running
|
||||
|
||||
def run_experiment(self):
|
||||
|
||||
try:
|
||||
|
||||
logger.debug("Starting a new run...")
|
||||
logger.info("Starting a new run...")
|
||||
|
||||
run = wandb.init(
|
||||
project=self.project_name,
|
||||
config=config_loader.config,
|
||||
save_code=True
|
||||
save_code=True,
|
||||
entity=wandb_user_name
|
||||
)
|
||||
run.log_code(
|
||||
root=".",
|
||||
include_fn=lambda path: path.endswith(".py") or path.endswith(".sh"),
|
||||
exclude_fn=lambda path: path.endswith("template.sh")
|
||||
) # not working as expected
|
||||
)
|
||||
|
||||
model_name = wandb.config.model_name[0]
|
||||
model_name = get_arg(wandb.config.model_name)
|
||||
model_args = models.get_model_specific_args(wandb.config)
|
||||
logger.debug(f"Initializing {model_name}")
|
||||
|
||||
if "pretrained" in wandb.config:
|
||||
weights = get_arg(wandb.config.pretrained)
|
||||
if weights != "false":
|
||||
model_args["pretrained"] = weights
|
||||
|
||||
if "norm" in wandb.config:
|
||||
model_args["norm"] = get_arg(wandb.config.norm)
|
||||
|
||||
if "finetuning" in wandb.config:
|
||||
model_args['finetuning_strategy'] = get_arg(wandb.config.finetuning)
|
||||
|
||||
if "lr_reduce_factor" in wandb.config:
|
||||
model_args['steplr_gamma'] = get_arg(wandb.config.lr_reduce_factor)
|
||||
|
||||
logger.debug(f"Initializing {model_name} with args: {model_args}")
|
||||
model = models.__getattribute__(model_name + "Lit")(**model_args)
|
||||
|
||||
train_loader, dev_loader = train.prepare_data(wandb.config, model, test_run=False)
|
||||
@@ -132,10 +154,13 @@ class HyperparameterSweep:
|
||||
|
||||
early_stopping_callback = EarlyStopping(
|
||||
monitor="val_loss",
|
||||
patience=3,
|
||||
patience=5,
|
||||
verbose=True,
|
||||
mode="min")
|
||||
callbacks = [checkpoint_callback, early_stopping_callback]
|
||||
|
||||
lr_monitor = LearningRateMonitor(logging_interval='epoch')
|
||||
|
||||
callbacks = [checkpoint_callback, early_stopping_callback, lr_monitor]
|
||||
|
||||
trainer = pl.Trainer(
|
||||
default_root_dir=config_loader.models_path,
|
||||
@@ -143,7 +168,6 @@ class HyperparameterSweep:
|
||||
callbacks=callbacks,
|
||||
**get_trainer_args(wandb.config)
|
||||
)
|
||||
|
||||
trainer.fit(model, train_loader, dev_loader)
|
||||
|
||||
except Exception as e:
|
||||
@@ -155,7 +179,6 @@ class HyperparameterSweep:
|
||||
|
||||
|
||||
def start_sweep(sweep_config):
|
||||
|
||||
logger.info("Starting sweep with config: " + str(sweep_config))
|
||||
set_random_seed(config_loader.seed)
|
||||
sweep_runner = HyperparameterSweep(project_name=wandb_project_name, sweep_config=sweep_config)
|
||||
@@ -165,7 +188,6 @@ def start_sweep(sweep_config):
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--sweep_config", type=str, required=True)
|
||||
args = parser.parse_args()
|
||||
|
146
scripts/input_validate.py
Normal file
146
scripts/input_validate.py
Normal file
@@ -0,0 +1,146 @@
|
||||
from pydantic import BaseModel, ConfigDict, field_validator
|
||||
from typing_extensions import Literal
|
||||
from typing import Union, List, Optional
|
||||
import yaml
|
||||
import logging
|
||||
|
||||
logging.root.setLevel(logging.INFO)
|
||||
logger = logging.getLogger('input_validator')
|
||||
|
||||
#todo
|
||||
# 1. check if a single value is allowed in a sweep
|
||||
# 2. merge input params
|
||||
# 3. change names of the classes
|
||||
# 4. add constraints for PhaseNet, GPD
|
||||
|
||||
|
||||
model_names = Literal["PhaseNet", "GPD", "BasicPhaseAE", "EQTransformer"]
|
||||
norm_values = Literal["peak", "std"]
|
||||
finetuning_values = Literal["all", "top", "decoder", "encoder"]
|
||||
pretrained_values = Literal['diting', 'ethz', 'geofon', 'instance', 'iquique', 'lendb', 'neic',
|
||||
'original', 'scedc', False]
|
||||
|
||||
|
||||
class Metric(BaseModel):
|
||||
goal: str
|
||||
name: str
|
||||
|
||||
|
||||
class NumericValue(BaseModel):
|
||||
value: Union[int, float, List[Union[int, float]]]
|
||||
|
||||
|
||||
class NumericValues(BaseModel):
|
||||
values: List[Union[int, float]]
|
||||
|
||||
|
||||
class IntDistribution(BaseModel):
|
||||
distribution: str = "int_uniform"
|
||||
min: int
|
||||
max: int
|
||||
|
||||
|
||||
class FloatDistribution(BaseModel):
|
||||
distribution: str = "uniform"
|
||||
min: float
|
||||
max: float
|
||||
|
||||
|
||||
class Pretrained(BaseModel):
|
||||
distribution: Optional[str] = "categorical"
|
||||
values: List[pretrained_values] = None
|
||||
value: Union[pretrained_values, List[pretrained_values]] = None
|
||||
|
||||
|
||||
class Finetuning(BaseModel):
|
||||
distribution: Optional[str] = "categorical"
|
||||
values: List[finetuning_values] = None
|
||||
value: Union[finetuning_values, List[finetuning_values]] = None
|
||||
|
||||
|
||||
class Norm(BaseModel):
|
||||
distribution: Optional[str] = "categorical"
|
||||
values: List[norm_values] = None
|
||||
value: Union[norm_values, List[norm_values]] = None
|
||||
|
||||
|
||||
class ModelType(BaseModel):
|
||||
distribution: Optional[str] = "categorical"
|
||||
value: Union[model_names, List[model_names]] = None
|
||||
values: List[model_names] = None
|
||||
|
||||
|
||||
class Parameters(BaseModel):
|
||||
model_config = ConfigDict(extra='forbid', protected_namespaces=())
|
||||
model_name: ModelType
|
||||
batch_size: Union[IntDistribution, NumericValue, NumericValues]
|
||||
learning_rate: Union[FloatDistribution, NumericValue, NumericValues]
|
||||
max_epochs: Union[IntDistribution, NumericValue, NumericValues]
|
||||
|
||||
|
||||
class PhaseNetParameters(Parameters):
|
||||
model_config = ConfigDict(extra='forbid')
|
||||
norm: Norm = None
|
||||
pretrained: Pretrained = None
|
||||
finetuning: Finetuning = None
|
||||
lr_reduce_factor: Optional[Union[FloatDistribution, NumericValue, NumericValues]] = None
|
||||
|
||||
highpass: Union[NumericValue, NumericValues, FloatDistribution, IntDistribution] = None
|
||||
lowpass: Union[NumericValue, NumericValues, FloatDistribution, IntDistribution] = None
|
||||
|
||||
@field_validator("model_name")
|
||||
def validate_model(cls, v):
|
||||
if "PhaseNet" not in v.value:
|
||||
raise ValueError("Additional parameters implemented for PhaseNet only")
|
||||
return v
|
||||
|
||||
|
||||
class FilteringParameters(Parameters):
|
||||
model_config = ConfigDict(extra='forbid')
|
||||
|
||||
highpass: Union[NumericValue, NumericValues, FloatDistribution, IntDistribution] = None
|
||||
lowpass: Union[NumericValue, NumericValues, FloatDistribution, IntDistribution] = None
|
||||
|
||||
@field_validator("model_name")
|
||||
def validate_model(cls, v):
|
||||
print(v.value)
|
||||
if v.value[0] not in ["GPD", "PhaseNet"]:
|
||||
raise ValueError("Filtering parameters implemented for GPD and PhaseNet only")
|
||||
|
||||
|
||||
class InputParams(BaseModel):
|
||||
name: str
|
||||
method: str
|
||||
metric: Metric
|
||||
parameters: Union[Parameters, PhaseNetParameters, FilteringParameters]
|
||||
|
||||
|
||||
def validate_sweep_yaml(yaml_filename, model_name=None):
|
||||
# Load YAML configuration
|
||||
with open(yaml_filename, 'r') as f:
|
||||
sweep_config = yaml.safe_load(f)
|
||||
|
||||
validate_sweep_config(sweep_config, model_name)
|
||||
|
||||
|
||||
def validate_sweep_config(sweep_config, model_name=None):
|
||||
|
||||
# Validate sweep config
|
||||
|
||||
input_params = InputParams(**sweep_config)
|
||||
|
||||
# Check consistency of input parameters and sweep configuration
|
||||
sweep_model_name = input_params.parameters.model_name.value
|
||||
if model_name is not None and model_name not in sweep_model_name:
|
||||
info = f"Model name {model_name} is inconsistent with the sweep configuration {sweep_model_name}."
|
||||
logger.info(info)
|
||||
raise ValueError(info)
|
||||
logger.info("Input validation successful.")
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
yaml_filename = "../experiments/sweep_phasenet_bogdanka_lr_bs.yaml"
|
||||
validate_sweep_yaml(yaml_filename, None)
|
@@ -7,16 +7,26 @@ import seisbench.generate as sbg
|
||||
|
||||
import pytorch_lightning as pl
|
||||
import torch
|
||||
from torch.optim import lr_scheduler
|
||||
import torch.nn.functional as F
|
||||
import numpy as np
|
||||
from abc import abstractmethod, ABC
|
||||
|
||||
# import lightning as L
|
||||
|
||||
|
||||
# Allows to import this file in both jupyter notebook and code
|
||||
try:
|
||||
from .augmentations import DuplicateEvent
|
||||
except ImportError:
|
||||
from augmentations import DuplicateEvent
|
||||
|
||||
import os
|
||||
import logging
|
||||
|
||||
script_name = os.path.splitext(os.path.basename(__file__))[0]
|
||||
logger = logging.getLogger(script_name)
|
||||
logger.setLevel(logging.DEBUG)
|
||||
|
||||
# Phase dict for labelling. We only study P and S phases without differentiating between them.
|
||||
phase_dict = {
|
||||
@@ -131,7 +141,31 @@ class PhaseNetLit(SeisBenchModuleLit):
|
||||
self.sigma = sigma
|
||||
self.sample_boundaries = sample_boundaries
|
||||
self.loss = vector_cross_entropy
|
||||
self.model = sbm.PhaseNet(**kwargs)
|
||||
self.pretrained = kwargs.pop("pretrained", None)
|
||||
self.norm = kwargs.pop("norm", "peak")
|
||||
self.highpass = kwargs.pop("highpass", None)
|
||||
self.lowpass = kwargs.pop("lowpass", None)
|
||||
|
||||
|
||||
if self.pretrained is not None:
|
||||
self.model = sbm.PhaseNet.from_pretrained(self.pretrained)
|
||||
# self.norm = self.model.norm
|
||||
else:
|
||||
self.model = sbm.PhaseNet(**kwargs)
|
||||
|
||||
self.finetuning_strategy = kwargs.pop("finetuning_strategy", None)
|
||||
self.steplr_gamma = kwargs.pop("steplr_gamma", 0.1)
|
||||
self.reduce_lr_on_plateau = False
|
||||
|
||||
self.initial_epochs = 0
|
||||
|
||||
if self.finetuning_strategy is not None:
|
||||
if self.finetuning_strategy == "top":
|
||||
self.initial_epochs = 3
|
||||
elif self.finetuning_strategy in ["decoder", "encoder"]:
|
||||
self.initial_epochs = 6
|
||||
|
||||
self.freeze()
|
||||
|
||||
def forward(self, x):
|
||||
return self.model(x)
|
||||
@@ -154,9 +188,49 @@ class PhaseNetLit(SeisBenchModuleLit):
|
||||
|
||||
def configure_optimizers(self):
|
||||
optimizer = torch.optim.Adam(self.parameters(), lr=self.lr)
|
||||
return optimizer
|
||||
if self.finetuning_strategy is not None:
|
||||
scheduler = lr_scheduler.LambdaLR(optimizer, self.lr_lambda)
|
||||
self.reduce_lr_on_plateau = False
|
||||
else:
|
||||
scheduler = lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.1, patience=3)
|
||||
self.reduce_lr_on_plateau = True
|
||||
#
|
||||
return {
|
||||
'optimizer': optimizer,
|
||||
'lr_scheduler': {
|
||||
'scheduler': scheduler,
|
||||
'monitor': 'val_loss',
|
||||
'interval': 'epoch',
|
||||
'reduce_on_plateau': self.reduce_lr_on_plateau,
|
||||
},
|
||||
}
|
||||
|
||||
def lr_lambda(self, epoch):
|
||||
# reduce lr after x initial epochs
|
||||
if epoch == self.initial_epochs:
|
||||
self.lr *= self.steplr_gamma
|
||||
|
||||
return self.lr
|
||||
|
||||
def lr_scheduler_step(self, scheduler, metric):
|
||||
if self.reduce_lr_on_plateau:
|
||||
scheduler.step(metric, epoch=self.current_epoch)
|
||||
else:
|
||||
scheduler.step(epoch=self.current_epoch)
|
||||
|
||||
# def lr_scheduler_step(self, scheduler, optimizer_idx, metric):
|
||||
# scheduler.step(epoch=self.current_epoch)
|
||||
|
||||
def get_augmentations(self):
|
||||
filter = []
|
||||
if self.highpass is not None:
|
||||
filter = [sbg.Filter(1, self.highpass, "highpass", forward_backward=True)]
|
||||
logger.info(f"Using highpass filer {self.highpass}")
|
||||
if self.lowpass is not None:
|
||||
filter += [sbg.Filter(1, self.lowpass, "lowpass", forward_backward=True)]
|
||||
logger.info(f"Using lowpass filer {self.lowpass}")
|
||||
logger.info(filter)
|
||||
|
||||
return [
|
||||
# In 2/3 of the cases, select windows around picks, to reduce amount of noise traces in training.
|
||||
# Uses strategy variable, as padding will be handled by the random window.
|
||||
@@ -180,18 +254,26 @@ class PhaseNetLit(SeisBenchModuleLit):
|
||||
windowlen=3001,
|
||||
strategy="pad",
|
||||
),
|
||||
sbg.Normalize(demean_axis=-1, amp_norm_axis=-1, amp_norm_type=self.norm),
|
||||
*filter,
|
||||
sbg.ChangeDtype(np.float32),
|
||||
sbg.Normalize(demean_axis=-1, amp_norm_axis=-1, amp_norm_type="peak"),
|
||||
sbg.ProbabilisticLabeller(
|
||||
label_columns=phase_dict, sigma=self.sigma, dim=0
|
||||
),
|
||||
]
|
||||
|
||||
def get_eval_augmentations(self):
|
||||
filter = []
|
||||
if self.highpass is not None:
|
||||
filter = [sbg.Filter(1, self.highpass, "highpass", forward_backward=True)]
|
||||
if self.lowpass is not None:
|
||||
filter += [sbg.Filter(1, self.lowpass, "lowpass", forward_backward=True)]
|
||||
|
||||
return [
|
||||
sbg.SteeredWindow(windowlen=3001, strategy="pad"),
|
||||
sbg.Normalize(demean_axis=-1, amp_norm_axis=-1, amp_norm_type=self.norm),
|
||||
*filter,
|
||||
sbg.ChangeDtype(np.float32),
|
||||
sbg.Normalize(demean_axis=-1, amp_norm_axis=-1, amp_norm_type="peak"),
|
||||
]
|
||||
|
||||
def predict_step(self, batch, batch_idx=None, dataloader_idx=None):
|
||||
@@ -219,6 +301,27 @@ class PhaseNetLit(SeisBenchModuleLit):
|
||||
|
||||
return score_detection, score_p_or_s, p_sample, s_sample
|
||||
|
||||
def freeze(self):
|
||||
if self.finetuning_strategy == "decoder": # finetune decoder branch and freeze encoder branch
|
||||
for p in self.model.down_branch.parameters():
|
||||
p.requires_grad = False
|
||||
elif self.finetuning_strategy == "encoder": # finetune encoder branch and freeze decoder branch
|
||||
for p in self.model.up_branch.parameters():
|
||||
p.requires_grad = False
|
||||
elif self.finetuning_strategy == "top":
|
||||
for p in self.model.out.parameters():
|
||||
p.requires_grad = False
|
||||
|
||||
def unfreeze(self):
|
||||
logger.info("Unfreezing layers")
|
||||
for p in self.model.parameters():
|
||||
p.requires_grad = True
|
||||
|
||||
def on_train_epoch_start(self):
|
||||
# Unfreeze some layers after x initial epochs
|
||||
if self.current_epoch == self.initial_epochs:
|
||||
self.unfreeze()
|
||||
|
||||
|
||||
class GPDLit(SeisBenchModuleLit):
|
||||
"""
|
||||
@@ -846,7 +949,7 @@ class BasicPhaseAELit(SeisBenchModuleLit):
|
||||
# Create overlapping windows
|
||||
re = torch.zeros(x.shape[:2] + (7, 600), dtype=x.dtype, device=x.device)
|
||||
for i, start in enumerate(range(0, 2401, 400)):
|
||||
re[:, :, i] = x[:, :, start : start + 600]
|
||||
re[:, :, i] = x[:, :, start: start + 600]
|
||||
x = re
|
||||
|
||||
x = x.permute(0, 2, 1, 3) # --> (batch, windows, channels, samples)
|
||||
@@ -862,9 +965,9 @@ class BasicPhaseAELit(SeisBenchModuleLit):
|
||||
for i, start in enumerate(range(0, 2401, 400)):
|
||||
if start == 0:
|
||||
# Use full window (for start==0, the end will be overwritten)
|
||||
pred[:, :, start : start + 600] = window_pred[:, i]
|
||||
pred[:, :, start: start + 600] = window_pred[:, i]
|
||||
else:
|
||||
pred[:, :, start + 100 : start + 600] = window_pred[:, i, :, 100:]
|
||||
pred[:, :, start + 100: start + 600] = window_pred[:, i, :, 100:]
|
||||
|
||||
score_detection = torch.zeros(pred.shape[0])
|
||||
score_p_or_s = torch.zeros(pred.shape[0])
|
||||
@@ -1112,7 +1215,6 @@ class DPPPickerLit(SeisBenchModuleLit):
|
||||
|
||||
|
||||
def get_model_specific_args(config):
|
||||
|
||||
model = config.model_name[0]
|
||||
lr = config.learning_rate
|
||||
if type(lr) == list:
|
||||
@@ -1125,8 +1227,12 @@ def get_model_specific_args(config):
|
||||
if 'highpass' in config:
|
||||
args['highpass'] = config.highpass
|
||||
if 'lowpass' in config:
|
||||
args['lowpass'] = config.lowpass[0]
|
||||
args['lowpass'] = config.lowpass
|
||||
case 'PhaseNet':
|
||||
if 'highpass' in config:
|
||||
args['highpass'] = config.highpass
|
||||
if 'lowpass' in config:
|
||||
args['lowpass'] = config.lowpass
|
||||
if 'sample_boundaries' in config:
|
||||
args['sample_boundaries'] = config.sample_boundaries
|
||||
case 'DPPPicker':
|
||||
|
@@ -57,6 +57,8 @@ def split_events(events, input_path):
|
||||
events_stats.loc[i, 'split'] = 'dev'
|
||||
else:
|
||||
break
|
||||
|
||||
logger.info(f"Split: {events_stats['split'].value_counts()}")
|
||||
|
||||
logger.info(f"Split: {events_stats['split'].value_counts()}")
|
||||
|
||||
@@ -198,9 +200,10 @@ def convert_mseed_to_seisbench_format(input_path, catalog_path, output_path):
|
||||
|
||||
metadata_path = output_path + "/metadata.csv"
|
||||
waveforms_path = output_path + "/waveforms.hdf5"
|
||||
|
||||
|
||||
events_to_convert = events_stats[events_stats['pick_count'] > 0]
|
||||
|
||||
|
||||
logger.debug("Catalog loaded, starting converting {events_to_convert} events ...")
|
||||
|
||||
with sbd.WaveformDataWriter(metadata_path, waveforms_path) as writer:
|
||||
|
@@ -17,6 +17,8 @@ import eval
|
||||
import collect_results
|
||||
import importlib
|
||||
import config_loader
|
||||
import input_validate
|
||||
import data
|
||||
|
||||
logging.root.setLevel(logging.INFO)
|
||||
logger = logging.getLogger('pipeline')
|
||||
@@ -40,11 +42,22 @@ def load_sweep_config(model_name, args):
|
||||
return util.load_sweep_config(sweep_fname)
|
||||
|
||||
|
||||
def find_the_best_params(model_name, args):
|
||||
def validate_pipeline_input(args):
|
||||
|
||||
# validate input parameters
|
||||
for model_name in args.models:
|
||||
sweep_config = load_sweep_config(model_name, args)
|
||||
input_validate.validate_sweep_config(sweep_config, model_name)
|
||||
|
||||
# validate dataset
|
||||
data.validate_custom_dataset(config_loader.data_path)
|
||||
|
||||
|
||||
def find_the_best_params(sweep_config):
|
||||
# find the best hyperparams for the model_name
|
||||
model_name = sweep_config['parameters']['model_name']
|
||||
logger.info(f"Starting searching for the best hyperparams for the model: {model_name}")
|
||||
|
||||
sweep_config = load_sweep_config(model_name, args)
|
||||
sweep_runner = hyperparameter_sweep.start_sweep(sweep_config)
|
||||
|
||||
# wait for all runs to finish
|
||||
@@ -91,6 +104,8 @@ def main():
|
||||
util.set_dataset(args.dataset)
|
||||
importlib.reload(config_loader)
|
||||
|
||||
validate_pipeline_input(args)
|
||||
|
||||
logger.info(f"Started pipeline for the {config_loader.dataset_name} dataset.")
|
||||
|
||||
# generate labels
|
||||
@@ -101,7 +116,8 @@ def main():
|
||||
# find the best hyperparams for the models
|
||||
logger.info("Started training the models.")
|
||||
for model_name in args.models:
|
||||
sweep_id = find_the_best_params(model_name, args)
|
||||
sweep_config = load_sweep_config(model_name, args)
|
||||
sweep_id = find_the_best_params(sweep_config)
|
||||
generate_predictions(sweep_id, model_name)
|
||||
|
||||
# collect results
|
||||
|
Reference in New Issue
Block a user