From 9965d5ea2b9486ae705ea8da60ebfd9b272af3be Mon Sep 17 00:00:00 2001
From: Christopher Rhodes <christopher.rhodes@embl.de>
Date: Sat, 3 Feb 2024 10:22:21 +0100
Subject: [PATCH] Pass RoiSet object to products, much easier

---
 model_server/extensions/chaeo/products.py     | 40 ++++++++++++++++
 .../extensions/chaeo/tests/test_zstack.py     |  2 +-
 model_server/extensions/chaeo/zmask.py        | 46 +++----------------
 3 files changed, 47 insertions(+), 41 deletions(-)

diff --git a/model_server/extensions/chaeo/products.py b/model_server/extensions/chaeo/products.py
index ee773a20..d68b4c36 100644
--- a/model_server/extensions/chaeo/products.py
+++ b/model_server/extensions/chaeo/products.py
@@ -57,6 +57,46 @@ def write_patch_to_file(where, fname, yxcz):
         raise Exception(f'Unsupported file extension: {ext}')
 
 
+def get_patch_masks(roiset, pad_to: int = 256) -> MonoPatchStack:
+
+    patches = []
+    for mi in roiset.zmask_meta:
+        sl = mi['slice']
+
+        rbb = mi['relative_bounding_box']
+        x0 = rbb['x0']
+        y0 = rbb['y0']
+        x1 = rbb['x1']
+        y1 = rbb['y1']
+
+        sp_sl = np.s_[y0: y1, x0: x1, :, :]
+
+        h, w = roiset.acc_raw.data[sl].shape[0:2]
+        patch = np.zeros((h, w, 1, 1), dtype='uint8')
+        patch[sp_sl][:, :, 0, 0] = mi['mask'] * 255
+
+        if pad_to:
+            patch = pad(patch, pad_to)
+
+        patches.append(patch)
+    return MonoPatchStack(patches)
+
+
+def export_patch_masks_from_zstack(roiset, where: Path, pad_to: int = 256, prefix='mask') -> list:
+    patches_acc = roiset.get_patch_masks(pad_to=pad_to)
+
+    exported = []
+    for i in range(0, roiset.count):
+        mi = roiset.zmask_meta[i]
+        obj = mi['info']
+        patch = patches_acc.iat_yxcz(i)
+        ext = 'png'
+        fname = f'{prefix}-la{obj.label:04d}-zi{obj.zi:04d}.{ext}'
+        write_patch_to_file(where, fname, patch)
+        exported.append(fname)
+    return exported
+
+
 def get_patches_from_zmask_meta(
         stack: GenericImageDataAccessor,
         zmask_meta: list,
diff --git a/model_server/extensions/chaeo/tests/test_zstack.py b/model_server/extensions/chaeo/tests/test_zstack.py
index 66d25594..a925a503 100644
--- a/model_server/extensions/chaeo/tests/test_zstack.py
+++ b/model_server/extensions/chaeo/tests/test_zstack.py
@@ -199,7 +199,7 @@ class TestZStackDerivedDataProducts(unittest.TestCase):
             filters={'area': {'min': 1e3, 'max': 1e4}},
             expand_box_by=(128, 2)
         )
-        files = roiset.export_patch_masks_from_zstack(output_path / '2d_mask_patches', )
+        files = roiset.export_patch_masks(output_path / '2d_mask_patches', )
         self.assertGreaterEqual(len(files), 1)
 
     def test_object_map_workflow(self):
diff --git a/model_server/extensions/chaeo/zmask.py b/model_server/extensions/chaeo/zmask.py
index e30b2c18..1111510c 100644
--- a/model_server/extensions/chaeo/zmask.py
+++ b/model_server/extensions/chaeo/zmask.py
@@ -4,7 +4,6 @@ from math import floor
 import numpy as np
 import pandas as pd
 from pathlib import Path
-from typing import List
 
 from skimage.measure import find_contours, label, regionprops_table
 from sklearn.preprocessing import PolynomialFeatures
@@ -19,7 +18,7 @@ from model_server.extensions.chaeo.products import export_patches_from_zstack, e
 from extensions.chaeo.params import RoiSetMetaParams, RoiSetExportParams
 from model_server.extensions.chaeo.accessors import MonoPatchStack
 from model_server.extensions.chaeo.process import mask_largest_object
-from model_server.extensions.chaeo.products import get_patches_from_zmask_meta, write_patch_to_file
+from model_server.extensions.chaeo.products import get_patches_from_zmask_meta, get_patch_masks, export_patch_masks_from_zstack
 
 
 def get_label_ids(acc_seg_mask):
@@ -62,30 +61,11 @@ class RoiSet(object):
     def get_object_mask_by_class(self, class_id):
         return self.object_id_labels == class_id
 
-    def get_patch_masks(self, pad_to: int = 256) -> MonoPatchStack:
-
-        patches = []
-        for mi in self.zmask_meta:
-            sl = mi['slice']
-
-            rbb = mi['relative_bounding_box']
-            x0 = rbb['x0']
-            y0 = rbb['y0']
-            x1 = rbb['x1']
-            y1 = rbb['y1']
-
-            sp_sl = np.s_[y0: y1, x0: x1, :, :]
-
-            h, w = self.acc_raw.data[sl].shape[0:2]
-            patch = np.zeros((h, w, 1, 1), dtype='uint8')
-            patch[sp_sl][:, :, 0, 0] = mi['mask'] * 255
-
-            if pad_to:
-                patch = pad(patch, pad_to)
-
-            patches.append(patch)
-        return MonoPatchStack(patches)
+    def get_patch_masks(self, **kwargs) -> MonoPatchStack:
+        return get_patch_masks(self, **kwargs)
 
+    def export_patch_masks(self, where, **kwargs) -> list:
+        return export_patch_masks_from_zstack(self, where, **kwargs)
 
     def get_raw_patches(self, channel):
         return get_patches_from_zmask_meta(
@@ -99,20 +79,6 @@ class RoiSet(object):
     def get_zmask(self): # TODO: on-the-fly generation of zmask array
         return self.zmask
 
-    def export_patch_masks_from_zstack(self, where: Path, pad_to: int = 256, prefix='mask') -> List:
-        patches_acc = self.get_patch_masks(pad_to=pad_to)
-
-        exported = []
-        for i in range(0, self.count):
-            mi = self.zmask_meta[i]
-            obj = mi['info']
-            patch = patches_acc.iat_yxcz(i)
-            ext = 'png'
-            fname = f'{prefix}-la{obj.label:04d}-zi{obj.zi:04d}.{ext}'
-            write_patch_to_file(where, fname, patch)
-            exported.append(fname)
-        return exported
-
     def classify_by(self, channel, object_classification_model: InstanceSegmentationModel):
         # do this on a patch basis, i.e. only one object per frame
         obmap_patches = object_classification_model.label_instance_class(
@@ -161,7 +127,7 @@ class RoiSet(object):
                 self.df = pd.merge(self.df, df_patches, left_index=True, right_on='df_index').drop(columns='df_index')
                 self.df['patch_id'] = self.df.apply(lambda _: uuid4(), axis=1)
             if k == 'patch_masks':
-                self.export_patch_masks_from_zstack(subdir, prefix=pr)
+                self.export_patch_masks(subdir, prefix=pr)
             if k == 'annotated_zstacks':
                 annotated = InMemoryDataAccessor(
                     draw_boxes_on_3d_image(raw_ch.data, self.zmask_meta, **kp)
-- 
GitLab