Skip to content
Snippets Groups Projects
Commit c5b70e7d authored by Christopher Randolph Rhodes's avatar Christopher Randolph Rhodes
Browse files

Merge branch 'issue0017' into expose_session_logs_in_api

parents 43f84ad3 b15e0b2b
No related branches found
No related tags found
2 merge requests!16Completed (de)serialization of RoiSet,!9Session exposes a python log
This commit is part of merge request !9. Comments created here will be created in the context of that merge request.
......@@ -162,6 +162,14 @@ class CziImageFileAccessor(GenericImageFileAccessor):
except Exception:
raise FileAccessorError(f'Unable to access CZI data in {fpath}')
try:
md = cf.metadata(raw=False)
compmet = md['ImageDocument']['Metadata']['Information']['Image']['OriginalCompressionMethod']
except KeyError:
raise InvalidCziCompression('Could not find metadata key OriginalCompressionMethod')
if compmet.upper() != 'UNCOMPRESSED':
raise InvalidCziCompression(f'Unsupported compression method {compmet}')
sd = {ch: cf.shape[cf.axes.index(ch)] for ch in cf.axes}
if (sd.get('S') and (sd['S'] > 1)) or (sd.get('T') and (sd['T'] > 1)):
raise DataShapeError(f'Cannot handle image with multiple positions or time points: {sd}')
......@@ -284,6 +292,22 @@ class PatchStack(InMemoryDataAccessor):
def count(self):
return self.shape_dict['P']
def export_pyxcz(self, fpath: Path):
tzcyx = np.moveaxis(
self.pyxcz, # yxcz
[0, 4, 3, 1, 2],
[0, 1, 2, 3, 4]
)
if self.is_mask():
if self.dtype == 'bool':
data = (tzcyx * 255).astype('uint8')
else:
data = tzcyx.astype('uint8')
tifffile.imwrite(fpath, data, imagej=True)
else:
tifffile.imwrite(fpath, tzcyx, imagej=True)
@property
def shape_dict(self):
return dict(zip(('P', 'Y', 'X', 'C', 'Z'), self.data.shape))
......@@ -313,16 +337,32 @@ class PatchStack(InMemoryDataAccessor):
return dict(zip(('P', 'Y', 'X', 'C', 'Z'), self.data.shape))
def make_patch_stack_from_file(fpath): # interpret z-dimension as patch position
def make_patch_stack_from_file(fpath): # interpret t-dimension as patch position
if not Path(fpath).exists():
raise FileNotFoundError(f'Could not find {fpath}')
pyxc = np.moveaxis(
generate_file_accessor(fpath).data, # yxcz
[0, 1, 2, 3],
[1, 2, 3, 0]
try:
tf = tifffile.TiffFile(fpath)
except Exception:
raise FileAccessorError(f'Unable to access data in {fpath}')
if len(tf.series) != 1:
raise DataShapeError(f'Expect only one series in {fpath}')
se = tf.series[0]
axs = [a for a in se.axes if a in [*'TZCYX']]
sd = dict(zip(axs, se.shape))
for a in [*'TZC']:
if a not in axs:
sd[a] = 1
tzcyx = se.asarray().reshape([sd[k] for k in [*'TZCYX']])
pyxcz = np.moveaxis(
tzcyx,
[0, 3, 4, 2, 1],
[0, 1, 2, 3, 4],
)
pyxcz = np.expand_dims(pyxc, axis=3)
return PatchStack(pyxcz)
......@@ -345,6 +385,9 @@ class FileWriteError(Error):
class InvalidAxisKey(Error):
pass
class InvalidCziCompression(Error):
pass
class InvalidDataShape(Error):
pass
......
......@@ -12,8 +12,17 @@ from model_server.base.models import Model, ImageToImageModel, InstanceSegmentat
class IlastikModel(Model):
def __init__(self, params, autoload=True):
def __init__(self, params, autoload=True, enforce_embedded=True):
"""
Base class for models that run via ilastik shell API
:param params:
project_file: path to ilastik project file
:param autoload: automatically load model into memory if true
:param enforce_embedded:
raise an error if all input data are not embedded in the project file, i.e. on the filesystem
"""
self.project_file = Path(params['project_file'])
self.enforce_embedded = enforce_embedded
params['project_file'] = self.project_file.__str__()
if self.project_file.is_absolute():
pap = self.project_file
......@@ -42,6 +51,15 @@ class IlastikModel(Model):
args.project = self.project_file_abspath.__str__()
shell = app.main(args, init_logging=False)
# validate if inputs are embedded in project file
input_groups = shell.projectManager.currentProjectFile['Input Data']['infos']
lanes = input_groups.keys()
for ll in lanes:
input_types = input_groups[ll]
for tt in input_types:
ds_loc = input_groups[ll][tt].get('location', False)
if self.enforce_embedded and ds_loc and ds_loc[()] == b'FileSystem':
raise IlastikInputEmbedding('Cannot load ilastik project file where inputs are on filesystem')
if not isinstance(shell.workflow, self.get_workflow()):
raise ParameterExpectedError(
f'Ilastik project file {self.project_file} does not describe an instance of {shell.workflow.__class__}'
......@@ -87,22 +105,33 @@ class IlastikPixelClassifierModel(IlastikModel, SemanticSegmentationModel):
class IlastikObjectClassifierFromSegmentationModel(IlastikModel, InstanceSegmentationModel):
model_id = 'ilastik_object_classification_from_segmentation'
@staticmethod
def _make_8bit_mask(nda):
if nda.dtype == 'bool':
return 255 * nda.astype('uint8')
else:
return nda
@staticmethod
def get_workflow():
from ilastik.workflows.objectClassification.objectClassificationWorkflow import ObjectClassificationWorkflowBinary
return ObjectClassificationWorkflowBinary
def infer(self, input_img: GenericImageDataAccessor, segmentation_img: GenericImageDataAccessor) -> (np.ndarray, dict):
tagged_input_data = vigra.taggedView(input_img.data, 'yxcz')
assert segmentation_img.is_mask()
if segmentation_img.dtype == 'bool':
seg = 255 * segmentation_img.data.astype('uint8')
if isinstance(input_img, PatchStack):
assert isinstance(segmentation_img, PatchStack)
tagged_input_data = vigra.taggedView(input_img.pczyx, 'tczyx')
tagged_seg_data = vigra.taggedView(
255 * segmentation_img.data.astype('uint8'),
'yxcz'
self._make_8bit_mask(segmentation_img.pczyx),
'tczyx'
)
else:
tagged_seg_data = vigra.taggedView(segmentation_img.data, 'yxcz')
tagged_input_data = vigra.taggedView(input_img.data, 'yxcz')
tagged_seg_data = vigra.taggedView(
self._make_8bit_mask(segmentation_img.data),
'yxcz'
)
dsi = [
{
......@@ -115,12 +144,21 @@ class IlastikObjectClassifierFromSegmentationModel(IlastikModel, InstanceSegment
assert len(obmaps) == 1, 'ilastik generated more than one object map'
yxcz = np.moveaxis(
obmaps[0],
[1, 2, 3, 0],
[0, 1, 2, 3]
)
return InMemoryDataAccessor(data=yxcz), {'success': True}
if isinstance(input_img, PatchStack):
pyxcz = np.moveaxis(
obmaps[0],
[0, 1, 2, 3, 4],
[0, 4, 1, 2, 3]
)
return PatchStack(data=pyxcz), {'success': True}
else:
yxcz = np.moveaxis(
obmaps[0],
[1, 2, 3, 0],
[0, 1, 2, 3]
)
return InMemoryDataAccessor(data=yxcz), {'success': True}
def label_instance_class(self, img: GenericImageDataAccessor, mask: GenericImageDataAccessor, **kwargs):
super(IlastikObjectClassifierFromSegmentationModel, self).label_instance_class(img, mask, **kwargs)
......@@ -172,48 +210,15 @@ class IlastikObjectClassifierFromPixelPredictionsModel(IlastikModel, ImageToImag
"""
if not img.shape == pxmap.shape:
raise InvalidInputImageError('Expecting input image and pixel probabilities to be the same shape')
# TODO: check that pxmap is in-range
pxch = kwargs.get('pixel_classification_channel', 0)
pxtr = kwargs('pixel_classification_threshold', 0.5)
mask = InMemoryDataAccessor(pxmap.get_one_channel_data(pxch).data > pxtr)
# super(IlastikObjectClassifierFromSegmentationModel, self).label_instance_class(img, mask, **kwargs)
obmap, _ = self.infer(img, mask)
return obmap
class PatchStackObjectClassifier(IlastikObjectClassifierFromSegmentationModel):
"""
Wrap ilastik object classification for inputs comprising single-object series of raw images and binary
segmentation masks.
"""
def infer(self, input_acc: PatchStack, segmentation_acc: PatchStack) -> (np.ndarray, dict):
assert segmentation_acc.is_mask()
if not input_acc.chroma == 1:
raise InvalidInputImageError('Object classifier expects only monochrome patches')
if not input_acc.nz == 1:
raise InvalidInputImageError('Object classifier expects only 2d patches')
tagged_input_data = vigra.taggedView(input_acc.pczyx, 'tczyx')
tagged_seg_data = vigra.taggedView(segmentation_acc.pczyx, 'tczyx')
dsi = [
{
'Raw Data': self.PreloadedArrayDatasetInfo(preloaded_array=tagged_input_data),
'Segmentation Image': self.PreloadedArrayDatasetInfo(preloaded_array=tagged_seg_data),
}
]
obmaps = self.shell.workflow.batchProcessingApplet.run_export(dsi, export_to_array=True) # [z x h x w x n]
assert len(obmaps) == 1, 'ilastik generated more than one object map'
# for some reason ilastik scrambles these axes to P(1)YX(1); unclear which should be Z and C
assert obmaps[0].shape == (input_acc.count, 1, input_acc.hw[0], input_acc.hw[1], 1)
pyxcz = np.moveaxis(
obmaps[0],
[0, 1, 2, 3, 4],
[0, 4, 1, 2, 3]
)
class Error(Exception):
pass
return PatchStack(data=pyxcz), {'success': True}
\ No newline at end of file
class IlastikInputEmbedding(Error):
pass
\ No newline at end of file
......@@ -97,7 +97,7 @@ class TestIlastikPixelClassification(unittest.TestCase):
objmap,
)
)
self.assertEqual(objmap.data.max(), 3)
self.assertEqual(objmap.data.max(), 2)
def test_run_object_classifier_from_segmentation(self):
self.test_run_pixel_classifier()
......@@ -113,7 +113,7 @@ class TestIlastikPixelClassification(unittest.TestCase):
objmap,
)
)
self.assertEqual(objmap.data.max(), 3)
self.assertEqual(objmap.data.max(), 2)
def test_ilastik_pixel_classification_as_workflow(self):
result = classify_pixels(
......@@ -269,16 +269,18 @@ class TestIlastikObjectClassification(unittest.TestCase):
)
)
self.object_classifier = ilm.PatchStackObjectClassifier(
self.object_classifier = ilm.IlastikObjectClassifierFromSegmentationModel(
params={'project_file': ilastik_classifiers['seg_to_obj']}
)
def test_classify_patches(self):
raw_patches = self.roiset.get_raw_patches()
patch_masks = self.roiset.get_patch_masks()
res_patches, _ = self.object_classifier.infer(raw_patches, patch_masks)
res_patches = self.object_classifier.label_instance_class(raw_patches, patch_masks)
self.assertEqual(res_patches.count, self.roiset.count)
res_patches.export_pyxcz(output_path / 'res_patches.tif')
for pi in range(0, res_patches.count): # assert that there is only one nonzero label per patch
unique = np.unique(res_patches.iat(pi).data)
self.assertEqual(len(unique), 2)
self.assertEqual(unique[0], 0)
la, ct = np.unique(res_patches.iat(pi).data, return_counts=True)
self.assertEqual(np.sum(ct > 1), 2) # exclude single-pixel anomaly
self.assertEqual(la[0], 0)
......@@ -7,6 +7,9 @@ from model_server.base.accessors import PatchStack, make_patch_stack_from_file,
from model_server.conf.testing import czifile, output_path, monopngfile, rgbpngfile, tifffile, monozstackmask
from model_server.base.accessors import CziImageFileAccessor, DataShapeError, generate_file_accessor, InMemoryDataAccessor, PngFileAccessor, write_accessor_data_to_file, TifSingleSeriesFileAccessor
def _random_int(*args):
return np.random.randint(0, 2 ** 8, size=args, dtype='uint8')
class TestCziImageFileAccess(unittest.TestCase):
def setUp(self) -> None:
......@@ -40,7 +43,7 @@ class TestCziImageFileAccess(unittest.TestCase):
nc = 4
nz = 11
c = 3
cf = InMemoryDataAccessor(np.random.rand(h, w, nc, nz))
cf = InMemoryDataAccessor(_random_int(h, w, nc, nz))
sc = cf.get_one_channel_data(c)
self.assertEqual(sc.shape, (h, w, 1, nz))
......@@ -70,7 +73,7 @@ class TestCziImageFileAccess(unittest.TestCase):
def test_conform_data_shorter_than_xycz(self):
h = 256
w = 512
data = np.random.rand(h, w, 1)
data = _random_int(h, w, 1)
acc = InMemoryDataAccessor(data)
self.assertEqual(
InMemoryDataAccessor.conform_data(data).shape,
......@@ -82,7 +85,7 @@ class TestCziImageFileAccess(unittest.TestCase):
)
def test_conform_data_longer_than_xycz(self):
data = np.random.rand(256, 512, 12, 8, 3)
data = _random_int(256, 512, 12, 8, 3)
with self.assertRaises(DataShapeError):
acc = InMemoryDataAccessor(data)
......@@ -93,7 +96,7 @@ class TestCziImageFileAccess(unittest.TestCase):
c = 3
nz = 10
yxcz = (2**8 * np.random.rand(h, w, c, nz)).astype('uint8')
yxcz = _random_int(h, w, c, nz)
acc = InMemoryDataAccessor(yxcz)
fp = output_path / f'rand3d.tif'
self.assertTrue(
......@@ -138,7 +141,7 @@ class TestPatchStackAccessor(unittest.TestCase):
w = 256
h = 512
n = 4
acc = PatchStack(np.random.rand(n, h, w, 1, 1))
acc = PatchStack(_random_int(n, h, w, 1, 1))
self.assertEqual(acc.count, n)
self.assertEqual(acc.hw, (h, w))
self.assertEqual(acc.pyxcz.shape, (n, h, w, 1, 1))
......@@ -147,7 +150,7 @@ class TestPatchStackAccessor(unittest.TestCase):
w = 256
h = 512
n = 4
acc = PatchStack([np.random.rand(h, w, 1, 1) for _ in range(0, n)])
acc = PatchStack([_random_int(h, w, 1, 1) for _ in range(0, n)])
self.assertEqual(acc.count, n)
self.assertEqual(acc.hw, (h, w))
self.assertEqual(acc.pyxcz.shape, (n, h, w, 1, 1))
......@@ -176,8 +179,8 @@ class TestPatchStackAccessor(unittest.TestCase):
nz = 5
n = 4
patches = [np.random.rand(h, w, c, nz) for _ in range(0, n)]
patches.append(np.random.rand(h, 2 * w, c, nz))
patches = [_random_int(h, w, c, nz) for _ in range(0, n)]
patches.append(_random_int(h, 2 * w, c, nz))
acc = PatchStack(patches)
self.assertEqual(acc.count, n + 1)
self.assertEqual(acc.hw, (h, 2 * w))
......@@ -191,7 +194,15 @@ class TestPatchStackAccessor(unittest.TestCase):
n = 4
nz = 15
nc = 2
acc = PatchStack(np.random.rand(n, h, w, nc, nz))
acc = PatchStack(_random_int(n, h, w, nc, nz))
self.assertEqual(acc.count, n)
self.assertEqual(acc.pczyx.shape, (n, nc, nz, h, w))
self.assertEqual(acc.hw, (h, w))
return acc
def test_export_pczyx_patch_hyperstack(self):
acc = self.test_pczyx()
fp = output_path / 'patch_hyperstack.tif'
acc.export_pyxcz(fp)
acc2 = make_patch_stack_from_file(fp)
self.assertEqual(acc.shape, acc2.shape)
\ No newline at end of file
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment