PandSWavesDetectionTool/src/epos_ai_picking_tools/cli_phasenet.py

116 lines
2.5 KiB
Python

import pathlib
import click
import seisbench
from requests.models import default_hooks
from .__about__ import __version__
from .model_runner import list_pretrained_models
from .phasenet import PhaseNetModelRunner
@click.group()
def cli():
pass
@cli.command
def version():
"""Prints version"""
print(f"SeisBench v{seisbench.__version__}")
print(f"phasenet_tool v{__version__}")
@cli.command
def citation():
"""Prints citation of the model"""
m = PhaseNetModelRunner()
print(m.citation())
@cli.command
def list_pretrained():
"""Show pretrained model names"""
print(", ".join(list_pretrained_models(PhaseNetModelRunner)))
@cli.command
@click.option(
"-w",
"--weights",
default="original",
type=str,
show_default=True,
help=f"for possible options see output of 'list-pretrained'",
)
@click.option(
"-o",
"--output",
default=pathlib.Path("."),
type=click.Path(dir_okay=True, path_type=pathlib.Path),
show_default=True,
help="directory to store results",
)
@click.option(
"-tp",
"--threshold-p",
default=0.3,
type=float,
show_default=True,
help="detection threshold for the P phase",
)
@click.option(
"-ts",
"--threshold-s",
default=0.3,
type=float,
show_default=True,
help="detection threshold for the S phase",
)
@click.option(
"-b",
"--blinding",
default=(0, 0),
type=tuple,
show_default=True,
help="number of prediction samples to discard on each side of each window prediction",
)
@click.argument("stream_file_names", nargs=-1, type=click.Path(exists=True))
def pick(stream_file_names, weights, output, threshold_p, threshold_s, blinding):
"""Detect phases in streams"""
if not output.exists():
output.mkdir()
m = PhaseNetModelRunner(
weights,
output_dir=output,
threshold_p=threshold_p,
threshold_s=threshold_s,
blinding=blinding,
)
for stream in stream_file_names:
m.find_picks(stream)
@cli.command()
@click.option(
"-f",
"--force",
is_flag=True,
help="force downloading extra weights",
)
@click.option(
"-l",
"--list",
"list_weights",
is_flag=True,
help="list possible extra weights",
)
def extras(force, list_weights):
"""Downloads extra model weights from EPOS AI Platform"""
m = PhaseNetModelRunner()
if list_weights:
for w, _ in m.extra_weights:
print(w)
else:
m.download_extras(force=force)