Add command downloading extra weights
This commit is contained in:
parent
123df08603
commit
b5d5622f1d
@ -2,6 +2,7 @@ import pathlib
|
||||
|
||||
import click
|
||||
import seisbench
|
||||
from requests.models import default_hooks
|
||||
|
||||
from .__about__ import __version__
|
||||
from .model_runner import list_pretrained_models
|
||||
@ -88,3 +89,27 @@ def pick(stream_file_names, weights, output, threshold_p, threshold_s, blinding)
|
||||
)
|
||||
for stream in stream_file_names:
|
||||
m.find_picks(stream)
|
||||
|
||||
|
||||
@cli.command()
|
||||
@click.option(
|
||||
"-f",
|
||||
"--force",
|
||||
is_flag=True,
|
||||
help="force downloading extra weights",
|
||||
)
|
||||
@click.option(
|
||||
"-l",
|
||||
"--list",
|
||||
"list_weights",
|
||||
is_flag=True,
|
||||
help="list possible extra weights",
|
||||
)
|
||||
def extras(force, list_weights):
|
||||
"""Downloads extra model weights from EPOS AI Platform"""
|
||||
m = PhaseNetModelRunner()
|
||||
if list_weights:
|
||||
for w, _ in m.extra_weights:
|
||||
print(w)
|
||||
else:
|
||||
m.download_extras(force=force)
|
||||
|
@ -1,13 +1,19 @@
|
||||
import os
|
||||
import pathlib
|
||||
from urllib.parse import urljoin
|
||||
|
||||
import seisbench
|
||||
import seisbench.models as sbm
|
||||
import seisbench.util
|
||||
|
||||
from .model_runner import ModelRunner
|
||||
|
||||
EPOS_AI_MODEL_REPOSIOTRY_URL = "https://models.isl.grid.cyfronet.pl/models/v3/"
|
||||
|
||||
|
||||
class PhaseNetModelRunner(ModelRunner):
|
||||
|
||||
model_type = "PhaseNet"
|
||||
extra_weights = [("bogdanka", "1"), ("lgcd", "1")]
|
||||
|
||||
def __init__(self, weights_name="original", output_dir=pathlib.Path("."), **kwargs):
|
||||
self.model_name = getattr(sbm, PhaseNetModelRunner.model_type)
|
||||
@ -26,3 +32,49 @@ class PhaseNetModelRunner(ModelRunner):
|
||||
self.classify_kwargs["P_threshold"] = self.threshold_p
|
||||
self.classify_kwargs["S_threshold"] = self.threshold_s
|
||||
self.classify_kwargs["blinding"] = self.blinding
|
||||
|
||||
def download_extras(self, force=False):
|
||||
for weight_name, version in PhaseNetModelRunner.extra_weights:
|
||||
weight_path, metadata_path = self.model_name._pretrained_path(
|
||||
weight_name, version
|
||||
)
|
||||
|
||||
if weight_path.exists() and not force:
|
||||
seisbench.logger.info(f"Weight file {weight_name} already in cache")
|
||||
else:
|
||||
if force:
|
||||
os.remove(weight_path)
|
||||
os.remove(metadata_path)
|
||||
|
||||
self._download_wieght(weight_name, version, weight_path, metadata_path)
|
||||
|
||||
def _download_wieght(self, weight_name, version, weight_path, metadata_path):
|
||||
def download_callback(files):
|
||||
weight_path, metadata_path = files
|
||||
seisbench.logger.info(
|
||||
f"Weight file {weight_path.name} not in cache. Downloading..."
|
||||
)
|
||||
weight_path.parent.mkdir(exist_ok=True, parents=True)
|
||||
|
||||
remote_weight_name = f"{weight_name}.pt.v{version}"
|
||||
remote_metadata_name = f"{weight_name}.json.v{version}"
|
||||
|
||||
remote_path = urljoin(
|
||||
EPOS_AI_MODEL_REPOSIOTRY_URL,
|
||||
self.model_name._name_internal().lower(),
|
||||
)
|
||||
|
||||
remote_weight_path = f"{remote_path}/{remote_weight_name}"
|
||||
remote_metadata_path = f"{remote_path}/{remote_metadata_name}"
|
||||
|
||||
seisbench.util.download_http(remote_weight_path, weight_path)
|
||||
seisbench.util.download_http(
|
||||
remote_metadata_path, metadata_path, progress_bar=False
|
||||
)
|
||||
|
||||
seisbench.util.callback_if_uncached(
|
||||
[weight_path, metadata_path],
|
||||
download_callback,
|
||||
force=False,
|
||||
wait_for_file=True,
|
||||
)
|
||||
|
Loading…
Reference in New Issue
Block a user