platform-demo-scripts/scripts/input_validate.py

147 lines
4.4 KiB
Python
Raw Permalink Normal View History

from pydantic import BaseModel, ConfigDict, field_validator
from typing_extensions import Literal
from typing import Union, List, Optional
import yaml
import logging
logging.root.setLevel(logging.INFO)
logger = logging.getLogger('input_validator')
#todo
# 1. check if a single value is allowed in a sweep
# 2. merge input params
# 3. change names of the classes
# 4. add constraints for PhaseNet, GPD
model_names = Literal["PhaseNet", "GPD", "BasicPhaseAE", "EQTransformer"]
norm_values = Literal["peak", "std"]
finetuning_values = Literal["all", "top", "decoder", "encoder"]
pretrained_values = Literal['diting', 'ethz', 'geofon', 'instance', 'iquique', 'lendb', 'neic',
'original', 'scedc', False]
class Metric(BaseModel):
goal: str
name: str
class NumericValue(BaseModel):
value: Union[int, float, List[Union[int, float]]]
class NumericValues(BaseModel):
values: List[Union[int, float]]
class IntDistribution(BaseModel):
distribution: str = "int_uniform"
min: int
max: int
class FloatDistribution(BaseModel):
distribution: str = "uniform"
min: float
max: float
class Pretrained(BaseModel):
distribution: Optional[str] = "categorical"
values: List[pretrained_values] = None
value: Union[pretrained_values, List[pretrained_values]] = None
class Finetuning(BaseModel):
distribution: Optional[str] = "categorical"
values: List[finetuning_values] = None
value: Union[finetuning_values, List[finetuning_values]] = None
class Norm(BaseModel):
distribution: Optional[str] = "categorical"
values: List[norm_values] = None
value: Union[norm_values, List[norm_values]] = None
class ModelType(BaseModel):
distribution: Optional[str] = "categorical"
value: Union[model_names, List[model_names]] = None
values: List[model_names] = None
class Parameters(BaseModel):
model_config = ConfigDict(extra='forbid', protected_namespaces=())
model_name: ModelType
batch_size: Union[IntDistribution, NumericValue, NumericValues]
learning_rate: Union[FloatDistribution, NumericValue, NumericValues]
max_epochs: Union[IntDistribution, NumericValue, NumericValues]
class PhaseNetParameters(Parameters):
model_config = ConfigDict(extra='forbid')
norm: Norm = None
pretrained: Pretrained = None
finetuning: Finetuning = None
lr_reduce_factor: Optional[Union[FloatDistribution, NumericValue, NumericValues]] = None
highpass: Union[NumericValue, NumericValues, FloatDistribution, IntDistribution] = None
lowpass: Union[NumericValue, NumericValues, FloatDistribution, IntDistribution] = None
@field_validator("model_name")
def validate_model(cls, v):
if "PhaseNet" not in v.value:
raise ValueError("Additional parameters implemented for PhaseNet only")
return v
class FilteringParameters(Parameters):
model_config = ConfigDict(extra='forbid')
highpass: Union[NumericValue, NumericValues, FloatDistribution, IntDistribution] = None
lowpass: Union[NumericValue, NumericValues, FloatDistribution, IntDistribution] = None
@field_validator("model_name")
def validate_model(cls, v):
print(v.value)
if v.value[0] not in ["GPD", "PhaseNet"]:
raise ValueError("Filtering parameters implemented for GPD and PhaseNet only")
class InputParams(BaseModel):
name: str
method: str
metric: Metric
parameters: Union[Parameters, PhaseNetParameters, FilteringParameters]
def validate_sweep_yaml(yaml_filename, model_name=None):
# Load YAML configuration
with open(yaml_filename, 'r') as f:
sweep_config = yaml.safe_load(f)
validate_sweep_config(sweep_config, model_name)
def validate_sweep_config(sweep_config, model_name=None):
# Validate sweep config
input_params = InputParams(**sweep_config)
# Check consistency of input parameters and sweep configuration
sweep_model_name = input_params.parameters.model_name.value
if model_name is not None and model_name not in sweep_model_name:
info = f"Model name {model_name} is inconsistent with the sweep configuration {sweep_model_name}."
logger.info(info)
raise ValueError(info)
logger.info("Input validation successful.")
if __name__ == "__main__":
yaml_filename = "../experiments/sweep_phasenet_bogdanka_lr_bs.yaml"
validate_sweep_yaml(yaml_filename, None)