Skip to content
Snippets Groups Projects
test_ilastik.py 16.33 KiB
from pathlib import Path
from shutil import copyfile
import unittest

import numpy as np

from model_server.base.accessors import CziImageFileAccessor, generate_file_accessor, InMemoryDataAccessor, PatchStack, write_accessor_data_to_file
from model_server.base.api import app
from model_server.extensions.ilastik import models as ilm
from model_server.extensions.ilastik.pipelines import px_then_ob
from model_server.extensions.ilastik.router import router
from model_server.base.roiset import RoiSet, RoiSetMetaParams
from model_server.base.pipelines import segment
import model_server.conf.testing as conf

data = conf.meta['image_files']
output_path = conf.meta['output_path']
params = conf.meta['roiset']
czifile = conf.meta['image_files']['czifile']
ilastik_classifiers = conf.meta['ilastik_classifiers']

app.include_router(router)

def _random_int(*args):
    return np.random.randint(0, 2 ** 8, size=args, dtype='uint8')

class TestIlastikPixelClassification(unittest.TestCase):
    def setUp(self) -> None:
        self.cf = CziImageFileAccessor(czifile['path'])
        self.channel = 0
        self.model = ilm.IlastikPixelClassifierModel(
            params=ilm.IlastikPixelClassifierParams(project_file=ilastik_classifiers['px']['path'].__str__())
        )
        self.mono_image = self.cf.get_mono(self.channel)


    def test_raise_error_if_autoload_disabled(self):
        model = ilm.IlastikPixelClassifierModel(
            params=ilm.IlastikPixelClassifierParams(project_file=ilastik_classifiers['px']['path'].__str__()),
            autoload=False
        )
        w = 512
        h = 256

        input_img = InMemoryDataAccessor(data=np.random.rand(w, h, 1, 1))

        with self.assertRaises(AttributeError):
            mask = model.label_pixel_class(input_img)


    def test_run_pixel_classifier_on_random_data(self):
        w = 512
        h = 256

        input_img = InMemoryDataAccessor(data=np.random.rand(h, w, 1, 1))

        mask = self.model.label_pixel_class(input_img)
        self.assertEqual(mask.shape, (h, w, 1, 1))


    def test_run_pixel_classifier(self):
        self.assertEqual(self.mono_image.shape_dict['X'], czifile['w'])
        self.assertEqual(self.mono_image.shape_dict['Y'], czifile['h'])
        self.assertEqual(self.mono_image.shape_dict['C'], 1)
        self.assertEqual(self.mono_image.shape_dict['Z'], 1)

        mask = self.model.label_pixel_class(self.mono_image)

        self.assertTrue(mask.is_mask())
        self.assertEqual(mask.shape[0:2], self.cf.shape[0:2])
        self.assertEqual(mask.shape_dict['C'], 1)
        self.assertEqual(mask.shape_dict['Z'], 1)
        self.assertTrue(
            write_accessor_data_to_file(
                output_path / 'seg' / f'seg_{self.cf.fpath.stem}_ch{self.channel}.tif',
                mask
            )
        )

    def test_label_pixels_with_params(self):
        def _run_seg(tr, sig):
            mod = ilm.IlastikPixelClassifierModel(
                params=ilm.IlastikPixelClassifierParams(
                    project_file=ilastik_classifiers['px']['path'].__str__(),
                    px_prob_threshold=tr,
                    px_smoothing=sig,
                ),
            )
            mask = mod.label_pixel_class(self.mono_image)
            write_accessor_data_to_file(
                output_path / 'seg' / f'seg_tr{int(10*tr)}_sig{int(10*sig)}.tif',
                mask
            )
            return mask
        mask1 = _run_seg(0.5, 0.0)
        mask2 = _run_seg(0.5, 0.2)
        self.assertEqual(mask1.shape, mask2.shape)


    def test_pixel_classifier_enforces_input_shape(self):
        self.assertEqual(self.model.model_chroma, 1)
        self.assertEqual(self.model.model_3d, False)

        # correct data
        self.assertIsInstance(
            self.model.label_pixel_class(
                InMemoryDataAccessor(
                    _random_int(512, 256, 1, 1)
                )
            ),
            InMemoryDataAccessor
        )

        # raise except with input of multiple channels
        with self.assertRaises(ilm.IlastikInputShapeError):
            mask = self.model.label_pixel_class(
                InMemoryDataAccessor(
                    _random_int(512, 256, 3, 1)
                )
            )

        # raise except with input of multiple channels
        with self.assertRaises(ilm.IlastikInputShapeError):
            mask = self.model.label_pixel_class(
                InMemoryDataAccessor(
                    _random_int(512, 256, 1, 15)
                )
            )

    def test_ilastik_infer_pxmap_from_patchstack(self):

        def _r(h):
            return np.random.randint(0, 2 ** 8, size=(h, 512, 1, 1), dtype='uint8')

        acc = PatchStack([_r(256), _r(512), _r(256)])
        self.assertEqual(acc.hw, (512, 512))
        self.assertEqual(acc.iat(0, crop=True).hw, (256, 512))

        mask = self.model.label_patch_stack(acc)
        self.assertEqual(mask.dtype, bool)
        self.assertEqual(mask.chroma, 1)
        self.assertEqual(mask.hw, acc.hw)
        self.assertEqual(mask.nz, acc.nz)
        self.assertEqual(mask.count, acc.count)

        pxmap, _ = self.model.infer_patch_stack(acc)
        self.assertEqual(pxmap.dtype, float)
        self.assertEqual(pxmap.chroma, len(self.model.labels))
        self.assertEqual(pxmap.hw, acc.hw)
        self.assertEqual(pxmap.nz, acc.nz)
        self.assertEqual(pxmap.count, acc.count)

    def test_run_object_classifier_from_pixel_predictions(self):
        self.test_run_pixel_classifier()
        fp = czifile['path']
        model = ilm.IlastikObjectClassifierFromPixelPredictionsModel(
            params=ilm.IlastikParams(project_file=ilastik_classifiers['pxmap_to_obj']['path'].__str__())
        )
        mask = self.model.label_pixel_class(self.mono_image)
        objmap, _ = model.infer(self.mono_image, mask)

        self.assertTrue(
            write_accessor_data_to_file(
                output_path / f'obmap_{fp.stem}.tif',
                objmap,
            )
        )
        self.assertEqual(objmap.data.max(), 2)


    def test_run_object_classifier_from_segmentation(self):
        self.test_run_pixel_classifier()
        fp = czifile['path']
        model = ilm.IlastikObjectClassifierFromSegmentationModel(
            params=ilm.IlastikParams(project_file=ilastik_classifiers['seg_to_obj']['path'].__str__())
        )
        mask = self.model.label_pixel_class(self.mono_image)
        objmap = model.label_instance_class(self.mono_image, mask)

        self.assertTrue(
            write_accessor_data_to_file(
                output_path / f'obmap_from_seg_{fp.stem}.tif',
                objmap,
            )
        )
        self.assertEqual(objmap.data.max(), 2)

    def test_ilastik_pixel_classification_as_workflow(self):
        res = segment.segment_pipeline(
            accessors={
                'accessor': generate_file_accessor(czifile['path'])
            },
            models={
                'model': ilm.IlastikPixelClassifierModel(
                    params=ilm.IlastikPixelClassifierParams(
                        project_file=ilastik_classifiers['px']['path'].__str__()
                    ),
                ),
            },
            channel=0,
        )
        self.assertGreater(res.times['inference'], 0.1)


class TestServerTestCase(conf.TestServerBaseClass):
    app_name = 'tests.test_ilastik.test_ilastik:app'
    input_data = czifile


class TestIlastikOverApi(TestServerTestCase):
    def test_httpexception_if_incorrect_project_file_loaded(self):
        self.assertPutFailure(
            'ilastik/seg/load/',
            500,
            body={'project_file': 'improper.ilp'},
        )


    def test_load_ilastik_pixel_model(self):
        mid = self.assertPutSuccess(
            'ilastik/seg/load/',
            body={'project_file': str(ilastik_classifiers['px']['path'])},
        )['model_id']
        rl = self.assertGetSuccess('models')
        self.assertEqual(rl[mid]['class'], 'IlastikPixelClassifierModel')
        return mid

    def test_load_another_ilastik_pixel_model(self):
        self.test_load_ilastik_pixel_model()
        self.assertEqual(len(self.assertGetSuccess('models')), 1)
        self.assertPutSuccess(
            'ilastik/seg/load/',
            body={'project_file': str(ilastik_classifiers['px']['path']), 'duplicate': True},
        )
        self.assertEqual(len(self.assertGetSuccess('models')), 2)
        self.assertPutSuccess(
            'ilastik/seg/load/',
            body={'project_file': str(ilastik_classifiers['px']['path']), 'duplicate': False},
        )
        self.assertEqual(len(self.assertGetSuccess('models')), 2)

    def test_load_ilastik_pixel_model_with_params(self):
        params = {
            'project_file': str(ilastik_classifiers['px']['path']),
            'px_class': 0,
            'px_prob_threshold': 0.5
        }
        mid = self.assertPutSuccess(
            'ilastik/seg/load/',
            body=params,
        )['model_id']
        mods = self.assertGetSuccess('models')
        self.assertEqual(len(mods), 1)
        self.assertEqual(mods[mid]['params']['px_prob_threshold'], 0.5)


    def test_load_ilastik_pxmap_to_obj_model(self):
        mid = self.assertPutSuccess(
            'ilastik/pxmap_to_obj/load/',
            body={'project_file': str(ilastik_classifiers['pxmap_to_obj']['path'])},
        )['model_id']
        rl = self.assertGetSuccess('models')
        self.assertEqual(rl[mid]['class'], 'IlastikObjectClassifierFromPixelPredictionsModel')
        return mid

    def test_load_ilastik_model_with_model_id(self):
        nmid = 'new_model_id'
        rmid = self.assertPutSuccess(
            'ilastik/pxmap_to_obj/load/',
            query={
                'model_id': nmid,
            },
            body={
                'project_file': str(ilastik_classifiers['pxmap_to_obj']['path']),
            },
        )['model_id']
        self.assertEqual(rmid, nmid)

    def test_load_ilastik_seg_to_obj_model(self):
        mid = self.assertPutSuccess(
            'ilastik/seg_to_obj/load/',
            body={'project_file': str(ilastik_classifiers['seg_to_obj']['path'])},
        )['model_id']
        rl = self.assertGetSuccess('models')
        self.assertEqual(rl[mid]['class'], 'IlastikObjectClassifierFromSegmentationModel')
        return mid

    def test_ilastik_infer_pixel_probability(self):
        fname = self.copy_input_file_to_server()
        mid = self.test_load_ilastik_pixel_model()
        acc_id = self.assertPutSuccess(f'accessors/read_from_file/{fname}')

        self.assertPutSuccess(
            f'pipelines/segment',
            body={'model_id': mid, 'accessor_id': acc_id, 'channel': 0},
        )


    def test_ilastik_infer_px_then_ob(self):
        fname = self.copy_input_file_to_server()
        px_model_id = self.test_load_ilastik_pixel_model()
        ob_model_id = self.test_load_ilastik_pxmap_to_obj_model()

        in_acc_id = self.assertPutSuccess(f'accessors/read_from_file/{fname}')

        self.assertPutSuccess(
            'ilastik/pipelines/pixel_then_object_classification/infer/',
            body={
                'px_model_id': px_model_id,
                'ob_model_id': ob_model_id,
                'accessor_id': in_acc_id,
                'channel': 0,
            }
        )


class TestIlastikOnMultichannelInputs(TestServerTestCase):
    def setUp(self) -> None:
        super(TestIlastikOnMultichannelInputs, self).setUp()
        self.pa_px_classifier = ilastik_classifiers['px_color_zstack']['path']
        self.pa_ob_pxmap_classifier = ilastik_classifiers['ob_pxmap_color_zstack']['path']
        self.pa_ob_seg_classifier = ilastik_classifiers['ob_seg_color_zstack']['path']
        self.pa_input_image = data['multichannel_zstack_raw']['path']
        self.pa_mask = data['multichannel_zstack_mask3d']['path']


    def test_classify_pixels(self):
        img = generate_file_accessor(self.pa_input_image)
        self.assertGreater(img.chroma, 1)
        mod = ilm.IlastikPixelClassifierModel(ilm.IlastikPixelClassifierParams(project_file=self.pa_px_classifier.__str__()))
        pxmap = mod.infer(img)[0]
        self.assertEqual(pxmap.hw, img.hw)
        self.assertEqual(pxmap.nz, img.nz)
        return pxmap

    def test_classify_objects(self):
        pxmap = self.test_classify_pixels()
        img = generate_file_accessor(self.pa_input_image)
        mod = ilm.IlastikObjectClassifierFromPixelPredictionsModel(
            ilm.IlastikParams(project_file=self.pa_ob_pxmap_classifier.__str__())
        )
        obmap = mod.infer(img, pxmap)[0]
        self.assertEqual(obmap.hw, img.hw)
        self.assertEqual(obmap.nz, img.nz)

    def test_workflow(self):
        """
        Test calling pixel then object map classification pipeline function directly
        """
        def _call_workflow(channel):
            return px_then_ob.pixel_then_object_classification_pipeline(
                accessors={
                    'accessor': generate_file_accessor(self.pa_input_image)
                },
                models={
                    'px_model': ilm.IlastikPixelClassifierModel(
                        ilm.IlastikParams(project_file=self.pa_px_classifier.__str__()),
                    ),
                    'ob_model': ilm.IlastikObjectClassifierFromPixelPredictionsModel(
                        ilm.IlastikParams(project_file=self.pa_ob_pxmap_classifier.__str__()),
                    )
                },
                channel=channel,
            )

        with self.assertRaises(ilm.IlastikInputShapeError):
            _call_workflow(channel=0)
        res = _call_workflow(channel=None)
        acc_input = generate_file_accessor(self.pa_input_image)
        acc_obmap = res['ob_map']
        self.assertEqual(acc_obmap.hw, acc_input.hw)
        self.assertEqual(len(acc_obmap.unique()[1]), 3)


    def test_api(self):
        """
        Test calling pixel then object map classification pipeline over API
        """
        copyfile(
            self.pa_input_image,
            Path(self.assertGetSuccess('paths')['inbound_images']) / self.pa_input_image.name
        )

        in_acc_id = self.assertPutSuccess(f'accessors/read_from_file/{self.pa_input_image.name}')

        px_model_id = self.assertPutSuccess(
            'ilastik/seg/load/',
            body={'project_file': str(self.pa_px_classifier)},
        )['model_id']

        ob_model_id  = self.assertPutSuccess(
            'ilastik/pxmap_to_obj/load/',
            body={'project_file': str(self.pa_ob_pxmap_classifier)},
        )['model_id']

        # run the pipeline
        obmap_id = self.assertPutSuccess(
            'ilastik/pipelines/pixel_then_object_classification/infer/',
            body={
                'accessor_id': in_acc_id,
                'px_model_id': px_model_id,
                'ob_model_id': ob_model_id,
            }
        )['output_accessor_id']

        # save output object map to file and compare
        obmap_acc = self.get_accessor(obmap_id)
        self.assertEqual(obmap_acc.shape_dict['C'], 1)

        # compare dimensions to input image
        self.assertEqual(obmap_acc.hw, generate_file_accessor(self.pa_input_image).hw)


class TestIlastikObjectClassification(unittest.TestCase):
    def setUp(self):
        stack = generate_file_accessor(data['multichannel_zstack_raw']['path'])
        stack_ch_pa = stack.get_mono(conf.meta['roiset']['patches_channel'])
        seg_mask = generate_file_accessor(data['multichannel_zstack_mask2d']['path'])

        self.roiset = RoiSet.from_binary_mask(
            stack_ch_pa,
            seg_mask,
            params=RoiSetMetaParams(
                mask_type='boxes',
                filters={'area': {'min': 1e3, 'max': 1e4}},
                expand_box_by=(64, 2)
            )
        )

        self.classifier = ilm.IlastikObjectClassifierFromSegmentationModel(
            params=ilm.IlastikParams(project_file=ilastik_classifiers['seg_to_obj']['path'].__str__()),
        )
        self.raw = self.roiset.get_patches_acc()
        self.masks = self.roiset.get_patch_masks_acc()


    def test_classify_patches(self):
        res = self.classifier.label_patch_stack(self.raw, self.masks)
        self.assertEqual(res.count, self.roiset.count)
        res.export_pyxcz(output_path / 'res_patches.tif')
        for pi in range(0, res.count):  # assert that there is only one nonzero label per patch
            la, ct = np.unique(res.iat(pi).data, return_counts=True)
            self.assertEqual(np.sum(ct > 1), 2)  # exclude single-pixel anomaly
            self.assertEqual(la[0], 0)