From 75618c9a0d38c71417677994757eccc9de77aa76 Mon Sep 17 00:00:00 2001 From: Hubert Siejkowski Date: Thu, 10 Apr 2025 14:54:49 +0200 Subject: [PATCH] ISEPOS-2350 improve validation of input stream --- src/epos_ai_picking_tools/model_runner.py | 17 +++++++++++++---- 1 file changed, 13 insertions(+), 4 deletions(-) diff --git a/src/epos_ai_picking_tools/model_runner.py b/src/epos_ai_picking_tools/model_runner.py index eeb40da..ca881d6 100644 --- a/src/epos_ai_picking_tools/model_runner.py +++ b/src/epos_ai_picking_tools/model_runner.py @@ -73,16 +73,25 @@ class ModelRunner: fpath = self.output_dir / f"{stream_path.stem}_annotations.mseed" ann.write(fpath, format="MSEED") - @staticmethod - def validate_stream(stream): + def validate_stream(self, stream): + model = self.model groups = defaultdict(list) + samples = defaultdict(list) for trace in stream: groups[trace.stats.station].append(trace.stats.channel[-1]) + samples[trace.stats.station].append( + float(trace.stats.npts / trace.stats.sampling_rate) + ) number_of_channels = list(map(len, groups.values())) + lenght_of_traces = list(map(max , samples.values())) - if max(number_of_channels) < 3: - exit_error("Not enough traces in the stream") + if max(number_of_channels) < model.in_channels: + exit_error("Not enough channels in the stream.") + + minimal_trace_length = model.in_samples / model.sampling_rate + if max(lenght_of_traces) < minimal_trace_length: + exit_error(f"All traces are shorter than required {int(minimal_trace_length):d} seconds") def find_picks(self, stream_file_name, save_annotations=True): stream_path = pathlib.Path(stream_file_name)