From c3e3ff61c697b32900577420a7008111d3148eb0 Mon Sep 17 00:00:00 2001
From: Christopher Rhodes <christopher.rhodes@embl.de>
Date: Wed, 20 Dec 2023 15:30:53 +0100
Subject: [PATCH] Test covers instance segmentation base class

---
 model_server/models.py | 23 +++++++++++++++++++++--
 tests/test_model.py    | 29 ++++++++++++++++++-----------
 2 files changed, 39 insertions(+), 13 deletions(-)

diff --git a/model_server/models.py b/model_server/models.py
index f0864147..06bef01c 100644
--- a/model_server/models.py
+++ b/model_server/models.py
@@ -90,8 +90,7 @@ class InstanceSegmentationModel(ImageToImageModel):
             raise InvalidInputImageError('Expect input image and mask to be the same shape')
 
 
-
-class DummySegmentationModel(SemanticSegmentationModel):
+class DummySemanticSegmentationModel(SemanticSegmentationModel):
 
     model_id = 'dummy_make_white_square'
 
@@ -111,6 +110,26 @@ class DummySegmentationModel(SemanticSegmentationModel):
         mask, _ = self.infer(img)
         return mask
 
+class DummyInstanceSegmentationModel(InstanceSegmentationModel):
+
+    model_id = 'dummy_pass_input_mask'
+
+    def load(self):
+        return True
+
+    def infer(
+            self, img: GenericImageDataAccessor, mask: GenericImageDataAccessor
+    ) -> (GenericImageDataAccessor, dict):
+        return mask
+
+    def label_instance_class(
+            self, img: GenericImageDataAccessor, mask: GenericImageDataAccessor, **kwargs
+    ) -> GenericImageDataAccessor:
+        """
+        Returns a trivial segmentation, i.e. the input mask
+        """
+        super(DummyInstanceSegmentationModel, self).label_instance_class(img, mask, **kwargs)
+        return self.infer(img, mask)
 
 class Error(Exception):
     pass
diff --git a/tests/test_model.py b/tests/test_model.py
index 0d5f98ae..8730c666 100644
--- a/tests/test_model.py
+++ b/tests/test_model.py
@@ -1,49 +1,56 @@
 import unittest
 from conf.testing import czifile
 from model_server.accessors import CziImageFileAccessor
-from model_server.models import DummySegmentationModel, CouldNotLoadModelError
+from model_server.models import DummySemanticSegmentationModel, DummyInstanceSegmentationModel, CouldNotLoadModelError
 
 class TestCziImageFileAccess(unittest.TestCase):
     def setUp(self) -> None:
         self.cf = CziImageFileAccessor(czifile['path'])
 
     def test_instantiate_model(self):
-        model = DummySegmentationModel(params=None)
+        model = DummySemanticSegmentationModel(params=None)
         self.assertTrue(model.loaded)
 
     def test_instantiate_model_with_nondefault_kwarg(self):
-        model = DummySegmentationModel(autoload=False)
+        model = DummySemanticSegmentationModel(autoload=False)
         self.assertFalse(model.autoload, 'Could not override autoload flag in subclass of Model.')
 
     def test_raise_error_if_cannot_load_model(self):
-        class UnloadableDummyImageToImageModel(DummySegmentationModel):
+        class UnloadableDummyImageToImageModel(DummySemanticSegmentationModel):
             def load(self):
                 return False
 
         with self.assertRaises(CouldNotLoadModelError):
             mi = UnloadableDummyImageToImageModel()
 
-    def test_czifile_is_correct_shape(self):
-        model = DummySegmentationModel()
-        img, _ = model.infer(self.cf)
+    def test_dummy_pixel_segmentation(self):
+        model = DummySemanticSegmentationModel()
+        img = self.cf.get_one_channel_data(0)
+        mask = model.label_pixel_class(img)
 
         w = czifile['w']
         h = czifile['h']
 
         self.assertEqual(
-            img.shape,
+            mask.shape,
             (h, w, 1, 1),
             'Inferred image is not the expected shape'
         )
 
         self.assertEqual(
-            img.data[int(w/2), int(h/2)],
+            mask.data[int(w/2), int(h/2)],
             255,
             'Middle pixel is not white as expected'
         )
 
         self.assertEqual(
-            img.data[0, 0],
+            mask.data[0, 0],
             0,
             'First pixel is not black as expected'
-        )
\ No newline at end of file
+        )
+        return img, mask
+
+    def test_dummy_instance_segmentation(self):
+        img, mask = self.test_dummy_pixel_segmentation()
+        model = DummyInstanceSegmentationModel()
+        obmap = model.label_instance_class(img, mask)
-- 
GitLab