diff --git a/extensions/chaeo/conf/testing.py b/extensions/chaeo/conf/testing.py index 00389e08d8adbe01a3bd14c19f76342a9c00ccea..c151234c73607bf5f8d5660b23d1187a80f72b84 100644 --- a/extensions/chaeo/conf/testing.py +++ b/extensions/chaeo/conf/testing.py @@ -15,7 +15,10 @@ pixel_classifier = { } pipeline_params = { - 'threshold': 0.6, + 'segmentation_channel': 0, + 'patches_channel': 4, + 'pxmap_channel': 0, + 'pxmap_threshold': 0.6, } output_path = root / 'testing' / 'output' / 'chaeo' diff --git a/extensions/chaeo/tests/test_zstack.py b/extensions/chaeo/tests/test_zstack.py index 53cdac09349611a2e801d2a7d96e5a13056da5d6..c15fdfee209fce56d09d0b49a846de2c49b8acc9 100644 --- a/extensions/chaeo/tests/test_zstack.py +++ b/extensions/chaeo/tests/test_zstack.py @@ -6,30 +6,41 @@ from conf.testing import output_path from extensions.chaeo.conf.testing import multichannel_zstack, pixel_classifier, pipeline_params from extensions.chaeo.products import export_patches_from_zstack, export_multichannel_patches_from_zstack, export_patch_masks_from_zstack +from extensions.chaeo.workflows import infer_object_map_from_zstack from extensions.chaeo.zmask import build_zmask_from_object_mask from model_server.accessors import generate_file_accessor, InMemoryDataAccessor, write_accessor_data_to_file from extensions.ilastik.models import IlastikPixelClassifierModel +from model_server.models import DummyInstanceSegmentationModel class TestZStackDerivedDataProducts(unittest.TestCase): def setUp(self) -> None: + # need test data incl obj map self.stack = generate_file_accessor(multichannel_zstack['path']) + self.stack_ch_seg = self.stack.get_one_channel_data(pipeline_params['segmentation_channel']) + self.stack_ch_pa = self.stack.get_one_channel_data(pipeline_params['patches_channel']) - pxmodel = IlastikPixelClassifierModel( + self.pxmodel = IlastikPixelClassifierModel( {'project_file': pixel_classifier['path']}, ) - mip = InMemoryDataAccessor(self.stack.get_one_channel_data(channel=0).data.max(axis=-1, keepdims=True)) - self.pxmap, result = pxmodel.infer(mip) + mip = InMemoryDataAccessor( + self.stack_ch_seg.data.max(axis=-1, keepdims=True) + ) + pxmap, _ = self.pxmodel.infer(mip) - # write_accessor_data_to_file(output_path / 'pxmap.tif', self.pxmap) - self.obmap = InMemoryDataAccessor(self.pxmap.data > pipeline_params['threshold']) - # write_accessor_data_to_file(output_path / 'obmap.tif', self.obmap) + write_accessor_data_to_file(output_path / 'pxmap.tif', pxmap) + self.seg_mask = InMemoryDataAccessor( + pxmap.get_one_channel_data( + pipeline_params['pxmap_channel'] + ).data > pipeline_params['pxmap_threshold'] + ) + write_accessor_data_to_file(output_path / 'seg_mask.tif', self.seg_mask) def test_zmask_makes_correct_boxes(self, mask_type='boxes', **kwargs): zmask, meta, df, interm = build_zmask_from_object_mask( - self.obmap.get_one_channel_data(0), - self.stack.get_one_channel_data(0), + self.seg_mask, + self.stack_ch_pa, mask_type=mask_type, **kwargs, ) @@ -58,10 +69,10 @@ class TestZStackDerivedDataProducts(unittest.TestCase): return zmask, meta def test_zmask_works_on_non_zstacks(self, **kwargs): - acc_zstack_slice = InMemoryDataAccessor(self.stack.data[:, :, 0, 0]) + acc_zstack_slice = InMemoryDataAccessor(self.stack_ch_pa.data[:, :, :, 0]) self.assertEqual(acc_zstack_slice.nz, 1) zmask, meta, df, interm = build_zmask_from_object_mask( - self.obmap.get_one_channel_data(0), + self.seg_mask, acc_zstack_slice, mask_type='boxes', **kwargs, @@ -85,7 +96,7 @@ class TestZStackDerivedDataProducts(unittest.TestCase): ) files = export_patches_from_zstack( output_path / '2d_patches', - self.stack.get_one_channel_data(channel=1), + self.stack_ch_pa, meta, draw_bounding_box=True, ) @@ -98,15 +109,15 @@ class TestZStackDerivedDataProducts(unittest.TestCase): ) files = export_patches_from_zstack( output_path / '3d_patches', - self.stack.get_one_channel_data(4), + self.stack_ch_pa, meta, make_3d=True) self.assertGreaterEqual(len(files), 1) def test_flatten_image(self): zmask, meta, df, interm = build_zmask_from_object_mask( - self.obmap.get_one_channel_data(0), - self.stack.get_one_channel_data(4), + self.seg_mask, + self.stack_ch_pa, mask_type='boxes', ) @@ -188,4 +199,20 @@ class TestZStackDerivedDataProducts(unittest.TestCase): InMemoryDataAccessor(self.stack.data), meta, ) - self.assertGreaterEqual(len(files), 1) \ No newline at end of file + self.assertGreaterEqual(len(files), 1) + + def test_object_map_workflow(self): + pp = pipeline_params + models = [ + self.pxmodel, + DummyInstanceSegmentationModel(), + ] + infer_object_map_from_zstack( + multichannel_zstack['path'], + output_path, + models, + pxmap_foreground_channel=pp['pxmap_channel'], + pxmap_threshold=pp['pxmap_threshold'], + segmentation_channel=pp['segmentation_channel'], + patches_channel=pp['patches_channel'], + ) \ No newline at end of file diff --git a/extensions/chaeo/workflows.py b/extensions/chaeo/workflows.py index 50428f09b0c7f9a6f51e89ef586598bc016f9284..284a773a34fdd5c986ad131358318896447e358e 100644 --- a/extensions/chaeo/workflows.py +++ b/extensions/chaeo/workflows.py @@ -19,7 +19,7 @@ from extensions.chaeo.zmask import project_stack_from_focal_points, RoiSet from extensions.ilastik.models import IlastikPixelClassifierModel from model_server.accessors import generate_file_accessor, InMemoryDataAccessor, write_accessor_data_to_file -from model_server.models import Model +from model_server.models import Model, InstanceSegmentationModel, SemanticSegmentationModel from model_server.process import rescale from model_server.workflows import Timer @@ -240,14 +240,14 @@ def infer_object_map_from_zstack( zmask_type: str = 'boxes', zmask_filters: Dict = None, # zmask_expand_box_by: int = None, - exports: RoiSetExportParams = None, + exports: RoiSetExportParams = RoiSetExportParams(), **kwargs, ) -> Dict: assert len(models) == 2 pixel_classifier = models[0] - assert isinstance(pixel_classifier, IlastikPixelClassifierModel) + assert isinstance(pixel_classifier, SemanticSegmentationModel) object_classifier = models[1] - assert isinstance(object_classifier, PatchStackObjectClassifier) + assert isinstance(object_classifier, InstanceSegmentationModel) ti = Timer() stack = generate_file_accessor(Path(input_file_path)) @@ -286,13 +286,10 @@ def infer_object_map_from_zstack( stack.get_one_channel_data(segmentation_channel), mask_type=zmask_type, filters=zmask_filters, - expand_box_by=kwargs['zmask_expand_box_by'], + expand_box_by=kwargs.get('zmask_expand_box_by', (0, 0)), ) ti.click('generate_zmasks') - # record pixel scale - rois.df['pixel_scale_in_micrometers'] = float(stack.pixel_scale_in_micrometers.get('X')) - # ti, stack, fstem, obmask, pxmap, obj_table = get_zmask_meta( # input_file_path, # pixel_classifier, @@ -341,13 +338,13 @@ def infer_object_map_from_zstack( # output_map[labels_map == object_id] = object_class # meta.append({'object_id': ii, 'object_class': object_id}) - object_class_map = rois.classify_by(patches_channel) + object_class_map = rois.classify_by(patches_channel, object_classifier) # TODO: add ZMaskObjectTable method to export object map output_path = Path(output_folder_path) / ('obj_classes_' + (fstem + '.tif')) write_accessor_data_to_file( output_path, - InMemoryDataAccessor(object_class_map) + object_class_map ) ti.click('export_object_classes') diff --git a/extensions/chaeo/zmask.py b/extensions/chaeo/zmask.py index 69e128ef871fcd47ed9e385a6da9e128c38f2a76..1f93151a437244f0c1f28700ddf56d82840681e4 100644 --- a/extensions/chaeo/zmask.py +++ b/extensions/chaeo/zmask.py @@ -152,7 +152,7 @@ class RoiSet(object): return projected def get_raw_patches(self, channel): - return get_patches_from_zmask_meta(self.acc_raw(channel), self.zmask_meta) + return get_patches_from_zmask_meta(self.acc_raw, self.zmask_meta) def get_patch_masks(self): return get_patch_masks_from_zmask_meta(self.acc_raw, self.zmask_meta) diff --git a/model_server/api.py b/model_server/api.py index a1c785313881d594e49fa8f335f24664e4429bf2..79bd68fcc0e96af6f07bbacce5d96dd76ecab97c 100644 --- a/model_server/api.py +++ b/model_server/api.py @@ -1,6 +1,6 @@ from fastapi import FastAPI, HTTPException -from model_server.models import DummySegmentationModel +from model_server.models import DummySemanticSegmentationModel from model_server.session import Session, InvalidPathError from model_server.validators import validate_workflow_inputs from model_server.workflows import classify_pixels @@ -67,7 +67,7 @@ def list_active_models(): @app.put('/models/dummy/load/') def load_dummy_model() -> dict: - return {'model_id': session.load_model(DummySegmentationModel)} + return {'model_id': session.load_model(DummySemanticSegmentationModel)} @app.put('/workflows/segment') def infer_img(model_id: str, input_filename: str, channel: int = None) -> dict: diff --git a/tests/test_api.py b/tests/test_api.py index e846be74218912a97b346d862bdce4850c0eb733..89ab62e449c399f890cb42138252c099d91a0c95 100644 --- a/tests/test_api.py +++ b/tests/test_api.py @@ -4,7 +4,7 @@ import requests import unittest from conf.testing import czifile -from model_server.models import DummySegmentationModel +from model_server.models import DummySemanticSegmentationModel class TestServerBaseClass(unittest.TestCase): def setUp(self) -> None: diff --git a/tests/test_workflow.py b/tests/test_workflow.py index a88d8791eb4cba9ce6544124e51edacfa3a79c0a..7f6474c0038739358bdf704bec6a369c13266abc 100644 --- a/tests/test_workflow.py +++ b/tests/test_workflow.py @@ -1,13 +1,13 @@ import unittest from conf.testing import czifile, output_path -from model_server.models import DummySegmentationModel +from model_server.models import DummySemanticSegmentationModel from model_server.workflows import classify_pixels class TestGetSessionObject(unittest.TestCase): def setUp(self) -> None: - self.model = DummySegmentationModel() + self.model = DummySemanticSegmentationModel() def test_single_session_instance(self): result = classify_pixels(czifile['path'], self.model, output_path, channel=2)