From 7fe537b345dd4ea45d9ed36e75c0b164787fa0ed Mon Sep 17 00:00:00 2001
From: Christopher Rhodes <christopher.rhodes@embl.de>
Date: Sat, 3 Feb 2024 09:06:44 +0100
Subject: [PATCH] Standardized RoiSet constructor works on object id maps

---
 .../extensions/chaeo/tests/test_zstack.py     | 13 +++++---
 model_server/extensions/chaeo/workflows.py    |  6 ++--
 model_server/extensions/chaeo/zmask.py        | 31 +++++++++++--------
 3 files changed, 30 insertions(+), 20 deletions(-)

diff --git a/model_server/extensions/chaeo/tests/test_zstack.py b/model_server/extensions/chaeo/tests/test_zstack.py
index 0907846e..e0e09eed 100644
--- a/model_server/extensions/chaeo/tests/test_zstack.py
+++ b/model_server/extensions/chaeo/tests/test_zstack.py
@@ -8,13 +8,15 @@ from model_server.extensions.chaeo.conf.testing import multichannel_zstack, pixe
 from extensions.chaeo.params import RoiSetExportParams, RoiSetMetaParams
 from model_server.extensions.chaeo.products import export_patches_from_zstack, export_multichannel_patches_from_zstack, export_patch_masks_from_zstack
 from model_server.extensions.chaeo.workflows import infer_object_map_from_zstack
-from model_server.extensions.chaeo.zmask import build_zmask_from_object_mask
+from model_server.extensions.chaeo.zmask import build_zmask_from_object_mask, get_label_ids
 from model_server.base.accessors import generate_file_accessor, InMemoryDataAccessor, write_accessor_data_to_file
 from model_server.extensions.ilastik.models import IlastikPixelClassifierModel
 from model_server.base.models import DummyInstanceSegmentationModel
 
 class TestZStackDerivedDataProducts(unittest.TestCase):
 
+    # TODO: add cases that call RoiSet directly, not just through workflow function
+
     def setUp(self) -> None:
 
         # need test data incl obj map
@@ -39,8 +41,9 @@ class TestZStackDerivedDataProducts(unittest.TestCase):
         write_accessor_data_to_file(output_path / 'seg_mask.tif', self.seg_mask)
 
     def test_zmask_makes_correct_boxes(self, mask_type='boxes', **kwargs):
+        id_map = get_label_ids(self.seg_mask)
         zmask, meta, df, interm = build_zmask_from_object_mask(
-            self.seg_mask,
+            id_map,
             self.stack_ch_pa,
             params=RoiSetMetaParams(mask_type=mask_type, filters=kwargs.get('filters')),
         )
@@ -71,8 +74,9 @@ class TestZStackDerivedDataProducts(unittest.TestCase):
     def test_zmask_works_on_non_zstacks(self, **kwargs):
         acc_zstack_slice = InMemoryDataAccessor(self.stack_ch_pa.data[:, :, :, 0])
         self.assertEqual(acc_zstack_slice.nz, 1)
+        id_map = get_label_ids(self.seg_mask)
         zmask, meta, df, interm = build_zmask_from_object_mask(
-            self.seg_mask,
+            id_map,
             acc_zstack_slice,
             params=RoiSetMetaParams(mask_type='boxes'),
             **kwargs,
@@ -115,8 +119,9 @@ class TestZStackDerivedDataProducts(unittest.TestCase):
         self.assertGreaterEqual(len(files), 1)
 
     def test_flatten_image(self):
+        id_map = get_label_ids(self.seg_mask)
         zmask, meta, df, interm = build_zmask_from_object_mask(
-            self.seg_mask,
+            id_map,
             self.stack_ch_pa,
             params=RoiSetMetaParams(mask_type='boxes')
         )
diff --git a/model_server/extensions/chaeo/workflows.py b/model_server/extensions/chaeo/workflows.py
index aa07205b..71333fd3 100644
--- a/model_server/extensions/chaeo/workflows.py
+++ b/model_server/extensions/chaeo/workflows.py
@@ -10,7 +10,7 @@ from sklearn.model_selection import train_test_split
 
 from extensions.chaeo.params import RoiSetExportParams, RoiSetMetaParams
 from model_server.extensions.chaeo.process import mask_largest_object
-from model_server.extensions.chaeo.zmask import RoiSet
+from model_server.extensions.chaeo.zmask import get_label_ids, RoiSet
 
 from model_server.base.accessors import generate_file_accessor, InMemoryDataAccessor, write_accessor_data_to_file
 from model_server.base.models import Model, InstanceSegmentationModel, SemanticSegmentationModel
@@ -48,7 +48,7 @@ def infer_object_map_from_zstack(
     ti.click('classify_pixels')
 
     # make zmask
-    rois = RoiSet(mip_mask, stack, params=roi_params)
+    rois = RoiSet(get_label_ids(mip_mask), stack, params=roi_params)
     ti.click('generate_zmasks')
 
     rois.classify_by(patches_channel, models['object_classifier']['model'])
@@ -64,7 +64,7 @@ def infer_object_map_from_zstack(
         'output_path': output_folder_path,
     }
 
-
+# TODO: to app-specific ecotaxa module
 def transfer_ecotaxa_labels_to_patch_stacks(
     where_masks: str,
     where_patches: str,
diff --git a/model_server/extensions/chaeo/zmask.py b/model_server/extensions/chaeo/zmask.py
index 380f7b7d..2199a677 100644
--- a/model_server/extensions/chaeo/zmask.py
+++ b/model_server/extensions/chaeo/zmask.py
@@ -14,17 +14,19 @@ from model_server.extensions.chaeo.process import mask_largest_object
 from model_server.base.accessors import GenericImageDataAccessor, InMemoryDataAccessor, write_accessor_data_to_file
 from model_server.base.models import InstanceSegmentationModel
 
+def get_label_ids(acc_seg_mask):
+    return label(acc_seg_mask.data[:, :, 0, 0]).astype('uint16')
 
 class RoiSet(object):
 
     def __init__(
             self,
-            acc_mask: GenericImageDataAccessor,
+            acc_obj_ids: GenericImageDataAccessor,  # TODO: enforce subtype of binary or label ID mask
             acc_raw: GenericImageDataAccessor,
             params: RoiSetMetaParams = RoiSetMetaParams(),
     ):
         self.zmask, self.zmask_meta, self.df, self.interm = build_zmask_from_object_mask(
-            acc_mask,
+            acc_obj_ids,
             acc_raw,
             params=params,
         )
@@ -135,7 +137,7 @@ def build_zmask_from_object_mask(
     """
     Given a 2D mask of objects, build a 3D mask, where each object's z-position is determined by the index of
     maximum intensity in z.  Return this zmask and a list of each object's meta information.
-    :param obmask: GenericImageDataAccessor monochrome 2D binary mask of objects
+    :param obmask: GenericImageDataAccessor  2D map of objects IDs
     :param zstack: GenericImageDataAccessor monochrome zstack of same Y, X dimension as obmask
     :param params: RoiSetMetaParams
         filters: dictionary of form {attribute: (min, max)}; valid attributes are 'area' and 'solidity'
@@ -161,13 +163,16 @@ def build_zmask_from_object_mask(
     # validate inputs
     # assert zstack.chroma == 1
     assert mask_type in ('contours', 'boxes'), mask_type
-    assert obmask.is_mask()
-    assert obmask.chroma == 1
-    assert obmask.nz == 1
-    assert zstack.hw == obmask.hw
+    # assert obmask.is_mask()
+    # assert obmask.chroma == 1
+    # assert obmask.nz == 1
+    if zstack.hw != obmask.shape:
+        input()
+    assert zstack.hw == obmask.shape
 
     # assign object labels and build object query
-    lamap = label(obmask.data[:, :, 0, 0]).astype('uint16')
+    # lamap = label(obmask.data[:, :, 0, 0]).astype('uint16')
+    lamap = obmask
     query_str = 'label > 0'  # always true
     if filters is not None:
         for k, val in filters.dict(exclude_unset=True).items():
@@ -218,18 +223,18 @@ def build_zmask_from_object_mask(
         z1 = min(ob.zi + ebz, nz)
 
         # relative bounding box positions
-        rbb = {
+        rbb = {  # TODO: just put in the DF
             'y0': ob.y0 - y0,
             'y1': ob.y1 - y0,
             'x0': ob.x0 - x0,
             'x1': ob.x1 - x0,
         }
 
-        sl = np.s_[y0: y1, x0: x1, :, z0: z1 + 1]
+        sl = np.s_[y0: y1, x0: x1, :, z0: z1 + 1]  # TODO: on-the-fly in RoiSet, given DF
 
         # compute contours
-        obmask = (lamap == ob.label)
-        contour = find_contours(obmask)
+        obmask = (lamap == ob.label) # TODO: on-the-fly
+        contour = find_contours(obmask) # TODO: on-the-fly
         mask = obmask[ob.y0: ob.y1, ob.x0: ob.x1]
 
         assert rbb['x1'] <= (x1 - x0)
@@ -244,7 +249,7 @@ def build_zmask_from_object_mask(
             'mask': mask
         })
 
-    # build mask z-stack
+    # build mask z-stack # TODO: on-the-fly
     zi_st = np.zeros(zstack.shape, dtype='bool')
     if mask_type == 'contours':
         zi_map = (lut[lamap] + 1.0).astype('int')
-- 
GitLab