Initial import
This commit is contained in:
0
tests/__init__.py
Normal file
0
tests/__init__.py
Normal file
BIN
tests/dummy.mseed
Normal file
BIN
tests/dummy.mseed
Normal file
Binary file not shown.
85
tests/test_epos_ai_picking_tools.py
Normal file
85
tests/test_epos_ai_picking_tools.py
Normal file
@@ -0,0 +1,85 @@
|
||||
import pathlib
|
||||
|
||||
import pytest
|
||||
import seisbench.models as sbm
|
||||
|
||||
from epos_ai_picking_tools.gpd import GPDModelRunner
|
||||
from epos_ai_picking_tools.phasenet import PhaseNetModelRunner
|
||||
|
||||
models = [
|
||||
(GPDModelRunner, "stead", sbm.GPD),
|
||||
(PhaseNetModelRunner, "original", sbm.PhaseNet),
|
||||
]
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
def simple_seed():
|
||||
current_dir = pathlib.Path(__file__).parent
|
||||
seed = current_dir / "dummy.mseed"
|
||||
return seed
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
def phasenet_model_runner(tmpdir):
|
||||
m = PhaseNetModelRunner("original", output_dir=tmpdir)
|
||||
return m
|
||||
|
||||
|
||||
@pytest.mark.parametrize("model_runner,weights,seisbench_model", models)
|
||||
def test_fail_ModelRunner_load_weights(model_runner, weights, seisbench_model):
|
||||
with pytest.raises(ValueError):
|
||||
m = model_runner("no_exitsting_model_weights")
|
||||
|
||||
|
||||
@pytest.mark.parametrize("model_runner,weights,seisbench_model", models)
|
||||
def test_ModelRunner_load_weights(model_runner, weights, seisbench_model):
|
||||
m = model_runner(weights)
|
||||
assert isinstance(m.model, seisbench_model)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("model_runner,weights,seisbench_model", models)
|
||||
def test_ModelRunner_output_dir(model_runner, weights, seisbench_model):
|
||||
m = model_runner(weights)
|
||||
|
||||
assert m.output_dir == pathlib.Path(".")
|
||||
|
||||
m = model_runner(weights, "./tmp_dir")
|
||||
assert m.output_dir == pathlib.Path("./tmp_dir")
|
||||
|
||||
|
||||
@pytest.mark.parametrize("model_runner,weights,seisbench_model", models)
|
||||
def test_ModelRunner_find_picks(
|
||||
model_runner, weights, seisbench_model, simple_seed, tmpdir
|
||||
):
|
||||
m = model_runner(weights, output_dir=tmpdir)
|
||||
|
||||
picks = m.find_picks(simple_seed)
|
||||
assert isinstance(picks, list)
|
||||
|
||||
json_in_dir = m.output_dir.glob("*.json")
|
||||
assert len(list(json_in_dir)) == 1
|
||||
|
||||
annotations_in_dir = m.output_dir.glob("*annotations*")
|
||||
assert len(list(annotations_in_dir)) == 1
|
||||
|
||||
|
||||
def test_GPDModelRunner_find_picks_reults(simple_seed, tmpdir):
|
||||
m = GPDModelRunner("stead", output_dir=tmpdir)
|
||||
picks = m.find_picks(simple_seed)
|
||||
|
||||
assert picks[0].phase == "P"
|
||||
assert picks[0].peak_time == "2000-01-01T07:00:05.100000Z"
|
||||
|
||||
assert picks[2].phase == "S"
|
||||
assert picks[2].peak_time == "2000-01-01T07:00:15.700000Z"
|
||||
|
||||
|
||||
def test_PhaseNetModelRunner_find_picks_reults(simple_seed, tmpdir):
|
||||
m = PhaseNetModelRunner("original", output_dir=tmpdir)
|
||||
picks = m.find_picks(simple_seed)
|
||||
|
||||
assert picks[0].phase == "P"
|
||||
assert picks[0].peak_time == "2000-01-01T07:00:05.350000Z"
|
||||
|
||||
assert picks[1].phase == "S"
|
||||
assert picks[1].peak_time == "2000-01-01T07:00:06.140000Z"
|
Reference in New Issue
Block a user