From ce0166da7351b8846db7cb8bb941da19e08f7be6 Mon Sep 17 00:00:00 2001 From: Christopher Rhodes <christopher.rhodes@embl.de> Date: Tue, 26 Nov 2024 06:27:10 +0100 Subject: [PATCH] RoiSet.run_exports and .get_export_product_accessors no longer take a channel directly, but rather pull these from parameters of individual patch products --- model_server/base/pipelines/roiset_obmap.py | 2 +- model_server/base/roiset.py | 39 +++++++++++---------- tests/base/test_roiset.py | 13 +++---- 3 files changed, 28 insertions(+), 26 deletions(-) diff --git a/model_server/base/pipelines/roiset_obmap.py b/model_server/base/pipelines/roiset_obmap.py index 74380a02..3bc5e00d 100644 --- a/model_server/base/pipelines/roiset_obmap.py +++ b/model_server/base/pipelines/roiset_obmap.py @@ -84,7 +84,7 @@ def roiset_object_map_pipeline( rois = RoiSet.from_object_ids(d['input'], d['labeled'], RoiSetMetaParams(**k['roi_params'])) # optionally append RoiSet products - for ki, vi in rois.get_export_product_accessors(k['patches_channel'], RoiSetExportParams(**k['export_params'])).items(): + for ki, vi in rois.get_export_product_accessors(RoiSetExportParams(**k['export_params'])).items(): d[ki] = vi # optionally run an object classifier if specified diff --git a/model_server/base/roiset.py b/model_server/base/roiset.py index f7cc75f2..de34993c 100644 --- a/model_server/base/roiset.py +++ b/model_server/base/roiset.py @@ -25,16 +25,22 @@ from .process import mask_largest_object class PatchParams(BaseModel): + white_channel: Union[int, None] = None + channels: Union[List[int], None] = None draw_bounding_box: bool = False draw_contour: bool = False draw_mask: bool = False - rescale_clip: float = 0.001 - focus_metric: str = 'max_sobel' + rescale_clip: Union[float, None] = None + focus_metric: Union[str, None] = 'max_sobel' rgb_overlay_channels: Union[List[Union[int, None]], None] = None rgb_overlay_weights: List[float] = [1.0, 1.0, 1.0] pad_to: Union[int, None] = 256 expanded: bool = False +class AnnotatedPatchParams(PatchParams): + bounding_box_channel: int = 1 + bounding_box_linewidth: int = 2 + class AnnotatedZStackParams(BaseModel): draw_label: bool = False @@ -60,7 +66,7 @@ class RoiSetMetaParams(BaseModel): class RoiSetExportParams(BaseModel): patches_3d: Union[PatchParams, None] = None - annotated_patches_2d: Union[PatchParams, None] = None + annotated_patches_2d: Union[AnnotatedPatchParams, None] = None patches_2d: Union[PatchParams, None] = None annotated_zstacks: Union[AnnotatedZStackParams, None] = None object_classes: bool = False @@ -704,7 +710,7 @@ class RoiSet(object): ext = 'tif' if make_3d or patch.chroma > 3 or kwargs.get('force_tif') else 'png' fname = f'{prefix}-la{roi.label:04d}-zi{roi.zi:04d}.{ext}' - if patch.dtype is np.dtype('uint16'): + if patch.dtype == 'uint16': resampled = patch.apply(resample_to_8bit) write_accessor_data_to_file(where / fname, resampled) else: @@ -883,11 +889,10 @@ class RoiSet(object): """ self._df[colname] = se - def run_exports(self, where: Path, channel, prefix, params: RoiSetExportParams) -> dict: + def run_exports(self, where: Path, prefix, params: RoiSetExportParams) -> dict: """ Export various representations of ROIs, e.g. patches, annotated stacks, and object maps. :param where: path of directory in which to write all export products - :param channel: color channel of products to export :param prefix: prefix of the name of each product's file or subfolder :param params: RoiSetExportParams object describing which products to export and with which parameters :return: nested dict of Path objects describing the location of export products @@ -904,18 +909,17 @@ class RoiSet(object): continue if k == 'patches_3d': df_exp = self.export_patches( - subdir, white_channel=channel, prefix=pr, make_3d=True, **kp + subdir, prefix=pr, make_3d=True, **kp ) record[k] = [str(Path(k) / fn) for fn in df_exp.patch_path] if k == 'annotated_patches_2d': df_exp = self.export_patches( - subdir, prefix=pr, make_3d=False, white_channel=channel, - bounding_box_channel=1, bounding_box_linewidth=2, **kp, + subdir, prefix=pr, make_3d=False, **kp, ) record[k] = [str(Path(k) / fn) for fn in df_exp.patch_path] if k == 'patches_2d': df_exp = self.export_patches( - subdir, white_channel=channel, prefix=pr, make_3d=False, **kp + subdir, prefix=pr, make_3d=False, **kp ) self._df = self._df.join(df_exp.patch_path.apply(lambda x: str(Path('patches_2d') / x))) self._df['patch_id'] = self._df.apply(lambda _: uuid4(), axis=1) @@ -940,10 +944,9 @@ class RoiSet(object): return record - def get_export_product_accessors(self, channel, params: RoiSetExportParams) -> dict: + def get_export_product_accessors(self, params: RoiSetExportParams) -> dict: """ Return various representations of ROIs, e.g. patches, annotated stacks, and object maps, as accessors - :param channel: color channel of products to export :param params: RoiSetExportParams object describing which products to export and with which parameters :return: ordered dict of accessors containing the specified products """ @@ -955,14 +958,13 @@ class RoiSet(object): if kp is None: continue if k == 'patches_3d': - interm[k] = self.get_patches_acc([channel], make_3d=True, **kp) + interm[k] = self.get_patches_acc(make_3d=True, **kp) if k == 'annotated_patches_2d': interm[k] = self.get_patches_acc( - make_3d=False, white_channel=channel, - bounding_box_channel=1, bounding_box_linewidth=2, **kp + make_3d=False, **kp ) if k == 'patches_2d': - interm[k] = self.get_patches_acc(make_3d=False, white_channel=channel, **kp) + interm[k] = self.get_patches_acc(make_3d=False, **kp) if k == 'annotated_zstacks': interm[k] = InMemoryDataAccessor(draw_boxes_on_3d_image(self, **kp)) if k == 'object_classes': @@ -1147,16 +1149,15 @@ class RoiSetWithDerivedChannels(RoiSet): )[-1] self._df.loc[roi.Index, 'classify_by_' + name] = oc - def run_exports(self, where: Path, channel, prefix, params: RoiSetWithDerivedChannelsExportParams) -> dict: + def run_exports(self, where: Path, prefix, params: RoiSetWithDerivedChannelsExportParams) -> dict: """ Export various representations of ROIs, e.g. patches, annotated stacks, and object maps. :param where: path of directory in which to write all export products - :param channel: color channel of products to export :param prefix: prefix of the name of each product's file or subfolder :param params: RoiSetExportParams object describing which products to export and with which parameters :return: nested dict of Path objects describing the location of export products """ - record = super().run_exports(where, channel, prefix, params) + record = super().run_exports(where, prefix, params) k = 'derived_channels' if k in params.dict().keys(): diff --git a/tests/base/test_roiset.py b/tests/base/test_roiset.py index ee73aa4b..1d4742fc 100644 --- a/tests/base/test_roiset.py +++ b/tests/base/test_roiset.py @@ -221,7 +221,6 @@ class TestRoiSetMonoProducts(BaseTestRoiSetMonoProducts, unittest.TestCase): def test_export_object_classes(self): record = self.test_classify_by().run_exports( output_path / 'object_class_maps', - 0, 'obmap', RoiSetExportParams(object_classes=True) ) @@ -400,12 +399,14 @@ class TestRoiSetMultichannelProducts(BaseTestRoiSetMonoProducts, unittest.TestCa p = RoiSetExportParams(**{ 'patches_3d': {}, 'annotated_patches_2d': { + 'white_channel': 3, 'draw_bounding_box': True, 'rgb_overlay_channels': [3, None, None], 'rgb_overlay_weights': [0.2, 1.0, 1.0], 'pad_to': 512, }, 'patches_2d': { + 'white_channel': 3, 'draw_bounding_box': False, 'draw_mask': False, }, @@ -420,7 +421,6 @@ class TestRoiSetMultichannelProducts(BaseTestRoiSetMonoProducts, unittest.TestCa where = output_path / 'run_exports' res = self.roiset.run_exports( where, - channel=3, prefix='test', params=p ) @@ -446,21 +446,22 @@ class TestRoiSetMultichannelProducts(BaseTestRoiSetMonoProducts, unittest.TestCa p = RoiSetExportParams(**{ 'patches_3d': None, 'annotated_patches_2d': { + 'white_channel': 3, 'draw_bounding_box': True, 'rgb_overlay_channels': [3, None, None], 'rgb_overlay_weights': [0.2, 1.0, 1.0], 'pad_to': 512, }, 'patches_2d': { + 'white_channel': 3, 'draw_bounding_box': False, 'draw_mask': False, }, 'annotated_zstacks': {}, 'object_classes': True, }) - self.roiset.classify_by('dummy_class', [0], DummyInstanceMaskSegmentationModel()) + self.roiset.classify_by('dummy_class', [3], DummyInstanceMaskSegmentationModel()) interm = self.roiset.get_export_product_accessors( - channel=3, params=p ) self.assertNotIn('patches_3d', interm.keys()) @@ -485,6 +486,7 @@ class TestRoiSetMultichannelProducts(BaseTestRoiSetMonoProducts, unittest.TestCa def test_run_export_expanded_2d_patch(self): p = RoiSetExportParams(**{ 'patches_2d': { + 'white_channel': -1, 'draw_bounding_box': False, 'draw_mask': False, 'expanded': True, @@ -497,7 +499,6 @@ class TestRoiSetMultichannelProducts(BaseTestRoiSetMonoProducts, unittest.TestCa where = output_path / 'run_exports_expanded_2d_patch' res = self.roiset.run_exports( where, - channel=-1, prefix='test', params=p ) @@ -512,6 +513,7 @@ class TestRoiSetMultichannelProducts(BaseTestRoiSetMonoProducts, unittest.TestCa def test_run_export_mono_2d_patch(self): p = RoiSetExportParams(**{ 'patches_2d': { + 'white_channel': -1, 'draw_bounding_box': False, 'draw_mask': False, 'expanded': True, @@ -525,7 +527,6 @@ class TestRoiSetMultichannelProducts(BaseTestRoiSetMonoProducts, unittest.TestCa where = output_path / 'run_exports_mono_2d_patch' res = self.roiset.run_exports( where, - channel=-1, prefix='test', params=p ) -- GitLab