forked from epos-ai/epos-ai-picking-tools
Initial import
This commit is contained in:
@@ -0,0 +1,29 @@
|
||||
import pathlib
|
||||
|
||||
import seisbench.models as sbm
|
||||
|
||||
from .model_runner import ModelRunner
|
||||
|
||||
|
||||
class GPDModelRunner(ModelRunner):
|
||||
|
||||
model_type = "GPD"
|
||||
|
||||
def __init__(self, weights_name="original", output_dir=pathlib.Path("."), **kwargs):
|
||||
self.model_name = getattr(sbm, GPDModelRunner.model_type)
|
||||
super(GPDModelRunner, self).__init__(
|
||||
weights_name=weights_name, output_dir=output_dir, **kwargs
|
||||
)
|
||||
|
||||
def process_kwargs(self, **kwargs):
|
||||
self.stride = int(kwargs.get("stride", 10))
|
||||
self.threshold_p = float(kwargs.get("threshold_p", 0.75))
|
||||
self.threshold_s = float(kwargs.get("threshold_s", 0.75))
|
||||
|
||||
self.annotate_kwargs = {
|
||||
"stride": self.stride,
|
||||
}
|
||||
|
||||
self.classify_kwargs = self.model.default_args.copy()
|
||||
self.classify_kwargs["P_threshold"] = self.threshold_p
|
||||
self.classify_kwargs["S_threshold"] = self.threshold_s
|
@@ -0,0 +1,102 @@
|
||||
import json
|
||||
import pathlib
|
||||
import sys
|
||||
from collections import defaultdict
|
||||
|
||||
import obspy
|
||||
import seisbench.models as sbm
|
||||
from obspy.core.event import Catalog, Event, Pick, WaveformStreamID
|
||||
from obspy.io.json.default import Default
|
||||
|
||||
|
||||
def list_pretrained_models(model_runner_class):
|
||||
m = getattr(sbm, model_runner_class.model_type)
|
||||
weights = m.list_pretrained()
|
||||
return weights
|
||||
|
||||
|
||||
def exit_error(msg):
|
||||
print("ERROR:", msg)
|
||||
sys.exit(1)
|
||||
|
||||
|
||||
class ModelRunner:
|
||||
model_type = "EMPTY"
|
||||
|
||||
def __init__(self, weights_name="original", output_dir=pathlib.Path("."), **kwargs):
|
||||
# self.model_name = getattr(sbm, __class__.model_type)
|
||||
self.model = self.load_model(weights_name)
|
||||
self.output_dir = pathlib.Path(output_dir)
|
||||
self.process_kwargs(**kwargs)
|
||||
|
||||
def process_kwargs(self, **kwargs):
|
||||
pass
|
||||
|
||||
def citation(self):
|
||||
return self.model.citation
|
||||
|
||||
def load_model(self, weights_name):
|
||||
return self.model_name.from_pretrained(weights_name)
|
||||
|
||||
def load_stream(self, stream_file_name):
|
||||
return obspy.read(stream_file_name)
|
||||
|
||||
def save_picks(self, classs_picks, stream_path):
|
||||
dict_picks = list(map(lambda p: p.__dict__, classs_picks))
|
||||
fpath = self.output_dir / f"{stream_path.stem}_picks.json"
|
||||
|
||||
with open(fpath, "w") as fp:
|
||||
json.dump(dict_picks, fp, default=Default())
|
||||
|
||||
def save_quakeml(self, classs_picks, stream_path):
|
||||
e = Event()
|
||||
for cpick in classs_picks:
|
||||
net, sta, loc = cpick.trace_id.split(".")
|
||||
p = Pick(
|
||||
time=cpick.peak_time,
|
||||
phase_hint=cpick.phase,
|
||||
waveform_id=WaveformStreamID(
|
||||
network_code=net, station_code=sta, location_code=loc
|
||||
),
|
||||
)
|
||||
e.picks.append(p)
|
||||
|
||||
cat = Catalog([e])
|
||||
fpath = self.output_dir / f"{stream_path.stem}_picks.xml"
|
||||
cat.write(fpath, format="QUAKEML")
|
||||
|
||||
def write_annotations(self, annotations, stream_path):
|
||||
ann = annotations.copy()
|
||||
for tr in ann:
|
||||
tr.stats.channel = f"G_{tr.stats.component}"
|
||||
fpath = self.output_dir / f"{stream_path.stem}_annotations.mseed"
|
||||
ann.write(fpath, format="MSEED")
|
||||
|
||||
@staticmethod
|
||||
def validate_stream(stream):
|
||||
groups = defaultdict(list)
|
||||
for trace in stream:
|
||||
groups[trace.stats.station].append(trace.stats.channel[-1])
|
||||
|
||||
number_of_channels = list(map(len, groups.values()))
|
||||
|
||||
if max(number_of_channels) < 3:
|
||||
exit_error("Not enough traces in the stream")
|
||||
|
||||
def find_picks(self, stream_file_name, save_annotations=True):
|
||||
stream_path = pathlib.Path(stream_file_name)
|
||||
stream = self.load_stream(stream_path)
|
||||
|
||||
self.validate_stream(stream)
|
||||
|
||||
annotations = self.model.annotate(stream, **self.annotate_kwargs)
|
||||
|
||||
if save_annotations:
|
||||
self.write_annotations(annotations, stream_path)
|
||||
|
||||
classs_picks = self.model.classify_aggregate(annotations, self.classify_kwargs)
|
||||
|
||||
self.save_picks(classs_picks, stream_path)
|
||||
self.save_quakeml(classs_picks, stream_path)
|
||||
|
||||
return classs_picks
|
4
src/epos_ai_picking_tools/__about__.py
Normal file
4
src/epos_ai_picking_tools/__about__.py
Normal file
@@ -0,0 +1,4 @@
|
||||
# SPDX-FileCopyrightText: 2022-present Hubert Siejkowski <h.siejkowski@gmail.com>
|
||||
#
|
||||
# SPDX-License-Identifier: MIT
|
||||
__version__ = "0.4.0"
|
0
src/epos_ai_picking_tools/__init__.py
Normal file
0
src/epos_ai_picking_tools/__init__.py
Normal file
90
src/epos_ai_picking_tools/cli_gpd.py
Normal file
90
src/epos_ai_picking_tools/cli_gpd.py
Normal file
@@ -0,0 +1,90 @@
|
||||
import pathlib
|
||||
|
||||
import click
|
||||
import seisbench
|
||||
|
||||
from .__about__ import __version__
|
||||
from .gpd import GPDModelRunner
|
||||
from .model_runner import list_pretrained_models
|
||||
|
||||
|
||||
@click.group()
|
||||
def cli():
|
||||
pass
|
||||
|
||||
|
||||
@cli.command
|
||||
def version():
|
||||
"""Prints version"""
|
||||
print(f"SeisBench v{seisbench.__version__}")
|
||||
print(f"gpd_tool v{__version__}")
|
||||
|
||||
|
||||
@cli.command
|
||||
def citation():
|
||||
"""Prints citation of the model"""
|
||||
m = GPDModelRunner()
|
||||
print(m.citation())
|
||||
|
||||
|
||||
@cli.command
|
||||
def list_pretrained():
|
||||
"""Show pretrained model names"""
|
||||
print(", ".join(list_pretrained_models(GPDModelRunner)))
|
||||
|
||||
|
||||
@cli.command
|
||||
@click.option(
|
||||
"-w",
|
||||
"--weights",
|
||||
default="original",
|
||||
type=str,
|
||||
show_default=True,
|
||||
help=f"for possible options see output of 'list-pretrained'",
|
||||
)
|
||||
@click.option(
|
||||
"-o",
|
||||
"--output",
|
||||
default=pathlib.Path("."),
|
||||
type=click.Path(dir_okay=True, path_type=pathlib.Path),
|
||||
show_default=True,
|
||||
help="directory to store results",
|
||||
)
|
||||
@click.option(
|
||||
"-s",
|
||||
"--stride",
|
||||
default=10,
|
||||
type=int,
|
||||
show_default=True,
|
||||
help="stride in samples for point prediction models",
|
||||
)
|
||||
@click.option(
|
||||
"-tp",
|
||||
"--threshold-p",
|
||||
default=0.75,
|
||||
type=float,
|
||||
show_default=True,
|
||||
help="detection threshold for the P phase",
|
||||
)
|
||||
@click.option(
|
||||
"-ts",
|
||||
"--threshold-s",
|
||||
default=0.75,
|
||||
type=float,
|
||||
show_default=True,
|
||||
help="detection threshold for the S phase",
|
||||
)
|
||||
@click.argument("stream_file_names", nargs=-1, type=click.Path(exists=True))
|
||||
def pick(stream_file_names, weights, output, stride, threshold_p, threshold_s):
|
||||
"""Detect phases in streams"""
|
||||
if not output.exists():
|
||||
output.mkdir()
|
||||
m = GPDModelRunner(
|
||||
weights,
|
||||
output_dir=output,
|
||||
stride=stride,
|
||||
threshold_p=threshold_p,
|
||||
threshold_s=threshold_s,
|
||||
)
|
||||
for stream in stream_file_names:
|
||||
m.find_picks(stream)
|
90
src/epos_ai_picking_tools/cli_phasenet.py
Normal file
90
src/epos_ai_picking_tools/cli_phasenet.py
Normal file
@@ -0,0 +1,90 @@
|
||||
import pathlib
|
||||
|
||||
import click
|
||||
import seisbench
|
||||
|
||||
from .__about__ import __version__
|
||||
from .model_runner import list_pretrained_models
|
||||
from .phasenet import PhaseNetModelRunner
|
||||
|
||||
|
||||
@click.group()
|
||||
def cli():
|
||||
pass
|
||||
|
||||
|
||||
@cli.command
|
||||
def version():
|
||||
"""Prints version"""
|
||||
print(f"SeisBench v{seisbench.__version__}")
|
||||
print(f"phasenet_tool v{__version__}")
|
||||
|
||||
|
||||
@cli.command
|
||||
def citation():
|
||||
"""Prints citation of the model"""
|
||||
m = PhaseNetModelRunner()
|
||||
print(m.citation())
|
||||
|
||||
|
||||
@cli.command
|
||||
def list_pretrained():
|
||||
"""Show pretrained model names"""
|
||||
print(", ".join(list_pretrained_models(PhaseNetModelRunner)))
|
||||
|
||||
|
||||
@cli.command
|
||||
@click.option(
|
||||
"-w",
|
||||
"--weights",
|
||||
default="original",
|
||||
type=str,
|
||||
show_default=True,
|
||||
help=f"for possible options see output of 'list-pretrained'",
|
||||
)
|
||||
@click.option(
|
||||
"-o",
|
||||
"--output",
|
||||
default=pathlib.Path("."),
|
||||
type=click.Path(dir_okay=True, path_type=pathlib.Path),
|
||||
show_default=True,
|
||||
help="directory to store results",
|
||||
)
|
||||
@click.option(
|
||||
"-tp",
|
||||
"--threshold-p",
|
||||
default=0.3,
|
||||
type=float,
|
||||
show_default=True,
|
||||
help="detection threshold for the P phase",
|
||||
)
|
||||
@click.option(
|
||||
"-ts",
|
||||
"--threshold-s",
|
||||
default=0.3,
|
||||
type=float,
|
||||
show_default=True,
|
||||
help="detection threshold for the S phase",
|
||||
)
|
||||
@click.option(
|
||||
"-b",
|
||||
"--blinding",
|
||||
default=(0, 0),
|
||||
type=tuple,
|
||||
show_default=True,
|
||||
help="number of prediction samples to discard on each side of each window prediction",
|
||||
)
|
||||
@click.argument("stream_file_names", nargs=-1, type=click.Path(exists=True))
|
||||
def pick(stream_file_names, weights, output, threshold_p, threshold_s, blinding):
|
||||
"""Detect phases in streams"""
|
||||
if not output.exists():
|
||||
output.mkdir()
|
||||
m = PhaseNetModelRunner(
|
||||
weights,
|
||||
output_dir=output,
|
||||
threshold_p=threshold_p,
|
||||
threshold_s=threshold_s,
|
||||
blinding=blinding,
|
||||
)
|
||||
for stream in stream_file_names:
|
||||
m.find_picks(stream)
|
29
src/epos_ai_picking_tools/gpd.py
Normal file
29
src/epos_ai_picking_tools/gpd.py
Normal file
@@ -0,0 +1,29 @@
|
||||
import pathlib
|
||||
|
||||
import seisbench.models as sbm
|
||||
|
||||
from .model_runner import ModelRunner
|
||||
|
||||
|
||||
class GPDModelRunner(ModelRunner):
|
||||
|
||||
model_type = "GPD"
|
||||
|
||||
def __init__(self, weights_name="original", output_dir=pathlib.Path("."), **kwargs):
|
||||
self.model_name = getattr(sbm, GPDModelRunner.model_type)
|
||||
super(GPDModelRunner, self).__init__(
|
||||
weights_name=weights_name, output_dir=output_dir, **kwargs
|
||||
)
|
||||
|
||||
def process_kwargs(self, **kwargs):
|
||||
self.stride = int(kwargs.get("stride", 10))
|
||||
self.threshold_p = float(kwargs.get("threshold_p", 0.75))
|
||||
self.threshold_s = float(kwargs.get("threshold_s", 0.75))
|
||||
|
||||
self.annotate_kwargs = {
|
||||
"stride": self.stride,
|
||||
}
|
||||
|
||||
self.classify_kwargs = self.model.default_args.copy()
|
||||
self.classify_kwargs["P_threshold"] = self.threshold_p
|
||||
self.classify_kwargs["S_threshold"] = self.threshold_s
|
102
src/epos_ai_picking_tools/model_runner.py
Normal file
102
src/epos_ai_picking_tools/model_runner.py
Normal file
@@ -0,0 +1,102 @@
|
||||
import json
|
||||
import pathlib
|
||||
import sys
|
||||
from collections import defaultdict
|
||||
|
||||
import obspy
|
||||
import seisbench.models as sbm
|
||||
from obspy.core.event import Catalog, Event, Pick, WaveformStreamID
|
||||
from obspy.io.json.default import Default
|
||||
|
||||
|
||||
def list_pretrained_models(model_runner_class):
|
||||
m = getattr(sbm, model_runner_class.model_type)
|
||||
weights = m.list_pretrained()
|
||||
return weights
|
||||
|
||||
|
||||
def exit_error(msg):
|
||||
print("ERROR:", msg)
|
||||
sys.exit(1)
|
||||
|
||||
|
||||
class ModelRunner:
|
||||
model_type = "EMPTY"
|
||||
|
||||
def __init__(self, weights_name="original", output_dir=pathlib.Path("."), **kwargs):
|
||||
# self.model_name = getattr(sbm, __class__.model_type)
|
||||
self.model = self.load_model(weights_name)
|
||||
self.output_dir = pathlib.Path(output_dir)
|
||||
self.process_kwargs(**kwargs)
|
||||
|
||||
def process_kwargs(self, **kwargs):
|
||||
pass
|
||||
|
||||
def citation(self):
|
||||
return self.model.citation
|
||||
|
||||
def load_model(self, weights_name):
|
||||
return self.model_name.from_pretrained(weights_name)
|
||||
|
||||
def load_stream(self, stream_file_name):
|
||||
return obspy.read(stream_file_name)
|
||||
|
||||
def save_picks(self, classs_picks, stream_path):
|
||||
dict_picks = list(map(lambda p: p.__dict__, classs_picks))
|
||||
fpath = self.output_dir / f"{stream_path.stem}_picks.json"
|
||||
|
||||
with open(fpath, "w") as fp:
|
||||
json.dump(dict_picks, fp, default=Default())
|
||||
|
||||
def save_quakeml(self, classs_picks, stream_path):
|
||||
e = Event()
|
||||
for cpick in classs_picks:
|
||||
net, sta, loc = cpick.trace_id.split(".")
|
||||
p = Pick(
|
||||
time=cpick.peak_time,
|
||||
phase_hint=cpick.phase,
|
||||
waveform_id=WaveformStreamID(
|
||||
network_code=net, station_code=sta, location_code=loc
|
||||
),
|
||||
)
|
||||
e.picks.append(p)
|
||||
|
||||
cat = Catalog([e])
|
||||
fpath = self.output_dir / f"{stream_path.stem}_picks.xml"
|
||||
cat.write(fpath, format="QUAKEML")
|
||||
|
||||
def write_annotations(self, annotations, stream_path):
|
||||
ann = annotations.copy()
|
||||
for tr in ann:
|
||||
tr.stats.channel = f"G_{tr.stats.component}"
|
||||
fpath = self.output_dir / f"{stream_path.stem}_annotations.mseed"
|
||||
ann.write(fpath, format="MSEED")
|
||||
|
||||
@staticmethod
|
||||
def validate_stream(stream):
|
||||
groups = defaultdict(list)
|
||||
for trace in stream:
|
||||
groups[trace.stats.station].append(trace.stats.channel[-1])
|
||||
|
||||
number_of_channels = list(map(len, groups.values()))
|
||||
|
||||
if max(number_of_channels) < 3:
|
||||
exit_error("Not enough traces in the stream")
|
||||
|
||||
def find_picks(self, stream_file_name, save_annotations=True):
|
||||
stream_path = pathlib.Path(stream_file_name)
|
||||
stream = self.load_stream(stream_path)
|
||||
|
||||
self.validate_stream(stream)
|
||||
|
||||
annotations = self.model.annotate(stream, **self.annotate_kwargs)
|
||||
|
||||
if save_annotations:
|
||||
self.write_annotations(annotations, stream_path)
|
||||
|
||||
classs_picks = self.model.classify_aggregate(annotations, self.classify_kwargs)
|
||||
|
||||
self.save_picks(classs_picks, stream_path)
|
||||
self.save_quakeml(classs_picks, stream_path)
|
||||
|
||||
return classs_picks
|
28
src/epos_ai_picking_tools/phasenet.py
Normal file
28
src/epos_ai_picking_tools/phasenet.py
Normal file
@@ -0,0 +1,28 @@
|
||||
import pathlib
|
||||
|
||||
import seisbench.models as sbm
|
||||
|
||||
from .model_runner import ModelRunner
|
||||
|
||||
|
||||
class PhaseNetModelRunner(ModelRunner):
|
||||
|
||||
model_type = "PhaseNet"
|
||||
|
||||
def __init__(self, weights_name="original", output_dir=pathlib.Path("."), **kwargs):
|
||||
self.model_name = getattr(sbm, PhaseNetModelRunner.model_type)
|
||||
super(PhaseNetModelRunner, self).__init__(
|
||||
weights_name=weights_name, output_dir=output_dir, **kwargs
|
||||
)
|
||||
|
||||
def process_kwargs(self, **kwargs):
|
||||
self.threshold_p = float(kwargs.get("threshold_p", 0.3))
|
||||
self.threshold_s = float(kwargs.get("threshold_s", 0.3))
|
||||
self.blinding = kwargs.get("blinding", (0, 0))
|
||||
|
||||
self.annotate_kwargs = {}
|
||||
|
||||
self.classify_kwargs = self.model.default_args.copy()
|
||||
self.classify_kwargs["P_threshold"] = self.threshold_p
|
||||
self.classify_kwargs["S_threshold"] = self.threshold_s
|
||||
self.classify_kwargs["blinding"] = self.blinding
|
Reference in New Issue
Block a user