From dc1b5e0886343d398d8353e78c1e31c17c7a079d Mon Sep 17 00:00:00 2001
From: Christopher Rhodes <christopher.rhodes@embl.de>
Date: Mon, 15 Jul 2024 15:30:23 +0200
Subject: [PATCH] Renamed static generator function

---
 model_server/base/roiset.py | 10 +++++++++-
 tests/base/test_roiset.py   | 14 +++++++-------
 2 files changed, 16 insertions(+), 8 deletions(-)

diff --git a/model_server/base/roiset.py b/model_server/base/roiset.py
index d437683c..1e39496b 100644
--- a/model_server/base/roiset.py
+++ b/model_server/base/roiset.py
@@ -123,6 +123,7 @@ def _safe_add(a, g, b):
 
 class RoiSet(object):
 
+    # TODO: __init__ to take bounding boxes e.g. from obj det model; flag if overlaps are allowed
     def __init__(
             self,
             acc_raw: GenericImageDataAccessor,
@@ -157,8 +158,9 @@ class RoiSet(object):
         """Expose ROI meta information via the Pandas.DataFrame API"""
         return self._df.itertuples(name='Roi')
 
+    # TODO: add or overload for object detection case
     @staticmethod
-    def from_segmentation(
+    def from_binary_mask(
             acc_raw: GenericImageDataAccessor,
             acc_seg: GenericImageDataAccessor,
             allow_3d=False,
@@ -247,6 +249,8 @@ class RoiSet(object):
             axis=1,
             result_type='reduce',
         )
+
+        # TODO: make this contingent on whether seg is included
         df['binary_mask'] = df.apply(
             lambda r: (acc_obj_ids.data == r.label).max(axis=-1)[r.y0: r.y1, r.x0: r.x1, 0],
             axis=1,
@@ -389,6 +393,7 @@ class RoiSet(object):
 
         self._df['classify_by_' + name] = pd.Series(dtype='Int64')
 
+        # TODO: separate method to get object map
         # assign labels to object map:
         for i, roi in enumerate(self):
             oc = np.unique(
@@ -673,6 +678,7 @@ class RoiSet(object):
         record['tight_patch_masks'] = list(se_pa)
         return record
 
+    # TODO: implement
     def serialize_coco(self, where: Path, prefix='') -> dict:
         """
         Export the RoiSet according to the COCO seg standard
@@ -695,6 +701,8 @@ class RoiSet(object):
 
         return {}
 
+    # TODO: add docstring
+    # TODO: make this work with obj det dataset
     @staticmethod
     def deserialize(acc_raw: GenericImageDataAccessor, where: Path, prefix=''):
         df = pd.read_csv(where / 'dataframe' / (prefix + '.csv'))[['label', 'zi', 'y0', 'y1', 'x0', 'x1']]
diff --git a/tests/base/test_roiset.py b/tests/base/test_roiset.py
index cff0ec01..beff2a73 100644
--- a/tests/base/test_roiset.py
+++ b/tests/base/test_roiset.py
@@ -29,7 +29,7 @@ class BaseTestRoiSetMonoProducts(object):
 class TestRoiSetMonoProducts(BaseTestRoiSetMonoProducts, unittest.TestCase):
 
     def _make_roi_set(self, mask_type='boxes', **kwargs):
-        roiset = RoiSet.from_segmentation(
+        roiset = RoiSet.from_binary_mask(
             self.stack_ch_pa,
             self.seg_mask,
             params=RoiSetMetaParams(
@@ -70,7 +70,7 @@ class TestRoiSetMonoProducts(BaseTestRoiSetMonoProducts, unittest.TestCase):
         acc_zstack_slice = InMemoryDataAccessor(self.stack_ch_pa.data[:, :, :, 0])
         self.assertEqual(acc_zstack_slice.nz, 1)
 
-        roiset = RoiSet.from_segmentation(acc_zstack_slice, self.seg_mask, params=RoiSetMetaParams(mask_type='boxes'))
+        roiset = RoiSet.from_binary_mask(acc_zstack_slice, self.seg_mask, params=RoiSetMetaParams(mask_type='boxes'))
         zmask = roiset.get_zmask()
 
         zmask_acc = InMemoryDataAccessor(zmask)
@@ -164,7 +164,7 @@ class TestRoiSetMonoProducts(BaseTestRoiSetMonoProducts, unittest.TestCase):
         self.assertEqual(result.shape, roiset.acc_raw.shape)
 
     def test_flatten_image(self):
-        roiset = RoiSet.from_segmentation(self.stack_ch_pa, self.seg_mask, params=RoiSetMetaParams(mask_type='boxes'))
+        roiset = RoiSet.from_binary_mask(self.stack_ch_pa, self.seg_mask, params=RoiSetMetaParams(mask_type='boxes'))
         df = roiset.get_df()
 
         from model_server.base.roiset import project_stack_from_focal_points
@@ -205,7 +205,7 @@ class TestRoiSetMonoProducts(BaseTestRoiSetMonoProducts, unittest.TestCase):
         return roiset
 
     def test_classify_by_multiple_channels(self):
-        roiset = RoiSet.from_segmentation(self.stack, self.seg_mask)
+        roiset = RoiSet.from_binary_mask(self.stack, self.seg_mask)
         roiset.classify_by('dummy_class', [0, 1], DummyInstanceSegmentationModel())
         self.assertTrue(all(roiset.get_df()['classify_by_dummy_class'].unique() == [1]))
         self.assertTrue(all(np.unique(roiset.object_class_maps['dummy_class'].data) == [0, 1]))
@@ -216,7 +216,7 @@ class TestRoiSetMonoProducts(BaseTestRoiSetMonoProducts, unittest.TestCase):
             def infer(self, img, mask):
                 return PatchStack(super().infer(img, mask).data * img.chroma)
 
-        roiset = RoiSet.from_segmentation(
+        roiset = RoiSet.from_binary_mask(
             self.stack,
             self.seg_mask,
             params=RoiSetMetaParams(
@@ -282,7 +282,7 @@ class TestRoiSetMultichannelProducts(BaseTestRoiSetMonoProducts, unittest.TestCa
 
     def setUp(self) -> None:
         super().setUp()
-        self.roiset = RoiSet.from_segmentation(
+        self.roiset = RoiSet.from_binary_mask(
             self.stack,
             self.seg_mask,
             params=RoiSetMetaParams(
@@ -617,7 +617,7 @@ class TestRoiSetSerialization(unittest.TestCase):
             self.assertTrue(np.all(r_acc.data == t_acc.data))
 
     def test_serialize_coco(self):
-        roiset = RoiSet.from_segmentation(
+        roiset = RoiSet.from_binary_mask(
             self.stack_ch_pa,
             self.seg_mask_3d,
             params=RoiSetMetaParams(mask_type='contours')
-- 
GitLab