Skip to content
Snippets Groups Projects
test_roiset_pipeline.py 10.36 KiB
import json
from pathlib import Path
import unittest

import numpy as np

from model_server.base.accessors import generate_file_accessor
import model_server.conf.testing as conf
from model_server.base.pipelines.roiset_obmap import RoiSetObjectMapParams, roiset_object_map_pipeline

data = conf.meta['image_files']
output_path = conf.meta['output_path']
test_params = conf.meta['roiset']


class BaseTestRoiSetMonoProducts(object):

    @property
    def fpi(self):
        return data['multichannel_zstack_raw']['path'].__str__()

    @property
    def stack(self):
        return generate_file_accessor(self.fpi)

    @property
    def stack_ch_pa(self):
        return self.stack.get_mono(test_params['patches_channel'])

    @property
    def seg_mask(self):
        return generate_file_accessor(data['multichannel_zstack_mask2d']['path'])

    def _get_export_params(self):
        return {
            'patches_3d': None,
            'annotated_patches_2d': {
                'draw_bounding_box': True,
                'rgb_overlay_channels': [3, None, None],
                'rgb_overlay_weights': [0.2, 1.0, 1.0],
                'pad_to': 512,
            },
            'patches_2d': {
                'draw_bounding_box': False,
                'draw_mask': False,
            },
            'annotated_zstacks': None,
            'object_classes': True,
        }

    def _get_roi_params(self):
        return {
            'mask_type': 'boxes',
            'filters': {
                'area': {'min': 1e0, 'max': 1e8}
            },
            'expand_box_by': [128, 2],
            'deproject_channel': 0,
        }

    def _get_models(self):
        from model_server.base.models import BinaryThresholdSegmentationModel
        from model_server.base.roiset import IntensityThresholdInstanceMaskSegmentationModel
        return {
            'pixel_classifier_segmentation': {
                'name': 'min_px_mod',
                'model': BinaryThresholdSegmentationModel(tr=0.2),
            },
            'object_classifier': {
                'name': 'min_ob_mod',
                'model': IntensityThresholdInstanceMaskSegmentationModel(),
            },
        }


class TestRoiSetWorkflow(BaseTestRoiSetMonoProducts, unittest.TestCase):

    def _pipeline_params(self):
        return {
            'api': False,
            'accessor_id': 'acc_id',
            'pixel_classifier_segmentation_model_id': 'px_id',
            'object_classifier_model_id': 'ob_id',
            'segmentation': {
                'channel': test_params['segmentation_channel'],
            },
            'patches_channel': test_params['patches_channel'],
            'roi_params': self._get_roi_params(),
            'export_params': self._get_export_params(),
        }

    def test_object_map_workflow(self):
        acc_in = generate_file_accessor(self.fpi)
        params = RoiSetObjectMapParams(
            **self._pipeline_params(),
        )
        trace, rois = roiset_object_map_pipeline(
            {'': acc_in},
            {f'{k}_': v['model'] for k, v in self._get_models().items()},
            **params.dict()
        )
        self.assertEqual(trace.pop('annotated_patches_2d').count, 22)
        self.assertEqual(trace.pop('patches_2d').count, 22)
        trace.write_interm(Path(output_path) / 'trace', 'roiset_worfklow_trace', skip_first=False, skip_last=False)
        self.assertTrue('ob_id' in trace.keys())
        self.assertEqual(len(trace['labeled'].unique()[0]), 40)
        self.assertEqual(rois.count, 22)
        self.assertEqual(len(trace['ob_id'].unique()[0]), 2)

class TestRoiSetWorkflowOverApi(conf.TestServerBaseClass, BaseTestRoiSetMonoProducts):

    input_data = data['multichannel_zstack_raw']


    def setUp(self) -> None:
        self.where_out = output_path / 'roiset'
        self.where_out.mkdir(parents=True, exist_ok=True)
        return conf.TestServerBaseClass.setUp(self)

    def test_trivial_api_response(self):
        self.assertGetSuccess('')

    def test_load_input_accessor(self):
        fname = self.copy_input_file_to_server()
        return self.assertPutSuccess(f'accessors/read_from_file/{fname}')

    def test_load_pixel_classifier(self):
        mid = self.assertPutSuccess(
            'models/seg/threshold/load/',
            query={'tr': 0.2},
        )['model_id']
        self.assertTrue(mid.startswith('BinaryThresholdSegmentationModel'))
        return mid

    def test_load_object_classifier(self):
        mid = self.assertPutSuccess(
            'models/classify/threshold/load/',
            body={'tr': 0}
        )['model_id']
        self.assertTrue(mid.startswith('IntensityThresholdInstanceMaskSegmentation'))
        return mid

    def _object_map_workflow(self, ob_classifer_id):
        res = self.assertPutSuccess(
            'pipelines/roiset_to_obmap/infer',
            body={
                'accessor_id': self.test_load_input_accessor(),
                'pixel_classifier_segmentation_model_id': self.test_load_pixel_classifier(),
                'object_classifier_model_id': ob_classifer_id,
                'segmentation': {'channel': 0},
                'patches_channel': 1,
                'roi_params': self._get_roi_params(),
                'export_params': self._get_export_params(),
            },
        )

        # check on automatically written RoiSet
        roiset_id = res['roiset_id']
        roiset_info = self.assertGetSuccess(f'rois/{roiset_id}')
        self.assertGreater(roiset_info['count'], 0)
        return res

    def test_workflow_with_object_classifier(self):
        obmod_id = self.test_load_object_classifier()
        res = self._object_map_workflow(obmod_id)
        acc_obmap = self.get_accessor(res['output_accessor_id'])
        self.assertTrue(np.all(acc_obmap.unique()[0] == [0, 1]))

        # get object map via RoiSet API
        roiset_id = res['roiset_id']
        obmap_id = self.assertPutSuccess(f'rois/obmap/{roiset_id}/{obmod_id}', query={'object_classes': True})
        acc_obmap_roiset = self.get_accessor(obmap_id)
        self.assertTrue(np.all(acc_obmap_roiset.data == acc_obmap.data))

        # check serialize RoiSet
        self.assertPutSuccess(f'rois/write/{roiset_id}')
        self.assertFalse(
            self.assertGetSuccess(f'rois/{roiset_id}')['loaded']
        )


    def test_workflow_without_object_classifier(self):
        res = self._object_map_workflow(None)
        acc_obmap = self.get_accessor(res['output_accessor_id'])
        self.assertTrue(np.all(acc_obmap.unique()[0] == [0, 1]))

class TestRoiSetWorkflowOverApi(conf.TestServerBaseClass, BaseTestRoiSetMonoProducts):

    input_data = data['multichannel_zstack_raw']


    def setUp(self) -> None:
        self.where_out = output_path / 'roiset'
        self.where_out.mkdir(parents=True, exist_ok=True)
        return conf.TestServerBaseClass.setUp(self)

    def test_load_input_accessor(self):
        fname = self.copy_input_file_to_server()
        return self.assertPutSuccess(f'accessors/read_from_file/{fname}')

    def test_load_pixel_classifier(self):
        mid = self.assertPutSuccess(
            'models/seg/threshold/load/',
            query={'tr': 0.2},
        )['model_id']
        self.assertTrue(mid.startswith('BinaryThresholdSegmentationModel'))
        return mid

    def test_load_object_classifier(self):
        mid = self.assertPutSuccess(
            'models/classify/threshold/load/',
            body={'tr': 0}
        )['model_id']
        self.assertTrue(mid.startswith('IntensityThresholdInstanceMaskSegmentation'))
        return mid

    def _object_map_workflow(self, ob_classifer_id):
        res = self.assertPutSuccess(
            'pipelines/roiset_to_obmap/infer',
            body={
                'accessor_id': self.test_load_input_accessor(),
                'pixel_classifier_segmentation_model_id': self.test_load_pixel_classifier(),
                'object_classifier_model_id': ob_classifer_id,
                'segmentation': {'channel': 0},
                'patches_channel': 1,
                'roi_params': self._get_roi_params(),
                'export_params': self._get_export_params(),
            },
        )

        # check on automatically written RoiSet
        roiset_id = res['roiset_id']
        roiset_info = self.assertGetSuccess(f'rois/{roiset_id}')
        self.assertGreater(roiset_info['count'], 0)
        return res

    def test_workflow_with_object_classifier(self):
        obmod_id = self.test_load_object_classifier()
        res = self._object_map_workflow(obmod_id)
        acc_obmap = self.get_accessor(res['output_accessor_id'])
        self.assertTrue(np.all(acc_obmap.unique()[0] == [0, 1]))

        # get object map via RoiSet API
        roiset_id = res['roiset_id']
        obmap_id = self.assertPutSuccess(f'rois/obmap/{roiset_id}/{obmod_id}', query={'object_classes': True})
        acc_obmap_roiset = self.get_accessor(obmap_id)
        self.assertTrue(np.all(acc_obmap_roiset.data == acc_obmap.data))

        # check serialize RoiSet
        self.assertPutSuccess(f'rois/write/{roiset_id}')
        self.assertFalse(
            self.assertGetSuccess(f'rois/{roiset_id}')['loaded']
        )


    def test_workflow_without_object_classifier(self):
        res = self._object_map_workflow(None)
        acc_obmap = self.get_accessor(res['output_accessor_id'])
        self.assertTrue(np.all(acc_obmap.unique()[0] == [0, 1]))


class TestTaskQueuedRoiSetWorkflowOverApi(TestRoiSetWorkflowOverApi):
    def _object_map_workflow(self, ob_classifer_id):

        res_queue = self.assertPutSuccess(
            'pipelines/queue/roiset_to_obmap',
            body={
                'accessor_id': self.test_load_input_accessor(),
                'pixel_classifier_segmentation_model_id': self.test_load_pixel_classifier(),
                'object_classifier_model_id': ob_classifer_id,
                'segmentation': {'channel': 0},
                'patches_channel': 1,
                'roi_params': self._get_roi_params(),
                'export_params': self._get_export_params(),
            }
        )

        # check that task in enqueued
        task_id = res_queue['task_id']
        task_info = self.assertGetSuccess(f'tasks/{task_id}')
        self.assertEqual(task_info['status'], 'WAITING')

        # run the task
        res_run = self.assertPutSuccess(
            f'tasks/{task_id}/run'
        )
        self.assertTrue(res_run)
        self.assertEqual(self.assertGetSuccess(f'tasks/{task_id}')['status'], 'FINISHED')

        return res_run