Add new PhaseNet weights

This commit is contained in:
Hubert Siejkowski 2023-12-06 12:43:51 +00:00
parent fd1417b70d
commit 140ec4def3

View File

@ -13,7 +13,7 @@ EPOS_AI_MODEL_REPOSIOTRY_URL = "https://models.isl.grid.cyfronet.pl/models/v3/"
class PhaseNetModelRunner(ModelRunner): class PhaseNetModelRunner(ModelRunner):
model_type = "PhaseNet" model_type = "PhaseNet"
extra_weights = [("bogdanka", "1"), ("lgcd", "1")] extra_weights = [("bogdanka", "1"), ("lgcd", "1"), ("bogdanka_lgcd", "1")]
def __init__(self, weights_name="original", output_dir=pathlib.Path("."), **kwargs): def __init__(self, weights_name="original", output_dir=pathlib.Path("."), **kwargs):
self.model_name = getattr(sbm, PhaseNetModelRunner.model_type) self.model_name = getattr(sbm, PhaseNetModelRunner.model_type)