import pathlib import click import seisbench from requests.models import default_hooks from .__about__ import __version__ from .model_runner import list_pretrained_models from .phasenet import PhaseNetModelRunner @click.group() def cli(): pass @cli.command def version(): """Prints version""" print(f"SeisBench v{seisbench.__version__}") print(f"phasenet_tool v{__version__}") @cli.command def citation(): """Prints citation of the model""" m = PhaseNetModelRunner() print(m.citation()) @cli.command def list_pretrained(): """Show pretrained model names""" print(", ".join(list_pretrained_models(PhaseNetModelRunner))) @cli.command @click.option( "-w", "--weights", default="original", type=str, show_default=True, help=f"for possible options see output of 'list-pretrained'", ) @click.option( "-o", "--output", default=pathlib.Path("."), type=click.Path(dir_okay=True, path_type=pathlib.Path), show_default=True, help="directory to store results", ) @click.option( "-tp", "--threshold-p", default=0.3, type=float, show_default=True, help="detection threshold for the P phase", ) @click.option( "-ts", "--threshold-s", default=0.3, type=float, show_default=True, help="detection threshold for the S phase", ) @click.option( "-b", "--blinding", default=(0, 0), type=tuple, show_default=True, help="number of prediction samples to discard on each side of each window prediction", ) @click.argument("stream_file_names", nargs=-1, type=click.Path(exists=True)) def pick(stream_file_names, weights, output, threshold_p, threshold_s, blinding): """Detect phases in streams""" if not output.exists(): output.mkdir() m = PhaseNetModelRunner( weights, output_dir=output, threshold_p=threshold_p, threshold_s=threshold_s, blinding=blinding, ) for stream in stream_file_names: m.find_picks(stream) @cli.command() @click.option( "-f", "--force", is_flag=True, help="force downloading extra weights", ) @click.option( "-l", "--list", "list_weights", is_flag=True, help="list possible extra weights", ) def extras(force, list_weights): """Downloads extra model weights from EPOS AI Platform""" m = PhaseNetModelRunner() if list_weights: for w, _ in m.extra_weights: print(w) else: m.download_extras(force=force)