From 37e678ec181b245b487f69f1c06199168ef42e5e Mon Sep 17 00:00:00 2001
From: Christopher Rhodes <christopher.rhodes@embl.de>
Date: Thu, 28 Sep 2023 14:53:11 +0200
Subject: [PATCH] Tests of zmask results

---
 extensions/chaeo/test_zstack.py | 28 ++++++++++++++-------
 extensions/chaeo/zstack.py      | 43 ++++++++++++++++++---------------
 model_server/accessors.py       |  8 +++++-
 3 files changed, 49 insertions(+), 30 deletions(-)

diff --git a/extensions/chaeo/test_zstack.py b/extensions/chaeo/test_zstack.py
index dd4f279b..ae13cb0f 100644
--- a/extensions/chaeo/test_zstack.py
+++ b/extensions/chaeo/test_zstack.py
@@ -1,5 +1,7 @@
 import unittest
 
+import numpy as np
+
 from conf.testing import output_path
 
 from extensions.chaeo.conf.testing import multichannel_zstack, pixel_classifier, pipeline_params
@@ -11,34 +13,42 @@ class TestZStackDerivedDataProducts(unittest.TestCase):
 
     def setUp(self) -> None:
         # need test data incl obj map
-        self.zstack = generate_file_accessor(multichannel_zstack['path'])
+        self.stack = generate_file_accessor(multichannel_zstack['path'])
 
         pxmodel = IlastikPixelClassifierModel(
             {'project_file': pixel_classifier['path']},
         )
-        mip = InMemoryDataAccessor(self.zstack.get_one_channel_data(channel=0).data.max(axis=-1, keepdims=True))
+        mip = InMemoryDataAccessor(self.stack.get_one_channel_data(channel=0).data.max(axis=-1, keepdims=True))
         self.pxmap, result = pxmodel.infer(mip)
 
         write_accessor_data_to_file(output_path / 'pxmap.tif', self.pxmap)
-        self.obmap = InMemoryDataAccessor((self.pxmap.data > pipeline_params['threshold']).astype('uint8'))
+        self.obmap = InMemoryDataAccessor(self.pxmap.data > pipeline_params['threshold'])
         write_accessor_data_to_file(output_path / 'obmap.tif', self.obmap)
 
     def test_zmask_makes_correct_boxes(self):
         zmask, meta = build_stack_mask(
             'test_zmask_with boxes',
-            self.stack,
-            self.obmap,
+            self.obmap.get_one_channel_data(0),
+            self.stack.get_one_channel_data(0),
             mask_type='boxes',
         )
         zmask_acc = InMemoryDataAccessor(zmask)
-        self.assertTrue(zmask_acc.is_object_map())
+        self.assertTrue(zmask_acc.is_mask())
 
         # assert dimensionality of zmask
-        self.assertEqual(zmask.shape_dict['Z'] > 1)
-        self.assertEqual(zmask.shape_dict['C'] == 1)
+        self.assertGreater(zmask_acc.shape_dict['Z'], 1)
+        self.assertEqual(zmask_acc.shape_dict['C'], 1)
+        write_accessor_data_to_file(output_path / 'zmask.tif', zmask_acc)
+
+        # mask values are not just all True or all False
+        self.assertTrue(np.any(zmask))
+        self.assertFalse(np.all(zmask))
 
         # assert non-trivial meta info in boxes
-        pass
+        self.assertGreater(len(meta), 1)
+        sh = meta[1]['mask'].shape
+        ar = meta[1]['info'].area
+        self.assertGreaterEqual(sh[0] * sh[1], ar)
 
     def test_zmask_makes_correct_contours(self):
         pass
\ No newline at end of file
diff --git a/extensions/chaeo/zstack.py b/extensions/chaeo/zstack.py
index 475f6fd8..71bcd214 100644
--- a/extensions/chaeo/zstack.py
+++ b/extensions/chaeo/zstack.py
@@ -1,10 +1,12 @@
 import numpy as np
 import pandas as pd
 
-from skimage.measure import find_contours, regionprops_table
+from skimage.measure import find_contours, label, regionprops_table
+
+from model_server.accessors import GenericImageDataAccessor
 
 # build a single boolean 3d mask (objects v. bboxes) and return bounding boxes
-def build_stack_mask(desc, obmap, stack, filters=None, mask_type='contour', expand_box_by=(0, 0)): # TODO: specify boxes data type
+def build_stack_mask(desc, obmap: GenericImageDataAccessor, stack: GenericImageDataAccessor, filters=None, mask_type='contour', expand_box_by=(0, 0)): # TODO: specify boxes data type
     """
 
     filters: dict of (min, max) tuples
@@ -12,17 +14,17 @@ def build_stack_mask(desc, obmap, stack, filters=None, mask_type='contour', expa
     """
 
     # validate inputs
-    assert len(stack.shape) == 3, stack.shape
-    assert mask_type in ('contour', 'box'), mask_type # TODO: replace with call to validator
+    # assert len(stack.shape) == 3, stack.shape
+    assert stack.chroma == 1
+    assert stack.shape_dict['Z'] > 1
+    assert mask_type in ('contours', 'boxes'), mask_type # TODO: replace with call to validator
 
-    for k in filters.keys():
-        assert k in ('area', 'solidity')
-        vmin, vmax = filters[k]
-        assert vmin >= 0
+    assert obmap.is_mask()
+    lamap = label(obmap.data[:, :, 0, 0])
 
     # build object query
     query_str = 'label > 0'  # always true
-    if filters:
+    if filters is not None:
         for k in filters.keys():
             assert k in ('area', 'solidity')
             vmin, vmax = filters[k]
@@ -30,11 +32,11 @@ def build_stack_mask(desc, obmap, stack, filters=None, mask_type='contour', expa
             query_str = query_str + f' & {k} > {vmin} & {k} < {vmax}'
 
     # build dataframe of objects, assign z index to each object
-    argmax = stack.argmax(axis=0)
+    argmax = stack.data.argmax(axis=3, keepdims=True)[:, :, 0, 0]
     df = (
         pd.DataFrame(
             regionprops_table(
-                obmap,
+                lamap,
                 intensity_image=argmax,
                 properties=('label', 'area', 'intensity_mean', 'solidity', 'bbox')
             )
@@ -52,12 +54,12 @@ def build_stack_mask(desc, obmap, stack, filters=None, mask_type='contour', expa
     df['zi'] = df['intensity_mean'].round().astype('int')
 
     # make an object map where label is replaced by focus position in stack and background is -1
-    lut = np.zeros(obmap.max() + 1) - 1
+    lut = np.zeros(lamap.max() + 1) - 1
     lut[df.label] = df.zi
 
     # convert bounding boxes to slices
     ebxy, ebz = expand_box_by
-    nz, h, w = stack.shape
+    h, w, c, nz = stack.shape
 
     boxes = []
     for ob in df.itertuples(name='LabeledObject'):
@@ -76,10 +78,11 @@ def build_stack_mask(desc, obmap, stack, filters=None, mask_type='contour', expa
             'x1': ob.x1 - x0,
         }
 
-        sl = np.s_[z0: z1 + 1, y0: y1, x0: x1]
+        # sl = np.s_[z0: z1 + 1, y0: y1, x0: x1]
+        sl = np.s_[y0: y1, x0: x1, 0, z0: z1 + 1]
 
         # compute contours
-        obmask = (obmap == ob.label)
+        obmask = (lamap == ob.label)
         contour = find_contours(obmask)
         mask = obmask[ob.y0: ob.y1, ob.x0: ob.x1]
 
@@ -93,15 +96,15 @@ def build_stack_mask(desc, obmap, stack, filters=None, mask_type='contour', expa
 
     # build mask z-stack
     zi_st = np.zeros(stack.shape, dtype='bool')
-    if mask_type == 'contour':
-        zi_map = (lut[obmap] + 1.0).astype('int')
+    if mask_type == 'contours':
+        zi_map = (lut[lamap] + 1.0).astype('int')
         idxs = np.array([zi_map]) - 1
-        np.put_along_axis(zi_st, idxs, 1, axis=0)
+        np.put_along_axis(zi_st, idxs, 1, axis=3)
 
         # change background level from to 0 in final frame
-        zi_st[-1, :, :][obmap == 0] = 0
+        zi_st[:, :, :, -1][lamap == 0] = 0
 
-    elif mask_type == 'box':
+    elif mask_type == 'boxes':
         for bb in boxes:
             sl = bb['slice']
             zi_st[sl] = 1
diff --git a/model_server/accessors.py b/model_server/accessors.py
index ae5355f0..d9211c48 100644
--- a/model_server/accessors.py
+++ b/model_server/accessors.py
@@ -32,6 +32,9 @@ class GenericImageDataAccessor(ABC):
     def is_3d(self):
         return True if self.shape_dict['Z'] > 1 else False
 
+    def is_mask(self):
+        return self._data.dtype == 'bool'
+
     def get_one_channel_data (self, channel: int):
         c = int(channel)
         return InMemoryDataAccessor(self.data[:, :, c:(c+1), :])
@@ -133,7 +136,10 @@ def write_accessor_data_to_file(fpath: Path, accessor: GenericImageDataAccessor)
             [3, 2, 0, 1],
             [0, 1, 2, 3]
         )
-        tifffile.imwrite(fpath, zcyx, imagej=True)
+        if accessor.is_mask():
+            tifffile.imwrite(fpath, zcyx.astype('uint8'), imagej=True)
+        else:
+            tifffile.imwrite(fpath, zcyx, imagej=True)
     except:
         raise FileWriteError(f'Unable to write data to file')
     return True
-- 
GitLab