From 6b416419ee899c3f04f55b21a1f7500e8af93ec9 Mon Sep 17 00:00:00 2001
From: Christopher Rhodes <christopher.rhodes@embl.de>
Date: Wed, 14 Aug 2024 15:17:44 +0200
Subject: [PATCH] Separated out derived channel functionality from main RoiSet
 class and tests

---
 model_server/base/roiset.py       | 144 +++++++++++++++++++++---------
 tests/base/test_roiset.py         |  35 --------
 tests/base/test_roiset_derived.py |  59 ++++++++++++
 3 files changed, 159 insertions(+), 79 deletions(-)
 create mode 100644 tests/base/test_roiset_derived.py

diff --git a/model_server/base/roiset.py b/model_server/base/roiset.py
index d41056ba..0ed42890 100644
--- a/model_server/base/roiset.py
+++ b/model_server/base/roiset.py
@@ -59,7 +59,6 @@ class RoiSetExportParams(BaseModel):
     patches_2d: Union[PatchParams, None] = None
     annotated_zstacks: Union[AnnotatedZStackParams, None] = None
     object_classes: bool = False
-    derived_channels: bool = False
 
 
 def get_label_ids(acc_seg_mask: GenericImageDataAccessor, allow_3d=False, connect_3d=True) -> InMemoryDataAccessor:
@@ -302,7 +301,6 @@ class RoiSet(object):
         :param params: optional arguments that influence the definition and representation of ROIs
         """
         self.acc_raw = acc_raw
-        self.accs_derived = []
         self.params = params
 
         self._df = df
@@ -312,8 +310,9 @@ class RoiSet(object):
         """Expose ROI meta information via the Pandas.DataFrame API"""
         return self._df.itertuples(name='Roi')
 
-    @staticmethod
+    @classmethod
     def from_object_ids(
+            cls,
             acc_raw: GenericImageDataAccessor,
             acc_obj_ids: GenericImageDataAccessor,
             params: RoiSetMetaParams = RoiSetMetaParams(),
@@ -335,11 +334,12 @@ class RoiSet(object):
             params.filters,
         )
 
-        return RoiSet(acc_raw, df, params)
+        return cls(acc_raw, df, params)
 
 
-    @staticmethod
+    @classmethod
     def from_bounding_boxes(
+        cls,
         acc_raw: GenericImageDataAccessor,
         yxhw_list: List,
         params: RoiSetMetaParams = RoiSetMetaParams()
@@ -352,11 +352,12 @@ class RoiSet(object):
                 'x1': yxhw[1] + yxhw[3],
             } for yxhw in yxhw_list
         ])
-        return RoiSet(acc_raw, df, params)
+        return cls(acc_raw, df, params)
 
 
-    @staticmethod
+    @classmethod
     def from_binary_mask(
+            cls,
             acc_raw: GenericImageDataAccessor,
             acc_seg: GenericImageDataAccessor,
             allow_3d=False,
@@ -371,7 +372,7 @@ class RoiSet(object):
         :param connect_3d: objects can span multiple z-positions if True; objects are unique to a single z if False
         :param params: optional arguments that influence the definition and representation of ROIs
         """
-        return RoiSet.from_object_ids(
+        return cls.from_object_ids(
             acc_raw,
             get_label_ids(
                 acc_seg,
@@ -381,8 +382,9 @@ class RoiSet(object):
             params
         )
 
-    @staticmethod
+    @classmethod
     def from_polygons_2d(
+            cls,
             acc_raw,
             polygons: List[np.ndarray],
             params: RoiSetMetaParams = RoiSetMetaParams()
@@ -397,7 +399,7 @@ class RoiSet(object):
         for p in polygons:
             sl = draw.polygon(p[:, 1], p[:, 0])
             mask[sl] = True
-        return RoiSet.from_binary_mask(
+        return cls.from_binary_mask(
             acc_raw,
             InMemoryDataAccessor(mask),
             allow_3d=False,
@@ -467,43 +469,18 @@ class RoiSet(object):
     def classify_by(
             self, name: str, channels: list[int],
             object_classification_model: InstanceSegmentationModel,
-            derived_channel_functions: list[callable] = None
     ):
         """
         Insert a column in RoiSet data table that associates each ROI with an integer class, determined by passing
-        specified inputs through an instance segmentation classifier.  Optionally derive additional inputs for object
-        classification by passing a raw input channel through one or more functions.
+        specified inputs through an instance segmentation classifier.
 
         :param name: name of column to insert
         :param channels: list of nc raw input channels to send to classifier
         :param object_classification_model: InstanceSegmentation model object
-        :param derived_channel_functions: list of functions that each receive a PatchStack accessor with nc channels and
-            that return a single-channel PatchStack accessor of the same shape
         :return: None
         """
 
-        raw_acc = self.get_patches_acc(channels=channels, expanded=False, pad_to=None)  # all channels
-        if derived_channel_functions is not None:
-            mono_data = [raw_acc.get_mono(c).data for c in range(0, raw_acc.chroma)]
-            for fcn in derived_channel_functions:
-                der = fcn(raw_acc) # returns patch stack
-                if der.shape != mono_data[0].shape or der.dtype not in ['uint8', 'uint16']:
-                    raise DerivedChannelError(
-                        f'Error processing derived channel {der} with shape {der.shape_dict} and dtype {der.dtype}'
-                    )
-                self.accs_derived.append(der)
-
-            # combine channels
-            data_derived = [acc.data for acc in self.accs_derived]
-            input_acc = PatchStack(
-                np.concatenate(
-                    [*mono_data, *data_derived],
-                    axis=raw_acc._ga('C')
-                )
-            )
-
-        else:
-            input_acc = raw_acc
+        input_acc = self.get_patches_acc(channels=channels, expanded=False, pad_to=None)  # all channels
 
         # do this on a patch basis, i.e. only one object per frame
         obmap_patches = object_classification_model.label_patch_stack(
@@ -779,13 +756,6 @@ class RoiSet(object):
                     fp = subdir / n / (pr + '.tif')
                     write_accessor_data_to_file(fp, self.get_object_class_map(n))
                     record[f'{k}_{n}'] = str(fp)
-            if k == 'derived_channels':
-                record[k] = []
-                for di, dacc in enumerate(self.accs_derived):
-                    fp = subdir / f'dc{di:01d}.tif'
-                    fp.parent.mkdir(exist_ok=True, parents=True)
-                    dacc.export_pyxcz(fp)
-                    record[k].append(str(fp))
 
         # export dataframe and patch masks
         record = {**record, **self.serialize(where, prefix=prefix)}
@@ -882,6 +852,92 @@ class RoiSet(object):
         return RoiSet.from_object_ids(acc_raw, id_mask)
 
 
+class RoiSetWithDerivedChannelsExportParams(RoiSetExportParams):
+    derived_channels: bool = False
+
+class RoiSetWithDerivedChannels(RoiSet):
+
+    def __init__(self, *a, **k):
+        self.accs_derived = []
+        return super().__init__(*a, **k)
+
+    def classify_by(
+            self, name: str, channels: list[int],
+            object_classification_model: InstanceSegmentationModel,
+            derived_channel_functions: list[callable] = None
+    ):
+        """
+        Insert a column in RoiSet data table that associates each ROI with an integer class, determined by passing
+        specified inputs through an instance segmentation classifier.  Derive additional inputs for object
+        classification by passing a raw input channel through one or more functions.
+
+        :param name: name of column to insert
+        :param channels: list of nc raw input channels to send to classifier
+        :param object_classification_model: InstanceSegmentation model object
+        :param derived_channel_functions: list of functions that each receive a PatchStack accessor with nc channels and
+            that return a single-channel PatchStack accessor of the same shape
+        :return: None
+        """
+
+        raw_acc = self.get_patches_acc(channels=channels, expanded=False, pad_to=None)  # all channels
+        if derived_channel_functions is not None:
+            mono_data = [raw_acc.get_mono(c).data for c in range(0, raw_acc.chroma)]
+            for fcn in derived_channel_functions:
+                der = fcn(raw_acc) # returns patch stack
+                if der.shape != mono_data[0].shape or der.dtype not in ['uint8', 'uint16']:
+                    raise DerivedChannelError(
+                        f'Error processing derived channel {der} with shape {der.shape_dict} and dtype {der.dtype}'
+                    )
+                self.accs_derived.append(der)
+
+            # combine channels
+            data_derived = [acc.data for acc in self.accs_derived]
+            input_acc = PatchStack(
+                np.concatenate(
+                    [*mono_data, *data_derived],
+                    axis=raw_acc._ga('C')
+                )
+            )
+
+        else:
+            input_acc = raw_acc
+
+        # do this on a patch basis, i.e. only one object per frame
+        obmap_patches = object_classification_model.label_patch_stack(
+            input_acc,
+            self.get_patch_masks_acc(expanded=False, pad_to=None)
+        )
+
+        self._df['classify_by_' + name] = pd.Series(dtype='Int64')
+
+        for i, roi in enumerate(self):
+            oc = np.unique(
+                mask_largest_object(
+                    obmap_patches.iat(i).data
+                )
+            )[-1]
+            self._df.loc[roi.Index, 'classify_by_' + name] = oc
+
+    def run_exports(self, where: Path, channel, prefix, params: RoiSetWithDerivedChannelsExportParams) -> dict:
+        """
+        Export various representations of ROIs, e.g. patches, annotated stacks, and object maps.
+        :param where: path of directory in which to write all export products
+        :param channel: color channel of products to export
+        :param prefix: prefix of the name of each product's file or subfolder
+        :param params: RoiSetExportParams object describing which products to export and with which parameters
+        :return: nested dict of Path objects describing the location of export products
+        """
+        record = super().run_exports(where, channel, prefix, params)
+
+        k = 'derived_channels'
+        if k in params.dict().keys():
+            record[k] = []
+            for di, dacc in enumerate(self.accs_derived):
+                fp = where / k / f'dc{di:01d}.tif'
+                fp.parent.mkdir(exist_ok=True, parents=True)
+                dacc.export_pyxcz(fp)
+                record[k].append(str(fp))
+        return record
 
 class Error(Exception):
     pass
diff --git a/tests/base/test_roiset.py b/tests/base/test_roiset.py
index cf52835e..28885525 100644
--- a/tests/base/test_roiset.py
+++ b/tests/base/test_roiset.py
@@ -192,41 +192,6 @@ class TestRoiSetMonoProducts(BaseTestRoiSetMonoProducts, unittest.TestCase):
         self.assertTrue(all(np.unique(roiset.get_object_class_map('dummy_class').data) == [0, 1]))
         return roiset
 
-    def test_classify_by_with_derived_channel(self):
-        class ModelWithDerivedInputs(DummyInstanceSegmentationModel):
-            def infer(self, img, mask):
-                return PatchStack(super().infer(img, mask).data * img.chroma)
-
-        roiset = RoiSet.from_binary_mask(
-            self.stack,
-            self.seg_mask,
-            params=RoiSetMetaParams(
-                filters={'area': {'min': 1e3, 'max': 1e4}},
-            )
-        )
-        roiset.classify_by(
-            'multiple_input_model',
-            [0, 1],
-            ModelWithDerivedInputs(),
-            derived_channel_functions=[
-                lambda acc: PatchStack(2 * acc.get_channels([0]).data),
-                lambda acc: PatchStack((0.5 * acc.get_channels([1]).data).astype('uint8'))
-            ]
-        )
-        self.assertTrue(all(roiset.get_df()['classify_by_multiple_input_model'].unique() == [4]))
-        self.assertTrue(all(np.unique(roiset.get_object_class_map('multiple_input_model').data) == [0, 4]))
-
-        self.assertEqual(len(roiset.accs_derived), 2)
-        for di in roiset.accs_derived:
-            self.assertEqual(roiset.get_patches_acc().hw, di.hw)
-            self.assertEqual(roiset.get_patches_acc().nz, di.nz)
-            self.assertEqual(roiset.get_patches_acc().count, di.count)
-
-        dpas = roiset.run_exports(output_path / 'derived_channels', 0, 'der', RoiSetExportParams(derived_channels=True))
-        for fp in dpas['derived_channels']:
-            assert Path(fp).exists()
-        return roiset
-
     def test_export_object_classes(self):
         record = self.test_classify_by().run_exports(
             output_path / 'object_class_maps',
diff --git a/tests/base/test_roiset_derived.py b/tests/base/test_roiset_derived.py
new file mode 100644
index 00000000..49a535f0
--- /dev/null
+++ b/tests/base/test_roiset_derived.py
@@ -0,0 +1,59 @@
+from pathlib import Path
+import unittest
+
+import numpy as np
+
+from model_server.base.roiset import RoiSetWithDerivedChannelsExportParams, RoiSetMetaParams
+from model_server.base.roiset import RoiSetWithDerivedChannels
+from model_server.base.accessors import generate_file_accessor, PatchStack
+import model_server.conf.testing as conf
+from tests.base.test_model import DummyInstanceSegmentationModel
+
+data = conf.meta['image_files']
+params = conf.meta['roiset']
+output_path = conf.meta['output_path']
+
+class TestDerivedChannels(unittest.TestCase):
+    def setUp(self) -> None:
+        self.stack = generate_file_accessor(data['multichannel_zstack_raw']['path'])
+        self.stack_ch_pa = self.stack.get_mono(params['patches_channel'])
+        self.seg_mask = generate_file_accessor(data['multichannel_zstack_mask2d']['path'])
+
+    def test_classify_by_with_derived_channel(self):
+        class ModelWithDerivedInputs(DummyInstanceSegmentationModel):
+            def infer(self, img, mask):
+                return PatchStack(super().infer(img, mask).data * img.chroma)
+
+        roiset = RoiSetWithDerivedChannels.from_binary_mask(
+            self.stack,
+            self.seg_mask,
+            params=RoiSetMetaParams(
+                filters={'area': {'min': 1e3, 'max': 1e4}},
+            )
+        )
+        self.assertIsInstance(roiset, RoiSetWithDerivedChannels)
+        roiset.classify_by(
+            'multiple_input_model',
+            [0, 1],
+            ModelWithDerivedInputs(),
+            derived_channel_functions=[
+                lambda acc: PatchStack(2 * acc.get_channels([0]).data),
+                lambda acc: PatchStack((0.5 * acc.get_channels([1]).data).astype('uint8'))
+            ]
+        )
+        self.assertTrue(all(roiset.get_df()['classify_by_multiple_input_model'].unique() == [4]))
+        self.assertTrue(all(np.unique(roiset.get_object_class_map('multiple_input_model').data) == [0, 4]))
+
+        self.assertEqual(len(roiset.accs_derived), 2)
+        for di in roiset.accs_derived:
+            self.assertEqual(roiset.get_patches_acc().hw, di.hw)
+            self.assertEqual(roiset.get_patches_acc().nz, di.nz)
+            self.assertEqual(roiset.get_patches_acc().count, di.count)
+
+        dpas = roiset.run_exports(
+            output_path / 'derived_channels', 0, 'der',
+            RoiSetWithDerivedChannelsExportParams(derived_channels=True)
+        )
+        for fp in dpas['derived_channels']:
+            assert Path(fp).exists()
+        return roiset
\ No newline at end of file
-- 
GitLab