Skip to content
Snippets Groups Projects
Commit d191e4c2 authored by Christopher Randolph Rhodes's avatar Christopher Randolph Rhodes
Browse files

Merge branch 'staging' into dev_pipeline_and_smooth

# Conflicts:
#	model_server/base/accessors.py
#	model_server/base/roiset.py
#	model_server/base/util.py
#	tests/base/test_roiset.py
parents 18449743 eb7ed818
No related branches found
No related tags found
No related merge requests found
...@@ -61,16 +61,16 @@ class GenericImageDataAccessor(ABC): ...@@ -61,16 +61,16 @@ class GenericImageDataAccessor(ABC):
) )
) )
def get_mip(self):
"""
Return a new accessor of maximum intensity projection (MIP) along z-axis
"""
return self.apply(lambda x: x.max(axis=self._ga('Z'), keepdims=True))
def get_mono(self, channel: int, mip: bool = False):
def get_mono(self, channel: int, mip: bool = False, squeeze=False):
return self.get_channels([channel], mip=mip) return self.get_channels([channel], mip=mip)
def get_z_argmax(self):
return self.apply(lambda x: x.argmax(axis=self.get_axis('Z')))
def get_focus_vector(self):
return self.data.sum(axis=(0, 1, 2))
@property @property
def data_xy(self) -> np.ndarray: def data_xy(self) -> np.ndarray:
if not self.chroma == 1 and self.nz == 1: if not self.chroma == 1 and self.nz == 1:
...@@ -435,6 +435,11 @@ class PatchStack(InMemoryDataAccessor): ...@@ -435,6 +435,11 @@ class PatchStack(InMemoryDataAccessor):
else: else:
tifffile.imwrite(fpath, tzcyx, imagej=True) tifffile.imwrite(fpath, tzcyx, imagej=True)
def write(self, fp: Path, mkdir=True):
if mkdir:
fp.parent.mkdir(parents=True, exist_ok=True)
self.export_pyxcz(fp)
@property @property
def shape_dict(self): def shape_dict(self):
return dict(zip(('P', 'Y', 'X', 'C', 'Z'), self.data.shape)) return dict(zip(('P', 'Y', 'X', 'C', 'Z'), self.data.shape))
......
This diff is collapsed.
...@@ -61,6 +61,27 @@ class TestCziImageFileAccess(unittest.TestCase): ...@@ -61,6 +61,27 @@ class TestCziImageFileAccess(unittest.TestCase):
sc = cf.get_mono(c, mip=True) sc = cf.get_mono(c, mip=True)
self.assertEqual(sc.shape, (h, w, 1, 1)) self.assertEqual(sc.shape, (h, w, 1, 1))
def test_get_single_channel_argmax_from_zstack(self):
w = 256
h = 512
nc = 4
nz = 11
c = 3
cf = InMemoryDataAccessor(np.random.rand(h, w, nc, nz))
am = cf.get_mono(c).get_z_argmax()
self.assertEqual(am.shape, (h, w, 1, 1))
self.assertTrue(np.all(am.unique()[0] == range(0, nz)))
def test_get_single_channel_z_series_from_zstack(self):
w = 256
h = 512
nc = 4
nz = 11
c = 3
cf = InMemoryDataAccessor(np.random.rand(h, w, nc, nz))
zs = cf.get_mono(c).get_focus_vector()
self.assertEqual(zs.shape, (nz, ))
def test_get_zi(self): def test_get_zi(self):
w = 256 w = 256
h = 512 h = 512
......
...@@ -6,6 +6,7 @@ from pathlib import Path ...@@ -6,6 +6,7 @@ from pathlib import Path
import pandas as pd import pandas as pd
from model_server.base.process import smooth
from model_server.base.roiset import filter_df_overlap_bbox, filter_df_overlap_seg, RoiSetExportParams, RoiSetMetaParams from model_server.base.roiset import filter_df_overlap_bbox, filter_df_overlap_seg, RoiSetExportParams, RoiSetMetaParams
from model_server.base.roiset import RoiSet from model_server.base.roiset import RoiSet
from model_server.base.accessors import generate_file_accessor, InMemoryDataAccessor, write_accessor_data_to_file, PatchStack from model_server.base.accessors import generate_file_accessor, InMemoryDataAccessor, write_accessor_data_to_file, PatchStack
...@@ -188,12 +189,72 @@ class TestRoiSetMonoProducts(BaseTestRoiSetMonoProducts, unittest.TestCase): ...@@ -188,12 +189,72 @@ class TestRoiSetMonoProducts(BaseTestRoiSetMonoProducts, unittest.TestCase):
return roiset return roiset
def test_classify_by_multiple_channels(self): def test_classify_by_multiple_channels(self):
roiset = RoiSet.from_binary_mask(self.stack, self.seg_mask) roiset = RoiSet.from_binary_mask(self.stack, self.seg_mask, params=RoiSetMetaParams(deproject_channel=0))
roiset.classify_by('dummy_class', [0, 1], DummyInstanceSegmentationModel()) roiset.classify_by('dummy_class', [0, 1], DummyInstanceSegmentationModel())
self.assertTrue(all(roiset.get_df()['classify_by_dummy_class'].unique() == [1])) self.assertTrue(all(roiset.get_df()['classify_by_dummy_class'].unique() == [1]))
self.assertTrue(all(np.unique(roiset.get_object_class_map('dummy_class').data) == [0, 1])) self.assertTrue(all(np.unique(roiset.get_object_class_map('dummy_class').data) == [0, 1]))
return roiset return roiset
def test_transfer_classification(self):
roiset1 = RoiSet.from_binary_mask(self.stack, self.seg_mask, params=RoiSetMetaParams(deproject_channel=0))
# prepare alternative mask and compare
smoothed_mask = self.seg_mask.apply(lambda x: smooth(x, sig=1.5))
roiset2 = RoiSet.from_binary_mask(self.stack, smoothed_mask, params=RoiSetMetaParams(deproject_channel=0))
dmask = (self.seg_mask.data / 255) + (smoothed_mask.data / 255)
self.assertTrue(np.all(np.unique(dmask) == [0, 1, 2]))
total_iou = (dmask == 2).sum() / ((dmask == 1).sum() + (dmask == 2).sum())
self.assertGreater(total_iou, 0.6)
# classify first RoiSet
roiset1.classify_by('dummy_class', [0, 1], DummyInstanceSegmentationModel())
self.assertTrue('dummy_class' in roiset1.classification_columns)
self.assertFalse('dummy_class' in roiset2.classification_columns)
res = roiset2.get_instance_classification(roiset1)
self.assertTrue('dummy_class' in roiset2.classification_columns)
self.assertLess(
roiset2.get_df().classify_by_dummy_class.count(),
roiset1.get_df().classify_by_dummy_class.count(),
)
def test_classify_by_with_derived_channel(self):
class ModelWithDerivedInputs(DummyInstanceSegmentationModel):
def infer(self, img, mask):
return PatchStack(super().infer(img, mask).data * img.chroma)
roiset = RoiSet.from_binary_mask(
self.stack,
self.seg_mask,
params=RoiSetMetaParams(
filters={'area': {'min': 1e3, 'max': 1e4}},
deproject_channel=0,
)
)
roiset.classify_by(
'multiple_input_model',
[0, 1],
ModelWithDerivedInputs(),
derived_channel_functions=[
lambda acc: PatchStack(2 * acc.get_channels([0]).data),
lambda acc: PatchStack((0.5 * acc.get_channels([1]).data).astype('uint8'))
]
)
self.assertTrue(all(roiset.get_df()['classify_by_multiple_input_model'].unique() == [4]))
self.assertTrue(all(np.unique(roiset.get_object_class_map('multiple_input_model').data) == [0, 4]))
self.assertEqual(len(roiset.accs_derived), 2)
for di in roiset.accs_derived:
self.assertEqual(roiset.get_patches_acc().hw, di.hw)
self.assertEqual(roiset.get_patches_acc().nz, di.nz)
self.assertEqual(roiset.get_patches_acc().count, di.count)
dpas = roiset.run_exports(output_path / 'derived_channels', 0, 'der', RoiSetExportParams(derived_channels=True))
for fp in dpas['derived_channels']:
assert Path(fp).exists()
return roiset
def test_export_object_classes(self): def test_export_object_classes(self):
record = self.test_classify_by().run_exports( record = self.test_classify_by().run_exports(
output_path / 'object_class_maps', output_path / 'object_class_maps',
...@@ -237,6 +298,7 @@ class TestRoiSetMultichannelProducts(BaseTestRoiSetMonoProducts, unittest.TestCa ...@@ -237,6 +298,7 @@ class TestRoiSetMultichannelProducts(BaseTestRoiSetMonoProducts, unittest.TestCa
expand_box_by=(128, 2), expand_box_by=(128, 2),
mask_type='boxes', mask_type='boxes',
filters={'area': {'min': 1e3, 'max': 1e4}}, filters={'area': {'min': 1e3, 'max': 1e4}},
deproject_channel=0,
) )
) )
...@@ -367,7 +429,6 @@ class TestRoiSetMultichannelProducts(BaseTestRoiSetMonoProducts, unittest.TestCa ...@@ -367,7 +429,6 @@ class TestRoiSetMultichannelProducts(BaseTestRoiSetMonoProducts, unittest.TestCa
self.assertEqual(result.nz, self.roiset.acc_raw.nz) self.assertEqual(result.nz, self.roiset.acc_raw.nz)
self.assertEqual(result.chroma, 1) self.assertEqual(result.chroma, 1)
def test_run_exports(self): def test_run_exports(self):
p = RoiSetExportParams(**{ p = RoiSetExportParams(**{
'patches_3d': {}, 'patches_3d': {},
...@@ -606,6 +667,60 @@ class TestRoiSetSerialization(unittest.TestCase): ...@@ -606,6 +667,60 @@ class TestRoiSetSerialization(unittest.TestCase):
t_acc = generate_file_accessor(pt) t_acc = generate_file_accessor(pt)
self.assertTrue(np.all(r_acc.data == t_acc.data)) self.assertTrue(np.all(r_acc.data == t_acc.data))
class TestRoiSetObjectDetection(unittest.TestCase):
def setUp(self) -> None:
# set up test raw data and segmentation from file
self.stack = generate_file_accessor(data['multichannel_zstack_raw']['path'])
self.stack_ch_pa = self.stack.get_mono(params['segmentation_channel'])
self.seg_mask_3d = generate_file_accessor(data['multichannel_zstack_mask3d']['path'])
def test_create_roiset_from_bounding_boxes(self):
from skimage.measure import label, regionprops, regionprops_table
mask = self.seg_mask_3d
labels = label(mask.data_xyz, connectivity=3)
table = pd.DataFrame(
regionprops_table(labels)
).rename(
columns={'bbox-0': 'y', 'bbox-1': 'x', 'bbox-2': 'zi', 'bbox-3': 'y1', 'bbox-4': 'x1'}
).drop(
columns=['bbox-5']
)
table['w'] = table['x1'] - table['x']
table['h'] = table['y1'] - table['y']
bboxes = table[['y', 'x', 'h', 'w']].to_dict(orient='records')
roiset_bbox = RoiSet.from_bounding_boxes(self.stack_ch_pa, bboxes)
self.assertTrue('label' in roiset_bbox.get_df().columns)
patches_bbox = roiset_bbox.get_patches_acc()
self.assertEqual(len(table), patches_bbox.count)
# roiset w/ seg for comparison
roiset_seg = RoiSet.from_binary_mask(self.stack_ch_pa, mask, allow_3d=True)
patches_seg = roiset_seg.get_patches_acc()
# test bounding box dimensions match those from RoiSet generated directly from segmentation
self.assertEqual(roiset_seg.count, roiset_bbox.count)
for i in range(0, roiset_seg.count):
self.assertEqual(patches_seg.iat(0, crop=True).shape, patches_bbox.iat(0, crop=True).shape)
# test that serialization does not write patch masks
roiset_ser_path = output_path / 'roiset_from_bbox'
dd = roiset_bbox.serialize(roiset_ser_path)
self.assertTrue('tight_patch_masks' not in dd.keys())
self.assertFalse((roiset_ser_path / 'tight_patch_masks').exists())
# test that deserialized RoiSet matches the original
roiset_des = RoiSet.deserialize(self.stack_ch_pa, roiset_ser_path)
self.assertEqual(roiset_des.count, roiset_bbox.count)
for i in range(0, roiset_des.count):
self.assertEqual(patches_seg.iat(0, crop=True).shape, patches_bbox.iat(0, crop=True).shape)
self.assertTrue((roiset_bbox.get_zmask() == roiset_des.get_zmask()).all())
class TestRoiSetPolygons(BaseTestRoiSetMonoProducts, unittest.TestCase): class TestRoiSetPolygons(BaseTestRoiSetMonoProducts, unittest.TestCase):
def test_compute_polygons(self): def test_compute_polygons(self):
...@@ -649,10 +764,33 @@ class TestRoiSetPolygons(BaseTestRoiSetMonoProducts, unittest.TestCase): ...@@ -649,10 +764,33 @@ class TestRoiSetPolygons(BaseTestRoiSetMonoProducts, unittest.TestCase):
res = filter_df_overlap_bbox(df) res = filter_df_overlap_bbox(df)
self.assertEqual(len(res), 4) self.assertEqual(len(res), 4)
self.assertTrue((res.loc[0, 'bbox_overlaps_with'] == [1]).all()) self.assertTrue((res.loc[0, 'overlaps_with'] == [1]).all())
self.assertTrue((res.loc[1, 'bbox_overlaps_with'] == [0, 2]).all()) self.assertTrue((res.loc[1, 'overlaps_with'] == [0, 2]).all())
self.assertTrue((res.bbox_intersec == 2).all())
return res return res
def test_overlap_bbox_multiple(self):
df1 = pd.DataFrame({
'x0': [0, 1],
'x1': [2, 3],
'y0': [0, 0],
'y1': [2, 2],
'zi': [0, 0],
})
df2 = pd.DataFrame({
'x0': [2],
'x1': [4],
'y0': [0],
'y1': [2],
'zi': [0],
})
res = filter_df_overlap_bbox(df1, df2)
self.assertTrue((res.loc[1, 'overlaps_with'] == [0]).all())
self.assertEqual(len(res), 1)
self.assertTrue((res.bbox_intersec == 2).all())
def test_overlap_seg(self): def test_overlap_seg(self):
df = pd.DataFrame({ df = pd.DataFrame({
'x0': [0, 1, 2], 'x0': [0, 1, 2],
...@@ -677,4 +815,43 @@ class TestRoiSetPolygons(BaseTestRoiSetMonoProducts, unittest.TestCase): ...@@ -677,4 +815,43 @@ class TestRoiSetPolygons(BaseTestRoiSetMonoProducts, unittest.TestCase):
}) })
res = filter_df_overlap_seg(df) res = filter_df_overlap_seg(df)
self.assertTrue((res.loc[res.seg_overlaps, :].index == [1, 2]).all()) self.assertTrue((res.loc[res.seg_overlaps, :].index == [1, 2]).all())
\ No newline at end of file self.assertTrue((res.loc[res.seg_overlaps, 'seg_iou'] == 0.4).all())
def test_overlap_seg_multiple(self):
df1 = pd.DataFrame({
'x0': [0, 1],
'x1': [2, 3],
'y0': [0, 0],
'y1': [2, 2],
'zi': [0, 0],
'binary_mask': [
[
[1, 1],
[1, 0]
],
[
[0, 1],
[1, 1]
],
]
})
df2 = pd.DataFrame({
'x0': [2],
'x1': [4],
'y0': [0],
'y1': [2],
'zi': [0],
'binary_mask': [
[
[1, 1],
[1, 1]
],
]
})
res = filter_df_overlap_seg(df1, df2)
self.assertTrue((res.loc[1, 'overlaps_with'] == [0]).all())
self.assertEqual(len(res), 1)
self.assertTrue((res.bbox_intersec == 2).all())
self.assertTrue((res.loc[res.seg_overlaps, :].index == [1]).all())
self.assertTrue((res.loc[res.seg_overlaps, 'seg_iou'] == 0.4).all())
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment