147 lines
4.4 KiB
Python
147 lines
4.4 KiB
Python
|
from pydantic import BaseModel, ConfigDict, field_validator
|
||
|
from typing_extensions import Literal
|
||
|
from typing import Union, List, Optional
|
||
|
import yaml
|
||
|
import logging
|
||
|
|
||
|
logging.root.setLevel(logging.INFO)
|
||
|
logger = logging.getLogger('input_validator')
|
||
|
|
||
|
#todo
|
||
|
# 1. check if a single value is allowed in a sweep
|
||
|
# 2. merge input params
|
||
|
# 3. change names of the classes
|
||
|
# 4. add constraints for PhaseNet, GPD
|
||
|
|
||
|
|
||
|
model_names = Literal["PhaseNet", "GPD", "BasicPhaseAE", "EQTransformer"]
|
||
|
norm_values = Literal["peak", "std"]
|
||
|
finetuning_values = Literal["all", "top", "decoder", "encoder"]
|
||
|
pretrained_values = Literal['diting', 'ethz', 'geofon', 'instance', 'iquique', 'lendb', 'neic',
|
||
|
'original', 'scedc', False]
|
||
|
|
||
|
|
||
|
class Metric(BaseModel):
|
||
|
goal: str
|
||
|
name: str
|
||
|
|
||
|
|
||
|
class NumericValue(BaseModel):
|
||
|
value: Union[int, float, List[Union[int, float]]]
|
||
|
|
||
|
|
||
|
class NumericValues(BaseModel):
|
||
|
values: List[Union[int, float]]
|
||
|
|
||
|
|
||
|
class IntDistribution(BaseModel):
|
||
|
distribution: str = "int_uniform"
|
||
|
min: int
|
||
|
max: int
|
||
|
|
||
|
|
||
|
class FloatDistribution(BaseModel):
|
||
|
distribution: str = "uniform"
|
||
|
min: float
|
||
|
max: float
|
||
|
|
||
|
|
||
|
class Pretrained(BaseModel):
|
||
|
distribution: Optional[str] = "categorical"
|
||
|
values: List[pretrained_values] = None
|
||
|
value: Union[pretrained_values, List[pretrained_values]] = None
|
||
|
|
||
|
|
||
|
class Finetuning(BaseModel):
|
||
|
distribution: Optional[str] = "categorical"
|
||
|
values: List[finetuning_values] = None
|
||
|
value: Union[finetuning_values, List[finetuning_values]] = None
|
||
|
|
||
|
|
||
|
class Norm(BaseModel):
|
||
|
distribution: Optional[str] = "categorical"
|
||
|
values: List[norm_values] = None
|
||
|
value: Union[norm_values, List[norm_values]] = None
|
||
|
|
||
|
|
||
|
class ModelType(BaseModel):
|
||
|
distribution: Optional[str] = "categorical"
|
||
|
value: Union[model_names, List[model_names]] = None
|
||
|
values: List[model_names] = None
|
||
|
|
||
|
|
||
|
class Parameters(BaseModel):
|
||
|
model_config = ConfigDict(extra='forbid', protected_namespaces=())
|
||
|
model_name: ModelType
|
||
|
batch_size: Union[IntDistribution, NumericValue, NumericValues]
|
||
|
learning_rate: Union[FloatDistribution, NumericValue, NumericValues]
|
||
|
max_epochs: Union[IntDistribution, NumericValue, NumericValues]
|
||
|
|
||
|
|
||
|
class PhaseNetParameters(Parameters):
|
||
|
model_config = ConfigDict(extra='forbid')
|
||
|
norm: Norm = None
|
||
|
pretrained: Pretrained = None
|
||
|
finetuning: Finetuning = None
|
||
|
lr_reduce_factor: Optional[Union[FloatDistribution, NumericValue, NumericValues]] = None
|
||
|
|
||
|
highpass: Union[NumericValue, NumericValues, FloatDistribution, IntDistribution] = None
|
||
|
lowpass: Union[NumericValue, NumericValues, FloatDistribution, IntDistribution] = None
|
||
|
|
||
|
@field_validator("model_name")
|
||
|
def validate_model(cls, v):
|
||
|
if "PhaseNet" not in v.value:
|
||
|
raise ValueError("Additional parameters implemented for PhaseNet only")
|
||
|
return v
|
||
|
|
||
|
|
||
|
class FilteringParameters(Parameters):
|
||
|
model_config = ConfigDict(extra='forbid')
|
||
|
|
||
|
highpass: Union[NumericValue, NumericValues, FloatDistribution, IntDistribution] = None
|
||
|
lowpass: Union[NumericValue, NumericValues, FloatDistribution, IntDistribution] = None
|
||
|
|
||
|
@field_validator("model_name")
|
||
|
def validate_model(cls, v):
|
||
|
print(v.value)
|
||
|
if v.value[0] not in ["GPD", "PhaseNet"]:
|
||
|
raise ValueError("Filtering parameters implemented for GPD and PhaseNet only")
|
||
|
|
||
|
|
||
|
class InputParams(BaseModel):
|
||
|
name: str
|
||
|
method: str
|
||
|
metric: Metric
|
||
|
parameters: Union[Parameters, PhaseNetParameters, FilteringParameters]
|
||
|
|
||
|
|
||
|
def validate_sweep_yaml(yaml_filename, model_name=None):
|
||
|
# Load YAML configuration
|
||
|
with open(yaml_filename, 'r') as f:
|
||
|
sweep_config = yaml.safe_load(f)
|
||
|
|
||
|
validate_sweep_config(sweep_config, model_name)
|
||
|
|
||
|
|
||
|
def validate_sweep_config(sweep_config, model_name=None):
|
||
|
|
||
|
# Validate sweep config
|
||
|
|
||
|
input_params = InputParams(**sweep_config)
|
||
|
|
||
|
# Check consistency of input parameters and sweep configuration
|
||
|
sweep_model_name = input_params.parameters.model_name.value
|
||
|
if model_name is not None and model_name not in sweep_model_name:
|
||
|
info = f"Model name {model_name} is inconsistent with the sweep configuration {sweep_model_name}."
|
||
|
logger.info(info)
|
||
|
raise ValueError(info)
|
||
|
logger.info("Input validation successful.")
|
||
|
|
||
|
|
||
|
|
||
|
|
||
|
|
||
|
if __name__ == "__main__":
|
||
|
yaml_filename = "../experiments/sweep_phasenet_bogdanka_lr_bs.yaml"
|
||
|
validate_sweep_yaml(yaml_filename, None)
|