Skip to content
Snippets Groups Projects
Commit 95e5232e authored by Christopher Randolph Rhodes's avatar Christopher Randolph Rhodes
Browse files

Split into two test classes to simplify things a bit

parent 2397e58d
No related branches found
No related tags found
No related merge requests found
......@@ -14,7 +14,7 @@ from model_server.base.accessors import generate_file_accessor, InMemoryDataAcce
from model_server.extensions.ilastik.models import IlastikPixelClassifierModel
from model_server.base.models import DummyInstanceSegmentationModel
class TestZStackDerivedDataProducts(unittest.TestCase):
class BaseTestRoiSetMonoProducts(object):
def setUp(self) -> None:
......@@ -39,18 +39,9 @@ class TestZStackDerivedDataProducts(unittest.TestCase):
)
write_accessor_data_to_file(output_path / 'seg_mask.tif', self.seg_mask)
id_map = get_label_ids(self.seg_mask)
self.roiset = RoiSet(
id_map,
self.stack,
params=RoiSetMetaParams(
expand_box_by=(128, 2),
mask_type='boxes',
filters={'area': {'min': 1e3, 'max': 1e4}},
)
)
class TestRoiSetMonoProducts(BaseTestRoiSetMonoProducts, unittest.TestCase):
def test_zmask_makes_correct_boxes(self, mask_type='boxes', **kwargs):
def _make_roi_set(self, mask_type='boxes', **kwargs):
id_map = get_label_ids(self.seg_mask)
roiset = RoiSet(
id_map,
......@@ -59,6 +50,10 @@ class TestZStackDerivedDataProducts(unittest.TestCase):
mask_type=mask_type, filters=kwargs.get('filters')
)
)
return roiset
def test_zmask_makes_correct_boxes(self, **kwargs):
roiset = self._make_roi_set(mask_type='boxes', **kwargs)
zmask = roiset.get_zmask()
meta = roiset.zmask_meta
interm = roiset.interm
......@@ -98,23 +93,23 @@ class TestZStackDerivedDataProducts(unittest.TestCase):
self.assertTrue(zmask_acc.is_mask())
def test_zmask_makes_correct_contours(self):
return self.test_zmask_makes_correct_boxes(mask_type='contours')
return self._make_roi_set(mask_type='contours')
def test_zmask_makes_correct_boxes_with_filters(self):
return self.test_zmask_makes_correct_boxes(filters={'area': {'min': 1e3, 'max': 1e4}})
return self._make_roi_set(filters={'area': {'min': 1e3, 'max': 1e4}})
def test_zmask_makes_correct_expanded_boxes(self):
return self.test_zmask_makes_correct_boxes(expand_box_by=(64, 2))
return self._make_roi_set(expand_box_by=(64, 2))
def test_zmask_slices_are_valid(self):
roiset = self.test_zmask_makes_correct_boxes()
roiset = self._make_roi_set()
for s in roiset.get_slices():
ebb = roiset.acc_raw.data[s]
self.assertEqual(len(ebb.shape), 4)
self.assertTrue(np.all([si >= 1 for si in ebb.shape]))
def test_zmask_rel_slices_are_valid(self):
roiset = self.test_zmask_makes_correct_boxes()
roiset = self._make_roi_set()
for i, s in enumerate(roiset.get_slices()):
ebb = roiset.acc_raw.data[s]
self.assertEqual(len(ebb.shape), 4)
......@@ -125,7 +120,7 @@ class TestZStackDerivedDataProducts(unittest.TestCase):
self.assertTrue(np.all([si >= 1 for si in rbb.shape]))
def test_make_2d_patches_from_zmask(self):
roiset = self.test_zmask_makes_correct_boxes(
roiset = self._make_roi_set(
filters={'area': {'min': 1e3, 'max': 1e4}},
expand_box_by=(64, 2)
)
......@@ -137,7 +132,7 @@ class TestZStackDerivedDataProducts(unittest.TestCase):
self.assertGreaterEqual(len(files), 1)
def test_make_3d_patches_from_zmask(self):
roiset = self.test_zmask_makes_correct_boxes(
roiset = self._make_roi_set(
filters={'area': {'min': 1e3, 'max': 1e4}},
expand_box_by=(64, 2),
)
......@@ -170,7 +165,31 @@ class TestZStackDerivedDataProducts(unittest.TestCase):
InMemoryDataAccessor(img)
)
def test_make_binary_masks_from_zmask(self):
roiset = self._make_roi_set(
filters={'area': {'min': 1e3, 'max': 1e4}},
expand_box_by=(128, 2)
)
files = roiset.export_patch_masks(output_path / '2d_mask_patches', )
self.assertGreaterEqual(len(files), 1)
class TestRoiSetMultichannelProducts(BaseTestRoiSetMonoProducts, unittest.TestCase):
def setUp(self) -> None:
super().setUp()
id_map = get_label_ids(self.seg_mask)
self.roiset = RoiSet(
id_map,
self.stack,
params=RoiSetMetaParams(
expand_box_by=(128, 2),
mask_type='boxes',
filters={'area': {'min': 1e3, 'max': 1e4}},
)
)
def test_multichannel_to_mono_2d_patches(self):
files = export_multichannel_patches_from_zstack(
output_path / 'test_multichannel_to_mono_2d_patches',
self.roiset,
......@@ -226,15 +245,6 @@ class TestZStackDerivedDataProducts(unittest.TestCase):
result = generate_file_accessor(Path(files[0]['location']) / files[0]['patch_filename'])
self.assertEqual(result.chroma, 5)
# TODO: rewrite with direct call to RoiSet methods
def test_make_binary_masks_from_zmask(self):
roiset = self.test_zmask_makes_correct_boxes(
filters={'area': {'min': 1e3, 'max': 1e4}},
expand_box_by=(128, 2)
)
files = roiset.export_patch_masks(output_path / '2d_mask_patches', )
self.assertGreaterEqual(len(files), 1)
def test_object_map_workflow(self):
pp = pipeline_params
models = [
......@@ -276,7 +286,9 @@ class TestZStackDerivedDataProducts(unittest.TestCase):
'draw_bounding_box': False,
'draw_mask': False,
},
'patch_masks': True,
'patch_masks': {
'pad_to': 256,
},
'annotated_zstacks': {},
'object_classes': True,
'dataframe': True,
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment