From 6d9f27b0f3129a97a9a4c4bed1401e89aedb47d7 Mon Sep 17 00:00:00 2001
From: Christopher Rhodes <christopher.rhodes@embl.de>
Date: Thu, 21 Dec 2023 14:51:03 +0100
Subject: [PATCH] Dummy object classification exports map

---
 extensions/chaeo/params.py            |  1 +
 extensions/chaeo/tests/test_zstack.py |  3 ++-
 extensions/chaeo/workflows.py         | 18 +++++-------------
 extensions/chaeo/zmask.py             | 13 ++++++-------
 model_server/models.py                |  6 ++++--
 5 files changed, 18 insertions(+), 23 deletions(-)

diff --git a/extensions/chaeo/params.py b/extensions/chaeo/params.py
index 5a226925..0ba956fb 100644
--- a/extensions/chaeo/params.py
+++ b/extensions/chaeo/params.py
@@ -39,5 +39,6 @@ class RoiSetExportParams(BaseModel):
     patches_2d_for_training: Union[PatchParams, None] = None
     patch_masks: bool = False
     annotated_zstacks: Union[AnnotatedZStackParams, None] = None
+    object_classes: bool = False
 
 
diff --git a/extensions/chaeo/tests/test_zstack.py b/extensions/chaeo/tests/test_zstack.py
index 9e6e3a74..7778370a 100644
--- a/extensions/chaeo/tests/test_zstack.py
+++ b/extensions/chaeo/tests/test_zstack.py
@@ -229,7 +229,8 @@ class TestZStackDerivedDataProducts(unittest.TestCase):
                 'draw_mask': False,
             },
             'patch_masks': True,
-            'annotated_zstacks': {}
+            'annotated_zstacks': {},
+            'object_classes': True
         })
         infer_object_map_from_zstack(
             multichannel_zstack['path'],
diff --git a/extensions/chaeo/workflows.py b/extensions/chaeo/workflows.py
index 4b3f380d..7a0e0941 100644
--- a/extensions/chaeo/workflows.py
+++ b/extensions/chaeo/workflows.py
@@ -54,33 +54,25 @@ def infer_object_map_from_zstack(
     mip = InMemoryDataAccessor(
         zmask_data,
     )
-    # pxmap, _ = pixel_classifier.infer(mip)
+
     mip_mask = pixel_classifier.label_pixel_class(mip, pixel_class, pixel_probability_threshold,)
     ti.click('classify_pixels')
 
     # make zmask
-    # rois = RoiSet(mip_mask, stack, mask_type=zmask_type, filters=zmask_filters, expand_box_by=meta.expand_box_by)
     rois = RoiSet(mip_mask, stack, params=roi_params)
     ti.click('generate_zmasks')
 
-    object_class_map = rois.classify_by(patches_channel, object_classifier)
-
-    # TODO: add ZMaskObjectTable method to export object map
-    output_path = Path(output_folder_path) / ('obj_classes_' + (fstem + '.tif'))
-    write_accessor_data_to_file(
-        output_path,
-        object_class_map
-    )
-    ti.click('export_object_classes')
+    rois.classify_by(patches_channel, object_classifier)
+    ti.click('classify_objects')
 
     rois.run_exports(Path(output_folder_path), patches_channel, fstem, export_params)
     ti.click('export_roi_products')
 
     return {
         'timer_results': ti.events,
-        'dataframe':     rois.df,
+        'dataframe': rois.df,
         'interm': {},
-        'output_path': output_path.__str__(),
+        'output_path': output_folder_path,
     }
 
 
diff --git a/extensions/chaeo/zmask.py b/extensions/chaeo/zmask.py
index 13c54b7d..1441e9ae 100644
--- a/extensions/chaeo/zmask.py
+++ b/extensions/chaeo/zmask.py
@@ -32,6 +32,7 @@ class RoiSet(object):
         self.acc_raw = acc_raw
         self.count = len(self.zmask_meta)
         self.object_id_labels = self.interm['label_map']
+        self.object_class_map = None
 
     def get_argmax(self):
         return self.interm.argmax
@@ -67,7 +68,7 @@ class RoiSet(object):
         )
 
         lamap = self.object_id_labels
-        output_map = np.zeros(lamap.shape, dtype=lamap.dtype)
+        om = np.zeros(lamap.shape, dtype=lamap.dtype)
         self.df['instance_class'] = np.nan
 
         # assign labels to object map:
@@ -75,10 +76,10 @@ class RoiSet(object):
             object_id = self.zmask_meta[ii]['info'].label
             result_patch = mask_largest_object(obmap_patches.iat(ii))
             object_class = np.unique(result_patch)[1]
-            output_map[self.object_id_labels == object_id] = object_class
+            om[self.object_id_labels == object_id] = object_class
             self.df[object_id, 'instance_class'] = object_class
 
-        return InMemoryDataAccessor(output_map)
+        self.object_class_map = InMemoryDataAccessor(om)
 
     # TODO: test
     def get_object_mask_by_id(self, obj_id):
@@ -91,9 +92,6 @@ class RoiSet(object):
     def get_object_patch_by_id(self, obj_id):
         pass
 
-    def get_object_map(self, filters: RoiFilter):
-        pass
-
     def run_exports(self, where, channel, prefix, params: RoiSetExportParams):
         if not self.count:
             return
@@ -129,7 +127,8 @@ class RoiSet(object):
                     draw_boxes_on_3d_image(raw_ch.data, self.zmask_meta, **kp)
                 )
                 write_accessor_data_to_file(subdir / (pr + '.tif'), annotated)
-
+            if k == 'object_classes':
+                write_accessor_data_to_file(subdir / (pr + '.tif'), self.object_class_map)
 
 
 
diff --git a/model_server/models.py b/model_server/models.py
index 06bef01c..8239d1ca 100644
--- a/model_server/models.py
+++ b/model_server/models.py
@@ -120,13 +120,15 @@ class DummyInstanceSegmentationModel(InstanceSegmentationModel):
     def infer(
             self, img: GenericImageDataAccessor, mask: GenericImageDataAccessor
     ) -> (GenericImageDataAccessor, dict):
-        return mask
+        return img.__class__(
+            (mask.data / mask.data.max()).astype('uint16')
+        )
 
     def label_instance_class(
             self, img: GenericImageDataAccessor, mask: GenericImageDataAccessor, **kwargs
     ) -> GenericImageDataAccessor:
         """
-        Returns a trivial segmentation, i.e. the input mask
+        Returns a trivial segmentation, i.e. the input mask with value 1
         """
         super(DummyInstanceSegmentationModel, self).label_instance_class(img, mask, **kwargs)
         return self.infer(img, mask)
-- 
GitLab