diff --git a/model_server/base/roiset.py b/model_server/base/roiset.py index 943aedf2cc205c7fca6d6305ad52f3484b1f4515..403ea765a5faf8cd4511f5207c3f451c8dca836d 100644 --- a/model_server/base/roiset.py +++ b/model_server/base/roiset.py @@ -257,18 +257,18 @@ class RoiSet(object): projected = self.acc_raw.data.max(axis=-1) return projected - def get_patches_acc(self, channel=None, **kwargs): # padded, un-annotated 2d patches + def get_patches_acc(self, channel=None, **kwargs) -> PatchStack: # padded, un-annotated 2d patches if channel: patches_df = self.get_patches(white_channel=channel, **kwargs) else: patches_df = self.get_patches(**kwargs) return PatchStack(list(patches_df.patch)) - def export_annotated_zstack(self, where, prefix='zstack', **kwargs) -> Path: + def export_annotated_zstack(self, where, prefix='zstack', **kwargs) -> str(Path): annotated = InMemoryDataAccessor(draw_boxes_on_3d_image(self, **kwargs)) fp = where / (prefix + '.tif') write_accessor_data_to_file(fp, annotated) - return fp + return str(fp) def get_zmask(self, mask_type='boxes'): """ @@ -337,24 +337,22 @@ class RoiSet(object): def export_patch_masks(self, where: Path, pad_to: int = None, prefix='mask', expanded=False) -> list: - patches_df = self.get_patch_masks(pad_to=pad_to, expanded=expanded) + patches_df = self.get_patch_masks(pad_to=pad_to, expanded=expanded).copy() - exported = [] def _export_patch_mask(roi): patch = InMemoryDataAccessor(roi.patch_mask) ext = 'png' fname = f'{prefix}-la{roi.label:04d}-zi{roi.zi:04d}.{ext}' write_accessor_data_to_file(where / fname, patch) - exported.append(fname) + return str(where / fname) - for roi in patches_df.itertuples(): # just used for label info - _export_patch_mask(roi) - return exported + patches_df['patch_mask_path'] = patches_df.apply(_export_patch_mask, axis=1) + return patches_df - def export_patches(self, where: Path, prefix='patch', **kwargs) -> list: + def export_patches(self, where: Path, prefix='patch', **kwargs) -> pd.DataFrame: make_3d = kwargs.get('make_3d', False) - patches_df = self.get_patches(**kwargs) + patches_df = self.get_patches(**kwargs).copy() def _export_patch(roi): patch = InMemoryDataAccessor(roi.patch) @@ -366,14 +364,10 @@ class RoiSet(object): write_accessor_data_to_file(where / fname, resampled) else: write_accessor_data_to_file(where / fname, patch) + return str(where / fname) - exported.append(where / fname) - - exported = [] - for roi in patches_df.itertuples(): # just used for label info - _export_patch(roi) - - return exported + patches_df['patch_path'] = patches_df.apply(_export_patch, axis=1) + return patches_df def get_patch_masks(self, pad_to: int = None, expanded: bool = False) -> pd.DataFrame: def _make_patch_mask(roi): @@ -552,29 +546,30 @@ class RoiSet(object): if kp is None: continue if k == 'patches_3d': - record[k] = self.export_patches( + df_exp = self.export_patches( subdir, white_channel=channel, prefix=pr, make_3d=True, expanded=True, **kp ) + record[k] = list(df_exp.patch_path) if k == 'annotated_patches_2d': - record[k] = self.export_patches( + df_exp = self.export_patches( subdir, prefix=pr, make_3d=False, white_channel=channel, bounding_box_channel=1, bounding_box_linewidth=2, **kp, ) + record[k] = list(df_exp.patch_path) if k == 'patches_2d': - files = self.export_patches( + df_exp = self.export_patches( subdir, white_channel=channel, prefix=pr, make_3d=False, **kp ) - df_patches = pd.DataFrame(files) - self._df = pd.merge(self._df, df_patches, left_index=True, right_on='df_index').drop(columns='df_index') + self._df = self._df.join(df_exp.patch_path) self._df['patch_id'] = self._df.apply(lambda _: uuid4(), axis=1) - record[k] = files + record[k] = list(df_exp.patch_path) if k == 'annotated_zstacks': record[k] = self.export_annotated_zstack(subdir, prefix=pr, **kp) if k == 'object_classes': for kc, acc in self.object_class_maps.items(): fp = subdir / kc / (pr + '.tif') write_accessor_data_to_file(fp, acc) - record[f'{k}_{kc}'] = fp + record[f'{k}_{kc}'] = str(fp) return record @@ -586,13 +581,15 @@ class RoiSet(object): :return: nested dict of Path objects describing the locations of export products """ record = {} - record['dataframe'] = self.export_dataframe(where / 'dataframe' / (prefix + '.csv')) - record['tight_patch_masks'] = self.export_patch_masks( + df_exp = self.export_patch_masks( where / 'tight_patch_masks', prefix=prefix, pad_to=None, expanded=False ) + self._df = self._df.join(df_exp.patch_mask_path) + record['dataframe'] = str(self.export_dataframe(where / 'dataframe' / (prefix + '.csv'))) + record['tight_patch_masks'] = list(df_exp.patch_mask_path) return record @staticmethod diff --git a/tests/test_roiset.py b/tests/test_roiset.py index 0d507b3f90f3346335f9798a84e77efea94baf35..fad797404e8cbda6d552aba818c5c7a1c11817d6 100644 --- a/tests/test_roiset.py +++ b/tests/test_roiset.py @@ -102,14 +102,14 @@ class TestRoiSetMonoProducts(BaseTestRoiSetMonoProducts, unittest.TestCase): def test_make_expanded_2d_patches(self): roiset = self._make_roi_set() - files = roiset.export_patches( + df_res = roiset.export_patches( output_path / 'expanded_2d_patches', draw_bounding_box=True, expanded=True, pad_to=256, ) df = roiset.get_df() - for f in files: + for f in df_res.patch_path: acc = generate_file_accessor(f) la = int(re.search(r'la([\d]+)', str(f)).group(1)) roi_q = df.loc[df.label == la, :] @@ -118,13 +118,13 @@ class TestRoiSetMonoProducts(BaseTestRoiSetMonoProducts, unittest.TestCase): def test_make_tight_2d_patches(self): roiset = self._make_roi_set() - files = roiset.export_patches( + df_res = roiset.export_patches( output_path / 'tight_2d_patches', draw_bounding_box=True, expanded=False ) df = roiset.get_df() - for f in files: # all exported files are same shape as bounding boxes in RoiSet's datatable + for f in df_res.patch_path: # all exported files are same shape as bounding boxes in RoiSet's datatable acc = generate_file_accessor(f) la = int(re.search(r'la([\d]+)', str(f)).group(1)) roi_q = df.loc[df.label == la, :] @@ -134,13 +134,13 @@ class TestRoiSetMonoProducts(BaseTestRoiSetMonoProducts, unittest.TestCase): def test_make_expanded_3d_patches(self): roiset = self._make_roi_set() - files = roiset.export_patches( + df_res = roiset.export_patches( output_path / '3d_patches', make_3d=True, expanded=True ) - self.assertGreaterEqual(len(files), 1) - for f in files: + self.assertGreaterEqual(len(df_res), 1) + for f in df_res.patch_path: acc = generate_file_accessor(f) self.assertGreater(acc.nz, 1) @@ -178,10 +178,10 @@ class TestRoiSetMonoProducts(BaseTestRoiSetMonoProducts, unittest.TestCase): def test_make_binary_masks(self): roiset = self._make_roi_set() - files = roiset.export_patch_masks(output_path / '2d_mask_patches', ) + df_res = roiset.export_patch_masks(output_path / '2d_mask_patches', ) df = roiset.get_df() - for f in files: # all exported files are same shape as bounding boxes in RoiSet's datatable + for f in df_res.patch_mask_path: # all exported files are same shape as bounding boxes in RoiSet's datatable acc = generate_file_accessor(output_path / '2d_mask_patches' / f) la = int(re.search(r'la([\d]+)', str(f)).group(1)) roi_q = df.loc[df.label == la, :] @@ -244,18 +244,18 @@ class TestRoiSetMultichannelProducts(BaseTestRoiSetMonoProducts, unittest.TestCa ) def test_multichannel_to_mono_2d_patches(self): - files = self.roiset.export_patches( + df_res = self.roiset.export_patches( output_path / 'multichannel' / 'mono_2d_patches', white_channel=3, draw_bounding_box=True, expanded=True, pad_to=256, ) - result = generate_file_accessor(files[0]) + result = generate_file_accessor(df_res.patch_path.iloc[0]) self.assertEqual(result.chroma, 1) def test_multichannnel_to_mono_2d_patches_rgb_bbox(self): - files = self.roiset.export_patches( + df_res = self.roiset.export_patches( output_path / 'multichannel' / 'mono_2d_patches_rgb_bbox', white_channel=3, draw_bounding_box=True, @@ -263,11 +263,11 @@ class TestRoiSetMultichannelProducts(BaseTestRoiSetMonoProducts, unittest.TestCa expanded=True, pad_to=256, ) - result = generate_file_accessor(files[0]) + result = generate_file_accessor(df_res.patch_path.iloc[0]) self.assertEqual(result.chroma, 3) def test_multichannnel_to_rgb_2d_patches_bbox(self): - files = self.roiset.export_patches( + df_res = self.roiset.export_patches( output_path / 'multichannel' / 'rgb_2d_patches_bbox', white_channel=4, rgb_overlay_channels=(3, None, None), @@ -278,11 +278,11 @@ class TestRoiSetMultichannelProducts(BaseTestRoiSetMonoProducts, unittest.TestCa expanded=True, pad_to=256, ) - result = generate_file_accessor(files[0]) + result = generate_file_accessor(df_res.patch_path.iloc[0]) self.assertEqual(result.chroma, 3) def test_multichannnel_to_rgb_2d_patches_mask(self): - files = self.roiset.export_patches( + df_res = self.roiset.export_patches( output_path / 'multichannel' / 'rgb_2d_patches_mask', white_channel=4, rgb_overlay_channels=(3, None, None), @@ -292,11 +292,11 @@ class TestRoiSetMultichannelProducts(BaseTestRoiSetMonoProducts, unittest.TestCa expanded=True, pad_to=256, ) - result = generate_file_accessor(files[0]) + result = generate_file_accessor(df_res.patch_path.iloc[0]) self.assertEqual(result.chroma, 3) def test_multichannnel_to_rgb_2d_patches_contour(self): - files = self.roiset.export_patches( + df_res = self.roiset.export_patches( output_path / 'multichannel' / 'rgb_2d_patches_contour', rgb_overlay_channels=(3, None, None), draw_contour=True, @@ -305,17 +305,17 @@ class TestRoiSetMultichannelProducts(BaseTestRoiSetMonoProducts, unittest.TestCa expanded=True, pad_to=256, ) - result = generate_file_accessor(files[0]) + result = generate_file_accessor(df_res.patch_path.iloc[0]) self.assertEqual(result.chroma, 3) self.assertEqual(result.get_one_channel_data(2).data.max(), 0) # blue channel is black def test_multichannel_to_multichannel_tif_patches(self): - files = self.roiset.export_patches( + df_res = self.roiset.export_patches( output_path / 'multichannel' / 'multichannel_tif_patches', expanded=True, pad_to=256, ) - result = generate_file_accessor(files[0]) + result = generate_file_accessor(df_res.patch_path.iloc[0]) self.assertEqual(result.chroma, 5) self.assertEqual(result.nz, 1) @@ -342,6 +342,42 @@ class TestRoiSetMultichannelProducts(BaseTestRoiSetMonoProducts, unittest.TestCa self.assertEqual(result.nz, self.roiset.acc_raw.nz) self.assertEqual(result.chroma, 1) + def test_run_exports(self): + p = RoiSetExportParams(**{ + 'patches_3d': {}, + 'annotated_patches_2d': { + 'draw_bounding_box': True, + 'rgb_overlay_channels': [3, None, None], + 'rgb_overlay_weights': [0.2, 1.0, 1.0], + 'pad_to': 512, + }, + 'patches_2d': { + 'draw_bounding_box': False, + 'draw_mask': False, + }, + 'patch_masks': { + 'pad_to': 256, + }, + 'annotated_zstacks': {}, + 'object_classes': True, + 'dataframe': True, + }) + + res = self.roiset.run_exports( + output_path / 'run_exports', + channel=3, + prefix='test', + params=p + ) + + for k, v in res.items(): + if isinstance(v, list): + for f in v: + self.assertTrue(Path(f).exists()) + else: + self.assertTrue(Path(v).exists()) + + class TestRoiSetFromZmask(unittest.TestCase): def setUp(self) -> None: