finetuning #1

Merged
k.milian merged 2 commits from finetuning into master 2024-07-29 13:41:06 +02:00
7 changed files with 72 additions and 61 deletions
Showing only changes of commit f40ac35cc8 - Show all commits

View File

@ -1,6 +1,6 @@
{
"dataset_name": "bogdanka",
"data_path": "datasets/bogdanka/seisbench_format/",
"dataset_name": "bogdanka_2018_2022",
"data_path": "datasets/bogdanka_2018_2022/seisbench_format/",
"targets_path": "datasets/targets",
"models_path": "weights",
"configs_path": "experiments",
@ -13,5 +13,5 @@
"BasicPhaseAE": "sweep_basicphase_ae.yaml",
"EQTransformer": "sweep_eqtransformer.yaml"
},
"experiment_count": 20
"experiment_count": 15
}

View File

@ -1,3 +1,4 @@
name: BasicPhaseAE
method: bayes
metric:
goal: minimize
@ -7,13 +8,9 @@ parameters:
value:
- BasicPhaseAE
batch_size:
distribution: int_uniform
max: 1024
min: 256
values: [64, 128, 256]
max_epochs:
value:
- 20
- 30
learning_rate:
distribution: uniform
max: 0.02
min: 0.001
values: [0.01, 0.005, 0.001]

View File

@ -8,13 +8,9 @@ parameters:
value:
- EQTransformer
batch_size:
distribution: int_uniform
max: 1024
min: 256
values: [64, 128, 256]
max_epochs:
value:
- 30
learning_rate:
distribution: uniform
max: 0.02
min: 0.005
values: [0.01, 0.005, 0.001]

View File

@ -1,4 +1,4 @@
name: GPD_fixed_highpass:2-10
name: GPD
method: bayes
metric:
goal: minimize
@ -8,16 +8,12 @@ parameters:
value:
- GPD
batch_size:
distribution: int_uniform
max: 1024
min: 256
values: [64, 128, 256]
max_epochs:
value:
- 30
learning_rate:
distribution: uniform
max: 0.02
min: 0.005
values: [0.01, 0.005, 0.001]
highpass:
value:
- 1

View File

@ -56,6 +56,10 @@ def get_trainer_args(config):
return trainer_args
def get_arg(arg):
if type(arg) == list:
return arg[0]
return arg
class HyperparameterSweep:
def __init__(self, project_name, sweep_config):
@ -87,10 +91,10 @@ class HyperparameterSweep:
return all_not_running
def run_experiment(self):
try:
logger.info("Starting a new run...")
run = wandb.init(
project=self.project_name,
config=config_loader.config,
@ -103,30 +107,24 @@ class HyperparameterSweep:
exclude_fn=lambda path: path.endswith("template.sh")
)
model_name = wandb.config.model_name[0]
model_name = get_arg(wandb.config.model_name)
model_args = models.get_model_specific_args(wandb.config)
if "pretrained" in wandb.config:
weights = wandb.config.get("pretrained")
if type(weights) == list:
weights = weights[0]
weights = get_arg(wandb.config.pretrained)
if weights != "false":
model_args["pretrained"] = weights
if "norm" in wandb.config:
model_args["norm"] = wandb.config.norm
logger.debug(f"Initializing {model_name}")
if "norm" in wandb.config:
model_args["norm"] = get_arg(wandb.config.norm)
if "finetuning" in wandb.config:
# train for a few epochs with some frozen params, then unfreeze and continue training
if type(wandb.config.finetuning) == list:
finetuning_strategy = wandb.config.finetuning[0]
else:
finetuning_strategy = wandb.config.finetuning
model_args['finetuning_strategy'] = finetuning_strategy
model_args['finetuning_strategy'] = get_arg(wandb.config.finetuning)
if "lr_reduce_factor" in wandb.config:
model_args['steplr_gamma'] = wandb.config.lr_reduce_factor
model_args['steplr_gamma'] = get_arg(wandb.config.lr_reduce_factor)
logger.debug(f"Initializing {model_name} with args: {model_args}")
model = models.__getattribute__(model_name + "Lit")(**model_args)
train_loader, dev_loader = train.prepare_data(wandb.config, model, test_run=False)

View File

@ -85,6 +85,9 @@ class PhaseNetParameters(Parameters):
finetuning: Finetuning = None
lr_reduce_factor: Optional[Union[FloatDistribution, NumericValue, NumericValues]] = None
highpass: Union[NumericValue, NumericValues, FloatDistribution, IntDistribution] = None
lowpass: Union[NumericValue, NumericValues, FloatDistribution, IntDistribution] = None
@field_validator("model_name")
def validate_model(cls, v):
if "PhaseNet" not in v.value:
@ -92,23 +95,24 @@ class PhaseNetParameters(Parameters):
return v
class GPDParameters(Parameters):
class FilteringParameters(Parameters):
model_config = ConfigDict(extra='forbid')
highpass: Union[NumericValue, NumericValues, FloatDistribution] = None
lowpass: Union[NumericValue, NumericValues, FloatDistribution] = None
highpass: Union[NumericValue, NumericValues, FloatDistribution, IntDistribution] = None
lowpass: Union[NumericValue, NumericValues, FloatDistribution, IntDistribution] = None
@field_validator("model_name")
def validate_model(cls, v):
if "GPD" not in v.value:
raise ValueError("Additional parameters implemented for GPD only")
print(v.value)
if v.value[0] not in ["GPD", "PhaseNet"]:
raise ValueError("Filtering parameters implemented for GPD and PhaseNet only")
class InputParams(BaseModel):
name: str
method: str
metric: Metric
parameters: Union[Parameters, PhaseNetParameters, GPDParameters]
parameters: Union[Parameters, PhaseNetParameters, FilteringParameters]
def validate_sweep_yaml(yaml_filename, model_name=None):
@ -138,5 +142,5 @@ def validate_sweep_config(sweep_config, model_name=None):
if __name__ == "__main__":
yaml_filename = "../experiments/sweep_phasenet_lumineos_lr_bs.yaml"
yaml_filename = "../experiments/sweep_phasenet_bogdanka_lr_bs.yaml"
validate_sweep_yaml(yaml_filename, None)

View File

@ -143,6 +143,9 @@ class PhaseNetLit(SeisBenchModuleLit):
self.loss = vector_cross_entropy
self.pretrained = kwargs.pop("pretrained", None)
self.norm = kwargs.pop("norm", "peak")
self.highpass = kwargs.pop("highpass", None)
self.lowpass = kwargs.pop("lowpass", None)
if self.pretrained is not None:
self.model = sbm.PhaseNet.from_pretrained(self.pretrained)
@ -152,6 +155,7 @@ class PhaseNetLit(SeisBenchModuleLit):
self.finetuning_strategy = kwargs.pop("finetuning_strategy", None)
self.steplr_gamma = kwargs.pop("steplr_gamma", 0.1)
self.reduce_lr_on_plateau = False
self.initial_epochs = 0
@ -163,36 +167,33 @@ class PhaseNetLit(SeisBenchModuleLit):
self.freeze()
def forward(self, x):
return self.model(x)
def shared_step(self, batch):
x = batch["X"]
y_true = batch["y"]
y_pred = self.model(x)
return self.loss(y_pred, y_true)
def training_step(self, batch, batch_idx):
loss = self.shared_step(batch)
self.log("train_loss", loss)
return loss
def validation_step(self, batch, batch_idx):
loss = self.shared_step(batch)
self.log("val_loss", loss)
return loss
def configure_optimizers(self):
optimizer = torch.optim.Adam(self.parameters(), lr=self.lr)
if self.finetuning_strategy is not None:
scheduler = lr_scheduler.LambdaLR(optimizer, self.lr_lambda)
self.reduce_lr_on_plateau = False
else:
scheduler = lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.1, patience=3)
self.reduce_lr_on_plateau = True
#
return {
'optimizer': optimizer,
@ -200,11 +201,10 @@ class PhaseNetLit(SeisBenchModuleLit):
'scheduler': scheduler,
'monitor': 'val_loss',
'interval': 'epoch',
'reduce_on_plateau': False,
'reduce_on_plateau': self.reduce_lr_on_plateau,
},
}
def lr_lambda(self, epoch):
# reduce lr after x initial epochs
if epoch == self.initial_epochs:
@ -212,15 +212,25 @@ class PhaseNetLit(SeisBenchModuleLit):
return self.lr
def lr_scheduler_step(self, scheduler, metric):
if self.reduce_lr_on_plateau:
scheduler.step(metric, epoch=self.current_epoch)
else:
scheduler.step(epoch=self.current_epoch)
# def lr_scheduler_step(self, scheduler, optimizer_idx, metric):
# scheduler.step(epoch=self.current_epoch)
def get_augmentations(self):
filter = []
if self.highpass is not None:
filter = [sbg.Filter(1, self.highpass, "highpass", forward_backward=True)]
logger.info(f"Using highpass filer {self.highpass}")
if self.lowpass is not None:
filter += [sbg.Filter(1, self.lowpass, "lowpass", forward_backward=True)]
logger.info(f"Using lowpass filer {self.lowpass}")
logger.info(filter)
return [
# In 2/3 of the cases, select windows around picks, to reduce amount of noise traces in training.
# Uses strategy variable, as padding will be handled by the random window.
@ -244,22 +254,28 @@ class PhaseNetLit(SeisBenchModuleLit):
windowlen=3001,
strategy="pad",
),
sbg.ChangeDtype(np.float32),
sbg.Normalize(demean_axis=-1, amp_norm_axis=-1, amp_norm_type=self.norm),
*filter,
sbg.ChangeDtype(np.float32),
sbg.ProbabilisticLabeller(
label_columns=phase_dict, sigma=self.sigma, dim=0
),
]
def get_eval_augmentations(self):
filter = []
if self.highpass is not None:
filter = [sbg.Filter(1, self.highpass, "highpass", forward_backward=True)]
if self.lowpass is not None:
filter += [sbg.Filter(1, self.lowpass, "lowpass", forward_backward=True)]
return [
sbg.SteeredWindow(windowlen=3001, strategy="pad"),
sbg.ChangeDtype(np.float32),
sbg.Normalize(demean_axis=-1, amp_norm_axis=-1, amp_norm_type=self.norm),
*filter,
sbg.ChangeDtype(np.float32),
]
def predict_step(self, batch, batch_idx=None, dataloader_idx=None):
x = batch["X"]
window_borders = batch["window_borders"]
@ -1211,8 +1227,12 @@ def get_model_specific_args(config):
if 'highpass' in config:
args['highpass'] = config.highpass
if 'lowpass' in config:
args['lowpass'] = config.lowpass[0]
args['lowpass'] = config.lowpass
case 'PhaseNet':
if 'highpass' in config:
args['highpass'] = config.highpass
if 'lowpass' in config:
args['lowpass'] = config.lowpass
if 'sample_boundaries' in config:
args['sample_boundaries'] = config.sample_boundaries
case 'DPPPicker':