From 3fa8292a1da8b43c623f5c346af9407d82f21789 Mon Sep 17 00:00:00 2001
From: Christopher Rhodes <christopher.rhodes@embl.de>
Date: Thu, 15 Aug 2024 16:46:49 +0200
Subject: [PATCH] API maps RoiSet products as accessors

---
 model_server/base/models.py                 |  3 +--
 model_server/base/pipelines/roiset_obmap.py | 19 +++++--------------
 model_server/base/session.py                |  7 +++++--
 tests/test_ilastik/test_roiset_workflow.py  | 12 ++++--------
 4 files changed, 15 insertions(+), 26 deletions(-)

diff --git a/model_server/base/models.py b/model_server/base/models.py
index ffaa346f..eaded2c9 100644
--- a/model_server/base/models.py
+++ b/model_server/base/models.py
@@ -50,7 +50,7 @@ class Model(ABC):
 
     @property
     def name(self):
-        return f'{self.__class__}'
+        return f'{self.__class__.__name__}'
 
 
 
@@ -131,7 +131,6 @@ class InstanceSegmentationModel(ImageToImageModel):
 
 class BinaryThresholdSegmentationModel(SemanticSegmentationModel):
 
-    # TODO: also allow relative threshold
     def __init__(self, tr: float = 0.5):
         self.tr = tr
 
diff --git a/model_server/base/pipelines/roiset_obmap.py b/model_server/base/pipelines/roiset_obmap.py
index 376dbe3c..96e03390 100644
--- a/model_server/base/pipelines/roiset_obmap.py
+++ b/model_server/base/pipelines/roiset_obmap.py
@@ -53,20 +53,7 @@ class RoiSetObjectMapParams(PipelineParams):
         },
         'expand_box_by': [128, 2]
     })
-    # TODO: maybe don't support all these exports here; instead leverage interm accessors
-    export_params: RoiSetExportParams = RoiSetExportParams(**{
-        'annotated_patches_2d': {
-            'draw_bounding_box': True,
-            'pad_to': 256,
-        },
-        'patches_2d': {
-            'draw_bounding_box': True,
-            'draw_mask': False,
-        },
-        'patch_masks': {},
-        'object_classes': True,
-        'dataframe': True,
-    })
+    export_params: RoiSetExportParams = RoiSetExportParams()
     derived_channels_input_channel: Union[int, None] = Field(
         None,
         description='Channel of input image from which to compute derived channels; use all if empty'
@@ -114,6 +101,10 @@ def roiset_object_map_pipeline(
     d['labeled'] = get_label_ids(d.last)
     rois = RoiSet.from_object_ids(d['input'], d['labeled'], RoiSetMetaParams(**k['roi_params']))
 
+    # optionally append RoiSet products
+    for ki, vi in rois.get_export_product_accessors(k['patches_channel'], RoiSetExportParams(**k['export_params'])).items():
+        d[ki] = vi
+
     # optionally run an object classifier if specified
     if obmod := models.get('object_classifier_model'):
         obmod_name = k['object_classifier_model_id']
diff --git a/model_server/base/session.py b/model_server/base/session.py
index 8358a957..f22655c6 100644
--- a/model_server/base/session.py
+++ b/model_server/base/session.py
@@ -10,7 +10,7 @@ from typing import Union
 import pandas as pd
 
 from ..conf import defaults
-from .accessors import GenericImageDataAccessor
+from .accessors import GenericImageDataAccessor, PatchStack
 from .models import Model
 
 logger = logging.getLogger(__name__)
@@ -166,7 +166,10 @@ class _Session(object):
                 f'Cannot overwrite accessor that is already written to {old_fp}'
             )
 
-        acc.write(fp)
+        if isinstance(acc, PatchStack):
+            acc.export_pyxcz(fp)
+        else:
+            acc.write(fp)
         self.accessors[acc_id]['filepath'] = fp.__str__()
         return fp.name
 
diff --git a/tests/test_ilastik/test_roiset_workflow.py b/tests/test_ilastik/test_roiset_workflow.py
index cd6d2ded..c02dd6d9 100644
--- a/tests/test_ilastik/test_roiset_workflow.py
+++ b/tests/test_ilastik/test_roiset_workflow.py
@@ -39,8 +39,7 @@ class BaseTestRoiSetMonoProducts(object):
 
     def _get_export_params(self):
         return {
-            'pixel_probabilities': True,
-            'patches_3d': {},
+            'patches_3d': None,
             'annotated_patches_2d': {
                 'draw_bounding_box': True,
                 'rgb_overlay_channels': [3, None, None],
@@ -51,12 +50,8 @@ class BaseTestRoiSetMonoProducts(object):
                 'draw_bounding_box': False,
                 'draw_mask': False,
             },
-            'patch_masks': {
-                'pad_to': 256,
-            },
-            'annotated_zstacks': {},
+            'annotated_zstacks': None,
             'object_classes': True,
-            'dataframe': True,
         }
 
     def _get_roi_params(self):
@@ -119,7 +114,8 @@ class TestRoiSetWorkflow(BaseTestRoiSetMonoProducts, unittest.TestCase):
             {f'{k}_model': v['model'] for k, v in self._get_models().items()},
             **params.dict()
         )
-
+        self.assertEqual(trace.pop('annotated_patches_2d').count, 13)
+        self.assertEqual(trace.pop('patches_2d').count, 13)
         trace.write_interm(Path(output_path) / 'trace', 'roiset_worfklow_trace', skip_first=False, skip_last=False)
         self.assertTrue('ob_id' in trace.keys())
         self.assertEqual(len(trace['labeled'].unique()[0]), 14)
-- 
GitLab