diff --git a/model_server/extensions/chaeo/annotators.py b/model_server/extensions/chaeo/annotators.py index e182d89594f41f98de04ac92b7f2813dd2007979..b2df418d46fc5801f39e34f4512863dee69e7cfb 100644 --- a/model_server/extensions/chaeo/annotators.py +++ b/model_server/extensions/chaeo/annotators.py @@ -3,48 +3,30 @@ from PIL import Image, ImageDraw, ImageFont from model_server.base.process import rescale -def draw_boxes_on_2d_image(yx_img, boxes, **kwargs): - pilimg = Image.fromarray(np.copy(yx_img)) # drawing modifies array in-place - draw = ImageDraw.Draw(pilimg) +def draw_boxes_on_3d_image(roiset, draw_full_depth=False, **kwargs): + _, _, chroma, nz = roiset.acc_raw.shape font_size = kwargs.get('font_size', 18) linewidth = kwargs.get('linewidth', 4) - draw.font = ImageFont.truetype(font="arial.ttf", size=font_size) - - for box in boxes: - y0 = box['info'].y0 - y1 = box['info'].y1 - x0 = box['info'].x0 - x1 = box['info'].x1 - xm = round((x0 + x1) / 2) - - la = box['info'].label - zi = box['info'].zi - - draw.rectangle([(x0, y0), (x1, y1)], outline='white', width=linewidth) - - if kwargs.get('draw_label') is True: - draw.text((xm, y0), f'{la:04d}', fill='white', anchor='mb') - - return pilimg - - -def draw_boxes_on_3d_image(yxcz_img, boxes, draw_full_depth=False, **kwargs): - # assert len(yxcz_img.shape) == 4 - # nz = yxcz_img.shape[3] - _, _, chroma, nz = yxcz_img.shape - # assert yxcz_img.shape[2] == 1 - - annotated = np.zeros(yxcz_img.shape, dtype=yxcz_img.dtype) + annotated = np.zeros(roiset.acc_raw.shape, dtype=roiset.acc_raw.dtype) for zi in range(0, nz): if draw_full_depth: - zi_boxes = boxes + subset = roiset.get_df() else: - zi_boxes = [bb for bb in boxes if bb['info'].zi == zi] - # annotated[:, :, 0, zi] = draw_boxes_on_2d_image(yxcz_img[:, :, 0, zi], zi_boxes, **kwargs) + subset = roiset.get_df().query(f'zi == {zi}') for c in range(0, chroma): - annotated[:, :, c, zi] = draw_boxes_on_2d_image(yxcz_img[:, :, c, zi], zi_boxes, **kwargs) + pilimg = Image.fromarray(roiset.acc_raw.data[:, :, c, zi]) + draw = ImageDraw.Draw(pilimg) + draw.font = ImageFont.truetype(font="arial.ttf", size=font_size) + + for roi in subset.itertuples('Roi'): + xm = round((roi.x0 + roi.x1) / 2) + draw.rectangle([(roi.x0, roi.y0), (roi.x1, roi.y1)], outline='white', width=linewidth) + + if kwargs.get('draw_label') is True: + draw.text((xm, roi.y0), f'{roi.label:04d}', fill='white', anchor='mb') + annotated[:, :, c, zi] = pilimg if clip := kwargs.get('rescale_clip'): diff --git a/model_server/extensions/chaeo/products.py b/model_server/extensions/chaeo/products.py index b1cedda60dacff5d33b782617900344a22d09f40..f7782f8ced3d923c2a15128f96d197eb3846f81c 100644 --- a/model_server/extensions/chaeo/products.py +++ b/model_server/extensions/chaeo/products.py @@ -9,9 +9,9 @@ from skimage.io import imsave from skimage.measure import find_contours, shannon_entropy from tifffile import imwrite -from model_server.extensions.chaeo.accessors import MonoPatchStack, Multichannel3dPatchStack +from model_server.extensions.chaeo.accessors import MonoPatchStack from model_server.extensions.chaeo.annotators import draw_box_on_patch, draw_contours_on_patch -from model_server.base.accessors import GenericImageDataAccessor, InMemoryDataAccessor +from model_server.base.accessors import InMemoryDataAccessor from model_server.base.process import pad, rescale, resample_to_8bit def _make_rgb(zs): @@ -95,7 +95,7 @@ def export_patch_masks(roiset, where: Path, pad_to: int = 256, prefix='mask', ** return exported -def get_patches_from_zmask_meta( +def get_roiset_patches( roiset, rescale_clip: float = 0.0, pad_to: int = 256, @@ -225,7 +225,7 @@ def export_patches_from_zstack( **kwargs ): make_3d = kwargs.get('make_3d', False) - patches_df = get_patches_from_zmask_meta(roiset, **kwargs) + patches_df = get_roiset_patches(roiset, **kwargs) def _export_patch(roi): patch = InMemoryDataAccessor(roi.patch) diff --git a/model_server/extensions/chaeo/tests/test_zstack.py b/model_server/extensions/chaeo/tests/test_zstack.py index 73d00aaf61536ca5809bf3995f9f07bac0de48c4..b5744062623f41f67839367909a9c7226751de80 100644 --- a/model_server/extensions/chaeo/tests/test_zstack.py +++ b/model_server/extensions/chaeo/tests/test_zstack.py @@ -140,7 +140,7 @@ class TestRoiSetMonoProducts(BaseTestRoiSetMonoProducts, unittest.TestCase): def test_export_annotated_zstack(self): roiset = self._make_roi_set() file = roiset.export_annotated_zstack( - output_path / 'annotated_stack', + output_path / 'annotated_zstack', ) result = generate_file_accessor(Path(file['location']) / file['filename']) self.assertEqual(result.shape, roiset.acc_raw.shape) diff --git a/model_server/extensions/chaeo/zmask.py b/model_server/extensions/chaeo/zmask.py index 70479f2a063e92c144e8aa4468f55aac9bb3c119..f0cd31705fbe02b77718509a1341b11ab199eafd 100644 --- a/model_server/extensions/chaeo/zmask.py +++ b/model_server/extensions/chaeo/zmask.py @@ -18,7 +18,7 @@ from model_server.extensions.chaeo.products import export_patches_from_zstack, e from extensions.chaeo.params import RoiFilter, RoiSetMetaParams, RoiSetExportParams from model_server.extensions.chaeo.accessors import MonoPatchStack, Multichannel3dPatchStack from model_server.extensions.chaeo.process import mask_largest_object -from model_server.extensions.chaeo.products import get_patches_from_zmask_meta, get_patch_masks, export_patch_masks +from model_server.extensions.chaeo.products import get_roiset_patches, get_patch_masks, export_patch_masks def get_label_ids(acc_seg_mask): @@ -76,7 +76,7 @@ class RoiSet(object): 'argmax': acc_raw.data.argmax(axis=3, keepdims=True)[:, :, 0, 0].astype('uint16'), } - self.count = len(self.zmask_meta) + self.count = len(self._df) self.object_id_labels = self.interm['label_map'] self.object_class_maps = {} # classification results @@ -193,9 +193,9 @@ class RoiSet(object): def get_raw_patches(self, channel=None, pad_to=256, make_3d=False): # padded, un-annotated 2d patches if channel: - patches_df = get_patches_from_zmask_meta(self, white_channel=channel, pad_to=pad_to) + patches_df = get_roiset_patches(self, white_channel=channel, pad_to=pad_to) else: - patches_df = get_patches_from_zmask_meta(self, pad_to=pad_to) + patches_df = get_roiset_patches(self, pad_to=pad_to) patches = list(patches_df['patch']) if channel is not None or self.acc_raw.chroma == 1: return MonoPatchStack(patches) @@ -203,8 +203,8 @@ class RoiSet(object): return Multichannel3dPatchStack(patches) def export_annotated_zstack(self, where, prefix='zstack', **kwargs): - annotated = InMemoryDataAccessor( - draw_boxes_on_3d_image(self.acc_raw.data, self.zmask_meta, **kwargs) # TODO remove zmask_meta ref + annotated = InMemoryDataAccessor(draw_boxes_on_3d_image(self, **kwargs) + # draw_boxes_on_3d_image(self.acc_raw.data, self.zmask_meta, **kwargs) # TODO remove zmask_meta ref ) success = write_accessor_data_to_file(where / (prefix + '.tif'), annotated) return {'location': where.__str__(), 'filename': prefix + '.tif'} @@ -240,14 +240,12 @@ class RoiSet(object): zi_st[:, :, :, -1][lamap == 0] = 0 elif mask_type == 'boxes': - for bb in self.zmask_meta: - sl = bb['slice'] - zi_st[sl] = 1 + for roi in self: + zi_st[roi.relative_slice] = 1 return zi_st - # TODO: channel restriction as an argument def classify_by(self, name: str, channel: int, object_classification_model: InstanceSegmentationModel, ): # do this on a patch basis, i.e. only one object per frame