From 9dd88f7ee45424a47ae173d6be9bb7b0550a4b90 Mon Sep 17 00:00:00 2001 From: Christopher Rhodes <christopher.rhodes@embl.de> Date: Fri, 29 Sep 2023 11:04:38 +0200 Subject: [PATCH] Consolidated utility methods and added tests for 2, 3, 4 dimension padding --- model_server/process.py | 52 ++++++++++++++++++++++++----------------- tests/test_process.py | 30 ++++++++++++++++++++++++ 2 files changed, 61 insertions(+), 21 deletions(-) create mode 100644 tests/test_process.py diff --git a/model_server/process.py b/model_server/process.py index 26fd3005..2bc09422 100644 --- a/model_server/process.py +++ b/model_server/process.py @@ -6,38 +6,42 @@ from math import ceil, floor import numpy as np from skimage.exposure import rescale_intensity - -def pad(im, mpx): # now in model_server.batch - '''Pads and crops image width edge values to specified dimension''' - dh = 0.5 * (mpx - im.shape[0]) - dw = 0.5 * (mpx - im.shape[1]) +def pad(yxcz, mpx: int): + """ + Pad and crop image data in Y and X axes to meet specific dimension + :param yxcz: np.ndarray + :param mpx: int pixel size of resulting square + :return: np.ndarray array of size (mpx, mpx, nc, nz) + """ + assert len(yxcz.shape) == 4 + nc = yxcz.shape[2] + nz = yxcz.shape[3] + dh = 0.5 * (mpx - yxcz.shape[0]) + dw = 0.5 * (mpx - yxcz.shape[1]) if dw < 0: x0 = floor(-dw) x1 = x0 + mpx - im = im[:, x0:x1] + yxcz = yxcz[:, x0:x1, :, :] dw = 0 if dh < 0: y0 = floor(-dh) y1 = y0 + mpx - im = im[y0:y1, :] + yxcz = yxcz[y0:y1, :, :, :] dh = 0 - border = ((floor(dh), ceil(dh)), (floor(dw), ceil(dw))) - padded = np.pad(im, border, mode='constant') - if padded.shape != (mpx, mpx): - raise Exception(f'Incorrect image shape: {padded.shape} v. {(mpx, mpx)}') - return padded - -def pad_3d(im, mpx): # im: [z x h x w] - assert(len(im.shape) == 3) - nz, h, w = im.shape - padded = np.zeros((nz, mpx, mpx), dtype=im.dtype) - for zi in range(nz): - padded[zi, :, :] = pad(im[zi, :, :], mpx) + border = ((floor(dh), ceil(dh)), (floor(dw), ceil(dw)), (0, 0), (0, 0)) + padded = np.pad(yxcz, border, mode='constant') return padded -def resample(nda, cmin=0, cmax=2**16): # now in model_server.batch +def resample_to_8bit(nda, cmin=0, cmax=2**16): + """ + Resample a 16 bit image to 8 bit, optionally bracketing a given intensity range + :param nda: np.ndarray input data of arbitrary dimension + :param cmin: intensity level on 16-bit scale that become zero in 8-bit scale + :param cmax: intensity level on 16-bit scale that become maximum (255) in 8-bit scale + :return: rescaled data of same dimension as input + """ return rescale_intensity( np.clip(nda, cmin, cmax), in_range=(cmin, cmax + 1), @@ -45,7 +49,13 @@ def resample(nda, cmin=0, cmax=2**16): # now in model_server.batch ).astype('uint8') -def rescale(nda, clip=0.0): # now in model_server.batch +def rescale(nda, clip=0.0): + """ + Rescale an image for a given clipping ratio + :param nda: input data of arbitrary dimension and scale + :param clip: Ratio of clipping in the resulting image + :return: rescaled image of same dimension as input + """ clip_pct = (100.0 * clip, 100.0 * (1.0 - clip)) cmin, cmax = np.percentile(nda, clip_pct) rescaled = rescale_intensity(nda, in_range=(cmin, cmax)) diff --git a/tests/test_process.py b/tests/test_process.py new file mode 100644 index 00000000..bef8426e --- /dev/null +++ b/tests/test_process.py @@ -0,0 +1,30 @@ +import unittest + +import numpy as np + +from model_server.process import pad + +class TestProcessingUtilityMethods(unittest.TestCase): + def setUp(self) -> None: + w = 200 + h = 300 + nc = 4 + nz = 11 + self.data2d = (2**16 * np.random.rand(h, w, 1, 1)).astype('uint16') + self.data3d = (2**16 * np.random.rand(h, w, 1, nz)).astype('uint16') + self.data4d = (2**16 * np.random.rand(h, w, nc, nz)).astype('uint16') + + def test_pad_2d(self): + padded = pad(self.data2d, 256) + self.assertEqual(padded.shape, (256, 256, 1, 1)) + + def test_pad_3d(self): + nz = self.data3d.shape[3] + padded = pad(self.data3d, 256) + self.assertEqual(padded.shape, (256, 256, 1, nz)) + + def test_pad_4d(self): + nc = self.data4d.shape[2] + nz = self.data4d.shape[3] + padded = pad(self.data4d, 256) + self.assertEqual(padded.shape, (256, 256, nc, nz)) \ No newline at end of file -- GitLab