-
Christopher Randolph Rhodes authoredChristopher Randolph Rhodes authored
models.py 11.63 KiB
import json
import os
from pathlib import Path
import numpy as np
import vigra
import model_server.extensions.ilastik.conf
from model_server.base.accessors import PatchStack
from model_server.base.accessors import GenericImageDataAccessor, InMemoryDataAccessor
from model_server.base.models import Model, ImageToImageModel, InstanceSegmentationModel, InvalidInputImageError, ParameterExpectedError, SemanticSegmentationModel
class IlastikModel(Model):
def __init__(self, params, autoload=True, enforce_embedded=True):
"""
Base class for models that run via ilastik shell API
:param params:
project_file: path to ilastik project file
:param autoload: automatically load model into memory if true
:param enforce_embedded:
raise an error if all input data are not embedded in the project file, i.e. on the filesystem
"""
self.project_file = Path(params['project_file'])
self.enforce_embedded = enforce_embedded
params['project_file'] = self.project_file.__str__()
if self.project_file.is_absolute():
pap = self.project_file
else:
pap = model_server.extensions.ilastik.conf.paths['project_files'] / self.project_file
self.project_file_abspath = pap
if not pap.exists():
raise FileNotFoundError(f'Project file does not exist: {pap}')
if 'project_file' not in params or not self.project_file_abspath.exists():
raise ParameterExpectedError('Ilastik model expects a project (*.ilp) file')
self.shell = None
super().__init__(autoload, params)
def load(self):
from ilastik import app
from ilastik.applets.dataSelection.opDataSelection import PreloadedArrayDatasetInfo
self.PreloadedArrayDatasetInfo = PreloadedArrayDatasetInfo
os.environ["LAZYFLOW_THREADS"] = "8"
os.environ["LAZYFLOW_TOTAL_RAM_MB"] = "24000"
args = app.parse_args([])
args.headless = True
args.project = self.project_file_abspath.__str__()
shell = app.main(args, init_logging=False)
# validate if inputs are embedded in project file
h5 = shell.projectManager.currentProjectFile
for lane in h5['Input Data/infos'].keys():
for role in h5[f'Input Data/infos/{lane}'].keys():
grp = h5[f'Input Data/infos/{lane}/{role}']
if self.enforce_embedded and ('location' in grp.keys()) and grp['location'][()] != b'ProjectInternal':
raise IlastikInputEmbedding('Cannot load ilastik project file where inputs are on filesystem')
assert True
if not isinstance(shell.workflow, self.get_workflow()):
raise ParameterExpectedError(
f'Ilastik project file {self.project_file} does not describe an instance of {shell.workflow.__class__}'
)
self.shell = shell
return True
@property
def model_shape_dict(self):
raw_info = self.shell.projectManager.currentProjectFile['Input Data']['infos']['lane0000']['Raw Data']
ax = raw_info['axistags'][()]
ax_keys = [ax['key'].upper() for ax in json.loads(ax)['axes']]
shape = raw_info['shape'][()]
dd = dict(zip(ax_keys, shape))
for ci in 'TCZ':
if ci not in dd.keys():
dd[ci] = 1
return dd
@property
def model_chroma(self):
return self.model_shape_dict['C']
@property
def model_3d(self):
return self.model_shape_dict['Z'] > 1
class IlastikPixelClassifierModel(IlastikModel, SemanticSegmentationModel):
model_id = 'ilastik_pixel_classification'
operations = ['segment', ]
@staticmethod
def get_workflow():
from ilastik.workflows import PixelClassificationWorkflow
return PixelClassificationWorkflow
@property
def labels(self):
h5 = self.shell.projectManager.currentProjectFile
return [l.decode() for l in h5['PixelClassification/LabelNames'][()]]
def infer(self, input_img: GenericImageDataAccessor) -> (InMemoryDataAccessor, dict):
if self.model_chroma != input_img.chroma or self.model_3d != input_img.is_3d():
raise IlastikInputShapeError()
tagged_input_data = vigra.taggedView(input_img.data, 'yxcz')
dsi = [
{
'Raw Data': self.PreloadedArrayDatasetInfo(preloaded_array=tagged_input_data),
}
]
pxmaps = self.shell.workflow.batchProcessingApplet.run_export(dsi, export_to_array=True) # [z x h x w x n]
assert len(pxmaps) == 1, 'ilastik generated more than one pixel map'
yxcz = np.moveaxis(
pxmaps[0],
[1, 2, 3, 0],
[0, 1, 2, 3]
)
return InMemoryDataAccessor(data=yxcz), {'success': True}
def infer_patch_stack(self, img: PatchStack, **kwargs) -> (np.ndarray, dict):
"""
Iterative over a patch stack, call inference separately on each cropped patch
"""
nc = len(self.labels)
data = np.zeros((img.count, *img.hw, nc, img.nz), dtype=float) # interpret as PYXCZ
for i in range(0, img.count):
sl = img.get_slice_at(i)
data[i][sl[0], sl[1], :, sl[3]] = self.infer(img.iat(i, crop=True))[0].data
return PatchStack(data), {'success': True}
def label_pixel_class(self, img: GenericImageDataAccessor, px_class: int = 0, px_prob_threshold=0.5, **kwargs):
pxmap, _ = self.infer(img)
mask = pxmap.data[:, :, px_class, :] > px_prob_threshold
return InMemoryDataAccessor(mask)
class IlastikObjectClassifierFromSegmentationModel(IlastikModel, InstanceSegmentationModel):
model_id = 'ilastik_object_classification_from_segmentation'
@staticmethod
def _make_8bit_mask(nda):
if nda.dtype == 'bool':
return 255 * nda.astype('uint8')
else:
return nda
@staticmethod
def get_workflow():
from ilastik.workflows.objectClassification.objectClassificationWorkflow import ObjectClassificationWorkflowBinary
return ObjectClassificationWorkflowBinary
def infer(self, input_img: GenericImageDataAccessor, segmentation_img: GenericImageDataAccessor) -> (np.ndarray, dict):
if self.model_chroma != input_img.chroma or self.model_3d != input_img.is_3d():
raise IlastikInputShapeError()
assert segmentation_img.is_mask()
if isinstance(input_img, PatchStack):
assert isinstance(segmentation_img, PatchStack)
tagged_input_data = vigra.taggedView(input_img.pczyx, 'tczyx')
tagged_seg_data = vigra.taggedView(
self._make_8bit_mask(segmentation_img.pczyx),
'tczyx'
)
else:
tagged_input_data = vigra.taggedView(input_img.data, 'yxcz')
tagged_seg_data = vigra.taggedView(
self._make_8bit_mask(segmentation_img.data),
'yxcz'
)
dsi = [
{
'Raw Data': self.PreloadedArrayDatasetInfo(preloaded_array=tagged_input_data),
'Segmentation Image': self.PreloadedArrayDatasetInfo(preloaded_array=tagged_seg_data),
}
]
obmaps = self.shell.workflow.batchProcessingApplet.run_export(dsi, export_to_array=True) # [z x h x w x n]
assert len(obmaps) == 1, 'ilastik generated more than one object map'
if isinstance(input_img, PatchStack):
pyxcz = np.moveaxis(
obmaps[0],
[0, 1, 2, 3, 4],
[0, 4, 1, 2, 3]
)
return PatchStack(data=pyxcz), {'success': True}
else:
yxcz = np.moveaxis(
obmaps[0],
[1, 2, 3, 0],
[0, 1, 2, 3]
)
return InMemoryDataAccessor(data=yxcz), {'success': True}
def label_instance_class(self, img: GenericImageDataAccessor, mask: GenericImageDataAccessor, **kwargs):
super(IlastikObjectClassifierFromSegmentationModel, self).label_instance_class(img, mask, **kwargs)
obmap, _ = self.infer(img, mask)
return obmap
class IlastikObjectClassifierFromPixelPredictionsModel(IlastikModel, ImageToImageModel):
model_id = 'ilastik_object_classification_from_pixel_predictions'
@staticmethod
def get_workflow():
from ilastik.workflows.objectClassification.objectClassificationWorkflow import ObjectClassificationWorkflowPrediction
return ObjectClassificationWorkflowPrediction
def infer(self, input_img: GenericImageDataAccessor, pxmap_img: GenericImageDataAccessor) -> (np.ndarray, dict):
if self.model_chroma != input_img.chroma or self.model_3d != input_img.is_3d():
raise IlastikInputShapeError()
tagged_input_data = vigra.taggedView(input_img.data, 'yxcz')
tagged_pxmap_data = vigra.taggedView(pxmap_img.data, 'yxcz')
dsi = [
{
'Raw Data': self.PreloadedArrayDatasetInfo(preloaded_array=tagged_input_data),
'Prediction Maps': self.PreloadedArrayDatasetInfo(preloaded_array=tagged_pxmap_data),
}
]
obmaps = self.shell.workflow.batchProcessingApplet.run_export(dsi, export_to_array=True) # [z x h x w x n]
assert len(obmaps) == 1, 'ilastik generated more than one object map'
yxcz = np.moveaxis(
obmaps[0],
[1, 2, 3, 0],
[0, 1, 2, 3]
)
return InMemoryDataAccessor(data=yxcz), {'success': True}
def label_instance_class(self, img: GenericImageDataAccessor, pxmap: GenericImageDataAccessor, **kwargs):
"""
Given an image and a map of pixel probabilities of the same shape, return a map where each connected object is
assigned a class.
:param img: input image
:param pxmap: map of pixel probabilities
:param kwargs:
pixel_classification_channel: channel of pxmap used to segment objects
pixel_classification_thresold: threshold of pxmap used to segment objects
:return:
"""
if not img.shape == pxmap.shape:
raise InvalidInputImageError('Expecting input image and pixel probabilities to be the same shape')
pxch = kwargs.get('pixel_classification_channel', 0)
pxtr = kwargs.get('pixel_classification_threshold', 0.5)
mask = InMemoryDataAccessor(pxmap.get_one_channel_data(pxch).data > pxtr)
obmap, _ = self.infer(img, mask)
return obmap
def make_instance_segmentation_model(self, px_ch: int):
"""
Generate an instance segmentation model, i.e. one that takes binary masks instead of pixel probabilities as a
second input.
:param px_ch: channel of pixel probability map to use
:return:
InstanceSegmentationModel object
"""
class _Mod(self.__class__, InstanceSegmentationModel):
def label_instance_class(
self, img: GenericImageDataAccessor, mask: GenericImageDataAccessor, **kwargs
) -> GenericImageDataAccessor:
if mask.dtype == 'bool':
norm_mask = 1.0 * mask.data
else:
norm_mask = mask.data / np.iinfo(mask.dtype).max
norm_mask_acc = InMemoryDataAccessor(norm_mask.astype('float32'))
return super().label_instance_class(img, norm_mask_acc, pixel_classification_channel=px_ch)
return _Mod(params={'project_file': self.project_file})
class Error(Exception):
pass
class IlastikInputEmbedding(Error):
pass
class IlastikInputShapeError(Error):
"""Raised when an ilastik classifier is asked to infer on data that is incompatible with its input shape"""
pass