From c3a8f9153b3b50ad5c556c32b55643fd2e79595f Mon Sep 17 00:00:00 2001
From: Christopher Rhodes <christopher.rhodes@embl.de>
Date: Mon, 25 Mar 2024 12:44:46 +0100
Subject: [PATCH] Generate an InstanceSegmentationModel from an
 IlastikObjectClassifierFromPixelPredictionsModel given binarization
 parameters

---
 model_server/extensions/ilastik/models.py      | 18 +++++++++++++++++-
 .../extensions/ilastik/tests/test_ilastik.py   | 17 +++++++++++++++++
 2 files changed, 34 insertions(+), 1 deletion(-)

diff --git a/model_server/extensions/ilastik/models.py b/model_server/extensions/ilastik/models.py
index ac322e85..d3cc0d93 100644
--- a/model_server/extensions/ilastik/models.py
+++ b/model_server/extensions/ilastik/models.py
@@ -235,11 +235,27 @@ class IlastikObjectClassifierFromPixelPredictionsModel(IlastikModel, ImageToImag
         if not img.shape == pxmap.shape:
             raise InvalidInputImageError('Expecting input image and pixel probabilities to be the same shape')
         pxch = kwargs.get('pixel_classification_channel', 0)
-        pxtr = kwargs('pixel_classification_threshold', 0.5)
+        pxtr = kwargs.get('pixel_classification_threshold', 0.5)
         mask = InMemoryDataAccessor(pxmap.get_one_channel_data(pxch).data > pxtr)
         obmap, _ = self.infer(img, mask)
         return obmap
 
+    def make_instance_segmentation_model(self, px_ch: int):
+        """
+        Generate an instance segmentation model, i.e. one that takes binary masks instead of pixel probabilities as a
+        second input.
+        :param px_ch: channel of pixel probability map to use
+        :return:
+            InstanceSegmentationModel object
+        """
+        class _Mod(self.__class__, InstanceSegmentationModel):
+            def label_instance_class(
+                    self, img: GenericImageDataAccessor, mask: GenericImageDataAccessor, **kwargs
+            ) -> GenericImageDataAccessor:
+                return super().label_instance_class(img, mask, pixel_classification_channel=px_ch)
+        return _Mod(params={'project_file': self.project_file})
+
+
 
 class Error(Exception):
     pass
diff --git a/model_server/extensions/ilastik/tests/test_ilastik.py b/model_server/extensions/ilastik/tests/test_ilastik.py
index 21e3d9af..15bca41d 100644
--- a/model_server/extensions/ilastik/tests/test_ilastik.py
+++ b/model_server/extensions/ilastik/tests/test_ilastik.py
@@ -136,6 +136,23 @@ class TestIlastikPixelClassification(unittest.TestCase):
         )
         self.assertEqual(objmap.data.max(), 2)
 
+    def test_make_seg_obj_model_from_pxmap_obj(self):
+        self.test_run_pixel_classifier()
+        fp = czifile['path']
+        pxmap_model = ilm.IlastikObjectClassifierFromPixelPredictionsModel(
+            {'project_file': ilastik_classifiers['pxmap_to_obj']}
+        )
+        seg_model = pxmap_model.make_instance_segmentation_model(px_ch=0)
+        objmap = seg_model.label_instance_class(self.mono_image, self.mask)
+
+        self.assertTrue(
+            write_accessor_data_to_file(
+                output_path / f'obmap_seg_from_pxmap_{fp.stem}.tif',
+                objmap,
+            )
+        )
+        self.assertEqual(objmap.data.max(), 2)
+
     def test_run_object_classifier_from_segmentation(self):
         self.test_run_pixel_classifier()
         fp = czifile['path']
-- 
GitLab