Skip to content
Snippets Groups Projects
Commit 3bf73138 authored by Christopher Randolph Rhodes's avatar Christopher Randolph Rhodes
Browse files

Test covering unified zmask workflow passes

parent c3e3ff61
No related branches found
No related tags found
No related merge requests found
......@@ -15,7 +15,10 @@ pixel_classifier = {
}
pipeline_params = {
'threshold': 0.6,
'segmentation_channel': 0,
'patches_channel': 4,
'pxmap_channel': 0,
'pxmap_threshold': 0.6,
}
output_path = root / 'testing' / 'output' / 'chaeo'
......
......@@ -6,30 +6,41 @@ from conf.testing import output_path
from extensions.chaeo.conf.testing import multichannel_zstack, pixel_classifier, pipeline_params
from extensions.chaeo.products import export_patches_from_zstack, export_multichannel_patches_from_zstack, export_patch_masks_from_zstack
from extensions.chaeo.workflows import infer_object_map_from_zstack
from extensions.chaeo.zmask import build_zmask_from_object_mask
from model_server.accessors import generate_file_accessor, InMemoryDataAccessor, write_accessor_data_to_file
from extensions.ilastik.models import IlastikPixelClassifierModel
from model_server.models import DummyInstanceSegmentationModel
class TestZStackDerivedDataProducts(unittest.TestCase):
def setUp(self) -> None:
# need test data incl obj map
self.stack = generate_file_accessor(multichannel_zstack['path'])
self.stack_ch_seg = self.stack.get_one_channel_data(pipeline_params['segmentation_channel'])
self.stack_ch_pa = self.stack.get_one_channel_data(pipeline_params['patches_channel'])
pxmodel = IlastikPixelClassifierModel(
self.pxmodel = IlastikPixelClassifierModel(
{'project_file': pixel_classifier['path']},
)
mip = InMemoryDataAccessor(self.stack.get_one_channel_data(channel=0).data.max(axis=-1, keepdims=True))
self.pxmap, result = pxmodel.infer(mip)
mip = InMemoryDataAccessor(
self.stack_ch_seg.data.max(axis=-1, keepdims=True)
)
pxmap, _ = self.pxmodel.infer(mip)
# write_accessor_data_to_file(output_path / 'pxmap.tif', self.pxmap)
self.obmap = InMemoryDataAccessor(self.pxmap.data > pipeline_params['threshold'])
# write_accessor_data_to_file(output_path / 'obmap.tif', self.obmap)
write_accessor_data_to_file(output_path / 'pxmap.tif', pxmap)
self.seg_mask = InMemoryDataAccessor(
pxmap.get_one_channel_data(
pipeline_params['pxmap_channel']
).data > pipeline_params['pxmap_threshold']
)
write_accessor_data_to_file(output_path / 'seg_mask.tif', self.seg_mask)
def test_zmask_makes_correct_boxes(self, mask_type='boxes', **kwargs):
zmask, meta, df, interm = build_zmask_from_object_mask(
self.obmap.get_one_channel_data(0),
self.stack.get_one_channel_data(0),
self.seg_mask,
self.stack_ch_pa,
mask_type=mask_type,
**kwargs,
)
......@@ -58,10 +69,10 @@ class TestZStackDerivedDataProducts(unittest.TestCase):
return zmask, meta
def test_zmask_works_on_non_zstacks(self, **kwargs):
acc_zstack_slice = InMemoryDataAccessor(self.stack.data[:, :, 0, 0])
acc_zstack_slice = InMemoryDataAccessor(self.stack_ch_pa.data[:, :, :, 0])
self.assertEqual(acc_zstack_slice.nz, 1)
zmask, meta, df, interm = build_zmask_from_object_mask(
self.obmap.get_one_channel_data(0),
self.seg_mask,
acc_zstack_slice,
mask_type='boxes',
**kwargs,
......@@ -85,7 +96,7 @@ class TestZStackDerivedDataProducts(unittest.TestCase):
)
files = export_patches_from_zstack(
output_path / '2d_patches',
self.stack.get_one_channel_data(channel=1),
self.stack_ch_pa,
meta,
draw_bounding_box=True,
)
......@@ -98,15 +109,15 @@ class TestZStackDerivedDataProducts(unittest.TestCase):
)
files = export_patches_from_zstack(
output_path / '3d_patches',
self.stack.get_one_channel_data(4),
self.stack_ch_pa,
meta,
make_3d=True)
self.assertGreaterEqual(len(files), 1)
def test_flatten_image(self):
zmask, meta, df, interm = build_zmask_from_object_mask(
self.obmap.get_one_channel_data(0),
self.stack.get_one_channel_data(4),
self.seg_mask,
self.stack_ch_pa,
mask_type='boxes',
)
......@@ -188,4 +199,20 @@ class TestZStackDerivedDataProducts(unittest.TestCase):
InMemoryDataAccessor(self.stack.data),
meta,
)
self.assertGreaterEqual(len(files), 1)
\ No newline at end of file
self.assertGreaterEqual(len(files), 1)
def test_object_map_workflow(self):
pp = pipeline_params
models = [
self.pxmodel,
DummyInstanceSegmentationModel(),
]
infer_object_map_from_zstack(
multichannel_zstack['path'],
output_path,
models,
pxmap_foreground_channel=pp['pxmap_channel'],
pxmap_threshold=pp['pxmap_threshold'],
segmentation_channel=pp['segmentation_channel'],
patches_channel=pp['patches_channel'],
)
\ No newline at end of file
......@@ -19,7 +19,7 @@ from extensions.chaeo.zmask import project_stack_from_focal_points, RoiSet
from extensions.ilastik.models import IlastikPixelClassifierModel
from model_server.accessors import generate_file_accessor, InMemoryDataAccessor, write_accessor_data_to_file
from model_server.models import Model
from model_server.models import Model, InstanceSegmentationModel, SemanticSegmentationModel
from model_server.process import rescale
from model_server.workflows import Timer
......@@ -240,14 +240,14 @@ def infer_object_map_from_zstack(
zmask_type: str = 'boxes',
zmask_filters: Dict = None,
# zmask_expand_box_by: int = None,
exports: RoiSetExportParams = None,
exports: RoiSetExportParams = RoiSetExportParams(),
**kwargs,
) -> Dict:
assert len(models) == 2
pixel_classifier = models[0]
assert isinstance(pixel_classifier, IlastikPixelClassifierModel)
assert isinstance(pixel_classifier, SemanticSegmentationModel)
object_classifier = models[1]
assert isinstance(object_classifier, PatchStackObjectClassifier)
assert isinstance(object_classifier, InstanceSegmentationModel)
ti = Timer()
stack = generate_file_accessor(Path(input_file_path))
......@@ -286,13 +286,10 @@ def infer_object_map_from_zstack(
stack.get_one_channel_data(segmentation_channel),
mask_type=zmask_type,
filters=zmask_filters,
expand_box_by=kwargs['zmask_expand_box_by'],
expand_box_by=kwargs.get('zmask_expand_box_by', (0, 0)),
)
ti.click('generate_zmasks')
# record pixel scale
rois.df['pixel_scale_in_micrometers'] = float(stack.pixel_scale_in_micrometers.get('X'))
# ti, stack, fstem, obmask, pxmap, obj_table = get_zmask_meta(
# input_file_path,
# pixel_classifier,
......@@ -341,13 +338,13 @@ def infer_object_map_from_zstack(
# output_map[labels_map == object_id] = object_class
# meta.append({'object_id': ii, 'object_class': object_id})
object_class_map = rois.classify_by(patches_channel)
object_class_map = rois.classify_by(patches_channel, object_classifier)
# TODO: add ZMaskObjectTable method to export object map
output_path = Path(output_folder_path) / ('obj_classes_' + (fstem + '.tif'))
write_accessor_data_to_file(
output_path,
InMemoryDataAccessor(object_class_map)
object_class_map
)
ti.click('export_object_classes')
......
......@@ -152,7 +152,7 @@ class RoiSet(object):
return projected
def get_raw_patches(self, channel):
return get_patches_from_zmask_meta(self.acc_raw(channel), self.zmask_meta)
return get_patches_from_zmask_meta(self.acc_raw, self.zmask_meta)
def get_patch_masks(self):
return get_patch_masks_from_zmask_meta(self.acc_raw, self.zmask_meta)
......
from fastapi import FastAPI, HTTPException
from model_server.models import DummySegmentationModel
from model_server.models import DummySemanticSegmentationModel
from model_server.session import Session, InvalidPathError
from model_server.validators import validate_workflow_inputs
from model_server.workflows import classify_pixels
......@@ -67,7 +67,7 @@ def list_active_models():
@app.put('/models/dummy/load/')
def load_dummy_model() -> dict:
return {'model_id': session.load_model(DummySegmentationModel)}
return {'model_id': session.load_model(DummySemanticSegmentationModel)}
@app.put('/workflows/segment')
def infer_img(model_id: str, input_filename: str, channel: int = None) -> dict:
......
......@@ -4,7 +4,7 @@ import requests
import unittest
from conf.testing import czifile
from model_server.models import DummySegmentationModel
from model_server.models import DummySemanticSegmentationModel
class TestServerBaseClass(unittest.TestCase):
def setUp(self) -> None:
......
import unittest
from conf.testing import czifile, output_path
from model_server.models import DummySegmentationModel
from model_server.models import DummySemanticSegmentationModel
from model_server.workflows import classify_pixels
class TestGetSessionObject(unittest.TestCase):
def setUp(self) -> None:
self.model = DummySegmentationModel()
self.model = DummySemanticSegmentationModel()
def test_single_session_instance(self):
result = classify_pixels(czifile['path'], self.model, output_path, channel=2)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment