diff --git a/model_server/base/roiset.py b/model_server/base/roiset.py index d41056ba3ff74c0598dd3aed1a86fbf021923beb..0ed42890feb3c092678dc10b204d4a6e912ad19b 100644 --- a/model_server/base/roiset.py +++ b/model_server/base/roiset.py @@ -59,7 +59,6 @@ class RoiSetExportParams(BaseModel): patches_2d: Union[PatchParams, None] = None annotated_zstacks: Union[AnnotatedZStackParams, None] = None object_classes: bool = False - derived_channels: bool = False def get_label_ids(acc_seg_mask: GenericImageDataAccessor, allow_3d=False, connect_3d=True) -> InMemoryDataAccessor: @@ -302,7 +301,6 @@ class RoiSet(object): :param params: optional arguments that influence the definition and representation of ROIs """ self.acc_raw = acc_raw - self.accs_derived = [] self.params = params self._df = df @@ -312,8 +310,9 @@ class RoiSet(object): """Expose ROI meta information via the Pandas.DataFrame API""" return self._df.itertuples(name='Roi') - @staticmethod + @classmethod def from_object_ids( + cls, acc_raw: GenericImageDataAccessor, acc_obj_ids: GenericImageDataAccessor, params: RoiSetMetaParams = RoiSetMetaParams(), @@ -335,11 +334,12 @@ class RoiSet(object): params.filters, ) - return RoiSet(acc_raw, df, params) + return cls(acc_raw, df, params) - @staticmethod + @classmethod def from_bounding_boxes( + cls, acc_raw: GenericImageDataAccessor, yxhw_list: List, params: RoiSetMetaParams = RoiSetMetaParams() @@ -352,11 +352,12 @@ class RoiSet(object): 'x1': yxhw[1] + yxhw[3], } for yxhw in yxhw_list ]) - return RoiSet(acc_raw, df, params) + return cls(acc_raw, df, params) - @staticmethod + @classmethod def from_binary_mask( + cls, acc_raw: GenericImageDataAccessor, acc_seg: GenericImageDataAccessor, allow_3d=False, @@ -371,7 +372,7 @@ class RoiSet(object): :param connect_3d: objects can span multiple z-positions if True; objects are unique to a single z if False :param params: optional arguments that influence the definition and representation of ROIs """ - return RoiSet.from_object_ids( + return cls.from_object_ids( acc_raw, get_label_ids( acc_seg, @@ -381,8 +382,9 @@ class RoiSet(object): params ) - @staticmethod + @classmethod def from_polygons_2d( + cls, acc_raw, polygons: List[np.ndarray], params: RoiSetMetaParams = RoiSetMetaParams() @@ -397,7 +399,7 @@ class RoiSet(object): for p in polygons: sl = draw.polygon(p[:, 1], p[:, 0]) mask[sl] = True - return RoiSet.from_binary_mask( + return cls.from_binary_mask( acc_raw, InMemoryDataAccessor(mask), allow_3d=False, @@ -467,43 +469,18 @@ class RoiSet(object): def classify_by( self, name: str, channels: list[int], object_classification_model: InstanceSegmentationModel, - derived_channel_functions: list[callable] = None ): """ Insert a column in RoiSet data table that associates each ROI with an integer class, determined by passing - specified inputs through an instance segmentation classifier. Optionally derive additional inputs for object - classification by passing a raw input channel through one or more functions. + specified inputs through an instance segmentation classifier. :param name: name of column to insert :param channels: list of nc raw input channels to send to classifier :param object_classification_model: InstanceSegmentation model object - :param derived_channel_functions: list of functions that each receive a PatchStack accessor with nc channels and - that return a single-channel PatchStack accessor of the same shape :return: None """ - raw_acc = self.get_patches_acc(channels=channels, expanded=False, pad_to=None) # all channels - if derived_channel_functions is not None: - mono_data = [raw_acc.get_mono(c).data for c in range(0, raw_acc.chroma)] - for fcn in derived_channel_functions: - der = fcn(raw_acc) # returns patch stack - if der.shape != mono_data[0].shape or der.dtype not in ['uint8', 'uint16']: - raise DerivedChannelError( - f'Error processing derived channel {der} with shape {der.shape_dict} and dtype {der.dtype}' - ) - self.accs_derived.append(der) - - # combine channels - data_derived = [acc.data for acc in self.accs_derived] - input_acc = PatchStack( - np.concatenate( - [*mono_data, *data_derived], - axis=raw_acc._ga('C') - ) - ) - - else: - input_acc = raw_acc + input_acc = self.get_patches_acc(channels=channels, expanded=False, pad_to=None) # all channels # do this on a patch basis, i.e. only one object per frame obmap_patches = object_classification_model.label_patch_stack( @@ -779,13 +756,6 @@ class RoiSet(object): fp = subdir / n / (pr + '.tif') write_accessor_data_to_file(fp, self.get_object_class_map(n)) record[f'{k}_{n}'] = str(fp) - if k == 'derived_channels': - record[k] = [] - for di, dacc in enumerate(self.accs_derived): - fp = subdir / f'dc{di:01d}.tif' - fp.parent.mkdir(exist_ok=True, parents=True) - dacc.export_pyxcz(fp) - record[k].append(str(fp)) # export dataframe and patch masks record = {**record, **self.serialize(where, prefix=prefix)} @@ -882,6 +852,92 @@ class RoiSet(object): return RoiSet.from_object_ids(acc_raw, id_mask) +class RoiSetWithDerivedChannelsExportParams(RoiSetExportParams): + derived_channels: bool = False + +class RoiSetWithDerivedChannels(RoiSet): + + def __init__(self, *a, **k): + self.accs_derived = [] + return super().__init__(*a, **k) + + def classify_by( + self, name: str, channels: list[int], + object_classification_model: InstanceSegmentationModel, + derived_channel_functions: list[callable] = None + ): + """ + Insert a column in RoiSet data table that associates each ROI with an integer class, determined by passing + specified inputs through an instance segmentation classifier. Derive additional inputs for object + classification by passing a raw input channel through one or more functions. + + :param name: name of column to insert + :param channels: list of nc raw input channels to send to classifier + :param object_classification_model: InstanceSegmentation model object + :param derived_channel_functions: list of functions that each receive a PatchStack accessor with nc channels and + that return a single-channel PatchStack accessor of the same shape + :return: None + """ + + raw_acc = self.get_patches_acc(channels=channels, expanded=False, pad_to=None) # all channels + if derived_channel_functions is not None: + mono_data = [raw_acc.get_mono(c).data for c in range(0, raw_acc.chroma)] + for fcn in derived_channel_functions: + der = fcn(raw_acc) # returns patch stack + if der.shape != mono_data[0].shape or der.dtype not in ['uint8', 'uint16']: + raise DerivedChannelError( + f'Error processing derived channel {der} with shape {der.shape_dict} and dtype {der.dtype}' + ) + self.accs_derived.append(der) + + # combine channels + data_derived = [acc.data for acc in self.accs_derived] + input_acc = PatchStack( + np.concatenate( + [*mono_data, *data_derived], + axis=raw_acc._ga('C') + ) + ) + + else: + input_acc = raw_acc + + # do this on a patch basis, i.e. only one object per frame + obmap_patches = object_classification_model.label_patch_stack( + input_acc, + self.get_patch_masks_acc(expanded=False, pad_to=None) + ) + + self._df['classify_by_' + name] = pd.Series(dtype='Int64') + + for i, roi in enumerate(self): + oc = np.unique( + mask_largest_object( + obmap_patches.iat(i).data + ) + )[-1] + self._df.loc[roi.Index, 'classify_by_' + name] = oc + + def run_exports(self, where: Path, channel, prefix, params: RoiSetWithDerivedChannelsExportParams) -> dict: + """ + Export various representations of ROIs, e.g. patches, annotated stacks, and object maps. + :param where: path of directory in which to write all export products + :param channel: color channel of products to export + :param prefix: prefix of the name of each product's file or subfolder + :param params: RoiSetExportParams object describing which products to export and with which parameters + :return: nested dict of Path objects describing the location of export products + """ + record = super().run_exports(where, channel, prefix, params) + + k = 'derived_channels' + if k in params.dict().keys(): + record[k] = [] + for di, dacc in enumerate(self.accs_derived): + fp = where / k / f'dc{di:01d}.tif' + fp.parent.mkdir(exist_ok=True, parents=True) + dacc.export_pyxcz(fp) + record[k].append(str(fp)) + return record class Error(Exception): pass diff --git a/tests/base/test_roiset.py b/tests/base/test_roiset.py index cf52835e52fd693ad4091563d455456f05f48428..28885525c564b6bb5d9dd81410ce49e171edce1b 100644 --- a/tests/base/test_roiset.py +++ b/tests/base/test_roiset.py @@ -192,41 +192,6 @@ class TestRoiSetMonoProducts(BaseTestRoiSetMonoProducts, unittest.TestCase): self.assertTrue(all(np.unique(roiset.get_object_class_map('dummy_class').data) == [0, 1])) return roiset - def test_classify_by_with_derived_channel(self): - class ModelWithDerivedInputs(DummyInstanceSegmentationModel): - def infer(self, img, mask): - return PatchStack(super().infer(img, mask).data * img.chroma) - - roiset = RoiSet.from_binary_mask( - self.stack, - self.seg_mask, - params=RoiSetMetaParams( - filters={'area': {'min': 1e3, 'max': 1e4}}, - ) - ) - roiset.classify_by( - 'multiple_input_model', - [0, 1], - ModelWithDerivedInputs(), - derived_channel_functions=[ - lambda acc: PatchStack(2 * acc.get_channels([0]).data), - lambda acc: PatchStack((0.5 * acc.get_channels([1]).data).astype('uint8')) - ] - ) - self.assertTrue(all(roiset.get_df()['classify_by_multiple_input_model'].unique() == [4])) - self.assertTrue(all(np.unique(roiset.get_object_class_map('multiple_input_model').data) == [0, 4])) - - self.assertEqual(len(roiset.accs_derived), 2) - for di in roiset.accs_derived: - self.assertEqual(roiset.get_patches_acc().hw, di.hw) - self.assertEqual(roiset.get_patches_acc().nz, di.nz) - self.assertEqual(roiset.get_patches_acc().count, di.count) - - dpas = roiset.run_exports(output_path / 'derived_channels', 0, 'der', RoiSetExportParams(derived_channels=True)) - for fp in dpas['derived_channels']: - assert Path(fp).exists() - return roiset - def test_export_object_classes(self): record = self.test_classify_by().run_exports( output_path / 'object_class_maps', diff --git a/tests/base/test_roiset_derived.py b/tests/base/test_roiset_derived.py new file mode 100644 index 0000000000000000000000000000000000000000..49a535f0b2ff5254fb257d7716b8acd03381c4ca --- /dev/null +++ b/tests/base/test_roiset_derived.py @@ -0,0 +1,59 @@ +from pathlib import Path +import unittest + +import numpy as np + +from model_server.base.roiset import RoiSetWithDerivedChannelsExportParams, RoiSetMetaParams +from model_server.base.roiset import RoiSetWithDerivedChannels +from model_server.base.accessors import generate_file_accessor, PatchStack +import model_server.conf.testing as conf +from tests.base.test_model import DummyInstanceSegmentationModel + +data = conf.meta['image_files'] +params = conf.meta['roiset'] +output_path = conf.meta['output_path'] + +class TestDerivedChannels(unittest.TestCase): + def setUp(self) -> None: + self.stack = generate_file_accessor(data['multichannel_zstack_raw']['path']) + self.stack_ch_pa = self.stack.get_mono(params['patches_channel']) + self.seg_mask = generate_file_accessor(data['multichannel_zstack_mask2d']['path']) + + def test_classify_by_with_derived_channel(self): + class ModelWithDerivedInputs(DummyInstanceSegmentationModel): + def infer(self, img, mask): + return PatchStack(super().infer(img, mask).data * img.chroma) + + roiset = RoiSetWithDerivedChannels.from_binary_mask( + self.stack, + self.seg_mask, + params=RoiSetMetaParams( + filters={'area': {'min': 1e3, 'max': 1e4}}, + ) + ) + self.assertIsInstance(roiset, RoiSetWithDerivedChannels) + roiset.classify_by( + 'multiple_input_model', + [0, 1], + ModelWithDerivedInputs(), + derived_channel_functions=[ + lambda acc: PatchStack(2 * acc.get_channels([0]).data), + lambda acc: PatchStack((0.5 * acc.get_channels([1]).data).astype('uint8')) + ] + ) + self.assertTrue(all(roiset.get_df()['classify_by_multiple_input_model'].unique() == [4])) + self.assertTrue(all(np.unique(roiset.get_object_class_map('multiple_input_model').data) == [0, 4])) + + self.assertEqual(len(roiset.accs_derived), 2) + for di in roiset.accs_derived: + self.assertEqual(roiset.get_patches_acc().hw, di.hw) + self.assertEqual(roiset.get_patches_acc().nz, di.nz) + self.assertEqual(roiset.get_patches_acc().count, di.count) + + dpas = roiset.run_exports( + output_path / 'derived_channels', 0, 'der', + RoiSetWithDerivedChannelsExportParams(derived_channels=True) + ) + for fp in dpas['derived_channels']: + assert Path(fp).exists() + return roiset \ No newline at end of file