Skip to content
Snippets Groups Projects
Commit 3bf73138 authored by Christopher Randolph Rhodes's avatar Christopher Randolph Rhodes
Browse files

Test covering unified zmask workflow passes

parent c3e3ff61
No related branches found
No related tags found
No related merge requests found
...@@ -15,7 +15,10 @@ pixel_classifier = { ...@@ -15,7 +15,10 @@ pixel_classifier = {
} }
pipeline_params = { pipeline_params = {
'threshold': 0.6, 'segmentation_channel': 0,
'patches_channel': 4,
'pxmap_channel': 0,
'pxmap_threshold': 0.6,
} }
output_path = root / 'testing' / 'output' / 'chaeo' output_path = root / 'testing' / 'output' / 'chaeo'
......
...@@ -6,30 +6,41 @@ from conf.testing import output_path ...@@ -6,30 +6,41 @@ from conf.testing import output_path
from extensions.chaeo.conf.testing import multichannel_zstack, pixel_classifier, pipeline_params from extensions.chaeo.conf.testing import multichannel_zstack, pixel_classifier, pipeline_params
from extensions.chaeo.products import export_patches_from_zstack, export_multichannel_patches_from_zstack, export_patch_masks_from_zstack from extensions.chaeo.products import export_patches_from_zstack, export_multichannel_patches_from_zstack, export_patch_masks_from_zstack
from extensions.chaeo.workflows import infer_object_map_from_zstack
from extensions.chaeo.zmask import build_zmask_from_object_mask from extensions.chaeo.zmask import build_zmask_from_object_mask
from model_server.accessors import generate_file_accessor, InMemoryDataAccessor, write_accessor_data_to_file from model_server.accessors import generate_file_accessor, InMemoryDataAccessor, write_accessor_data_to_file
from extensions.ilastik.models import IlastikPixelClassifierModel from extensions.ilastik.models import IlastikPixelClassifierModel
from model_server.models import DummyInstanceSegmentationModel
class TestZStackDerivedDataProducts(unittest.TestCase): class TestZStackDerivedDataProducts(unittest.TestCase):
def setUp(self) -> None: def setUp(self) -> None:
# need test data incl obj map # need test data incl obj map
self.stack = generate_file_accessor(multichannel_zstack['path']) self.stack = generate_file_accessor(multichannel_zstack['path'])
self.stack_ch_seg = self.stack.get_one_channel_data(pipeline_params['segmentation_channel'])
self.stack_ch_pa = self.stack.get_one_channel_data(pipeline_params['patches_channel'])
pxmodel = IlastikPixelClassifierModel( self.pxmodel = IlastikPixelClassifierModel(
{'project_file': pixel_classifier['path']}, {'project_file': pixel_classifier['path']},
) )
mip = InMemoryDataAccessor(self.stack.get_one_channel_data(channel=0).data.max(axis=-1, keepdims=True)) mip = InMemoryDataAccessor(
self.pxmap, result = pxmodel.infer(mip) self.stack_ch_seg.data.max(axis=-1, keepdims=True)
)
pxmap, _ = self.pxmodel.infer(mip)
# write_accessor_data_to_file(output_path / 'pxmap.tif', self.pxmap) write_accessor_data_to_file(output_path / 'pxmap.tif', pxmap)
self.obmap = InMemoryDataAccessor(self.pxmap.data > pipeline_params['threshold']) self.seg_mask = InMemoryDataAccessor(
# write_accessor_data_to_file(output_path / 'obmap.tif', self.obmap) pxmap.get_one_channel_data(
pipeline_params['pxmap_channel']
).data > pipeline_params['pxmap_threshold']
)
write_accessor_data_to_file(output_path / 'seg_mask.tif', self.seg_mask)
def test_zmask_makes_correct_boxes(self, mask_type='boxes', **kwargs): def test_zmask_makes_correct_boxes(self, mask_type='boxes', **kwargs):
zmask, meta, df, interm = build_zmask_from_object_mask( zmask, meta, df, interm = build_zmask_from_object_mask(
self.obmap.get_one_channel_data(0), self.seg_mask,
self.stack.get_one_channel_data(0), self.stack_ch_pa,
mask_type=mask_type, mask_type=mask_type,
**kwargs, **kwargs,
) )
...@@ -58,10 +69,10 @@ class TestZStackDerivedDataProducts(unittest.TestCase): ...@@ -58,10 +69,10 @@ class TestZStackDerivedDataProducts(unittest.TestCase):
return zmask, meta return zmask, meta
def test_zmask_works_on_non_zstacks(self, **kwargs): def test_zmask_works_on_non_zstacks(self, **kwargs):
acc_zstack_slice = InMemoryDataAccessor(self.stack.data[:, :, 0, 0]) acc_zstack_slice = InMemoryDataAccessor(self.stack_ch_pa.data[:, :, :, 0])
self.assertEqual(acc_zstack_slice.nz, 1) self.assertEqual(acc_zstack_slice.nz, 1)
zmask, meta, df, interm = build_zmask_from_object_mask( zmask, meta, df, interm = build_zmask_from_object_mask(
self.obmap.get_one_channel_data(0), self.seg_mask,
acc_zstack_slice, acc_zstack_slice,
mask_type='boxes', mask_type='boxes',
**kwargs, **kwargs,
...@@ -85,7 +96,7 @@ class TestZStackDerivedDataProducts(unittest.TestCase): ...@@ -85,7 +96,7 @@ class TestZStackDerivedDataProducts(unittest.TestCase):
) )
files = export_patches_from_zstack( files = export_patches_from_zstack(
output_path / '2d_patches', output_path / '2d_patches',
self.stack.get_one_channel_data(channel=1), self.stack_ch_pa,
meta, meta,
draw_bounding_box=True, draw_bounding_box=True,
) )
...@@ -98,15 +109,15 @@ class TestZStackDerivedDataProducts(unittest.TestCase): ...@@ -98,15 +109,15 @@ class TestZStackDerivedDataProducts(unittest.TestCase):
) )
files = export_patches_from_zstack( files = export_patches_from_zstack(
output_path / '3d_patches', output_path / '3d_patches',
self.stack.get_one_channel_data(4), self.stack_ch_pa,
meta, meta,
make_3d=True) make_3d=True)
self.assertGreaterEqual(len(files), 1) self.assertGreaterEqual(len(files), 1)
def test_flatten_image(self): def test_flatten_image(self):
zmask, meta, df, interm = build_zmask_from_object_mask( zmask, meta, df, interm = build_zmask_from_object_mask(
self.obmap.get_one_channel_data(0), self.seg_mask,
self.stack.get_one_channel_data(4), self.stack_ch_pa,
mask_type='boxes', mask_type='boxes',
) )
...@@ -188,4 +199,20 @@ class TestZStackDerivedDataProducts(unittest.TestCase): ...@@ -188,4 +199,20 @@ class TestZStackDerivedDataProducts(unittest.TestCase):
InMemoryDataAccessor(self.stack.data), InMemoryDataAccessor(self.stack.data),
meta, meta,
) )
self.assertGreaterEqual(len(files), 1) self.assertGreaterEqual(len(files), 1)
\ No newline at end of file
def test_object_map_workflow(self):
pp = pipeline_params
models = [
self.pxmodel,
DummyInstanceSegmentationModel(),
]
infer_object_map_from_zstack(
multichannel_zstack['path'],
output_path,
models,
pxmap_foreground_channel=pp['pxmap_channel'],
pxmap_threshold=pp['pxmap_threshold'],
segmentation_channel=pp['segmentation_channel'],
patches_channel=pp['patches_channel'],
)
\ No newline at end of file
...@@ -19,7 +19,7 @@ from extensions.chaeo.zmask import project_stack_from_focal_points, RoiSet ...@@ -19,7 +19,7 @@ from extensions.chaeo.zmask import project_stack_from_focal_points, RoiSet
from extensions.ilastik.models import IlastikPixelClassifierModel from extensions.ilastik.models import IlastikPixelClassifierModel
from model_server.accessors import generate_file_accessor, InMemoryDataAccessor, write_accessor_data_to_file from model_server.accessors import generate_file_accessor, InMemoryDataAccessor, write_accessor_data_to_file
from model_server.models import Model from model_server.models import Model, InstanceSegmentationModel, SemanticSegmentationModel
from model_server.process import rescale from model_server.process import rescale
from model_server.workflows import Timer from model_server.workflows import Timer
...@@ -240,14 +240,14 @@ def infer_object_map_from_zstack( ...@@ -240,14 +240,14 @@ def infer_object_map_from_zstack(
zmask_type: str = 'boxes', zmask_type: str = 'boxes',
zmask_filters: Dict = None, zmask_filters: Dict = None,
# zmask_expand_box_by: int = None, # zmask_expand_box_by: int = None,
exports: RoiSetExportParams = None, exports: RoiSetExportParams = RoiSetExportParams(),
**kwargs, **kwargs,
) -> Dict: ) -> Dict:
assert len(models) == 2 assert len(models) == 2
pixel_classifier = models[0] pixel_classifier = models[0]
assert isinstance(pixel_classifier, IlastikPixelClassifierModel) assert isinstance(pixel_classifier, SemanticSegmentationModel)
object_classifier = models[1] object_classifier = models[1]
assert isinstance(object_classifier, PatchStackObjectClassifier) assert isinstance(object_classifier, InstanceSegmentationModel)
ti = Timer() ti = Timer()
stack = generate_file_accessor(Path(input_file_path)) stack = generate_file_accessor(Path(input_file_path))
...@@ -286,13 +286,10 @@ def infer_object_map_from_zstack( ...@@ -286,13 +286,10 @@ def infer_object_map_from_zstack(
stack.get_one_channel_data(segmentation_channel), stack.get_one_channel_data(segmentation_channel),
mask_type=zmask_type, mask_type=zmask_type,
filters=zmask_filters, filters=zmask_filters,
expand_box_by=kwargs['zmask_expand_box_by'], expand_box_by=kwargs.get('zmask_expand_box_by', (0, 0)),
) )
ti.click('generate_zmasks') ti.click('generate_zmasks')
# record pixel scale
rois.df['pixel_scale_in_micrometers'] = float(stack.pixel_scale_in_micrometers.get('X'))
# ti, stack, fstem, obmask, pxmap, obj_table = get_zmask_meta( # ti, stack, fstem, obmask, pxmap, obj_table = get_zmask_meta(
# input_file_path, # input_file_path,
# pixel_classifier, # pixel_classifier,
...@@ -341,13 +338,13 @@ def infer_object_map_from_zstack( ...@@ -341,13 +338,13 @@ def infer_object_map_from_zstack(
# output_map[labels_map == object_id] = object_class # output_map[labels_map == object_id] = object_class
# meta.append({'object_id': ii, 'object_class': object_id}) # meta.append({'object_id': ii, 'object_class': object_id})
object_class_map = rois.classify_by(patches_channel) object_class_map = rois.classify_by(patches_channel, object_classifier)
# TODO: add ZMaskObjectTable method to export object map # TODO: add ZMaskObjectTable method to export object map
output_path = Path(output_folder_path) / ('obj_classes_' + (fstem + '.tif')) output_path = Path(output_folder_path) / ('obj_classes_' + (fstem + '.tif'))
write_accessor_data_to_file( write_accessor_data_to_file(
output_path, output_path,
InMemoryDataAccessor(object_class_map) object_class_map
) )
ti.click('export_object_classes') ti.click('export_object_classes')
......
...@@ -152,7 +152,7 @@ class RoiSet(object): ...@@ -152,7 +152,7 @@ class RoiSet(object):
return projected return projected
def get_raw_patches(self, channel): def get_raw_patches(self, channel):
return get_patches_from_zmask_meta(self.acc_raw(channel), self.zmask_meta) return get_patches_from_zmask_meta(self.acc_raw, self.zmask_meta)
def get_patch_masks(self): def get_patch_masks(self):
return get_patch_masks_from_zmask_meta(self.acc_raw, self.zmask_meta) return get_patch_masks_from_zmask_meta(self.acc_raw, self.zmask_meta)
......
from fastapi import FastAPI, HTTPException from fastapi import FastAPI, HTTPException
from model_server.models import DummySegmentationModel from model_server.models import DummySemanticSegmentationModel
from model_server.session import Session, InvalidPathError from model_server.session import Session, InvalidPathError
from model_server.validators import validate_workflow_inputs from model_server.validators import validate_workflow_inputs
from model_server.workflows import classify_pixels from model_server.workflows import classify_pixels
...@@ -67,7 +67,7 @@ def list_active_models(): ...@@ -67,7 +67,7 @@ def list_active_models():
@app.put('/models/dummy/load/') @app.put('/models/dummy/load/')
def load_dummy_model() -> dict: def load_dummy_model() -> dict:
return {'model_id': session.load_model(DummySegmentationModel)} return {'model_id': session.load_model(DummySemanticSegmentationModel)}
@app.put('/workflows/segment') @app.put('/workflows/segment')
def infer_img(model_id: str, input_filename: str, channel: int = None) -> dict: def infer_img(model_id: str, input_filename: str, channel: int = None) -> dict:
......
...@@ -4,7 +4,7 @@ import requests ...@@ -4,7 +4,7 @@ import requests
import unittest import unittest
from conf.testing import czifile from conf.testing import czifile
from model_server.models import DummySegmentationModel from model_server.models import DummySemanticSegmentationModel
class TestServerBaseClass(unittest.TestCase): class TestServerBaseClass(unittest.TestCase):
def setUp(self) -> None: def setUp(self) -> None:
......
import unittest import unittest
from conf.testing import czifile, output_path from conf.testing import czifile, output_path
from model_server.models import DummySegmentationModel from model_server.models import DummySemanticSegmentationModel
from model_server.workflows import classify_pixels from model_server.workflows import classify_pixels
class TestGetSessionObject(unittest.TestCase): class TestGetSessionObject(unittest.TestCase):
def setUp(self) -> None: def setUp(self) -> None:
self.model = DummySegmentationModel() self.model = DummySemanticSegmentationModel()
def test_single_session_instance(self): def test_single_session_instance(self):
result = classify_pixels(czifile['path'], self.model, output_path, channel=2) result = classify_pixels(czifile['path'], self.model, output_path, channel=2)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment