Skip to content
Snippets Groups Projects
Commit cf492a3e authored by Christopher Randolph Rhodes's avatar Christopher Randolph Rhodes
Browse files

Merge branch 'dev_multichannel_ilastik' into 'staging'

ilastik workflow and API allow multichannel inputs

See merge request rhodes/model_server!40
parents bcf1efbe 2310314b
Branches
Tags
2 merge requests!50Release 2024.06.03,!40ilastik workflow and API allow multichannel inputs
......@@ -57,6 +57,9 @@ ilastik_classifiers = {
'px': root / 'ilastik' / 'demo_px.ilp',
'pxmap_to_obj': root / 'ilastik' / 'demo_obj.ilp',
'seg_to_obj': root / 'ilastik' / 'demo_obj_seg.ilp',
'px_color_zstack': root / 'ilastik' / 'px-3d-color.ilp',
'ob_pxmap_color_zstack': root / 'ilastik' / 'ob-pxmap-color-zstack.ilp',
'ob_seg_color_zstack': root / 'ilastik' / 'ob-seg-color-zstack.ilp',
}
roiset_test_data = {
......
......@@ -104,8 +104,15 @@ class IlastikPixelClassifierModel(IlastikModel, SemanticSegmentationModel):
return [l.decode() for l in h5['PixelClassification/LabelNames'][()]]
def infer(self, input_img: GenericImageDataAccessor) -> (InMemoryDataAccessor, dict):
if self.model_chroma != input_img.chroma or self.model_3d != input_img.is_3d():
raise IlastikInputShapeError()
if self.model_chroma != input_img.chroma:
raise IlastikInputShapeError(
f'Model {self} expects {self.model_chroma} input channels but received {input_img.chroma}'
)
if self.model_3d != input_img.is_3d():
if self.model_3d:
raise IlastikInputShapeError(f'Model is 3D but input image is 2D')
else:
raise IlastikInputShapeError(f'Model is 2D but input image is 3D')
tagged_input_data = vigra.taggedView(input_img.data, 'yxcz')
dsi = [
......@@ -164,7 +171,7 @@ class IlastikObjectClassifierFromSegmentationModel(IlastikModel, InstanceSegment
def infer(self, input_img: GenericImageDataAccessor, segmentation_img: GenericImageDataAccessor) -> (np.ndarray, dict):
if self.model_chroma != input_img.chroma:
raise IlastikInputShapeError(
f'Model {self} expects {self.model_chroma} input channels but received only {input_img.chroma}'
f'Model {self} expects {self.model_chroma} input channels but received {input_img.chroma}'
)
if self.model_3d != input_img.is_3d():
if self.model_3d:
......@@ -229,8 +236,15 @@ class IlastikObjectClassifierFromPixelPredictionsModel(IlastikModel, ImageToImag
return ObjectClassificationWorkflowPrediction
def infer(self, input_img: GenericImageDataAccessor, pxmap_img: GenericImageDataAccessor) -> (np.ndarray, dict):
if self.model_chroma != input_img.chroma or self.model_3d != input_img.is_3d():
raise IlastikInputShapeError()
if self.model_chroma != input_img.chroma:
raise IlastikInputShapeError(
f'Model {self} expects {self.model_chroma} input channels but received {input_img.chroma}'
)
if self.model_3d != input_img.is_3d():
if self.model_3d:
raise IlastikInputShapeError(f'Model is 3D but input image is 2D')
else:
raise IlastikInputShapeError(f'Model is 2D but input image is 3D')
if isinstance(input_img, PatchStack):
assert isinstance(pxmap_img, PatchStack)
......
......@@ -7,6 +7,7 @@ import numpy as np
from model_server.conf.testing import czifile, ilastik_classifiers, output_path, roiset_test_data
from model_server.base.accessors import CziImageFileAccessor, generate_file_accessor, InMemoryDataAccessor, PatchStack, write_accessor_data_to_file
from model_server.extensions.ilastik import models as ilm
from model_server.extensions.ilastik.workflows import infer_px_then_ob_model
from model_server.base.models import InvalidObjectLabelsError
from model_server.base.roiset import _get_label_ids, RoiSet, RoiSetMetaParams
from model_server.base.workflows import classify_pixels
......@@ -320,6 +321,92 @@ class TestIlastikOverApi(TestServerBaseClass):
self.assertEqual(resp_infer.status_code, 200, resp_infer.content.decode())
class TestIlastikOnMultichannelInputs(TestServerBaseClass):
def setUp(self) -> None:
super(TestIlastikOnMultichannelInputs, self).setUp()
self.pa_px_classifier = ilastik_classifiers['px_color_zstack']
self.pa_ob_pxmap_classifier = ilastik_classifiers['ob_pxmap_color_zstack']
self.pa_ob_seg_classifier = ilastik_classifiers['ob_seg_color_zstack']
self.pa_input_image = roiset_test_data['multichannel_zstack']['path']
self.pa_mask = roiset_test_data['multichannel_zstack']['mask_path_3d']
def _copy_input_file_to_server(self):
from shutil import copyfile
pa_data = roiset_test_data['multichannel_zstack']['path']
resp = self._get('paths')
pa = resp.json()['inbound_images']
outpath = pathlib.Path(pa) / pa_data.name
copyfile(
czifile['path'],
outpath
)
def test_classify_pixels(self):
img = generate_file_accessor(self.pa_input_image)
self.assertGreater(img.chroma, 1)
mod = ilm.IlastikPixelClassifierModel(params={'project_file': self.pa_px_classifier})
pxmap = mod.infer(img)[0]
self.assertEqual(pxmap.hw, img.hw)
self.assertEqual(pxmap.nz, img.nz)
return pxmap
def test_classify_objects(self):
pxmap = self.test_classify_pixels()
img = generate_file_accessor(self.pa_input_image)
mod = ilm.IlastikObjectClassifierFromPixelPredictionsModel(params={'project_file': self.pa_ob_pxmap_classifier})
obmap = mod.infer(img, pxmap)[0]
self.assertEqual(obmap.hw, img.hw)
self.assertEqual(obmap.nz, img.nz)
def _call_workflow(self, channel):
return infer_px_then_ob_model(
self.pa_input_image,
ilm.IlastikPixelClassifierModel(params={'project_file': self.pa_px_classifier}),
ilm.IlastikObjectClassifierFromPixelPredictionsModel(params={'project_file': self.pa_ob_pxmap_classifier}),
output_path,
channel=channel,
)
def test_workflow(self):
with self.assertRaises(ilm.IlastikInputShapeError):
self._call_workflow(channel=0)
res = self._call_workflow(channel=None)
acc_input = generate_file_accessor(self.pa_input_image)
acc_obmap = generate_file_accessor(res.object_map_filepath)
self.assertEqual(acc_obmap.hw, acc_input.hw)
self.assertEqual(len(acc_obmap._unique()[1]), 3)
def test_api(self):
resp_load = self._put(
'ilastik/seg/load/',
query={'project_file': str(self.pa_px_classifier)},
)
self.assertEqual(resp_load.status_code, 200, resp_load.json())
px_model_id = resp_load.json()['model_id']
resp_load = self._put(
'ilastik/pxmap_to_obj/load/',
query={'project_file': str(self.pa_ob_pxmap_classifier)},
)
self.assertEqual(resp_load.status_code, 200, resp_load.json())
ob_model_id = resp_load.json()['model_id']
resp_infer = self._put(
'ilastik/pixel_then_object_classification/infer/',
query={
'px_model_id': px_model_id,
'ob_model_id': ob_model_id,
'input_filename': self.pa_input_image,
}
)
self.assertEqual(resp_infer.status_code, 200, resp_infer.content.decode())
acc_input = generate_file_accessor(self.pa_input_image)
acc_obmap = generate_file_accessor(resp_infer.json()['object_map_filepath'])
self.assertEqual(acc_obmap.hw, acc_input.hw)
class TestIlastikObjectClassification(unittest.TestCase):
def setUp(self):
stack = generate_file_accessor(roiset_test_data['multichannel_zstack']['path'])
......
......@@ -26,6 +26,7 @@ def infer_px_then_ob_model(
px_model: IlastikPixelClassifierModel,
ob_model: IlastikObjectClassifierFromPixelPredictionsModel,
where_output: Path,
channel: int = None,
**kwargs
) -> WorkflowRunRecord:
"""
......@@ -35,6 +36,7 @@ def infer_px_then_ob_model(
:param px_model: model instance for pixel classification
:param ob_model: model instance for object classification
:param where_output: Path object that references output image directory
:param channel: input image channel to pass to pixel classification, or all channels if None
:param kwargs: variable-length keyword arguments
:return:
"""
......@@ -42,8 +44,12 @@ def infer_px_then_ob_model(
assert isinstance(ob_model, IlastikObjectClassifierFromPixelPredictionsModel)
ti = Timer()
ch = kwargs.get('channel')
img = generate_file_accessor(fpi).get_one_channel_data(ch, mip=kwargs.get('mip', False))
raw_acc = generate_file_accessor(fpi)
if channel is not None:
channels = [channel]
else:
channels = range(0, raw_acc.chroma)
img = raw_acc.get_channels(channels, mip=kwargs.get('mip', False))
ti.click('file_input')
px_map, _ = px_model.infer(img)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please to comment