Add command downloading extra weights

This commit is contained in:
Hubert Siejkowski 2023-10-24 20:30:57 +00:00
parent 123df08603
commit b5d5622f1d
2 changed files with 78 additions and 1 deletions

View File

@ -2,6 +2,7 @@ import pathlib
import click
import seisbench
from requests.models import default_hooks
from .__about__ import __version__
from .model_runner import list_pretrained_models
@ -88,3 +89,27 @@ def pick(stream_file_names, weights, output, threshold_p, threshold_s, blinding)
)
for stream in stream_file_names:
m.find_picks(stream)
@cli.command()
@click.option(
"-f",
"--force",
is_flag=True,
help="force downloading extra weights",
)
@click.option(
"-l",
"--list",
"list_weights",
is_flag=True,
help="list possible extra weights",
)
def extras(force, list_weights):
"""Downloads extra model weights from EPOS AI Platform"""
m = PhaseNetModelRunner()
if list_weights:
for w, _ in m.extra_weights:
print(w)
else:
m.download_extras(force=force)

View File

@ -1,13 +1,19 @@
import os
import pathlib
from urllib.parse import urljoin
import seisbench
import seisbench.models as sbm
import seisbench.util
from .model_runner import ModelRunner
EPOS_AI_MODEL_REPOSIOTRY_URL = "https://models.isl.grid.cyfronet.pl/models/v3/"
class PhaseNetModelRunner(ModelRunner):
model_type = "PhaseNet"
extra_weights = [("bogdanka", "1"), ("lgcd", "1")]
def __init__(self, weights_name="original", output_dir=pathlib.Path("."), **kwargs):
self.model_name = getattr(sbm, PhaseNetModelRunner.model_type)
@ -26,3 +32,49 @@ class PhaseNetModelRunner(ModelRunner):
self.classify_kwargs["P_threshold"] = self.threshold_p
self.classify_kwargs["S_threshold"] = self.threshold_s
self.classify_kwargs["blinding"] = self.blinding
def download_extras(self, force=False):
for weight_name, version in PhaseNetModelRunner.extra_weights:
weight_path, metadata_path = self.model_name._pretrained_path(
weight_name, version
)
if weight_path.exists() and not force:
seisbench.logger.info(f"Weight file {weight_name} already in cache")
else:
if force:
os.remove(weight_path)
os.remove(metadata_path)
self._download_wieght(weight_name, version, weight_path, metadata_path)
def _download_wieght(self, weight_name, version, weight_path, metadata_path):
def download_callback(files):
weight_path, metadata_path = files
seisbench.logger.info(
f"Weight file {weight_path.name} not in cache. Downloading..."
)
weight_path.parent.mkdir(exist_ok=True, parents=True)
remote_weight_name = f"{weight_name}.pt.v{version}"
remote_metadata_name = f"{weight_name}.json.v{version}"
remote_path = urljoin(
EPOS_AI_MODEL_REPOSIOTRY_URL,
self.model_name._name_internal().lower(),
)
remote_weight_path = f"{remote_path}/{remote_weight_name}"
remote_metadata_path = f"{remote_path}/{remote_metadata_name}"
seisbench.util.download_http(remote_weight_path, weight_path)
seisbench.util.download_http(
remote_metadata_path, metadata_path, progress_bar=False
)
seisbench.util.callback_if_uncached(
[weight_path, metadata_path],
download_callback,
force=False,
wait_for_file=True,
)