SourceParametersEstimation/opt_algorithms.py

76 lines
2.7 KiB
Python
Raw Permalink Normal View History

2024-10-09 11:24:44 +02:00
import scipy.optimize
import numpy as np
from SpectralParameters import SpectralParams
import f_models
from f_norms import FNorm
class OptAlgorithm:
"""Base class for optimization algorithms.
"""
def __init__(self, initial_model: SpectralParams, freq_bins, amplitude_spectrum, weights, travel_time, config,
logger):
self.initial_model = initial_model
self.config = config
self.travel_time = travel_time
self.f_model = getattr(f_models, config.get("sp_fit_model"))(freq_bins=freq_bins,
spectral_parameters=initial_model)
self.f_norm = FNorm(norm=self.config.get("norm"), spectral_params=initial_model, freq_bins=freq_bins,
amplitude_spectrum=amplitude_spectrum, weights=weights, travel_time=travel_time,
source_model=self.f_model, logger=logger)
self.solution = None
self.error = None
self.name = self.__class__.__name__
def run(self):
"""Run the optimization algorithm and return SpectralParams results
:return:
"""
return self.error, self.solution
def __repr__(self):
if self.solution:
output = f"{self.name} results:\n"
output += f" {self.solution.__str__()} \n"
output += f" Error: {self.error:.4f}"
return output
else:
return f"{self.name}: no solution"
class OptNelderMead(OptAlgorithm):
"""
Minimize a function using the downhill simplex algorithm from scipy.optimize.
"""
def __init__(self, initial_model: SpectralParams, freq_bins, amplitude_spectrum, weights, travel_time, config,
logger):
super().__init__(initial_model, freq_bins, amplitude_spectrum, weights, travel_time, config, logger)
self.initial_q = (self.config.get("q_min")+self.config.get("q_max"))/2
def run(self):
def prepare_fun(x):
self.f_norm.spectral_par = SpectralParams(mo=x[0], fo=x[1], q=x[2], )
return self.f_norm.calculate().misfit
# Initial model parameters
x0 = [self.initial_model.mo, self.initial_model.fo, self.initial_q]
# Optimization bounds
bounds = [(None, None), (1 / self.config.get("window_len"), self.config.get("freq_max")),
(self.config.get("q_min"), self.config.get("q_max"))]
# Perform optimization
xopt = scipy.optimize.minimize(method='Nelder-Mead', fun=prepare_fun, x0=np.array(x0), bounds=bounds)
# Store the results
self.solution = SpectralParams(mo=xopt.x[0], fo=xopt.x[1], q=xopt.x[2])
self.error = xopt.fun
return self.error, self.solution