Skip to content
Snippets Groups Projects
Commit 7fe537b3 authored by Christopher Randolph Rhodes's avatar Christopher Randolph Rhodes
Browse files

Standardized RoiSet constructor works on object id maps

parent b5953056
No related merge requests found
......@@ -8,13 +8,15 @@ from model_server.extensions.chaeo.conf.testing import multichannel_zstack, pixe
from extensions.chaeo.params import RoiSetExportParams, RoiSetMetaParams
from model_server.extensions.chaeo.products import export_patches_from_zstack, export_multichannel_patches_from_zstack, export_patch_masks_from_zstack
from model_server.extensions.chaeo.workflows import infer_object_map_from_zstack
from model_server.extensions.chaeo.zmask import build_zmask_from_object_mask
from model_server.extensions.chaeo.zmask import build_zmask_from_object_mask, get_label_ids
from model_server.base.accessors import generate_file_accessor, InMemoryDataAccessor, write_accessor_data_to_file
from model_server.extensions.ilastik.models import IlastikPixelClassifierModel
from model_server.base.models import DummyInstanceSegmentationModel
class TestZStackDerivedDataProducts(unittest.TestCase):
# TODO: add cases that call RoiSet directly, not just through workflow function
def setUp(self) -> None:
# need test data incl obj map
......@@ -39,8 +41,9 @@ class TestZStackDerivedDataProducts(unittest.TestCase):
write_accessor_data_to_file(output_path / 'seg_mask.tif', self.seg_mask)
def test_zmask_makes_correct_boxes(self, mask_type='boxes', **kwargs):
id_map = get_label_ids(self.seg_mask)
zmask, meta, df, interm = build_zmask_from_object_mask(
self.seg_mask,
id_map,
self.stack_ch_pa,
params=RoiSetMetaParams(mask_type=mask_type, filters=kwargs.get('filters')),
)
......@@ -71,8 +74,9 @@ class TestZStackDerivedDataProducts(unittest.TestCase):
def test_zmask_works_on_non_zstacks(self, **kwargs):
acc_zstack_slice = InMemoryDataAccessor(self.stack_ch_pa.data[:, :, :, 0])
self.assertEqual(acc_zstack_slice.nz, 1)
id_map = get_label_ids(self.seg_mask)
zmask, meta, df, interm = build_zmask_from_object_mask(
self.seg_mask,
id_map,
acc_zstack_slice,
params=RoiSetMetaParams(mask_type='boxes'),
**kwargs,
......@@ -115,8 +119,9 @@ class TestZStackDerivedDataProducts(unittest.TestCase):
self.assertGreaterEqual(len(files), 1)
def test_flatten_image(self):
id_map = get_label_ids(self.seg_mask)
zmask, meta, df, interm = build_zmask_from_object_mask(
self.seg_mask,
id_map,
self.stack_ch_pa,
params=RoiSetMetaParams(mask_type='boxes')
)
......
......@@ -10,7 +10,7 @@ from sklearn.model_selection import train_test_split
from extensions.chaeo.params import RoiSetExportParams, RoiSetMetaParams
from model_server.extensions.chaeo.process import mask_largest_object
from model_server.extensions.chaeo.zmask import RoiSet
from model_server.extensions.chaeo.zmask import get_label_ids, RoiSet
from model_server.base.accessors import generate_file_accessor, InMemoryDataAccessor, write_accessor_data_to_file
from model_server.base.models import Model, InstanceSegmentationModel, SemanticSegmentationModel
......@@ -48,7 +48,7 @@ def infer_object_map_from_zstack(
ti.click('classify_pixels')
# make zmask
rois = RoiSet(mip_mask, stack, params=roi_params)
rois = RoiSet(get_label_ids(mip_mask), stack, params=roi_params)
ti.click('generate_zmasks')
rois.classify_by(patches_channel, models['object_classifier']['model'])
......@@ -64,7 +64,7 @@ def infer_object_map_from_zstack(
'output_path': output_folder_path,
}
# TODO: to app-specific ecotaxa module
def transfer_ecotaxa_labels_to_patch_stacks(
where_masks: str,
where_patches: str,
......
......@@ -14,17 +14,19 @@ from model_server.extensions.chaeo.process import mask_largest_object
from model_server.base.accessors import GenericImageDataAccessor, InMemoryDataAccessor, write_accessor_data_to_file
from model_server.base.models import InstanceSegmentationModel
def get_label_ids(acc_seg_mask):
return label(acc_seg_mask.data[:, :, 0, 0]).astype('uint16')
class RoiSet(object):
def __init__(
self,
acc_mask: GenericImageDataAccessor,
acc_obj_ids: GenericImageDataAccessor, # TODO: enforce subtype of binary or label ID mask
acc_raw: GenericImageDataAccessor,
params: RoiSetMetaParams = RoiSetMetaParams(),
):
self.zmask, self.zmask_meta, self.df, self.interm = build_zmask_from_object_mask(
acc_mask,
acc_obj_ids,
acc_raw,
params=params,
)
......@@ -135,7 +137,7 @@ def build_zmask_from_object_mask(
"""
Given a 2D mask of objects, build a 3D mask, where each object's z-position is determined by the index of
maximum intensity in z. Return this zmask and a list of each object's meta information.
:param obmask: GenericImageDataAccessor monochrome 2D binary mask of objects
:param obmask: GenericImageDataAccessor 2D map of objects IDs
:param zstack: GenericImageDataAccessor monochrome zstack of same Y, X dimension as obmask
:param params: RoiSetMetaParams
filters: dictionary of form {attribute: (min, max)}; valid attributes are 'area' and 'solidity'
......@@ -161,13 +163,16 @@ def build_zmask_from_object_mask(
# validate inputs
# assert zstack.chroma == 1
assert mask_type in ('contours', 'boxes'), mask_type
assert obmask.is_mask()
assert obmask.chroma == 1
assert obmask.nz == 1
assert zstack.hw == obmask.hw
# assert obmask.is_mask()
# assert obmask.chroma == 1
# assert obmask.nz == 1
if zstack.hw != obmask.shape:
input()
assert zstack.hw == obmask.shape
# assign object labels and build object query
lamap = label(obmask.data[:, :, 0, 0]).astype('uint16')
# lamap = label(obmask.data[:, :, 0, 0]).astype('uint16')
lamap = obmask
query_str = 'label > 0' # always true
if filters is not None:
for k, val in filters.dict(exclude_unset=True).items():
......@@ -218,18 +223,18 @@ def build_zmask_from_object_mask(
z1 = min(ob.zi + ebz, nz)
# relative bounding box positions
rbb = {
rbb = { # TODO: just put in the DF
'y0': ob.y0 - y0,
'y1': ob.y1 - y0,
'x0': ob.x0 - x0,
'x1': ob.x1 - x0,
}
sl = np.s_[y0: y1, x0: x1, :, z0: z1 + 1]
sl = np.s_[y0: y1, x0: x1, :, z0: z1 + 1] # TODO: on-the-fly in RoiSet, given DF
# compute contours
obmask = (lamap == ob.label)
contour = find_contours(obmask)
obmask = (lamap == ob.label) # TODO: on-the-fly
contour = find_contours(obmask) # TODO: on-the-fly
mask = obmask[ob.y0: ob.y1, ob.x0: ob.x1]
assert rbb['x1'] <= (x1 - x0)
......@@ -244,7 +249,7 @@ def build_zmask_from_object_mask(
'mask': mask
})
# build mask z-stack
# build mask z-stack # TODO: on-the-fly
zi_st = np.zeros(zstack.shape, dtype='bool')
if mask_type == 'contours':
zi_map = (lut[lamap] + 1.0).astype('int')
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment