From 644b4e69bda963709abc25b0bdbe7827b6791f75 Mon Sep 17 00:00:00 2001
From: Christopher Rhodes <christopher.rhodes@embl.de>
Date: Wed, 30 Oct 2024 12:14:35 +0100
Subject: [PATCH] Threshold object classifier using PatchStacks passes tests

---
 model_server/base/roiset.py  | 33 ++++++++++++++++++++-------------
 tests/base/test_accessors.py |  1 +
 2 files changed, 21 insertions(+), 13 deletions(-)

diff --git a/model_server/base/roiset.py b/model_server/base/roiset.py
index 23532707..27e32944 100644
--- a/model_server/base/roiset.py
+++ b/model_server/base/roiset.py
@@ -1190,19 +1190,26 @@ class IntensityThresholdInstanceMaskSegmentationModel(InstanceMaskSegmentationMo
             allow_3d: bool = False,
             connect_3d: bool = True,
     ) -> GenericImageDataAccessor:
-        labels = get_label_ids(mask)
-        df = pd.DataFrame(regionprops_table(
-            labels.data_yxz,
-            intensity_image=img.get_mono(self.channel).data_yxz,
-            properties=('label', 'area', 'intensity_mean', 'bbox')
-        ))
-        df['intensity_mean'] > self.tr
-
-        om = np.zeros(labels.shape, labels.dtype)
-        def _label_object_class(la):
-            om[labels.data == la] = 1
-        df.loc[df['intensity_mean'] > (self.tr * img.dtype_max), 'label'].apply(_label_object_class)
-        return InMemoryDataAccessor(om)
+        if isinstance(img, PatchStack):  # assume one object per patch
+            df = img.get_object_df(mask)
+            om = np.zeros(mask.shape, 'uint16')
+            def _label_patch_class(la):
+                om[la] = (mask.iat(la).data > 0) * 1
+            df.loc[df['intensity_mean'] > (self.tr * img.dtype_max), 'label'].apply(_label_patch_class)
+            return PatchStack(om)
+        else:
+            labels = get_label_ids(mask)
+            df = pd.DataFrame(regionprops_table(
+                labels.data_yxz,
+                intensity_image=img.get_mono(self.channel).data_yxz,
+                properties=('label', 'area', 'intensity_mean')
+            ))
+
+            om = np.zeros(labels.shape, labels.dtype)
+            def _label_object_class(la):
+                om[labels.data == la] = 1
+            df.loc[df['intensity_mean'] > (self.tr * img.dtype_max), 'label'].apply(_label_object_class)
+            return InMemoryDataAccessor(om)
 
     def label_instance_class(
             self, img: GenericImageDataAccessor, mask: GenericImageDataAccessor, **kwargs
diff --git a/tests/base/test_accessors.py b/tests/base/test_accessors.py
index 05663dd7..2c5c052b 100644
--- a/tests/base/test_accessors.py
+++ b/tests/base/test_accessors.py
@@ -302,6 +302,7 @@ class TestPatchStackAccessor(unittest.TestCase):
         self.assertEqual(acc.hw, (h, w))
         self.assertEqual(acc.get_mono(channel=0).data_yxz.shape, (n, h, w, nz))
         self.assertEqual(acc.get_mono(channel=0, mip=True).data_yx.shape, (n, h, w))
+        return acc
 
     def test_object_df(self):
         w = 30
-- 
GitLab