From efe1e45ce6311f17855031c82cde20de7427a6f2 Mon Sep 17 00:00:00 2001
From: Christopher Rhodes <christopher.rhodes@embl.de>
Date: Sat, 3 Feb 2024 22:47:03 +0100
Subject: [PATCH] Product methods call RoiSet directly; this is causing
 problems with multichannel ones where raw data is essentially re-issues to
 RGB channels

---
 model_server/extensions/chaeo/products.py     | 85 ++++++++++---------
 .../extensions/chaeo/tests/test_zstack.py     | 34 +++++---
 model_server/extensions/chaeo/zmask.py        |  1 +
 3 files changed, 68 insertions(+), 52 deletions(-)

diff --git a/model_server/extensions/chaeo/products.py b/model_server/extensions/chaeo/products.py
index 3664ce8e..d001114f 100644
--- a/model_server/extensions/chaeo/products.py
+++ b/model_server/extensions/chaeo/products.py
@@ -85,8 +85,9 @@ def export_patch_masks(roiset, where: Path, pad_to: int = 256, prefix='mask') ->
 
 
 def get_patches_from_zmask_meta(
-        stack: GenericImageDataAccessor,
-        zmask_meta: list,
+        # stack: GenericImageDataAccessor,
+        # zmask_meta: list,
+        roiset,
         rescale_clip: float = 0.0,
         pad_to: int = 256,
         make_3d: bool = False,
@@ -95,22 +96,23 @@ def get_patches_from_zmask_meta(
 ) -> MonoPatchStack:
     patches = []
 
-    for mi in zmask_meta:
-
-        sl = mi['slice']
-        rbb = mi['relative_bounding_box'] # TODO: call rel_ fields in DF
-        idx = mi['df_index']
-
-        x0 = rbb['x0']
-        y0 = rbb['y0']
-        x1 = rbb['x1']
-        y1 = rbb['y1']
-
-        sp_sl = np.s_[y0: y1, x0: x1, :, :]
-
-        patch3d = stack.data[sl]
+    # for mi in zmask_meta:
+    for i, roi in enumerate(roiset.get_df().itertuples()):
+
+        # sl = roi['slice']
+        # rbb = mi['relative_bounding_box'] # TODO: call rel_ fields in DF
+        # idx = mi['df_index']
+        #
+        # x0 = rbb['x0']
+        # y0 = rbb['y0']
+        # x1 = rbb['x1']
+        # y1 = rbb['y1']
+        #
+        # sp_sl = np.s_[y0: y1, x0: x1, :, :]
+
+        patch3d = roiset.acc_raw.data[roi.slice]
         ph, pw, pc, pz = patch3d.shape
-        subpatch = patch3d[sp_sl]
+        subpatch = patch3d[roi.relative_slice]
 
         # make a 3d patch
         if make_3d:
@@ -133,7 +135,7 @@ def get_patches_from_zmask_meta(
             patch = patch3d[:, :, :, [zim]]
 
         assert len(patch.shape) == 4
-        assert patch.shape[2] == stack.chroma
+        assert patch.shape[2] == roiset.acc_raw.chroma
 
         if rescale_clip is not None:
             patch = rescale(patch, rescale_clip)
@@ -146,21 +148,21 @@ def get_patches_from_zmask_meta(
             for zi in range(0, patch.shape[3]):
                 patch[:, :, bci, zi] = draw_box_on_patch(
                     patch[:, :, bci, zi],
-                    ((x0, y0), (x1, y1)),
+                    ((roi.rel_x0, roi.rel_y0), (roi.rel_x1, roi.rel_y1)),
                     linewidth=kwargs.get('bounding_box_linewidth', 1)
                 )
 
         if kwargs.get('draw_mask'):
             mci = kwargs.get('mask_channel', 0)
             mask = np.zeros(patch.shape[0:2], dtype=bool)
-            mask[sp_sl[0:2]] = mi['mask']
+            mask[roi.relative_slice[0:2]] = roi.mask
             for zi in range(0, patch.shape[3]):
                 patch[:, :, mci, zi] = np.invert(mask) * patch[:, :, mci, zi]
 
         if kwargs.get('draw_contour'):
             mci = kwargs.get('contour_channel', 0)
             mask = np.zeros(patch.shape[0:2], dtype=bool)
-            mask[sp_sl[0:2]] = mi['mask']
+            mask[roi.relative_slice[0:2]] = roi.mask
 
             for zi in range(0, patch.shape[3]):
                 patch[:, :, mci, zi] = draw_contours_on_patch(
@@ -180,8 +182,9 @@ def get_patches_from_zmask_meta(
 
 def export_patches_from_zstack(
         where: Path,
-        stack: GenericImageDataAccessor,
-        zmask_meta: list,
+        # stack: GenericImageDataAccessor,
+        # zmask_meta: list,
+        roiset,
         rescale_clip: float = 0.0,
         pad_to: int = 256,
         make_3d: bool = False,
@@ -190,24 +193,25 @@ def export_patches_from_zstack(
         **kwargs
 ):
     patches_acc = get_patches_from_zmask_meta(
-        stack,
-        zmask_meta,
-        rescale_clip=rescale_clip,
+        roiset,
+        # stack,
+        # zmask_meta,
+        # rescale_clip=rescale_clip,
         pad_to=pad_to,
         make_3d=make_3d,
         focus_metric=focus_metric,
         **kwargs
     )
-    assert len(zmask_meta) == patches_acc.count
 
     exported = []
-    for i in range(0, len(zmask_meta)):
-        mi = zmask_meta[i]
+    for i, roi in enumerate(roiset.get_df().itertuples()):
+    # for i in range(0, len(zmask_meta)):
+    #     mi = zmask_meta[i]
         patch = patches_acc.iat_yxcz(i)
-        obj = mi['info']
-        idx = mi['df_index']
+        # obj = mi['info']
+        # idx = mi['df_index']
         ext = 'tif' if make_3d else 'png'
-        fname = f'{prefix}-la{obj.label:04d}-zi{obj.zi:04d}.{ext}'
+        fname = f'{prefix}-la{roi.label:04d}-zi{roi.zi:04d}.{ext}'
 
         if patch.dtype is np.dtype('uint16'):
             write_patch_to_file(where, fname, resample_to_8bit(patch))
@@ -215,7 +219,7 @@ def export_patches_from_zstack(
             write_patch_to_file(where, fname, patch)
 
         exported.append({
-            'df_index': idx,
+            'df_index': roi.Index,
             'patch_filename': fname,
         })
     return exported
@@ -303,8 +307,7 @@ def export_3d_patches_with_focus_metrics(
 
 def export_multichannel_patches_from_zstack(
     where: Path,
-    stack: GenericImageDataAccessor,
-    zmask_meta: list,
+    roiset,
     rgb_overlay_channels: list = None,
     rgb_overlay_weights: list = [1.0, 1.0, 1.0],
     ch_white: int = None,
@@ -327,9 +330,9 @@ def export_multichannel_patches_from_zstack(
             np.iinfo(a.dtype).max
         ).astype(a.dtype)
 
-    idata = stack.data
+    idata = roiset.acc_raw.data
     if ch_white:
-        assert ch_white < stack.chroma
+        assert ch_white < roiset.acc_raw.chroma
         mdata = idata[:, :, [ch_white, ch_white, ch_white], :]
     else:
         mdata = idata
@@ -341,14 +344,14 @@ def export_multichannel_patches_from_zstack(
             if ci is None:
                 continue
             assert isinstance(ci, int)
-            assert ci < stack.chroma
+            assert ci < roiset.acc_raw.chroma
             mdata[:, :, ii, :] = _safe_add(
                 mdata[:, :, ii, :],
                 rgb_overlay_weights[ii],
                 idata[:, :, ci, :]
             )
 
+    # TODO: this is a bit of a workaround
     mstack = InMemoryDataAccessor(mdata)
-    return export_patches_from_zstack(
-        where, mstack, zmask_meta, **kwargs
-    )
\ No newline at end of file
+    rgb_roiset = RoiSet(roiset.acc_obj_ids, mstack, roiset.params)
+    return export_patches_from_zstack(where, rgb_roiset, **kwargs)
\ No newline at end of file
diff --git a/model_server/extensions/chaeo/tests/test_zstack.py b/model_server/extensions/chaeo/tests/test_zstack.py
index 6fca51b8..65ecd7b8 100644
--- a/model_server/extensions/chaeo/tests/test_zstack.py
+++ b/model_server/extensions/chaeo/tests/test_zstack.py
@@ -119,8 +119,7 @@ class TestZStackDerivedDataProducts(unittest.TestCase):
         )
         files = export_patches_from_zstack(
             output_path / '2d_patches',
-            self.stack_ch_pa,
-            roiset.zmask_meta,
+            roiset,
             draw_bounding_box=True,
         )
         self.assertGreaterEqual(len(files), 1)
@@ -132,8 +131,9 @@ class TestZStackDerivedDataProducts(unittest.TestCase):
         )
         files = export_patches_from_zstack(
             output_path / '3d_patches',
-            self.stack_ch_pa,
-            roiset.zmask_meta,
+            # self.stack_ch_pa,
+            roiset,
+            # roiset.zmask_meta,
             make_3d=True)
         self.assertGreaterEqual(len(files), 1)
 
@@ -162,15 +162,26 @@ class TestZStackDerivedDataProducts(unittest.TestCase):
             InMemoryDataAccessor(img)
         )
 
+    def _setup_multichannel_tests(self, mask_type='boxes', **kwargs):
+        id_map = get_label_ids(self.seg_mask)
+        return RoiSet(
+            id_map,
+            self.stack,
+            params=RoiSetMetaParams(
+                mask_type=mask_type, filters=kwargs.get('filters')
+            )
+        )
+
     def test_make_multichannel_2d_patches_from_zmask(self):
-        roiset = self.test_zmask_makes_correct_boxes(
+        roiset = self._setup_multichannel_tests(
             filters={'area': {'min': 1e3, 'max': 1e4}},
             expand_box_by=(128, 2)
         )
         files = export_multichannel_patches_from_zstack(
             output_path / '2d_patches_chlorophyl_bbox_overlay',
-            InMemoryDataAccessor(self.stack.data),
-            roiset.zmask_meta,
+            # InMemoryDataAccessor(self.stack.data),
+            # roiset.zmask_meta,
+            roiset,
             ch_white=4,
             draw_bounding_box=True,
             bounding_box_channel=1,
@@ -184,8 +195,9 @@ class TestZStackDerivedDataProducts(unittest.TestCase):
         )
         files = export_multichannel_patches_from_zstack(
             output_path / '2d_patches_chlorophyl_mask_overlay',
-            InMemoryDataAccessor(self.stack.data),
-            roiset.zmask_meta,
+            # InMemoryDataAccessor(self.stack.data),
+            # roiset.zmask_meta,
+            roiset,
             ch_white=4,
             ch_rgb_overlay=(3, None, None),
             draw_mask=True,
@@ -201,8 +213,8 @@ class TestZStackDerivedDataProducts(unittest.TestCase):
         )
         files = export_multichannel_patches_from_zstack(
             output_path / '2d_patches_chlorophyl_contour_overlay',
-            InMemoryDataAccessor(self.stack.data),
-            roiset.zmask_meta,
+            # InMemoryDataAccessor(self.stack.data),
+            roiset,
             ch_white=4,
             ch_rgb_overlay=(3, None, None),
             draw_contour=True,
diff --git a/model_server/extensions/chaeo/zmask.py b/model_server/extensions/chaeo/zmask.py
index 781b6cbc..117e189a 100644
--- a/model_server/extensions/chaeo/zmask.py
+++ b/model_server/extensions/chaeo/zmask.py
@@ -34,6 +34,7 @@ class RoiSet(object):
     ):
         self.acc_obj_ids = acc_obj_ids
         self.acc_raw = acc_raw
+        self.params = params
 
         self._df = self.filter_df(
             self.make_df(
-- 
GitLab