From ce0166da7351b8846db7cb8bb941da19e08f7be6 Mon Sep 17 00:00:00 2001
From: Christopher Rhodes <christopher.rhodes@embl.de>
Date: Tue, 26 Nov 2024 06:27:10 +0100
Subject: [PATCH] RoiSet.run_exports and .get_export_product_accessors no
 longer take a channel directly, but rather pull these from parameters of
 individual patch products

---
 model_server/base/pipelines/roiset_obmap.py |  2 +-
 model_server/base/roiset.py                 | 39 +++++++++++----------
 tests/base/test_roiset.py                   | 13 +++----
 3 files changed, 28 insertions(+), 26 deletions(-)

diff --git a/model_server/base/pipelines/roiset_obmap.py b/model_server/base/pipelines/roiset_obmap.py
index 74380a02..3bc5e00d 100644
--- a/model_server/base/pipelines/roiset_obmap.py
+++ b/model_server/base/pipelines/roiset_obmap.py
@@ -84,7 +84,7 @@ def roiset_object_map_pipeline(
     rois = RoiSet.from_object_ids(d['input'], d['labeled'], RoiSetMetaParams(**k['roi_params']))
 
     # optionally append RoiSet products
-    for ki, vi in rois.get_export_product_accessors(k['patches_channel'], RoiSetExportParams(**k['export_params'])).items():
+    for ki, vi in rois.get_export_product_accessors(RoiSetExportParams(**k['export_params'])).items():
         d[ki] = vi
 
     # optionally run an object classifier if specified
diff --git a/model_server/base/roiset.py b/model_server/base/roiset.py
index f7cc75f2..de34993c 100644
--- a/model_server/base/roiset.py
+++ b/model_server/base/roiset.py
@@ -25,16 +25,22 @@ from .process import mask_largest_object
 
 
 class PatchParams(BaseModel):
+    white_channel: Union[int, None] = None
+    channels: Union[List[int], None] = None
     draw_bounding_box: bool = False
     draw_contour: bool = False
     draw_mask: bool = False
-    rescale_clip: float = 0.001
-    focus_metric: str = 'max_sobel'
+    rescale_clip: Union[float, None] = None
+    focus_metric: Union[str, None] = 'max_sobel'
     rgb_overlay_channels: Union[List[Union[int, None]], None] = None
     rgb_overlay_weights: List[float] = [1.0, 1.0, 1.0]
     pad_to: Union[int, None] = 256
     expanded: bool = False
 
+class AnnotatedPatchParams(PatchParams):
+    bounding_box_channel: int = 1
+    bounding_box_linewidth: int = 2
+
 
 class AnnotatedZStackParams(BaseModel):
     draw_label: bool = False
@@ -60,7 +66,7 @@ class RoiSetMetaParams(BaseModel):
 
 class RoiSetExportParams(BaseModel):
     patches_3d: Union[PatchParams, None] = None
-    annotated_patches_2d: Union[PatchParams, None] = None
+    annotated_patches_2d: Union[AnnotatedPatchParams, None] = None
     patches_2d: Union[PatchParams, None] = None
     annotated_zstacks: Union[AnnotatedZStackParams, None] = None
     object_classes: bool = False
@@ -704,7 +710,7 @@ class RoiSet(object):
             ext = 'tif' if make_3d or patch.chroma > 3 or kwargs.get('force_tif') else 'png'
             fname = f'{prefix}-la{roi.label:04d}-zi{roi.zi:04d}.{ext}'
 
-            if patch.dtype is np.dtype('uint16'):
+            if patch.dtype == 'uint16':
                 resampled = patch.apply(resample_to_8bit)
                 write_accessor_data_to_file(where / fname, resampled)
             else:
@@ -883,11 +889,10 @@ class RoiSet(object):
         """
         self._df[colname] = se
 
-    def run_exports(self, where: Path, channel, prefix, params: RoiSetExportParams) -> dict:
+    def run_exports(self, where: Path, prefix, params: RoiSetExportParams) -> dict:
         """
         Export various representations of ROIs, e.g. patches, annotated stacks, and object maps.
         :param where: path of directory in which to write all export products
-        :param channel: color channel of products to export
         :param prefix: prefix of the name of each product's file or subfolder
         :param params: RoiSetExportParams object describing which products to export and with which parameters
         :return: nested dict of Path objects describing the location of export products
@@ -904,18 +909,17 @@ class RoiSet(object):
                 continue
             if k == 'patches_3d':
                 df_exp = self.export_patches(
-                    subdir, white_channel=channel, prefix=pr, make_3d=True, **kp
+                    subdir, prefix=pr, make_3d=True, **kp
                 )
                 record[k] = [str(Path(k) / fn) for fn in df_exp.patch_path]
             if k == 'annotated_patches_2d':
                 df_exp = self.export_patches(
-                    subdir, prefix=pr, make_3d=False, white_channel=channel,
-                    bounding_box_channel=1, bounding_box_linewidth=2, **kp,
+                    subdir, prefix=pr, make_3d=False, **kp,
                 )
                 record[k] = [str(Path(k) / fn) for fn in df_exp.patch_path]
             if k == 'patches_2d':
                 df_exp = self.export_patches(
-                    subdir, white_channel=channel, prefix=pr, make_3d=False, **kp
+                    subdir, prefix=pr, make_3d=False, **kp
                 )
                 self._df = self._df.join(df_exp.patch_path.apply(lambda x: str(Path('patches_2d') / x)))
                 self._df['patch_id'] = self._df.apply(lambda _: uuid4(), axis=1)
@@ -940,10 +944,9 @@ class RoiSet(object):
 
         return record
 
-    def get_export_product_accessors(self, channel, params: RoiSetExportParams) -> dict:
+    def get_export_product_accessors(self, params: RoiSetExportParams) -> dict:
         """
         Return various representations of ROIs, e.g. patches, annotated stacks, and object maps, as accessors
-        :param channel: color channel of products to export
         :param params: RoiSetExportParams object describing which products to export and with which parameters
         :return: ordered dict of accessors containing the specified products
         """
@@ -955,14 +958,13 @@ class RoiSet(object):
             if kp is None:
                 continue
             if k == 'patches_3d':
-                interm[k] = self.get_patches_acc([channel], make_3d=True, **kp)
+                interm[k] = self.get_patches_acc(make_3d=True, **kp)
             if k == 'annotated_patches_2d':
                 interm[k] = self.get_patches_acc(
-                    make_3d=False, white_channel=channel,
-                    bounding_box_channel=1, bounding_box_linewidth=2, **kp
+                    make_3d=False, **kp
                 )
             if k == 'patches_2d':
-                interm[k] = self.get_patches_acc(make_3d=False, white_channel=channel, **kp)
+                interm[k] = self.get_patches_acc(make_3d=False, **kp)
             if k == 'annotated_zstacks':
                 interm[k] = InMemoryDataAccessor(draw_boxes_on_3d_image(self, **kp))
             if k == 'object_classes':
@@ -1147,16 +1149,15 @@ class RoiSetWithDerivedChannels(RoiSet):
             )[-1]
             self._df.loc[roi.Index, 'classify_by_' + name] = oc
 
-    def run_exports(self, where: Path, channel, prefix, params: RoiSetWithDerivedChannelsExportParams) -> dict:
+    def run_exports(self, where: Path, prefix, params: RoiSetWithDerivedChannelsExportParams) -> dict:
         """
         Export various representations of ROIs, e.g. patches, annotated stacks, and object maps.
         :param where: path of directory in which to write all export products
-        :param channel: color channel of products to export
         :param prefix: prefix of the name of each product's file or subfolder
         :param params: RoiSetExportParams object describing which products to export and with which parameters
         :return: nested dict of Path objects describing the location of export products
         """
-        record = super().run_exports(where, channel, prefix, params)
+        record = super().run_exports(where, prefix, params)
 
         k = 'derived_channels'
         if k in params.dict().keys():
diff --git a/tests/base/test_roiset.py b/tests/base/test_roiset.py
index ee73aa4b..1d4742fc 100644
--- a/tests/base/test_roiset.py
+++ b/tests/base/test_roiset.py
@@ -221,7 +221,6 @@ class TestRoiSetMonoProducts(BaseTestRoiSetMonoProducts, unittest.TestCase):
     def test_export_object_classes(self):
         record = self.test_classify_by().run_exports(
             output_path / 'object_class_maps',
-            0,
             'obmap',
             RoiSetExportParams(object_classes=True)
         )
@@ -400,12 +399,14 @@ class TestRoiSetMultichannelProducts(BaseTestRoiSetMonoProducts, unittest.TestCa
         p = RoiSetExportParams(**{
             'patches_3d': {},
             'annotated_patches_2d': {
+                'white_channel': 3,
                 'draw_bounding_box': True,
                 'rgb_overlay_channels': [3, None, None],
                 'rgb_overlay_weights': [0.2, 1.0, 1.0],
                 'pad_to': 512,
             },
             'patches_2d': {
+                'white_channel': 3,
                 'draw_bounding_box': False,
                 'draw_mask': False,
             },
@@ -420,7 +421,6 @@ class TestRoiSetMultichannelProducts(BaseTestRoiSetMonoProducts, unittest.TestCa
         where = output_path / 'run_exports'
         res = self.roiset.run_exports(
             where,
-            channel=3,
             prefix='test',
             params=p
         )
@@ -446,21 +446,22 @@ class TestRoiSetMultichannelProducts(BaseTestRoiSetMonoProducts, unittest.TestCa
         p = RoiSetExportParams(**{
             'patches_3d': None,
             'annotated_patches_2d': {
+                'white_channel': 3,
                 'draw_bounding_box': True,
                 'rgb_overlay_channels': [3, None, None],
                 'rgb_overlay_weights': [0.2, 1.0, 1.0],
                 'pad_to': 512,
             },
             'patches_2d': {
+                'white_channel': 3,
                 'draw_bounding_box': False,
                 'draw_mask': False,
             },
             'annotated_zstacks': {},
             'object_classes': True,
         })
-        self.roiset.classify_by('dummy_class', [0], DummyInstanceMaskSegmentationModel())
+        self.roiset.classify_by('dummy_class', [3], DummyInstanceMaskSegmentationModel())
         interm = self.roiset.get_export_product_accessors(
-            channel=3,
             params=p
         )
         self.assertNotIn('patches_3d', interm.keys())
@@ -485,6 +486,7 @@ class TestRoiSetMultichannelProducts(BaseTestRoiSetMonoProducts, unittest.TestCa
     def test_run_export_expanded_2d_patch(self):
         p = RoiSetExportParams(**{
             'patches_2d': {
+                'white_channel': -1,
                 'draw_bounding_box': False,
                 'draw_mask': False,
                 'expanded': True,
@@ -497,7 +499,6 @@ class TestRoiSetMultichannelProducts(BaseTestRoiSetMonoProducts, unittest.TestCa
         where = output_path / 'run_exports_expanded_2d_patch'
         res = self.roiset.run_exports(
             where,
-            channel=-1,
             prefix='test',
             params=p
         )
@@ -512,6 +513,7 @@ class TestRoiSetMultichannelProducts(BaseTestRoiSetMonoProducts, unittest.TestCa
     def test_run_export_mono_2d_patch(self):
         p = RoiSetExportParams(**{
             'patches_2d': {
+                'white_channel': -1,
                 'draw_bounding_box': False,
                 'draw_mask': False,
                 'expanded': True,
@@ -525,7 +527,6 @@ class TestRoiSetMultichannelProducts(BaseTestRoiSetMonoProducts, unittest.TestCa
         where = output_path / 'run_exports_mono_2d_patch'
         res = self.roiset.run_exports(
             where,
-            channel=-1,
             prefix='test',
             params=p
         )
-- 
GitLab