From 04b4d5231329b447b4863832955d1006dcb346ed Mon Sep 17 00:00:00 2001
From: Christopher Rhodes <christopher.rhodes@embl.de>
Date: Fri, 25 Oct 2024 12:16:48 +0200
Subject: [PATCH] Models' inference methods now just return result, not tuple
 with {'success': True'}

---
 model_server/base/session.py                  |  2 +-
 model_server/extensions/ilastik/models.py     | 22 +++++++++----------
 .../ilastik/pipelines/px_then_ob.py           |  4 ++--
 tests/test_ilastik/test_ilastik.py            |  6 ++---
 tests/test_ilastik/test_roiset_workflow.py    |  8 ++-----
 5 files changed, 18 insertions(+), 24 deletions(-)

diff --git a/model_server/base/session.py b/model_server/base/session.py
index 4aa88a1e..be4911f0 100644
--- a/model_server/base/session.py
+++ b/model_server/base/session.py
@@ -251,7 +251,7 @@ class _Session(object):
         :param params: optional parameters that are passed to the model's construct
         :return: model_id of loaded model
         """
-        mi = ModelClass(params=params.dict())
+        mi = ModelClass(params=params.dict() if params else None)
         assert mi.loaded, f'Error loading instance of {ModelClass.__name__}'
         ii = 0
 
diff --git a/model_server/extensions/ilastik/models.py b/model_server/extensions/ilastik/models.py
index dccb269a..5e45bab4 100644
--- a/model_server/extensions/ilastik/models.py
+++ b/model_server/extensions/ilastik/models.py
@@ -137,7 +137,7 @@ class IlastikPixelClassifierModel(IlastikModel, SemanticSegmentationModel):
             [1, 2, 3, 0],
             [0, 1, 2, 3]
         )
-        return InMemoryDataAccessor(data=yxcz), {'success': True}
+        return InMemoryDataAccessor(data=yxcz)
 
     def infer_patch_stack(self, img: PatchStack, **kwargs) -> (np.ndarray, dict):
         """
@@ -150,13 +150,13 @@ class IlastikPixelClassifierModel(IlastikModel, SemanticSegmentationModel):
         for i in range(0, img.count):
             sl = img.get_slice_at(i)
             try:
-                data[i][sl[0], sl[1], :, sl[3]] = self.infer(img.iat(i, crop=True))[0].data
+                data[i][sl[0], sl[1], :, sl[3]] = self.infer(img.iat(i, crop=True)).data
             except FeatureSelectionConstraintError: # occurs occasionally on small patches
                 continue
-        return PatchStack(data), {'success': True}
+        return PatchStack(data)
 
     def label_pixel_class(self, img: GenericImageDataAccessor, **kwargs):
-        pxmap, _ = self.infer(img)
+        pxmap = self.infer(img)
         mask = pxmap.get_mono(self.params['px_class']).apply(lambda x: x > self.params['px_prob_threshold'])
         return mask
 
@@ -219,19 +219,18 @@ class IlastikObjectClassifierFromSegmentationModel(IlastikModel, InstanceSegment
                 [0, 1, 2, 3, 4],
                 [0, 4, 1, 2, 3]
             )
-            return PatchStack(data=pyxcz), {'success': True}
+            return PatchStack(data=pyxcz)
         else:
             yxcz = np.moveaxis(
                 obmaps[0],
                 [1, 2, 3, 0],
                 [0, 1, 2, 3]
             )
-            return InMemoryDataAccessor(data=yxcz), {'success': True}
+            return InMemoryDataAccessor(data=yxcz)
 
     def label_instance_class(self, img: GenericImageDataAccessor, mask: GenericImageDataAccessor, **kwargs):
         super(IlastikObjectClassifierFromSegmentationModel, self).label_instance_class(img, mask, **kwargs)
-        obmap, _ = self.infer(img, mask)
-        return obmap
+        return self.infer(img, mask)
 
 
 class IlastikObjectClassifierFromPixelPredictionsModel(IlastikModel, ImageToImageModel):
@@ -277,14 +276,14 @@ class IlastikObjectClassifierFromPixelPredictionsModel(IlastikModel, ImageToImag
                 [0, 1, 2, 3, 4],
                 [0, 4, 1, 2, 3]
             )
-            return PatchStack(data=pyxcz), {'success': True}
+            return PatchStack(data=pyxcz)
         else:
             yxcz = np.moveaxis(
                 obmaps[0],
                 [1, 2, 3, 0],
                 [0, 1, 2, 3]
             )
-            return InMemoryDataAccessor(data=yxcz), {'success': True}
+            return InMemoryDataAccessor(data=yxcz)
 
 
     def label_instance_class(self, img: GenericImageDataAccessor, pxmap: GenericImageDataAccessor, **kwargs):
@@ -305,8 +304,7 @@ class IlastikObjectClassifierFromPixelPredictionsModel(IlastikModel, ImageToImag
         pxch = kwargs.get('pixel_classification_channel', 0)
         pxtr = kwargs.get('pixel_classification_threshold', 0.5)
         mask = InMemoryDataAccessor(pxmap.get_one_channel_data(pxch).data > pxtr)
-        obmap, _ = self.infer(img, mask)
-        return obmap
+        return self.infer(img, mask)
 
 
 class Error(Exception):
diff --git a/model_server/extensions/ilastik/pipelines/px_then_ob.py b/model_server/extensions/ilastik/pipelines/px_then_ob.py
index 1aa64615..2e9fd6e4 100644
--- a/model_server/extensions/ilastik/pipelines/px_then_ob.py
+++ b/model_server/extensions/ilastik/pipelines/px_then_ob.py
@@ -57,8 +57,8 @@ def pixel_then_object_classification_pipeline(
     else:
         channels = range(0, d['input'].chroma)
     d['select_channels'] = d.last.get_channels(channels, mip=k.get('mip', False))
-    d['pxmap'], _ = models['px_model'].infer(d.last)
-    d['ob_map'], _ = models['ob_model'].infer(d['select_channels'], d['pxmap'])
+    d['pxmap'] = models['px_model'].infer(d.last)
+    d['ob_map'] = models['ob_model'].infer(d['select_channels'], d['pxmap'])
 
     return d
 
diff --git a/tests/test_ilastik/test_ilastik.py b/tests/test_ilastik/test_ilastik.py
index a6f467f9..4eedf981 100644
--- a/tests/test_ilastik/test_ilastik.py
+++ b/tests/test_ilastik/test_ilastik.py
@@ -148,7 +148,7 @@ class TestIlastikPixelClassification(unittest.TestCase):
         self.assertEqual(mask.nz, acc.nz)
         self.assertEqual(mask.count, acc.count)
 
-        pxmap, _ = self.model.infer_patch_stack(acc)
+        pxmap = self.model.infer_patch_stack(acc)
         self.assertEqual(pxmap.dtype, float)
         self.assertEqual(pxmap.chroma, len(self.model.labels))
         self.assertEqual(pxmap.hw, acc.hw)
@@ -162,7 +162,7 @@ class TestIlastikPixelClassification(unittest.TestCase):
             params={'project_file': ilastik_classifiers['pxmap_to_obj']['path'].__str__()}
         )
         mask = self.model.label_pixel_class(self.mono_image)
-        objmap, _ = model.infer(self.mono_image, mask)
+        objmap = model.infer(self.mono_image, mask)
 
         self.assertTrue(
             write_accessor_data_to_file(
@@ -333,7 +333,7 @@ class TestIlastikOnMultichannelInputs(TestServerTestCase):
         img = generate_file_accessor(self.pa_input_image)
         self.assertGreater(img.chroma, 1)
         mod = ilm.IlastikPixelClassifierModel({'project_file': self.pa_px_classifier.__str__()})
-        pxmap = mod.infer(img)[0]
+        pxmap = mod.infer(img)
         self.assertEqual(pxmap.hw, img.hw)
         self.assertEqual(pxmap.nz, img.nz)
         return pxmap
diff --git a/tests/test_ilastik/test_roiset_workflow.py b/tests/test_ilastik/test_roiset_workflow.py
index 5f799089..4adf24bb 100644
--- a/tests/test_ilastik/test_roiset_workflow.py
+++ b/tests/test_ilastik/test_roiset_workflow.py
@@ -72,18 +72,14 @@ class BaseTestRoiSetMonoProducts(object):
                 'name': 'ilastik_px_mod',
                 'project_file': fp_px,
                 'model': ilm.IlastikPixelClassifierModel(
-                    ilm.IlastikPixelClassifierParams(
-                        project_file=fp_px,
-                    )
+                    {'project_file': fp_px},
                 )
             },
             'object_classifier': {
                 'name': 'ilastik_ob_mod',
                 'project_file': fp_ob,
                 'model': ilm.IlastikObjectClassifierFromSegmentationModel(
-                    ilm.IlastikParams(
-                        project_file=fp_ob
-                    )
+                    {'project_file': fp_ob},
                 )
             },
         }
-- 
GitLab