86 lines
2.5 KiB
Python
86 lines
2.5 KiB
Python
import pathlib
|
|
|
|
import pytest
|
|
import seisbench.models as sbm
|
|
|
|
from epos_ai_picking_tools.gpd import GPDModelRunner
|
|
from epos_ai_picking_tools.phasenet import PhaseNetModelRunner
|
|
|
|
models = [
|
|
(GPDModelRunner, "stead", sbm.GPD),
|
|
(PhaseNetModelRunner, "original", sbm.PhaseNet),
|
|
]
|
|
|
|
|
|
@pytest.fixture()
|
|
def simple_seed():
|
|
current_dir = pathlib.Path(__file__).parent
|
|
seed = current_dir / "dummy.mseed"
|
|
return seed
|
|
|
|
|
|
@pytest.fixture()
|
|
def phasenet_model_runner(tmpdir):
|
|
m = PhaseNetModelRunner("original", output_dir=tmpdir)
|
|
return m
|
|
|
|
|
|
@pytest.mark.parametrize("model_runner,weights,seisbench_model", models)
|
|
def test_fail_ModelRunner_load_weights(model_runner, weights, seisbench_model):
|
|
with pytest.raises(ValueError):
|
|
m = model_runner("no_exitsting_model_weights")
|
|
|
|
|
|
@pytest.mark.parametrize("model_runner,weights,seisbench_model", models)
|
|
def test_ModelRunner_load_weights(model_runner, weights, seisbench_model):
|
|
m = model_runner(weights)
|
|
assert isinstance(m.model, seisbench_model)
|
|
|
|
|
|
@pytest.mark.parametrize("model_runner,weights,seisbench_model", models)
|
|
def test_ModelRunner_output_dir(model_runner, weights, seisbench_model):
|
|
m = model_runner(weights)
|
|
|
|
assert m.output_dir == pathlib.Path(".")
|
|
|
|
m = model_runner(weights, "./tmp_dir")
|
|
assert m.output_dir == pathlib.Path("./tmp_dir")
|
|
|
|
|
|
@pytest.mark.parametrize("model_runner,weights,seisbench_model", models)
|
|
def test_ModelRunner_find_picks(
|
|
model_runner, weights, seisbench_model, simple_seed, tmpdir
|
|
):
|
|
m = model_runner(weights, output_dir=tmpdir)
|
|
|
|
picks = m.find_picks(simple_seed)
|
|
assert isinstance(picks, list)
|
|
|
|
json_in_dir = m.output_dir.glob("*.json")
|
|
assert len(list(json_in_dir)) == 1
|
|
|
|
annotations_in_dir = m.output_dir.glob("*annotations*")
|
|
assert len(list(annotations_in_dir)) == 1
|
|
|
|
|
|
def test_GPDModelRunner_find_picks_reults(simple_seed, tmpdir):
|
|
m = GPDModelRunner("stead", output_dir=tmpdir)
|
|
picks = m.find_picks(simple_seed)
|
|
|
|
assert picks[0].phase == "P"
|
|
assert picks[0].peak_time == "2000-01-01T07:00:05.100000Z"
|
|
|
|
assert picks[2].phase == "S"
|
|
assert picks[2].peak_time == "2000-01-01T07:00:15.700000Z"
|
|
|
|
|
|
def test_PhaseNetModelRunner_find_picks_reults(simple_seed, tmpdir):
|
|
m = PhaseNetModelRunner("original", output_dir=tmpdir)
|
|
picks = m.find_picks(simple_seed)
|
|
|
|
assert picks[0].phase == "P"
|
|
assert picks[0].peak_time == "2000-01-01T07:00:05.350000Z"
|
|
|
|
assert picks[1].phase == "S"
|
|
assert picks[1].peak_time == "2000-01-01T07:00:06.140000Z"
|