From 6a24b7e2c0a17a1177a77250484dd06f6656048c Mon Sep 17 00:00:00 2001
From: Christopher Rhodes <christopher.rhodes@embl.de>
Date: Thu, 21 Dec 2023 13:42:43 +0100
Subject: [PATCH] Removed direct pixel map access from workflow function

---
 extensions/chaeo/params.py            | 25 +++++++++-------
 extensions/chaeo/tests/test_zstack.py | 16 +++++++----
 extensions/chaeo/workflows.py         | 41 ++++++++-------------------
 extensions/chaeo/zmask.py             | 15 +++++-----
 4 files changed, 44 insertions(+), 53 deletions(-)

diff --git a/extensions/chaeo/params.py b/extensions/chaeo/params.py
index ee4dabd6..5a226925 100644
--- a/extensions/chaeo/params.py
+++ b/extensions/chaeo/params.py
@@ -1,4 +1,4 @@
-from typing import Dict, List, Union
+from typing import List, Union
 
 from pydantic import BaseModel
 
@@ -17,11 +17,22 @@ class PatchParams(BaseModel):
 class AnnotatedZStackParams(BaseModel):
     draw_label: bool = False
 
-class RoiSetExportMetaParams(BaseModel):
+class RoiFilterRange(BaseModel):
+    min: float
+    max: float
+
+class RoiFilter(BaseModel):
+    area: Union[RoiFilterRange, None] = None
+    solidity: Union[RoiFilterRange, None] = None
+
+
+class RoiSetMetaParams(BaseModel):
+    mask_type: str = 'boxes'
+    filters: RoiFilter = None
     expand_box_by: List[int] = [128, 0]
 
+
 class RoiSetExportParams(BaseModel):
-    meta: RoiSetExportMetaParams = RoiSetExportMetaParams()
     pixel_probabilities: bool = False
     patches_3d: Union[PatchParams, None] = None
     patches_2d_for_annotation: Union[PatchParams, None] = None
@@ -30,11 +41,3 @@ class RoiSetExportParams(BaseModel):
     annotated_zstacks: Union[AnnotatedZStackParams, None] = None
 
 
-class RoiFilterRange(BaseModel):
-    min: float
-    max: float
-
-
-class RoiFilter(BaseModel):
-    area: Union[RoiFilterRange, None] = None
-    solidity: Union[RoiFilterRange, None] = None
diff --git a/extensions/chaeo/tests/test_zstack.py b/extensions/chaeo/tests/test_zstack.py
index be14b104..9e6e3a74 100644
--- a/extensions/chaeo/tests/test_zstack.py
+++ b/extensions/chaeo/tests/test_zstack.py
@@ -5,7 +5,7 @@ import numpy as np
 from conf.testing import output_path
 
 from extensions.chaeo.conf.testing import multichannel_zstack, pixel_classifier, pipeline_params
-from extensions.chaeo.params import RoiSetExportParams
+from extensions.chaeo.params import RoiSetExportParams, RoiSetMetaParams
 from extensions.chaeo.products import export_patches_from_zstack, export_multichannel_patches_from_zstack, export_patch_masks_from_zstack
 from extensions.chaeo.workflows import infer_object_map_from_zstack
 from extensions.chaeo.zmask import build_zmask_from_object_mask
@@ -208,8 +208,14 @@ class TestZStackDerivedDataProducts(unittest.TestCase):
             self.pxmodel,
             DummyInstanceSegmentationModel(),
         ]
+
+        roi_params = RoiSetMetaParams(**{
+            'mask_type': 'boxes',
+            'filters': {},
+            'expand_box_by': [128, 2]
+        })
+
         export_params = RoiSetExportParams(**{
-            'expand_box_by': [128, 2],
             'pixel_probabilities': True,
             'patches_3d': {},
             'patches_2d_for_annotation': {
@@ -229,10 +235,10 @@ class TestZStackDerivedDataProducts(unittest.TestCase):
             multichannel_zstack['path'],
             output_path / 'roiset' / 'workflow',
             models,
-            pxmap_foreground_channel=pp['pxmap_channel'],
-            pxmap_threshold=pp['pxmap_threshold'],
+            pixel_class=pp['pxmap_channel'],
+            pixel_probability_threshold=pp['pxmap_threshold'],
             segmentation_channel=pp['segmentation_channel'],
             patches_channel=pp['patches_channel'],
-            exports=export_params,
+            export_params=export_params,
         )
 
diff --git a/extensions/chaeo/workflows.py b/extensions/chaeo/workflows.py
index 9171aaa0..4b3f380d 100644
--- a/extensions/chaeo/workflows.py
+++ b/extensions/chaeo/workflows.py
@@ -9,7 +9,7 @@ from skimage.measure import label, regionprops_table
 from skimage.morphology import dilation
 from sklearn.model_selection import train_test_split
 
-from extensions.chaeo.params import RoiSetExportParams
+from extensions.chaeo.params import RoiSetExportParams, RoiSetMetaParams
 from extensions.chaeo.process import mask_largest_object
 from extensions.chaeo.zmask import RoiSet
 
@@ -23,15 +23,14 @@ def infer_object_map_from_zstack(
         input_file_path: str,
         output_folder_path: str,
         models: List[Model],
-        pxmap_foreground_channel: int,
-        pxmap_threshold: float,
         segmentation_channel: int,
         patches_channel: int,
         zmask_zindex: int = None,  # None for MIP,
         zmask_clip: int = None,
-        zmask_type: str = 'boxes',
-        zmask_filters: Dict = None,
-        exports: RoiSetExportParams = RoiSetExportParams(),
+        roi_params: RoiSetMetaParams = RoiSetMetaParams(),
+        export_params: RoiSetExportParams = RoiSetExportParams(),
+        pixel_class=0,
+        pixel_probability_threshold=0.6,
 ) -> Dict:
     assert len(models) == 2
     pixel_classifier = models[0]
@@ -55,29 +54,13 @@ def infer_object_map_from_zstack(
     mip = InMemoryDataAccessor(
         zmask_data,
     )
-    pxmap, _ = pixel_classifier.infer(mip)
-    ti.click('infer_pixel_probability')
-
-    # if exports.pixel_probabilities:
-    #     write_accessor_data_to_file(
-    #         Path(output_folder_path) / 'pixel_probabilities' / (fstem + '.tif'),
-    #         pxmap
-    #     )
-    #     ti.click('export_pixel_probability')
-
-    obmask = InMemoryDataAccessor(
-        pxmap.data > pxmap_threshold
-    )
-    ti.click('threshold_pixel_mask')
+    # pxmap, _ = pixel_classifier.infer(mip)
+    mip_mask = pixel_classifier.label_pixel_class(mip, pixel_class, pixel_probability_threshold,)
+    ti.click('classify_pixels')
 
     # make zmask
-    rois = RoiSet(
-        obmask.get_one_channel_data(pxmap_foreground_channel),
-        stack,
-        mask_type=zmask_type,
-        filters=zmask_filters,
-        expand_box_by=exports.meta.expand_box_by,
-    )
+    # rois = RoiSet(mip_mask, stack, mask_type=zmask_type, filters=zmask_filters, expand_box_by=meta.expand_box_by)
+    rois = RoiSet(mip_mask, stack, params=roi_params)
     ti.click('generate_zmasks')
 
     object_class_map = rois.classify_by(patches_channel, object_classifier)
@@ -90,7 +73,8 @@ def infer_object_map_from_zstack(
     )
     ti.click('export_object_classes')
 
-    rois.run_exports(Path(output_folder_path), patches_channel, fstem, exports)
+    rois.run_exports(Path(output_folder_path), patches_channel, fstem, export_params)
+    ti.click('export_roi_products')
 
     return {
         'timer_results': ti.events,
@@ -100,7 +84,6 @@ def infer_object_map_from_zstack(
     }
 
 
-
 def transfer_ecotaxa_labels_to_patch_stacks(
     where_masks: str,
     where_patches: str,
diff --git a/extensions/chaeo/zmask.py b/extensions/chaeo/zmask.py
index 95d2a73f..900ab02e 100644
--- a/extensions/chaeo/zmask.py
+++ b/extensions/chaeo/zmask.py
@@ -9,7 +9,7 @@ from sklearn.linear_model import LinearRegression
 
 from extensions.chaeo.annotators import draw_boxes_on_3d_image
 from extensions.chaeo.products import export_patches_from_zstack, export_multichannel_patches_from_zstack, export_patch_masks_from_zstack, get_patches_from_zmask_meta, get_patch_masks_from_zmask_meta
-from extensions.chaeo.params import AnnotatedZStackParams, PatchParams, RoiFilter, RoiSetExportParams
+from extensions.chaeo.params import AnnotatedZStackParams, PatchParams, RoiFilter, RoiSetMetaParams, RoiSetExportParams
 from extensions.chaeo.process import mask_largest_object
 from model_server.accessors import GenericImageDataAccessor, InMemoryDataAccessor, write_accessor_data_to_file
 from model_server.models import InstanceSegmentationModel
@@ -21,16 +21,14 @@ class RoiSet(object):
             self,
             acc_mask: GenericImageDataAccessor,
             acc_raw: GenericImageDataAccessor,
-            filters=None,
-            mask_type='contours',
-            expand_box_by=(0, 0),
+            params: RoiSetMetaParams = RoiSetMetaParams(),
     ):
         self.zmask, self.zmask_meta, self.df, self.interm = build_zmask_from_object_mask(
             acc_mask,
             acc_raw,
-            filters=filters,
-            mask_type=mask_type,
-            expand_box_by=expand_box_by
+            filters=params.filters,
+            mask_type=params.mask_type,
+            expand_box_by=params.expand_box_by
         )
 
         self.acc_raw = acc_raw
@@ -106,7 +104,7 @@ class RoiSet(object):
             subdir = where / k
             pr = prefix
             kp = params.dict()[k]
-            if k == 'meta' or kp is None:
+            if kp is None:
                 continue
             if k == 'patches_3d':
                 files = export_patches_from_zstack(
@@ -136,6 +134,7 @@ class RoiSet(object):
 
 
 
+
 def build_zmask_from_object_mask(
         obmask: GenericImageDataAccessor,
         zstack: GenericImageDataAccessor,
-- 
GitLab