import pathlib import pytest import seisbench.models as sbm from epos_ai_picking_tools.gpd import GPDModelRunner from epos_ai_picking_tools.phasenet import PhaseNetModelRunner models = [ (GPDModelRunner, "stead", sbm.GPD), (PhaseNetModelRunner, "original", sbm.PhaseNet), ] @pytest.fixture() def simple_seed(): current_dir = pathlib.Path(__file__).parent seed = current_dir / "dummy.mseed" return seed @pytest.fixture() def phasenet_model_runner(tmpdir): m = PhaseNetModelRunner("original", output_dir=tmpdir) return m @pytest.mark.parametrize("model_runner,weights,seisbench_model", models) def test_fail_ModelRunner_load_weights(model_runner, weights, seisbench_model): with pytest.raises(ValueError): m = model_runner("no_exitsting_model_weights") @pytest.mark.parametrize("model_runner,weights,seisbench_model", models) def test_ModelRunner_load_weights(model_runner, weights, seisbench_model): m = model_runner(weights) assert isinstance(m.model, seisbench_model) @pytest.mark.parametrize("model_runner,weights,seisbench_model", models) def test_ModelRunner_output_dir(model_runner, weights, seisbench_model): m = model_runner(weights) assert m.output_dir == pathlib.Path(".") m = model_runner(weights, "./tmp_dir") assert m.output_dir == pathlib.Path("./tmp_dir") @pytest.mark.parametrize("model_runner,weights,seisbench_model", models) def test_ModelRunner_find_picks( model_runner, weights, seisbench_model, simple_seed, tmpdir ): m = model_runner(weights, output_dir=tmpdir) picks = m.find_picks(simple_seed) assert isinstance(picks, list) json_in_dir = m.output_dir.glob("*.json") assert len(list(json_in_dir)) == 1 annotations_in_dir = m.output_dir.glob("*annotations*") assert len(list(annotations_in_dir)) == 1 def test_GPDModelRunner_find_picks_reults(simple_seed, tmpdir): m = GPDModelRunner("stead", output_dir=tmpdir) picks = m.find_picks(simple_seed) assert picks[0].phase == "P" assert picks[0].peak_time == "2000-01-01T07:00:05.100000Z" assert picks[2].phase == "S" assert picks[2].peak_time == "2000-01-01T07:00:15.700000Z" def test_PhaseNetModelRunner_find_picks_reults(simple_seed, tmpdir): m = PhaseNetModelRunner("original", output_dir=tmpdir) picks = m.find_picks(simple_seed) assert picks[0].phase == "P" assert picks[0].peak_time == "2000-01-01T07:00:05.350000Z" assert picks[1].phase == "S" assert picks[1].peak_time == "2000-01-01T07:00:06.140000Z"