import unittest

import numpy as np

from model_server.process import pad

class TestProcessingUtilityMethods(unittest.TestCase):
    def setUp(self) -> None:
        w = 200
        h = 300
        nc = 4
        nz = 11
        self.data2d = (2**16 * np.random.rand(h, w, 1, 1)).astype('uint16')
        self.data3d = (2**16 * np.random.rand(h, w, 1, nz)).astype('uint16')
        self.data4d = (2**16 * np.random.rand(h, w, nc, nz)).astype('uint16')

    def test_pad_2d(self):
        padded = pad(self.data2d, 256)
        self.assertEqual(padded.shape, (256, 256, 1, 1))

    def test_pad_3d(self):
        nz = self.data3d.shape[3]
        padded = pad(self.data3d, 256)
        self.assertEqual(padded.shape, (256, 256, 1, nz))

    def test_pad_4d(self):
        nc = self.data4d.shape[2]
        nz = self.data4d.shape[3]
        padded = pad(self.data4d, 256)
        self.assertEqual(padded.shape, (256, 256, nc, nz))