From 0bf3572546900997a4e223874c0935201192c2e9 Mon Sep 17 00:00:00 2001
From: Christopher Rhodes <christopher.rhodes@embl.de>
Date: Thu, 15 Aug 2024 16:14:22 +0200
Subject: [PATCH] RoiSet can also return ordered dictionary of its export
 products as accessors

---
 model_server/base/api.py                    |  2 +
 model_server/base/pipelines/roiset_obmap.py |  3 --
 model_server/base/roiset.py                 | 34 +++++++++++++++++
 tests/base/test_roiset.py                   | 41 +++++++++++++++++++++
 4 files changed, 77 insertions(+), 3 deletions(-)

diff --git a/model_server/base/api.py b/model_server/base/api.py
index db484d0f..1715d9ab 100644
--- a/model_server/base/api.py
+++ b/model_server/base/api.py
@@ -107,3 +107,5 @@ def write_accessor_to_file(accessor_id: str, filename: Union[str, None] = None)
         raise HTTPException(404, f'Did not find accessor with ID {accessor_id}')
     except WriteAccessorError as e:
         raise HTTPException(409, str(e))
+
+# TODO: endpoint to unload all accessors
\ No newline at end of file
diff --git a/model_server/base/pipelines/roiset_obmap.py b/model_server/base/pipelines/roiset_obmap.py
index c4991aa3..376dbe3c 100644
--- a/model_server/base/pipelines/roiset_obmap.py
+++ b/model_server/base/pipelines/roiset_obmap.py
@@ -88,11 +88,8 @@ def roiset_object_map(p: RoiSetObjectMapParams) -> RoiSetToObjectMapRecord:
     """
     record, rois = call_pipeline(roiset_object_map_pipeline, p)
 
-    # TODO: try labeling this correctly as input accessor_id instead of filename
-    assert isinstance(rois, RoiSet)
     table = rois.get_serializable_dataframe()
 
-    # TODO: instead, explicitly let trace include RoiSets
     session.write_to_table('RoiSet', {'input_filename': p.accessor_id}, table)
     ret = RoiSetToObjectMapRecord(
         roiset_table=table.to_dict(),
diff --git a/model_server/base/roiset.py b/model_server/base/roiset.py
index bccbcb05..d5a78f77 100644
--- a/model_server/base/roiset.py
+++ b/model_server/base/roiset.py
@@ -1,3 +1,4 @@
+from collections import OrderedDict
 from itertools import combinations
 from math import sqrt, floor
 from pathlib import Path
@@ -772,6 +773,39 @@ class RoiSet(object):
 
         return record
 
+    def get_export_product_accessors(self, channel, params: RoiSetExportParams) -> dict:
+        """
+        Return various representations of ROIs, e.g. patches, annotated stacks, and object maps, as accessors
+        :param channel: color channel of products to export
+        :param params: RoiSetExportParams object describing which products to export and with which parameters
+        :return: ordered dict of accessors containing the specified products
+        """
+        interm = OrderedDict()
+        if not self.count:
+            return interm
+
+        for k, kp in params.dict().items():
+            if kp is None:
+                continue
+            if k == 'patches_3d':
+                interm[k] = self.get_patches_acc([channel], make_3d=True, **kp)
+            if k == 'annotated_patches_2d':
+                interm[k] = self.get_patches_acc(
+                    make_3d=False, white_channel=channel,
+                    bounding_box_channel=1, bounding_box_linewidth=2, **kp
+                )
+            if k == 'patches_2d':
+                interm[k] = self.get_patches_acc(make_3d=False, white_channel=channel, **kp)
+            if k == 'annotated_zstacks':
+                interm[k] = InMemoryDataAccessor(draw_boxes_on_3d_image(self, **kp))
+            if k == 'object_classes':
+                pr = 'classify_by_'
+                cnames = [c.split(pr)[1] for c in self._df.columns if c.startswith(pr)]
+                for n in cnames:
+                    interm[f'{k}_{n}'] = self.get_object_class_map(n)
+
+        return interm
+
     def serialize(self, where: Path, prefix='') -> dict:
         """
         Export the minimal information needed to recreate RoiSet object, i.e. CSV data file and tight patch masks
diff --git a/tests/base/test_roiset.py b/tests/base/test_roiset.py
index 44486110..32a2a4b5 100644
--- a/tests/base/test_roiset.py
+++ b/tests/base/test_roiset.py
@@ -367,6 +367,7 @@ class TestRoiSetMultichannelProducts(BaseTestRoiSetMonoProducts, unittest.TestCa
         self.assertEqual(result.nz, self.roiset.acc_raw.nz)
         self.assertEqual(result.chroma, 1)
 
+
     def test_run_exports(self):
         p = RoiSetExportParams(**{
             'patches_3d': {},
@@ -413,6 +414,46 @@ class TestRoiSetMultichannelProducts(BaseTestRoiSetMonoProducts, unittest.TestCa
             for f in test_df[c]:
                 self.assertTrue((where / f).exists(), where / f)
 
+    def test_get_interm_prods(self):
+        p = RoiSetExportParams(**{
+            'patches_3d': None,
+            'annotated_patches_2d': {
+                'draw_bounding_box': True,
+                'rgb_overlay_channels': [3, None, None],
+                'rgb_overlay_weights': [0.2, 1.0, 1.0],
+                'pad_to': 512,
+            },
+            'patches_2d': {
+                'draw_bounding_box': False,
+                'draw_mask': False,
+            },
+            'annotated_zstacks': {},
+            'object_classes': True,
+        })
+        self.roiset.classify_by('dummy_class', [0], DummyInstanceSegmentationModel())
+        interm = self.roiset.get_export_product_accessors(
+            channel=3,
+            params=p
+        )
+        self.assertNotIn('patches_3d', interm.keys())
+        self.assertEqual(
+            interm['annotated_patches_2d'].hw,
+            (self.roiset.get_df().h.max(), self.roiset.get_df().w.max())
+        )
+        self.assertEqual(
+            interm['patches_2d'].hw,
+            (self.roiset.get_df().h.max(), self.roiset.get_df().w.max())
+        )
+        self.assertEqual(
+            interm['annotated_zstacks'].hw,
+            self.stack.hw
+        )
+        self.assertEqual(
+            interm['object_classes_dummy_class'].hw,
+            self.stack.hw
+        )
+        self.assertTrue(np.all(interm['object_classes_dummy_class'].unique()[0] == [0, 1]))
+
     def test_run_export_expanded_2d_patch(self):
         p = RoiSetExportParams(**{
             'patches_2d': {
-- 
GitLab