From ffd3a19dc0b4cea718647eaf38c9a3bc47d0139b Mon Sep 17 00:00:00 2001
From: Christopher Rhodes <christopher.rhodes@embl.de>
Date: Tue, 6 Feb 2024 18:10:01 +0100
Subject: [PATCH] Progress on running workflow, still throwing errors but
 suspect these just require method signature updates

---
 model_server/extensions/chaeo/accessors.py |  1 +
 model_server/extensions/chaeo/zmask.py     | 10 +++++-----
 2 files changed, 6 insertions(+), 5 deletions(-)

diff --git a/model_server/extensions/chaeo/accessors.py b/model_server/extensions/chaeo/accessors.py
index 31d043e3..bc65c6f3 100644
--- a/model_server/extensions/chaeo/accessors.py
+++ b/model_server/extensions/chaeo/accessors.py
@@ -77,6 +77,7 @@ class MonoPatchStackFromFile(MonoPatchStack):
     def fpath(self):
         return self.file_acc.fpath
 
+# TODO: unify this into one accessor
 class Multichannel3dPatchStack(InMemoryDataAccessor):
 
     def __init__(self, data):
diff --git a/model_server/extensions/chaeo/zmask.py b/model_server/extensions/chaeo/zmask.py
index 242033d8..02e5922a 100644
--- a/model_server/extensions/chaeo/zmask.py
+++ b/model_server/extensions/chaeo/zmask.py
@@ -195,7 +195,7 @@ class RoiSet(object):
         else:
             patches_df = get_patches_from_zmask_meta(self, pad_to=pad_to)
         patches = list(patches_df['patch'])
-        if self.acc_raw.chroma == 1:
+        if channel is not None or self.acc_raw.chroma == 1:
             return MonoPatchStack(patches)
         else:
             return Multichannel3dPatchStack(patches)
@@ -256,13 +256,13 @@ class RoiSet(object):
         # df['classify_by_' + f]
 
         # assign labels to object map:
-        for i in idx:
+        for i in range(0, len(idx)):
             # object_id = self.zmask_meta[i]['info'].label
-            object_id = df.loc[i, 'label']
+            object_id = df.loc[idx[i], 'label']
             result_patch = mask_largest_object(obmap_patches.iat(i))
             object_class = np.unique(result_patch)[1]
             om[self.object_id_labels == object_id] = object_class
-            se.loc[i] = object_class
+            se.loc[idx[i]] = object_class
 
         self.add_df_col('classify_by_' + name, se)
         self.object_class_map[name] = InMemoryDataAccessor(om)
@@ -279,7 +279,7 @@ class RoiSet(object):
                 continue
             if k == 'patches_3d':
                 files = export_patches_from_zstack(
-                    subdir, raw_ch, self.zmask_meta, prefix=pr, make_3d=True, **kp
+                    subdir, self, white_channel=channel, prefix=pr, make_3d=True, **kp
                 )
             if k == 'patches_2d_for_annotation':
                 files = export_multichannel_patches_from_zstack(
-- 
GitLab