Skip to content
Snippets Groups Projects
Commit b99720e2 authored by Christopher Randolph Rhodes's avatar Christopher Randolph Rhodes
Browse files

Validate chroma and dimensionality of inputs to pixel classification model

parent b15e0b2b
No related branches found
No related tags found
2 merge requests!16Completed (de)serialization of RoiSet,!5Resolve "ilastik models do not validate dimensionality of input data"
import json
import os import os
from pathlib import Path from pathlib import Path
...@@ -73,12 +74,31 @@ class IlastikPixelClassifierModel(IlastikModel, SemanticSegmentationModel): ...@@ -73,12 +74,31 @@ class IlastikPixelClassifierModel(IlastikModel, SemanticSegmentationModel):
model_id = 'ilastik_pixel_classification' model_id = 'ilastik_pixel_classification'
operations = ['segment', ] operations = ['segment', ]
@property
def model_shape_dict(self):
raw_info = self.shell.projectManager.currentProjectFile['Input Data']['infos']['lane0000']['Raw Data']
ax = raw_info['axistags'][()]
ax_keys = [ax['key'].upper() for ax in json.loads(ax)['axes']]
shape = raw_info['shape'][()]
return dict(zip(ax_keys, shape))
@property
def model_chroma(self):
return self.model_shape_dict['C']
@property
def model_3d(self):
return self.model_shape_dict['Z'] > 1
@staticmethod @staticmethod
def get_workflow(): def get_workflow():
from ilastik.workflows import PixelClassificationWorkflow from ilastik.workflows import PixelClassificationWorkflow
return PixelClassificationWorkflow return PixelClassificationWorkflow
def infer(self, input_img: GenericImageDataAccessor) -> (np.ndarray, dict): def infer(self, input_img: GenericImageDataAccessor) -> (np.ndarray, dict):
if self.model_chroma != input_img.chroma or self.model_3d != input_img.is_3d():
raise IlastikInputShapeError()
tagged_input_data = vigra.taggedView(input_img.data, 'yxcz') tagged_input_data = vigra.taggedView(input_img.data, 'yxcz')
dsi = [ dsi = [
{ {
...@@ -221,4 +241,8 @@ class Error(Exception): ...@@ -221,4 +241,8 @@ class Error(Exception):
pass pass
class IlastikInputEmbedding(Error): class IlastikInputEmbedding(Error):
pass
class IlastikInputShapeError(Error):
"""Raised when an ilastik classifier is asked to infer on data that is incompatible with its input shape"""
pass pass
\ No newline at end of file
...@@ -11,6 +11,9 @@ from model_server.base.roiset import _get_label_ids, RoiSet, RoiSetMetaParams ...@@ -11,6 +11,9 @@ from model_server.base.roiset import _get_label_ids, RoiSet, RoiSetMetaParams
from model_server.base.workflows import classify_pixels from model_server.base.workflows import classify_pixels
from tests.test_api import TestServerBaseClass from tests.test_api import TestServerBaseClass
def _random_int(*args):
return np.random.randint(0, 2 ** 8, size=args, dtype='uint8')
class TestIlastikPixelClassification(unittest.TestCase): class TestIlastikPixelClassification(unittest.TestCase):
def setUp(self) -> None: def setUp(self) -> None:
self.cf = CziImageFileAccessor(czifile['path']) self.cf = CziImageFileAccessor(czifile['path'])
...@@ -83,6 +86,40 @@ class TestIlastikPixelClassification(unittest.TestCase): ...@@ -83,6 +86,40 @@ class TestIlastikPixelClassification(unittest.TestCase):
self.mono_image = mono_image self.mono_image = mono_image
self.mask = mask self.mask = mask
def test_pixel_classifier_enforces_input_shape(self):
model = ilm.IlastikPixelClassifierModel(
{'project_file': ilastik_classifiers['px']}
)
self.assertEqual(model.model_chroma, 1)
self.assertEqual(model.model_3d, False)
# correct data
self.assertIsInstance(
model.label_pixel_class(
InMemoryDataAccessor(
_random_int(512, 256, 1, 1)
)
),
InMemoryDataAccessor
)
# raise except with input of multiple channels
with self.assertRaises(ilm.IlastikInputShapeError):
mask = model.label_pixel_class(
InMemoryDataAccessor(
_random_int(512, 256, 3, 1)
)
)
# raise except with input of multiple channels
with self.assertRaises(ilm.IlastikInputShapeError):
mask = model.label_pixel_class(
InMemoryDataAccessor(
_random_int(512, 256, 1, 15)
)
)
def test_run_object_classifier_from_pixel_predictions(self): def test_run_object_classifier_from_pixel_predictions(self):
self.test_run_pixel_classifier() self.test_run_pixel_classifier()
fp = czifile['path'] fp = czifile['path']
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment