Skip to content
Snippets Groups Projects
Commit 820cc3e3 authored by Christopher Randolph Rhodes's avatar Christopher Randolph Rhodes
Browse files

Single annotator method for annotated zstacks

parent 52d16ed7
No related branches found
No related tags found
No related merge requests found
......@@ -3,48 +3,30 @@ from PIL import Image, ImageDraw, ImageFont
from model_server.base.process import rescale
def draw_boxes_on_2d_image(yx_img, boxes, **kwargs):
pilimg = Image.fromarray(np.copy(yx_img)) # drawing modifies array in-place
draw = ImageDraw.Draw(pilimg)
def draw_boxes_on_3d_image(roiset, draw_full_depth=False, **kwargs):
_, _, chroma, nz = roiset.acc_raw.shape
font_size = kwargs.get('font_size', 18)
linewidth = kwargs.get('linewidth', 4)
draw.font = ImageFont.truetype(font="arial.ttf", size=font_size)
for box in boxes:
y0 = box['info'].y0
y1 = box['info'].y1
x0 = box['info'].x0
x1 = box['info'].x1
xm = round((x0 + x1) / 2)
la = box['info'].label
zi = box['info'].zi
draw.rectangle([(x0, y0), (x1, y1)], outline='white', width=linewidth)
if kwargs.get('draw_label') is True:
draw.text((xm, y0), f'{la:04d}', fill='white', anchor='mb')
return pilimg
def draw_boxes_on_3d_image(yxcz_img, boxes, draw_full_depth=False, **kwargs):
# assert len(yxcz_img.shape) == 4
# nz = yxcz_img.shape[3]
_, _, chroma, nz = yxcz_img.shape
# assert yxcz_img.shape[2] == 1
annotated = np.zeros(yxcz_img.shape, dtype=yxcz_img.dtype)
annotated = np.zeros(roiset.acc_raw.shape, dtype=roiset.acc_raw.dtype)
for zi in range(0, nz):
if draw_full_depth:
zi_boxes = boxes
subset = roiset.get_df()
else:
zi_boxes = [bb for bb in boxes if bb['info'].zi == zi]
# annotated[:, :, 0, zi] = draw_boxes_on_2d_image(yxcz_img[:, :, 0, zi], zi_boxes, **kwargs)
subset = roiset.get_df().query(f'zi == {zi}')
for c in range(0, chroma):
annotated[:, :, c, zi] = draw_boxes_on_2d_image(yxcz_img[:, :, c, zi], zi_boxes, **kwargs)
pilimg = Image.fromarray(roiset.acc_raw.data[:, :, c, zi])
draw = ImageDraw.Draw(pilimg)
draw.font = ImageFont.truetype(font="arial.ttf", size=font_size)
for roi in subset.itertuples('Roi'):
xm = round((roi.x0 + roi.x1) / 2)
draw.rectangle([(roi.x0, roi.y0), (roi.x1, roi.y1)], outline='white', width=linewidth)
if kwargs.get('draw_label') is True:
draw.text((xm, roi.y0), f'{roi.label:04d}', fill='white', anchor='mb')
annotated[:, :, c, zi] = pilimg
if clip := kwargs.get('rescale_clip'):
......
......@@ -9,9 +9,9 @@ from skimage.io import imsave
from skimage.measure import find_contours, shannon_entropy
from tifffile import imwrite
from model_server.extensions.chaeo.accessors import MonoPatchStack, Multichannel3dPatchStack
from model_server.extensions.chaeo.accessors import MonoPatchStack
from model_server.extensions.chaeo.annotators import draw_box_on_patch, draw_contours_on_patch
from model_server.base.accessors import GenericImageDataAccessor, InMemoryDataAccessor
from model_server.base.accessors import InMemoryDataAccessor
from model_server.base.process import pad, rescale, resample_to_8bit
def _make_rgb(zs):
......@@ -95,7 +95,7 @@ def export_patch_masks(roiset, where: Path, pad_to: int = 256, prefix='mask', **
return exported
def get_patches_from_zmask_meta(
def get_roiset_patches(
roiset,
rescale_clip: float = 0.0,
pad_to: int = 256,
......@@ -225,7 +225,7 @@ def export_patches_from_zstack(
**kwargs
):
make_3d = kwargs.get('make_3d', False)
patches_df = get_patches_from_zmask_meta(roiset, **kwargs)
patches_df = get_roiset_patches(roiset, **kwargs)
def _export_patch(roi):
patch = InMemoryDataAccessor(roi.patch)
......
......@@ -140,7 +140,7 @@ class TestRoiSetMonoProducts(BaseTestRoiSetMonoProducts, unittest.TestCase):
def test_export_annotated_zstack(self):
roiset = self._make_roi_set()
file = roiset.export_annotated_zstack(
output_path / 'annotated_stack',
output_path / 'annotated_zstack',
)
result = generate_file_accessor(Path(file['location']) / file['filename'])
self.assertEqual(result.shape, roiset.acc_raw.shape)
......
......@@ -18,7 +18,7 @@ from model_server.extensions.chaeo.products import export_patches_from_zstack, e
from extensions.chaeo.params import RoiFilter, RoiSetMetaParams, RoiSetExportParams
from model_server.extensions.chaeo.accessors import MonoPatchStack, Multichannel3dPatchStack
from model_server.extensions.chaeo.process import mask_largest_object
from model_server.extensions.chaeo.products import get_patches_from_zmask_meta, get_patch_masks, export_patch_masks
from model_server.extensions.chaeo.products import get_roiset_patches, get_patch_masks, export_patch_masks
def get_label_ids(acc_seg_mask):
......@@ -76,7 +76,7 @@ class RoiSet(object):
'argmax': acc_raw.data.argmax(axis=3, keepdims=True)[:, :, 0, 0].astype('uint16'),
}
self.count = len(self.zmask_meta)
self.count = len(self._df)
self.object_id_labels = self.interm['label_map']
self.object_class_maps = {} # classification results
......@@ -193,9 +193,9 @@ class RoiSet(object):
def get_raw_patches(self, channel=None, pad_to=256, make_3d=False): # padded, un-annotated 2d patches
if channel:
patches_df = get_patches_from_zmask_meta(self, white_channel=channel, pad_to=pad_to)
patches_df = get_roiset_patches(self, white_channel=channel, pad_to=pad_to)
else:
patches_df = get_patches_from_zmask_meta(self, pad_to=pad_to)
patches_df = get_roiset_patches(self, pad_to=pad_to)
patches = list(patches_df['patch'])
if channel is not None or self.acc_raw.chroma == 1:
return MonoPatchStack(patches)
......@@ -203,8 +203,8 @@ class RoiSet(object):
return Multichannel3dPatchStack(patches)
def export_annotated_zstack(self, where, prefix='zstack', **kwargs):
annotated = InMemoryDataAccessor(
draw_boxes_on_3d_image(self.acc_raw.data, self.zmask_meta, **kwargs) # TODO remove zmask_meta ref
annotated = InMemoryDataAccessor(draw_boxes_on_3d_image(self, **kwargs)
# draw_boxes_on_3d_image(self.acc_raw.data, self.zmask_meta, **kwargs) # TODO remove zmask_meta ref
)
success = write_accessor_data_to_file(where / (prefix + '.tif'), annotated)
return {'location': where.__str__(), 'filename': prefix + '.tif'}
......@@ -240,14 +240,12 @@ class RoiSet(object):
zi_st[:, :, :, -1][lamap == 0] = 0
elif mask_type == 'boxes':
for bb in self.zmask_meta:
sl = bb['slice']
zi_st[sl] = 1
for roi in self:
zi_st[roi.relative_slice] = 1
return zi_st
# TODO: channel restriction as an argument
def classify_by(self, name: str, channel: int, object_classification_model: InstanceSegmentationModel, ):
# do this on a patch basis, i.e. only one object per frame
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment