-
Christopher Randolph Rhodes authoredChristopher Randolph Rhodes authored
test_ilastik.py 16.33 KiB
from pathlib import Path
from shutil import copyfile
import unittest
import numpy as np
from model_server.base.accessors import CziImageFileAccessor, generate_file_accessor, InMemoryDataAccessor, PatchStack, write_accessor_data_to_file
from model_server.base.api import app
from model_server.extensions.ilastik import models as ilm
from model_server.extensions.ilastik.pipelines import px_then_ob
from model_server.extensions.ilastik.router import router
from model_server.base.roiset import RoiSet, RoiSetMetaParams
from model_server.base.pipelines import segment
import model_server.conf.testing as conf
data = conf.meta['image_files']
output_path = conf.meta['output_path']
params = conf.meta['roiset']
czifile = conf.meta['image_files']['czifile']
ilastik_classifiers = conf.meta['ilastik_classifiers']
app.include_router(router)
def _random_int(*args):
return np.random.randint(0, 2 ** 8, size=args, dtype='uint8')
class TestIlastikPixelClassification(unittest.TestCase):
def setUp(self) -> None:
self.cf = CziImageFileAccessor(czifile['path'])
self.channel = 0
self.model = ilm.IlastikPixelClassifierModel(
params=ilm.IlastikPixelClassifierParams(project_file=ilastik_classifiers['px']['path'].__str__())
)
self.mono_image = self.cf.get_mono(self.channel)
def test_raise_error_if_autoload_disabled(self):
model = ilm.IlastikPixelClassifierModel(
params=ilm.IlastikPixelClassifierParams(project_file=ilastik_classifiers['px']['path'].__str__()),
autoload=False
)
w = 512
h = 256
input_img = InMemoryDataAccessor(data=np.random.rand(w, h, 1, 1))
with self.assertRaises(AttributeError):
mask = model.label_pixel_class(input_img)
def test_run_pixel_classifier_on_random_data(self):
w = 512
h = 256
input_img = InMemoryDataAccessor(data=np.random.rand(h, w, 1, 1))
mask = self.model.label_pixel_class(input_img)
self.assertEqual(mask.shape, (h, w, 1, 1))
def test_run_pixel_classifier(self):
self.assertEqual(self.mono_image.shape_dict['X'], czifile['w'])
self.assertEqual(self.mono_image.shape_dict['Y'], czifile['h'])
self.assertEqual(self.mono_image.shape_dict['C'], 1)
self.assertEqual(self.mono_image.shape_dict['Z'], 1)
mask = self.model.label_pixel_class(self.mono_image)
self.assertTrue(mask.is_mask())
self.assertEqual(mask.shape[0:2], self.cf.shape[0:2])
self.assertEqual(mask.shape_dict['C'], 1)
self.assertEqual(mask.shape_dict['Z'], 1)
self.assertTrue(
write_accessor_data_to_file(
output_path / 'seg' / f'seg_{self.cf.fpath.stem}_ch{self.channel}.tif',
mask
)
)
def test_label_pixels_with_params(self):
def _run_seg(tr, sig):
mod = ilm.IlastikPixelClassifierModel(
params=ilm.IlastikPixelClassifierParams(
project_file=ilastik_classifiers['px']['path'].__str__(),
px_prob_threshold=tr,
px_smoothing=sig,
),
)
mask = mod.label_pixel_class(self.mono_image)
write_accessor_data_to_file(
output_path / 'seg' / f'seg_tr{int(10*tr)}_sig{int(10*sig)}.tif',
mask
)
return mask
mask1 = _run_seg(0.5, 0.0)
mask2 = _run_seg(0.5, 0.2)
self.assertEqual(mask1.shape, mask2.shape)
def test_pixel_classifier_enforces_input_shape(self):
self.assertEqual(self.model.model_chroma, 1)
self.assertEqual(self.model.model_3d, False)
# correct data
self.assertIsInstance(
self.model.label_pixel_class(
InMemoryDataAccessor(
_random_int(512, 256, 1, 1)
)
),
InMemoryDataAccessor
)
# raise except with input of multiple channels
with self.assertRaises(ilm.IlastikInputShapeError):
mask = self.model.label_pixel_class(
InMemoryDataAccessor(
_random_int(512, 256, 3, 1)
)
)
# raise except with input of multiple channels
with self.assertRaises(ilm.IlastikInputShapeError):
mask = self.model.label_pixel_class(
InMemoryDataAccessor(
_random_int(512, 256, 1, 15)
)
)
def test_ilastik_infer_pxmap_from_patchstack(self):
def _r(h):
return np.random.randint(0, 2 ** 8, size=(h, 512, 1, 1), dtype='uint8')
acc = PatchStack([_r(256), _r(512), _r(256)])
self.assertEqual(acc.hw, (512, 512))
self.assertEqual(acc.iat(0, crop=True).hw, (256, 512))
mask = self.model.label_patch_stack(acc)
self.assertEqual(mask.dtype, bool)
self.assertEqual(mask.chroma, 1)
self.assertEqual(mask.hw, acc.hw)
self.assertEqual(mask.nz, acc.nz)
self.assertEqual(mask.count, acc.count)
pxmap, _ = self.model.infer_patch_stack(acc)
self.assertEqual(pxmap.dtype, float)
self.assertEqual(pxmap.chroma, len(self.model.labels))
self.assertEqual(pxmap.hw, acc.hw)
self.assertEqual(pxmap.nz, acc.nz)
self.assertEqual(pxmap.count, acc.count)
def test_run_object_classifier_from_pixel_predictions(self):
self.test_run_pixel_classifier()
fp = czifile['path']
model = ilm.IlastikObjectClassifierFromPixelPredictionsModel(
params=ilm.IlastikParams(project_file=ilastik_classifiers['pxmap_to_obj']['path'].__str__())
)
mask = self.model.label_pixel_class(self.mono_image)
objmap, _ = model.infer(self.mono_image, mask)
self.assertTrue(
write_accessor_data_to_file(
output_path / f'obmap_{fp.stem}.tif',
objmap,
)
)
self.assertEqual(objmap.data.max(), 2)
def test_run_object_classifier_from_segmentation(self):
self.test_run_pixel_classifier()
fp = czifile['path']
model = ilm.IlastikObjectClassifierFromSegmentationModel(
params=ilm.IlastikParams(project_file=ilastik_classifiers['seg_to_obj']['path'].__str__())
)
mask = self.model.label_pixel_class(self.mono_image)
objmap = model.label_instance_class(self.mono_image, mask)
self.assertTrue(
write_accessor_data_to_file(
output_path / f'obmap_from_seg_{fp.stem}.tif',
objmap,
)
)
self.assertEqual(objmap.data.max(), 2)
def test_ilastik_pixel_classification_as_workflow(self):
res = segment.segment_pipeline(
accessors={
'accessor': generate_file_accessor(czifile['path'])
},
models={
'model': ilm.IlastikPixelClassifierModel(
params=ilm.IlastikPixelClassifierParams(
project_file=ilastik_classifiers['px']['path'].__str__()
),
),
},
channel=0,
)
self.assertGreater(res.times['inference'], 0.1)
class TestServerTestCase(conf.TestServerBaseClass):
app_name = 'tests.test_ilastik.test_ilastik:app'
input_data = czifile
class TestIlastikOverApi(TestServerTestCase):
def test_httpexception_if_incorrect_project_file_loaded(self):
self.assertPutFailure(
'ilastik/seg/load/',
500,
body={'project_file': 'improper.ilp'},
)
def test_load_ilastik_pixel_model(self):
mid = self.assertPutSuccess(
'ilastik/seg/load/',
body={'project_file': str(ilastik_classifiers['px']['path'])},
)['model_id']
rl = self.assertGetSuccess('models')
self.assertEqual(rl[mid]['class'], 'IlastikPixelClassifierModel')
return mid
def test_load_another_ilastik_pixel_model(self):
self.test_load_ilastik_pixel_model()
self.assertEqual(len(self.assertGetSuccess('models')), 1)
self.assertPutSuccess(
'ilastik/seg/load/',
body={'project_file': str(ilastik_classifiers['px']['path']), 'duplicate': True},
)
self.assertEqual(len(self.assertGetSuccess('models')), 2)
self.assertPutSuccess(
'ilastik/seg/load/',
body={'project_file': str(ilastik_classifiers['px']['path']), 'duplicate': False},
)
self.assertEqual(len(self.assertGetSuccess('models')), 2)
def test_load_ilastik_pixel_model_with_params(self):
params = {
'project_file': str(ilastik_classifiers['px']['path']),
'px_class': 0,
'px_prob_threshold': 0.5
}
mid = self.assertPutSuccess(
'ilastik/seg/load/',
body=params,
)['model_id']
mods = self.assertGetSuccess('models')
self.assertEqual(len(mods), 1)
self.assertEqual(mods[mid]['params']['px_prob_threshold'], 0.5)
def test_load_ilastik_pxmap_to_obj_model(self):
mid = self.assertPutSuccess(
'ilastik/pxmap_to_obj/load/',
body={'project_file': str(ilastik_classifiers['pxmap_to_obj']['path'])},
)['model_id']
rl = self.assertGetSuccess('models')
self.assertEqual(rl[mid]['class'], 'IlastikObjectClassifierFromPixelPredictionsModel')
return mid
def test_load_ilastik_model_with_model_id(self):
nmid = 'new_model_id'
rmid = self.assertPutSuccess(
'ilastik/pxmap_to_obj/load/',
query={
'model_id': nmid,
},
body={
'project_file': str(ilastik_classifiers['pxmap_to_obj']['path']),
},
)['model_id']
self.assertEqual(rmid, nmid)
def test_load_ilastik_seg_to_obj_model(self):
mid = self.assertPutSuccess(
'ilastik/seg_to_obj/load/',
body={'project_file': str(ilastik_classifiers['seg_to_obj']['path'])},
)['model_id']
rl = self.assertGetSuccess('models')
self.assertEqual(rl[mid]['class'], 'IlastikObjectClassifierFromSegmentationModel')
return mid
def test_ilastik_infer_pixel_probability(self):
fname = self.copy_input_file_to_server()
mid = self.test_load_ilastik_pixel_model()
acc_id = self.assertPutSuccess(f'accessors/read_from_file/{fname}')
self.assertPutSuccess(
f'pipelines/segment',
body={'model_id': mid, 'accessor_id': acc_id, 'channel': 0},
)
def test_ilastik_infer_px_then_ob(self):
fname = self.copy_input_file_to_server()
px_model_id = self.test_load_ilastik_pixel_model()
ob_model_id = self.test_load_ilastik_pxmap_to_obj_model()
in_acc_id = self.assertPutSuccess(f'accessors/read_from_file/{fname}')
self.assertPutSuccess(
'ilastik/pipelines/pixel_then_object_classification/infer/',
body={
'px_model_id': px_model_id,
'ob_model_id': ob_model_id,
'accessor_id': in_acc_id,
'channel': 0,
}
)
class TestIlastikOnMultichannelInputs(TestServerTestCase):
def setUp(self) -> None:
super(TestIlastikOnMultichannelInputs, self).setUp()
self.pa_px_classifier = ilastik_classifiers['px_color_zstack']['path']
self.pa_ob_pxmap_classifier = ilastik_classifiers['ob_pxmap_color_zstack']['path']
self.pa_ob_seg_classifier = ilastik_classifiers['ob_seg_color_zstack']['path']
self.pa_input_image = data['multichannel_zstack_raw']['path']
self.pa_mask = data['multichannel_zstack_mask3d']['path']
def test_classify_pixels(self):
img = generate_file_accessor(self.pa_input_image)
self.assertGreater(img.chroma, 1)
mod = ilm.IlastikPixelClassifierModel(ilm.IlastikPixelClassifierParams(project_file=self.pa_px_classifier.__str__()))
pxmap = mod.infer(img)[0]
self.assertEqual(pxmap.hw, img.hw)
self.assertEqual(pxmap.nz, img.nz)
return pxmap
def test_classify_objects(self):
pxmap = self.test_classify_pixels()
img = generate_file_accessor(self.pa_input_image)
mod = ilm.IlastikObjectClassifierFromPixelPredictionsModel(
ilm.IlastikParams(project_file=self.pa_ob_pxmap_classifier.__str__())
)
obmap = mod.infer(img, pxmap)[0]
self.assertEqual(obmap.hw, img.hw)
self.assertEqual(obmap.nz, img.nz)
def test_workflow(self):
"""
Test calling pixel then object map classification pipeline function directly
"""
def _call_workflow(channel):
return px_then_ob.pixel_then_object_classification_pipeline(
accessors={
'accessor': generate_file_accessor(self.pa_input_image)
},
models={
'px_model': ilm.IlastikPixelClassifierModel(
ilm.IlastikParams(project_file=self.pa_px_classifier.__str__()),
),
'ob_model': ilm.IlastikObjectClassifierFromPixelPredictionsModel(
ilm.IlastikParams(project_file=self.pa_ob_pxmap_classifier.__str__()),
)
},
channel=channel,
)
with self.assertRaises(ilm.IlastikInputShapeError):
_call_workflow(channel=0)
res = _call_workflow(channel=None)
acc_input = generate_file_accessor(self.pa_input_image)
acc_obmap = res['ob_map']
self.assertEqual(acc_obmap.hw, acc_input.hw)
self.assertEqual(len(acc_obmap.unique()[1]), 3)
def test_api(self):
"""
Test calling pixel then object map classification pipeline over API
"""
copyfile(
self.pa_input_image,
Path(self.assertGetSuccess('paths')['inbound_images']) / self.pa_input_image.name
)
in_acc_id = self.assertPutSuccess(f'accessors/read_from_file/{self.pa_input_image.name}')
px_model_id = self.assertPutSuccess(
'ilastik/seg/load/',
body={'project_file': str(self.pa_px_classifier)},
)['model_id']
ob_model_id = self.assertPutSuccess(
'ilastik/pxmap_to_obj/load/',
body={'project_file': str(self.pa_ob_pxmap_classifier)},
)['model_id']
# run the pipeline
obmap_id = self.assertPutSuccess(
'ilastik/pipelines/pixel_then_object_classification/infer/',
body={
'accessor_id': in_acc_id,
'px_model_id': px_model_id,
'ob_model_id': ob_model_id,
}
)['output_accessor_id']
# save output object map to file and compare
obmap_acc = self.get_accessor(obmap_id)
self.assertEqual(obmap_acc.shape_dict['C'], 1)
# compare dimensions to input image
self.assertEqual(obmap_acc.hw, generate_file_accessor(self.pa_input_image).hw)
class TestIlastikObjectClassification(unittest.TestCase):
def setUp(self):
stack = generate_file_accessor(data['multichannel_zstack_raw']['path'])
stack_ch_pa = stack.get_mono(conf.meta['roiset']['patches_channel'])
seg_mask = generate_file_accessor(data['multichannel_zstack_mask2d']['path'])
self.roiset = RoiSet.from_binary_mask(
stack_ch_pa,
seg_mask,
params=RoiSetMetaParams(
mask_type='boxes',
filters={'area': {'min': 1e3, 'max': 1e4}},
expand_box_by=(64, 2)
)
)
self.classifier = ilm.IlastikObjectClassifierFromSegmentationModel(
params=ilm.IlastikParams(project_file=ilastik_classifiers['seg_to_obj']['path'].__str__()),
)
self.raw = self.roiset.get_patches_acc()
self.masks = self.roiset.get_patch_masks_acc()
def test_classify_patches(self):
res = self.classifier.label_patch_stack(self.raw, self.masks)
self.assertEqual(res.count, self.roiset.count)
res.export_pyxcz(output_path / 'res_patches.tif')
for pi in range(0, res.count): # assert that there is only one nonzero label per patch
la, ct = np.unique(res.iat(pi).data, return_counts=True)
self.assertEqual(np.sum(ct > 1), 2) # exclude single-pixel anomaly
self.assertEqual(la[0], 0)