finetuning #1
@ -1,6 +1,6 @@
|
||||
{
|
||||
"dataset_name": "bogdanka",
|
||||
"data_path": "datasets/bogdanka/seisbench_format/",
|
||||
"dataset_name": "bogdanka_2018_2022",
|
||||
"data_path": "datasets/bogdanka_2018_2022/seisbench_format/",
|
||||
"targets_path": "datasets/targets",
|
||||
"models_path": "weights",
|
||||
"configs_path": "experiments",
|
||||
@ -13,5 +13,5 @@
|
||||
"BasicPhaseAE": "sweep_basicphase_ae.yaml",
|
||||
"EQTransformer": "sweep_eqtransformer.yaml"
|
||||
},
|
||||
"experiment_count": 20
|
||||
"experiment_count": 15
|
||||
}
|
@ -1,3 +1,4 @@
|
||||
name: BasicPhaseAE
|
||||
method: bayes
|
||||
metric:
|
||||
goal: minimize
|
||||
@ -7,13 +8,9 @@ parameters:
|
||||
value:
|
||||
- BasicPhaseAE
|
||||
batch_size:
|
||||
distribution: int_uniform
|
||||
max: 1024
|
||||
min: 256
|
||||
values: [64, 128, 256]
|
||||
max_epochs:
|
||||
value:
|
||||
- 20
|
||||
- 30
|
||||
learning_rate:
|
||||
distribution: uniform
|
||||
max: 0.02
|
||||
min: 0.001
|
||||
values: [0.01, 0.005, 0.001]
|
||||
|
@ -8,13 +8,9 @@ parameters:
|
||||
value:
|
||||
- EQTransformer
|
||||
batch_size:
|
||||
distribution: int_uniform
|
||||
max: 1024
|
||||
min: 256
|
||||
values: [64, 128, 256]
|
||||
max_epochs:
|
||||
value:
|
||||
- 30
|
||||
learning_rate:
|
||||
distribution: uniform
|
||||
max: 0.02
|
||||
min: 0.005
|
||||
values: [0.01, 0.005, 0.001]
|
@ -1,4 +1,4 @@
|
||||
name: GPD_fixed_highpass:2-10
|
||||
name: GPD
|
||||
method: bayes
|
||||
metric:
|
||||
goal: minimize
|
||||
@ -8,16 +8,12 @@ parameters:
|
||||
value:
|
||||
- GPD
|
||||
batch_size:
|
||||
distribution: int_uniform
|
||||
max: 1024
|
||||
min: 256
|
||||
values: [64, 128, 256]
|
||||
max_epochs:
|
||||
value:
|
||||
- 30
|
||||
learning_rate:
|
||||
distribution: uniform
|
||||
max: 0.02
|
||||
min: 0.005
|
||||
values: [0.01, 0.005, 0.001]
|
||||
highpass:
|
||||
value:
|
||||
- 1
|
||||
|
@ -56,6 +56,10 @@ def get_trainer_args(config):
|
||||
return trainer_args
|
||||
|
||||
|
||||
def get_arg(arg):
|
||||
if type(arg) == list:
|
||||
return arg[0]
|
||||
return arg
|
||||
|
||||
class HyperparameterSweep:
|
||||
def __init__(self, project_name, sweep_config):
|
||||
@ -87,10 +91,10 @@ class HyperparameterSweep:
|
||||
return all_not_running
|
||||
|
||||
def run_experiment(self):
|
||||
|
||||
try:
|
||||
|
||||
logger.info("Starting a new run...")
|
||||
|
||||
run = wandb.init(
|
||||
project=self.project_name,
|
||||
config=config_loader.config,
|
||||
@ -103,30 +107,24 @@ class HyperparameterSweep:
|
||||
exclude_fn=lambda path: path.endswith("template.sh")
|
||||
)
|
||||
|
||||
model_name = wandb.config.model_name[0]
|
||||
model_name = get_arg(wandb.config.model_name)
|
||||
model_args = models.get_model_specific_args(wandb.config)
|
||||
|
||||
if "pretrained" in wandb.config:
|
||||
weights = wandb.config.get("pretrained")
|
||||
if type(weights) == list:
|
||||
weights = weights[0]
|
||||
weights = get_arg(wandb.config.pretrained)
|
||||
if weights != "false":
|
||||
model_args["pretrained"] = weights
|
||||
if "norm" in wandb.config:
|
||||
model_args["norm"] = wandb.config.norm
|
||||
|
||||
logger.debug(f"Initializing {model_name}")
|
||||
if "norm" in wandb.config:
|
||||
model_args["norm"] = get_arg(wandb.config.norm)
|
||||
|
||||
if "finetuning" in wandb.config:
|
||||
# train for a few epochs with some frozen params, then unfreeze and continue training
|
||||
if type(wandb.config.finetuning) == list:
|
||||
finetuning_strategy = wandb.config.finetuning[0]
|
||||
else:
|
||||
finetuning_strategy = wandb.config.finetuning
|
||||
model_args['finetuning_strategy'] = finetuning_strategy
|
||||
model_args['finetuning_strategy'] = get_arg(wandb.config.finetuning)
|
||||
|
||||
if "lr_reduce_factor" in wandb.config:
|
||||
model_args['steplr_gamma'] = wandb.config.lr_reduce_factor
|
||||
|
||||
model_args['steplr_gamma'] = get_arg(wandb.config.lr_reduce_factor)
|
||||
|
||||
logger.debug(f"Initializing {model_name} with args: {model_args}")
|
||||
model = models.__getattribute__(model_name + "Lit")(**model_args)
|
||||
|
||||
train_loader, dev_loader = train.prepare_data(wandb.config, model, test_run=False)
|
||||
|
@ -85,6 +85,9 @@ class PhaseNetParameters(Parameters):
|
||||
finetuning: Finetuning = None
|
||||
lr_reduce_factor: Optional[Union[FloatDistribution, NumericValue, NumericValues]] = None
|
||||
|
||||
highpass: Union[NumericValue, NumericValues, FloatDistribution, IntDistribution] = None
|
||||
lowpass: Union[NumericValue, NumericValues, FloatDistribution, IntDistribution] = None
|
||||
|
||||
@field_validator("model_name")
|
||||
def validate_model(cls, v):
|
||||
if "PhaseNet" not in v.value:
|
||||
@ -92,23 +95,24 @@ class PhaseNetParameters(Parameters):
|
||||
return v
|
||||
|
||||
|
||||
class GPDParameters(Parameters):
|
||||
class FilteringParameters(Parameters):
|
||||
model_config = ConfigDict(extra='forbid')
|
||||
|
||||
highpass: Union[NumericValue, NumericValues, FloatDistribution] = None
|
||||
lowpass: Union[NumericValue, NumericValues, FloatDistribution] = None
|
||||
highpass: Union[NumericValue, NumericValues, FloatDistribution, IntDistribution] = None
|
||||
lowpass: Union[NumericValue, NumericValues, FloatDistribution, IntDistribution] = None
|
||||
|
||||
@field_validator("model_name")
|
||||
def validate_model(cls, v):
|
||||
if "GPD" not in v.value:
|
||||
raise ValueError("Additional parameters implemented for GPD only")
|
||||
print(v.value)
|
||||
if v.value[0] not in ["GPD", "PhaseNet"]:
|
||||
raise ValueError("Filtering parameters implemented for GPD and PhaseNet only")
|
||||
|
||||
|
||||
class InputParams(BaseModel):
|
||||
name: str
|
||||
method: str
|
||||
metric: Metric
|
||||
parameters: Union[Parameters, PhaseNetParameters, GPDParameters]
|
||||
parameters: Union[Parameters, PhaseNetParameters, FilteringParameters]
|
||||
|
||||
|
||||
def validate_sweep_yaml(yaml_filename, model_name=None):
|
||||
@ -138,5 +142,5 @@ def validate_sweep_config(sweep_config, model_name=None):
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
yaml_filename = "../experiments/sweep_phasenet_lumineos_lr_bs.yaml"
|
||||
yaml_filename = "../experiments/sweep_phasenet_bogdanka_lr_bs.yaml"
|
||||
validate_sweep_yaml(yaml_filename, None)
|
||||
|
@ -143,6 +143,9 @@ class PhaseNetLit(SeisBenchModuleLit):
|
||||
self.loss = vector_cross_entropy
|
||||
self.pretrained = kwargs.pop("pretrained", None)
|
||||
self.norm = kwargs.pop("norm", "peak")
|
||||
self.highpass = kwargs.pop("highpass", None)
|
||||
self.lowpass = kwargs.pop("lowpass", None)
|
||||
|
||||
|
||||
if self.pretrained is not None:
|
||||
self.model = sbm.PhaseNet.from_pretrained(self.pretrained)
|
||||
@ -152,6 +155,7 @@ class PhaseNetLit(SeisBenchModuleLit):
|
||||
|
||||
self.finetuning_strategy = kwargs.pop("finetuning_strategy", None)
|
||||
self.steplr_gamma = kwargs.pop("steplr_gamma", 0.1)
|
||||
self.reduce_lr_on_plateau = False
|
||||
|
||||
self.initial_epochs = 0
|
||||
|
||||
@ -163,36 +167,33 @@ class PhaseNetLit(SeisBenchModuleLit):
|
||||
|
||||
self.freeze()
|
||||
|
||||
|
||||
def forward(self, x):
|
||||
return self.model(x)
|
||||
|
||||
|
||||
def shared_step(self, batch):
|
||||
x = batch["X"]
|
||||
y_true = batch["y"]
|
||||
y_pred = self.model(x)
|
||||
return self.loss(y_pred, y_true)
|
||||
|
||||
|
||||
def training_step(self, batch, batch_idx):
|
||||
loss = self.shared_step(batch)
|
||||
self.log("train_loss", loss)
|
||||
return loss
|
||||
|
||||
|
||||
def validation_step(self, batch, batch_idx):
|
||||
loss = self.shared_step(batch)
|
||||
self.log("val_loss", loss)
|
||||
return loss
|
||||
|
||||
|
||||
def configure_optimizers(self):
|
||||
optimizer = torch.optim.Adam(self.parameters(), lr=self.lr)
|
||||
if self.finetuning_strategy is not None:
|
||||
scheduler = lr_scheduler.LambdaLR(optimizer, self.lr_lambda)
|
||||
self.reduce_lr_on_plateau = False
|
||||
else:
|
||||
scheduler = lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.1, patience=3)
|
||||
self.reduce_lr_on_plateau = True
|
||||
#
|
||||
return {
|
||||
'optimizer': optimizer,
|
||||
@ -200,11 +201,10 @@ class PhaseNetLit(SeisBenchModuleLit):
|
||||
'scheduler': scheduler,
|
||||
'monitor': 'val_loss',
|
||||
'interval': 'epoch',
|
||||
'reduce_on_plateau': False,
|
||||
'reduce_on_plateau': self.reduce_lr_on_plateau,
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
def lr_lambda(self, epoch):
|
||||
# reduce lr after x initial epochs
|
||||
if epoch == self.initial_epochs:
|
||||
@ -212,15 +212,25 @@ class PhaseNetLit(SeisBenchModuleLit):
|
||||
|
||||
return self.lr
|
||||
|
||||
|
||||
def lr_scheduler_step(self, scheduler, metric):
|
||||
if self.reduce_lr_on_plateau:
|
||||
scheduler.step(metric, epoch=self.current_epoch)
|
||||
else:
|
||||
scheduler.step(epoch=self.current_epoch)
|
||||
|
||||
|
||||
# def lr_scheduler_step(self, scheduler, optimizer_idx, metric):
|
||||
# scheduler.step(epoch=self.current_epoch)
|
||||
|
||||
def get_augmentations(self):
|
||||
filter = []
|
||||
if self.highpass is not None:
|
||||
filter = [sbg.Filter(1, self.highpass, "highpass", forward_backward=True)]
|
||||
logger.info(f"Using highpass filer {self.highpass}")
|
||||
if self.lowpass is not None:
|
||||
filter += [sbg.Filter(1, self.lowpass, "lowpass", forward_backward=True)]
|
||||
logger.info(f"Using lowpass filer {self.lowpass}")
|
||||
logger.info(filter)
|
||||
|
||||
return [
|
||||
# In 2/3 of the cases, select windows around picks, to reduce amount of noise traces in training.
|
||||
# Uses strategy variable, as padding will be handled by the random window.
|
||||
@ -244,22 +254,28 @@ class PhaseNetLit(SeisBenchModuleLit):
|
||||
windowlen=3001,
|
||||
strategy="pad",
|
||||
),
|
||||
sbg.ChangeDtype(np.float32),
|
||||
sbg.Normalize(demean_axis=-1, amp_norm_axis=-1, amp_norm_type=self.norm),
|
||||
*filter,
|
||||
sbg.ChangeDtype(np.float32),
|
||||
sbg.ProbabilisticLabeller(
|
||||
label_columns=phase_dict, sigma=self.sigma, dim=0
|
||||
),
|
||||
]
|
||||
|
||||
|
||||
def get_eval_augmentations(self):
|
||||
filter = []
|
||||
if self.highpass is not None:
|
||||
filter = [sbg.Filter(1, self.highpass, "highpass", forward_backward=True)]
|
||||
if self.lowpass is not None:
|
||||
filter += [sbg.Filter(1, self.lowpass, "lowpass", forward_backward=True)]
|
||||
|
||||
return [
|
||||
sbg.SteeredWindow(windowlen=3001, strategy="pad"),
|
||||
sbg.ChangeDtype(np.float32),
|
||||
sbg.Normalize(demean_axis=-1, amp_norm_axis=-1, amp_norm_type=self.norm),
|
||||
*filter,
|
||||
sbg.ChangeDtype(np.float32),
|
||||
]
|
||||
|
||||
|
||||
def predict_step(self, batch, batch_idx=None, dataloader_idx=None):
|
||||
x = batch["X"]
|
||||
window_borders = batch["window_borders"]
|
||||
@ -1211,8 +1227,12 @@ def get_model_specific_args(config):
|
||||
if 'highpass' in config:
|
||||
args['highpass'] = config.highpass
|
||||
if 'lowpass' in config:
|
||||
args['lowpass'] = config.lowpass[0]
|
||||
args['lowpass'] = config.lowpass
|
||||
case 'PhaseNet':
|
||||
if 'highpass' in config:
|
||||
args['highpass'] = config.highpass
|
||||
if 'lowpass' in config:
|
||||
args['lowpass'] = config.lowpass
|
||||
if 'sample_boundaries' in config:
|
||||
args['sample_boundaries'] = config.sample_boundaries
|
||||
case 'DPPPicker':
|
||||
|
Loading…
Reference in New Issue
Block a user