Skip to content
Snippets Groups Projects
Commit d2287ee5 authored by Christopher Randolph Rhodes's avatar Christopher Randolph Rhodes
Browse files

Merge branch 'dev_connectivity_patch_obmaps' into 'staging'

Connectivity of labeling

See merge request !44
parents 87d32037 e2d864b6
No related branches found
No related tags found
No related merge requests found
...@@ -73,6 +73,9 @@ class GenericImageDataAccessor(ABC): ...@@ -73,6 +73,9 @@ class GenericImageDataAccessor(ABC):
def dtype(self): def dtype(self):
return self.data.dtype return self.data.dtype
def write(self, fp: Path, mkdir=True):
write_accessor_data_to_file(fp, self, mkdir=mkdir)
def get_axis(self, ch): def get_axis(self, ch):
return self.axes.index(ch.upper()) return self.axes.index(ch.upper())
...@@ -122,6 +125,10 @@ class GenericImageFileAccessor(GenericImageDataAccessor): # image data is loaded ...@@ -122,6 +125,10 @@ class GenericImageFileAccessor(GenericImageDataAccessor): # image data is loaded
raise FileAccessorError(f'Could not find file at {fpath}') raise FileAccessorError(f'Could not find file at {fpath}')
self.fpath = fpath self.fpath = fpath
@staticmethod
def read(fp: Path):
return generate_file_accessor(fp)
class TifSingleSeriesFileAccessor(GenericImageFileAccessor): class TifSingleSeriesFileAccessor(GenericImageFileAccessor):
def __init__(self, fpath: Path): def __init__(self, fpath: Path):
super().__init__(fpath) super().__init__(fpath)
......
...@@ -73,14 +73,18 @@ def _get_label_ids(acc_seg_mask: GenericImageDataAccessor, allow_3d=False, conne ...@@ -73,14 +73,18 @@ def _get_label_ids(acc_seg_mask: GenericImageDataAccessor, allow_3d=False, conne
""" """
if allow_3d and connect_3d: if allow_3d and connect_3d:
nda_la = label( nda_la = label(
acc_seg_mask.data[:, :, 0, :] acc_seg_mask.data[:, :, 0, :],
connectivity=3,
).astype('uint16') ).astype('uint16')
return InMemoryDataAccessor(np.expand_dims(nda_la, 2)) return InMemoryDataAccessor(np.expand_dims(nda_la, 2))
elif allow_3d and not connect_3d: elif allow_3d and not connect_3d:
nla = 0 nla = 0
la_3d = np.zeros((*acc_seg_mask.hw, 1, acc_seg_mask.nz), dtype='uint16') la_3d = np.zeros((*acc_seg_mask.hw, 1, acc_seg_mask.nz), dtype='uint16')
for zi in range(0, acc_seg_mask.nz): for zi in range(0, acc_seg_mask.nz):
la_2d = label(acc_seg_mask.data[:, :, 0, zi]).astype('uint16') la_2d = label(
acc_seg_mask.data[:, :, 0, zi],
connectivity=2,
).astype('uint16')
la_2d[la_2d > 0] = la_2d[la_2d > 0] + nla la_2d[la_2d > 0] = la_2d[la_2d > 0] + nla
nla = la_2d.max() nla = la_2d.max()
la_3d[:, :, 0, zi] = la_2d la_3d[:, :, 0, zi] = la_2d
...@@ -88,7 +92,8 @@ def _get_label_ids(acc_seg_mask: GenericImageDataAccessor, allow_3d=False, conne ...@@ -88,7 +92,8 @@ def _get_label_ids(acc_seg_mask: GenericImageDataAccessor, allow_3d=False, conne
else: else:
return InMemoryDataAccessor( return InMemoryDataAccessor(
label( label(
acc_seg_mask.data[:, :, 0, :].max(axis=-1) acc_seg_mask.data[:, :, 0, :].max(axis=-1),
connectivity=1,
).astype('uint16') ).astype('uint16')
) )
......
...@@ -455,7 +455,6 @@ class TestRoiSetMultichannelProducts(BaseTestRoiSetMonoProducts, unittest.TestCa ...@@ -455,7 +455,6 @@ class TestRoiSetMultichannelProducts(BaseTestRoiSetMonoProducts, unittest.TestCa
self.assertTrue(pa.exists()) self.assertTrue(pa.exists())
pacc = generate_file_accessor(pa) pacc = generate_file_accessor(pa)
self.assertEqual(pacc.hw, (256, 256)) self.assertEqual(pacc.hw, (256, 256))
print('res')
def test_run_export_mono_2d_patch(self): def test_run_export_mono_2d_patch(self):
p = RoiSetExportParams(**{ p = RoiSetExportParams(**{
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment