From 9dd88f7ee45424a47ae173d6be9bb7b0550a4b90 Mon Sep 17 00:00:00 2001
From: Christopher Rhodes <christopher.rhodes@embl.de>
Date: Fri, 29 Sep 2023 11:04:38 +0200
Subject: [PATCH] Consolidated utility methods and added tests for 2, 3, 4
 dimension padding

---
 model_server/process.py | 52 ++++++++++++++++++++++++-----------------
 tests/test_process.py   | 30 ++++++++++++++++++++++++
 2 files changed, 61 insertions(+), 21 deletions(-)
 create mode 100644 tests/test_process.py

diff --git a/model_server/process.py b/model_server/process.py
index 26fd3005..2bc09422 100644
--- a/model_server/process.py
+++ b/model_server/process.py
@@ -6,38 +6,42 @@ from math import ceil, floor
 import numpy as np
 from skimage.exposure import rescale_intensity
 
-
-def pad(im, mpx):  # now in model_server.batch
-    '''Pads and crops image width edge values to specified dimension'''
-    dh = 0.5 * (mpx - im.shape[0])
-    dw = 0.5 * (mpx - im.shape[1])
+def pad(yxcz, mpx: int):
+    """
+    Pad and crop image data in Y and X axes to meet specific dimension
+    :param yxcz: np.ndarray
+    :param mpx: int pixel size of resulting square
+    :return: np.ndarray array of size (mpx, mpx, nc, nz)
+    """
+    assert len(yxcz.shape) == 4
+    nc = yxcz.shape[2]
+    nz = yxcz.shape[3]
+    dh = 0.5 * (mpx - yxcz.shape[0])
+    dw = 0.5 * (mpx - yxcz.shape[1])
 
     if dw < 0:
         x0 = floor(-dw)
         x1 = x0 + mpx
-        im = im[:, x0:x1]
+        yxcz = yxcz[:, x0:x1, :, :]
         dw = 0
     if dh < 0:
         y0 = floor(-dh)
         y1 = y0 + mpx
-        im = im[y0:y1, :]
+        yxcz = yxcz[y0:y1, :, :, :]
         dh = 0
 
-    border = ((floor(dh), ceil(dh)), (floor(dw), ceil(dw)))
-    padded = np.pad(im, border, mode='constant')
-    if padded.shape != (mpx, mpx):
-        raise Exception(f'Incorrect image shape: {padded.shape} v. {(mpx, mpx)}')
-    return padded
-
-def pad_3d(im, mpx): # im: [z x h x w]
-    assert(len(im.shape) == 3)
-    nz, h, w = im.shape
-    padded = np.zeros((nz, mpx, mpx), dtype=im.dtype)
-    for zi in range(nz):
-        padded[zi, :, :] = pad(im[zi, :, :], mpx)
+    border = ((floor(dh), ceil(dh)), (floor(dw), ceil(dw)), (0, 0), (0, 0))
+    padded = np.pad(yxcz, border, mode='constant')
     return padded
 
-def resample(nda, cmin=0, cmax=2**16): # now in model_server.batch
+def resample_to_8bit(nda, cmin=0, cmax=2**16):
+    """
+    Resample a 16 bit image to 8 bit, optionally bracketing a given intensity range
+    :param nda: np.ndarray input data of arbitrary dimension
+    :param cmin: intensity level on 16-bit scale that become zero in 8-bit scale
+    :param cmax: intensity level on 16-bit scale that become maximum (255) in 8-bit scale
+    :return: rescaled data of same dimension as input
+    """
     return rescale_intensity(
         np.clip(nda, cmin, cmax),
         in_range=(cmin, cmax + 1),
@@ -45,7 +49,13 @@ def resample(nda, cmin=0, cmax=2**16): # now in model_server.batch
     ).astype('uint8')
 
 
-def rescale(nda, clip=0.0): # now in model_server.batch
+def rescale(nda, clip=0.0):
+    """
+    Rescale an image for a given clipping ratio
+    :param nda: input data of arbitrary dimension and scale
+    :param clip: Ratio of clipping in the resulting image
+    :return: rescaled image of same dimension as input
+    """
     clip_pct = (100.0 * clip, 100.0 * (1.0 - clip))
     cmin, cmax = np.percentile(nda, clip_pct)
     rescaled = rescale_intensity(nda, in_range=(cmin, cmax))
diff --git a/tests/test_process.py b/tests/test_process.py
new file mode 100644
index 00000000..bef8426e
--- /dev/null
+++ b/tests/test_process.py
@@ -0,0 +1,30 @@
+import unittest
+
+import numpy as np
+
+from model_server.process import pad
+
+class TestProcessingUtilityMethods(unittest.TestCase):
+    def setUp(self) -> None:
+        w = 200
+        h = 300
+        nc = 4
+        nz = 11
+        self.data2d = (2**16 * np.random.rand(h, w, 1, 1)).astype('uint16')
+        self.data3d = (2**16 * np.random.rand(h, w, 1, nz)).astype('uint16')
+        self.data4d = (2**16 * np.random.rand(h, w, nc, nz)).astype('uint16')
+
+    def test_pad_2d(self):
+        padded = pad(self.data2d, 256)
+        self.assertEqual(padded.shape, (256, 256, 1, 1))
+
+    def test_pad_3d(self):
+        nz = self.data3d.shape[3]
+        padded = pad(self.data3d, 256)
+        self.assertEqual(padded.shape, (256, 256, 1, nz))
+
+    def test_pad_4d(self):
+        nc = self.data4d.shape[2]
+        nz = self.data4d.shape[3]
+        padded = pad(self.data4d, 256)
+        self.assertEqual(padded.shape, (256, 256, nc, nz))
\ No newline at end of file
-- 
GitLab