diff --git a/model_server/base/accessors.py b/model_server/base/accessors.py index 11f048bb6e6ed35d1523cc1275ef03c408866c52..a9366156bb6a5fcd63267cd3fb6e319d13bfbe37 100644 --- a/model_server/base/accessors.py +++ b/model_server/base/accessors.py @@ -49,13 +49,51 @@ class GenericImageDataAccessor(ABC): nda = self.data.take(indices=carr, axis=self._ga('C')) return self._derived_accessor(nda) + + def get_zi(self, zi: int): + """ + Return a new accessor of a specific z-coordinate + """ + return self._derived_accessor( + self.data.take( + indices=[zi], + axis=self._ga('Z') + ) + ) + + def get_mip(self): + """ + Return a new accessor of maximum intensity projection (MIP) along z-axis + """ + return self.apply(lambda x: x.max(axis=self._ga('Z'), keepdims=True)) + def get_mono(self, channel: int, mip: bool = False): return self.get_channels([channel], mip=mip) + def get_z_argmax(self): + return self.apply(lambda x: x.argmax(axis=self.get_axis('Z'))) + + def get_focus_vector(self): + return self.data.sum(axis=(0, 1, 2)) + + @property + def data_xy(self) -> np.ndarray: + if not self.chroma == 1 and self.nz == 1: + raise InvalidDataShape('Can only return XY array from accessors with a single channel and single z-level') + else: + return self.data[:, :, 0, 0] + + @property + def data_xyz(self) -> np.ndarray: + if not self.chroma == 1: + raise InvalidDataShape('Can only return XYZ array from accessors with a single channel') + else: + return self.data[:, :, 0, :] + def _gc(self, channels): return self.get_channels(list(channels)) - def _unique(self): + def unique(self): return np.unique(self.data, return_counts=True) @property @@ -75,6 +113,15 @@ class GenericImageDataAccessor(ABC): def _ga(self, arg): return self.get_axis(arg) + def crop_hw(self, yxhw: tuple): + """ + Return subset of data cropped in X and Y + :param yxhw: tuple (Y, X, H, W) + :return: InMemoryDataAccessor of size (H x W), starting at (Y, X) + """ + y, x, h, w = yxhw + return InMemoryDataAccessor(self.data[y: (y + h), x: (x + w), :, :]) + @property def hw(self): """ @@ -120,6 +167,14 @@ class GenericImageDataAccessor(ABC): func(self.data) ) + @property + def info(self): + return { + 'shape_dict': self.shape_dict, + 'dtype': str(self.dtype), + 'filepath': '', + } + class InMemoryDataAccessor(GenericImageDataAccessor): def __init__(self, data): self._data = self.conform_data(data) @@ -139,6 +194,12 @@ class GenericImageFileAccessor(GenericImageDataAccessor): # image data is loaded def read(fp: Path): return generate_file_accessor(fp) + @property + def info(self): + d = super().info + d['filepath'] = self.fpath.__str__() + return d + class TifSingleSeriesFileAccessor(GenericImageFileAccessor): def __init__(self, fpath: Path): super().__init__(fpath) @@ -240,7 +301,7 @@ class CziImageFileAccessor(GenericImageFileAccessor): def write_accessor_data_to_file(fpath: Path, acc: GenericImageDataAccessor, mkdir=True) -> bool: """ - Export an image accessor to file. + Export an image accessor to file :param fpath: complete path including filename and extension :param acc: image accessor to be written :param mkdir: create any needed subdirectories in fpath if True @@ -287,7 +348,7 @@ def write_accessor_data_to_file(fpath: Path, acc: GenericImageDataAccessor, mkdi def generate_file_accessor(fpath): """ Given an image file path, return an image accessor, assuming the file is a supported format and represents - a single position array, which may be single or multi-channel, single plane or z-stack. + a single position array, which may be single or multichannel, single plane or z-stack. """ if str(fpath).upper().endswith('.TIF') or str(fpath).upper().endswith('.TIFF'): return TifSingleSeriesFileAccessor(fpath) @@ -379,6 +440,11 @@ class PatchStack(InMemoryDataAccessor): else: tifffile.imwrite(fpath, tzcyx, imagej=True) + def write(self, fp: Path, mkdir=True): + if mkdir: + fp.parent.mkdir(parents=True, exist_ok=True) + self.export_pyxcz(fp) + @property def shape_dict(self): return dict(zip(('P', 'Y', 'X', 'C', 'Z'), self.data.shape)) @@ -437,7 +503,6 @@ def make_patch_stack_from_file(fpath): # interpret t-dimension as patch positio return PatchStack(pyxcz) - class Error(Exception): pass diff --git a/model_server/base/process.py b/model_server/base/process.py index d475b063ca641922e4c08707bdc2bfc36049f9ea..dac347cd8c2e7a43e1ef82c4c00d9d119fdf303b 100644 --- a/model_server/base/process.py +++ b/model_server/base/process.py @@ -18,7 +18,7 @@ def is_mask(img): return True elif img.dtype == 'uint8': unique = np.unique(img) - if unique.shape[0] == 2 and np.all(unique == [0, 255]): + if unique.shape[0] <= 2 and np.all(unique == [0, 255]): return True return False @@ -136,7 +136,14 @@ def smooth(img: np.ndarray, sig: float) -> np.ndarray: :param sig: threshold parameter :return: smoothed image """ - return gaussian(img, sig) + ga = gaussian(img, sig, preserve_range=True) + if is_mask(img): + if img.dtype == 'bool': + return ga > ga.mean() + elif img.dtype == 'uint8': + return (255 * (ga > ga.mean())).astype('uint8') + else: + return ga class Error(Exception): pass diff --git a/model_server/base/roiset.py b/model_server/base/roiset.py index 2b52d5a3db08ae68cf39ad57fdecd231e791d6ce..4eeda4909fa4e4c813698b0dc315f7c3b2c980e8 100644 --- a/model_server/base/roiset.py +++ b/model_server/base/roiset.py @@ -1,7 +1,8 @@ +import itertools from math import sqrt, floor from pathlib import Path -import re -from typing import List, Union +from typing import Dict, List, Union +from typing_extensions import Self from uuid import uuid4 import numpy as np @@ -10,9 +11,9 @@ from pydantic import BaseModel from scipy.stats import moment from skimage.filters import sobel -from skimage.measure import label, regionprops_table, shannon_entropy, find_contours -from sklearn.preprocessing import PolynomialFeatures -from sklearn.linear_model import LinearRegression +from skimage import draw +from skimage.measure import approximate_polygon, find_contours, label, points_in_poly, regionprops, regionprops_table, shannon_entropy +from skimage.morphology import binary_dilation, disk from .accessors import GenericImageDataAccessor, InMemoryDataAccessor, write_accessor_data_to_file from .models import InstanceSegmentationModel @@ -51,6 +52,7 @@ class RoiFilter(BaseModel): class RoiSetMetaParams(BaseModel): filters: Union[RoiFilter, None] = None expand_box_by: List[int] = [128, 0] + deproject_channel: Union[int, None] = None class RoiSetExportParams(BaseModel): @@ -62,8 +64,7 @@ class RoiSetExportParams(BaseModel): derived_channels: bool = False - -def _get_label_ids(acc_seg_mask: GenericImageDataAccessor, allow_3d=False, connect_3d=True) -> InMemoryDataAccessor: +def get_label_ids(acc_seg_mask: GenericImageDataAccessor, allow_3d=False, connect_3d=True) -> InMemoryDataAccessor: """ Convert binary segmentation mask into either a 2D or 3D object identities map :param acc_seg_mask: binary segmentation mask (mono) of either two or three dimensions @@ -73,7 +74,7 @@ def _get_label_ids(acc_seg_mask: GenericImageDataAccessor, allow_3d=False, conne """ if allow_3d and connect_3d: nda_la = label( - acc_seg_mask.data[:, :, 0, :], + acc_seg_mask.data_xyz, connectivity=3, ).astype('uint16') return InMemoryDataAccessor(np.expand_dims(nda_la, 2)) @@ -82,7 +83,7 @@ def _get_label_ids(acc_seg_mask: GenericImageDataAccessor, allow_3d=False, conne la_3d = np.zeros((*acc_seg_mask.hw, 1, acc_seg_mask.nz), dtype='uint16') for zi in range(0, acc_seg_mask.nz): la_2d = label( - acc_seg_mask.data[:, :, 0, zi], + acc_seg_mask.data_xyz[:, :, zi], connectivity=2, ).astype('uint16') la_2d[la_2d > 0] = la_2d[la_2d > 0] + nla @@ -92,13 +93,13 @@ def _get_label_ids(acc_seg_mask: GenericImageDataAccessor, allow_3d=False, conne else: return InMemoryDataAccessor( label( - acc_seg_mask.data[:, :, 0, :].max(axis=-1), + acc_seg_mask.get_mip().data_xy, connectivity=1, ).astype('uint16') ) -def _focus_metrics(): +def focus_metrics(): return { 'max_intensity': lambda x: np.max(x), 'stdev': lambda x: np.std(x), @@ -109,7 +110,216 @@ def _focus_metrics(): } -def _safe_add(a, g, b): +def filter_df(df: pd.DataFrame, filters: RoiFilter = None) -> pd.DataFrame: + query_str = 'label > 0' # always true + if filters is not None: # parse filters + for k, val in filters.dict(exclude_unset=True).items(): + assert k in ('area') + vmin = val['min'] + vmax = val['max'] + assert vmin >= 0 + query_str = query_str + f' & {k} > {vmin} & {k} < {vmax}' + return df.loc[df.query(query_str).index, :] + + +def filter_df_overlap_bbox(df1: pd.DataFrame, df2: pd.DataFrame = None) -> pd.DataFrame: + """ + If passed a single DataFrame, return the subset whose bounding boxes overlap in 3D space. If passed two DataFrames, + return the subset where a ROI in the first overlaps a ROI in the second. May return duplicates entries where a ROI + overlaps with multiple neighbors. + :param df1: DataFrame with potentially overlapping bounding boxes + :param df2: (optional) second DataFrame + :return DataFrame describing subset of overlapping ROIs + bbox_overlaps_with: index of ROI that overlaps + bbox_intersec: pixel area of intersecting region + """ + + def _compare(r0, r1): + olx = (r0.x0 < r1.x1) and (r0.x1 > r1.x0) + oly = (r0.y0 < r1.y1) and (r0.y1 > r1.y0) + olz = (r0.zi == r1.zi) + return olx and oly and olz + + def _intersec(r0, r1): + return (r0.x1 - r1.x0) * (r0.y1 - r1.y0) + + first = [] + second = [] + intersec = [] + + if df2 is not None: + for pair in itertools.product(df1.index, df2.index): + if _compare(df1.iloc[pair[0]], df2.iloc[pair[1]]): + first.append(pair[0]) + second.append(pair[1]) + intersec.append( + _intersec(df1.iloc[pair[0]], df2.iloc[pair[1]]) + ) + else: + for pair in itertools.combinations(df1.index, 2): + if _compare(df1.iloc[pair[0]], df1.iloc[pair[1]]): + first.append(pair[0]) + second.append(pair[1]) + first.append(pair[1]) + second.append(pair[0]) + isc = _intersec(df1.iloc[pair[0]], df1.iloc[pair[1]]) + intersec.append(isc) + intersec.append(isc) + + sdf = df1.iloc[first] + sdf.loc[:, 'overlaps_with'] = second + sdf.loc[:, 'bbox_intersec'] = intersec + return sdf + + +def filter_df_overlap_seg(df1: pd.DataFrame, df2: pd.DataFrame = None) -> pd.DataFrame: + """ + If passed a single DataFrame, return the subset whose segmentations overlap in 3D space. If passed two DataFrames, + return the subset where a ROI in the first overlaps a ROI in the second. May return duplicates entries where a ROI + overlaps with multiple neighbors. + :param df1: DataFrame with potentially overlapping bounding boxes + :param df2: (optional) second DataFrame + :return DataFrame describing subset of overlapping ROIs + seg_overlaps_with: index of ROI that overlaps + seg_intersec: pixel area of intersecting region + seg_iou: intersection over union + """ + + dfbb = filter_df_overlap_bbox(df1, df2) + + def _overlap_seg(r): + roi1 = df1.loc[r.name] + if df2 is not None: + roi2 = df2.loc[r.overlaps_with] + else: + roi2 = df1.loc[r.overlaps_with] + ex0 = min(roi1.x0, roi2.x0, roi1.x1, roi2.x1) + ew = max(roi1.x0, roi2.x0, roi1.x1, roi2.x1) - ex0 + ey0 = min(roi1.y0, roi2.y0, roi1.y1, roi2.y1) + eh = max(roi1.y0, roi2.y0, roi1.y1, roi2.y1) - ey0 + emask = np.zeros((eh, ew), dtype='uint8') + sl1 = np.s_[(roi1.y0 - ey0): (roi1.y1 - ey0), (roi1.x0 - ex0): (roi1.x1 - ex0)] + sl2 = np.s_[(roi2.y0 - ey0): (roi2.y1 - ey0), (roi2.x0 - ex0): (roi2.x1 - ex0)] + emask[sl1] = roi1.binary_mask + emask[sl2] = emask[sl2] + roi2.binary_mask + return emask + + emasks = dfbb.apply(_overlap_seg, axis=1) + dfbb['seg_overlaps'] = emasks.apply(lambda x: np.any(x > 1)) + dfbb['seg_intersec'] = emasks.apply(lambda x: (x == 2).sum()) + dfbb['seg_iou'] = emasks.apply(lambda x: (x == 2).sum() / (x > 0).sum()) + return dfbb + + +def make_df_from_object_ids(acc_raw, acc_obj_ids, expand_box_by, deproject_channel=None) -> pd.DataFrame: + """ + Build dataframe that associate object IDs with summary stats; + :param acc_raw: accessor to raw image data + :param acc_obj_ids: accessor to map of object IDs + :param expand_box_by: number of pixels to expand bounding box in all directions (without exceeding image boundary) + :param deproject_channel: if objects' z-coordinates are not specified, compute them based on argmax of this channel + :return: pd.DataFrame + """ + # build dataframe of objects, assign z index to each object + + if acc_obj_ids.nz == 1 and acc_raw.nz > 1: + + if deproject_channel is None or deproject_channel >= acc_raw.chroma or deproject_channel < 0: + if acc_raw.chroma == 1: + deproject_channel = 0 + else: + raise NoDeprojectChannelSpecifiedError( + f'When labeling objects, either their z-coordinates or a valid deprojection channel are required.' + ) + acc_raw.get_mono(deproject_channel) + + zi_map = acc_raw.get_mono(deproject_channel).get_z_argmax().data_xy.astype('uint16') + assert len(zi_map.shape) == 2 + df = pd.DataFrame(regionprops_table( + acc_obj_ids.data_xy, + intensity_image=zi_map, + properties=('label', 'area', 'intensity_mean', 'bbox') + )).rename(columns={'bbox-0': 'y0', 'bbox-1': 'x0', 'bbox-2': 'y1', 'bbox-3': 'x1'}) + df['zi'] = df['intensity_mean'].round().astype('int') + + else: # objects' z-coordinates come from arg of max count in object identities map + df = pd.DataFrame(regionprops_table( + acc_obj_ids.data_xyz, + properties=('label', 'area', 'bbox') + )).rename(columns={ + 'bbox-0': 'y0', 'bbox-1': 'x0', 'bbox-2': 'z0', 'bbox-3': 'y1', 'bbox-4': 'x1', 'bbox-5': 'z1' + }) + + def _get_zi_from_label(la): + return acc_obj_ids.apply(lambda x: x == la).get_focus_vector().argmax() + + df['zi'] = df['label'].apply(_get_zi_from_label) + + df = df_insert_slices(df, acc_raw.shape_dict, expand_box_by) + + def _make_binary_mask(r): + acc = InMemoryDataAccessor(acc_obj_ids.data == r.label) + cropped = acc.get_mono(0, mip=True).crop_hw((r.y0, r.x0, (r.y1 - r.y0), (r.x1 - r.x0))).data_xy + return cropped + + + df['binary_mask'] = df.apply( + _make_binary_mask, + axis=1, + result_type='reduce', + ) + return df + + +def df_insert_slices(df: pd.DataFrame, sd: dict, expand_box_by) -> pd.DataFrame: + h = sd['Y'] + w = sd['X'] + nz = sd['Z'] + + df['h'] = df['y1'] - df['y0'] + df['w'] = df['x1'] - df['x0'] + ebxy, ebz = expand_box_by + df['ebb_y0'] = (df.y0 - ebxy).apply(lambda x: max(x, 0)) + df['ebb_y1'] = (df.y1 + ebxy).apply(lambda x: min(x, h)) + df['ebb_x0'] = (df.x0 - ebxy).apply(lambda x: max(x, 0)) + df['ebb_x1'] = (df.x1 + ebxy).apply(lambda x: min(x, w)) + df['ebb_z0'] = (df.zi - ebz).apply(lambda x: max(x, 0)) + df['ebb_z1'] = (df.zi + ebz).apply(lambda x: min(x, nz)) + df['ebb_h'] = df['ebb_y1'] - df['ebb_y0'] + df['ebb_w'] = df['ebb_x1'] - df['ebb_x0'] + df['ebb_nz'] = df['ebb_z1'] - df['ebb_z0'] + 1 + + # compute relative bounding boxes + df['rel_y0'] = df.y0 - df.ebb_y0 + df['rel_y1'] = df.y1 - df.ebb_y0 + df['rel_x0'] = df.x0 - df.ebb_x0 + df['rel_x1'] = df.x1 - df.ebb_x0 + + assert np.all(df['rel_x1'] <= (df['ebb_x1'] - df['ebb_x0'])) + assert np.all(df['rel_y1'] <= (df['ebb_y1'] - df['ebb_y0'])) + + df['slice'] = df.apply( + lambda r: + np.s_[int(r.y0): int(r.y1), int(r.x0): int(r.x1), :, int(r.zi): int(r.zi + 1)], + axis=1, + result_type='reduce', + ) + df['expanded_slice'] = df.apply( + lambda r: + np.s_[int(r.ebb_y0): int(r.ebb_y1), int(r.ebb_x0): int(r.ebb_x1), :, int(r.ebb_z0): int(r.ebb_z1) + 1], + axis=1, + result_type='reduce', + ) + df['relative_slice'] = df.apply( + lambda r: + np.s_[int(r.rel_y0): int(r.rel_y1), int(r.rel_x0): int(r.rel_x1), :, :], + axis=1, + result_type='reduce', + ) + return df + + +def safe_add(a, g, b): assert a.dtype == b.dtype assert a.shape == b.shape assert g >= 0.0 @@ -120,45 +330,125 @@ def _safe_add(a, g, b): np.iinfo(a.dtype).max ).astype(a.dtype) +def make_object_ids_from_df(df: pd.DataFrame, sd: dict) -> InMemoryDataAccessor: + id_mask = np.zeros((sd['Y'], sd['X'], 1, sd['Z']), dtype='uint16') + if 'binary_mask' not in df.columns: + raise MissingSegmentationError('RoiSet dataframe does not contain segmentation') + + def _label_obj(r): + sl = np.s_[r.y0:r.y1, r.x0:r.x1, :, r.zi:r.zi + 1] + mask = np.expand_dims(r.binary_mask, (2, 3)) + id_mask[sl] = id_mask[sl] + r.label * mask + + df.apply(_label_obj, axis=1) + return InMemoryDataAccessor(id_mask) + class RoiSet(object): def __init__( self, acc_raw: GenericImageDataAccessor, - acc_obj_ids: GenericImageDataAccessor, + df: pd.DataFrame, params: RoiSetMetaParams = RoiSetMetaParams(), ): """ A set of regions of interest, referenced by their positions and contours in the YXCZ space of stack acc_raw. RoiSet contains their internal state, which may be exported as patches, maps, and other products by export methods. :param acc_raw: accessor to a generally a multichannel z-stack - :param acc_obj_ids: accessor to a 2D single-channel object identities map, where each pixel's intensity - labels its membership in a connected object + :param df: dataframe containing at minimum bounding box and segmentation mask information :param params: optional arguments that influence the definition and representation of ROIs """ - assert acc_obj_ids.chroma == 1 - self.acc_obj_ids = acc_obj_ids self.acc_raw = acc_raw self.accs_derived = [] self.params = params - self._df = self.filter_df( - self.make_df( - self.acc_raw, self.acc_obj_ids, expand_box_by=params.expand_box_by - ), - params.filters, - ) - + self._df = df self.count = len(self._df) - self.object_class_maps = {} # classification results def __iter__(self): """Expose ROI meta information via the Pandas.DataFrame API""" return self._df.itertuples(name='Roi') @staticmethod - def from_segmentation( + def from_object_ids( + acc_raw: GenericImageDataAccessor, + acc_obj_ids: GenericImageDataAccessor, + params: RoiSetMetaParams = RoiSetMetaParams(), + ): + """ + + :param acc_raw: + :param acc_obj_ids: accessor to a 2D single-channel object identities map, where each pixel's intensity + labels its membership in a connected object + :param params: + :return: + """ + assert acc_obj_ids.chroma == 1 + + df = filter_df( + make_df_from_object_ids( + acc_raw, acc_obj_ids, + expand_box_by=params.expand_box_by, + deproject_channel=params.deproject_channel, + ), + params.filters, + ) + + return RoiSet(acc_raw, df, params) + + @staticmethod + def from_bounding_boxes( + acc_raw: GenericImageDataAccessor, + bbox_yxhw: List[Dict], + bbox_zi: Union[List[int], int] = None, + params: RoiSetMetaParams = RoiSetMetaParams() + ): + bbox_df = pd.DataFrame(bbox_yxhw) + if list(bbox_df.columns.str.upper().sort_values()) != ['H', 'W', 'X', 'Y']: + raise BoundingBoxError(f'Expecting bounding box coordinates Y, X, H, and W, not {list(bbox_df.columns)}') + + + # deproject if zi is not specified + if bbox_zi is None: + dch = params.deproject_channel + if dch is None or dch >= acc_raw.chroma or dch < 0: + if acc_raw.chroma == 1: + dch = 0 + else: + raise NoDeprojectChannelSpecifiedError( + f'When labeling objects, either their z-coordinates or a valid deprojection channel are required.' + ) + bbox_df['zi'] = acc_raw.get_mono(dch).get_focus_vector().argmax() + else: + bbox_df['zi'] = bbox_zi + + bbox_df['y0'] = bbox_df['y'] + bbox_df['x0'] = bbox_df['x'] + bbox_df['y1'] = bbox_df['y0'] + bbox_df['h'] + bbox_df['x1'] = bbox_df['x0'] + bbox_df['w'] + bbox_df['label'] = bbox_df.index + + + df = df_insert_slices( + bbox_df[['y0', 'x0', 'y1', 'x1', 'zi', 'label']], + acc_raw.shape_dict, + params.expand_box_by, + ) + + def _make_binary_mask(r): + return np.ones((r.h, r.w), dtype=bool) + + df['binary_mask'] = df.apply( + _make_binary_mask, + axis=1, + result_type='reduce', + ) + return RoiSet(acc_raw, df, params) + + + @staticmethod + def from_binary_mask( acc_raw: GenericImageDataAccessor, acc_seg: GenericImageDataAccessor, allow_3d=False, @@ -169,102 +459,43 @@ class RoiSet(object): Create a RoiSet from a binary segmentation mask (either 2D or 3D) :param acc_raw: accessor to a generally a multichannel z-stack :param acc_seg: accessor of a binary segmentation mask (mono) of either two or three dimensions - :param allow_3d: return a 3D map if True; return a 2D map of the mask's maximum intensity project if False + :param allow_3d: use a 3D map if True; use a 2D map of the mask's maximum intensity project if False :param connect_3d: objects can span multiple z-positions if True; objects are unique to a single z if False :param params: optional arguments that influence the definition and representation of ROIs - :return: object identities map """ - return RoiSet(acc_raw, _get_label_ids(acc_seg, allow_3d=allow_3d, connect_3d=connect_3d), params) + return RoiSet.from_object_ids( + acc_raw, + get_label_ids( + acc_seg, + allow_3d=allow_3d, + connect_3d=connect_3d + ), + params + ) @staticmethod - def make_df(acc_raw, acc_obj_ids, expand_box_by) -> pd.DataFrame: + def from_polygons_2d( + acc_raw, + polygons: List[np.ndarray], + params: RoiSetMetaParams = RoiSetMetaParams() + ): """ - Build dataframe associate object IDs with summary stats - :param acc_raw: accessor to raw image data - :param acc_obj_ids: accessor to map of object IDs - :param expand_box_by: number of pixels to expand bounding box in all directions (without exceeding image boundary) - # :param deproject: assign object's z-position based on argmax of raw data if True - :return: pd.DataFrame + Create a RoiSet where objects are defined from a list of polygon coordinates + :param acc_raw: accessor to a generally a multichannel z-stack + :param polygons: list of (variable x 2) np.ndarrays describing (x, y) polymer coordinates + :param params: optional arguments that influence the definition and representation of ROIs """ - # build dataframe of objects, assign z index to each object - - if acc_obj_ids.nz == 1: # deproject objects' z-coordinates from argmax of raw image - df = pd.DataFrame(regionprops_table( - acc_obj_ids.data[:, :, 0, 0], - intensity_image=acc_raw.data.argmax(axis=3, keepdims=True)[:, :, 0, 0].astype('uint16'), - properties=('label', 'area', 'intensity_mean', 'bbox', 'centroid') - )).rename(columns={'bbox-0': 'y0', 'bbox-1': 'x0', 'bbox-2': 'y1', 'bbox-3': 'x1'}) - df['zi'] = df['intensity_mean'].round().astype('int') - - else: # objects' z-coordinates come from arg of max count in object identities map - df = pd.DataFrame(regionprops_table( - acc_obj_ids.data[:, :, 0, :], - properties=('label', 'area', 'bbox', 'centroid') - )).rename(columns={ - 'bbox-0': 'y0', 'bbox-1': 'x0', 'bbox-2': 'z0', 'bbox-3': 'y1', 'bbox-4': 'x1', 'bbox-5': 'z1' - }) - df['zi'] = df['label'].apply(lambda x: (acc_obj_ids.data == x).sum(axis=(0, 1, 2)).argmax()) - - # compute expanded bounding boxes - h, w, c, nz = acc_raw.shape - df['h'] = df['y1'] - df['y0'] - df['w'] = df['x1'] - df['x0'] - ebxy, ebz = expand_box_by - df['ebb_y0'] = (df.y0 - ebxy).apply(lambda x: max(x, 0)) - df['ebb_y1'] = (df.y1 + ebxy).apply(lambda x: min(x, h)) - df['ebb_x0'] = (df.x0 - ebxy).apply(lambda x: max(x, 0)) - df['ebb_x1'] = (df.x1 + ebxy).apply(lambda x: min(x, w)) - df['ebb_z0'] = (df.zi - ebz).apply(lambda x: max(x, 0)) - df['ebb_z1'] = (df.zi + ebz).apply(lambda x: min(x, nz)) - df['ebb_h'] = df['ebb_y1'] - df['ebb_y0'] - df['ebb_w'] = df['ebb_x1'] - df['ebb_x0'] - df['ebb_nz'] = df['ebb_z1'] - df['ebb_z0'] + 1 - - # compute relative bounding boxes - df['rel_y0'] = df.y0 - df.ebb_y0 - df['rel_y1'] = df.y1 - df.ebb_y0 - df['rel_x0'] = df.x0 - df.ebb_x0 - df['rel_x1'] = df.x1 - df.ebb_x0 - - assert np.all(df['rel_x1'] <= (df['ebb_x1'] - df['ebb_x0'])) - assert np.all(df['rel_y1'] <= (df['ebb_y1'] - df['ebb_y0'])) - - df['slice'] = df.apply( - lambda r: - np.s_[int(r.y0): int(r.y1), int(r.x0): int(r.x1), :, int(r.zi): int(r.zi + 1)], - axis=1, - result_type='reduce', - ) - df['expanded_slice'] = df.apply( - lambda r: - np.s_[int(r.ebb_y0): int(r.ebb_y1), int(r.ebb_x0): int(r.ebb_x1), :, int(r.ebb_z0): int(r.ebb_z1) + 1], - axis=1, - result_type='reduce', - ) - df['relative_slice'] = df.apply( - lambda r: - np.s_[int(r.rel_y0): int(r.rel_y1), int(r.rel_x0): int(r.rel_x1), :, :], - axis=1, - result_type='reduce', - ) - df['binary_mask'] = df.apply( - lambda r: (acc_obj_ids.data == r.label).max(axis=-1)[r.y0: r.y1, r.x0: r.x1, 0], - axis=1, - result_type='reduce', + mask = np.zeros(acc_raw.get_mono(0, mip=True).shape, dtype=bool) + for p in polygons: + sl = draw.polygon(p[:, 1], p[:, 0]) + mask[sl] = True + return RoiSet.from_binary_mask( + acc_raw, + InMemoryDataAccessor(mask), + allow_3d=False, + connect_3d=False, + params=params, ) - return df - - @staticmethod - def filter_df(df: pd.DataFrame, filters: RoiFilter = None) -> pd.DataFrame: - query_str = 'label > 0' # always true - if filters is not None: # parse filters - for k, val in filters.dict(exclude_unset=True).items(): - assert k in ('area') - vmin = val['min'] - vmax = val['max'] - assert vmin >= 0 - query_str = query_str + f' & {k} > {vmin} & {k} < {vmax}' - return df.loc[df.query(query_str).index, :] def get_df(self) -> pd.DataFrame: return self._df @@ -275,19 +506,6 @@ class RoiSet(object): def add_df_col(self, name, se: pd.Series) -> None: self._df[name] = se - def get_multichannel_projection(self): - if self.count: - projected = project_stack_from_focal_points( - self._df['centroid-0'].to_numpy(), - self._df['centroid-1'].to_numpy(), - self._df['zi'].to_numpy(), - self.acc_raw, - degree=4, - ) - else: # else just return MIP - projected = self.acc_raw.data.max(axis=-1) - return projected - def get_patches_acc(self, channels: list = None, **kwargs) -> PatchStack: # padded, un-annotated 2d patches if channels and len(channels) == 1: patches_df = self.get_patches(white_channel=channels[0], **kwargs) @@ -301,7 +519,7 @@ class RoiSet(object): write_accessor_data_to_file(fp, annotated) return (prefix + '.tif') - def get_zmask(self, mask_type='boxes'): + def get_zmask(self, mask_type='boxes') -> np.ndarray: """ Return a mask of same dimensionality as raw data @@ -352,7 +570,7 @@ class RoiSet(object): :param channels: list of nc raw input channels to send to classifier :param object_classification_model: InstanceSegmentation model object :param derived_channel_functions: list of functions that each receive a PatchStack accessor with nc channels and - return a single-channel PatchStack accessor of the same shape + that return a single-channel PatchStack accessor of the same shape :return: None """ @@ -385,26 +603,67 @@ class RoiSet(object): self.get_patch_masks_acc(expanded=False, pad_to=None) ) - om = np.zeros(self.acc_obj_ids.shape, self.acc_obj_ids.dtype) + se = pd.Series(dtype='Int64', index=self._df.index) - self._df['classify_by_' + name] = pd.Series(dtype='Int64') - - # assign labels to object map: for i, roi in enumerate(self): oc = np.unique( mask_largest_object( obmap_patches.iat(i).data ) )[-1] - self._df.loc[roi.Index, 'classify_by_' + name] = oc - om[self.acc_obj_ids.data == roi.label] = oc - self.object_class_maps[name] = InMemoryDataAccessor(om) + se[roi.Index] = oc + self.set_classification(f'classify_by_{name}', se) - def export_dataframe(self, csv_path: Path) -> str: - csv_path.parent.mkdir(parents=True, exist_ok=True) - self._df.drop(['expanded_slice', 'slice', 'relative_slice', 'binary_mask'], axis=1).to_csv(csv_path, index=False) - return csv_path.name + def get_instance_classification(self, roiset_from: Self, iou_min: float = 0.5) -> pd.DataFrame: + """ + Transfer instance classification labels from another RoiSet based on intersection over union (IOU) similarity + :param roiset_from: RoiSet source of classification labels, same shape as this RoiSet + :param iou_min: threshold IOU below which a label is not transferred + :return DataFrame of source RoiSet, including overlaps with this RoiSet and IOU metric + """ + if self.acc_raw.shape != roiset_from.acc_raw.shape: + raise ShapeMismatchError( + f'Expecting two RoiSets of same shape: {self.acc_raw.shape} != {roiset_from.acc_raw.shape}') + + columns = [f'classify_by_{c}' for c in roiset_from.classification_columns] + + if len(columns) == 0: + raise MissingInstanceLabelsError('Expecting at least on instance classification channel but none found') + + df_overlaps = filter_df_overlap_seg( + roiset_from.get_df(), + self.get_df() + ) + df_overlaps['transfer'] = df_overlaps.seg_iou > iou_min + df_merge = pd.merge( + roiset_from.get_df()[columns], + df_overlaps.loc[df_overlaps.transfer, ['overlaps_with']], + left_index=True, + right_index=True, + how='inner', + ).set_index('overlaps_with') + for col in columns: + self.set_classification(col, df_merge[col]) + + return df_overlaps + + def get_object_class_map(self, name: str) -> InMemoryDataAccessor: + """ + For a given classification result, return a map where object IDs are replaced by each object's class + :param name: name of the classification result, same as passed to RoiSet.classify_by() + :return: accessor of object class map + """ + colname = ('classify_by_' + name) + assert colname in self._df.columns + obj_ids = self.acc_obj_ids + om = np.zeros(obj_ids.shape, obj_ids.dtype) + + def _label_object_class(roi): + om[self.acc_obj_ids.data == roi.label] = roi[colname] + self._df.apply(_label_object_class, axis=1) + + return InMemoryDataAccessor(om) def export_patch_masks(self, where: Path, pad_to: int = None, prefix='mask', expanded=False) -> pd.DataFrame: @@ -446,15 +705,13 @@ class RoiSet(object): patch = np.zeros((roi.ebb_h, roi.ebb_w, 1, 1), dtype='uint8') patch[roi.relative_slice][:, :, 0, 0] = roi.binary_mask * 255 else: - patch = np.zeros((roi.y1 - roi.y0, roi.x1 - roi.x0, 1, 1), dtype='uint8') - patch[:, :, 0, 0] = roi.binary_mask * 255 - + patch = (roi.binary_mask * 255).astype('uint8') if pad_to: patch = pad(patch, pad_to) - return patch + return np.expand_dims(patch, (2, 3)) dfe = self._df.copy() - dfe['patch_mask'] = dfe.apply(lambda r: _make_patch_mask(r), axis=1) + dfe['patch_mask'] = dfe.apply(_make_patch_mask, axis=1) return dfe def get_patch_masks_acc(self, **kwargs) -> PatchStack: @@ -482,7 +739,8 @@ class RoiSet(object): if white_channel: assert white_channel < raw.chroma - stack = raw.data[:, :, [white_channel, white_channel, white_channel], :] + mono = raw.get_mono(white_channel).data_xyz + stack = np.stack([mono, mono, mono], axis=2) else: stack = np.zeros([*raw.shape[0:2], 3, raw.shape[3]], dtype=raw.dtype) @@ -491,10 +749,10 @@ class RoiSet(object): continue assert isinstance(ci, int) assert ci < raw.chroma - stack[:, :, ii, :] = _safe_add( + stack[:, :, ii, :] = safe_add( stack[:, :, ii, :], # either black or grayscale channel rgb_overlay_weights[ii], - raw.data[:, :, ci, :] + raw.get_mono(ci).data_xyz ) else: if white_channel is not None: # interpret as just a single channel @@ -509,9 +767,10 @@ class RoiSet(object): annotate_rgb = True break if annotate_rgb: # make RGB patches anyway to include annotation color - stack = raw.data[:, :, [white_channel, white_channel, white_channel], :] + mono = raw.get_mono(white_channel).data_xyz + stack = np.stack([mono, mono, mono], axis=2) else: # make monochrome patches - stack = raw.data[:, :, [white_channel], :] + stack = raw.get_mono(white_channel).data elif kwargs.get('channels'): stack = raw.get_channels(kwargs['channels']).data else: @@ -533,7 +792,7 @@ class RoiSet(object): # make a 2d patch, find optimal z-position determined by focus_metric function on each channel separately elif focus_metric is not None: - foc = _focus_metrics()[focus_metric] + foc = focus_metrics()[focus_metric] patch = np.zeros([ph, pw, pc, 1], dtype=patch3d.dtype) @@ -593,6 +852,22 @@ class RoiSet(object): dfe['patch'] = dfe.apply(lambda r: _make_patch(r), axis=1) return dfe + @property + def classification_columns(self): + """ + Return list of columns that describe instance classification results + """ + pr = 'classify_by_' + return [c.split(pr)[1] for c in self._df.columns if c.startswith(pr)] + + def set_classification(self, colname: str, se: pd.Series): + """ + Set instance classification result as a column addition on dataframe + :param colname: name of classification result + :param se: series containing class information + """ + self._df[colname] = se + def run_exports(self, where: Path, channel, prefix, params: RoiSetExportParams) -> dict: """ Export various representations of ROIs, e.g. patches, annotated stacks, and object maps. @@ -633,10 +908,10 @@ class RoiSet(object): if k == 'annotated_zstacks': record[k] = str(Path(k) / self.export_annotated_zstack(subdir, prefix=pr, **kp)) if k == 'object_classes': - for kc, acc in self.object_class_maps.items(): - fp = subdir / kc / (pr + '.tif') - write_accessor_data_to_file(fp, acc) - record[f'{k}_{kc}'] = str(fp) + for n in self.classification_columns: + fp = subdir / n / (pr + '.tif') + write_accessor_data_to_file(fp, self.get_object_class_map(n)) + record[f'{k}_{n}'] = str(fp) if k == 'derived_channels': record[k] = [] for di, dacc in enumerate(self.accs_derived): @@ -650,7 +925,7 @@ class RoiSet(object): return record - def serialize(self, where: Path, prefix='') -> dict: + def serialize(self, where: Path, prefix='roiset') -> dict: """ Export the minimal information needed to recreate RoiSet object, i.e. CSV data file and tight patch masks :param where: path of directory in which to write files @@ -658,94 +933,136 @@ class RoiSet(object): :return: nested dict of Path objects describing the locations of export products """ record = {} - df_exp = self.export_patch_masks( - where / 'tight_patch_masks', - prefix=prefix, - pad_to=None, - expanded=False - ) - se_pa = df_exp.patch_mask_path.apply( - lambda x: str(Path('tight_patch_masks') / x) - ).rename('tight_patch_masks_path') - self._df = self._df.join(se_pa) - df_fn = self.export_dataframe(where / 'dataframe' / (prefix + '.csv')) - record['dataframe'] = str(Path('dataframe') / df_fn) - record['tight_patch_masks'] = list(se_pa) + if not self._df.binary_mask.apply(lambda x: np.all(x)).all(): # binary masks aren't just all True + df_exp = self.export_patch_masks( + where / 'tight_patch_masks', + prefix=prefix, + pad_to=None, + expanded=False + ) + # record patch masks paths to dataframe, then save static columns to CSV + se_pa = df_exp.patch_mask_path.apply( + lambda x: str(Path('tight_patch_masks') / x) + ).rename('tight_patch_masks_path') + self._df = self._df.join(se_pa) + record['tight_patch_masks'] = list(se_pa) + + csv_path = where / 'dataframe' / (prefix + '.csv') + csv_path.parent.mkdir(parents=True, exist_ok=True) + self._df.drop( + ['expanded_slice', 'slice', 'relative_slice', 'binary_mask'], + axis=1 + ).to_csv(csv_path, index=False) + + record['dataframe'] = str(Path('dataframe') / csv_path.name) + return record + + def get_polygons(self, poly_threshold=0, dilation_radius=1) -> pd.DataFrame: + self.coordinates_ = """ + Fit polygons to all object boundaries in the RoiSet + :param poly_threshold: threshold distance for polygon fit; a smaller number follows sharp features more closely + :param dilation_radius: radius of binary dilation to apply before fitting polygon + :return: Series of (variable x 2) np.ndarrays describing (x, y) polymer coordinates + """ + + pad_to = 1 + + def _poly_from_mask(roi): + mask = roi.binary_mask + if len(mask.shape) != 2: + raise PatchMaskShapeError(f'Patch mask needs to be two dimensions to fit a polygon') + + # label and fill holes + labeled = label(mask) + filled = [rp.image_filled for rp in regionprops(labeled)] + assert (np.unique(labeled)[-1] == 1) and (len(filled) == 1), 'Cannot fit multiple polygons in a single patch mask' + + closed = binary_dilation(filled[0], footprint=disk(dilation_radius)) + padded = np.pad(closed, pad_to) * 1.0 + all_contours = find_contours(padded) + + nc = len(all_contours) + for j in range(0, nc): + if all([points_in_poly(all_contours[k], all_contours[j]).all() for k in range(0, nc)]): + contour = all_contours[j] + break + + rel_polygon = approximate_polygon(contour[:, [1, 0]], poly_threshold) - [pad_to, pad_to] + return rel_polygon + [roi.x0, roi.y0] + + return self._df.apply(_poly_from_mask, axis=1) + + @property + def acc_obj_ids(self): + return make_object_ids_from_df(self._df, self.acc_raw.shape_dict) + @staticmethod - def deserialize(acc_raw: GenericImageDataAccessor, where: Path, prefix=''): + def deserialize(acc_raw: GenericImageDataAccessor, where: Path, prefix='roiset') -> Self: + """ + Create an RoiSet object from saved files and an image accessor + :param acc_raw: accessor to image that contains ROIs + :param where: path to directory containing RoiSet serialization files, namely dataframe.csv and a subdirectory + named tight_patch_masks + :param prefix: starting prefix of patch mask filenames + :return: RoiSet object + """ df = pd.read_csv(where / 'dataframe' / (prefix + '.csv'))[['label', 'zi', 'y0', 'y1', 'x0', 'x1']] - - id_mask = np.zeros((*acc_raw.hw, 1, acc_raw.nz), dtype='uint16') - def _label_obj(r): - sl = np.s_[r.y0:r.y1, r.x0:r.x1, :, r.zi:r.zi + 1] - ext = 'png' - fname = f'{prefix}-la{r.label:04d}-zi{r.zi:04d}.{ext}' - try: - ma_acc = generate_file_accessor(where / 'tight_patch_masks' / fname) - bool_mask = ma_acc.data / np.iinfo(ma_acc.data.dtype).max - id_mask[sl] = id_mask[sl] + r.label * bool_mask - except Exception as e: - raise DeserializeRoiSet(e) - - df.apply(_label_obj, axis=1) - return RoiSet(acc_raw, InMemoryDataAccessor(id_mask)) - - -def project_stack_from_focal_points( - xx: np.ndarray, - yy: np.ndarray, - zz: np.ndarray, - stack: GenericImageDataAccessor, - degree: int = 2, -) -> np.ndarray: - """ - Given a set of 3D points, project a multichannel z-stack based on a surface fit of the provided points - :param xx: vector of point x-coordinates - :param yy: vector of point y-coordinates - :param zz: vector of point z-coordinates - :param stack: z-stack to project - :param degree: order of polynomial to fit - :return: multichannel 2d projected image array - """ - assert xx.shape == yy.shape - assert xx.shape == zz.shape - - poly = PolynomialFeatures(degree=degree) - X = np.stack([xx, yy]).T - features = poly.fit_transform(X, zz) - model = LinearRegression(fit_intercept=False) - model.fit(features, zz) - - xy_indices = np.indices(stack.hw).reshape(2, -1).T - xy_features = np.dot( - poly.fit_transform(xy_indices, zz), - model.coef_ - ) - zi_image = xy_features.reshape( - stack.hw - ).round().clip( - 0, (stack.nz - 1) - ).astype('uint16') - - return np.take_along_axis( - stack.data, - np.repeat( - np.expand_dims(zi_image, (2, 3)), - stack.chroma, - axis=2 - ), - axis=3 - ) + pa_masks = where / 'tight_patch_masks' + + if pa_masks.exists(): # import segmentation masks + def _read_binary_mask(r): + ext = 'png' + fname = f'{prefix}-la{r.label:04d}-zi{r.zi:04d}.{ext}' + try: + ma_acc = generate_file_accessor(pa_masks / fname) + assert ma_acc.chroma == 1 and ma_acc.nz == 1 + mask_data = ma_acc.data_xy / np.iinfo(ma_acc.data.dtype).max + return mask_data + except Exception as e: + raise DeserializeRoiSet(e) + + df['binary_mask'] = df.apply(_read_binary_mask, axis=1) + id_mask = make_object_ids_from_df(df, acc_raw.shape_dict) + return RoiSet.from_object_ids(acc_raw, id_mask) + + else: # assume bounding boxes only + df['y'] = df['y0'] + df['x'] = df['x0'] + df['h'] = df['y1'] - df['y0'] + df['w'] = df['x1'] - df['x0'] + return RoiSet.from_bounding_boxes( + acc_raw, + df[['y', 'x', 'h', 'w']].to_dict(orient='records'), + list(df['zi']) + ) class Error(Exception): pass +class BoundingBoxError(Error): + pass + class DeserializeRoiSet(Error): pass +class NoDeprojectChannelSpecifiedError(Error): + pass + class DerivedChannelError(Error): + pass + +class MissingSegmentationError(Error): + pass + +class PatchMaskShapeError(Error): + pass + +class ShapeMismatchError(Error): + pass + +class MissingInstanceLabelsError(Error): pass \ No newline at end of file diff --git a/model_server/base/util.py b/model_server/base/util.py index 9b95edca73fbef1e831d4e423e7a30bc82ad5907..81736d485cd322d9ee42af0dcc6d77604314e154 100644 --- a/model_server/base/util.py +++ b/model_server/base/util.py @@ -8,6 +8,7 @@ import pandas as pd from .accessors import InMemoryDataAccessor, write_accessor_data_to_file from .models import Model +from .roiset import filter_df_overlap_seg, RoiSet def autonumber_new_directory(where: str, prefix: str) -> str: """ @@ -163,4 +164,4 @@ def loop_workflow( ) if len(failures) > 0: - pd.DataFrame(failures).to_csv(Path(output_folder_path) / 'failures.csv') + pd.DataFrame(failures).to_csv(Path(output_folder_path) / 'failures.csv') \ No newline at end of file diff --git a/tests/base/test_accessors.py b/tests/base/test_accessors.py index a8a879f28ba238f62cfa70029d44672ed4477a3a..9bb3b7a8c5a50a4f9aedc1818df3febc2590eaa3 100644 --- a/tests/base/test_accessors.py +++ b/tests/base/test_accessors.py @@ -61,6 +61,52 @@ class TestCziImageFileAccess(unittest.TestCase): sc = cf.get_mono(c, mip=True) self.assertEqual(sc.shape, (h, w, 1, 1)) + def test_get_single_channel_argmax_from_zstack(self): + w = 256 + h = 512 + nc = 4 + nz = 11 + c = 3 + cf = InMemoryDataAccessor(np.random.rand(h, w, nc, nz)) + am = cf.get_mono(c).get_z_argmax() + self.assertEqual(am.shape, (h, w, 1, 1)) + self.assertTrue(np.all(am.unique()[0] == range(0, nz))) + + def test_get_single_channel_z_series_from_zstack(self): + w = 256 + h = 512 + nc = 4 + nz = 11 + c = 3 + cf = InMemoryDataAccessor(np.random.rand(h, w, nc, nz)) + zs = cf.get_mono(c).get_focus_vector() + self.assertEqual(zs.shape, (nz, )) + + def test_get_zi(self): + w = 256 + h = 512 + nc = 4 + nz = 11 + zi = 5 + cf = InMemoryDataAccessor(_random_int(h, w, nc, nz)) + sz = cf.get_zi(zi) + self.assertEqual(sz.shape_dict['Z'], 1) + + self.assertTrue(np.all(sz.data[:, :, :, 0] == cf.data[:, :, :, zi])) + + def test_crop_yx(self): + w = 256 + h = 512 + nc = 4 + nz = 11 + cf = InMemoryDataAccessor(_random_int(h, w, nc, nz)) + + yxhw = (100, 200, 10, 20) + sc = cf.crop_hw(yxhw) + self.assertEqual(sc.shape_dict['Z'], nz) + self.assertEqual(sc.shape_dict['C'], nc) + self.assertEqual(sc.hw, yxhw[2:]) + def test_write_single_channel_tif(self): ch = 4 cf = CziImageFileAccessor(data['czifile']['path']) diff --git a/tests/base/test_roiset.py b/tests/base/test_roiset.py index f993d2a00ba233b16a4d35d4a1a917cd27c5f98c..93e96df72b99640c418ece38cbb467fb0aeace92 100644 --- a/tests/base/test_roiset.py +++ b/tests/base/test_roiset.py @@ -6,16 +6,19 @@ from pathlib import Path import pandas as pd -from model_server.base.roiset import RoiSetExportParams, RoiSetMetaParams +from model_server.base.roiset import filter_df_overlap_bbox, filter_df_overlap_seg, RoiSetExportParams, RoiSetMetaParams from model_server.base.roiset import RoiSet from model_server.base.accessors import generate_file_accessor, InMemoryDataAccessor, write_accessor_data_to_file, PatchStack from model_server.base.models import DummyInstanceSegmentationModel +from model_server.base.process import smooth import model_server.conf.testing as conf data = conf.meta['image_files'] output_path = conf.meta['output_path'] params = conf.meta['roiset'] + + class BaseTestRoiSetMonoProducts(object): def setUp(self) -> None: @@ -28,7 +31,7 @@ class BaseTestRoiSetMonoProducts(object): class TestRoiSetMonoProducts(BaseTestRoiSetMonoProducts, unittest.TestCase): def _make_roi_set(self, mask_type='boxes', **kwargs): - roiset = RoiSet.from_segmentation( + roiset = RoiSet.from_binary_mask( self.stack_ch_pa, self.seg_mask, params=RoiSetMetaParams( @@ -69,7 +72,7 @@ class TestRoiSetMonoProducts(BaseTestRoiSetMonoProducts, unittest.TestCase): acc_zstack_slice = InMemoryDataAccessor(self.stack_ch_pa.data[:, :, :, 0]) self.assertEqual(acc_zstack_slice.nz, 1) - roiset = RoiSet.from_segmentation(acc_zstack_slice, self.seg_mask, params=RoiSetMetaParams(mask_type='boxes')) + roiset = RoiSet.from_binary_mask(acc_zstack_slice, self.seg_mask, params=RoiSetMetaParams(mask_type='boxes')) zmask = roiset.get_zmask() zmask_acc = InMemoryDataAccessor(zmask) @@ -77,7 +80,7 @@ class TestRoiSetMonoProducts(BaseTestRoiSetMonoProducts, unittest.TestCase): def test_create_roiset_with_no_objects(self): zero_obmap = InMemoryDataAccessor(np.zeros(self.seg_mask.shape, self.seg_mask.dtype)) - roiset = RoiSet(self.stack_ch_pa, zero_obmap) + roiset = RoiSet.from_object_ids(self.stack_ch_pa, zero_obmap) self.assertEqual(roiset.count, 0) def test_slices_are_valid(self): @@ -162,26 +165,6 @@ class TestRoiSetMonoProducts(BaseTestRoiSetMonoProducts, unittest.TestCase): result = generate_file_accessor(where / file) self.assertEqual(result.shape, roiset.acc_raw.shape) - def test_flatten_image(self): - roiset = RoiSet.from_segmentation(self.stack_ch_pa, self.seg_mask, params=RoiSetMetaParams(mask_type='boxes')) - df = roiset.get_df() - - from model_server.base.roiset import project_stack_from_focal_points - - img = project_stack_from_focal_points( - df['centroid-0'].to_numpy(), - df['centroid-1'].to_numpy(), - df['zi'].to_numpy(), - self.stack, - degree=4, - ) - - self.assertEqual(img.shape[0:2], self.stack.shape[0:2]) - - write_accessor_data_to_file( - output_path / 'flattened.tif', - InMemoryDataAccessor(img) - ) def test_make_binary_masks(self): roiset = self._make_roi_set() @@ -200,26 +183,51 @@ class TestRoiSetMonoProducts(BaseTestRoiSetMonoProducts, unittest.TestCase): roiset = self._make_roi_set() roiset.classify_by('dummy_class', [0], DummyInstanceSegmentationModel()) self.assertTrue(all(roiset.get_df()['classify_by_dummy_class'].unique() == [1])) - self.assertTrue(all(np.unique(roiset.object_class_maps['dummy_class'].data) == [0, 1])) + self.assertTrue(all(np.unique(roiset.get_object_class_map('dummy_class').data) == [0, 1])) return roiset def test_classify_by_multiple_channels(self): - roiset = RoiSet.from_segmentation(self.stack, self.seg_mask) + roiset = RoiSet.from_binary_mask(self.stack, self.seg_mask, params=RoiSetMetaParams(deproject_channel=0)) roiset.classify_by('dummy_class', [0, 1], DummyInstanceSegmentationModel()) self.assertTrue(all(roiset.get_df()['classify_by_dummy_class'].unique() == [1])) - self.assertTrue(all(np.unique(roiset.object_class_maps['dummy_class'].data) == [0, 1])) + self.assertTrue(all(np.unique(roiset.get_object_class_map('dummy_class').data) == [0, 1])) return roiset + def test_transfer_classification(self): + roiset1 = RoiSet.from_binary_mask(self.stack, self.seg_mask, params=RoiSetMetaParams(deproject_channel=0)) + + # prepare alternative mask and compare + smoothed_mask = self.seg_mask.apply(lambda x: smooth(x, sig=1.5)) + roiset2 = RoiSet.from_binary_mask(self.stack, smoothed_mask, params=RoiSetMetaParams(deproject_channel=0)) + dmask = (self.seg_mask.data / 255) + (smoothed_mask.data / 255) + self.assertTrue(np.all(np.unique(dmask) == [0, 1, 2])) + total_iou = (dmask == 2).sum() / ((dmask == 1).sum() + (dmask == 2).sum()) + self.assertGreater(total_iou, 0.6) + + # classify first RoiSet + roiset1.classify_by('dummy_class', [0, 1], DummyInstanceSegmentationModel()) + + self.assertTrue('dummy_class' in roiset1.classification_columns) + self.assertFalse('dummy_class' in roiset2.classification_columns) + res = roiset2.get_instance_classification(roiset1) + self.assertTrue('dummy_class' in roiset2.classification_columns) + self.assertLess( + roiset2.get_df().classify_by_dummy_class.count(), + roiset1.get_df().classify_by_dummy_class.count(), + ) + + def test_classify_by_with_derived_channel(self): class ModelWithDerivedInputs(DummyInstanceSegmentationModel): def infer(self, img, mask): return PatchStack(super().infer(img, mask).data * img.chroma) - roiset = RoiSet.from_segmentation( + roiset = RoiSet.from_binary_mask( self.stack, self.seg_mask, params=RoiSetMetaParams( filters={'area': {'min': 1e3, 'max': 1e4}}, + deproject_channel=0, ) ) roiset.classify_by( @@ -232,7 +240,7 @@ class TestRoiSetMonoProducts(BaseTestRoiSetMonoProducts, unittest.TestCase): ] ) self.assertTrue(all(roiset.get_df()['classify_by_multiple_input_model'].unique() == [4])) - self.assertTrue(all(np.unique(roiset.object_class_maps['multiple_input_model'].data) == [0, 4])) + self.assertTrue(all(np.unique(roiset.get_object_class_map('multiple_input_model').data) == [0, 4])) self.assertEqual(len(roiset.accs_derived), 2) for di in roiset.accs_derived: @@ -281,13 +289,14 @@ class TestRoiSetMultichannelProducts(BaseTestRoiSetMonoProducts, unittest.TestCa def setUp(self) -> None: super().setUp() - self.roiset = RoiSet.from_segmentation( + self.roiset = RoiSet.from_binary_mask( self.stack, self.seg_mask, params=RoiSetMetaParams( expand_box_by=(128, 2), mask_type='boxes', filters={'area': {'min': 1e3, 'max': 1e4}}, + deproject_channel=0, ) ) @@ -520,7 +529,7 @@ class TestRoiSetMultichannelProducts(BaseTestRoiSetMonoProducts, unittest.TestCa self.assertEqual(pacc.chroma, 1) -from model_server.base.roiset import _get_label_ids +from model_server.base.roiset import get_label_ids class TestRoiSetSerialization(unittest.TestCase): def setUp(self) -> None: @@ -528,6 +537,7 @@ class TestRoiSetSerialization(unittest.TestCase): self.stack = generate_file_accessor(data['multichannel_zstack_raw']['path']) self.stack_ch_pa = self.stack.get_mono(params['segmentation_channel']) self.seg_mask_3d = generate_file_accessor(data['multichannel_zstack_mask3d']['path']) + self.seg_mask_2d = generate_file_accessor(data['multichannel_zstack_raw']['path']) @staticmethod def _label_is_2d(id_map, la): # single label's zmask has same counts as its MIP @@ -536,22 +546,22 @@ class TestRoiSetSerialization(unittest.TestCase): return mask_3d.sum() == mask_mip.sum() def test_id_map_connects_z(self): - id_map = _get_label_ids(self.seg_mask_3d, allow_3d=True, connect_3d=True) + id_map = get_label_ids(self.seg_mask_3d, allow_3d=True, connect_3d=True) labels = np.unique(id_map.data)[1:] is_2d = all([self._label_is_2d(id_map.data, la) for la in labels]) self.assertFalse(is_2d) def test_id_map_disconnects_z(self): - id_map = _get_label_ids(self.seg_mask_3d, allow_3d=True, connect_3d=False) + id_map = get_label_ids(self.seg_mask_3d, allow_3d=True, connect_3d=False) labels = np.unique(id_map.data)[1:] is_2d = all([self._label_is_2d(id_map.data, la) for la in labels]) self.assertTrue(is_2d) def test_create_roiset_from_3d_obj_ids(self): - id_map = _get_label_ids(self.seg_mask_3d, allow_3d=True, connect_3d=False) + id_map = get_label_ids(self.seg_mask_3d, allow_3d=True, connect_3d=False) self.assertEqual(self.stack_ch_pa.shape, id_map.shape) - roiset = RoiSet( + roiset = RoiSet.from_object_ids( self.stack_ch_pa, id_map, params=RoiSetMetaParams(mask_type='contours') @@ -560,11 +570,11 @@ class TestRoiSetSerialization(unittest.TestCase): self.assertGreater(len(roiset.get_df()['zi'].unique()), 1) def test_create_roiset_from_2d_obj_ids(self): - id_map = _get_label_ids(self.seg_mask_3d, allow_3d=False) + id_map = get_label_ids(self.seg_mask_3d, allow_3d=False) self.assertEqual(self.stack_ch_pa.shape[0:3], id_map.shape[0:3]) self.assertEqual(id_map.nz, 1) - roiset = RoiSet( + roiset = RoiSet.from_object_ids( self.stack_ch_pa, id_map, params=RoiSetMetaParams(mask_type='contours') @@ -593,6 +603,7 @@ class TestRoiSetSerialization(unittest.TestCase): m_acc = generate_file_accessor(pmf) self.assertEqual((roi.h, roi.w), m_acc.hw) patch_filenames.append(pmf.name) + self.assertEqual(m_acc.nz, 1) # make another RoiSet from just the data table, raw images, and (tight) patch masks test_roiset = RoiSet.deserialize(self.stack_ch_pa, where_ser, prefix='ref') @@ -614,3 +625,191 @@ class TestRoiSetSerialization(unittest.TestCase): t_acc = generate_file_accessor(pt) self.assertTrue(np.all(r_acc.data == t_acc.data)) + +class TestRoiSetObjectDetection(unittest.TestCase): + + def setUp(self) -> None: + # set up test raw data and segmentation from file + self.stack = generate_file_accessor(data['multichannel_zstack_raw']['path']) + self.stack_ch_pa = self.stack.get_mono(params['segmentation_channel']) + self.seg_mask_3d = generate_file_accessor(data['multichannel_zstack_mask3d']['path']) + + def test_create_roiset_from_bounding_boxes(self): + from skimage.measure import label, regionprops, regionprops_table + + mask = self.seg_mask_3d + labels = label(mask.data_xyz, connectivity=3) + table = pd.DataFrame( + regionprops_table(labels) + ).rename( + columns={'bbox-0': 'y', 'bbox-1': 'x', 'bbox-2': 'zi', 'bbox-3': 'y1', 'bbox-4': 'x1'} + ).drop( + columns=['bbox-5'] + ) + table['w'] = table['x1'] - table['x'] + table['h'] = table['y1'] - table['y'] + bboxes = table[['y', 'x', 'h', 'w']].to_dict(orient='records') + + roiset_bbox = RoiSet.from_bounding_boxes(self.stack_ch_pa, bboxes) + self.assertTrue('label' in roiset_bbox.get_df().columns) + patches_bbox = roiset_bbox.get_patches_acc() + self.assertEqual(len(table), patches_bbox.count) + + + # roiset w/ seg for comparison + roiset_seg = RoiSet.from_binary_mask(self.stack_ch_pa, mask, allow_3d=True) + patches_seg = roiset_seg.get_patches_acc() + + # test bounding box dimensions match those from RoiSet generated directly from segmentation + self.assertEqual(roiset_seg.count, roiset_bbox.count) + for i in range(0, roiset_seg.count): + self.assertEqual(patches_seg.iat(0, crop=True).shape, patches_bbox.iat(0, crop=True).shape) + + # test that serialization does not write patch masks + roiset_ser_path = output_path / 'roiset_from_bbox' + dd = roiset_bbox.serialize(roiset_ser_path) + self.assertTrue('tight_patch_masks' not in dd.keys()) + self.assertFalse((roiset_ser_path / 'tight_patch_masks').exists()) + + # test that deserialized RoiSet matches the original + roiset_des = RoiSet.deserialize(self.stack_ch_pa, roiset_ser_path) + self.assertEqual(roiset_des.count, roiset_bbox.count) + for i in range(0, roiset_des.count): + self.assertEqual(patches_seg.iat(0, crop=True).shape, patches_bbox.iat(0, crop=True).shape) + self.assertTrue((roiset_bbox.get_zmask() == roiset_des.get_zmask()).all()) + + +class TestRoiSetPolygons(BaseTestRoiSetMonoProducts, unittest.TestCase): + + def test_compute_polygons(self): + roiset_ref = RoiSet.from_binary_mask( + self.stack_ch_pa, + self.seg_mask, + params=RoiSetMetaParams( + mask_type='contours', + filters={'area': {'min': 1e1, 'max': 1e6}} + ) + ) + + poly = roiset_ref.get_polygons() + roiset_test = RoiSet.from_polygons_2d(self.stack_ch_pa, poly) + binary_poly = (roiset_test.acc_obj_ids.get_mono(0, mip=True).data > 0) + self.assertEqual(self.seg_mask.shape, binary_poly.shape) + + + # most mask pixels are within in fitted polygon + test_mask = np.logical_and( + np.logical_not(binary_poly), + (self.seg_mask.data == 255) + ) + self.assertLess(test_mask.sum() / test_mask.size, 0.001) + + # output results + od = output_path / 'polygons' + write_accessor_data_to_file(od / 'from_polygons.tif', InMemoryDataAccessor(binary_poly)) + write_accessor_data_to_file(od / 'ref_mask.tif', self.seg_mask) + write_accessor_data_to_file(od / 'diff.tif', InMemoryDataAccessor(test_mask)) + + + def test_overlap_bbox(self): + df = pd.DataFrame({ + 'x0': [0, 1, 2, 1, 1], + 'x1': [2, 3, 4, 3, 3], + 'y0': [0, 0, 0, 2, 0], + 'y1': [2, 2, 2, 3, 2], + 'zi': [0, 0, 0, 0, 1], + }) + + res = filter_df_overlap_bbox(df) + self.assertEqual(len(res), 4) + self.assertTrue((res.loc[0, 'overlaps_with'] == [1]).all()) + self.assertTrue((res.loc[1, 'overlaps_with'] == [0, 2]).all()) + self.assertTrue((res.bbox_intersec == 2).all()) + return res + + + def test_overlap_bbox_multiple(self): + df1 = pd.DataFrame({ + 'x0': [0, 1], + 'x1': [2, 3], + 'y0': [0, 0], + 'y1': [2, 2], + 'zi': [0, 0], + }) + df2 = pd.DataFrame({ + 'x0': [2], + 'x1': [4], + 'y0': [0], + 'y1': [2], + 'zi': [0], + }) + res = filter_df_overlap_bbox(df1, df2) + self.assertTrue((res.loc[1, 'overlaps_with'] == [0]).all()) + self.assertEqual(len(res), 1) + self.assertTrue((res.bbox_intersec == 2).all()) + + + def test_overlap_seg(self): + df = pd.DataFrame({ + 'x0': [0, 1, 2], + 'x1': [2, 3, 4], + 'y0': [0, 0, 0], + 'y1': [2, 2, 2], + 'zi': [0, 0, 0], + 'binary_mask': [ + [ + [1, 1], + [1, 0] + ], + [ + [0, 1], + [1, 1] + ], + [ + [1, 1], + [1, 1] + ], + ] + }) + + res = filter_df_overlap_seg(df) + self.assertTrue((res.loc[res.seg_overlaps, :].index == [1, 2]).all()) + self.assertTrue((res.loc[res.seg_overlaps, 'seg_iou'] == 0.4).all()) + + def test_overlap_seg_multiple(self): + df1 = pd.DataFrame({ + 'x0': [0, 1], + 'x1': [2, 3], + 'y0': [0, 0], + 'y1': [2, 2], + 'zi': [0, 0], + 'binary_mask': [ + [ + [1, 1], + [1, 0] + ], + [ + [0, 1], + [1, 1] + ], + ] + }) + df2 = pd.DataFrame({ + 'x0': [2], + 'x1': [4], + 'y0': [0], + 'y1': [2], + 'zi': [0], + 'binary_mask': [ + [ + [1, 1], + [1, 1] + ], + ] + }) + res = filter_df_overlap_seg(df1, df2) + self.assertTrue((res.loc[1, 'overlaps_with'] == [0]).all()) + self.assertEqual(len(res), 1) + self.assertTrue((res.bbox_intersec == 2).all()) + self.assertTrue((res.loc[res.seg_overlaps, :].index == [1]).all()) + self.assertTrue((res.loc[res.seg_overlaps, 'seg_iou'] == 0.4).all()) diff --git a/tests/test_ilastik/test_ilastik.py b/tests/test_ilastik/test_ilastik.py index bd2be94a9e2c62e9ca874296523421e77a7158b4..fd6893a55feff793facb931efb6c4265abd9961d 100644 --- a/tests/test_ilastik/test_ilastik.py +++ b/tests/test_ilastik/test_ilastik.py @@ -5,7 +5,7 @@ import numpy as np from model_server.base.accessors import CziImageFileAccessor, generate_file_accessor, InMemoryDataAccessor, PatchStack, write_accessor_data_to_file from model_server.extensions.ilastik import models as ilm from model_server.extensions.ilastik.workflows import infer_px_then_ob_model -from model_server.base.roiset import _get_label_ids, RoiSet, RoiSetMetaParams +from model_server.base.roiset import get_label_ids, RoiSet, RoiSetMetaParams from model_server.base.workflows import classify_pixels import model_server.conf.testing as conf @@ -363,7 +363,7 @@ class TestIlastikOnMultichannelInputs(conf.TestServerBaseClass): acc_input = generate_file_accessor(self.pa_input_image) acc_obmap = generate_file_accessor(res.object_map_filepath) self.assertEqual(acc_obmap.hw, acc_input.hw) - self.assertEqual(len(acc_obmap._unique()[1]), 3) + self.assertEqual(len(acc_obmap.unique()[1]), 3) def test_api(self): @@ -401,9 +401,9 @@ class TestIlastikObjectClassification(unittest.TestCase): stack_ch_pa = stack.get_mono(conf.meta['roiset']['patches_channel']) seg_mask = generate_file_accessor(data['multichannel_zstack_mask2d']['path']) - self.roiset = RoiSet( + self.roiset = RoiSet.from_binary_mask( stack_ch_pa, - _get_label_ids(seg_mask), + seg_mask, params=RoiSetMetaParams( mask_type='boxes', filters={'area': {'min': 1e3, 'max': 1e4}},