from pydantic import BaseModel, ConfigDict, field_validator from typing_extensions import Literal from typing import Union, List, Optional import yaml import logging logging.root.setLevel(logging.INFO) logger = logging.getLogger('input_validator') #todo # 1. check if a single value is allowed in a sweep # 2. merge input params # 3. change names of the classes # 4. add constraints for PhaseNet, GPD model_names = Literal["PhaseNet", "GPD", "BasicPhaseAE", "EQTransformer"] norm_values = Literal["peak", "std"] finetuning_values = Literal["all", "top", "decoder", "encoder"] pretrained_values = Literal['diting', 'ethz', 'geofon', 'instance', 'iquique', 'lendb', 'neic', 'original', 'scedc', False] class Metric(BaseModel): goal: str name: str class NumericValue(BaseModel): value: Union[int, float, List[Union[int, float]]] class NumericValues(BaseModel): values: List[Union[int, float]] class IntDistribution(BaseModel): distribution: str = "int_uniform" min: int max: int class FloatDistribution(BaseModel): distribution: str = "uniform" min: float max: float class Pretrained(BaseModel): distribution: Optional[str] = "categorical" values: List[pretrained_values] = None value: Union[pretrained_values, List[pretrained_values]] = None class Finetuning(BaseModel): distribution: Optional[str] = "categorical" values: List[finetuning_values] = None value: Union[finetuning_values, List[finetuning_values]] = None class Norm(BaseModel): distribution: Optional[str] = "categorical" values: List[norm_values] = None value: Union[norm_values, List[norm_values]] = None class ModelType(BaseModel): distribution: Optional[str] = "categorical" value: Union[model_names, List[model_names]] = None values: List[model_names] = None class Parameters(BaseModel): model_config = ConfigDict(extra='forbid', protected_namespaces=()) model_name: ModelType batch_size: Union[IntDistribution, NumericValue, NumericValues] learning_rate: Union[FloatDistribution, NumericValue, NumericValues] max_epochs: Union[IntDistribution, NumericValue, NumericValues] class PhaseNetParameters(Parameters): model_config = ConfigDict(extra='forbid') norm: Norm = None pretrained: Pretrained = None finetuning: Finetuning = None lr_reduce_factor: Optional[Union[FloatDistribution, NumericValue, NumericValues]] = None highpass: Union[NumericValue, NumericValues, FloatDistribution, IntDistribution] = None lowpass: Union[NumericValue, NumericValues, FloatDistribution, IntDistribution] = None @field_validator("model_name") def validate_model(cls, v): if "PhaseNet" not in v.value: raise ValueError("Additional parameters implemented for PhaseNet only") return v class FilteringParameters(Parameters): model_config = ConfigDict(extra='forbid') highpass: Union[NumericValue, NumericValues, FloatDistribution, IntDistribution] = None lowpass: Union[NumericValue, NumericValues, FloatDistribution, IntDistribution] = None @field_validator("model_name") def validate_model(cls, v): print(v.value) if v.value[0] not in ["GPD", "PhaseNet"]: raise ValueError("Filtering parameters implemented for GPD and PhaseNet only") class InputParams(BaseModel): name: str method: str metric: Metric parameters: Union[Parameters, PhaseNetParameters, FilteringParameters] def validate_sweep_yaml(yaml_filename, model_name=None): # Load YAML configuration with open(yaml_filename, 'r') as f: sweep_config = yaml.safe_load(f) validate_sweep_config(sweep_config, model_name) def validate_sweep_config(sweep_config, model_name=None): # Validate sweep config input_params = InputParams(**sweep_config) # Check consistency of input parameters and sweep configuration sweep_model_name = input_params.parameters.model_name.value if model_name is not None and model_name not in sweep_model_name: info = f"Model name {model_name} is inconsistent with the sweep configuration {sweep_model_name}." logger.info(info) raise ValueError(info) logger.info("Input validation successful.") if __name__ == "__main__": yaml_filename = "../experiments/sweep_phasenet_bogdanka_lr_bs.yaml" validate_sweep_yaml(yaml_filename, None)