from pathlib import Path from shutil import copyfile import unittest import numpy as np from model_server.base.accessors import CziImageFileAccessor, generate_file_accessor, InMemoryDataAccessor, PatchStack, write_accessor_data_to_file from model_server.base.api import app from model_server.extensions.ilastik import models as ilm from model_server.extensions.ilastik.pipelines import px_then_ob from model_server.extensions.ilastik.router import router from model_server.base.roiset import RoiSet, RoiSetMetaParams from model_server.base.pipelines import segment import model_server.conf.testing as conf data = conf.meta['image_files'] output_path = conf.meta['output_path'] params = conf.meta['roiset'] czifile = conf.meta['image_files']['czifile'] ilastik_classifiers = conf.meta['ilastik_classifiers'] app.include_router(router) def _random_int(*args): return np.random.randint(0, 2 ** 8, size=args, dtype='uint8') class TestIlastikPixelClassification(unittest.TestCase): def setUp(self) -> None: self.cf = CziImageFileAccessor(czifile['path']) self.channel = 0 self.model = ilm.IlastikPixelClassifierModel( params=ilm.IlastikPixelClassifierParams(project_file=ilastik_classifiers['px']['path'].__str__()) ) self.mono_image = self.cf.get_mono(self.channel) def test_raise_error_if_autoload_disabled(self): model = ilm.IlastikPixelClassifierModel( params=ilm.IlastikPixelClassifierParams(project_file=ilastik_classifiers['px']['path'].__str__()), autoload=False ) w = 512 h = 256 input_img = InMemoryDataAccessor(data=np.random.rand(w, h, 1, 1)) with self.assertRaises(AttributeError): mask = model.label_pixel_class(input_img) def test_run_pixel_classifier_on_random_data(self): w = 512 h = 256 input_img = InMemoryDataAccessor(data=np.random.rand(h, w, 1, 1)) mask = self.model.label_pixel_class(input_img) self.assertEqual(mask.shape, (h, w, 1, 1)) def test_run_pixel_classifier(self): self.assertEqual(self.mono_image.shape_dict['X'], czifile['w']) self.assertEqual(self.mono_image.shape_dict['Y'], czifile['h']) self.assertEqual(self.mono_image.shape_dict['C'], 1) self.assertEqual(self.mono_image.shape_dict['Z'], 1) mask = self.model.label_pixel_class(self.mono_image) self.assertTrue(mask.is_mask()) self.assertEqual(mask.shape[0:2], self.cf.shape[0:2]) self.assertEqual(mask.shape_dict['C'], 1) self.assertEqual(mask.shape_dict['Z'], 1) self.assertTrue( write_accessor_data_to_file( output_path / 'seg' / f'seg_{self.cf.fpath.stem}_ch{self.channel}.tif', mask ) ) def test_label_pixels_with_params(self): def _run_seg(tr, sig): mod = ilm.IlastikPixelClassifierModel( params=ilm.IlastikPixelClassifierParams( project_file=ilastik_classifiers['px']['path'].__str__(), px_prob_threshold=tr, px_smoothing=sig, ), ) mask = mod.label_pixel_class(self.mono_image) write_accessor_data_to_file( output_path / 'seg' / f'seg_tr{int(10*tr)}_sig{int(10*sig)}.tif', mask ) return mask mask1 = _run_seg(0.5, 0.0) mask2 = _run_seg(0.5, 0.2) self.assertEqual(mask1.shape, mask2.shape) def test_pixel_classifier_enforces_input_shape(self): self.assertEqual(self.model.model_chroma, 1) self.assertEqual(self.model.model_3d, False) # correct data self.assertIsInstance( self.model.label_pixel_class( InMemoryDataAccessor( _random_int(512, 256, 1, 1) ) ), InMemoryDataAccessor ) # raise except with input of multiple channels with self.assertRaises(ilm.IlastikInputShapeError): mask = self.model.label_pixel_class( InMemoryDataAccessor( _random_int(512, 256, 3, 1) ) ) # raise except with input of multiple channels with self.assertRaises(ilm.IlastikInputShapeError): mask = self.model.label_pixel_class( InMemoryDataAccessor( _random_int(512, 256, 1, 15) ) ) def test_ilastik_infer_pxmap_from_patchstack(self): def _r(h): return np.random.randint(0, 2 ** 8, size=(h, 512, 1, 1), dtype='uint8') acc = PatchStack([_r(256), _r(512), _r(256)]) self.assertEqual(acc.hw, (512, 512)) self.assertEqual(acc.iat(0, crop=True).hw, (256, 512)) mask = self.model.label_patch_stack(acc) self.assertEqual(mask.dtype, bool) self.assertEqual(mask.chroma, 1) self.assertEqual(mask.hw, acc.hw) self.assertEqual(mask.nz, acc.nz) self.assertEqual(mask.count, acc.count) pxmap, _ = self.model.infer_patch_stack(acc) self.assertEqual(pxmap.dtype, float) self.assertEqual(pxmap.chroma, len(self.model.labels)) self.assertEqual(pxmap.hw, acc.hw) self.assertEqual(pxmap.nz, acc.nz) self.assertEqual(pxmap.count, acc.count) def test_run_object_classifier_from_pixel_predictions(self): self.test_run_pixel_classifier() fp = czifile['path'] model = ilm.IlastikObjectClassifierFromPixelPredictionsModel( params=ilm.IlastikParams(project_file=ilastik_classifiers['pxmap_to_obj']['path'].__str__()) ) mask = self.model.label_pixel_class(self.mono_image) objmap, _ = model.infer(self.mono_image, mask) self.assertTrue( write_accessor_data_to_file( output_path / f'obmap_{fp.stem}.tif', objmap, ) ) self.assertEqual(objmap.data.max(), 2) def test_run_object_classifier_from_segmentation(self): self.test_run_pixel_classifier() fp = czifile['path'] model = ilm.IlastikObjectClassifierFromSegmentationModel( params=ilm.IlastikParams(project_file=ilastik_classifiers['seg_to_obj']['path'].__str__()) ) mask = self.model.label_pixel_class(self.mono_image) objmap = model.label_instance_class(self.mono_image, mask) self.assertTrue( write_accessor_data_to_file( output_path / f'obmap_from_seg_{fp.stem}.tif', objmap, ) ) self.assertEqual(objmap.data.max(), 2) def test_ilastik_pixel_classification_as_workflow(self): res = segment.segment_pipeline( accessors={ 'accessor': generate_file_accessor(czifile['path']) }, models={ 'model': ilm.IlastikPixelClassifierModel( params=ilm.IlastikPixelClassifierParams( project_file=ilastik_classifiers['px']['path'].__str__() ), ), }, channel=0, ) self.assertGreater(res.times['inference'], 0.1) class TestServerTestCase(conf.TestServerBaseClass): app_name = 'tests.test_ilastik.test_ilastik:app' input_data = czifile class TestIlastikOverApi(TestServerTestCase): def test_httpexception_if_incorrect_project_file_loaded(self): self.assertPutFailure( 'ilastik/seg/load/', 500, body={'project_file': 'improper.ilp'}, ) def test_load_ilastik_pixel_model(self): mid = self.assertPutSuccess( 'ilastik/seg/load/', body={'project_file': str(ilastik_classifiers['px']['path'])}, )['model_id'] rl = self.assertGetSuccess('models') self.assertEqual(rl[mid]['class'], 'IlastikPixelClassifierModel') return mid def test_load_another_ilastik_pixel_model(self): self.test_load_ilastik_pixel_model() self.assertEqual(len(self.assertGetSuccess('models')), 1) self.assertPutSuccess( 'ilastik/seg/load/', body={'project_file': str(ilastik_classifiers['px']['path']), 'duplicate': True}, ) self.assertEqual(len(self.assertGetSuccess('models')), 2) self.assertPutSuccess( 'ilastik/seg/load/', body={'project_file': str(ilastik_classifiers['px']['path']), 'duplicate': False}, ) self.assertEqual(len(self.assertGetSuccess('models')), 2) def test_load_ilastik_pixel_model_with_params(self): params = { 'project_file': str(ilastik_classifiers['px']['path']), 'px_class': 0, 'px_prob_threshold': 0.5 } mid = self.assertPutSuccess( 'ilastik/seg/load/', body=params, )['model_id'] mods = self.assertGetSuccess('models') self.assertEqual(len(mods), 1) self.assertEqual(mods[mid]['params']['px_prob_threshold'], 0.5) def test_load_ilastik_pxmap_to_obj_model(self): mid = self.assertPutSuccess( 'ilastik/pxmap_to_obj/load/', body={'project_file': str(ilastik_classifiers['pxmap_to_obj']['path'])}, )['model_id'] rl = self.assertGetSuccess('models') self.assertEqual(rl[mid]['class'], 'IlastikObjectClassifierFromPixelPredictionsModel') return mid def test_load_ilastik_model_with_model_id(self): nmid = 'new_model_id' rmid = self.assertPutSuccess( 'ilastik/pxmap_to_obj/load/', query={ 'model_id': nmid, }, body={ 'project_file': str(ilastik_classifiers['pxmap_to_obj']['path']), }, )['model_id'] self.assertEqual(rmid, nmid) def test_load_ilastik_seg_to_obj_model(self): mid = self.assertPutSuccess( 'ilastik/seg_to_obj/load/', body={'project_file': str(ilastik_classifiers['seg_to_obj']['path'])}, )['model_id'] rl = self.assertGetSuccess('models') self.assertEqual(rl[mid]['class'], 'IlastikObjectClassifierFromSegmentationModel') return mid def test_ilastik_infer_pixel_probability(self): fname = self.copy_input_file_to_server() mid = self.test_load_ilastik_pixel_model() acc_id = self.assertPutSuccess(f'accessors/read_from_file/{fname}') self.assertPutSuccess( f'pipelines/segment', body={'model_id': mid, 'accessor_id': acc_id, 'channel': 0}, ) def test_ilastik_infer_px_then_ob(self): fname = self.copy_input_file_to_server() px_model_id = self.test_load_ilastik_pixel_model() ob_model_id = self.test_load_ilastik_pxmap_to_obj_model() in_acc_id = self.assertPutSuccess(f'accessors/read_from_file/{fname}') self.assertPutSuccess( 'ilastik/pipelines/pixel_then_object_classification/infer/', body={ 'px_model_id': px_model_id, 'ob_model_id': ob_model_id, 'accessor_id': in_acc_id, 'channel': 0, } ) class TestIlastikOnMultichannelInputs(TestServerTestCase): def setUp(self) -> None: super(TestIlastikOnMultichannelInputs, self).setUp() self.pa_px_classifier = ilastik_classifiers['px_color_zstack']['path'] self.pa_ob_pxmap_classifier = ilastik_classifiers['ob_pxmap_color_zstack']['path'] self.pa_ob_seg_classifier = ilastik_classifiers['ob_seg_color_zstack']['path'] self.pa_input_image = data['multichannel_zstack_raw']['path'] self.pa_mask = data['multichannel_zstack_mask3d']['path'] def test_classify_pixels(self): img = generate_file_accessor(self.pa_input_image) self.assertGreater(img.chroma, 1) mod = ilm.IlastikPixelClassifierModel(ilm.IlastikPixelClassifierParams(project_file=self.pa_px_classifier.__str__())) pxmap = mod.infer(img)[0] self.assertEqual(pxmap.hw, img.hw) self.assertEqual(pxmap.nz, img.nz) return pxmap def test_classify_objects(self): pxmap = self.test_classify_pixels() img = generate_file_accessor(self.pa_input_image) mod = ilm.IlastikObjectClassifierFromPixelPredictionsModel( ilm.IlastikParams(project_file=self.pa_ob_pxmap_classifier.__str__()) ) obmap = mod.infer(img, pxmap)[0] self.assertEqual(obmap.hw, img.hw) self.assertEqual(obmap.nz, img.nz) def test_workflow(self): """ Test calling pixel then object map classification pipeline function directly """ def _call_workflow(channel): return px_then_ob.pixel_then_object_classification_pipeline( accessors={ 'accessor': generate_file_accessor(self.pa_input_image) }, models={ 'px_model': ilm.IlastikPixelClassifierModel( ilm.IlastikParams(project_file=self.pa_px_classifier.__str__()), ), 'ob_model': ilm.IlastikObjectClassifierFromPixelPredictionsModel( ilm.IlastikParams(project_file=self.pa_ob_pxmap_classifier.__str__()), ) }, channel=channel, ) with self.assertRaises(ilm.IlastikInputShapeError): _call_workflow(channel=0) res = _call_workflow(channel=None) acc_input = generate_file_accessor(self.pa_input_image) acc_obmap = res['ob_map'] self.assertEqual(acc_obmap.hw, acc_input.hw) self.assertEqual(len(acc_obmap.unique()[1]), 3) def test_api(self): """ Test calling pixel then object map classification pipeline over API """ copyfile( self.pa_input_image, Path(self.assertGetSuccess('paths')['inbound_images']) / self.pa_input_image.name ) in_acc_id = self.assertPutSuccess(f'accessors/read_from_file/{self.pa_input_image.name}') px_model_id = self.assertPutSuccess( 'ilastik/seg/load/', body={'project_file': str(self.pa_px_classifier)}, )['model_id'] ob_model_id = self.assertPutSuccess( 'ilastik/pxmap_to_obj/load/', body={'project_file': str(self.pa_ob_pxmap_classifier)}, )['model_id'] # run the pipeline obmap_id = self.assertPutSuccess( 'ilastik/pipelines/pixel_then_object_classification/infer/', body={ 'accessor_id': in_acc_id, 'px_model_id': px_model_id, 'ob_model_id': ob_model_id, } )['output_accessor_id'] # save output object map to file and compare obmap_acc = self.get_accessor(obmap_id) self.assertEqual(obmap_acc.shape_dict['C'], 1) # compare dimensions to input image self.assertEqual(obmap_acc.hw, generate_file_accessor(self.pa_input_image).hw) class TestIlastikObjectClassification(unittest.TestCase): def setUp(self): stack = generate_file_accessor(data['multichannel_zstack_raw']['path']) stack_ch_pa = stack.get_mono(conf.meta['roiset']['patches_channel']) seg_mask = generate_file_accessor(data['multichannel_zstack_mask2d']['path']) self.roiset = RoiSet.from_binary_mask( stack_ch_pa, seg_mask, params=RoiSetMetaParams( mask_type='boxes', filters={'area': {'min': 1e3, 'max': 1e4}}, expand_box_by=(64, 2) ) ) self.classifier = ilm.IlastikObjectClassifierFromSegmentationModel( params=ilm.IlastikParams(project_file=ilastik_classifiers['seg_to_obj']['path'].__str__()), ) self.raw = self.roiset.get_patches_acc() self.masks = self.roiset.get_patch_masks_acc() def test_classify_patches(self): res = self.classifier.label_patch_stack(self.raw, self.masks) self.assertEqual(res.count, self.roiset.count) res.export_pyxcz(output_path / 'res_patches.tif') for pi in range(0, res.count): # assert that there is only one nonzero label per patch la, ct = np.unique(res.iat(pi).data, return_counts=True) self.assertEqual(np.sum(ct > 1), 2) # exclude single-pixel anomaly self.assertEqual(la[0], 0)