import json from pathlib import Path import unittest import numpy as np from model_server.base.accessors import generate_file_accessor import model_server.conf.testing as conf from model_server.base.pipelines.roiset_obmap import RoiSetObjectMapParams, roiset_object_map_pipeline data = conf.meta['image_files'] output_path = conf.meta['output_path'] test_params = conf.meta['roiset'] class BaseTestRoiSetMonoProducts(object): @property def fpi(self): return data['multichannel_zstack_raw']['path'].__str__() @property def stack(self): return generate_file_accessor(self.fpi) @property def stack_ch_pa(self): return self.stack.get_mono(test_params['patches_channel']) @property def seg_mask(self): return generate_file_accessor(data['multichannel_zstack_mask2d']['path']) def _get_export_params(self): return { 'patches_3d': None, 'annotated_patches_2d': { 'draw_bounding_box': True, 'rgb_overlay_channels': [3, None, None], 'rgb_overlay_weights': [0.2, 1.0, 1.0], 'pad_to': 512, }, 'patches_2d': { 'draw_bounding_box': False, 'draw_mask': False, }, 'annotated_zstacks': None, 'object_classes': True, } def _get_roi_params(self): return { 'mask_type': 'boxes', 'filters': { 'area': {'min': 1e0, 'max': 1e8} }, 'expand_box_by': [128, 2], 'deproject_channel': 0, } def _get_models(self): from model_server.base.models import BinaryThresholdSegmentationModel from model_server.base.roiset import IntensityThresholdInstanceMaskSegmentationModel return { 'pixel_classifier_segmentation': { 'name': 'min_px_mod', 'model': BinaryThresholdSegmentationModel(tr=0.2), }, 'object_classifier': { 'name': 'min_ob_mod', 'model': IntensityThresholdInstanceMaskSegmentationModel(), }, } class TestRoiSetWorkflow(BaseTestRoiSetMonoProducts, unittest.TestCase): def _pipeline_params(self): return { 'api': False, 'accessor_id': 'acc_id', 'pixel_classifier_segmentation_model_id': 'px_id', 'object_classifier_model_id': 'ob_id', 'segmentation': { 'channel': test_params['segmentation_channel'], }, 'patches_channel': test_params['patches_channel'], 'roi_params': self._get_roi_params(), 'export_params': self._get_export_params(), } def test_object_map_workflow(self): acc_in = generate_file_accessor(self.fpi) params = RoiSetObjectMapParams( **self._pipeline_params(), ) trace, rois = roiset_object_map_pipeline( {'': acc_in}, {f'{k}_': v['model'] for k, v in self._get_models().items()}, **params.dict() ) self.assertEqual(trace.pop('annotated_patches_2d').count, 22) self.assertEqual(trace.pop('patches_2d').count, 22) trace.write_interm(Path(output_path) / 'trace', 'roiset_worfklow_trace', skip_first=False, skip_last=False) self.assertTrue('ob_id' in trace.keys()) self.assertEqual(len(trace['labeled'].unique()[0]), 40) self.assertEqual(rois.count, 22) self.assertEqual(len(trace['ob_id'].unique()[0]), 2) class TestRoiSetWorkflowOverApi(conf.TestServerBaseClass, BaseTestRoiSetMonoProducts): input_data = data['multichannel_zstack_raw'] def setUp(self) -> None: self.where_out = output_path / 'roiset' self.where_out.mkdir(parents=True, exist_ok=True) return conf.TestServerBaseClass.setUp(self) def test_trivial_api_response(self): self.assertGetSuccess('') def test_load_input_accessor(self): fname = self.copy_input_file_to_server() return self.assertPutSuccess(f'accessors/read_from_file/{fname}') def test_load_pixel_classifier(self): mid = self.assertPutSuccess( 'models/seg/threshold/load/', query={'tr': 0.2}, )['model_id'] self.assertTrue(mid.startswith('BinaryThresholdSegmentationModel')) return mid def test_load_object_classifier(self): mid = self.assertPutSuccess( 'models/classify/threshold/load/', body={'tr': 0} )['model_id'] self.assertTrue(mid.startswith('IntensityThresholdInstanceMaskSegmentation')) return mid def _object_map_workflow(self, ob_classifer_id): res = self.assertPutSuccess( 'pipelines/roiset_to_obmap/infer', body={ 'accessor_id': self.test_load_input_accessor(), 'pixel_classifier_segmentation_model_id': self.test_load_pixel_classifier(), 'object_classifier_model_id': ob_classifer_id, 'segmentation': {'channel': 0}, 'patches_channel': 1, 'roi_params': self._get_roi_params(), 'export_params': self._get_export_params(), }, ) # check on automatically written RoiSet roiset_id = res['roiset_id'] roiset_info = self.assertGetSuccess(f'rois/{roiset_id}') self.assertGreater(roiset_info['count'], 0) return res def test_workflow_with_object_classifier(self): obmod_id = self.test_load_object_classifier() res = self._object_map_workflow(obmod_id) acc_obmap = self.get_accessor(res['output_accessor_id']) self.assertTrue(np.all(acc_obmap.unique()[0] == [0, 1])) # get object map via RoiSet API roiset_id = res['roiset_id'] obmap_id = self.assertPutSuccess(f'rois/obmap/{roiset_id}/{obmod_id}', query={'object_classes': True}) acc_obmap_roiset = self.get_accessor(obmap_id) self.assertTrue(np.all(acc_obmap_roiset.data == acc_obmap.data)) # check serialize RoiSet self.assertPutSuccess(f'rois/write/{roiset_id}') self.assertFalse( self.assertGetSuccess(f'rois/{roiset_id}')['loaded'] ) def test_workflow_without_object_classifier(self): res = self._object_map_workflow(None) acc_obmap = self.get_accessor(res['output_accessor_id']) self.assertTrue(np.all(acc_obmap.unique()[0] == [0, 1])) class TestRoiSetWorkflowOverApi(conf.TestServerBaseClass, BaseTestRoiSetMonoProducts): input_data = data['multichannel_zstack_raw'] def setUp(self) -> None: self.where_out = output_path / 'roiset' self.where_out.mkdir(parents=True, exist_ok=True) return conf.TestServerBaseClass.setUp(self) def test_load_input_accessor(self): fname = self.copy_input_file_to_server() return self.assertPutSuccess(f'accessors/read_from_file/{fname}') def test_load_pixel_classifier(self): mid = self.assertPutSuccess( 'models/seg/threshold/load/', query={'tr': 0.2}, )['model_id'] self.assertTrue(mid.startswith('BinaryThresholdSegmentationModel')) return mid def test_load_object_classifier(self): mid = self.assertPutSuccess( 'models/classify/threshold/load/', body={'tr': 0} )['model_id'] self.assertTrue(mid.startswith('IntensityThresholdInstanceMaskSegmentation')) return mid def _object_map_workflow(self, ob_classifer_id): res = self.assertPutSuccess( 'pipelines/roiset_to_obmap/infer', body={ 'accessor_id': self.test_load_input_accessor(), 'pixel_classifier_segmentation_model_id': self.test_load_pixel_classifier(), 'object_classifier_model_id': ob_classifer_id, 'segmentation': {'channel': 0}, 'patches_channel': 1, 'roi_params': self._get_roi_params(), 'export_params': self._get_export_params(), }, ) # check on automatically written RoiSet roiset_id = res['roiset_id'] roiset_info = self.assertGetSuccess(f'rois/{roiset_id}') self.assertGreater(roiset_info['count'], 0) return res def test_workflow_with_object_classifier(self): obmod_id = self.test_load_object_classifier() res = self._object_map_workflow(obmod_id) acc_obmap = self.get_accessor(res['output_accessor_id']) self.assertTrue(np.all(acc_obmap.unique()[0] == [0, 1])) # get object map via RoiSet API roiset_id = res['roiset_id'] obmap_id = self.assertPutSuccess(f'rois/obmap/{roiset_id}/{obmod_id}', query={'object_classes': True}) acc_obmap_roiset = self.get_accessor(obmap_id) self.assertTrue(np.all(acc_obmap_roiset.data == acc_obmap.data)) # check serialize RoiSet self.assertPutSuccess(f'rois/write/{roiset_id}') self.assertFalse( self.assertGetSuccess(f'rois/{roiset_id}')['loaded'] ) def test_workflow_without_object_classifier(self): res = self._object_map_workflow(None) acc_obmap = self.get_accessor(res['output_accessor_id']) self.assertTrue(np.all(acc_obmap.unique()[0] == [0, 1])) class TestTaskQueuedRoiSetWorkflowOverApi(TestRoiSetWorkflowOverApi): def _object_map_workflow(self, ob_classifer_id): res_queue = self.assertPutSuccess( 'pipelines/queue/roiset_to_obmap', body={ 'accessor_id': self.test_load_input_accessor(), 'pixel_classifier_segmentation_model_id': self.test_load_pixel_classifier(), 'object_classifier_model_id': ob_classifer_id, 'segmentation': {'channel': 0}, 'patches_channel': 1, 'roi_params': self._get_roi_params(), 'export_params': self._get_export_params(), } ) # check that task in enqueued task_id = res_queue['task_id'] task_info = self.assertGetSuccess(f'tasks/{task_id}') self.assertEqual(task_info['status'], 'WAITING') # run the task res_run = self.assertPutSuccess( f'tasks/{task_id}/run' ) self.assertTrue(res_run) self.assertEqual(self.assertGetSuccess(f'tasks/{task_id}')['status'], 'FINISHED') return res_run