From efe1e45ce6311f17855031c82cde20de7427a6f2 Mon Sep 17 00:00:00 2001 From: Christopher Rhodes <christopher.rhodes@embl.de> Date: Sat, 3 Feb 2024 22:47:03 +0100 Subject: [PATCH] Product methods call RoiSet directly; this is causing problems with multichannel ones where raw data is essentially re-issues to RGB channels --- model_server/extensions/chaeo/products.py | 85 ++++++++++--------- .../extensions/chaeo/tests/test_zstack.py | 34 +++++--- model_server/extensions/chaeo/zmask.py | 1 + 3 files changed, 68 insertions(+), 52 deletions(-) diff --git a/model_server/extensions/chaeo/products.py b/model_server/extensions/chaeo/products.py index 3664ce8e..d001114f 100644 --- a/model_server/extensions/chaeo/products.py +++ b/model_server/extensions/chaeo/products.py @@ -85,8 +85,9 @@ def export_patch_masks(roiset, where: Path, pad_to: int = 256, prefix='mask') -> def get_patches_from_zmask_meta( - stack: GenericImageDataAccessor, - zmask_meta: list, + # stack: GenericImageDataAccessor, + # zmask_meta: list, + roiset, rescale_clip: float = 0.0, pad_to: int = 256, make_3d: bool = False, @@ -95,22 +96,23 @@ def get_patches_from_zmask_meta( ) -> MonoPatchStack: patches = [] - for mi in zmask_meta: - - sl = mi['slice'] - rbb = mi['relative_bounding_box'] # TODO: call rel_ fields in DF - idx = mi['df_index'] - - x0 = rbb['x0'] - y0 = rbb['y0'] - x1 = rbb['x1'] - y1 = rbb['y1'] - - sp_sl = np.s_[y0: y1, x0: x1, :, :] - - patch3d = stack.data[sl] + # for mi in zmask_meta: + for i, roi in enumerate(roiset.get_df().itertuples()): + + # sl = roi['slice'] + # rbb = mi['relative_bounding_box'] # TODO: call rel_ fields in DF + # idx = mi['df_index'] + # + # x0 = rbb['x0'] + # y0 = rbb['y0'] + # x1 = rbb['x1'] + # y1 = rbb['y1'] + # + # sp_sl = np.s_[y0: y1, x0: x1, :, :] + + patch3d = roiset.acc_raw.data[roi.slice] ph, pw, pc, pz = patch3d.shape - subpatch = patch3d[sp_sl] + subpatch = patch3d[roi.relative_slice] # make a 3d patch if make_3d: @@ -133,7 +135,7 @@ def get_patches_from_zmask_meta( patch = patch3d[:, :, :, [zim]] assert len(patch.shape) == 4 - assert patch.shape[2] == stack.chroma + assert patch.shape[2] == roiset.acc_raw.chroma if rescale_clip is not None: patch = rescale(patch, rescale_clip) @@ -146,21 +148,21 @@ def get_patches_from_zmask_meta( for zi in range(0, patch.shape[3]): patch[:, :, bci, zi] = draw_box_on_patch( patch[:, :, bci, zi], - ((x0, y0), (x1, y1)), + ((roi.rel_x0, roi.rel_y0), (roi.rel_x1, roi.rel_y1)), linewidth=kwargs.get('bounding_box_linewidth', 1) ) if kwargs.get('draw_mask'): mci = kwargs.get('mask_channel', 0) mask = np.zeros(patch.shape[0:2], dtype=bool) - mask[sp_sl[0:2]] = mi['mask'] + mask[roi.relative_slice[0:2]] = roi.mask for zi in range(0, patch.shape[3]): patch[:, :, mci, zi] = np.invert(mask) * patch[:, :, mci, zi] if kwargs.get('draw_contour'): mci = kwargs.get('contour_channel', 0) mask = np.zeros(patch.shape[0:2], dtype=bool) - mask[sp_sl[0:2]] = mi['mask'] + mask[roi.relative_slice[0:2]] = roi.mask for zi in range(0, patch.shape[3]): patch[:, :, mci, zi] = draw_contours_on_patch( @@ -180,8 +182,9 @@ def get_patches_from_zmask_meta( def export_patches_from_zstack( where: Path, - stack: GenericImageDataAccessor, - zmask_meta: list, + # stack: GenericImageDataAccessor, + # zmask_meta: list, + roiset, rescale_clip: float = 0.0, pad_to: int = 256, make_3d: bool = False, @@ -190,24 +193,25 @@ def export_patches_from_zstack( **kwargs ): patches_acc = get_patches_from_zmask_meta( - stack, - zmask_meta, - rescale_clip=rescale_clip, + roiset, + # stack, + # zmask_meta, + # rescale_clip=rescale_clip, pad_to=pad_to, make_3d=make_3d, focus_metric=focus_metric, **kwargs ) - assert len(zmask_meta) == patches_acc.count exported = [] - for i in range(0, len(zmask_meta)): - mi = zmask_meta[i] + for i, roi in enumerate(roiset.get_df().itertuples()): + # for i in range(0, len(zmask_meta)): + # mi = zmask_meta[i] patch = patches_acc.iat_yxcz(i) - obj = mi['info'] - idx = mi['df_index'] + # obj = mi['info'] + # idx = mi['df_index'] ext = 'tif' if make_3d else 'png' - fname = f'{prefix}-la{obj.label:04d}-zi{obj.zi:04d}.{ext}' + fname = f'{prefix}-la{roi.label:04d}-zi{roi.zi:04d}.{ext}' if patch.dtype is np.dtype('uint16'): write_patch_to_file(where, fname, resample_to_8bit(patch)) @@ -215,7 +219,7 @@ def export_patches_from_zstack( write_patch_to_file(where, fname, patch) exported.append({ - 'df_index': idx, + 'df_index': roi.Index, 'patch_filename': fname, }) return exported @@ -303,8 +307,7 @@ def export_3d_patches_with_focus_metrics( def export_multichannel_patches_from_zstack( where: Path, - stack: GenericImageDataAccessor, - zmask_meta: list, + roiset, rgb_overlay_channels: list = None, rgb_overlay_weights: list = [1.0, 1.0, 1.0], ch_white: int = None, @@ -327,9 +330,9 @@ def export_multichannel_patches_from_zstack( np.iinfo(a.dtype).max ).astype(a.dtype) - idata = stack.data + idata = roiset.acc_raw.data if ch_white: - assert ch_white < stack.chroma + assert ch_white < roiset.acc_raw.chroma mdata = idata[:, :, [ch_white, ch_white, ch_white], :] else: mdata = idata @@ -341,14 +344,14 @@ def export_multichannel_patches_from_zstack( if ci is None: continue assert isinstance(ci, int) - assert ci < stack.chroma + assert ci < roiset.acc_raw.chroma mdata[:, :, ii, :] = _safe_add( mdata[:, :, ii, :], rgb_overlay_weights[ii], idata[:, :, ci, :] ) + # TODO: this is a bit of a workaround mstack = InMemoryDataAccessor(mdata) - return export_patches_from_zstack( - where, mstack, zmask_meta, **kwargs - ) \ No newline at end of file + rgb_roiset = RoiSet(roiset.acc_obj_ids, mstack, roiset.params) + return export_patches_from_zstack(where, rgb_roiset, **kwargs) \ No newline at end of file diff --git a/model_server/extensions/chaeo/tests/test_zstack.py b/model_server/extensions/chaeo/tests/test_zstack.py index 6fca51b8..65ecd7b8 100644 --- a/model_server/extensions/chaeo/tests/test_zstack.py +++ b/model_server/extensions/chaeo/tests/test_zstack.py @@ -119,8 +119,7 @@ class TestZStackDerivedDataProducts(unittest.TestCase): ) files = export_patches_from_zstack( output_path / '2d_patches', - self.stack_ch_pa, - roiset.zmask_meta, + roiset, draw_bounding_box=True, ) self.assertGreaterEqual(len(files), 1) @@ -132,8 +131,9 @@ class TestZStackDerivedDataProducts(unittest.TestCase): ) files = export_patches_from_zstack( output_path / '3d_patches', - self.stack_ch_pa, - roiset.zmask_meta, + # self.stack_ch_pa, + roiset, + # roiset.zmask_meta, make_3d=True) self.assertGreaterEqual(len(files), 1) @@ -162,15 +162,26 @@ class TestZStackDerivedDataProducts(unittest.TestCase): InMemoryDataAccessor(img) ) + def _setup_multichannel_tests(self, mask_type='boxes', **kwargs): + id_map = get_label_ids(self.seg_mask) + return RoiSet( + id_map, + self.stack, + params=RoiSetMetaParams( + mask_type=mask_type, filters=kwargs.get('filters') + ) + ) + def test_make_multichannel_2d_patches_from_zmask(self): - roiset = self.test_zmask_makes_correct_boxes( + roiset = self._setup_multichannel_tests( filters={'area': {'min': 1e3, 'max': 1e4}}, expand_box_by=(128, 2) ) files = export_multichannel_patches_from_zstack( output_path / '2d_patches_chlorophyl_bbox_overlay', - InMemoryDataAccessor(self.stack.data), - roiset.zmask_meta, + # InMemoryDataAccessor(self.stack.data), + # roiset.zmask_meta, + roiset, ch_white=4, draw_bounding_box=True, bounding_box_channel=1, @@ -184,8 +195,9 @@ class TestZStackDerivedDataProducts(unittest.TestCase): ) files = export_multichannel_patches_from_zstack( output_path / '2d_patches_chlorophyl_mask_overlay', - InMemoryDataAccessor(self.stack.data), - roiset.zmask_meta, + # InMemoryDataAccessor(self.stack.data), + # roiset.zmask_meta, + roiset, ch_white=4, ch_rgb_overlay=(3, None, None), draw_mask=True, @@ -201,8 +213,8 @@ class TestZStackDerivedDataProducts(unittest.TestCase): ) files = export_multichannel_patches_from_zstack( output_path / '2d_patches_chlorophyl_contour_overlay', - InMemoryDataAccessor(self.stack.data), - roiset.zmask_meta, + # InMemoryDataAccessor(self.stack.data), + roiset, ch_white=4, ch_rgb_overlay=(3, None, None), draw_contour=True, diff --git a/model_server/extensions/chaeo/zmask.py b/model_server/extensions/chaeo/zmask.py index 781b6cbc..117e189a 100644 --- a/model_server/extensions/chaeo/zmask.py +++ b/model_server/extensions/chaeo/zmask.py @@ -34,6 +34,7 @@ class RoiSet(object): ): self.acc_obj_ids = acc_obj_ids self.acc_raw = acc_raw + self.params = params self._df = self.filter_df( self.make_df( -- GitLab