epos-ai-picking-tools/tests/test_epos_ai_picking_tools.py
2023-09-20 09:44:18 +00:00

86 lines
2.5 KiB
Python

import pathlib
import pytest
import seisbench.models as sbm
from epos_ai_picking_tools.gpd import GPDModelRunner
from epos_ai_picking_tools.phasenet import PhaseNetModelRunner
models = [
(GPDModelRunner, "stead", sbm.GPD),
(PhaseNetModelRunner, "original", sbm.PhaseNet),
]
@pytest.fixture()
def simple_seed():
current_dir = pathlib.Path(__file__).parent
seed = current_dir / "dummy.mseed"
return seed
@pytest.fixture()
def phasenet_model_runner(tmpdir):
m = PhaseNetModelRunner("original", output_dir=tmpdir)
return m
@pytest.mark.parametrize("model_runner,weights,seisbench_model", models)
def test_fail_ModelRunner_load_weights(model_runner, weights, seisbench_model):
with pytest.raises(ValueError):
m = model_runner("no_exitsting_model_weights")
@pytest.mark.parametrize("model_runner,weights,seisbench_model", models)
def test_ModelRunner_load_weights(model_runner, weights, seisbench_model):
m = model_runner(weights)
assert isinstance(m.model, seisbench_model)
@pytest.mark.parametrize("model_runner,weights,seisbench_model", models)
def test_ModelRunner_output_dir(model_runner, weights, seisbench_model):
m = model_runner(weights)
assert m.output_dir == pathlib.Path(".")
m = model_runner(weights, "./tmp_dir")
assert m.output_dir == pathlib.Path("./tmp_dir")
@pytest.mark.parametrize("model_runner,weights,seisbench_model", models)
def test_ModelRunner_find_picks(
model_runner, weights, seisbench_model, simple_seed, tmpdir
):
m = model_runner(weights, output_dir=tmpdir)
picks = m.find_picks(simple_seed)
assert isinstance(picks, list)
json_in_dir = m.output_dir.glob("*.json")
assert len(list(json_in_dir)) == 1
annotations_in_dir = m.output_dir.glob("*annotations*")
assert len(list(annotations_in_dir)) == 1
def test_GPDModelRunner_find_picks_reults(simple_seed, tmpdir):
m = GPDModelRunner("stead", output_dir=tmpdir)
picks = m.find_picks(simple_seed)
assert picks[0].phase == "P"
assert picks[0].peak_time == "2000-01-01T07:00:05.100000Z"
assert picks[2].phase == "S"
assert picks[2].peak_time == "2000-01-01T07:00:15.700000Z"
def test_PhaseNetModelRunner_find_picks_reults(simple_seed, tmpdir):
m = PhaseNetModelRunner("original", output_dir=tmpdir)
picks = m.find_picks(simple_seed)
assert picks[0].phase == "P"
assert picks[0].peak_time == "2000-01-01T07:00:05.350000Z"
assert picks[1].phase == "S"
assert picks[1].peak_time == "2000-01-01T07:00:06.140000Z"