From f9f8621180a9dd5e837b3f62f1e0d066e03752ec Mon Sep 17 00:00:00 2001
From: Christopher Rhodes <christopher.rhodes@embl.de>
Date: Wed, 30 Oct 2024 10:17:36 +0100
Subject: [PATCH] Moved intensity gate model to RoiSet module

---
 model_server/base/api.py    |  4 ++--
 model_server/base/models.py | 25 ------------------------
 model_server/base/roiset.py | 38 +++++++++++++++++++++++++++++++++++++
 tests/base/test_model.py    |  6 ------
 tests/base/test_roiset.py   | 17 ++++++++++++++---
 5 files changed, 54 insertions(+), 36 deletions(-)

diff --git a/model_server/base/api.py b/model_server/base/api.py
index 6d9f5814..081b3b03 100644
--- a/model_server/base/api.py
+++ b/model_server/base/api.py
@@ -3,8 +3,8 @@ from typing import List, Union
 
 from fastapi import FastAPI, HTTPException
 from .accessors import generate_file_accessor
-from .models import BinaryThresholdSegmentationModel, IntensityThresholdInstanceMaskSegmentationModel
-from .roiset import RoiSetExportParams, SerializeRoiSetError
+from .models import BinaryThresholdSegmentationModel
+from .roiset import IntensityThresholdInstanceMaskSegmentationModel, RoiSetExportParams, SerializeRoiSetError
 from .session import session, AccessorIdError, InvalidPathError, RoiSetIdError, WriteAccessorError
 
 app = FastAPI(debug=True)
diff --git a/model_server/base/models.py b/model_server/base/models.py
index 3e026143..800ec951 100644
--- a/model_server/base/models.py
+++ b/model_server/base/models.py
@@ -158,31 +158,6 @@ class InstanceMaskSegmentationModel(ImageToImageModel):
         return PatchStack(data)
 
 
-class IntensityThresholdInstanceMaskSegmentationModel(InstanceMaskSegmentationModel):
-    def __init__(self, tr: float = 0.5, channel: int = 0):
-        """
-        Model that labels all objects as class 1 if the intensity in specified channel exceeds a threshold; labels all
-        objects as class 1 if threshold = 0.0
-        :param tr: threshold in range of 0.0 to 1.0; model handles normalization to full pixel intensity range
-        :param channel: channel to use for thresholding
-        """
-        self.tr = tr
-        self.channel = channel
-        self.loaded = self.load()
-        super().__init__(info={'tr': tr, 'channel': channel})
-
-    def load(self):
-        return True
-
-    def infer(self, acc: GenericImageDataAccessor, mask: GenericImageDataAccessor) -> GenericImageDataAccessor:
-        return mask.apply(lambda x: (1 * (x > self.tr)).astype(acc.dtype))
-
-    def label_instance_class(
-            self, img: GenericImageDataAccessor, mask: GenericImageDataAccessor, **kwargs
-    ) -> GenericImageDataAccessor:
-        super().label_instance_class(img, mask, **kwargs)
-        return self.infer(img, mask)
-
 class Error(Exception):
     pass
 
diff --git a/model_server/base/roiset.py b/model_server/base/roiset.py
index ba7c6b06..ac93238b 100644
--- a/model_server/base/roiset.py
+++ b/model_server/base/roiset.py
@@ -1081,6 +1081,7 @@ class RoiSet(object):
 class RoiSetWithDerivedChannelsExportParams(RoiSetExportParams):
     derived_channels: bool = False
 
+
 class RoiSetWithDerivedChannels(RoiSet):
 
     def __init__(self, *a, **k):
@@ -1165,32 +1166,69 @@ class RoiSetWithDerivedChannels(RoiSet):
                 record[k].append(str(fp))
         return record
 
+
+class IntensityThresholdInstanceMaskSegmentationModel(InstanceMaskSegmentationModel):
+    def __init__(self, tr: float = 0.5, channel: int = 0):
+        """
+        Model that labels all objects as class 1 if the intensity in specified channel exceeds a threshold; labels all
+        objects as class 1 if threshold = 0.0
+        :param tr: threshold in range of 0.0 to 1.0; model handles normalization to full pixel intensity range
+        :param channel: channel to use for thresholding
+        """
+        self.tr = tr
+        self.channel = channel
+        self.loaded = self.load()
+        super().__init__(info={'tr': tr, 'channel': channel})
+
+    def load(self):
+        return True
+
+    def infer(self, acc: GenericImageDataAccessor, mask: GenericImageDataAccessor) -> GenericImageDataAccessor:
+        # TODO: gate by intensity in image, not threshold in mask channel; this will require iterating on labeled objects
+        return mask.apply(lambda x: (1 * (x > self.tr)).astype(acc.dtype))
+
+    def label_instance_class(
+            self, img: GenericImageDataAccessor, mask: GenericImageDataAccessor, **kwargs
+    ) -> GenericImageDataAccessor:
+        super().label_instance_class(img, mask, **kwargs)
+        return self.infer(img, mask)
+
+
 class Error(Exception):
     pass
 
+
 class BoundingBoxError(Error):
     pass
 
+
 class DeserializeRoiSetError(Error):
     pass
 
+
 class SerializeRoiSetError(Error):
     pass
 
+
 class NoDeprojectChannelSpecifiedError(Error):
     pass
 
+
 class DerivedChannelError(Error):
     pass
 
+
 class MissingSegmentationError(Error):
     pass
 
+
 class PatchMaskShapeError(Error):
     pass
 
+
 class ShapeMismatchError(Error):
     pass
 
+
 class MissingInstanceLabelsError(Error):
     pass
\ No newline at end of file
diff --git a/tests/base/test_model.py b/tests/base/test_model.py
index e5c96894..8b4aa3a7 100644
--- a/tests/base/test_model.py
+++ b/tests/base/test_model.py
@@ -68,9 +68,3 @@ class TestCziImageFileAccess(unittest.TestCase):
         obmap = model.label_instance_class(img, mask)
         self.assertTrue(all(obmap.unique()[0] == [0, 1]))
         self.assertTrue(all(obmap.unique()[1] > 0))
-
-    def test_permissive_instance_segmentation(self):
-        img, mask = self.test_dummy_pixel_segmentation()
-        model = IntensityThresholdInstanceMaskSegmentationModel()
-        obmap = model.label_instance_class(img, mask)
-        self.assertTrue(np.all(mask.data == 255 * obmap.data))
diff --git a/tests/base/test_roiset.py b/tests/base/test_roiset.py
index c510510a..b9b89410 100644
--- a/tests/base/test_roiset.py
+++ b/tests/base/test_roiset.py
@@ -7,9 +7,8 @@ from pathlib import Path
 import pandas as pd
 
 from model_server.base.process import smooth
-from model_server.base.roiset import filter_df_overlap_bbox, filter_df_overlap_seg, RoiSetExportParams, RoiSetMetaParams
-from model_server.base.roiset import RoiSet
-from model_server.base.accessors import generate_file_accessor, InMemoryDataAccessor, write_accessor_data_to_file, PatchStack
+from model_server.base.roiset import filter_df_overlap_bbox, filter_df_overlap_seg, RoiSet, RoiSetExportParams, RoiSetMetaParams
+from model_server.base.accessors import generate_file_accessor, InMemoryDataAccessor, write_accessor_data_to_file
 import model_server.conf.testing as conf
 from model_server.conf.testing import DummyInstanceMaskSegmentationModel
 
@@ -819,3 +818,15 @@ class TestRoiSetPolygons(BaseTestRoiSetMonoProducts, unittest.TestCase):
         self.assertTrue((res.bbox_intersec == 2).all())
         self.assertTrue((res.loc[res.seg_overlaps, :].index == [1]).all())
         self.assertTrue((res.loc[res.seg_overlaps, 'seg_iou'] == 0.4).all())
+
+class TestIntensityThresholdObjectModel(BaseTestRoiSetMonoProducts, unittest.TestCase):
+    def test_permissive_instance_segmentation(self):
+        from model_server.base.roiset import IntensityThresholdInstanceMaskSegmentationModel
+
+        img = self.stack.get_mono(channel=0, mip=True)
+        mask = self.seg_mask
+
+        # TODO: test on actual binary mask, not dummy square mask
+        model = IntensityThresholdInstanceMaskSegmentationModel()
+        obmap = model.label_instance_class(img, mask)
+        self.assertTrue(np.all(mask.data == 255 * obmap.data))
\ No newline at end of file
-- 
GitLab