From fcb1dedf2ff715b2ae800bdc6c9f7cb46dcd4d68 Mon Sep 17 00:00:00 2001
From: Christopher Rhodes <christopher.rhodes@embl.de>
Date: Thu, 2 Nov 2023 15:43:28 +0100
Subject: [PATCH] Support 3D patch accessor; RGB patch tests still failing

---
 extensions/chaeo/products.py | 31 +++++++++++++++++--------------
 1 file changed, 17 insertions(+), 14 deletions(-)

diff --git a/extensions/chaeo/products.py b/extensions/chaeo/products.py
index a9310e37..3f93588c 100644
--- a/extensions/chaeo/products.py
+++ b/extensions/chaeo/products.py
@@ -9,7 +9,7 @@ from skimage.io import imsave
 from skimage.measure import find_contours, shannon_entropy
 from tifffile import imwrite
 
-from extensions.chaeo.accessors import MonoPatchStack
+from extensions.chaeo.accessors import MonoPatchStack, PatchStack3D
 from extensions.chaeo.annotators import draw_box_on_patch, draw_contours_on_patch
 from model_server.accessors import GenericImageDataAccessor, InMemoryDataAccessor
 from model_server.process import pad, rescale, resample_to_8bit
@@ -31,25 +31,25 @@ def _focus_metrics():
         'moment': lambda x: moment(x.flatten(), moment=2),
     }
 
-def _write_patch_to_file(where, fname, data):
+def _write_patch_to_file(where, fname, yxcz):
     ext = fname.split('.')[-1].upper()
     where.mkdir(parents=True, exist_ok=True)
 
     if ext == 'PNG':
-        assert data.dtype == 'uint8', f'Invalid data type {data.dtype}'
-        assert data.shape[2] <= 3, f'Cannot export images with more than 3 channels as PNGs'
-        assert data.shape[3] == 1, f'Cannot export z-stacks as PNGs'
-        if data.shape[2] == 1:
-            outdata = data[:, :, 0, 0]
-        elif data.shape[2] == 2: # add a blank blue channel
-            outdata = _make_rgb(data)
+        assert yxcz.dtype == 'uint8', f'Invalid data type {yxcz.dtype}'
+        assert yxcz.shape[2] <= 3, f'Cannot export images with more than 3 channels as PNGs'
+        assert yxcz.shape[3] == 1, f'Cannot export z-stacks as PNGs'
+        if yxcz.shape[2] == 1:
+            outdata = yxcz[:, :, 0, 0]
+        elif yxcz.shape[2] == 2: # add a blank blue channel
+            outdata = _make_rgb(yxcz)
         else: # preserve RGB order
-            outdata = data[:, :, :, 0]
+            outdata = yxcz[:, :, :, 0]
         imsave(where / fname, outdata, check_contrast=False)
         return True
 
     elif ext in ['TIF', 'TIFF']:
-        zcyx = np.moveaxis(data, [3, 2, 0, 1], [0, 1, 2, 3])
+        zcyx = np.moveaxis(yxcz, [3, 2, 0, 1], [0, 1, 2, 3])
         imwrite(where / fname, zcyx, imagej=True)
         return True
 
@@ -103,7 +103,7 @@ def export_patch_masks_from_zstack(
     for i in range(0, len(zmask_meta)):
         mi = zmask_meta[i]
         obj = mi['info']
-        patch = patches_acc.iat(i)
+        patch = patches_acc.iat_yxcz(i)
         ext = 'png'
         fname = f'{prefix}-la{obj.label:04d}-zi{obj.zi:04d}.{ext}'
         _write_patch_to_file(where, fname, patch)
@@ -197,7 +197,10 @@ def get_patches_from_zmask_meta(
             patch = pad(patch, pad_to)
 
         patches.append(patch)
-    return MonoPatchStack(patches)
+    if not make_3d:
+        return MonoPatchStack(patches)
+    else:
+        return PatchStack3D(patches)
 
 def export_patches_from_zstack(
         where: Path,
@@ -224,7 +227,7 @@ def export_patches_from_zstack(
     exported = []
     for i in range(0, len(zmask_meta)):
         mi = zmask_meta[i]
-        patch = patches_acc.iat(i)
+        patch = patches_acc.iat_yxcz(i)
         obj = mi['info']
         idx = mi['df_index']
         ext = 'tif' if make_3d else 'png'
-- 
GitLab