finetuning #1

Merged
k.milian merged 2 commits from finetuning into master 2024-07-29 13:41:06 +02:00
7 changed files with 343 additions and 27 deletions
Showing only changes of commit ebc759c215 - Show all commits

View File

@ -111,7 +111,7 @@ After adjusting the grant name, the paths to conda env and the paths to data sen
The script performs the following steps: The script performs the following steps:
1. Generates evaluation targets in `datasets/<dataset_name>/targets` directory. 1. Generates evaluation targets in `datasets/<dataset_name>/targets` directory.
1. Trains multiple versions of GPD, PhaseNet and ... models to find the best hyperparameters, producing the lowest validation loss. 1. Trains multiple versions of GPD, PhaseNet, BasicPhaseAE, and EQTransformer models to find the best hyperparameters, producing the lowest validation loss.
This step utilizes the Weights & Biases platform to perform the hyperparameters search (called sweeping) and track the training process and store the results. This step utilizes the Weights & Biases platform to perform the hyperparameters search (called sweeping) and track the training process and store the results.
The results are available at The results are available at
@ -126,14 +126,31 @@ After adjusting the grant name, the paths to conda env and the paths to data sen
The results are saved in the `scripts/pred/results.csv` file. They are additionally logged in Weights & Biases platform as summary metrics of corresponding runs. The results are saved in the `scripts/pred/results.csv` file. They are additionally logged in Weights & Biases platform as summary metrics of corresponding runs.
<br/> <br/>
The default settings are saved in config.json file. To change the settings, edit the config.json file or pass the new settings as arguments to the script. For example, to change the sweep configuration file for GPD model, run: The default settings for max number of experiments and paths are saved in config.json file. To change the settings, edit the config.json file or pass the new settings as arguments to the script. For example, to change the sweep configuration file for the GPD model, run:
```python pipeline.py --gpd_config <new config file>``` ```python pipeline.py --gpd_config <new config file>```
The new config file should be placed in the `experiments` folder or as specified in the `configs_path` parameter in the config.json file. The new config file should be placed in the `experiments` folder or as specified in the `configs_path` parameter in the config.json file.
Sweep configs are used to define the max number of epochs to run and the hyperparameters search space for the following parameters:
* `batch_size`
* `learning_rate`
Phasenet model has additional available parameters:
* `norm` - normalization method, options ('peak', 'std')
* `pretrained` - pretrained seisbench models used for transfer learning
* `finetuning` - the type of layers to finetune first, options ('all', 'top', 'encoder', 'decoder')
* `lr_reduce_factor` - factor to reduce learning rate after unfreezing layers
GPD model has additional parameters for filtering:
* `highpass` - highpass filter frequency
* `lowpass` - lowpass filter frequency
The sweep configs are saved in the `experiments` folder.
If you have multiple datasets, you can run the pipeline for each dataset separately by specifying the dataset name as an argument: If you have multiple datasets, you can run the pipeline for each dataset separately by specifying the dataset name as an argument:
```python pipeline.py --dataset <dataset_name>``` ```python pipeline.py --dataset <dataset_name>```
### Troubleshooting ### Troubleshooting

View File

@ -1,8 +1,13 @@
""" """
This file contains functionality related to data. This file contains functionality related to data.
""" """
import os.path
import seisbench.data as sbd import seisbench.data as sbd
import logging
logging.root.setLevel(logging.INFO)
logger = logging.getLogger('data')
def get_dataset_by_name(name): def get_dataset_by_name(name):
@ -30,3 +35,26 @@ def get_custom_dataset(path):
except AttributeError: except AttributeError:
raise ValueError(f"Unknown dataset '{path}'.") raise ValueError(f"Unknown dataset '{path}'.")
def validate_custom_dataset(data_path):
"""
Validate the dataset
:param data_path: path to the dataset
:return:
"""
# check if path exists
if not os.path.isdir((data_path)):
raise ValueError(f"Data path {data_path} does not exist.")
dataset = sbd.WaveformDataset(data_path)
# check if the dataset is split into train, dev and test
if len(dataset.train()) == 0:
raise ValueError(f"Training set is empty.")
if len(dataset.dev()) == 0:
raise ValueError(f"Dev set is empty.")
if len(dataset.test()) == 0:
raise ValueError(f"Test set is empty.")
logger.info("Custom dataset validated successfully.")

View File

@ -7,8 +7,9 @@ import os
import os.path import os.path
import argparse import argparse
from pytorch_lightning.loggers import WandbLogger, CSVLogger from pytorch_lightning.loggers import WandbLogger, CSVLogger
from pytorch_lightning.callbacks import ModelCheckpoint from pytorch_lightning.callbacks import ModelCheckpoint, LearningRateMonitor
from pytorch_lightning.callbacks.early_stopping import EarlyStopping from pytorch_lightning.callbacks.early_stopping import EarlyStopping
import pytorch_lightning as pl import pytorch_lightning as pl
import wandb import wandb
import torch import torch
@ -20,6 +21,8 @@ import train
import util import util
import config_loader import config_loader
torch.multiprocessing.set_sharing_strategy('file_system') torch.multiprocessing.set_sharing_strategy('file_system')
os.system("ulimit -n unlimited") os.system("ulimit -n unlimited")
@ -31,14 +34,13 @@ host = os.environ.get("WANDB_HOST")
if host is None: if host is None:
raise ValueError("WANDB_HOST environment variable is not set.") raise ValueError("WANDB_HOST environment variable is not set.")
wandb.login(key=wandb_api_key, host=host) wandb.login(key=wandb_api_key, host=host)
wandb_project_name = os.environ.get("WANDB_PROJECT") wandb_project_name = os.environ.get("WANDB_PROJECT")
wandb_user_name = os.environ.get("WANDB_USER") wandb_user_name = os.environ.get("WANDB_USER")
script_name = os.path.splitext(os.path.basename(__file__))[0] script_name = os.path.splitext(os.path.basename(__file__))[0]
logger = logging.getLogger(script_name) logger = logging.getLogger(script_name)
logger.setLevel(logging.WARNING) logger.setLevel(logging.INFO)
def set_random_seed(seed=3): def set_random_seed(seed=3):
@ -54,6 +56,7 @@ def get_trainer_args(config):
return trainer_args return trainer_args
class HyperparameterSweep: class HyperparameterSweep:
def __init__(self, project_name, sweep_config): def __init__(self, project_name, sweep_config):
self.project_name = project_name self.project_name = project_name
@ -87,21 +90,42 @@ class HyperparameterSweep:
try: try:
logger.debug("Starting a new run...") logger.info("Starting a new run...")
run = wandb.init( run = wandb.init(
project=self.project_name, project=self.project_name,
config=config_loader.config, config=config_loader.config,
save_code=True save_code=True,
entity=wandb_user_name
) )
run.log_code( run.log_code(
root=".", root=".",
include_fn=lambda path: path.endswith(".py") or path.endswith(".sh"), include_fn=lambda path: path.endswith(".py") or path.endswith(".sh"),
exclude_fn=lambda path: path.endswith("template.sh") exclude_fn=lambda path: path.endswith("template.sh")
) # not working as expected )
model_name = wandb.config.model_name[0] model_name = wandb.config.model_name[0]
model_args = models.get_model_specific_args(wandb.config) model_args = models.get_model_specific_args(wandb.config)
if "pretrained" in wandb.config:
weights = wandb.config.get("pretrained")
if type(weights) == list:
weights = weights[0]
if weights != "false":
model_args["pretrained"] = weights
if "norm" in wandb.config:
model_args["norm"] = wandb.config.norm
logger.debug(f"Initializing {model_name}") logger.debug(f"Initializing {model_name}")
if "finetuning" in wandb.config:
# train for a few epochs with some frozen params, then unfreeze and continue training
if type(wandb.config.finetuning) == list:
finetuning_strategy = wandb.config.finetuning[0]
else:
finetuning_strategy = wandb.config.finetuning
model_args['finetuning_strategy'] = finetuning_strategy
if "lr_reduce_factor" in wandb.config:
model_args['steplr_gamma'] = wandb.config.lr_reduce_factor
model = models.__getattribute__(model_name + "Lit")(**model_args) model = models.__getattribute__(model_name + "Lit")(**model_args)
@ -132,10 +156,13 @@ class HyperparameterSweep:
early_stopping_callback = EarlyStopping( early_stopping_callback = EarlyStopping(
monitor="val_loss", monitor="val_loss",
patience=3, patience=5,
verbose=True, verbose=True,
mode="min") mode="min")
callbacks = [checkpoint_callback, early_stopping_callback]
lr_monitor = LearningRateMonitor(logging_interval='epoch')
callbacks = [checkpoint_callback, early_stopping_callback, lr_monitor]
trainer = pl.Trainer( trainer = pl.Trainer(
default_root_dir=config_loader.models_path, default_root_dir=config_loader.models_path,
@ -143,7 +170,6 @@ class HyperparameterSweep:
callbacks=callbacks, callbacks=callbacks,
**get_trainer_args(wandb.config) **get_trainer_args(wandb.config)
) )
trainer.fit(model, train_loader, dev_loader) trainer.fit(model, train_loader, dev_loader)
except Exception as e: except Exception as e:
@ -155,7 +181,6 @@ class HyperparameterSweep:
def start_sweep(sweep_config): def start_sweep(sweep_config):
logger.info("Starting sweep with config: " + str(sweep_config)) logger.info("Starting sweep with config: " + str(sweep_config))
set_random_seed(config_loader.seed) set_random_seed(config_loader.seed)
sweep_runner = HyperparameterSweep(project_name=wandb_project_name, sweep_config=sweep_config) sweep_runner = HyperparameterSweep(project_name=wandb_project_name, sweep_config=sweep_config)
@ -165,7 +190,6 @@ def start_sweep(sweep_config):
if __name__ == "__main__": if __name__ == "__main__":
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument("--sweep_config", type=str, required=True) parser.add_argument("--sweep_config", type=str, required=True)
args = parser.parse_args() args = parser.parse_args()

142
scripts/input_validate.py Normal file
View File

@ -0,0 +1,142 @@
from pydantic import BaseModel, ConfigDict, field_validator
from typing_extensions import Literal
from typing import Union, List, Optional
import yaml
import logging
logging.root.setLevel(logging.INFO)
logger = logging.getLogger('input_validator')
#todo
# 1. check if a single value is allowed in a sweep
# 2. merge input params
# 3. change names of the classes
# 4. add constraints for PhaseNet, GPD
model_names = Literal["PhaseNet", "GPD", "BasicPhaseAE", "EQTransformer"]
norm_values = Literal["peak", "std"]
finetuning_values = Literal["all", "top", "decoder", "encoder"]
pretrained_values = Literal['diting', 'ethz', 'geofon', 'instance', 'iquique', 'lendb', 'neic',
'original', 'scedc', False]
class Metric(BaseModel):
goal: str
name: str
class NumericValue(BaseModel):
value: Union[int, float, List[Union[int, float]]]
class NumericValues(BaseModel):
values: List[Union[int, float]]
class IntDistribution(BaseModel):
distribution: str = "int_uniform"
min: int
max: int
class FloatDistribution(BaseModel):
distribution: str = "uniform"
min: float
max: float
class Pretrained(BaseModel):
distribution: Optional[str] = "categorical"
values: List[pretrained_values] = None
value: Union[pretrained_values, List[pretrained_values]] = None
class Finetuning(BaseModel):
distribution: Optional[str] = "categorical"
values: List[finetuning_values] = None
value: Union[finetuning_values, List[finetuning_values]] = None
class Norm(BaseModel):
distribution: Optional[str] = "categorical"
values: List[norm_values] = None
value: Union[norm_values, List[norm_values]] = None
class ModelType(BaseModel):
distribution: Optional[str] = "categorical"
value: Union[model_names, List[model_names]] = None
values: List[model_names] = None
class Parameters(BaseModel):
model_config = ConfigDict(extra='forbid', protected_namespaces=())
model_name: ModelType
batch_size: Union[IntDistribution, NumericValue, NumericValues]
learning_rate: Union[FloatDistribution, NumericValue, NumericValues]
max_epochs: Union[IntDistribution, NumericValue, NumericValues]
class PhaseNetParameters(Parameters):
model_config = ConfigDict(extra='forbid')
norm: Norm = None
pretrained: Pretrained = None
finetuning: Finetuning = None
lr_reduce_factor: Optional[Union[FloatDistribution, NumericValue, NumericValues]] = None
@field_validator("model_name")
def validate_model(cls, v):
if "PhaseNet" not in v.value:
raise ValueError("Additional parameters implemented for PhaseNet only")
return v
class GPDParameters(Parameters):
model_config = ConfigDict(extra='forbid')
highpass: Union[NumericValue, NumericValues, FloatDistribution] = None
lowpass: Union[NumericValue, NumericValues, FloatDistribution] = None
@field_validator("model_name")
def validate_model(cls, v):
if "GPD" not in v.value:
raise ValueError("Additional parameters implemented for GPD only")
class InputParams(BaseModel):
name: str
method: str
metric: Metric
parameters: Union[Parameters, PhaseNetParameters, GPDParameters]
def validate_sweep_yaml(yaml_filename, model_name=None):
# Load YAML configuration
with open(yaml_filename, 'r') as f:
sweep_config = yaml.safe_load(f)
validate_sweep_config(sweep_config, model_name)
def validate_sweep_config(sweep_config, model_name=None):
# Validate sweep config
input_params = InputParams(**sweep_config)
# Check consistency of input parameters and sweep configuration
sweep_model_name = input_params.parameters.model_name.value
if model_name is not None and model_name not in sweep_model_name:
info = f"Model name {model_name} is inconsistent with the sweep configuration {sweep_model_name}."
logger.info(info)
raise ValueError(info)
logger.info("Input validation successful.")
if __name__ == "__main__":
yaml_filename = "../experiments/sweep_phasenet_lumineos_lr_bs.yaml"
validate_sweep_yaml(yaml_filename, None)

View File

@ -7,16 +7,26 @@ import seisbench.generate as sbg
import pytorch_lightning as pl import pytorch_lightning as pl
import torch import torch
from torch.optim import lr_scheduler
import torch.nn.functional as F import torch.nn.functional as F
import numpy as np import numpy as np
from abc import abstractmethod, ABC from abc import abstractmethod, ABC
# import lightning as L
# Allows to import this file in both jupyter notebook and code # Allows to import this file in both jupyter notebook and code
try: try:
from .augmentations import DuplicateEvent from .augmentations import DuplicateEvent
except ImportError: except ImportError:
from augmentations import DuplicateEvent from augmentations import DuplicateEvent
import os
import logging
script_name = os.path.splitext(os.path.basename(__file__))[0]
logger = logging.getLogger(script_name)
logger.setLevel(logging.DEBUG)
# Phase dict for labelling. We only study P and S phases without differentiating between them. # Phase dict for labelling. We only study P and S phases without differentiating between them.
phase_dict = { phase_dict = {
@ -131,30 +141,84 @@ class PhaseNetLit(SeisBenchModuleLit):
self.sigma = sigma self.sigma = sigma
self.sample_boundaries = sample_boundaries self.sample_boundaries = sample_boundaries
self.loss = vector_cross_entropy self.loss = vector_cross_entropy
self.model = sbm.PhaseNet(**kwargs) self.pretrained = kwargs.pop("pretrained", None)
self.norm = kwargs.pop("norm", "peak")
if self.pretrained is not None:
self.model = sbm.PhaseNet.from_pretrained(self.pretrained)
# self.norm = self.model.norm
else:
self.model = sbm.PhaseNet(**kwargs)
self.finetuning_strategy = kwargs.pop("finetuning_strategy", None)
self.steplr_gamma = kwargs.pop("steplr_gamma", 0.1)
self.initial_epochs = 0
if self.finetuning_strategy is not None:
if self.finetuning_strategy == "top":
self.initial_epochs = 3
elif self.finetuning_strategy in ["decoder", "encoder"]:
self.initial_epochs = 6
self.freeze()
def forward(self, x): def forward(self, x):
return self.model(x) return self.model(x)
def shared_step(self, batch): def shared_step(self, batch):
x = batch["X"] x = batch["X"]
y_true = batch["y"] y_true = batch["y"]
y_pred = self.model(x) y_pred = self.model(x)
return self.loss(y_pred, y_true) return self.loss(y_pred, y_true)
def training_step(self, batch, batch_idx): def training_step(self, batch, batch_idx):
loss = self.shared_step(batch) loss = self.shared_step(batch)
self.log("train_loss", loss) self.log("train_loss", loss)
return loss return loss
def validation_step(self, batch, batch_idx): def validation_step(self, batch, batch_idx):
loss = self.shared_step(batch) loss = self.shared_step(batch)
self.log("val_loss", loss) self.log("val_loss", loss)
return loss return loss
def configure_optimizers(self): def configure_optimizers(self):
optimizer = torch.optim.Adam(self.parameters(), lr=self.lr) optimizer = torch.optim.Adam(self.parameters(), lr=self.lr)
return optimizer if self.finetuning_strategy is not None:
scheduler = lr_scheduler.LambdaLR(optimizer, self.lr_lambda)
else:
scheduler = lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.1, patience=3)
#
return {
'optimizer': optimizer,
'lr_scheduler': {
'scheduler': scheduler,
'monitor': 'val_loss',
'interval': 'epoch',
'reduce_on_plateau': False,
},
}
def lr_lambda(self, epoch):
# reduce lr after x initial epochs
if epoch == self.initial_epochs:
self.lr *= self.steplr_gamma
return self.lr
def lr_scheduler_step(self, scheduler, metric):
scheduler.step(epoch=self.current_epoch)
# def lr_scheduler_step(self, scheduler, optimizer_idx, metric):
# scheduler.step(epoch=self.current_epoch)
def get_augmentations(self): def get_augmentations(self):
return [ return [
@ -181,19 +245,21 @@ class PhaseNetLit(SeisBenchModuleLit):
strategy="pad", strategy="pad",
), ),
sbg.ChangeDtype(np.float32), sbg.ChangeDtype(np.float32),
sbg.Normalize(demean_axis=-1, amp_norm_axis=-1, amp_norm_type="peak"), sbg.Normalize(demean_axis=-1, amp_norm_axis=-1, amp_norm_type=self.norm),
sbg.ProbabilisticLabeller( sbg.ProbabilisticLabeller(
label_columns=phase_dict, sigma=self.sigma, dim=0 label_columns=phase_dict, sigma=self.sigma, dim=0
), ),
] ]
def get_eval_augmentations(self): def get_eval_augmentations(self):
return [ return [
sbg.SteeredWindow(windowlen=3001, strategy="pad"), sbg.SteeredWindow(windowlen=3001, strategy="pad"),
sbg.ChangeDtype(np.float32), sbg.ChangeDtype(np.float32),
sbg.Normalize(demean_axis=-1, amp_norm_axis=-1, amp_norm_type="peak"), sbg.Normalize(demean_axis=-1, amp_norm_axis=-1, amp_norm_type=self.norm),
] ]
def predict_step(self, batch, batch_idx=None, dataloader_idx=None): def predict_step(self, batch, batch_idx=None, dataloader_idx=None):
x = batch["X"] x = batch["X"]
window_borders = batch["window_borders"] window_borders = batch["window_borders"]
@ -219,6 +285,27 @@ class PhaseNetLit(SeisBenchModuleLit):
return score_detection, score_p_or_s, p_sample, s_sample return score_detection, score_p_or_s, p_sample, s_sample
def freeze(self):
if self.finetuning_strategy == "decoder": # finetune decoder branch and freeze encoder branch
for p in self.model.down_branch.parameters():
p.requires_grad = False
elif self.finetuning_strategy == "encoder": # finetune encoder branch and freeze decoder branch
for p in self.model.up_branch.parameters():
p.requires_grad = False
elif self.finetuning_strategy == "top":
for p in self.model.out.parameters():
p.requires_grad = False
def unfreeze(self):
logger.info("Unfreezing layers")
for p in self.model.parameters():
p.requires_grad = True
def on_train_epoch_start(self):
# Unfreeze some layers after x initial epochs
if self.current_epoch == self.initial_epochs:
self.unfreeze()
class GPDLit(SeisBenchModuleLit): class GPDLit(SeisBenchModuleLit):
""" """
@ -846,7 +933,7 @@ class BasicPhaseAELit(SeisBenchModuleLit):
# Create overlapping windows # Create overlapping windows
re = torch.zeros(x.shape[:2] + (7, 600), dtype=x.dtype, device=x.device) re = torch.zeros(x.shape[:2] + (7, 600), dtype=x.dtype, device=x.device)
for i, start in enumerate(range(0, 2401, 400)): for i, start in enumerate(range(0, 2401, 400)):
re[:, :, i] = x[:, :, start : start + 600] re[:, :, i] = x[:, :, start: start + 600]
x = re x = re
x = x.permute(0, 2, 1, 3) # --> (batch, windows, channels, samples) x = x.permute(0, 2, 1, 3) # --> (batch, windows, channels, samples)
@ -862,9 +949,9 @@ class BasicPhaseAELit(SeisBenchModuleLit):
for i, start in enumerate(range(0, 2401, 400)): for i, start in enumerate(range(0, 2401, 400)):
if start == 0: if start == 0:
# Use full window (for start==0, the end will be overwritten) # Use full window (for start==0, the end will be overwritten)
pred[:, :, start : start + 600] = window_pred[:, i] pred[:, :, start: start + 600] = window_pred[:, i]
else: else:
pred[:, :, start + 100 : start + 600] = window_pred[:, i, :, 100:] pred[:, :, start + 100: start + 600] = window_pred[:, i, :, 100:]
score_detection = torch.zeros(pred.shape[0]) score_detection = torch.zeros(pred.shape[0])
score_p_or_s = torch.zeros(pred.shape[0]) score_p_or_s = torch.zeros(pred.shape[0])
@ -1112,7 +1199,6 @@ class DPPPickerLit(SeisBenchModuleLit):
def get_model_specific_args(config): def get_model_specific_args(config):
model = config.model_name[0] model = config.model_name[0]
lr = config.learning_rate lr = config.learning_rate
if type(lr) == list: if type(lr) == list:

View File

@ -57,6 +57,8 @@ def split_events(events, input_path):
events_stats.loc[i, 'split'] = 'dev' events_stats.loc[i, 'split'] = 'dev'
else: else:
break break
logger.info(f"Split: {events_stats['split'].value_counts()}")
logger.info(f"Split: {events_stats['split'].value_counts()}") logger.info(f"Split: {events_stats['split'].value_counts()}")
@ -198,9 +200,10 @@ def convert_mseed_to_seisbench_format(input_path, catalog_path, output_path):
metadata_path = output_path + "/metadata.csv" metadata_path = output_path + "/metadata.csv"
waveforms_path = output_path + "/waveforms.hdf5" waveforms_path = output_path + "/waveforms.hdf5"
events_to_convert = events_stats[events_stats['pick_count'] > 0] events_to_convert = events_stats[events_stats['pick_count'] > 0]
logger.debug("Catalog loaded, starting converting {events_to_convert} events ...") logger.debug("Catalog loaded, starting converting {events_to_convert} events ...")
with sbd.WaveformDataWriter(metadata_path, waveforms_path) as writer: with sbd.WaveformDataWriter(metadata_path, waveforms_path) as writer:

View File

@ -17,6 +17,8 @@ import eval
import collect_results import collect_results
import importlib import importlib
import config_loader import config_loader
import input_validate
import data
logging.root.setLevel(logging.INFO) logging.root.setLevel(logging.INFO)
logger = logging.getLogger('pipeline') logger = logging.getLogger('pipeline')
@ -40,11 +42,22 @@ def load_sweep_config(model_name, args):
return util.load_sweep_config(sweep_fname) return util.load_sweep_config(sweep_fname)
def find_the_best_params(model_name, args): def validate_pipeline_input(args):
# validate input parameters
for model_name in args.models:
sweep_config = load_sweep_config(model_name, args)
input_validate.validate_sweep_config(sweep_config, model_name)
# validate dataset
data.validate_custom_dataset(config_loader.data_path)
def find_the_best_params(sweep_config):
# find the best hyperparams for the model_name # find the best hyperparams for the model_name
model_name = sweep_config['parameters']['model_name']
logger.info(f"Starting searching for the best hyperparams for the model: {model_name}") logger.info(f"Starting searching for the best hyperparams for the model: {model_name}")
sweep_config = load_sweep_config(model_name, args)
sweep_runner = hyperparameter_sweep.start_sweep(sweep_config) sweep_runner = hyperparameter_sweep.start_sweep(sweep_config)
# wait for all runs to finish # wait for all runs to finish
@ -91,6 +104,8 @@ def main():
util.set_dataset(args.dataset) util.set_dataset(args.dataset)
importlib.reload(config_loader) importlib.reload(config_loader)
validate_pipeline_input(args)
logger.info(f"Started pipeline for the {config_loader.dataset_name} dataset.") logger.info(f"Started pipeline for the {config_loader.dataset_name} dataset.")
# generate labels # generate labels
@ -101,7 +116,8 @@ def main():
# find the best hyperparams for the models # find the best hyperparams for the models
logger.info("Started training the models.") logger.info("Started training the models.")
for model_name in args.models: for model_name in args.models:
sweep_id = find_the_best_params(model_name, args) sweep_config = load_sweep_config(model_name, args)
sweep_id = find_the_best_params(sweep_config)
generate_predictions(sweep_id, model_name) generate_predictions(sweep_id, model_name)
# collect results # collect results