finetuning #1
@ -1,6 +1,6 @@
|
|||||||
{
|
{
|
||||||
"dataset_name": "bogdanka",
|
"dataset_name": "bogdanka_2018_2022",
|
||||||
"data_path": "datasets/bogdanka/seisbench_format/",
|
"data_path": "datasets/bogdanka_2018_2022/seisbench_format/",
|
||||||
"targets_path": "datasets/targets",
|
"targets_path": "datasets/targets",
|
||||||
"models_path": "weights",
|
"models_path": "weights",
|
||||||
"configs_path": "experiments",
|
"configs_path": "experiments",
|
||||||
@ -13,5 +13,5 @@
|
|||||||
"BasicPhaseAE": "sweep_basicphase_ae.yaml",
|
"BasicPhaseAE": "sweep_basicphase_ae.yaml",
|
||||||
"EQTransformer": "sweep_eqtransformer.yaml"
|
"EQTransformer": "sweep_eqtransformer.yaml"
|
||||||
},
|
},
|
||||||
"experiment_count": 20
|
"experiment_count": 15
|
||||||
}
|
}
|
@ -1,3 +1,4 @@
|
|||||||
|
name: BasicPhaseAE
|
||||||
method: bayes
|
method: bayes
|
||||||
metric:
|
metric:
|
||||||
goal: minimize
|
goal: minimize
|
||||||
@ -7,13 +8,9 @@ parameters:
|
|||||||
value:
|
value:
|
||||||
- BasicPhaseAE
|
- BasicPhaseAE
|
||||||
batch_size:
|
batch_size:
|
||||||
distribution: int_uniform
|
values: [64, 128, 256]
|
||||||
max: 1024
|
|
||||||
min: 256
|
|
||||||
max_epochs:
|
max_epochs:
|
||||||
value:
|
value:
|
||||||
- 20
|
- 30
|
||||||
learning_rate:
|
learning_rate:
|
||||||
distribution: uniform
|
values: [0.01, 0.005, 0.001]
|
||||||
max: 0.02
|
|
||||||
min: 0.001
|
|
||||||
|
@ -8,13 +8,9 @@ parameters:
|
|||||||
value:
|
value:
|
||||||
- EQTransformer
|
- EQTransformer
|
||||||
batch_size:
|
batch_size:
|
||||||
distribution: int_uniform
|
values: [64, 128, 256]
|
||||||
max: 1024
|
|
||||||
min: 256
|
|
||||||
max_epochs:
|
max_epochs:
|
||||||
value:
|
value:
|
||||||
- 30
|
- 30
|
||||||
learning_rate:
|
learning_rate:
|
||||||
distribution: uniform
|
values: [0.01, 0.005, 0.001]
|
||||||
max: 0.02
|
|
||||||
min: 0.005
|
|
@ -1,4 +1,4 @@
|
|||||||
name: GPD_fixed_highpass:2-10
|
name: GPD
|
||||||
method: bayes
|
method: bayes
|
||||||
metric:
|
metric:
|
||||||
goal: minimize
|
goal: minimize
|
||||||
@ -8,16 +8,12 @@ parameters:
|
|||||||
value:
|
value:
|
||||||
- GPD
|
- GPD
|
||||||
batch_size:
|
batch_size:
|
||||||
distribution: int_uniform
|
values: [64, 128, 256]
|
||||||
max: 1024
|
|
||||||
min: 256
|
|
||||||
max_epochs:
|
max_epochs:
|
||||||
value:
|
value:
|
||||||
- 30
|
- 30
|
||||||
learning_rate:
|
learning_rate:
|
||||||
distribution: uniform
|
values: [0.01, 0.005, 0.001]
|
||||||
max: 0.02
|
|
||||||
min: 0.005
|
|
||||||
highpass:
|
highpass:
|
||||||
value:
|
value:
|
||||||
- 1
|
- 1
|
||||||
|
@ -56,6 +56,10 @@ def get_trainer_args(config):
|
|||||||
return trainer_args
|
return trainer_args
|
||||||
|
|
||||||
|
|
||||||
|
def get_arg(arg):
|
||||||
|
if type(arg) == list:
|
||||||
|
return arg[0]
|
||||||
|
return arg
|
||||||
|
|
||||||
class HyperparameterSweep:
|
class HyperparameterSweep:
|
||||||
def __init__(self, project_name, sweep_config):
|
def __init__(self, project_name, sweep_config):
|
||||||
@ -87,10 +91,10 @@ class HyperparameterSweep:
|
|||||||
return all_not_running
|
return all_not_running
|
||||||
|
|
||||||
def run_experiment(self):
|
def run_experiment(self):
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
|
||||||
logger.info("Starting a new run...")
|
logger.info("Starting a new run...")
|
||||||
|
|
||||||
run = wandb.init(
|
run = wandb.init(
|
||||||
project=self.project_name,
|
project=self.project_name,
|
||||||
config=config_loader.config,
|
config=config_loader.config,
|
||||||
@ -103,30 +107,24 @@ class HyperparameterSweep:
|
|||||||
exclude_fn=lambda path: path.endswith("template.sh")
|
exclude_fn=lambda path: path.endswith("template.sh")
|
||||||
)
|
)
|
||||||
|
|
||||||
model_name = wandb.config.model_name[0]
|
model_name = get_arg(wandb.config.model_name)
|
||||||
model_args = models.get_model_specific_args(wandb.config)
|
model_args = models.get_model_specific_args(wandb.config)
|
||||||
|
|
||||||
if "pretrained" in wandb.config:
|
if "pretrained" in wandb.config:
|
||||||
weights = wandb.config.get("pretrained")
|
weights = get_arg(wandb.config.pretrained)
|
||||||
if type(weights) == list:
|
|
||||||
weights = weights[0]
|
|
||||||
if weights != "false":
|
if weights != "false":
|
||||||
model_args["pretrained"] = weights
|
model_args["pretrained"] = weights
|
||||||
if "norm" in wandb.config:
|
|
||||||
model_args["norm"] = wandb.config.norm
|
|
||||||
|
|
||||||
logger.debug(f"Initializing {model_name}")
|
if "norm" in wandb.config:
|
||||||
|
model_args["norm"] = get_arg(wandb.config.norm)
|
||||||
|
|
||||||
if "finetuning" in wandb.config:
|
if "finetuning" in wandb.config:
|
||||||
# train for a few epochs with some frozen params, then unfreeze and continue training
|
model_args['finetuning_strategy'] = get_arg(wandb.config.finetuning)
|
||||||
if type(wandb.config.finetuning) == list:
|
|
||||||
finetuning_strategy = wandb.config.finetuning[0]
|
|
||||||
else:
|
|
||||||
finetuning_strategy = wandb.config.finetuning
|
|
||||||
model_args['finetuning_strategy'] = finetuning_strategy
|
|
||||||
|
|
||||||
if "lr_reduce_factor" in wandb.config:
|
if "lr_reduce_factor" in wandb.config:
|
||||||
model_args['steplr_gamma'] = wandb.config.lr_reduce_factor
|
model_args['steplr_gamma'] = get_arg(wandb.config.lr_reduce_factor)
|
||||||
|
|
||||||
|
|
||||||
|
logger.debug(f"Initializing {model_name} with args: {model_args}")
|
||||||
model = models.__getattribute__(model_name + "Lit")(**model_args)
|
model = models.__getattribute__(model_name + "Lit")(**model_args)
|
||||||
|
|
||||||
train_loader, dev_loader = train.prepare_data(wandb.config, model, test_run=False)
|
train_loader, dev_loader = train.prepare_data(wandb.config, model, test_run=False)
|
||||||
|
@ -85,6 +85,9 @@ class PhaseNetParameters(Parameters):
|
|||||||
finetuning: Finetuning = None
|
finetuning: Finetuning = None
|
||||||
lr_reduce_factor: Optional[Union[FloatDistribution, NumericValue, NumericValues]] = None
|
lr_reduce_factor: Optional[Union[FloatDistribution, NumericValue, NumericValues]] = None
|
||||||
|
|
||||||
|
highpass: Union[NumericValue, NumericValues, FloatDistribution, IntDistribution] = None
|
||||||
|
lowpass: Union[NumericValue, NumericValues, FloatDistribution, IntDistribution] = None
|
||||||
|
|
||||||
@field_validator("model_name")
|
@field_validator("model_name")
|
||||||
def validate_model(cls, v):
|
def validate_model(cls, v):
|
||||||
if "PhaseNet" not in v.value:
|
if "PhaseNet" not in v.value:
|
||||||
@ -92,23 +95,24 @@ class PhaseNetParameters(Parameters):
|
|||||||
return v
|
return v
|
||||||
|
|
||||||
|
|
||||||
class GPDParameters(Parameters):
|
class FilteringParameters(Parameters):
|
||||||
model_config = ConfigDict(extra='forbid')
|
model_config = ConfigDict(extra='forbid')
|
||||||
|
|
||||||
highpass: Union[NumericValue, NumericValues, FloatDistribution] = None
|
highpass: Union[NumericValue, NumericValues, FloatDistribution, IntDistribution] = None
|
||||||
lowpass: Union[NumericValue, NumericValues, FloatDistribution] = None
|
lowpass: Union[NumericValue, NumericValues, FloatDistribution, IntDistribution] = None
|
||||||
|
|
||||||
@field_validator("model_name")
|
@field_validator("model_name")
|
||||||
def validate_model(cls, v):
|
def validate_model(cls, v):
|
||||||
if "GPD" not in v.value:
|
print(v.value)
|
||||||
raise ValueError("Additional parameters implemented for GPD only")
|
if v.value[0] not in ["GPD", "PhaseNet"]:
|
||||||
|
raise ValueError("Filtering parameters implemented for GPD and PhaseNet only")
|
||||||
|
|
||||||
|
|
||||||
class InputParams(BaseModel):
|
class InputParams(BaseModel):
|
||||||
name: str
|
name: str
|
||||||
method: str
|
method: str
|
||||||
metric: Metric
|
metric: Metric
|
||||||
parameters: Union[Parameters, PhaseNetParameters, GPDParameters]
|
parameters: Union[Parameters, PhaseNetParameters, FilteringParameters]
|
||||||
|
|
||||||
|
|
||||||
def validate_sweep_yaml(yaml_filename, model_name=None):
|
def validate_sweep_yaml(yaml_filename, model_name=None):
|
||||||
@ -138,5 +142,5 @@ def validate_sweep_config(sweep_config, model_name=None):
|
|||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
yaml_filename = "../experiments/sweep_phasenet_lumineos_lr_bs.yaml"
|
yaml_filename = "../experiments/sweep_phasenet_bogdanka_lr_bs.yaml"
|
||||||
validate_sweep_yaml(yaml_filename, None)
|
validate_sweep_yaml(yaml_filename, None)
|
||||||
|
@ -143,6 +143,9 @@ class PhaseNetLit(SeisBenchModuleLit):
|
|||||||
self.loss = vector_cross_entropy
|
self.loss = vector_cross_entropy
|
||||||
self.pretrained = kwargs.pop("pretrained", None)
|
self.pretrained = kwargs.pop("pretrained", None)
|
||||||
self.norm = kwargs.pop("norm", "peak")
|
self.norm = kwargs.pop("norm", "peak")
|
||||||
|
self.highpass = kwargs.pop("highpass", None)
|
||||||
|
self.lowpass = kwargs.pop("lowpass", None)
|
||||||
|
|
||||||
|
|
||||||
if self.pretrained is not None:
|
if self.pretrained is not None:
|
||||||
self.model = sbm.PhaseNet.from_pretrained(self.pretrained)
|
self.model = sbm.PhaseNet.from_pretrained(self.pretrained)
|
||||||
@ -152,6 +155,7 @@ class PhaseNetLit(SeisBenchModuleLit):
|
|||||||
|
|
||||||
self.finetuning_strategy = kwargs.pop("finetuning_strategy", None)
|
self.finetuning_strategy = kwargs.pop("finetuning_strategy", None)
|
||||||
self.steplr_gamma = kwargs.pop("steplr_gamma", 0.1)
|
self.steplr_gamma = kwargs.pop("steplr_gamma", 0.1)
|
||||||
|
self.reduce_lr_on_plateau = False
|
||||||
|
|
||||||
self.initial_epochs = 0
|
self.initial_epochs = 0
|
||||||
|
|
||||||
@ -163,36 +167,33 @@ class PhaseNetLit(SeisBenchModuleLit):
|
|||||||
|
|
||||||
self.freeze()
|
self.freeze()
|
||||||
|
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
return self.model(x)
|
return self.model(x)
|
||||||
|
|
||||||
|
|
||||||
def shared_step(self, batch):
|
def shared_step(self, batch):
|
||||||
x = batch["X"]
|
x = batch["X"]
|
||||||
y_true = batch["y"]
|
y_true = batch["y"]
|
||||||
y_pred = self.model(x)
|
y_pred = self.model(x)
|
||||||
return self.loss(y_pred, y_true)
|
return self.loss(y_pred, y_true)
|
||||||
|
|
||||||
|
|
||||||
def training_step(self, batch, batch_idx):
|
def training_step(self, batch, batch_idx):
|
||||||
loss = self.shared_step(batch)
|
loss = self.shared_step(batch)
|
||||||
self.log("train_loss", loss)
|
self.log("train_loss", loss)
|
||||||
return loss
|
return loss
|
||||||
|
|
||||||
|
|
||||||
def validation_step(self, batch, batch_idx):
|
def validation_step(self, batch, batch_idx):
|
||||||
loss = self.shared_step(batch)
|
loss = self.shared_step(batch)
|
||||||
self.log("val_loss", loss)
|
self.log("val_loss", loss)
|
||||||
return loss
|
return loss
|
||||||
|
|
||||||
|
|
||||||
def configure_optimizers(self):
|
def configure_optimizers(self):
|
||||||
optimizer = torch.optim.Adam(self.parameters(), lr=self.lr)
|
optimizer = torch.optim.Adam(self.parameters(), lr=self.lr)
|
||||||
if self.finetuning_strategy is not None:
|
if self.finetuning_strategy is not None:
|
||||||
scheduler = lr_scheduler.LambdaLR(optimizer, self.lr_lambda)
|
scheduler = lr_scheduler.LambdaLR(optimizer, self.lr_lambda)
|
||||||
|
self.reduce_lr_on_plateau = False
|
||||||
else:
|
else:
|
||||||
scheduler = lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.1, patience=3)
|
scheduler = lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.1, patience=3)
|
||||||
|
self.reduce_lr_on_plateau = True
|
||||||
#
|
#
|
||||||
return {
|
return {
|
||||||
'optimizer': optimizer,
|
'optimizer': optimizer,
|
||||||
@ -200,11 +201,10 @@ class PhaseNetLit(SeisBenchModuleLit):
|
|||||||
'scheduler': scheduler,
|
'scheduler': scheduler,
|
||||||
'monitor': 'val_loss',
|
'monitor': 'val_loss',
|
||||||
'interval': 'epoch',
|
'interval': 'epoch',
|
||||||
'reduce_on_plateau': False,
|
'reduce_on_plateau': self.reduce_lr_on_plateau,
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
def lr_lambda(self, epoch):
|
def lr_lambda(self, epoch):
|
||||||
# reduce lr after x initial epochs
|
# reduce lr after x initial epochs
|
||||||
if epoch == self.initial_epochs:
|
if epoch == self.initial_epochs:
|
||||||
@ -212,15 +212,25 @@ class PhaseNetLit(SeisBenchModuleLit):
|
|||||||
|
|
||||||
return self.lr
|
return self.lr
|
||||||
|
|
||||||
|
|
||||||
def lr_scheduler_step(self, scheduler, metric):
|
def lr_scheduler_step(self, scheduler, metric):
|
||||||
|
if self.reduce_lr_on_plateau:
|
||||||
|
scheduler.step(metric, epoch=self.current_epoch)
|
||||||
|
else:
|
||||||
scheduler.step(epoch=self.current_epoch)
|
scheduler.step(epoch=self.current_epoch)
|
||||||
|
|
||||||
|
|
||||||
# def lr_scheduler_step(self, scheduler, optimizer_idx, metric):
|
# def lr_scheduler_step(self, scheduler, optimizer_idx, metric):
|
||||||
# scheduler.step(epoch=self.current_epoch)
|
# scheduler.step(epoch=self.current_epoch)
|
||||||
|
|
||||||
def get_augmentations(self):
|
def get_augmentations(self):
|
||||||
|
filter = []
|
||||||
|
if self.highpass is not None:
|
||||||
|
filter = [sbg.Filter(1, self.highpass, "highpass", forward_backward=True)]
|
||||||
|
logger.info(f"Using highpass filer {self.highpass}")
|
||||||
|
if self.lowpass is not None:
|
||||||
|
filter += [sbg.Filter(1, self.lowpass, "lowpass", forward_backward=True)]
|
||||||
|
logger.info(f"Using lowpass filer {self.lowpass}")
|
||||||
|
logger.info(filter)
|
||||||
|
|
||||||
return [
|
return [
|
||||||
# In 2/3 of the cases, select windows around picks, to reduce amount of noise traces in training.
|
# In 2/3 of the cases, select windows around picks, to reduce amount of noise traces in training.
|
||||||
# Uses strategy variable, as padding will be handled by the random window.
|
# Uses strategy variable, as padding will be handled by the random window.
|
||||||
@ -244,22 +254,28 @@ class PhaseNetLit(SeisBenchModuleLit):
|
|||||||
windowlen=3001,
|
windowlen=3001,
|
||||||
strategy="pad",
|
strategy="pad",
|
||||||
),
|
),
|
||||||
sbg.ChangeDtype(np.float32),
|
|
||||||
sbg.Normalize(demean_axis=-1, amp_norm_axis=-1, amp_norm_type=self.norm),
|
sbg.Normalize(demean_axis=-1, amp_norm_axis=-1, amp_norm_type=self.norm),
|
||||||
|
*filter,
|
||||||
|
sbg.ChangeDtype(np.float32),
|
||||||
sbg.ProbabilisticLabeller(
|
sbg.ProbabilisticLabeller(
|
||||||
label_columns=phase_dict, sigma=self.sigma, dim=0
|
label_columns=phase_dict, sigma=self.sigma, dim=0
|
||||||
),
|
),
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
def get_eval_augmentations(self):
|
def get_eval_augmentations(self):
|
||||||
|
filter = []
|
||||||
|
if self.highpass is not None:
|
||||||
|
filter = [sbg.Filter(1, self.highpass, "highpass", forward_backward=True)]
|
||||||
|
if self.lowpass is not None:
|
||||||
|
filter += [sbg.Filter(1, self.lowpass, "lowpass", forward_backward=True)]
|
||||||
|
|
||||||
return [
|
return [
|
||||||
sbg.SteeredWindow(windowlen=3001, strategy="pad"),
|
sbg.SteeredWindow(windowlen=3001, strategy="pad"),
|
||||||
sbg.ChangeDtype(np.float32),
|
|
||||||
sbg.Normalize(demean_axis=-1, amp_norm_axis=-1, amp_norm_type=self.norm),
|
sbg.Normalize(demean_axis=-1, amp_norm_axis=-1, amp_norm_type=self.norm),
|
||||||
|
*filter,
|
||||||
|
sbg.ChangeDtype(np.float32),
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
def predict_step(self, batch, batch_idx=None, dataloader_idx=None):
|
def predict_step(self, batch, batch_idx=None, dataloader_idx=None):
|
||||||
x = batch["X"]
|
x = batch["X"]
|
||||||
window_borders = batch["window_borders"]
|
window_borders = batch["window_borders"]
|
||||||
@ -1211,8 +1227,12 @@ def get_model_specific_args(config):
|
|||||||
if 'highpass' in config:
|
if 'highpass' in config:
|
||||||
args['highpass'] = config.highpass
|
args['highpass'] = config.highpass
|
||||||
if 'lowpass' in config:
|
if 'lowpass' in config:
|
||||||
args['lowpass'] = config.lowpass[0]
|
args['lowpass'] = config.lowpass
|
||||||
case 'PhaseNet':
|
case 'PhaseNet':
|
||||||
|
if 'highpass' in config:
|
||||||
|
args['highpass'] = config.highpass
|
||||||
|
if 'lowpass' in config:
|
||||||
|
args['lowpass'] = config.lowpass
|
||||||
if 'sample_boundaries' in config:
|
if 'sample_boundaries' in config:
|
||||||
args['sample_boundaries'] = config.sample_boundaries
|
args['sample_boundaries'] = config.sample_boundaries
|
||||||
case 'DPPPicker':
|
case 'DPPPicker':
|
||||||
|
Loading…
Reference in New Issue
Block a user