From 7fe537b345dd4ea45d9ed36e75c0b164787fa0ed Mon Sep 17 00:00:00 2001 From: Christopher Rhodes <christopher.rhodes@embl.de> Date: Sat, 3 Feb 2024 09:06:44 +0100 Subject: [PATCH] Standardized RoiSet constructor works on object id maps --- .../extensions/chaeo/tests/test_zstack.py | 13 +++++--- model_server/extensions/chaeo/workflows.py | 6 ++-- model_server/extensions/chaeo/zmask.py | 31 +++++++++++-------- 3 files changed, 30 insertions(+), 20 deletions(-) diff --git a/model_server/extensions/chaeo/tests/test_zstack.py b/model_server/extensions/chaeo/tests/test_zstack.py index 0907846e..e0e09eed 100644 --- a/model_server/extensions/chaeo/tests/test_zstack.py +++ b/model_server/extensions/chaeo/tests/test_zstack.py @@ -8,13 +8,15 @@ from model_server.extensions.chaeo.conf.testing import multichannel_zstack, pixe from extensions.chaeo.params import RoiSetExportParams, RoiSetMetaParams from model_server.extensions.chaeo.products import export_patches_from_zstack, export_multichannel_patches_from_zstack, export_patch_masks_from_zstack from model_server.extensions.chaeo.workflows import infer_object_map_from_zstack -from model_server.extensions.chaeo.zmask import build_zmask_from_object_mask +from model_server.extensions.chaeo.zmask import build_zmask_from_object_mask, get_label_ids from model_server.base.accessors import generate_file_accessor, InMemoryDataAccessor, write_accessor_data_to_file from model_server.extensions.ilastik.models import IlastikPixelClassifierModel from model_server.base.models import DummyInstanceSegmentationModel class TestZStackDerivedDataProducts(unittest.TestCase): + # TODO: add cases that call RoiSet directly, not just through workflow function + def setUp(self) -> None: # need test data incl obj map @@ -39,8 +41,9 @@ class TestZStackDerivedDataProducts(unittest.TestCase): write_accessor_data_to_file(output_path / 'seg_mask.tif', self.seg_mask) def test_zmask_makes_correct_boxes(self, mask_type='boxes', **kwargs): + id_map = get_label_ids(self.seg_mask) zmask, meta, df, interm = build_zmask_from_object_mask( - self.seg_mask, + id_map, self.stack_ch_pa, params=RoiSetMetaParams(mask_type=mask_type, filters=kwargs.get('filters')), ) @@ -71,8 +74,9 @@ class TestZStackDerivedDataProducts(unittest.TestCase): def test_zmask_works_on_non_zstacks(self, **kwargs): acc_zstack_slice = InMemoryDataAccessor(self.stack_ch_pa.data[:, :, :, 0]) self.assertEqual(acc_zstack_slice.nz, 1) + id_map = get_label_ids(self.seg_mask) zmask, meta, df, interm = build_zmask_from_object_mask( - self.seg_mask, + id_map, acc_zstack_slice, params=RoiSetMetaParams(mask_type='boxes'), **kwargs, @@ -115,8 +119,9 @@ class TestZStackDerivedDataProducts(unittest.TestCase): self.assertGreaterEqual(len(files), 1) def test_flatten_image(self): + id_map = get_label_ids(self.seg_mask) zmask, meta, df, interm = build_zmask_from_object_mask( - self.seg_mask, + id_map, self.stack_ch_pa, params=RoiSetMetaParams(mask_type='boxes') ) diff --git a/model_server/extensions/chaeo/workflows.py b/model_server/extensions/chaeo/workflows.py index aa07205b..71333fd3 100644 --- a/model_server/extensions/chaeo/workflows.py +++ b/model_server/extensions/chaeo/workflows.py @@ -10,7 +10,7 @@ from sklearn.model_selection import train_test_split from extensions.chaeo.params import RoiSetExportParams, RoiSetMetaParams from model_server.extensions.chaeo.process import mask_largest_object -from model_server.extensions.chaeo.zmask import RoiSet +from model_server.extensions.chaeo.zmask import get_label_ids, RoiSet from model_server.base.accessors import generate_file_accessor, InMemoryDataAccessor, write_accessor_data_to_file from model_server.base.models import Model, InstanceSegmentationModel, SemanticSegmentationModel @@ -48,7 +48,7 @@ def infer_object_map_from_zstack( ti.click('classify_pixels') # make zmask - rois = RoiSet(mip_mask, stack, params=roi_params) + rois = RoiSet(get_label_ids(mip_mask), stack, params=roi_params) ti.click('generate_zmasks') rois.classify_by(patches_channel, models['object_classifier']['model']) @@ -64,7 +64,7 @@ def infer_object_map_from_zstack( 'output_path': output_folder_path, } - +# TODO: to app-specific ecotaxa module def transfer_ecotaxa_labels_to_patch_stacks( where_masks: str, where_patches: str, diff --git a/model_server/extensions/chaeo/zmask.py b/model_server/extensions/chaeo/zmask.py index 380f7b7d..2199a677 100644 --- a/model_server/extensions/chaeo/zmask.py +++ b/model_server/extensions/chaeo/zmask.py @@ -14,17 +14,19 @@ from model_server.extensions.chaeo.process import mask_largest_object from model_server.base.accessors import GenericImageDataAccessor, InMemoryDataAccessor, write_accessor_data_to_file from model_server.base.models import InstanceSegmentationModel +def get_label_ids(acc_seg_mask): + return label(acc_seg_mask.data[:, :, 0, 0]).astype('uint16') class RoiSet(object): def __init__( self, - acc_mask: GenericImageDataAccessor, + acc_obj_ids: GenericImageDataAccessor, # TODO: enforce subtype of binary or label ID mask acc_raw: GenericImageDataAccessor, params: RoiSetMetaParams = RoiSetMetaParams(), ): self.zmask, self.zmask_meta, self.df, self.interm = build_zmask_from_object_mask( - acc_mask, + acc_obj_ids, acc_raw, params=params, ) @@ -135,7 +137,7 @@ def build_zmask_from_object_mask( """ Given a 2D mask of objects, build a 3D mask, where each object's z-position is determined by the index of maximum intensity in z. Return this zmask and a list of each object's meta information. - :param obmask: GenericImageDataAccessor monochrome 2D binary mask of objects + :param obmask: GenericImageDataAccessor 2D map of objects IDs :param zstack: GenericImageDataAccessor monochrome zstack of same Y, X dimension as obmask :param params: RoiSetMetaParams filters: dictionary of form {attribute: (min, max)}; valid attributes are 'area' and 'solidity' @@ -161,13 +163,16 @@ def build_zmask_from_object_mask( # validate inputs # assert zstack.chroma == 1 assert mask_type in ('contours', 'boxes'), mask_type - assert obmask.is_mask() - assert obmask.chroma == 1 - assert obmask.nz == 1 - assert zstack.hw == obmask.hw + # assert obmask.is_mask() + # assert obmask.chroma == 1 + # assert obmask.nz == 1 + if zstack.hw != obmask.shape: + input() + assert zstack.hw == obmask.shape # assign object labels and build object query - lamap = label(obmask.data[:, :, 0, 0]).astype('uint16') + # lamap = label(obmask.data[:, :, 0, 0]).astype('uint16') + lamap = obmask query_str = 'label > 0' # always true if filters is not None: for k, val in filters.dict(exclude_unset=True).items(): @@ -218,18 +223,18 @@ def build_zmask_from_object_mask( z1 = min(ob.zi + ebz, nz) # relative bounding box positions - rbb = { + rbb = { # TODO: just put in the DF 'y0': ob.y0 - y0, 'y1': ob.y1 - y0, 'x0': ob.x0 - x0, 'x1': ob.x1 - x0, } - sl = np.s_[y0: y1, x0: x1, :, z0: z1 + 1] + sl = np.s_[y0: y1, x0: x1, :, z0: z1 + 1] # TODO: on-the-fly in RoiSet, given DF # compute contours - obmask = (lamap == ob.label) - contour = find_contours(obmask) + obmask = (lamap == ob.label) # TODO: on-the-fly + contour = find_contours(obmask) # TODO: on-the-fly mask = obmask[ob.y0: ob.y1, ob.x0: ob.x1] assert rbb['x1'] <= (x1 - x0) @@ -244,7 +249,7 @@ def build_zmask_from_object_mask( 'mask': mask }) - # build mask z-stack + # build mask z-stack # TODO: on-the-fly zi_st = np.zeros(zstack.shape, dtype='bool') if mask_type == 'contours': zi_map = (lut[lamap] + 1.0).astype('int') -- GitLab