Skip to content
Snippets Groups Projects
test_model.py 2.61 KiB
import unittest

import numpy as np

import model_server.conf.testing as conf
from model_server.conf.testing import DummySemanticSegmentationModel, DummyInstanceSegmentationModel
from model_server.base.accessors import CziImageFileAccessor
from model_server.base.models import CouldNotLoadModelError, BinaryThresholdSegmentationModel, PermissiveInstanceSegmentationModel

czifile = conf.meta['image_files']['czifile']


class TestCziImageFileAccess(unittest.TestCase):
    def setUp(self) -> None:
        self.cf = CziImageFileAccessor(czifile['path'])

    def test_instantiate_model(self):
        model = DummySemanticSegmentationModel(params=None)
        self.assertTrue(model.loaded)

    def test_instantiate_model_with_nondefault_kwarg(self):
        model = DummySemanticSegmentationModel(autoload=False)
        self.assertFalse(model.autoload, 'Could not override autoload flag in subclass of Model.')

    def test_raise_error_if_cannot_load_model(self):
        class UnloadableDummyImageToImageModel(DummySemanticSegmentationModel):
            def load(self):
                return False

        with self.assertRaises(CouldNotLoadModelError):
            mi = UnloadableDummyImageToImageModel()

    def test_dummy_pixel_segmentation(self):
        model = DummySemanticSegmentationModel()
        img = self.cf.get_mono(0)
        mask = model.label_pixel_class(img)

        w = czifile['w']
        h = czifile['h']

        self.assertEqual(
            mask.shape,
            (h, w, 1, 1),
            'Inferred image is not the expected shape'
        )

        self.assertEqual(
            mask.data[int(w/2), int(h/2)],
            255,
            'Middle pixel is not white as expected'
        )

        self.assertEqual(
            mask.data[0, 0],
            0,
            'First pixel is not black as expected'
        )
        return img, mask

    def test_binary_segmentation(self):
        model = BinaryThresholdSegmentationModel({'tr': 3e4})
        res = model.label_pixel_class(self.cf)
        self.assertTrue(res.is_mask())

    def test_dummy_instance_segmentation(self):
        img, mask = self.test_dummy_pixel_segmentation()
        model = DummyInstanceSegmentationModel()
        obmap = model.label_instance_class(img, mask)
        self.assertTrue(all(obmap.unique()[0] == [0, 1]))
        self.assertTrue(all(obmap.unique()[1] > 0))

    def test_permissive_instance_segmentation(self):
        img, mask = self.test_dummy_pixel_segmentation()
        model = PermissiveInstanceSegmentationModel()
        obmap = model.label_instance_class(img, mask)
        self.assertTrue(np.all(mask.data == 255 * obmap.data))