From 426b5b79d7309ac763298344c8c87630e71aca7a Mon Sep 17 00:00:00 2001
From: Christopher Rhodes <christopher.rhodes@embl.de>
Date: Fri, 3 May 2024 16:21:58 +0200
Subject: [PATCH] Added static constructor for making RoiSets from segmentation
 masks

---
 model_server/base/roiset.py | 19 +++++++++++++++++++
 tests/test_roiset.py        | 21 ++++++++-------------
 2 files changed, 27 insertions(+), 13 deletions(-)

diff --git a/model_server/base/roiset.py b/model_server/base/roiset.py
index f3c0c947..d289284f 100644
--- a/model_server/base/roiset.py
+++ b/model_server/base/roiset.py
@@ -152,6 +152,25 @@ class RoiSet(object):
         """Expose ROI meta information via the Pandas.DataFrame API"""
         return self._df.itertuples(name='Roi')
 
+    @staticmethod
+    def from_segmentation(
+            acc_raw: GenericImageDataAccessor,
+            acc_seg: GenericImageDataAccessor,
+            allow_3d=False,
+            connect_3d=True,
+            params: RoiSetMetaParams = RoiSetMetaParams()
+    ):
+        """
+        Create a RoiSet from a binary segmentation mask (either 2D or 3D)
+        :param acc_raw: accessor to a generally a multichannel z-stack
+        :param acc_seg: accessor of a binary segmentation mask (mono) of either two or three dimensions
+        :param allow_3d: return a 3D map if True; return a 2D map of the mask's maximum intensity project if False
+        :param connect_3d: objects can span multiple z-positions if True; objects are unique to a single z if False
+        :param params: optional arguments that influence the definition and representation of ROIs
+        :return: object identities map
+        """
+        return RoiSet(acc_raw, _get_label_ids(acc_seg, allow_3d=allow_3d, connect_3d=connect_3d), params)
+
     @staticmethod
     def make_df(acc_raw, acc_obj_ids, expand_box_by) -> pd.DataFrame:
         """
diff --git a/tests/test_roiset.py b/tests/test_roiset.py
index 47856888..c1d773c6 100644
--- a/tests/test_roiset.py
+++ b/tests/test_roiset.py
@@ -10,7 +10,7 @@ import pandas as pd
 from model_server.conf.testing import output_path, roiset_test_data
 
 from model_server.base.roiset import RoiSetExportParams, RoiSetMetaParams
-from model_server.base.roiset import _get_label_ids, RoiSet
+from model_server.base.roiset import RoiSet
 from model_server.base.accessors import generate_file_accessor, InMemoryDataAccessor, write_accessor_data_to_file, PatchStack
 from model_server.base.models import DummyInstanceSegmentationModel
 
@@ -26,10 +26,9 @@ class BaseTestRoiSetMonoProducts(object):
 class TestRoiSetMonoProducts(BaseTestRoiSetMonoProducts, unittest.TestCase):
 
     def _make_roi_set(self, mask_type='boxes', **kwargs):
-        id_map = _get_label_ids(self.seg_mask)
-        roiset = RoiSet(
+        roiset = RoiSet.from_segmentation(
             self.stack_ch_pa,
-            id_map,
+            self.seg_mask,
             params=RoiSetMetaParams(
                 mask_type=mask_type,
                 filters=kwargs.get('filters', {'area': {'min': 1e3, 'max': 1e4}}),
@@ -67,9 +66,8 @@ class TestRoiSetMonoProducts(BaseTestRoiSetMonoProducts, unittest.TestCase):
     def test_roiset_from_non_zstacks(self, **kwargs):
         acc_zstack_slice = InMemoryDataAccessor(self.stack_ch_pa.data[:, :, :, 0])
         self.assertEqual(acc_zstack_slice.nz, 1)
-        id_map = _get_label_ids(self.seg_mask)
 
-        roiset = RoiSet(acc_zstack_slice, id_map, params=RoiSetMetaParams(mask_type='boxes'))
+        roiset = RoiSet.from_segmentation(acc_zstack_slice, self.seg_mask, params=RoiSetMetaParams(mask_type='boxes'))
         zmask = roiset.get_zmask()
 
         zmask_acc = InMemoryDataAccessor(zmask)
@@ -158,9 +156,7 @@ class TestRoiSetMonoProducts(BaseTestRoiSetMonoProducts, unittest.TestCase):
         self.assertEqual(result.shape, roiset.acc_raw.shape)
 
     def test_flatten_image(self):
-        id_map = _get_label_ids(self.seg_mask)
-
-        roiset = RoiSet(self.stack_ch_pa, id_map, params=RoiSetMetaParams(mask_type='boxes'))
+        roiset = RoiSet.from_segmentation(self.stack_ch_pa, self.seg_mask, params=RoiSetMetaParams(mask_type='boxes'))
         df = roiset.get_df()
 
         from model_server.base.roiset import project_stack_from_focal_points
@@ -270,10 +266,9 @@ class TestRoiSetMultichannelProducts(BaseTestRoiSetMonoProducts, unittest.TestCa
 
     def setUp(self) -> None:
         super().setUp()
-        id_map = _get_label_ids(self.seg_mask)
-        self.roiset = RoiSet(
+        self.roiset = RoiSet.from_segmentation(
             self.stack,
-            id_map,
+            self.seg_mask,
             params=RoiSetMetaParams(
                 expand_box_by=(128, 2),
                 mask_type='boxes',
@@ -492,7 +487,7 @@ class TestRoiSetMultichannelProducts(BaseTestRoiSetMonoProducts, unittest.TestCa
         print('res')
 
 
-
+from model_server.base.roiset import _get_label_ids
 class TestRoiSetSerialization(unittest.TestCase):
 
     def setUp(self) -> None:
-- 
GitLab