from uuid import uuid4 import numpy as np import pandas as pd from skimage.measure import find_contours, label, regionprops_table from sklearn.preprocessing import PolynomialFeatures from sklearn.linear_model import LinearRegression from extensions.chaeo.annotators import draw_boxes_on_3d_image from extensions.chaeo.products import export_patches_from_zstack, export_multichannel_patches_from_zstack, export_patch_masks_from_zstack, get_patches_from_zmask_meta, get_patch_masks_from_zmask_meta from extensions.chaeo.params import AnnotatedZStackParams, PatchParams, RoiFilter, RoiSetExportParams from extensions.chaeo.process import mask_largest_object from model_server.accessors import GenericImageDataAccessor, InMemoryDataAccessor, write_accessor_data_to_file from model_server.models import InstanceSegmentationModel class RoiSet(object): def __init__( self, acc_mask: GenericImageDataAccessor, acc_raw: GenericImageDataAccessor, filters=None, mask_type='contours', expand_box_by=(0, 0), ): self.zmask, self.zmask_meta, self.df, self.interm = build_zmask_from_object_mask( acc_mask, acc_raw, filters=filters, mask_type=mask_type, expand_box_by=expand_box_by ) self.acc_raw = acc_raw self.count = len(self.zmask_meta) self.object_id_labels = self.interm['label_map'] def get_argmax(self): return self.interm.argmax def get_multichannel_projection(self): dff = self.df[self.df['keeper']] if self.count: projected = project_stack_from_focal_points( dff['centroid-0'].to_numpy(), dff['centroid-1'].to_numpy(), dff['zi'].to_numpy(), self.acc_raw, degree=4, ) else: # else just return MIP projected = self.acc_raw.data.max(axis=-1) return projected def get_raw_patches(self, channel): return get_patches_from_zmask_meta( self.acc_raw.get_one_channel_data(channel), self.zmask_meta ) def get_patch_masks(self): return get_patch_masks_from_zmask_meta(self.acc_raw, self.zmask_meta) def classify_by(self, channel, object_classification_model: InstanceSegmentationModel): # do this on a patch basis, i.e. only one object per frame obmap_patches = object_classification_model.label_instance_class( self.get_raw_patches(channel), self.get_patch_masks() ) lamap = self.object_id_labels output_map = np.zeros(lamap.shape, dtype=lamap.dtype) self.df['instance_class'] = np.nan # assign labels to object map: for ii in range(0, self.count): object_id = self.zmask_meta[ii]['info'].label result_patch = mask_largest_object(obmap_patches.iat(ii)) object_class = np.unique(result_patch)[1] output_map[self.object_id_labels == object_id] = object_class self.df[object_id, 'instance_class'] = object_class return InMemoryDataAccessor(output_map) # TODO: test def get_object_mask_by_id(self, obj_id): return self.object_id_labels == obj_id def get_object_mask_by_class(self, class_id): return self.object_id_labels == class_id # TODO: implement def get_object_patch_by_id(self, obj_id): pass def get_object_map(self, filters: RoiFilter): pass def run_exports(self, where, channel, prefix, params: RoiSetExportParams): if not self.count: return raw_ch = self.acc_raw.get_one_channel_data(channel) for k in params.dict().keys(): subdir = where / k pr = prefix kp = params.dict()[k] if k == 'meta' or kp is None: continue if k == 'patches_3d': files = export_patches_from_zstack( subdir, raw_ch, self.zmask_meta, prefix=pr, make_3d=True, **kp ) if k == 'patches_2d_for_annotation': files = export_multichannel_patches_from_zstack( subdir, self.acc_raw, self.zmask_meta, prefix=pr, make_3d=False, ch_white=channel, bounding_box_channel=1, bounding_box_linewidth=2, **kp, ) if k == 'patches_2d_for_training': files = export_multichannel_patches_from_zstack( subdir, self.acc_raw, self.zmask_meta, ch_white=channel, prefix=pr, make_3d=False, **kp ) df_patches = pd.DataFrame(files) self.df = pd.merge(self.df, df_patches, left_index=True, right_on='df_index').drop(columns='df_index') self.df['patch_id'] = self.df.apply(lambda _: uuid4(), axis=1) if k == 'patch_masks': export_patch_masks_from_zstack( subdir, raw_ch, self.zmask_meta, prefix=pr, ) if k == 'annotated_zstacks': annotated = InMemoryDataAccessor( draw_boxes_on_3d_image(raw_ch.data, self.zmask_meta, **kp) ) write_accessor_data_to_file(subdir / (pr + '.tif'), annotated) def build_zmask_from_object_mask( obmask: GenericImageDataAccessor, zstack: GenericImageDataAccessor, filters=None, mask_type='contours', expand_box_by=(0, 0), ): """ Given a 2D mask of objects, build a 3D mask, where each object's z-position is determined by the index of maximum intensity in z. Return this zmask and a list of each object's meta information. :param obmask: GenericImageDataAccessor monochrome 2D inary mask of objects :param zstack: GenericImageDataAccessor monochrome zstack of same Y, X dimension as obmask :param filters: dictionary of form {attribute: (min, max)}; valid attributes are 'area' and 'solidity' :param mask_type: if 'boxes', zmask is True in each object's complete bounding box; otherwise 'contours' :param expand_box_by: (xy, z) expands bounding box by (xy, z) pixels except where this hits a boundary :return: tuple (zmask, meta) np.ndarray: boolean mask of same size as stack List containing one Dict per object, with keys: info: object's properties from skimage.measure.regionprops_table, including bounding box (y0, y1, x0, x1) slice: named slice (np.s_) of (optionally) expanded bounding box relative_bounding_box: bounding box (y0, y1, x0, x1) in relative frame of (optionally) expanded bounding box contour: object's contour returned by skimage.measure.find_contours mask: mask of object in relative frame of (optionally) expanded bounding box pd.DataFrame: objects, including bounding, box information after filtering Dict of intermediate image products: label_map: np.ndarray (h x w) where each unique object has an integer label argmax: np.ndarray (h x w x 1 x 1) z-index of highest intensity in zstack """ # validate inputs # assert zstack.chroma == 1 assert mask_type in ('contours', 'boxes'), mask_type assert obmask.is_mask() assert obmask.chroma == 1 assert obmask.nz == 1 assert zstack.hw == obmask.hw # assign object labels and build object query lamap = label(obmask.data[:, :, 0, 0]).astype('uint16') query_str = 'label > 0' # always true if filters is not None: for k in filters.keys(): assert k in ('area', 'solidity') vmin, vmax = filters[k] assert vmin >= 0 query_str = query_str + f' & {k} > {vmin} & {k} < {vmax}' # build dataframe of objects, assign z index to each object argmax = zstack.data.argmax(axis=3, keepdims=True)[:, :, 0, 0].astype('uint16') df = ( pd.DataFrame( regionprops_table( lamap, intensity_image=argmax, properties=('label', 'area', 'intensity_mean', 'solidity', 'bbox', 'centroid') ) ) .rename( columns={ 'bbox-0': 'y0', 'bbox-1': 'x0', 'bbox-2': 'y1', 'bbox-3': 'x1', } ) ) df['zi'] = df['intensity_mean'].round().astype('int') df['keeper'] = False df.loc[df.query(query_str).index, 'keeper'] = True # make an object map where label is replaced by focus position in stack and background is -1 lut = np.zeros(lamap.max() + 1) - 1 lut[df.label] = df.zi # convert bounding boxes to numpy slice objects ebxy, ebz = expand_box_by h, w, c, nz = zstack.shape meta = [] for ob in df[df['keeper']].itertuples(name='LabeledObject'): y0 = max(ob.y0 - ebxy, 0) y1 = min(ob.y1 + ebxy, h) x0 = max(ob.x0 - ebxy, 0) x1 = min(ob.x1 + ebxy, w) z0 = max(ob.zi - ebz, 0) z1 = min(ob.zi + ebz, nz) # relative bounding box positions rbb = { 'y0': ob.y0 - y0, 'y1': ob.y1 - y0, 'x0': ob.x0 - x0, 'x1': ob.x1 - x0, } sl = np.s_[y0: y1, x0: x1, :, z0: z1 + 1] # compute contours obmask = (lamap == ob.label) contour = find_contours(obmask) mask = obmask[ob.y0: ob.y1, ob.x0: ob.x1] assert rbb['x1'] <= (x1 - x0) assert rbb['y1'] <= (y1 - y0) meta.append({ 'df_index': ob.Index, 'info': ob, 'slice': sl, 'relative_bounding_box': rbb, 'contour': contour, 'mask': mask }) # build mask z-stack zi_st = np.zeros(zstack.shape, dtype='bool') if mask_type == 'contours': zi_map = (lut[lamap] + 1.0).astype('int') idxs = np.array(zi_map) - 1 np.put_along_axis( zi_st, np.expand_dims(idxs, (2, 3)), 1, axis=3 ) # change background level from to 0 in final frame zi_st[:, :, :, -1][lamap == 0] = 0 elif mask_type == 'boxes': for bb in meta: sl = bb['slice'] zi_st[sl] = 1 # return intermediate image arrays interm = { 'label_map': lamap, 'argmax': argmax, } return zi_st, meta, df, interm def project_stack_from_focal_points( xx: np.ndarray, yy: np.ndarray, zz: np.ndarray, stack: GenericImageDataAccessor, degree: int = 2, ) -> np.ndarray: """ Given a set of 3D points, project a multichannel z-stack based on a surface fit of the provided points :param xx: vector of point x-coordinates :param yy: vector of point y-coordinates :param zz: vector of point z-coordinates :param stack: z-stack to project :param degree: order of polynomial to fit :return: multichannel 2d projected image array """ assert xx.shape == yy.shape assert xx.shape == zz.shape poly = PolynomialFeatures(degree=degree) X = np.stack([xx, yy]).T features = poly.fit_transform(X, zz) model = LinearRegression(fit_intercept=False) model.fit(features, zz) xy_indices = np.indices(stack.hw).reshape(2, -1).T xy_features = np.dot( poly.fit_transform(xy_indices, zz), model.coef_ ) zi_image = xy_features.reshape( stack.hw ).round().clip( 0, (stack.nz - 1) ).astype('uint16') return np.take_along_axis( stack.data, np.repeat( np.expand_dims(zi_image, (2, 3)), stack.chroma, axis=2 ), axis=3 )