Skip to content
Snippets Groups Projects
test_process.py 2.84 KiB
Newer Older
from base.annotators import draw_contours_on_patch
from base.process import get_safe_contours, mask_largest_object, pad

class TestProcessingUtilityMethods(unittest.TestCase):
    def setUp(self) -> None:
        w = 200
        h = 300
        nc = 4
        nz = 11
        self.data2d = (2**16 * np.random.rand(h, w, 1, 1)).astype('uint16')
        self.data3d = (2**16 * np.random.rand(h, w, 1, nz)).astype('uint16')
        self.data4d = (2**16 * np.random.rand(h, w, nc, nz)).astype('uint16')

    def test_pad_2d(self):
        padded = pad(self.data2d, 256)
        self.assertEqual(padded.shape, (256, 256, 1, 1))

    def test_pad_3d(self):
        nz = self.data3d.shape[3]
        padded = pad(self.data3d, 256)
        self.assertEqual(padded.shape, (256, 256, 1, nz))

    def test_pad_4d(self):
        nc = self.data4d.shape[2]
        nz = self.data4d.shape[3]
        padded = pad(self.data4d, 256)
        self.assertEqual(padded.shape, (256, 256, nc, nz))


class TestMaskLargestObject(unittest.TestCase):
    def test_mask_largest_touching_object(self):
        arr = np.zeros([5, 5], dtype='uint8')
        arr[0:3, 0:3] = 2
        arr[3:, 2:] = 4
        masked = mask_largest_object(arr)
        self.assertTrue(np.all(np.unique(masked) == [0, 2]))
        self.assertTrue(np.all(masked[4:5, 0:2] == 0))
        self.assertTrue(np.all(masked[0:3, 3:5] == 0))

    def test_no_change(self):
        arr = np.zeros([5, 5], dtype='uint8')
        arr[0:3, 0:3] = 2
        masked = mask_largest_object(arr)
        self.assertTrue(np.all(masked == arr))

    def test_mask_multiple_objects_in_binary_maks(self):
        arr = np.zeros([5, 5], dtype='uint8')
        arr[0:3, 0:3] = 255
        arr[4, 2:5] = 255
        masked = mask_largest_object(arr)
        print(np.unique(masked))
        self.assertTrue(np.all(np.unique(masked) == [0, 255]))
        self.assertTrue(np.all(masked[:, 3:5] == 0))
        self.assertTrue(np.all(masked[3:5, :] == 0))


class TestSafeContours(unittest.TestCase):
    def setUp(self) -> None:
        self.patch = np.ones((10, 20), dtype='uint8')
        self.mask_ref = np.zeros((10, 20), dtype=bool)
        self.mask_ref[0:5, 0:10] = True
        self.mask_test = np.ones((1, 20), dtype=bool)

    def test_contours_on_compliant_mask(self):
        con = get_safe_contours(self.mask_ref)
        patch = self.patch.copy()
        self.assertEqual((patch == 0).sum(), 0)
        patch = draw_contours_on_patch(patch, con)
        self.assertEqual((patch == 0).sum(), 14)

    def test_contours_on_noncompliant_mask(self):
        con = get_safe_contours(self.mask_test)
        patch = self.patch.copy()
        self.assertEqual((patch == 0).sum(), 0)
        patch = draw_contours_on_patch(self.patch, con)
        self.assertEqual((patch == 0).sum(), 20)
        self.assertEqual((patch[0, :] == 0).sum(), 20)