Skip to content
Snippets Groups Projects
test_ilastik.py 12.34 KiB
import pathlib
import requests
import unittest

import numpy as np

from model_server.conf.testing import czifile, ilastik_classifiers, output_path, roiset_test_data
from model_server.base.accessors import CziImageFileAccessor, generate_file_accessor, InMemoryDataAccessor, write_accessor_data_to_file
from model_server.extensions.ilastik import models as ilm
from model_server.base.roiset import _get_label_ids, RoiSet, RoiSetMetaParams
from model_server.base.workflows import classify_pixels
from tests.test_api import TestServerBaseClass

def _random_int(*args):
    return np.random.randint(0, 2 ** 8, size=args, dtype='uint8')

class TestIlastikPixelClassification(unittest.TestCase):
    def setUp(self) -> None:
        self.cf = CziImageFileAccessor(czifile['path'])


    def test_faulthandler(self): # recreate error that is messing up ilastik
        import io
        import sys
        import faulthandler

        with self.assertRaises(io.UnsupportedOperation):
            faulthandler.enable(file=sys.stdout)


    def test_raise_error_if_autoload_disabled(self):
        model = ilm.IlastikPixelClassifierModel(
            {'project_file': ilastik_classifiers['px']},
            autoload=False
        )
        w = 512
        h = 256

        input_img = InMemoryDataAccessor(data=np.random.rand(w, h, 1, 1))

        with self.assertRaises(AttributeError):
            mask = model.label_pixel_class(input_img)


    def test_run_pixel_classifier_on_random_data(self):
        model = ilm.IlastikPixelClassifierModel(
            {'project_file': ilastik_classifiers['px']},
        )
        w = 512
        h = 256

        input_img = InMemoryDataAccessor(data=np.random.rand(h, w, 1, 1))

        mask = model.label_pixel_class(input_img)
        self.assertEqual(mask.shape, (h, w, 1, 1))


    def test_run_pixel_classifier(self):
        channel = 0
        model = ilm.IlastikPixelClassifierModel(
            {'project_file': ilastik_classifiers['px']}
        )
        cf = CziImageFileAccessor(
            czifile['path']
        )
        mono_image = cf.get_one_channel_data(channel)

        self.assertEqual(mono_image.shape_dict['X'], czifile['w'])
        self.assertEqual(mono_image.shape_dict['Y'], czifile['h'])
        self.assertEqual(mono_image.shape_dict['C'], 1)
        self.assertEqual(mono_image.shape_dict['Z'], 1)

        mask = model.label_pixel_class(mono_image)

        self.assertTrue(mask.is_mask())
        self.assertEqual(mask.shape[0:2], cf.shape[0:2])
        self.assertEqual(mask.shape_dict['C'], 1)
        self.assertEqual(mask.shape_dict['Z'], 1)
        self.assertTrue(
            write_accessor_data_to_file(
                output_path / f'pxmap_{cf.fpath.stem}_ch{channel}.tif',
                mask
            )
        )

        self.mono_image = mono_image
        self.mask = mask

    def test_pixel_classifier_enforces_input_shape(self):
        model = ilm.IlastikPixelClassifierModel(
            {'project_file': ilastik_classifiers['px']}
        )
        self.assertEqual(model.model_chroma, 1)
        self.assertEqual(model.model_3d, False)

        # correct data
        self.assertIsInstance(
            model.label_pixel_class(
                InMemoryDataAccessor(
                    _random_int(512, 256, 1, 1)
                )
            ),
            InMemoryDataAccessor
        )

        # raise except with input of multiple channels
        with self.assertRaises(ilm.IlastikInputShapeError):
            mask = model.label_pixel_class(
                InMemoryDataAccessor(
                    _random_int(512, 256, 3, 1)
                )
            )

        # raise except with input of multiple channels
        with self.assertRaises(ilm.IlastikInputShapeError):
            mask = model.label_pixel_class(
                InMemoryDataAccessor(
                    _random_int(512, 256, 1, 15)
                )
            )


    def test_run_object_classifier_from_pixel_predictions(self):
        self.test_run_pixel_classifier()
        fp = czifile['path']
        model = ilm.IlastikObjectClassifierFromPixelPredictionsModel(
            {'project_file': ilastik_classifiers['pxmap_to_obj']}
        )
        objmap, _ = model.infer(self.mono_image, self.mask)

        self.assertTrue(
            write_accessor_data_to_file(
                output_path / f'obmap_{fp.stem}.tif',
                objmap,
            )
        )
        self.assertEqual(objmap.data.max(), 2)

    def test_make_seg_obj_model_from_pxmap_obj(self):
        self.test_run_pixel_classifier()
        fp = czifile['path']
        pxmap_model = ilm.IlastikObjectClassifierFromPixelPredictionsModel(
            {'project_file': ilastik_classifiers['pxmap_to_obj']}
        )
        seg_model = pxmap_model.make_instance_segmentation_model(px_ch=0)
        objmap = seg_model.label_instance_class(self.mono_image, self.mask)

        self.assertTrue(
            write_accessor_data_to_file(
                output_path / f'obmap_seg_from_pxmap_{fp.stem}.tif',
                objmap,
            )
        )
        self.assertEqual(objmap.data.max(), 2)

    def test_run_object_classifier_from_segmentation(self):
        self.test_run_pixel_classifier()
        fp = czifile['path']
        model = ilm.IlastikObjectClassifierFromSegmentationModel(
            {'project_file': ilastik_classifiers['seg_to_obj']}
        )
        objmap = model.label_instance_class(self.mono_image, self.mask)

        self.assertTrue(
            write_accessor_data_to_file(
                output_path / f'obmap_from_seg_{fp.stem}.tif',
                objmap,
            )
        )
        self.assertEqual(objmap.data.max(), 2)

    def test_ilastik_pixel_classification_as_workflow(self):
        result = classify_pixels(
            czifile['path'],
            ilm.IlastikPixelClassifierModel(
                {'project_file': ilastik_classifiers['px']}
            ),
            output_path,
            channel=0,
        )
        self.assertTrue(result.success)
        self.assertGreater(result.timer_results['inference'], 1.0)

class TestIlastikOverApi(TestServerBaseClass):

    def test_httpexception_if_incorrect_project_file_loaded(self):
        resp_load = self._put(
            'ilastik/seg/load/',
            {'project_file': 'improper.ilp'},
        )
        self.assertEqual(resp_load.status_code, 404)


    def test_load_ilastik_pixel_model(self):
        resp_load = self._put(
            'ilastik/seg/load/',
            {'project_file': str(ilastik_classifiers['px'])},
        )
        self.assertEqual(resp_load.status_code, 200, resp_load.json())
        model_id = resp_load.json()['model_id']
        resp_list = self._get('models')
        self.assertEqual(resp_list.status_code, 200)
        rj = resp_list.json()
        self.assertEqual(rj[model_id]['class'], 'IlastikPixelClassifierModel')
        return model_id

    def test_load_another_ilastik_pixel_model(self):
        model_id = self.test_load_ilastik_pixel_model()
        resp_list_1st = self._get('models').json()
        self.assertEqual(len(resp_list_1st), 1, resp_list_1st)
        resp_load_2nd = self._put(
            'ilastik/seg/load/',
            {'project_file': str(ilastik_classifiers['px']), 'duplicate': True, },
        )
        resp_list_2nd = self._get('models').json()
        self.assertEqual(len(resp_list_2nd), 2, resp_list_2nd)
        resp_load_3rd = self._put(
            'ilastik/seg/load/',
            {'project_file': str(ilastik_classifiers['px']), 'duplicate': False},
        )
        resp_list_3rd = self._get('models').json()
        self.assertEqual(len(resp_list_3rd), 2, resp_list_3rd)

    def test_no_duplicate_model_with_different_path_formats(self):
        self._get('session/restart')
        resp_list_1 = self._get('models').json()
        self.assertEqual(len(resp_list_1), 0)
        ilp = ilastik_classifiers['px']

        # create and validate two copies of the same pathname with different string formats
        ilp_win = str(pathlib.PureWindowsPath(ilp))
        self.assertGreater(ilp_win.count('\\'), 0) # i.e. contains backslashes
        self.assertEqual(ilp_win.count('/'), 0)
        ilp_posx = ilastik_classifiers['px'].as_posix()
        self.assertGreater(ilp_posx.count('/'), 0)
        self.assertEqual(ilp_posx.count('\\'), 0)
        self.assertEqual(pathlib.Path(ilp_win), pathlib.Path(ilp_posx))

        # load models with these paths
        resp1 = self._put(
            'ilastik/seg/load/',
            {'project_file': ilp_win, 'duplicate': False },
        )
        resp2 = self._put(
            'ilastik/seg/load/',
            {'project_file': ilp_posx, 'duplicate': False},
        )
        self.assertEqual(resp1.json(), resp2.json())

        # assert that only one copy of the model is loaded
        resp_list_2 = self._get('models').json()
        print(resp_list_2)
        self.assertEqual(len(resp_list_2), 1)


    def test_load_ilastik_pxmap_to_obj_model(self):
        resp_load = self._put(
            'ilastik/pxmap_to_obj/load/',
            {'project_file': str(ilastik_classifiers['pxmap_to_obj'])},
        )
        model_id = resp_load.json()['model_id']

        self.assertEqual(resp_load.status_code, 200, resp_load.json())
        resp_list = self._get('models')
        self.assertEqual(resp_list.status_code, 200)
        rj = resp_list.json()
        self.assertEqual(rj[model_id]['class'], 'IlastikObjectClassifierFromPixelPredictionsModel')
        return model_id

    def test_load_ilastik_seg_to_obj_model(self):
        resp_load = self._put(
            'ilastik/seg_to_obj/load/',
            {'project_file': str(ilastik_classifiers['seg_to_obj'])},
        )
        model_id = resp_load.json()['model_id']

        self.assertEqual(resp_load.status_code, 200, resp_load.json())
        resp_list = self._get('models')
        self.assertEqual(resp_list.status_code, 200)
        rj = resp_list.json()
        self.assertEqual(rj[model_id]['class'], 'IlastikObjectClassifierFromSegmentationModel')
        return model_id

    def test_ilastik_infer_pixel_probability(self):
        self.copy_input_file_to_server()
        model_id = self.test_load_ilastik_pixel_model()

        resp_infer = self._put(
            f'workflows/segment',
            {'model_id': model_id, 'input_filename': czifile['filename'], 'channel': 0},
        )
        self.assertEqual(resp_infer.status_code, 200, resp_infer.content.decode())

    def test_ilastik_infer_px_then_ob(self):
        self.copy_input_file_to_server()
        px_model_id = self.test_load_ilastik_pixel_model()
        ob_model_id = self.test_load_ilastik_pxmap_to_obj_model()

        resp_infer = self._put(
            'ilastik/pixel_then_object_classification/infer/',
            {
                'px_model_id': px_model_id,
                'ob_model_id': ob_model_id,
                'input_filename': czifile['filename'],
                'channel': 0,
            }
        )
        self.assertEqual(resp_infer.status_code, 200, resp_infer.content.decode())

class TestIlastikObjectClassification(unittest.TestCase):
    def setUp(self):
        stack = generate_file_accessor(roiset_test_data['multichannel_zstack']['path'])
        stack_ch_pa = stack.get_one_channel_data(roiset_test_data['pipeline_params']['patches_channel'])
        seg_mask = generate_file_accessor(roiset_test_data['multichannel_zstack']['mask_path'])

        self.roiset = RoiSet(
            stack_ch_pa,
            _get_label_ids(seg_mask),
            params=RoiSetMetaParams(
                mask_type='boxes',
                filters={'area': {'min': 1e3, 'max': 1e4}},
                expand_box_by=(64, 2)
            )
        )

        self.object_classifier = ilm.IlastikObjectClassifierFromSegmentationModel(
            params={'project_file': ilastik_classifiers['seg_to_obj']}
        )


    def test_classify_patches(self):
        raw_patches = self.roiset.get_patches_acc()
        patch_masks = self.roiset.get_patch_masks_acc()
        res_patches = self.object_classifier.label_instance_class(raw_patches, patch_masks)
        self.assertEqual(res_patches.count, self.roiset.count)
        res_patches.export_pyxcz(output_path / 'res_patches.tif')
        for pi in range(0, res_patches.count):  # assert that there is only one nonzero label per patch
            la, ct = np.unique(res_patches.iat(pi).data, return_counts=True)
            self.assertEqual(np.sum(ct > 1), 2)  # exclude single-pixel anomaly
            self.assertEqual(la[0], 0)