From ee6740a1d59f1525ce63bdf047a0af5fd8a3e3e8 Mon Sep 17 00:00:00 2001 From: Christopher Rhodes <christopher.rhodes@embl.de> Date: Sat, 7 Dec 2024 08:58:31 +0100 Subject: [PATCH] Getting a patch mask can optionally update patch focus --- model_server/base/roiset.py | 22 +++++++++++++++------- tests/base/test_roiset.py | 26 ++++++++++++++++++++++++++ 2 files changed, 41 insertions(+), 7 deletions(-) diff --git a/model_server/base/roiset.py b/model_server/base/roiset.py index 510b9fe3..0a8775e6 100644 --- a/model_server/base/roiset.py +++ b/model_server/base/roiset.py @@ -40,6 +40,10 @@ class PatchParams(BaseModel): pad_to: Union[int, None] = 256 expanded: bool = False force_tif: bool = False + update_focus_zi: bool = Field( + False, + description='If generating 2d patches with a 3d focus metric, update the RoiSet patch focus zi accordingly' + ) class AnnotatedPatchParams(PatchParams): bounding_box_channel: int = 1 @@ -888,6 +892,7 @@ class RoiSet(object): rgb_overlay_weights: list = [1.0, 1.0, 1.0], white_channel: int = None, expanded=False, + update_focus_zi=False, **kwargs ) -> pd.Series: @@ -937,7 +942,7 @@ class RoiSet(object): else: stack = raw.data - def _make_patch(roi): + def _make_patch(roi): # extract, focus, and annotate a patch if expanded: patch3d = stack[roi.expanded_slice] subpatch = patch3d[roi.relative_slice] @@ -947,11 +952,11 @@ class RoiSet(object): ph, pw, pc, pz = patch3d.shape - # make a 3d patch + # make a 3d patch, focus stays where it is if make_3d: patch = patch3d.copy() + zif = roi.zi - # TODO: somehow persist the result of patch focusing for use by patch mask exporter # make a 2d patch, find optimal z-position determined by focus_metric function on each channel separately elif focus_metric is not None: foc = focus_metrics()[focus_metric] @@ -965,8 +970,8 @@ class RoiSet(object): # make a 2d patch from middle of z-stack else: - zim = floor(pz / 2) - patch = patch3d[:, :, :, [zim]] + zif = floor(pz / 2) + patch = patch3d[:, :, :, [zif]] assert len(patch.shape) == 4 @@ -1012,9 +1017,12 @@ class RoiSet(object): if pad_to and expanded: patch = pad(patch, pad_to) - return patch + return {'patch': patch, 'zif': zif} - return self._df.apply(lambda r: _make_patch(r), axis=1) + df_processed_patches = self._df.apply(lambda r: _make_patch(r), axis=1, result_type='expand') + if update_focus_zi: + self._df['zi'] = df_processed_patches['zif'] + return df_processed_patches['patch'] @property def classification_columns(self): diff --git a/tests/base/test_roiset.py b/tests/base/test_roiset.py index e0b9168d..376b8b44 100644 --- a/tests/base/test_roiset.py +++ b/tests/base/test_roiset.py @@ -296,6 +296,7 @@ class TestRoiSet3dProducts(unittest.TestCase): 'draw_mask': False, 'pad_to': None, 'rgb_overlay_channels': None, + 'update_focus_zi': False, }, '3d_masks': { 'make_3d': True, @@ -324,6 +325,31 @@ class TestRoiSet3dProducts(unittest.TestCase): self.assertTrue(any([_get_nz_from_file(f) > 1 for f in res['patches_3d_masks']])) self.assertTrue(all([_get_nz_from_file(f) == 1 for f in res['patches_2d_masks']])) + def test_run_update_focus_during_patch_export(self): + p = RoiSetExportParams(**{ + 'patches': { + '2d': { + 'make_3d': False, + 'focus_metric': 'max_sobel', + 'white_channel': -1, + 'draw_bounding_box': False, + 'draw_mask': False, + 'pad_to': None, + 'rgb_overlay_channels': None, + 'update_focus_zi': True, + }, + }, + }) + roiset = self.test_create_roiset_from_3d_obj_ids() + starting_zi = roiset.get_df()['zi'] + + res = roiset.run_exports( + self.where, + prefix='test', + params=p + ) + updated_zi = roiset.get_df()['zi'] + self.assertTrue((starting_zi != updated_zi).any()) def test_run_export_mono_3d_labels_overlay(self): # export via .run_exports() method -- GitLab