Initial import
This commit is contained in:
		
							
								
								
									
										3
									
								
								.gitignore
									
									
									
									
										vendored
									
									
										Normal file
									
								
							
							
						
						
									
										3
									
								
								.gitignore
									
									
									
									
										vendored
									
									
										Normal file
									
								
							@@ -0,0 +1,3 @@
 | 
				
			|||||||
 | 
					.python-version
 | 
				
			||||||
 | 
					dist
 | 
				
			||||||
 | 
					*.pyc
 | 
				
			||||||
							
								
								
									
										165
									
								
								README.md
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										165
									
								
								README.md
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,165 @@
 | 
				
			|||||||
 | 
					# P and S Waves Detection with Deep Learning
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					## Installation
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					```
 | 
				
			||||||
 | 
					pip install --extra-index-url https://epos-apps.grid.cyfronet.pl/api/packages/epos-ai/pypi/simple epos_ai_picking_tools
 | 
				
			||||||
 | 
					```
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					## Generalized Phase Detection (GPD) Model
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					### Documentation
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					<!-- [[[cog
 | 
				
			||||||
 | 
					import cog
 | 
				
			||||||
 | 
					from epos_ai_picking_tools import cli_gpd as cli
 | 
				
			||||||
 | 
					from click.testing import CliRunner
 | 
				
			||||||
 | 
					runner = CliRunner()
 | 
				
			||||||
 | 
					result = runner.invoke(cli.cli, ["--help"])
 | 
				
			||||||
 | 
					help = result.output.replace("Usage: cli", "Usage: gpd_tool")
 | 
				
			||||||
 | 
					cog.out(
 | 
				
			||||||
 | 
					    "```\n{}\n```".format(help)
 | 
				
			||||||
 | 
					)
 | 
				
			||||||
 | 
					]]] -->
 | 
				
			||||||
 | 
					```
 | 
				
			||||||
 | 
					Usage: gpd_tool [OPTIONS] COMMAND [ARGS]...
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					Options:
 | 
				
			||||||
 | 
					  --help  Show this message and exit.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					Commands:
 | 
				
			||||||
 | 
					  citation         Prints citation of the model
 | 
				
			||||||
 | 
					  list-pretrained  Show pretrained model names
 | 
				
			||||||
 | 
					  pick             Detect phases in streams
 | 
				
			||||||
 | 
					  version          Prints version
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					```
 | 
				
			||||||
 | 
					<!-- [[[end]]] -->
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					### Pick
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					<!-- [[[cog
 | 
				
			||||||
 | 
					import cog
 | 
				
			||||||
 | 
					from epos_ai_picking_tools import cli_gpd as cli
 | 
				
			||||||
 | 
					from click.testing import CliRunner
 | 
				
			||||||
 | 
					runner = CliRunner()
 | 
				
			||||||
 | 
					result = runner.invoke(cli.cli, ["pick", "--help"])
 | 
				
			||||||
 | 
					help = result.output.replace("Usage: cli", "Usage: gpd_tool")
 | 
				
			||||||
 | 
					cog.out(
 | 
				
			||||||
 | 
					    "```\n{}\n```".format(help)
 | 
				
			||||||
 | 
					)
 | 
				
			||||||
 | 
					]]] -->
 | 
				
			||||||
 | 
					```
 | 
				
			||||||
 | 
					Usage: gpd_tool pick [OPTIONS] [STREAM_FILE_NAMES]...
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  Detect phases in streams
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					Options:
 | 
				
			||||||
 | 
					  -w, --weights TEXT        for possible options see output of 'list-pretrained'
 | 
				
			||||||
 | 
					                            [default: original]
 | 
				
			||||||
 | 
					  -o, --output PATH         directory to store results  [default: .]
 | 
				
			||||||
 | 
					  -s, --stride INTEGER      stride in samples for point prediction models
 | 
				
			||||||
 | 
					                            [default: 10]
 | 
				
			||||||
 | 
					  -tp, --threshold-p FLOAT  detection threshold for the P phase  [default: 0.75]
 | 
				
			||||||
 | 
					  -ts, --threshold-s FLOAT  detection threshold for the S phase  [default: 0.75]
 | 
				
			||||||
 | 
					  --help                    Show this message and exit.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					```
 | 
				
			||||||
 | 
					<!-- [[[end]]] -->
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					### Citation
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					<!-- [[[cog
 | 
				
			||||||
 | 
					import cog
 | 
				
			||||||
 | 
					from epos_ai_picking_tools import cli_gpd as cli
 | 
				
			||||||
 | 
					from click.testing import CliRunner
 | 
				
			||||||
 | 
					runner = CliRunner()
 | 
				
			||||||
 | 
					result = runner.invoke(cli.cli, ["citation"])
 | 
				
			||||||
 | 
					help = result.output.replace("Usage: cli", "Usage: gpd_tool")
 | 
				
			||||||
 | 
					cog.out(
 | 
				
			||||||
 | 
					    "\n{}\n".format(help)
 | 
				
			||||||
 | 
					)
 | 
				
			||||||
 | 
					]]] -->
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					Ross, Z. E., Meier, M.-A., Hauksson, E., & Heaton, T. H. (2018). Generalized Seismic Phase Detection with Deep Learning. ArXiv:1805.01075 [Physics]. https://arxiv.org/abs/1805.01075
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					<!-- [[[end]]] -->
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					## PhaseNet
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					### Documentation
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					<!-- [[[cog
 | 
				
			||||||
 | 
					import cog
 | 
				
			||||||
 | 
					from epos_ai_picking_tools import cli_phasenet as cli
 | 
				
			||||||
 | 
					from click.testing import CliRunner
 | 
				
			||||||
 | 
					runner = CliRunner()
 | 
				
			||||||
 | 
					result = runner.invoke(cli.cli, ["--help"])
 | 
				
			||||||
 | 
					help = result.output.replace("Usage: cli", "Usage: phasenet_tool")
 | 
				
			||||||
 | 
					cog.out(
 | 
				
			||||||
 | 
					    "```\n{}\n```".format(help)
 | 
				
			||||||
 | 
					)
 | 
				
			||||||
 | 
					]]] -->
 | 
				
			||||||
 | 
					```
 | 
				
			||||||
 | 
					Usage: phasenet_tool [OPTIONS] COMMAND [ARGS]...
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					Options:
 | 
				
			||||||
 | 
					  --help  Show this message and exit.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					Commands:
 | 
				
			||||||
 | 
					  citation         Prints citation of the model
 | 
				
			||||||
 | 
					  list-pretrained  Show pretrained model names
 | 
				
			||||||
 | 
					  pick             Detect phases in streams
 | 
				
			||||||
 | 
					  version          Prints version
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					```
 | 
				
			||||||
 | 
					<!-- [[[end]]] -->
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					### Pick
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					<!-- [[[cog
 | 
				
			||||||
 | 
					import cog
 | 
				
			||||||
 | 
					from epos_ai_picking_tools import cli_phasenet as cli
 | 
				
			||||||
 | 
					from click.testing import CliRunner
 | 
				
			||||||
 | 
					runner = CliRunner()
 | 
				
			||||||
 | 
					result = runner.invoke(cli.cli, ["pick", "--help"])
 | 
				
			||||||
 | 
					help = result.output.replace("Usage: cli", "Usage: phasenet_tool")
 | 
				
			||||||
 | 
					cog.out(
 | 
				
			||||||
 | 
					    "```\n{}\n```".format(help)
 | 
				
			||||||
 | 
					)
 | 
				
			||||||
 | 
					]]] -->
 | 
				
			||||||
 | 
					```
 | 
				
			||||||
 | 
					Usage: phasenet_tool pick [OPTIONS] [STREAM_FILE_NAMES]...
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  Detect phases in streams
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					Options:
 | 
				
			||||||
 | 
					  -w, --weights TEXT        for possible options see output of 'list-pretrained'
 | 
				
			||||||
 | 
					                            [default: original]
 | 
				
			||||||
 | 
					  -o, --output PATH         directory to store results  [default: .]
 | 
				
			||||||
 | 
					  -tp, --threshold-p FLOAT  detection threshold for the P phase  [default: 0.3]
 | 
				
			||||||
 | 
					  -ts, --threshold-s FLOAT  detection threshold for the S phase  [default: 0.3]
 | 
				
			||||||
 | 
					  -b, --blinding TUPLE      number of prediction samples to discard on each side
 | 
				
			||||||
 | 
					                            of each window prediction  [default: 0, 0]
 | 
				
			||||||
 | 
					  --help                    Show this message and exit.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					```
 | 
				
			||||||
 | 
					<!-- [[[end]]] -->
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					### Citation
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					<!-- [[[cog
 | 
				
			||||||
 | 
					import cog
 | 
				
			||||||
 | 
					from epos_ai_picking_tools import cli_phasenet as cli
 | 
				
			||||||
 | 
					from click.testing import CliRunner
 | 
				
			||||||
 | 
					runner = CliRunner()
 | 
				
			||||||
 | 
					result = runner.invoke(cli.cli, ["citation"])
 | 
				
			||||||
 | 
					help = result.output.replace("Usage: cli", "Usage: phasenet_tool")
 | 
				
			||||||
 | 
					cog.out(
 | 
				
			||||||
 | 
					    "\n{}\n".format(help)
 | 
				
			||||||
 | 
					)
 | 
				
			||||||
 | 
					]]] -->
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					Zhu, W., & Beroza, G. C. (2019). PhaseNet: a deep-neural-network-based seismic arrival-time picking method. Geophysical Journal International, 216(1), 261-273. https://doi.org/10.1093/gji/ggy423
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					<!-- [[[end]]] -->
 | 
				
			||||||
							
								
								
									
										65
									
								
								pyproject.toml
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										65
									
								
								pyproject.toml
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,65 @@
 | 
				
			|||||||
 | 
					[build-system]
 | 
				
			||||||
 | 
					requires = ["hatchling"]
 | 
				
			||||||
 | 
					build-backend = "hatchling.build"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					[project]
 | 
				
			||||||
 | 
					name = "epos_ai_picking_tools"
 | 
				
			||||||
 | 
					description = 'P and S Waves Detection with Deep Learning'
 | 
				
			||||||
 | 
					readme = "README.md"
 | 
				
			||||||
 | 
					requires-python = ">=3.7"
 | 
				
			||||||
 | 
					license = "MIT"
 | 
				
			||||||
 | 
					keywords = []
 | 
				
			||||||
 | 
					authors = [
 | 
				
			||||||
 | 
					  { name = "Hubert Siejkowski", email = "h.siejkowski@cyfronet.pl" },
 | 
				
			||||||
 | 
					]
 | 
				
			||||||
 | 
					classifiers = [
 | 
				
			||||||
 | 
					  "Programming Language :: Python",
 | 
				
			||||||
 | 
					  "Programming Language :: Python :: 3.7",
 | 
				
			||||||
 | 
					  "Programming Language :: Python :: 3.8",
 | 
				
			||||||
 | 
					  "Programming Language :: Python :: 3.9",
 | 
				
			||||||
 | 
					  "Programming Language :: Python :: 3.10",
 | 
				
			||||||
 | 
					  "Programming Language :: Python :: 3.11",
 | 
				
			||||||
 | 
					  "Programming Language :: Python :: Implementation :: CPython",
 | 
				
			||||||
 | 
					]
 | 
				
			||||||
 | 
					dependencies = [
 | 
				
			||||||
 | 
					    "seisbench==0.4.*",
 | 
				
			||||||
 | 
					    "click"
 | 
				
			||||||
 | 
					]
 | 
				
			||||||
 | 
					dynamic = ["version"]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					[tool.hatch.version]
 | 
				
			||||||
 | 
					path = "src/epos_ai_picking_tools/__about__.py"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					[tool.hatch.envs.default]
 | 
				
			||||||
 | 
					dependencies = [
 | 
				
			||||||
 | 
					  "pytest",
 | 
				
			||||||
 | 
					  "pytest-cov",
 | 
				
			||||||
 | 
					  "ipython",
 | 
				
			||||||
 | 
					  "black",
 | 
				
			||||||
 | 
					  "isort",
 | 
				
			||||||
 | 
					  "cogapp"
 | 
				
			||||||
 | 
					]
 | 
				
			||||||
 | 
					[tool.hatch.envs.default.scripts]
 | 
				
			||||||
 | 
					cov = "pytest --cov-report=term-missing --cov-config=pyproject.toml --cov=gpd_tool --cov=tests"
 | 
				
			||||||
 | 
					no-cov = "cov --no-cov"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					[[tool.hatch.envs.test.matrix]]
 | 
				
			||||||
 | 
					python = ["310", "311"]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					[tool.coverage.run]
 | 
				
			||||||
 | 
					branch = true
 | 
				
			||||||
 | 
					parallel = true
 | 
				
			||||||
 | 
					omit = [
 | 
				
			||||||
 | 
					  "gpd_tool/__about__.py",
 | 
				
			||||||
 | 
					]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					[tool.coverage.report]
 | 
				
			||||||
 | 
					exclude_lines = [
 | 
				
			||||||
 | 
					  "no cov",
 | 
				
			||||||
 | 
					  "if __name__ == .__main__.:",
 | 
				
			||||||
 | 
					  "if TYPE_CHECKING:",
 | 
				
			||||||
 | 
					]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					[project.scripts]
 | 
				
			||||||
 | 
					gpd_tool = "epos_ai_picking_tools.cli_gpd:cli"
 | 
				
			||||||
 | 
					phasenet_tool = "epos_ai_picking_tools.cli_phasenet:cli"
 | 
				
			||||||
@@ -0,0 +1,29 @@
 | 
				
			|||||||
 | 
					import pathlib
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					import seisbench.models as sbm
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					from .model_runner import ModelRunner
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					class GPDModelRunner(ModelRunner):
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    model_type = "GPD"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def __init__(self, weights_name="original", output_dir=pathlib.Path("."), **kwargs):
 | 
				
			||||||
 | 
					        self.model_name = getattr(sbm, GPDModelRunner.model_type)
 | 
				
			||||||
 | 
					        super(GPDModelRunner, self).__init__(
 | 
				
			||||||
 | 
					            weights_name=weights_name, output_dir=output_dir, **kwargs
 | 
				
			||||||
 | 
					        )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def process_kwargs(self, **kwargs):
 | 
				
			||||||
 | 
					        self.stride = int(kwargs.get("stride", 10))
 | 
				
			||||||
 | 
					        self.threshold_p = float(kwargs.get("threshold_p", 0.75))
 | 
				
			||||||
 | 
					        self.threshold_s = float(kwargs.get("threshold_s", 0.75))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        self.annotate_kwargs = {
 | 
				
			||||||
 | 
					            "stride": self.stride,
 | 
				
			||||||
 | 
					        }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        self.classify_kwargs = self.model.default_args.copy()
 | 
				
			||||||
 | 
					        self.classify_kwargs["P_threshold"] = self.threshold_p
 | 
				
			||||||
 | 
					        self.classify_kwargs["S_threshold"] = self.threshold_s
 | 
				
			||||||
@@ -0,0 +1,102 @@
 | 
				
			|||||||
 | 
					import json
 | 
				
			||||||
 | 
					import pathlib
 | 
				
			||||||
 | 
					import sys
 | 
				
			||||||
 | 
					from collections import defaultdict
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					import obspy
 | 
				
			||||||
 | 
					import seisbench.models as sbm
 | 
				
			||||||
 | 
					from obspy.core.event import Catalog, Event, Pick, WaveformStreamID
 | 
				
			||||||
 | 
					from obspy.io.json.default import Default
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def list_pretrained_models(model_runner_class):
 | 
				
			||||||
 | 
					    m = getattr(sbm, model_runner_class.model_type)
 | 
				
			||||||
 | 
					    weights = m.list_pretrained()
 | 
				
			||||||
 | 
					    return weights
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def exit_error(msg):
 | 
				
			||||||
 | 
					    print("ERROR:", msg)
 | 
				
			||||||
 | 
					    sys.exit(1)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					class ModelRunner:
 | 
				
			||||||
 | 
					    model_type = "EMPTY"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def __init__(self, weights_name="original", output_dir=pathlib.Path("."), **kwargs):
 | 
				
			||||||
 | 
					        # self.model_name = getattr(sbm, __class__.model_type)
 | 
				
			||||||
 | 
					        self.model = self.load_model(weights_name)
 | 
				
			||||||
 | 
					        self.output_dir = pathlib.Path(output_dir)
 | 
				
			||||||
 | 
					        self.process_kwargs(**kwargs)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def process_kwargs(self, **kwargs):
 | 
				
			||||||
 | 
					        pass
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def citation(self):
 | 
				
			||||||
 | 
					        return self.model.citation
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def load_model(self, weights_name):
 | 
				
			||||||
 | 
					        return self.model_name.from_pretrained(weights_name)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def load_stream(self, stream_file_name):
 | 
				
			||||||
 | 
					        return obspy.read(stream_file_name)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def save_picks(self, classs_picks, stream_path):
 | 
				
			||||||
 | 
					        dict_picks = list(map(lambda p: p.__dict__, classs_picks))
 | 
				
			||||||
 | 
					        fpath = self.output_dir / f"{stream_path.stem}_picks.json"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        with open(fpath, "w") as fp:
 | 
				
			||||||
 | 
					            json.dump(dict_picks, fp, default=Default())
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def save_quakeml(self, classs_picks, stream_path):
 | 
				
			||||||
 | 
					        e = Event()
 | 
				
			||||||
 | 
					        for cpick in classs_picks:
 | 
				
			||||||
 | 
					            net, sta, loc = cpick.trace_id.split(".")
 | 
				
			||||||
 | 
					            p = Pick(
 | 
				
			||||||
 | 
					                time=cpick.peak_time,
 | 
				
			||||||
 | 
					                phase_hint=cpick.phase,
 | 
				
			||||||
 | 
					                waveform_id=WaveformStreamID(
 | 
				
			||||||
 | 
					                    network_code=net, station_code=sta, location_code=loc
 | 
				
			||||||
 | 
					                ),
 | 
				
			||||||
 | 
					            )
 | 
				
			||||||
 | 
					            e.picks.append(p)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        cat = Catalog([e])
 | 
				
			||||||
 | 
					        fpath = self.output_dir / f"{stream_path.stem}_picks.xml"
 | 
				
			||||||
 | 
					        cat.write(fpath, format="QUAKEML")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def write_annotations(self, annotations, stream_path):
 | 
				
			||||||
 | 
					        ann = annotations.copy()
 | 
				
			||||||
 | 
					        for tr in ann:
 | 
				
			||||||
 | 
					            tr.stats.channel = f"G_{tr.stats.component}"
 | 
				
			||||||
 | 
					        fpath = self.output_dir / f"{stream_path.stem}_annotations.mseed"
 | 
				
			||||||
 | 
					        ann.write(fpath, format="MSEED")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    @staticmethod
 | 
				
			||||||
 | 
					    def validate_stream(stream):
 | 
				
			||||||
 | 
					        groups = defaultdict(list)
 | 
				
			||||||
 | 
					        for trace in stream:
 | 
				
			||||||
 | 
					            groups[trace.stats.station].append(trace.stats.channel[-1])
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        number_of_channels = list(map(len, groups.values()))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        if max(number_of_channels) < 3:
 | 
				
			||||||
 | 
					            exit_error("Not enough traces in the stream")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def find_picks(self, stream_file_name, save_annotations=True):
 | 
				
			||||||
 | 
					        stream_path = pathlib.Path(stream_file_name)
 | 
				
			||||||
 | 
					        stream = self.load_stream(stream_path)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        self.validate_stream(stream)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        annotations = self.model.annotate(stream, **self.annotate_kwargs)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        if save_annotations:
 | 
				
			||||||
 | 
					            self.write_annotations(annotations, stream_path)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        classs_picks = self.model.classify_aggregate(annotations, self.classify_kwargs)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        self.save_picks(classs_picks, stream_path)
 | 
				
			||||||
 | 
					        self.save_quakeml(classs_picks, stream_path)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        return classs_picks
 | 
				
			||||||
							
								
								
									
										4
									
								
								src/epos_ai_picking_tools/__about__.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										4
									
								
								src/epos_ai_picking_tools/__about__.py
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,4 @@
 | 
				
			|||||||
 | 
					# SPDX-FileCopyrightText: 2022-present Hubert Siejkowski <h.siejkowski@gmail.com>
 | 
				
			||||||
 | 
					#
 | 
				
			||||||
 | 
					# SPDX-License-Identifier: MIT
 | 
				
			||||||
 | 
					__version__ = "0.4.0"
 | 
				
			||||||
							
								
								
									
										0
									
								
								src/epos_ai_picking_tools/__init__.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										0
									
								
								src/epos_ai_picking_tools/__init__.py
									
									
									
									
									
										Normal file
									
								
							
							
								
								
									
										90
									
								
								src/epos_ai_picking_tools/cli_gpd.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										90
									
								
								src/epos_ai_picking_tools/cli_gpd.py
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,90 @@
 | 
				
			|||||||
 | 
					import pathlib
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					import click
 | 
				
			||||||
 | 
					import seisbench
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					from .__about__ import __version__
 | 
				
			||||||
 | 
					from .gpd import GPDModelRunner
 | 
				
			||||||
 | 
					from .model_runner import list_pretrained_models
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					@click.group()
 | 
				
			||||||
 | 
					def cli():
 | 
				
			||||||
 | 
					    pass
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					@cli.command
 | 
				
			||||||
 | 
					def version():
 | 
				
			||||||
 | 
					    """Prints version"""
 | 
				
			||||||
 | 
					    print(f"SeisBench v{seisbench.__version__}")
 | 
				
			||||||
 | 
					    print(f"gpd_tool v{__version__}")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					@cli.command
 | 
				
			||||||
 | 
					def citation():
 | 
				
			||||||
 | 
					    """Prints citation of the model"""
 | 
				
			||||||
 | 
					    m = GPDModelRunner()
 | 
				
			||||||
 | 
					    print(m.citation())
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					@cli.command
 | 
				
			||||||
 | 
					def list_pretrained():
 | 
				
			||||||
 | 
					    """Show pretrained model names"""
 | 
				
			||||||
 | 
					    print(", ".join(list_pretrained_models(GPDModelRunner)))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					@cli.command
 | 
				
			||||||
 | 
					@click.option(
 | 
				
			||||||
 | 
					    "-w",
 | 
				
			||||||
 | 
					    "--weights",
 | 
				
			||||||
 | 
					    default="original",
 | 
				
			||||||
 | 
					    type=str,
 | 
				
			||||||
 | 
					    show_default=True,
 | 
				
			||||||
 | 
					    help=f"for possible options see output of 'list-pretrained'",
 | 
				
			||||||
 | 
					)
 | 
				
			||||||
 | 
					@click.option(
 | 
				
			||||||
 | 
					    "-o",
 | 
				
			||||||
 | 
					    "--output",
 | 
				
			||||||
 | 
					    default=pathlib.Path("."),
 | 
				
			||||||
 | 
					    type=click.Path(dir_okay=True, path_type=pathlib.Path),
 | 
				
			||||||
 | 
					    show_default=True,
 | 
				
			||||||
 | 
					    help="directory to store results",
 | 
				
			||||||
 | 
					)
 | 
				
			||||||
 | 
					@click.option(
 | 
				
			||||||
 | 
					    "-s",
 | 
				
			||||||
 | 
					    "--stride",
 | 
				
			||||||
 | 
					    default=10,
 | 
				
			||||||
 | 
					    type=int,
 | 
				
			||||||
 | 
					    show_default=True,
 | 
				
			||||||
 | 
					    help="stride in samples for point prediction models",
 | 
				
			||||||
 | 
					)
 | 
				
			||||||
 | 
					@click.option(
 | 
				
			||||||
 | 
					    "-tp",
 | 
				
			||||||
 | 
					    "--threshold-p",
 | 
				
			||||||
 | 
					    default=0.75,
 | 
				
			||||||
 | 
					    type=float,
 | 
				
			||||||
 | 
					    show_default=True,
 | 
				
			||||||
 | 
					    help="detection threshold for the P phase",
 | 
				
			||||||
 | 
					)
 | 
				
			||||||
 | 
					@click.option(
 | 
				
			||||||
 | 
					    "-ts",
 | 
				
			||||||
 | 
					    "--threshold-s",
 | 
				
			||||||
 | 
					    default=0.75,
 | 
				
			||||||
 | 
					    type=float,
 | 
				
			||||||
 | 
					    show_default=True,
 | 
				
			||||||
 | 
					    help="detection threshold for the S phase",
 | 
				
			||||||
 | 
					)
 | 
				
			||||||
 | 
					@click.argument("stream_file_names", nargs=-1, type=click.Path(exists=True))
 | 
				
			||||||
 | 
					def pick(stream_file_names, weights, output, stride, threshold_p, threshold_s):
 | 
				
			||||||
 | 
					    """Detect phases in streams"""
 | 
				
			||||||
 | 
					    if not output.exists():
 | 
				
			||||||
 | 
					        output.mkdir()
 | 
				
			||||||
 | 
					    m = GPDModelRunner(
 | 
				
			||||||
 | 
					        weights,
 | 
				
			||||||
 | 
					        output_dir=output,
 | 
				
			||||||
 | 
					        stride=stride,
 | 
				
			||||||
 | 
					        threshold_p=threshold_p,
 | 
				
			||||||
 | 
					        threshold_s=threshold_s,
 | 
				
			||||||
 | 
					    )
 | 
				
			||||||
 | 
					    for stream in stream_file_names:
 | 
				
			||||||
 | 
					        m.find_picks(stream)
 | 
				
			||||||
							
								
								
									
										90
									
								
								src/epos_ai_picking_tools/cli_phasenet.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										90
									
								
								src/epos_ai_picking_tools/cli_phasenet.py
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,90 @@
 | 
				
			|||||||
 | 
					import pathlib
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					import click
 | 
				
			||||||
 | 
					import seisbench
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					from .__about__ import __version__
 | 
				
			||||||
 | 
					from .model_runner import list_pretrained_models
 | 
				
			||||||
 | 
					from .phasenet import PhaseNetModelRunner
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					@click.group()
 | 
				
			||||||
 | 
					def cli():
 | 
				
			||||||
 | 
					    pass
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					@cli.command
 | 
				
			||||||
 | 
					def version():
 | 
				
			||||||
 | 
					    """Prints version"""
 | 
				
			||||||
 | 
					    print(f"SeisBench v{seisbench.__version__}")
 | 
				
			||||||
 | 
					    print(f"phasenet_tool v{__version__}")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					@cli.command
 | 
				
			||||||
 | 
					def citation():
 | 
				
			||||||
 | 
					    """Prints citation of the model"""
 | 
				
			||||||
 | 
					    m = PhaseNetModelRunner()
 | 
				
			||||||
 | 
					    print(m.citation())
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					@cli.command
 | 
				
			||||||
 | 
					def list_pretrained():
 | 
				
			||||||
 | 
					    """Show pretrained model names"""
 | 
				
			||||||
 | 
					    print(", ".join(list_pretrained_models(PhaseNetModelRunner)))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					@cli.command
 | 
				
			||||||
 | 
					@click.option(
 | 
				
			||||||
 | 
					    "-w",
 | 
				
			||||||
 | 
					    "--weights",
 | 
				
			||||||
 | 
					    default="original",
 | 
				
			||||||
 | 
					    type=str,
 | 
				
			||||||
 | 
					    show_default=True,
 | 
				
			||||||
 | 
					    help=f"for possible options see output of 'list-pretrained'",
 | 
				
			||||||
 | 
					)
 | 
				
			||||||
 | 
					@click.option(
 | 
				
			||||||
 | 
					    "-o",
 | 
				
			||||||
 | 
					    "--output",
 | 
				
			||||||
 | 
					    default=pathlib.Path("."),
 | 
				
			||||||
 | 
					    type=click.Path(dir_okay=True, path_type=pathlib.Path),
 | 
				
			||||||
 | 
					    show_default=True,
 | 
				
			||||||
 | 
					    help="directory to store results",
 | 
				
			||||||
 | 
					)
 | 
				
			||||||
 | 
					@click.option(
 | 
				
			||||||
 | 
					    "-tp",
 | 
				
			||||||
 | 
					    "--threshold-p",
 | 
				
			||||||
 | 
					    default=0.3,
 | 
				
			||||||
 | 
					    type=float,
 | 
				
			||||||
 | 
					    show_default=True,
 | 
				
			||||||
 | 
					    help="detection threshold for the P phase",
 | 
				
			||||||
 | 
					)
 | 
				
			||||||
 | 
					@click.option(
 | 
				
			||||||
 | 
					    "-ts",
 | 
				
			||||||
 | 
					    "--threshold-s",
 | 
				
			||||||
 | 
					    default=0.3,
 | 
				
			||||||
 | 
					    type=float,
 | 
				
			||||||
 | 
					    show_default=True,
 | 
				
			||||||
 | 
					    help="detection threshold for the S phase",
 | 
				
			||||||
 | 
					)
 | 
				
			||||||
 | 
					@click.option(
 | 
				
			||||||
 | 
					    "-b",
 | 
				
			||||||
 | 
					    "--blinding",
 | 
				
			||||||
 | 
					    default=(0, 0),
 | 
				
			||||||
 | 
					    type=tuple,
 | 
				
			||||||
 | 
					    show_default=True,
 | 
				
			||||||
 | 
					    help="number of prediction samples to discard on each side of each window prediction",
 | 
				
			||||||
 | 
					)
 | 
				
			||||||
 | 
					@click.argument("stream_file_names", nargs=-1, type=click.Path(exists=True))
 | 
				
			||||||
 | 
					def pick(stream_file_names, weights, output, threshold_p, threshold_s, blinding):
 | 
				
			||||||
 | 
					    """Detect phases in streams"""
 | 
				
			||||||
 | 
					    if not output.exists():
 | 
				
			||||||
 | 
					        output.mkdir()
 | 
				
			||||||
 | 
					    m = PhaseNetModelRunner(
 | 
				
			||||||
 | 
					        weights,
 | 
				
			||||||
 | 
					        output_dir=output,
 | 
				
			||||||
 | 
					        threshold_p=threshold_p,
 | 
				
			||||||
 | 
					        threshold_s=threshold_s,
 | 
				
			||||||
 | 
					        blinding=blinding,
 | 
				
			||||||
 | 
					    )
 | 
				
			||||||
 | 
					    for stream in stream_file_names:
 | 
				
			||||||
 | 
					        m.find_picks(stream)
 | 
				
			||||||
							
								
								
									
										29
									
								
								src/epos_ai_picking_tools/gpd.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										29
									
								
								src/epos_ai_picking_tools/gpd.py
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,29 @@
 | 
				
			|||||||
 | 
					import pathlib
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					import seisbench.models as sbm
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					from .model_runner import ModelRunner
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					class GPDModelRunner(ModelRunner):
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    model_type = "GPD"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def __init__(self, weights_name="original", output_dir=pathlib.Path("."), **kwargs):
 | 
				
			||||||
 | 
					        self.model_name = getattr(sbm, GPDModelRunner.model_type)
 | 
				
			||||||
 | 
					        super(GPDModelRunner, self).__init__(
 | 
				
			||||||
 | 
					            weights_name=weights_name, output_dir=output_dir, **kwargs
 | 
				
			||||||
 | 
					        )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def process_kwargs(self, **kwargs):
 | 
				
			||||||
 | 
					        self.stride = int(kwargs.get("stride", 10))
 | 
				
			||||||
 | 
					        self.threshold_p = float(kwargs.get("threshold_p", 0.75))
 | 
				
			||||||
 | 
					        self.threshold_s = float(kwargs.get("threshold_s", 0.75))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        self.annotate_kwargs = {
 | 
				
			||||||
 | 
					            "stride": self.stride,
 | 
				
			||||||
 | 
					        }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        self.classify_kwargs = self.model.default_args.copy()
 | 
				
			||||||
 | 
					        self.classify_kwargs["P_threshold"] = self.threshold_p
 | 
				
			||||||
 | 
					        self.classify_kwargs["S_threshold"] = self.threshold_s
 | 
				
			||||||
							
								
								
									
										102
									
								
								src/epos_ai_picking_tools/model_runner.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										102
									
								
								src/epos_ai_picking_tools/model_runner.py
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,102 @@
 | 
				
			|||||||
 | 
					import json
 | 
				
			||||||
 | 
					import pathlib
 | 
				
			||||||
 | 
					import sys
 | 
				
			||||||
 | 
					from collections import defaultdict
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					import obspy
 | 
				
			||||||
 | 
					import seisbench.models as sbm
 | 
				
			||||||
 | 
					from obspy.core.event import Catalog, Event, Pick, WaveformStreamID
 | 
				
			||||||
 | 
					from obspy.io.json.default import Default
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def list_pretrained_models(model_runner_class):
 | 
				
			||||||
 | 
					    m = getattr(sbm, model_runner_class.model_type)
 | 
				
			||||||
 | 
					    weights = m.list_pretrained()
 | 
				
			||||||
 | 
					    return weights
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def exit_error(msg):
 | 
				
			||||||
 | 
					    print("ERROR:", msg)
 | 
				
			||||||
 | 
					    sys.exit(1)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					class ModelRunner:
 | 
				
			||||||
 | 
					    model_type = "EMPTY"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def __init__(self, weights_name="original", output_dir=pathlib.Path("."), **kwargs):
 | 
				
			||||||
 | 
					        # self.model_name = getattr(sbm, __class__.model_type)
 | 
				
			||||||
 | 
					        self.model = self.load_model(weights_name)
 | 
				
			||||||
 | 
					        self.output_dir = pathlib.Path(output_dir)
 | 
				
			||||||
 | 
					        self.process_kwargs(**kwargs)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def process_kwargs(self, **kwargs):
 | 
				
			||||||
 | 
					        pass
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def citation(self):
 | 
				
			||||||
 | 
					        return self.model.citation
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def load_model(self, weights_name):
 | 
				
			||||||
 | 
					        return self.model_name.from_pretrained(weights_name)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def load_stream(self, stream_file_name):
 | 
				
			||||||
 | 
					        return obspy.read(stream_file_name)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def save_picks(self, classs_picks, stream_path):
 | 
				
			||||||
 | 
					        dict_picks = list(map(lambda p: p.__dict__, classs_picks))
 | 
				
			||||||
 | 
					        fpath = self.output_dir / f"{stream_path.stem}_picks.json"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        with open(fpath, "w") as fp:
 | 
				
			||||||
 | 
					            json.dump(dict_picks, fp, default=Default())
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def save_quakeml(self, classs_picks, stream_path):
 | 
				
			||||||
 | 
					        e = Event()
 | 
				
			||||||
 | 
					        for cpick in classs_picks:
 | 
				
			||||||
 | 
					            net, sta, loc = cpick.trace_id.split(".")
 | 
				
			||||||
 | 
					            p = Pick(
 | 
				
			||||||
 | 
					                time=cpick.peak_time,
 | 
				
			||||||
 | 
					                phase_hint=cpick.phase,
 | 
				
			||||||
 | 
					                waveform_id=WaveformStreamID(
 | 
				
			||||||
 | 
					                    network_code=net, station_code=sta, location_code=loc
 | 
				
			||||||
 | 
					                ),
 | 
				
			||||||
 | 
					            )
 | 
				
			||||||
 | 
					            e.picks.append(p)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        cat = Catalog([e])
 | 
				
			||||||
 | 
					        fpath = self.output_dir / f"{stream_path.stem}_picks.xml"
 | 
				
			||||||
 | 
					        cat.write(fpath, format="QUAKEML")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def write_annotations(self, annotations, stream_path):
 | 
				
			||||||
 | 
					        ann = annotations.copy()
 | 
				
			||||||
 | 
					        for tr in ann:
 | 
				
			||||||
 | 
					            tr.stats.channel = f"G_{tr.stats.component}"
 | 
				
			||||||
 | 
					        fpath = self.output_dir / f"{stream_path.stem}_annotations.mseed"
 | 
				
			||||||
 | 
					        ann.write(fpath, format="MSEED")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    @staticmethod
 | 
				
			||||||
 | 
					    def validate_stream(stream):
 | 
				
			||||||
 | 
					        groups = defaultdict(list)
 | 
				
			||||||
 | 
					        for trace in stream:
 | 
				
			||||||
 | 
					            groups[trace.stats.station].append(trace.stats.channel[-1])
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        number_of_channels = list(map(len, groups.values()))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        if max(number_of_channels) < 3:
 | 
				
			||||||
 | 
					            exit_error("Not enough traces in the stream")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def find_picks(self, stream_file_name, save_annotations=True):
 | 
				
			||||||
 | 
					        stream_path = pathlib.Path(stream_file_name)
 | 
				
			||||||
 | 
					        stream = self.load_stream(stream_path)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        self.validate_stream(stream)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        annotations = self.model.annotate(stream, **self.annotate_kwargs)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        if save_annotations:
 | 
				
			||||||
 | 
					            self.write_annotations(annotations, stream_path)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        classs_picks = self.model.classify_aggregate(annotations, self.classify_kwargs)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        self.save_picks(classs_picks, stream_path)
 | 
				
			||||||
 | 
					        self.save_quakeml(classs_picks, stream_path)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        return classs_picks
 | 
				
			||||||
							
								
								
									
										28
									
								
								src/epos_ai_picking_tools/phasenet.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										28
									
								
								src/epos_ai_picking_tools/phasenet.py
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,28 @@
 | 
				
			|||||||
 | 
					import pathlib
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					import seisbench.models as sbm
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					from .model_runner import ModelRunner
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					class PhaseNetModelRunner(ModelRunner):
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    model_type = "PhaseNet"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def __init__(self, weights_name="original", output_dir=pathlib.Path("."), **kwargs):
 | 
				
			||||||
 | 
					        self.model_name = getattr(sbm, PhaseNetModelRunner.model_type)
 | 
				
			||||||
 | 
					        super(PhaseNetModelRunner, self).__init__(
 | 
				
			||||||
 | 
					            weights_name=weights_name, output_dir=output_dir, **kwargs
 | 
				
			||||||
 | 
					        )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def process_kwargs(self, **kwargs):
 | 
				
			||||||
 | 
					        self.threshold_p = float(kwargs.get("threshold_p", 0.3))
 | 
				
			||||||
 | 
					        self.threshold_s = float(kwargs.get("threshold_s", 0.3))
 | 
				
			||||||
 | 
					        self.blinding = kwargs.get("blinding", (0, 0))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        self.annotate_kwargs = {}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        self.classify_kwargs = self.model.default_args.copy()
 | 
				
			||||||
 | 
					        self.classify_kwargs["P_threshold"] = self.threshold_p
 | 
				
			||||||
 | 
					        self.classify_kwargs["S_threshold"] = self.threshold_s
 | 
				
			||||||
 | 
					        self.classify_kwargs["blinding"] = self.blinding
 | 
				
			||||||
							
								
								
									
										0
									
								
								tests/__init__.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										0
									
								
								tests/__init__.py
									
									
									
									
									
										Normal file
									
								
							
							
								
								
									
										
											BIN
										
									
								
								tests/dummy.mseed
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										
											BIN
										
									
								
								tests/dummy.mseed
									
									
									
									
									
										Normal file
									
								
							
										
											Binary file not shown.
										
									
								
							
							
								
								
									
										85
									
								
								tests/test_epos_ai_picking_tools.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										85
									
								
								tests/test_epos_ai_picking_tools.py
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,85 @@
 | 
				
			|||||||
 | 
					import pathlib
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					import pytest
 | 
				
			||||||
 | 
					import seisbench.models as sbm
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					from epos_ai_picking_tools.gpd import GPDModelRunner
 | 
				
			||||||
 | 
					from epos_ai_picking_tools.phasenet import PhaseNetModelRunner
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					models = [
 | 
				
			||||||
 | 
					    (GPDModelRunner, "stead", sbm.GPD),
 | 
				
			||||||
 | 
					    (PhaseNetModelRunner, "original", sbm.PhaseNet),
 | 
				
			||||||
 | 
					]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					@pytest.fixture()
 | 
				
			||||||
 | 
					def simple_seed():
 | 
				
			||||||
 | 
					    current_dir = pathlib.Path(__file__).parent
 | 
				
			||||||
 | 
					    seed = current_dir / "dummy.mseed"
 | 
				
			||||||
 | 
					    return seed
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					@pytest.fixture()
 | 
				
			||||||
 | 
					def phasenet_model_runner(tmpdir):
 | 
				
			||||||
 | 
					    m = PhaseNetModelRunner("original", output_dir=tmpdir)
 | 
				
			||||||
 | 
					    return m
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					@pytest.mark.parametrize("model_runner,weights,seisbench_model", models)
 | 
				
			||||||
 | 
					def test_fail_ModelRunner_load_weights(model_runner, weights, seisbench_model):
 | 
				
			||||||
 | 
					    with pytest.raises(ValueError):
 | 
				
			||||||
 | 
					        m = model_runner("no_exitsting_model_weights")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					@pytest.mark.parametrize("model_runner,weights,seisbench_model", models)
 | 
				
			||||||
 | 
					def test_ModelRunner_load_weights(model_runner, weights, seisbench_model):
 | 
				
			||||||
 | 
					    m = model_runner(weights)
 | 
				
			||||||
 | 
					    assert isinstance(m.model, seisbench_model)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					@pytest.mark.parametrize("model_runner,weights,seisbench_model", models)
 | 
				
			||||||
 | 
					def test_ModelRunner_output_dir(model_runner, weights, seisbench_model):
 | 
				
			||||||
 | 
					    m = model_runner(weights)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    assert m.output_dir == pathlib.Path(".")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    m = model_runner(weights, "./tmp_dir")
 | 
				
			||||||
 | 
					    assert m.output_dir == pathlib.Path("./tmp_dir")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					@pytest.mark.parametrize("model_runner,weights,seisbench_model", models)
 | 
				
			||||||
 | 
					def test_ModelRunner_find_picks(
 | 
				
			||||||
 | 
					    model_runner, weights, seisbench_model, simple_seed, tmpdir
 | 
				
			||||||
 | 
					):
 | 
				
			||||||
 | 
					    m = model_runner(weights, output_dir=tmpdir)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    picks = m.find_picks(simple_seed)
 | 
				
			||||||
 | 
					    assert isinstance(picks, list)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    json_in_dir = m.output_dir.glob("*.json")
 | 
				
			||||||
 | 
					    assert len(list(json_in_dir)) == 1
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    annotations_in_dir = m.output_dir.glob("*annotations*")
 | 
				
			||||||
 | 
					    assert len(list(annotations_in_dir)) == 1
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def test_GPDModelRunner_find_picks_reults(simple_seed, tmpdir):
 | 
				
			||||||
 | 
					    m = GPDModelRunner("stead", output_dir=tmpdir)
 | 
				
			||||||
 | 
					    picks = m.find_picks(simple_seed)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    assert picks[0].phase == "P"
 | 
				
			||||||
 | 
					    assert picks[0].peak_time == "2000-01-01T07:00:05.100000Z"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    assert picks[2].phase == "S"
 | 
				
			||||||
 | 
					    assert picks[2].peak_time == "2000-01-01T07:00:15.700000Z"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def test_PhaseNetModelRunner_find_picks_reults(simple_seed, tmpdir):
 | 
				
			||||||
 | 
					    m = PhaseNetModelRunner("original", output_dir=tmpdir)
 | 
				
			||||||
 | 
					    picks = m.find_picks(simple_seed)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    assert picks[0].phase == "P"
 | 
				
			||||||
 | 
					    assert picks[0].peak_time == "2000-01-01T07:00:05.350000Z"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    assert picks[1].phase == "S"
 | 
				
			||||||
 | 
					    assert picks[1].peak_time == "2000-01-01T07:00:06.140000Z"
 | 
				
			||||||
		Reference in New Issue
	
	Block a user