finetuning #1

Merged
k.milian merged 2 commits from finetuning into master 2024-07-29 13:41:06 +02:00
7 changed files with 72 additions and 61 deletions
Showing only changes of commit f40ac35cc8 - Show all commits

View File

@ -1,6 +1,6 @@
{ {
"dataset_name": "bogdanka", "dataset_name": "bogdanka_2018_2022",
"data_path": "datasets/bogdanka/seisbench_format/", "data_path": "datasets/bogdanka_2018_2022/seisbench_format/",
"targets_path": "datasets/targets", "targets_path": "datasets/targets",
"models_path": "weights", "models_path": "weights",
"configs_path": "experiments", "configs_path": "experiments",
@ -13,5 +13,5 @@
"BasicPhaseAE": "sweep_basicphase_ae.yaml", "BasicPhaseAE": "sweep_basicphase_ae.yaml",
"EQTransformer": "sweep_eqtransformer.yaml" "EQTransformer": "sweep_eqtransformer.yaml"
}, },
"experiment_count": 20 "experiment_count": 15
} }

View File

@ -1,3 +1,4 @@
name: BasicPhaseAE
method: bayes method: bayes
metric: metric:
goal: minimize goal: minimize
@ -7,13 +8,9 @@ parameters:
value: value:
- BasicPhaseAE - BasicPhaseAE
batch_size: batch_size:
distribution: int_uniform values: [64, 128, 256]
max: 1024
min: 256
max_epochs: max_epochs:
value: value:
- 20 - 30
learning_rate: learning_rate:
distribution: uniform values: [0.01, 0.005, 0.001]
max: 0.02
min: 0.001

View File

@ -8,13 +8,9 @@ parameters:
value: value:
- EQTransformer - EQTransformer
batch_size: batch_size:
distribution: int_uniform values: [64, 128, 256]
max: 1024
min: 256
max_epochs: max_epochs:
value: value:
- 30 - 30
learning_rate: learning_rate:
distribution: uniform values: [0.01, 0.005, 0.001]
max: 0.02
min: 0.005

View File

@ -1,4 +1,4 @@
name: GPD_fixed_highpass:2-10 name: GPD
method: bayes method: bayes
metric: metric:
goal: minimize goal: minimize
@ -8,16 +8,12 @@ parameters:
value: value:
- GPD - GPD
batch_size: batch_size:
distribution: int_uniform values: [64, 128, 256]
max: 1024
min: 256
max_epochs: max_epochs:
value: value:
- 30 - 30
learning_rate: learning_rate:
distribution: uniform values: [0.01, 0.005, 0.001]
max: 0.02
min: 0.005
highpass: highpass:
value: value:
- 1 - 1

View File

@ -56,6 +56,10 @@ def get_trainer_args(config):
return trainer_args return trainer_args
def get_arg(arg):
if type(arg) == list:
return arg[0]
return arg
class HyperparameterSweep: class HyperparameterSweep:
def __init__(self, project_name, sweep_config): def __init__(self, project_name, sweep_config):
@ -87,10 +91,10 @@ class HyperparameterSweep:
return all_not_running return all_not_running
def run_experiment(self): def run_experiment(self):
try: try:
logger.info("Starting a new run...") logger.info("Starting a new run...")
run = wandb.init( run = wandb.init(
project=self.project_name, project=self.project_name,
config=config_loader.config, config=config_loader.config,
@ -103,30 +107,24 @@ class HyperparameterSweep:
exclude_fn=lambda path: path.endswith("template.sh") exclude_fn=lambda path: path.endswith("template.sh")
) )
model_name = wandb.config.model_name[0] model_name = get_arg(wandb.config.model_name)
model_args = models.get_model_specific_args(wandb.config) model_args = models.get_model_specific_args(wandb.config)
if "pretrained" in wandb.config: if "pretrained" in wandb.config:
weights = wandb.config.get("pretrained") weights = get_arg(wandb.config.pretrained)
if type(weights) == list:
weights = weights[0]
if weights != "false": if weights != "false":
model_args["pretrained"] = weights model_args["pretrained"] = weights
if "norm" in wandb.config:
model_args["norm"] = wandb.config.norm
logger.debug(f"Initializing {model_name}") if "norm" in wandb.config:
model_args["norm"] = get_arg(wandb.config.norm)
if "finetuning" in wandb.config: if "finetuning" in wandb.config:
# train for a few epochs with some frozen params, then unfreeze and continue training model_args['finetuning_strategy'] = get_arg(wandb.config.finetuning)
if type(wandb.config.finetuning) == list:
finetuning_strategy = wandb.config.finetuning[0]
else:
finetuning_strategy = wandb.config.finetuning
model_args['finetuning_strategy'] = finetuning_strategy
if "lr_reduce_factor" in wandb.config: if "lr_reduce_factor" in wandb.config:
model_args['steplr_gamma'] = wandb.config.lr_reduce_factor model_args['steplr_gamma'] = get_arg(wandb.config.lr_reduce_factor)
logger.debug(f"Initializing {model_name} with args: {model_args}")
model = models.__getattribute__(model_name + "Lit")(**model_args) model = models.__getattribute__(model_name + "Lit")(**model_args)
train_loader, dev_loader = train.prepare_data(wandb.config, model, test_run=False) train_loader, dev_loader = train.prepare_data(wandb.config, model, test_run=False)

View File

@ -85,6 +85,9 @@ class PhaseNetParameters(Parameters):
finetuning: Finetuning = None finetuning: Finetuning = None
lr_reduce_factor: Optional[Union[FloatDistribution, NumericValue, NumericValues]] = None lr_reduce_factor: Optional[Union[FloatDistribution, NumericValue, NumericValues]] = None
highpass: Union[NumericValue, NumericValues, FloatDistribution, IntDistribution] = None
lowpass: Union[NumericValue, NumericValues, FloatDistribution, IntDistribution] = None
@field_validator("model_name") @field_validator("model_name")
def validate_model(cls, v): def validate_model(cls, v):
if "PhaseNet" not in v.value: if "PhaseNet" not in v.value:
@ -92,23 +95,24 @@ class PhaseNetParameters(Parameters):
return v return v
class GPDParameters(Parameters): class FilteringParameters(Parameters):
model_config = ConfigDict(extra='forbid') model_config = ConfigDict(extra='forbid')
highpass: Union[NumericValue, NumericValues, FloatDistribution] = None highpass: Union[NumericValue, NumericValues, FloatDistribution, IntDistribution] = None
lowpass: Union[NumericValue, NumericValues, FloatDistribution] = None lowpass: Union[NumericValue, NumericValues, FloatDistribution, IntDistribution] = None
@field_validator("model_name") @field_validator("model_name")
def validate_model(cls, v): def validate_model(cls, v):
if "GPD" not in v.value: print(v.value)
raise ValueError("Additional parameters implemented for GPD only") if v.value[0] not in ["GPD", "PhaseNet"]:
raise ValueError("Filtering parameters implemented for GPD and PhaseNet only")
class InputParams(BaseModel): class InputParams(BaseModel):
name: str name: str
method: str method: str
metric: Metric metric: Metric
parameters: Union[Parameters, PhaseNetParameters, GPDParameters] parameters: Union[Parameters, PhaseNetParameters, FilteringParameters]
def validate_sweep_yaml(yaml_filename, model_name=None): def validate_sweep_yaml(yaml_filename, model_name=None):
@ -138,5 +142,5 @@ def validate_sweep_config(sweep_config, model_name=None):
if __name__ == "__main__": if __name__ == "__main__":
yaml_filename = "../experiments/sweep_phasenet_lumineos_lr_bs.yaml" yaml_filename = "../experiments/sweep_phasenet_bogdanka_lr_bs.yaml"
validate_sweep_yaml(yaml_filename, None) validate_sweep_yaml(yaml_filename, None)

View File

@ -143,6 +143,9 @@ class PhaseNetLit(SeisBenchModuleLit):
self.loss = vector_cross_entropy self.loss = vector_cross_entropy
self.pretrained = kwargs.pop("pretrained", None) self.pretrained = kwargs.pop("pretrained", None)
self.norm = kwargs.pop("norm", "peak") self.norm = kwargs.pop("norm", "peak")
self.highpass = kwargs.pop("highpass", None)
self.lowpass = kwargs.pop("lowpass", None)
if self.pretrained is not None: if self.pretrained is not None:
self.model = sbm.PhaseNet.from_pretrained(self.pretrained) self.model = sbm.PhaseNet.from_pretrained(self.pretrained)
@ -152,6 +155,7 @@ class PhaseNetLit(SeisBenchModuleLit):
self.finetuning_strategy = kwargs.pop("finetuning_strategy", None) self.finetuning_strategy = kwargs.pop("finetuning_strategy", None)
self.steplr_gamma = kwargs.pop("steplr_gamma", 0.1) self.steplr_gamma = kwargs.pop("steplr_gamma", 0.1)
self.reduce_lr_on_plateau = False
self.initial_epochs = 0 self.initial_epochs = 0
@ -163,36 +167,33 @@ class PhaseNetLit(SeisBenchModuleLit):
self.freeze() self.freeze()
def forward(self, x): def forward(self, x):
return self.model(x) return self.model(x)
def shared_step(self, batch): def shared_step(self, batch):
x = batch["X"] x = batch["X"]
y_true = batch["y"] y_true = batch["y"]
y_pred = self.model(x) y_pred = self.model(x)
return self.loss(y_pred, y_true) return self.loss(y_pred, y_true)
def training_step(self, batch, batch_idx): def training_step(self, batch, batch_idx):
loss = self.shared_step(batch) loss = self.shared_step(batch)
self.log("train_loss", loss) self.log("train_loss", loss)
return loss return loss
def validation_step(self, batch, batch_idx): def validation_step(self, batch, batch_idx):
loss = self.shared_step(batch) loss = self.shared_step(batch)
self.log("val_loss", loss) self.log("val_loss", loss)
return loss return loss
def configure_optimizers(self): def configure_optimizers(self):
optimizer = torch.optim.Adam(self.parameters(), lr=self.lr) optimizer = torch.optim.Adam(self.parameters(), lr=self.lr)
if self.finetuning_strategy is not None: if self.finetuning_strategy is not None:
scheduler = lr_scheduler.LambdaLR(optimizer, self.lr_lambda) scheduler = lr_scheduler.LambdaLR(optimizer, self.lr_lambda)
self.reduce_lr_on_plateau = False
else: else:
scheduler = lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.1, patience=3) scheduler = lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.1, patience=3)
self.reduce_lr_on_plateau = True
# #
return { return {
'optimizer': optimizer, 'optimizer': optimizer,
@ -200,11 +201,10 @@ class PhaseNetLit(SeisBenchModuleLit):
'scheduler': scheduler, 'scheduler': scheduler,
'monitor': 'val_loss', 'monitor': 'val_loss',
'interval': 'epoch', 'interval': 'epoch',
'reduce_on_plateau': False, 'reduce_on_plateau': self.reduce_lr_on_plateau,
}, },
} }
def lr_lambda(self, epoch): def lr_lambda(self, epoch):
# reduce lr after x initial epochs # reduce lr after x initial epochs
if epoch == self.initial_epochs: if epoch == self.initial_epochs:
@ -212,15 +212,25 @@ class PhaseNetLit(SeisBenchModuleLit):
return self.lr return self.lr
def lr_scheduler_step(self, scheduler, metric): def lr_scheduler_step(self, scheduler, metric):
if self.reduce_lr_on_plateau:
scheduler.step(metric, epoch=self.current_epoch)
else:
scheduler.step(epoch=self.current_epoch) scheduler.step(epoch=self.current_epoch)
# def lr_scheduler_step(self, scheduler, optimizer_idx, metric): # def lr_scheduler_step(self, scheduler, optimizer_idx, metric):
# scheduler.step(epoch=self.current_epoch) # scheduler.step(epoch=self.current_epoch)
def get_augmentations(self): def get_augmentations(self):
filter = []
if self.highpass is not None:
filter = [sbg.Filter(1, self.highpass, "highpass", forward_backward=True)]
logger.info(f"Using highpass filer {self.highpass}")
if self.lowpass is not None:
filter += [sbg.Filter(1, self.lowpass, "lowpass", forward_backward=True)]
logger.info(f"Using lowpass filer {self.lowpass}")
logger.info(filter)
return [ return [
# In 2/3 of the cases, select windows around picks, to reduce amount of noise traces in training. # In 2/3 of the cases, select windows around picks, to reduce amount of noise traces in training.
# Uses strategy variable, as padding will be handled by the random window. # Uses strategy variable, as padding will be handled by the random window.
@ -244,22 +254,28 @@ class PhaseNetLit(SeisBenchModuleLit):
windowlen=3001, windowlen=3001,
strategy="pad", strategy="pad",
), ),
sbg.ChangeDtype(np.float32),
sbg.Normalize(demean_axis=-1, amp_norm_axis=-1, amp_norm_type=self.norm), sbg.Normalize(demean_axis=-1, amp_norm_axis=-1, amp_norm_type=self.norm),
*filter,
sbg.ChangeDtype(np.float32),
sbg.ProbabilisticLabeller( sbg.ProbabilisticLabeller(
label_columns=phase_dict, sigma=self.sigma, dim=0 label_columns=phase_dict, sigma=self.sigma, dim=0
), ),
] ]
def get_eval_augmentations(self): def get_eval_augmentations(self):
filter = []
if self.highpass is not None:
filter = [sbg.Filter(1, self.highpass, "highpass", forward_backward=True)]
if self.lowpass is not None:
filter += [sbg.Filter(1, self.lowpass, "lowpass", forward_backward=True)]
return [ return [
sbg.SteeredWindow(windowlen=3001, strategy="pad"), sbg.SteeredWindow(windowlen=3001, strategy="pad"),
sbg.ChangeDtype(np.float32),
sbg.Normalize(demean_axis=-1, amp_norm_axis=-1, amp_norm_type=self.norm), sbg.Normalize(demean_axis=-1, amp_norm_axis=-1, amp_norm_type=self.norm),
*filter,
sbg.ChangeDtype(np.float32),
] ]
def predict_step(self, batch, batch_idx=None, dataloader_idx=None): def predict_step(self, batch, batch_idx=None, dataloader_idx=None):
x = batch["X"] x = batch["X"]
window_borders = batch["window_borders"] window_borders = batch["window_borders"]
@ -1211,8 +1227,12 @@ def get_model_specific_args(config):
if 'highpass' in config: if 'highpass' in config:
args['highpass'] = config.highpass args['highpass'] = config.highpass
if 'lowpass' in config: if 'lowpass' in config:
args['lowpass'] = config.lowpass[0] args['lowpass'] = config.lowpass
case 'PhaseNet': case 'PhaseNet':
if 'highpass' in config:
args['highpass'] = config.highpass
if 'lowpass' in config:
args['lowpass'] = config.lowpass
if 'sample_boundaries' in config: if 'sample_boundaries' in config:
args['sample_boundaries'] = config.sample_boundaries args['sample_boundaries'] = config.sample_boundaries
case 'DPPPicker': case 'DPPPicker':