From 75dca594d3d4d005be6706ab47cf4611d5b9b704 Mon Sep 17 00:00:00 2001
From: Christopher Rhodes <christopher.rhodes@embl.de>
Date: Thu, 1 Aug 2024 12:14:49 +0200
Subject: [PATCH] Test cover use .data_xy and .data_xyz properties

---
 model_server/base/accessors.py |  4 ++--
 model_server/base/roiset.py    | 23 +++++++++++------------
 2 files changed, 13 insertions(+), 14 deletions(-)

diff --git a/model_server/base/accessors.py b/model_server/base/accessors.py
index cd35da8a..5be57a4b 100644
--- a/model_server/base/accessors.py
+++ b/model_server/base/accessors.py
@@ -64,14 +64,14 @@ class GenericImageDataAccessor(ABC):
 
     @property
     def data_xy(self) -> np.ndarray:
-        if not self.nc == 1 and self.nz == 1:
+        if not self.chroma == 1 and self.nz == 1:
             raise InvalidDataShape('Can only return XY array from accessors with a single channel and single z-level')
         else:
             return self.data[:, :, 0, 0]
 
     @property
     def data_xyz(self) -> np.ndarray:
-        if not self.nc == 1:
+        if not self.chroma == 1:
             raise InvalidDataShape('Can only return XYZ array from accessors with a single channel')
         else:
             return self.data[:, :, 0, :]
diff --git a/model_server/base/roiset.py b/model_server/base/roiset.py
index 12310bec..e28c6645 100644
--- a/model_server/base/roiset.py
+++ b/model_server/base/roiset.py
@@ -71,7 +71,7 @@ def get_label_ids(acc_seg_mask: GenericImageDataAccessor, allow_3d=False, connec
     """
     if allow_3d and connect_3d:
         nda_la = label(
-            acc_seg_mask.data[:, :, 0, :],
+            acc_seg_mask.data_xyz,
             connectivity=3,
         ).astype('uint16')
         return InMemoryDataAccessor(np.expand_dims(nda_la, 2))
@@ -80,7 +80,7 @@ def get_label_ids(acc_seg_mask: GenericImageDataAccessor, allow_3d=False, connec
         la_3d = np.zeros((*acc_seg_mask.hw, 1, acc_seg_mask.nz), dtype='uint16')
         for zi in range(0, acc_seg_mask.nz):
             la_2d = label(
-                acc_seg_mask.data[:, :, 0, zi],
+                acc_seg_mask.data_xyz[:, :, zi],
                 connectivity=2,
             ).astype('uint16')
             la_2d[la_2d > 0] = la_2d[la_2d > 0] + nla
@@ -90,7 +90,7 @@ def get_label_ids(acc_seg_mask: GenericImageDataAccessor, allow_3d=False, connec
     else:
         return InMemoryDataAccessor(
             label(
-                acc_seg_mask.data[:, :, 0, :].max(axis=-1),
+                acc_seg_mask.data_xyz.max(axis=-1),
                 connectivity=1,
             ).astype('uint16')
         )
@@ -152,7 +152,7 @@ def make_df_from_object_ids(acc_raw, acc_obj_ids, expand_box_by) -> pd.DataFrame
 
     if acc_obj_ids.nz == 1:  # deproject objects' z-coordinates from argmax of raw image
         df = pd.DataFrame(regionprops_table(
-            acc_obj_ids.data[:, :, 0, 0],
+            acc_obj_ids.data_xy,
             intensity_image=acc_raw.data.argmax(axis=3, keepdims=True)[:, :, 0, 0].astype('uint16'),
             properties=('label', 'area', 'intensity_mean', 'bbox')
         )).rename(columns={'bbox-0': 'y0', 'bbox-1': 'x0', 'bbox-2': 'y1', 'bbox-3': 'x1'})
@@ -160,7 +160,7 @@ def make_df_from_object_ids(acc_raw, acc_obj_ids, expand_box_by) -> pd.DataFrame
 
     else:  # objects' z-coordinates come from arg of max count in object identities map
         df = pd.DataFrame(regionprops_table(
-            acc_obj_ids.data[:, :, 0, :],
+            acc_obj_ids.data_xyz,
             properties=('label', 'area', 'bbox')
         )).rename(columns={
             'bbox-0': 'y0', 'bbox-1': 'x0', 'bbox-2': 'z0', 'bbox-3': 'y1', 'bbox-4': 'x1', 'bbox-5': 'z1'
@@ -171,7 +171,7 @@ def make_df_from_object_ids(acc_raw, acc_obj_ids, expand_box_by) -> pd.DataFrame
 
     def _make_binary_mask(r):
         acc = InMemoryDataAccessor(acc_obj_ids.data == r.label)
-        cropped = acc.get_mono(0, mip=True).crop_hw((r.y0, r.x0, (r.y1 - r.y0), (r.x1 - r.x0))).data[:, :, 0, 0]
+        cropped = acc.get_mono(0, mip=True).crop_hw((r.y0, r.x0, (r.y1 - r.y0), (r.x1 - r.x0))).data_xy
         return cropped
 
     df['binary_mask'] = df.apply(
@@ -581,7 +581,7 @@ class RoiSet(object):
 
             if white_channel:
                 assert white_channel < raw.chroma
-                stack = raw.data[:, :, [white_channel, white_channel, white_channel], :]
+                stack = raw.data[:, :, [white_channel, white_channel, white_channel], :] # TODO: remove direct data access
             else:
                 stack = np.zeros([*raw.shape[0:2], 3, raw.shape[3]], dtype=raw.dtype)
 
@@ -593,7 +593,7 @@ class RoiSet(object):
                 stack[:, :, ii, :] = safe_add(
                     stack[:, :, ii, :],  # either black or grayscale channel
                     rgb_overlay_weights[ii],
-                    raw.data[:, :, ci, :]
+                    raw.get_mono(ci).data_xyz
                 )
         else:
             if white_channel is not None:  # interpret as just a single channel
@@ -608,9 +608,9 @@ class RoiSet(object):
                         annotate_rgb = True
                         break
                 if annotate_rgb:  # make RGB patches anyway to include annotation color
-                    stack = raw.data[:, :, [white_channel, white_channel, white_channel], :]
+                    stack = raw.data[:, :, [white_channel, white_channel, white_channel], :] # TODO: remove direct data access
                 else:  # make monochrome patches
-                    stack = raw.data[:, :, [white_channel], :]
+                    stack = raw.data[:, :, [white_channel], :]  # TODO: remove direct data access
             elif kwargs.get('channels'):
                 stack = raw.get_channels(kwargs['channels']).data
             else:
@@ -778,7 +778,6 @@ class RoiSet(object):
         pad_to = 1
 
         def _poly_from_mask(roi):
-            # mask = generate_file_accessor(roi.mask_path).data[:, :, 0, 0]
             mask = roi.binary_mask
 
             # label and fill holes
@@ -847,7 +846,7 @@ class RoiSet(object):
             try:
                 ma_acc = generate_file_accessor(where / 'tight_patch_masks' / fname)
                 assert ma_acc.chroma == 1 and ma_acc.nz == 1
-                mask_data = ma_acc.data[:, :, 0, 0] / np.iinfo(ma_acc.data.dtype).max
+                mask_data = ma_acc.data_xy / np.iinfo(ma_acc.data.dtype).max
                 return mask_data
             except Exception as e:
                 raise DeserializeRoiSet(e)
-- 
GitLab