From cb1bba3e40f1063a37813a1f53dd5fd732ccc7af Mon Sep 17 00:00:00 2001
From: Christopher Rhodes <christopher.rhodes@embl.de>
Date: Wed, 7 Feb 2024 15:10:04 +0100
Subject: [PATCH] Removed zmask_meta references

---
 .../extensions/chaeo/tests/test_zstack.py     | 17 ++--
 model_server/extensions/chaeo/zmask.py        | 77 +++++--------------
 2 files changed, 24 insertions(+), 70 deletions(-)

diff --git a/model_server/extensions/chaeo/tests/test_zstack.py b/model_server/extensions/chaeo/tests/test_zstack.py
index dfc548db..41571203 100644
--- a/model_server/extensions/chaeo/tests/test_zstack.py
+++ b/model_server/extensions/chaeo/tests/test_zstack.py
@@ -54,11 +54,9 @@ class TestRoiSetMonoProducts(BaseTestRoiSetMonoProducts, unittest.TestCase):
         )
         return roiset
 
-    def test_zmask_makes_correct_boxes(self, **kwargs):
+    def test_roi_mask_shape(self, **kwargs):
         roiset = self._make_roi_set(**kwargs)
         zmask = roiset.get_zmask()
-        meta = roiset.zmask_meta
-        interm = roiset.interm
         zmask_acc = InMemoryDataAccessor(zmask)
         self.assertTrue(zmask_acc.is_mask())
 
@@ -73,16 +71,10 @@ class TestRoiSetMonoProducts(BaseTestRoiSetMonoProducts, unittest.TestCase):
 
         # assert non-trivial meta info in boxes
         self.assertGreater(roiset.count, 1)
-        sh = meta[1]['mask'].shape
-        ar = meta[1]['info'].area
+        sh = roiset.get_df().iloc[1]['mask'].shape
+        ar = roiset.get_df().iloc[1]['area']
         self.assertGreaterEqual(sh[0] * sh[1], ar)
 
-        # assert dimensionality of intermediate data products
-        self.assertEqual(interm['label_map'].shape, zmask.shape[0:2])
-        self.assertEqual(interm['argmax'].shape, zmask.shape[0:2])
-
-        return roiset
-
     def test_zmask_works_on_non_zstacks(self, **kwargs):
         acc_zstack_slice = InMemoryDataAccessor(self.stack_ch_pa.data[:, :, :, 0])
         self.assertEqual(acc_zstack_slice.nz, 1)
@@ -176,7 +168,8 @@ class TestRoiSetMonoProducts(BaseTestRoiSetMonoProducts, unittest.TestCase):
     def test_classify_by(self):
         roiset = self._make_roi_set()
         roiset.classify_by('dummy_class', 0, DummyInstanceSegmentationModel())
-        self.assertTrue(all(roiset.get_df()['classify_by_dummy_class'].unique() == [1.]))
+        self.assertTrue(all(roiset.get_df()['classify_by_dummy_class'].unique() == [1]))
+        self.assertTrue(all(np.unique(roiset.object_class_maps['dummy_class'].data) == [0, 1]))
 
     def test_object_map_workflow(self):
         pp = pipeline_params
diff --git a/model_server/extensions/chaeo/zmask.py b/model_server/extensions/chaeo/zmask.py
index 4ed84a6b..55030731 100644
--- a/model_server/extensions/chaeo/zmask.py
+++ b/model_server/extensions/chaeo/zmask.py
@@ -21,17 +21,19 @@ from model_server.extensions.chaeo.process import mask_largest_object
 from model_server.extensions.chaeo.products import get_roiset_patches, get_patch_masks, export_patch_masks
 
 
-def get_label_ids(acc_seg_mask):
-    return label(acc_seg_mask.data[:, :, 0, 0]).astype('uint16')
+def get_label_ids(acc_seg_mask: GenericImageDataAccessor) -> InMemoryDataAccessor:
+    return InMemoryDataAccessor(label(acc_seg_mask.data[:, :, 0, 0]).astype('uint16'))
 
 class RoiSet(object):
 
     def __init__(
             self,
-            acc_obj_ids: GenericImageDataAccessor,  # TODO: enforce subtype of binary or label ID mask
+            acc_obj_ids: GenericImageDataAccessor,
             acc_raw: GenericImageDataAccessor,
             params: RoiSetMetaParams = RoiSetMetaParams(),
     ):
+        assert acc_obj_ids.chroma == 1
+        assert acc_obj_ids.nz == 1
         self.acc_obj_ids = acc_obj_ids
         self.acc_raw = acc_raw
         self.params = params
@@ -43,41 +45,7 @@ class RoiSet(object):
             params.filters,
         )
 
-        # temporarily build zmask meta here
-        meta = []
-        for ob in self.get_df().itertuples(name='LabeledObject'):
-            sl = np.s_[ob.ebb_y0: ob.ebb_y1, ob.ebb_x0: ob.ebb_x1, :, ob.ebb_z0: ob.ebb_z1 + 1]  # TODO: on-the-fly in RoiSet, given DF
-
-            # compute contours
-            obmask = (acc_obj_ids == ob.label)  # TODO: on-the-fly
-            contour = find_contours(obmask)  # TODO: on-the-fly
-            mask = obmask[ob.y0: ob.y1, ob.x0: ob.x1]
-
-            rbb = {  # TODO: just put in the DF
-                'y0': ob.rel_y0,
-                'y1': ob.rel_y1,
-                'x0': ob.rel_x0,
-                'x1': ob.rel_x1,
-            }
-
-            meta.append({
-                'df_index': ob.Index,
-                'info': ob,
-                'slice': sl,
-                'relative_bounding_box': rbb,  # TODO: put in DF
-                'contour': contour,  # TODO: delegate to getter
-                'mask': mask  # TODO: delegate to getter
-            })
-        self.zmask_meta = meta
-
-        # return intermediate image arrays  # TODO: make on-the-fly
-        self.interm = {
-            'label_map': acc_obj_ids,
-            'argmax': acc_raw.data.argmax(axis=3, keepdims=True)[:, :, 0, 0].astype('uint16'),
-        }
-
         self.count = len(self._df)
-        self.object_id_labels = self.interm['label_map']
         self.object_class_maps = {}  # classification results
 
     def __iter__(self):
@@ -98,7 +66,7 @@ class RoiSet(object):
         df = (
             pd.DataFrame(
                 regionprops_table(
-                    acc_obj_ids,
+                    acc_obj_ids.data[:, :, 0, 0],
                     intensity_image=argmax,
                     properties=('label', 'area', 'intensity_mean', 'solidity', 'bbox', 'centroid')
                 )
@@ -142,7 +110,7 @@ class RoiSet(object):
             axis=1
         )
         df['mask'] = df.apply(
-            lambda r: (acc_obj_ids == r.label)[r.y0: r.y1, r.x0: r.x1],
+            lambda r: (acc_obj_ids.data == r.label)[r.y0: r.y1, r.x0: r.x1, 0, 0],
             axis=1
         )
         return df
@@ -182,9 +150,6 @@ class RoiSet(object):
             projected = self.acc_raw.data.max(axis=-1)
         return projected
 
-    def get_object_mask_by_class(self, class_id):
-        return self.object_id_labels == class_id
-
     def get_patch_masks(self, **kwargs):
         return get_patch_masks(self, **kwargs)
 
@@ -217,7 +182,7 @@ class RoiSet(object):
 
         assert mask_type in ('contours', 'boxes')
         zi_st = np.zeros(self.acc_raw.shape, dtype='bool')
-        lamap = self.acc_obj_ids
+        lamap = self.acc_obj_ids.data
 
         # make an object map where label is replaced by focus position in stack and background is -1
         lut = np.zeros(lamap.max() + 1) - 1
@@ -248,27 +213,23 @@ class RoiSet(object):
 
         # do this on a patch basis, i.e. only one object per frame
         obmap_patches = object_classification_model.label_instance_class(
-            self.get_raw_patches(channel=channel),  # TODO: enforce df index
+            self.get_raw_patches(channel=channel),
             self.get_patch_masks()
         )
 
-        lamap = self.object_id_labels
-        om = np.zeros(lamap.shape, dtype=lamap.dtype)
+        om = np.zeros(self.acc_obj_ids.shape, self.acc_obj_ids.dtype)
 
-        df = self.get_df()
-        idx = df.index
-        se = pd.Series(data=np.nan, index=idx)
+        self._df['classify_by_' + name] = pd.Series(dtype='Int64')
 
         # assign labels to object map:
-        for i in range(0, len(idx)):
-            # object_id = self.zmask_meta[i]['info'].label
-            object_id = df.loc[idx[i], 'label']
-            result_patch = mask_largest_object(obmap_patches.iat(i))
-            object_class = np.unique(result_patch)[1]
-            om[self.object_id_labels == object_id] = object_class
-            se.loc[idx[i]] = object_class
-
-        self.add_df_col('classify_by_' + name, se)
+        for i, roi in enumerate(self):
+            oc = np.unique(
+                mask_largest_object(
+                    obmap_patches.iat(i)
+                )
+            )[1]
+            self._df.loc[roi.Index, 'classify_by_' + name] = oc
+            om[self.acc_obj_ids.data == roi.label] = oc
         self.object_class_maps[name] = InMemoryDataAccessor(om)
 
     def run_exports(self, where, channel, prefix, params: RoiSetExportParams):
-- 
GitLab