diff --git a/model_server/base/roiset.py b/model_server/base/roiset.py index 96162c95b6a069be79ab6771c63ed27d3ee52f00..8b7642774a4e84c9f476b5aa744ec62652ed551e 100644 --- a/model_server/base/roiset.py +++ b/model_server/base/roiset.py @@ -370,8 +370,9 @@ class RoiSet(object): """Expose ROI meta information via the Pandas.DataFrame API""" return self._df.itertuples(name='Roi') - @staticmethod + @classmethod def from_object_ids( + cls, acc_raw: GenericImageDataAccessor, acc_obj_ids: GenericImageDataAccessor, params: RoiSetMetaParams = RoiSetMetaParams(), @@ -395,10 +396,11 @@ class RoiSet(object): params.filters, ) - return RoiSet(acc_raw, df, params) + return cls(acc_raw, df, params) - @staticmethod + @classmethod def from_bounding_boxes( + cls, acc_raw: GenericImageDataAccessor, bbox_yxhw: List[Dict], bbox_zi: Union[List[int], int] = None, @@ -451,11 +453,12 @@ class RoiSet(object): axis=1, result_type='reduce', ) - return RoiSet(acc_raw, df, params) + return cls(acc_raw, df, params) - @staticmethod + @classmethod def from_binary_mask( + cls, acc_raw: GenericImageDataAccessor, acc_seg: GenericImageDataAccessor, allow_3d=False, @@ -470,7 +473,7 @@ class RoiSet(object): :param connect_3d: objects can span multiple z-positions if True; objects are unique to a single z if False :param params: optional arguments that influence the definition and representation of ROIs """ - return RoiSet.from_object_ids( + return cls.from_object_ids( acc_raw, get_label_ids( acc_seg, @@ -480,8 +483,9 @@ class RoiSet(object): params ) - @staticmethod + @classmethod def from_polygons_2d( + cls, acc_raw, polygons: List[np.ndarray], params: RoiSetMetaParams = RoiSetMetaParams() @@ -496,7 +500,7 @@ class RoiSet(object): for p in polygons: sl = draw.polygon(p[:, 1], p[:, 0]) mask[sl] = True - return RoiSet.from_binary_mask( + return cls.from_binary_mask( acc_raw, InMemoryDataAccessor(mask), allow_3d=False, @@ -566,35 +570,21 @@ class RoiSet(object): def classify_by( self, name: str, channels: list[int], object_classification_model: InstanceSegmentationModel, - derived_channel_functions: list[callable] = None ): """ Insert a column in RoiSet data table that associates each ROI with an integer class, determined by passing - specified inputs through an instance segmentation classifier. Optionally derive additional inputs for object - classification by passing a raw input channel through one or more functions. + specified inputs through an instance segmentation classifier. :param name: name of column to insert :param channels: list of nc raw input channels to send to classifier :param object_classification_model: InstanceSegmentation model object - :param derived_channel_functions: list of functions that each receive a PatchStack accessor with nc channels and - that return a single-channel PatchStack accessor of the same shape :return: None """ if self.count == 0: self._df['classify_by_' + name] = None return True - # combine channels - data_derived = [acc.data for acc in self.accs_derived] - input_acc = PatchStack( - np.concatenate( - [*mono_data, *data_derived], - axis=raw_acc._ga('C') - ) - ) - - else: - input_acc = raw_acc + input_acc = self.get_patches_acc(channels=channels, expanded=False, pad_to=None) # all channels # do this on a patch basis, i.e. only one object per frame obmap_patches = object_classification_model.label_patch_stack( @@ -1035,8 +1025,8 @@ class RoiSet(object): def acc_obj_ids(self): return make_object_ids_from_df(self._df, self.acc_raw.shape_dict) - @staticmethod - def deserialize(acc_raw: GenericImageDataAccessor, where: Path, prefix='roiset') -> Self: + @classmethod + def deserialize(cls, acc_raw: GenericImageDataAccessor, where: Path, prefix='roiset') -> Self: """ Create an RoiSet object from saved files and an image accessor :param acc_raw: accessor to image that contains ROIs @@ -1062,14 +1052,14 @@ class RoiSet(object): df['binary_mask'] = df.apply(_read_binary_mask, axis=1) id_mask = make_object_ids_from_df(df, acc_raw.shape_dict) - return RoiSet.from_object_ids(acc_raw, id_mask) + return cls.from_object_ids(acc_raw, id_mask) else: # assume bounding boxes only df['y'] = df['y0'] df['x'] = df['x0'] df['h'] = df['y1'] - df['y0'] df['w'] = df['x1'] - df['x0'] - return RoiSet.from_bounding_boxes( + return cls.from_bounding_boxes( acc_raw, df[['y', 'x', 'h', 'w']].to_dict(orient='records'), list(df['zi']) diff --git a/tests/base/test_roiset.py b/tests/base/test_roiset.py index de18e6908c185081b4b49288914fd417bbabd8b7..0a03973e78ca824083f2821157c70bb9adc77265 100644 --- a/tests/base/test_roiset.py +++ b/tests/base/test_roiset.py @@ -219,42 +219,6 @@ class TestRoiSetMonoProducts(BaseTestRoiSetMonoProducts, unittest.TestCase): ) - def test_classify_by_with_derived_channel(self): - class ModelWithDerivedInputs(DummyInstanceSegmentationModel): - def infer(self, img, mask): - return PatchStack(super().infer(img, mask).data * img.chroma) - - roiset = RoiSet.from_binary_mask( - self.stack, - self.seg_mask, - params=RoiSetMetaParams( - filters={'area': {'min': 1e3, 'max': 1e4}}, - deproject_channel=0, - ) - ) - roiset.classify_by( - 'multiple_input_model', - [0, 1], - ModelWithDerivedInputs(), - derived_channel_functions=[ - lambda acc: PatchStack(2 * acc.get_channels([0]).data), - lambda acc: PatchStack((0.5 * acc.get_channels([1]).data).astype('uint8')) - ] - ) - self.assertTrue(all(roiset.get_df()['classify_by_multiple_input_model'].unique() == [4])) - self.assertTrue(all(np.unique(roiset.get_object_class_map('multiple_input_model').data) == [0, 4])) - - self.assertEqual(len(roiset.accs_derived), 2) - for di in roiset.accs_derived: - self.assertEqual(roiset.get_patches_acc().hw, di.hw) - self.assertEqual(roiset.get_patches_acc().nz, di.nz) - self.assertEqual(roiset.get_patches_acc().count, di.count) - - dpas = roiset.run_exports(output_path / 'derived_channels', 0, 'der', RoiSetExportParams(derived_channels=True)) - for fp in dpas['derived_channels']: - assert Path(fp).exists() - return roiset - def test_export_object_classes(self): record = self.test_classify_by().run_exports( output_path / 'object_class_maps',