From 51339418dce1eaef2e42e7ee02aa7e19301cbe34 Mon Sep 17 00:00:00 2001
From: Christopher Rhodes <christopher.rhodes@embl.de>
Date: Wed, 17 Jul 2024 09:40:25 +0200
Subject: [PATCH] Accidentally not taking mask MIP

---
 model_server/base/roiset.py | 11 +++--------
 tests/base/test_roiset.py   |  1 +
 2 files changed, 4 insertions(+), 8 deletions(-)

diff --git a/model_server/base/roiset.py b/model_server/base/roiset.py
index f0af11a7..1e19ecfd 100644
--- a/model_server/base/roiset.py
+++ b/model_server/base/roiset.py
@@ -174,12 +174,10 @@ def _make_df_from_object_ids(acc_raw, acc_obj_ids, expand_box_by) -> pd.DataFram
     # TODO: make this contingent on whether seg is included
     def _make_binary_mask(r):
         acc = InMemoryDataAccessor(acc_obj_ids.data == r.label)
-        acc.get_mono(0, mip=True)
-        cropped = acc.crop_hw((r.y0, r.x0, (r.y1 - r.y0), (r.x1 - r.x0))).data
+        cropped = acc.get_mono(0, mip=True).crop_hw((r.y0, r.x0, (r.y1 - r.y0), (r.x1 - r.x0))).data
         return cropped
 
     df['binary_mask'] = df.apply(
-        # lambda r: (acc_obj_ids.data == r.label).max(axis=-1)[r.y0: r.y1, r.x0: r.x1, 0],
         _make_binary_mask,
         axis=1,
         result_type='reduce',
@@ -525,16 +523,13 @@ class RoiSet(object):
                 patch = np.zeros((roi.ebb_h, roi.ebb_w, 1, 1), dtype='uint8')
                 patch[roi.relative_slice][:, :, 0, 0] = roi.binary_mask * 255
             else:
-                patch = np.zeros((roi.y1 - roi.y0, roi.x1 - roi.x0, 1, 1), dtype='uint8')
-                patch = roi.binary_mask * 255
-
+                patch = (roi.binary_mask * 255).astype('uint8')
             if pad_to:
                 patch = pad(patch, pad_to)
             return patch
 
         dfe = self._df.copy()
-        # TODO: can just pass function handle
-        dfe['patch_mask'] = dfe.apply(lambda r: _make_patch_mask(r), axis=1)
+        dfe['patch_mask'] = dfe.apply(_make_patch_mask, axis=1)
         return dfe
 
     def get_patch_masks_acc(self, **kwargs) -> PatchStack:
diff --git a/tests/base/test_roiset.py b/tests/base/test_roiset.py
index 8e577f50..bb428b5b 100644
--- a/tests/base/test_roiset.py
+++ b/tests/base/test_roiset.py
@@ -634,6 +634,7 @@ class TestRoiSetSerialization(unittest.TestCase):
             m_acc = generate_file_accessor(pmf)
             self.assertEqual((roi.h, roi.w), m_acc.hw)
             patch_filenames.append(pmf.name)
+            self.assertEqual(m_acc.nz, 1)
 
         # make another RoiSet from just the data table, raw images, and (tight) patch masks
         test_roiset = RoiSet.deserialize(self.stack_ch_pa, where_ser, prefix='ref')
-- 
GitLab