import os import pathlib from urllib.parse import urljoin import seisbench import seisbench.models as sbm import seisbench.util from .model_runner import ModelRunner EPOS_AI_MODEL_REPOSIOTRY_URL = "https://models.isl.grid.cyfronet.pl/models/v3/" class PhaseNetModelRunner(ModelRunner): model_type = "PhaseNet" extra_weights = [("bogdanka", "1"), ("lgcd", "1")] def __init__(self, weights_name="original", output_dir=pathlib.Path("."), **kwargs): self.model_name = getattr(sbm, PhaseNetModelRunner.model_type) super(PhaseNetModelRunner, self).__init__( weights_name=weights_name, output_dir=output_dir, **kwargs ) def process_kwargs(self, **kwargs): self.threshold_p = float(kwargs.get("threshold_p", 0.3)) self.threshold_s = float(kwargs.get("threshold_s", 0.3)) self.blinding = kwargs.get("blinding", (0, 0)) self.annotate_kwargs = {} self.classify_kwargs = self.model.default_args.copy() self.classify_kwargs["P_threshold"] = self.threshold_p self.classify_kwargs["S_threshold"] = self.threshold_s self.classify_kwargs["blinding"] = self.blinding def download_extras(self, force=False): for weight_name, version in PhaseNetModelRunner.extra_weights: weight_path, metadata_path = self.model_name._pretrained_path( weight_name, version ) if weight_path.exists() and not force: seisbench.logger.info(f"Weight file {weight_name} already in cache") else: if force: os.remove(weight_path) os.remove(metadata_path) self._download_wieght(weight_name, version, weight_path, metadata_path) def _download_wieght(self, weight_name, version, weight_path, metadata_path): def download_callback(files): weight_path, metadata_path = files seisbench.logger.info( f"Weight file {weight_path.name} not in cache. Downloading..." ) weight_path.parent.mkdir(exist_ok=True, parents=True) remote_weight_name = f"{weight_name}.pt.v{version}" remote_metadata_name = f"{weight_name}.json.v{version}" remote_path = urljoin( EPOS_AI_MODEL_REPOSIOTRY_URL, self.model_name._name_internal().lower(), ) remote_weight_path = f"{remote_path}/{remote_weight_name}" remote_metadata_path = f"{remote_path}/{remote_metadata_name}" seisbench.util.download_http(remote_weight_path, weight_path) seisbench.util.download_http( remote_metadata_path, metadata_path, progress_bar=False ) seisbench.util.callback_if_uncached( [weight_path, metadata_path], download_callback, force=False, wait_for_file=True, )