From bfcd18aff82abc98b78415e585df7cb5d1fce39c Mon Sep 17 00:00:00 2001
From: Christopher Rhodes <christopher.rhodes@embl.de>
Date: Fri, 25 Oct 2024 11:25:56 +0200
Subject: [PATCH] Tests parameterize models directly with dicts, not pydantic
 models, which are reserved for API calls

---
 model_server/base/pipelines/segment.py    |  2 +-
 model_server/base/session.py              |  2 +-
 model_server/extensions/ilastik/models.py | 18 ++--------
 model_server/extensions/ilastik/router.py | 29 ++++++++++-----
 tests/test_ilastik/test_ilastik.py        | 43 +++++++++++++----------
 5 files changed, 50 insertions(+), 44 deletions(-)

diff --git a/model_server/base/pipelines/segment.py b/model_server/base/pipelines/segment.py
index 887ccea5..e0180b55 100644
--- a/model_server/base/pipelines/segment.py
+++ b/model_server/base/pipelines/segment.py
@@ -36,7 +36,7 @@ def segment_pipeline(
     model = models.get('model')
 
     if not isinstance(model, SemanticSegmentationModel):
-        raise IncompatibleModelsError('Expecting a pixel classification model')
+        raise IncompatibleModelsError('Expecting a semantic segmentation model')
 
     if ch := k.get('channel') is not None:
         d['mono'] = d['input'].get_mono(ch)
diff --git a/model_server/base/session.py b/model_server/base/session.py
index 8f25520d..4aa88a1e 100644
--- a/model_server/base/session.py
+++ b/model_server/base/session.py
@@ -251,7 +251,7 @@ class _Session(object):
         :param params: optional parameters that are passed to the model's construct
         :return: model_id of loaded model
         """
-        mi = ModelClass(params=params)
+        mi = ModelClass(params=params.dict())
         assert mi.loaded, f'Error loading instance of {ModelClass.__name__}'
         ii = 0
 
diff --git a/model_server/extensions/ilastik/models.py b/model_server/extensions/ilastik/models.py
index 75e0434c..dccb269a 100644
--- a/model_server/extensions/ilastik/models.py
+++ b/model_server/extensions/ilastik/models.py
@@ -2,11 +2,9 @@ import json
 from logging import getLogger
 import os
 from pathlib import Path
-from typing import Union
 import warnings
 
 import numpy as np
-from pydantic import BaseModel, Field
 import vigra
 
 import model_server.extensions.ilastik.conf
@@ -14,17 +12,10 @@ from ...base.accessors import PatchStack
 from ...base.accessors import GenericImageDataAccessor, InMemoryDataAccessor
 from ...base.models import Model, ImageToImageModel, InstanceSegmentationModel, InvalidInputImageError, ParameterExpectedError, SemanticSegmentationModel
 
-# TODO: move params models to router only; model classes to be created only with dict
-class IlastikParams(BaseModel):
-    project_file: str = Field(description='(*.ilp) ilastik project filename')
-    duplicate: bool = Field(
-        True,
-        description='Load another instance of the same project file if True; return existing one if False'
-    )
 
 class IlastikModel(Model):
 
-    def __init__(self, params: IlastikParams, autoload=True, enforce_embedded=True):
+    def __init__(self, params: dict, autoload=True, enforce_embedded=True):
         """
         Base class for models that run via ilastik shell API
         :param params:
@@ -33,7 +24,7 @@ class IlastikModel(Model):
         :param enforce_embedded:
             raise an error if all input data are not embedded in the project file, i.e. on the filesystem
         """
-        pf = Path(params.project_file)
+        pf = Path(params['project_file'])
         self.enforce_embedded = enforce_embedded
         if pf.is_absolute():
             pap = pf
@@ -103,14 +94,11 @@ class IlastikModel(Model):
     def model_3d(self):
         return self.model_shape_dict['Z'] > 1
 
-class IlastikPixelClassifierParams(IlastikParams):
-    px_class: int = 0
-    px_prob_threshold: float = 0.5
 
 class IlastikPixelClassifierModel(IlastikModel, SemanticSegmentationModel):
     operations = ['segment', ]
 
-    def __init__(self, params: IlastikPixelClassifierParams, **kwargs):
+    def __init__(self, params: dict, **kwargs):
         super(IlastikPixelClassifierModel, self).__init__(params, **kwargs)
 
     @staticmethod
diff --git a/model_server/extensions/ilastik/router.py b/model_server/extensions/ilastik/router.py
index ea852798..e4412221 100644
--- a/model_server/extensions/ilastik/router.py
+++ b/model_server/extensions/ilastik/router.py
@@ -1,4 +1,5 @@
 from fastapi import APIRouter
+from pydantic import BaseModel, Field
 
 from model_server.base.session import session
 
@@ -13,8 +14,20 @@ router = APIRouter(
 import model_server.extensions.ilastik.pipelines.px_then_ob
 router.include_router(model_server.extensions.ilastik.pipelines.px_then_ob.router)
 
+# TODO: move params models to router only; model classes to be created only with dict
+class IlastikParams(BaseModel):
+    project_file: str = Field(description='(*.ilp) ilastik project filename')
+    duplicate: bool = Field(
+        True,
+        description='Load another instance of the same project file if True; return existing one if False'
+    )
+
+class IlastikPixelClassifierParams(IlastikParams):
+    px_class: int = 0
+    px_prob_threshold: float = 0.5
+
 @router.put('/seg/load/')
-def load_px_model(p: ilm.IlastikPixelClassifierParams, model_id=None) -> dict:
+def load_px_model(p: IlastikPixelClassifierParams, model_id=None) -> dict:
     """
     Load an ilastik pixel classifier model from its project file
     """
@@ -25,7 +38,7 @@ def load_px_model(p: ilm.IlastikPixelClassifierParams, model_id=None) -> dict:
     )
 
 @router.put('/pxmap_to_obj/load/')
-def load_pxmap_to_obj_model(p: ilm.IlastikParams, model_id=None) -> dict:
+def load_pxmap_to_obj_model(p: IlastikParams, model_id=None) -> dict:
     """
     Load an ilastik object classifier from pixel predictions model from its project file
     """
@@ -36,7 +49,7 @@ def load_pxmap_to_obj_model(p: ilm.IlastikParams, model_id=None) -> dict:
     )
 
 @router.put('/seg_to_obj/load/')
-def load_seg_to_obj_model(p: ilm.IlastikParams, model_id=None) -> dict:
+def load_seg_to_obj_model(p: IlastikParams, model_id=None) -> dict:
     """
     Load an ilastik object classifier from segmentation model from its project file
     """
@@ -46,13 +59,13 @@ def load_seg_to_obj_model(p: ilm.IlastikParams, model_id=None) -> dict:
         model_id=model_id,
     )
 
-def load_ilastik_model(model_class: ilm.IlastikModel, p: ilm.IlastikParams, model_id=None) -> dict:
-    project_file = p.project_file
+def load_ilastik_model(model_class: ilm.IlastikModel, p: IlastikParams, model_id=None) -> dict:
+    pf = p.project_file
     if not p.duplicate:
-        existing_model_id = session.find_param_in_loaded_models('project_file', project_file, is_path=True)
+        existing_model_id = session.find_param_in_loaded_models('project_file', pf, is_path=True)
         if existing_model_id is not None:
-            session.log_info(f'An ilastik model from {project_file} already existing exists; did not load a duplicate')
+            session.log_info(f'An ilastik model from {pf} already existing exists; did not load a duplicate')
             return {'model_id': existing_model_id}
     result = session.load_model(model_class, key=model_id, params=p)
-    session.log_info(f'Loaded ilastik model {result} from {project_file}')
+    session.log_info(f'Loaded ilastik model {result} from {pf}')
     return {'model_id': result}
\ No newline at end of file
diff --git a/tests/test_ilastik/test_ilastik.py b/tests/test_ilastik/test_ilastik.py
index 40ba7e0b..a6f467f9 100644
--- a/tests/test_ilastik/test_ilastik.py
+++ b/tests/test_ilastik/test_ilastik.py
@@ -29,14 +29,18 @@ class TestIlastikPixelClassification(unittest.TestCase):
         self.cf = CziImageFileAccessor(czifile['path'])
         self.channel = 0
         self.model = ilm.IlastikPixelClassifierModel(
-            params=ilm.IlastikPixelClassifierParams(project_file=ilastik_classifiers['px']['path'].__str__())
+            params={
+                'project_file': ilastik_classifiers['px']['path'].__str__(),
+                'px_class': 0,
+                'px_prob_threshold': 0.5,
+            }
         )
         self.mono_image = self.cf.get_mono(self.channel)
 
 
     def test_raise_error_if_autoload_disabled(self):
         model = ilm.IlastikPixelClassifierModel(
-            params=ilm.IlastikPixelClassifierParams(project_file=ilastik_classifiers['px']['path'].__str__()),
+            params={'project_file': ilastik_classifiers['px']['path'].__str__()},
             autoload=False
         )
         w = 512
@@ -80,11 +84,12 @@ class TestIlastikPixelClassification(unittest.TestCase):
     def test_label_pixels_with_params(self):
         def _run_seg(tr, sig):
             mod = ilm.IlastikPixelClassifierModel(
-                params=ilm.IlastikPixelClassifierParams(
-                    project_file=ilastik_classifiers['px']['path'].__str__(),
-                    px_prob_threshold=tr,
-                    px_smoothing=sig,
-                ),
+                params={
+                    'project_file': ilastik_classifiers['px']['path'].__str__(),
+                    'px_class': 0,
+                    'px_prob_threshold': tr,
+                    'px_smoothing': sig,
+                },
             )
             mask = mod.label_pixel_class(self.mono_image)
             write_accessor_data_to_file(
@@ -154,7 +159,7 @@ class TestIlastikPixelClassification(unittest.TestCase):
         self.test_run_pixel_classifier()
         fp = czifile['path']
         model = ilm.IlastikObjectClassifierFromPixelPredictionsModel(
-            params=ilm.IlastikParams(project_file=ilastik_classifiers['pxmap_to_obj']['path'].__str__())
+            params={'project_file': ilastik_classifiers['pxmap_to_obj']['path'].__str__()}
         )
         mask = self.model.label_pixel_class(self.mono_image)
         objmap, _ = model.infer(self.mono_image, mask)
@@ -172,7 +177,7 @@ class TestIlastikPixelClassification(unittest.TestCase):
         self.test_run_pixel_classifier()
         fp = czifile['path']
         model = ilm.IlastikObjectClassifierFromSegmentationModel(
-            params=ilm.IlastikParams(project_file=ilastik_classifiers['seg_to_obj']['path'].__str__())
+            params={'project_file': ilastik_classifiers['seg_to_obj']['path'].__str__()}
         )
         mask = self.model.label_pixel_class(self.mono_image)
         objmap = model.label_instance_class(self.mono_image, mask)
@@ -191,11 +196,11 @@ class TestIlastikPixelClassification(unittest.TestCase):
                 'accessor': generate_file_accessor(czifile['path'])
             },
             models={
-                'model': ilm.IlastikPixelClassifierModel(
-                    params=ilm.IlastikPixelClassifierParams(
-                        project_file=ilastik_classifiers['px']['path'].__str__()
-                    ),
-                ),
+                'model': ilm.IlastikPixelClassifierModel({
+                    'project_file': ilastik_classifiers['px']['path'].__str__(),
+                    'px_class': 0,
+                    'px_prob_threshold': 0.5,
+                }),
             },
             channel=0,
         )
@@ -327,7 +332,7 @@ class TestIlastikOnMultichannelInputs(TestServerTestCase):
     def test_classify_pixels(self):
         img = generate_file_accessor(self.pa_input_image)
         self.assertGreater(img.chroma, 1)
-        mod = ilm.IlastikPixelClassifierModel(ilm.IlastikPixelClassifierParams(project_file=self.pa_px_classifier.__str__()))
+        mod = ilm.IlastikPixelClassifierModel({'project_file': self.pa_px_classifier.__str__()})
         pxmap = mod.infer(img)[0]
         self.assertEqual(pxmap.hw, img.hw)
         self.assertEqual(pxmap.nz, img.nz)
@@ -337,7 +342,7 @@ class TestIlastikOnMultichannelInputs(TestServerTestCase):
         pxmap = self.test_classify_pixels()
         img = generate_file_accessor(self.pa_input_image)
         mod = ilm.IlastikObjectClassifierFromPixelPredictionsModel(
-            ilm.IlastikParams(project_file=self.pa_ob_pxmap_classifier.__str__())
+            {'project_file': self.pa_ob_pxmap_classifier.__str__()}
         )
         obmap = mod.infer(img, pxmap)[0]
         self.assertEqual(obmap.hw, img.hw)
@@ -354,10 +359,10 @@ class TestIlastikOnMultichannelInputs(TestServerTestCase):
                 },
                 models={
                     'px_model': ilm.IlastikPixelClassifierModel(
-                        ilm.IlastikParams(project_file=self.pa_px_classifier.__str__()),
+                        {'project_file': self.pa_px_classifier.__str__()},
                     ),
                     'ob_model': ilm.IlastikObjectClassifierFromPixelPredictionsModel(
-                        ilm.IlastikParams(project_file=self.pa_ob_pxmap_classifier.__str__()),
+                        {'project_file': self.pa_ob_pxmap_classifier.__str__()}
                     )
                 },
                 channel=channel,
@@ -428,7 +433,7 @@ class TestIlastikObjectClassification(unittest.TestCase):
         )
 
         self.classifier = ilm.IlastikObjectClassifierFromSegmentationModel(
-            params=ilm.IlastikParams(project_file=ilastik_classifiers['seg_to_obj']['path'].__str__()),
+            params={'project_file': ilastik_classifiers['seg_to_obj']['path'].__str__()},
         )
         self.raw = self.roiset.get_patches_acc()
         self.masks = self.roiset.get_patch_masks_acc()
-- 
GitLab