-
Christopher Randolph Rhodes authoredChristopher Randolph Rhodes authored
test_roiset_pipeline.py 10.36 KiB
import json
from pathlib import Path
import unittest
import numpy as np
from model_server.base.accessors import generate_file_accessor
import model_server.conf.testing as conf
from model_server.base.pipelines.roiset_obmap import RoiSetObjectMapParams, roiset_object_map_pipeline
data = conf.meta['image_files']
output_path = conf.meta['output_path']
test_params = conf.meta['roiset']
class BaseTestRoiSetMonoProducts(object):
@property
def fpi(self):
return data['multichannel_zstack_raw']['path'].__str__()
@property
def stack(self):
return generate_file_accessor(self.fpi)
@property
def stack_ch_pa(self):
return self.stack.get_mono(test_params['patches_channel'])
@property
def seg_mask(self):
return generate_file_accessor(data['multichannel_zstack_mask2d']['path'])
def _get_export_params(self):
return {
'patches_3d': None,
'annotated_patches_2d': {
'draw_bounding_box': True,
'rgb_overlay_channels': [3, None, None],
'rgb_overlay_weights': [0.2, 1.0, 1.0],
'pad_to': 512,
},
'patches_2d': {
'draw_bounding_box': False,
'draw_mask': False,
},
'annotated_zstacks': None,
'object_classes': True,
}
def _get_roi_params(self):
return {
'mask_type': 'boxes',
'filters': {
'area': {'min': 1e0, 'max': 1e8}
},
'expand_box_by': [128, 2],
'deproject_channel': 0,
}
def _get_models(self):
from model_server.base.models import BinaryThresholdSegmentationModel
from model_server.base.roiset import IntensityThresholdInstanceMaskSegmentationModel
return {
'pixel_classifier_segmentation': {
'name': 'min_px_mod',
'model': BinaryThresholdSegmentationModel(tr=0.2),
},
'object_classifier': {
'name': 'min_ob_mod',
'model': IntensityThresholdInstanceMaskSegmentationModel(),
},
}
class TestRoiSetWorkflow(BaseTestRoiSetMonoProducts, unittest.TestCase):
def _pipeline_params(self):
return {
'api': False,
'accessor_id': 'acc_id',
'pixel_classifier_segmentation_model_id': 'px_id',
'object_classifier_model_id': 'ob_id',
'segmentation': {
'channel': test_params['segmentation_channel'],
},
'patches_channel': test_params['patches_channel'],
'roi_params': self._get_roi_params(),
'export_params': self._get_export_params(),
}
def test_object_map_workflow(self):
acc_in = generate_file_accessor(self.fpi)
params = RoiSetObjectMapParams(
**self._pipeline_params(),
)
trace, rois = roiset_object_map_pipeline(
{'': acc_in},
{f'{k}_': v['model'] for k, v in self._get_models().items()},
**params.dict()
)
self.assertEqual(trace.pop('annotated_patches_2d').count, 22)
self.assertEqual(trace.pop('patches_2d').count, 22)
trace.write_interm(Path(output_path) / 'trace', 'roiset_worfklow_trace', skip_first=False, skip_last=False)
self.assertTrue('ob_id' in trace.keys())
self.assertEqual(len(trace['labeled'].unique()[0]), 40)
self.assertEqual(rois.count, 22)
self.assertEqual(len(trace['ob_id'].unique()[0]), 2)
class TestRoiSetWorkflowOverApi(conf.TestServerBaseClass, BaseTestRoiSetMonoProducts):
input_data = data['multichannel_zstack_raw']
def setUp(self) -> None:
self.where_out = output_path / 'roiset'
self.where_out.mkdir(parents=True, exist_ok=True)
return conf.TestServerBaseClass.setUp(self)
def test_trivial_api_response(self):
self.assertGetSuccess('')
def test_load_input_accessor(self):
fname = self.copy_input_file_to_server()
return self.assertPutSuccess(f'accessors/read_from_file/{fname}')
def test_load_pixel_classifier(self):
mid = self.assertPutSuccess(
'models/seg/threshold/load/',
query={'tr': 0.2},
)['model_id']
self.assertTrue(mid.startswith('BinaryThresholdSegmentationModel'))
return mid
def test_load_object_classifier(self):
mid = self.assertPutSuccess(
'models/classify/threshold/load/',
body={'tr': 0}
)['model_id']
self.assertTrue(mid.startswith('IntensityThresholdInstanceMaskSegmentation'))
return mid
def _object_map_workflow(self, ob_classifer_id):
res = self.assertPutSuccess(
'pipelines/roiset_to_obmap/infer',
body={
'accessor_id': self.test_load_input_accessor(),
'pixel_classifier_segmentation_model_id': self.test_load_pixel_classifier(),
'object_classifier_model_id': ob_classifer_id,
'segmentation': {'channel': 0},
'patches_channel': 1,
'roi_params': self._get_roi_params(),
'export_params': self._get_export_params(),
},
)
# check on automatically written RoiSet
roiset_id = res['roiset_id']
roiset_info = self.assertGetSuccess(f'rois/{roiset_id}')
self.assertGreater(roiset_info['count'], 0)
return res
def test_workflow_with_object_classifier(self):
obmod_id = self.test_load_object_classifier()
res = self._object_map_workflow(obmod_id)
acc_obmap = self.get_accessor(res['output_accessor_id'])
self.assertTrue(np.all(acc_obmap.unique()[0] == [0, 1]))
# get object map via RoiSet API
roiset_id = res['roiset_id']
obmap_id = self.assertPutSuccess(f'rois/obmap/{roiset_id}/{obmod_id}', query={'object_classes': True})
acc_obmap_roiset = self.get_accessor(obmap_id)
self.assertTrue(np.all(acc_obmap_roiset.data == acc_obmap.data))
# check serialize RoiSet
self.assertPutSuccess(f'rois/write/{roiset_id}')
self.assertFalse(
self.assertGetSuccess(f'rois/{roiset_id}')['loaded']
)
def test_workflow_without_object_classifier(self):
res = self._object_map_workflow(None)
acc_obmap = self.get_accessor(res['output_accessor_id'])
self.assertTrue(np.all(acc_obmap.unique()[0] == [0, 1]))
class TestRoiSetWorkflowOverApi(conf.TestServerBaseClass, BaseTestRoiSetMonoProducts):
input_data = data['multichannel_zstack_raw']
def setUp(self) -> None:
self.where_out = output_path / 'roiset'
self.where_out.mkdir(parents=True, exist_ok=True)
return conf.TestServerBaseClass.setUp(self)
def test_load_input_accessor(self):
fname = self.copy_input_file_to_server()
return self.assertPutSuccess(f'accessors/read_from_file/{fname}')
def test_load_pixel_classifier(self):
mid = self.assertPutSuccess(
'models/seg/threshold/load/',
query={'tr': 0.2},
)['model_id']
self.assertTrue(mid.startswith('BinaryThresholdSegmentationModel'))
return mid
def test_load_object_classifier(self):
mid = self.assertPutSuccess(
'models/classify/threshold/load/',
body={'tr': 0}
)['model_id']
self.assertTrue(mid.startswith('IntensityThresholdInstanceMaskSegmentation'))
return mid
def _object_map_workflow(self, ob_classifer_id):
res = self.assertPutSuccess(
'pipelines/roiset_to_obmap/infer',
body={
'accessor_id': self.test_load_input_accessor(),
'pixel_classifier_segmentation_model_id': self.test_load_pixel_classifier(),
'object_classifier_model_id': ob_classifer_id,
'segmentation': {'channel': 0},
'patches_channel': 1,
'roi_params': self._get_roi_params(),
'export_params': self._get_export_params(),
},
)
# check on automatically written RoiSet
roiset_id = res['roiset_id']
roiset_info = self.assertGetSuccess(f'rois/{roiset_id}')
self.assertGreater(roiset_info['count'], 0)
return res
def test_workflow_with_object_classifier(self):
obmod_id = self.test_load_object_classifier()
res = self._object_map_workflow(obmod_id)
acc_obmap = self.get_accessor(res['output_accessor_id'])
self.assertTrue(np.all(acc_obmap.unique()[0] == [0, 1]))
# get object map via RoiSet API
roiset_id = res['roiset_id']
obmap_id = self.assertPutSuccess(f'rois/obmap/{roiset_id}/{obmod_id}', query={'object_classes': True})
acc_obmap_roiset = self.get_accessor(obmap_id)
self.assertTrue(np.all(acc_obmap_roiset.data == acc_obmap.data))
# check serialize RoiSet
self.assertPutSuccess(f'rois/write/{roiset_id}')
self.assertFalse(
self.assertGetSuccess(f'rois/{roiset_id}')['loaded']
)
def test_workflow_without_object_classifier(self):
res = self._object_map_workflow(None)
acc_obmap = self.get_accessor(res['output_accessor_id'])
self.assertTrue(np.all(acc_obmap.unique()[0] == [0, 1]))
class TestTaskQueuedRoiSetWorkflowOverApi(TestRoiSetWorkflowOverApi):
def _object_map_workflow(self, ob_classifer_id):
res_queue = self.assertPutSuccess(
'pipelines/queue/roiset_to_obmap',
body={
'accessor_id': self.test_load_input_accessor(),
'pixel_classifier_segmentation_model_id': self.test_load_pixel_classifier(),
'object_classifier_model_id': ob_classifer_id,
'segmentation': {'channel': 0},
'patches_channel': 1,
'roi_params': self._get_roi_params(),
'export_params': self._get_export_params(),
}
)
# check that task in enqueued
task_id = res_queue['task_id']
task_info = self.assertGetSuccess(f'tasks/{task_id}')
self.assertEqual(task_info['status'], 'WAITING')
# run the task
res_run = self.assertPutSuccess(
f'tasks/{task_id}/run'
)
self.assertTrue(res_run)
self.assertEqual(self.assertGetSuccess(f'tasks/{task_id}')['status'], 'FINISHED')
return res_run