diff --git a/model_server/base/roiset.py b/model_server/base/roiset.py index ba35a9a97f7f6b83f9fb718a236b9e7269a9139d..caf27aaec4e0c7ecd3ced0eb6fc29c45481c5bc7 100644 --- a/model_server/base/roiset.py +++ b/model_server/base/roiset.py @@ -796,8 +796,11 @@ class RoiSet(object): patches_df = self.get_patch_masks(pad_to=pad_to, expanded=expanded).copy() def _export_patch_mask(roi): - patch = InMemoryDataAccessor(roi.patch_mask) - ext = 'png' + patch = InMemoryDataAccessor.from_mono(roi.patch_mask) + if patch.nz == 1: + ext = 'png' + else: + ext = 'tif' fname = f'{prefix}-la{roi.label:04d}-zi{roi.zi:04d}.{ext}' write_accessor_data_to_file(where / fname, patch) return fname @@ -835,16 +838,18 @@ class RoiSet(object): if pad_to: patch = pad(patch, pad_to) if is_3d: - return np.expand_dims(patch, 3) + return patch else: - return np.expand_dims(patch, (2, 3)) + return np.expand_dims(patch, 2) dfe = self._df.copy() dfe['patch_mask'] = dfe.apply(_make_patch_mask, axis=1) return dfe def get_patch_masks_acc(self, **kwargs) -> PatchStack: - return PatchStack(list(self.get_patch_masks(**kwargs).patch_mask)) + se_pm = self.get_patch_masks(**kwargs).patch_mask + se_ext = se_pm.apply(lambda x: np.expand_dims(x, 2)) + return PatchStack(list(se_ext)) def get_patches( self, diff --git a/tests/base/test_roiset.py b/tests/base/test_roiset.py index 46896538a47c5f9067465e0a75fb3c9890535ca7..5dc1fd986c193ac4cff99f1e9c77704eda853236 100644 --- a/tests/base/test_roiset.py +++ b/tests/base/test_roiset.py @@ -242,8 +242,7 @@ class TestRoiSetMonoProducts(BaseTestRoiSetMonoProducts, unittest.TestCase): roiset = self._make_roi_set() df_patch_masks = roiset.get_patch_masks() for roi in df_patch_masks.itertuples(): - h, w, nc, nz = roi.patch_mask.shape - self.assertEqual(nc, 1) + h, w, nz = roi.patch_mask.shape self.assertEqual(nz, 1) self.assertEqual(h, roi.h) self.assertEqual(w, roi.w)