diff --git a/model_server/base/accessors.py b/model_server/base/accessors.py index bc74f919f3d2cc294398e921cffcdb7673754232..702388158c8b083f8c49898c8dc67dc1599d629e 100644 --- a/model_server/base/accessors.py +++ b/model_server/base/accessors.py @@ -73,6 +73,9 @@ class GenericImageDataAccessor(ABC): def dtype(self): return self.data.dtype + def write(self, fp: Path, mkdir=True): + write_accessor_data_to_file(fp, self, mkdir=mkdir) + def get_axis(self, ch): return self.axes.index(ch.upper()) @@ -122,6 +125,10 @@ class GenericImageFileAccessor(GenericImageDataAccessor): # image data is loaded raise FileAccessorError(f'Could not find file at {fpath}') self.fpath = fpath + @staticmethod + def read(fp: Path): + return generate_file_accessor(fp) + class TifSingleSeriesFileAccessor(GenericImageFileAccessor): def __init__(self, fpath: Path): super().__init__(fpath) diff --git a/model_server/base/roiset.py b/model_server/base/roiset.py index 28c1ce2d03d31695e91de1f98d1c428ed2cd7a19..2e47c7c58bcdde1a0ed4a0e0aa4e0511c363d36c 100644 --- a/model_server/base/roiset.py +++ b/model_server/base/roiset.py @@ -73,14 +73,18 @@ def _get_label_ids(acc_seg_mask: GenericImageDataAccessor, allow_3d=False, conne """ if allow_3d and connect_3d: nda_la = label( - acc_seg_mask.data[:, :, 0, :] + acc_seg_mask.data[:, :, 0, :], + connectivity=3, ).astype('uint16') return InMemoryDataAccessor(np.expand_dims(nda_la, 2)) elif allow_3d and not connect_3d: nla = 0 la_3d = np.zeros((*acc_seg_mask.hw, 1, acc_seg_mask.nz), dtype='uint16') for zi in range(0, acc_seg_mask.nz): - la_2d = label(acc_seg_mask.data[:, :, 0, zi]).astype('uint16') + la_2d = label( + acc_seg_mask.data[:, :, 0, zi], + connectivity=2, + ).astype('uint16') la_2d[la_2d > 0] = la_2d[la_2d > 0] + nla nla = la_2d.max() la_3d[:, :, 0, zi] = la_2d @@ -88,7 +92,8 @@ def _get_label_ids(acc_seg_mask: GenericImageDataAccessor, allow_3d=False, conne else: return InMemoryDataAccessor( label( - acc_seg_mask.data[:, :, 0, :].max(axis=-1) + acc_seg_mask.data[:, :, 0, :].max(axis=-1), + connectivity=1, ).astype('uint16') ) diff --git a/tests/test_roiset.py b/tests/test_roiset.py index c1d773c6d5ac975ecf0102e5bf349bb3e385e98b..d45c6132ea6dd52447f3f2614d0d9f35aef2a355 100644 --- a/tests/test_roiset.py +++ b/tests/test_roiset.py @@ -455,7 +455,6 @@ class TestRoiSetMultichannelProducts(BaseTestRoiSetMonoProducts, unittest.TestCa self.assertTrue(pa.exists()) pacc = generate_file_accessor(pa) self.assertEqual(pacc.hw, (256, 256)) - print('res') def test_run_export_mono_2d_patch(self): p = RoiSetExportParams(**{