diff --git a/model_server/extensions/ilastik/models.py b/model_server/extensions/ilastik/models.py index 48a18758ea486523cb6c38c31a7e7de695d10800..d17e69545c0aa746c3ebd17ddc89580ace56eadd 100644 --- a/model_server/extensions/ilastik/models.py +++ b/model_server/extensions/ilastik/models.py @@ -1,3 +1,4 @@ +import json import os from pathlib import Path @@ -73,12 +74,31 @@ class IlastikPixelClassifierModel(IlastikModel, SemanticSegmentationModel): model_id = 'ilastik_pixel_classification' operations = ['segment', ] + @property + def model_shape_dict(self): + raw_info = self.shell.projectManager.currentProjectFile['Input Data']['infos']['lane0000']['Raw Data'] + ax = raw_info['axistags'][()] + ax_keys = [ax['key'].upper() for ax in json.loads(ax)['axes']] + shape = raw_info['shape'][()] + return dict(zip(ax_keys, shape)) + + @property + def model_chroma(self): + return self.model_shape_dict['C'] + + @property + def model_3d(self): + return self.model_shape_dict['Z'] > 1 + @staticmethod def get_workflow(): from ilastik.workflows import PixelClassificationWorkflow return PixelClassificationWorkflow def infer(self, input_img: GenericImageDataAccessor) -> (np.ndarray, dict): + if self.model_chroma != input_img.chroma or self.model_3d != input_img.is_3d(): + raise IlastikInputShapeError() + tagged_input_data = vigra.taggedView(input_img.data, 'yxcz') dsi = [ { @@ -221,4 +241,8 @@ class Error(Exception): pass class IlastikInputEmbedding(Error): + pass + +class IlastikInputShapeError(Error): + """Raised when an ilastik classifier is asked to infer on data that is incompatible with its input shape""" pass \ No newline at end of file diff --git a/model_server/extensions/ilastik/tests/test_ilastik.py b/model_server/extensions/ilastik/tests/test_ilastik.py index e81a057ff6934374ba583f7650c6b876c59829f2..7adde1a4cbfcfaa953ff927ca9c0d7f1b4221a9f 100644 --- a/model_server/extensions/ilastik/tests/test_ilastik.py +++ b/model_server/extensions/ilastik/tests/test_ilastik.py @@ -11,6 +11,9 @@ from model_server.base.roiset import _get_label_ids, RoiSet, RoiSetMetaParams from model_server.base.workflows import classify_pixels from tests.test_api import TestServerBaseClass +def _random_int(*args): + return np.random.randint(0, 2 ** 8, size=args, dtype='uint8') + class TestIlastikPixelClassification(unittest.TestCase): def setUp(self) -> None: self.cf = CziImageFileAccessor(czifile['path']) @@ -83,6 +86,40 @@ class TestIlastikPixelClassification(unittest.TestCase): self.mono_image = mono_image self.mask = mask + def test_pixel_classifier_enforces_input_shape(self): + model = ilm.IlastikPixelClassifierModel( + {'project_file': ilastik_classifiers['px']} + ) + self.assertEqual(model.model_chroma, 1) + self.assertEqual(model.model_3d, False) + + # correct data + self.assertIsInstance( + model.label_pixel_class( + InMemoryDataAccessor( + _random_int(512, 256, 1, 1) + ) + ), + InMemoryDataAccessor + ) + + # raise except with input of multiple channels + with self.assertRaises(ilm.IlastikInputShapeError): + mask = model.label_pixel_class( + InMemoryDataAccessor( + _random_int(512, 256, 3, 1) + ) + ) + + # raise except with input of multiple channels + with self.assertRaises(ilm.IlastikInputShapeError): + mask = model.label_pixel_class( + InMemoryDataAccessor( + _random_int(512, 256, 1, 15) + ) + ) + + def test_run_object_classifier_from_pixel_predictions(self): self.test_run_pixel_classifier() fp = czifile['path']