Skip to content
Snippets Groups Projects
Commit 0d0a86bf authored by Christopher Randolph Rhodes's avatar Christopher Randolph Rhodes
Browse files

Force export of dataset and patches; just need to override bounding box expansion

parent 1b9a07f8
No related branches found
No related tags found
No related merge requests found
......@@ -56,11 +56,8 @@ class RoiSetExportParams(BaseModel):
patches_3d: Union[PatchParams, None] = None
annotated_patches_2d: Union[PatchParams, None] = None
patches_2d: Union[PatchParams, None] = None
patch_masks: Union[PatchParams, None] = None
annotated_zstacks: Union[AnnotatedZStackParams, None] = None
object_classes: bool = False
dataframe: bool = False
......@@ -350,7 +347,13 @@ class RoiSet(object):
om[self.acc_obj_ids.data == roi.label] = oc
self.object_class_maps[name] = InMemoryDataAccessor(om)
def export_patch_masks(self, where: Path, pad_to: int = 256, prefix='mask', **kwargs) -> list:
def export_dataframe(self, csv_path: Path):
csv_path.parent.mkdir(parents=True, exist_ok=True)
self._df.drop(['slice', 'relative_slice', 'mask'], axis=1).to_csv(csv_path, index=False)
return csv_path
def export_patch_masks(self, where: Path, pad_to: int = 256, prefix='mask') -> list:
patches_acc = self.get_patch_masks(pad_to=pad_to)
exported = []
......@@ -538,6 +541,11 @@ class RoiSet(object):
if not self.count:
return
raw_ch = self.acc_raw.get_one_channel_data(channel)
# export dataframe and patch masks
record['dataframe'] = self.export_dataframe(where / 'dataframe' / (prefix + '.csv'))
self.export_patch_masks(where / 'patch_masks', prefix=prefix, pad_to=None)
for k in params.dict().keys():
subdir = where / k
pr = prefix
......@@ -560,8 +568,6 @@ class RoiSet(object):
df_patches = pd.DataFrame(files)
self._df = pd.merge(self._df, df_patches, left_index=True, right_on='df_index').drop(columns='df_index')
self._df['patch_id'] = self._df.apply(lambda _: uuid4(), axis=1)
if k == 'patch_masks':
self.export_patch_masks(subdir, prefix=pr, **kp)
if k == 'annotated_zstacks':
self.export_annotated_zstack(subdir, prefix=pr, **kp)
if k == 'object_classes':
......@@ -569,11 +575,7 @@ class RoiSet(object):
fp = subdir / kc / (pr + '.tif')
write_accessor_data_to_file(fp, acc)
record[f'{k}_{kc}'] = fp
if k == 'dataframe':
dfpa = subdir / (pr + '.csv')
dfpa.parent.mkdir(parents=True, exist_ok=True)
self._df.drop(['slice', 'relative_slice', 'mask'], axis=1).to_csv(dfpa, index=False)
record[k] = dfpa
return record
......
import os
import re
import unittest
import numpy as np
......@@ -333,21 +335,37 @@ class TestRoiSetFromZmask(unittest.TestCase):
df_test = pd.read_csv(where_df)
# zmask = np.zeros((*self.stack.hw, 1, self.stack.nz), dtype=bool)
print('hi')
fn = output_path / 'roiset_from_3d' / 'patch_masks' / 'ref-la{:04d}-zi{:04d}.png'
patch_masks = {}
def _label_obj(r):
sl = np.s_[r.ebb_y0:r.ebb_y1, r.ebb_x0:r.ebb_x1, :, r.zi:r.zi + 1]
self.assertEqual(str(sl), r.slice)
patch_masks[r.label] = generate_file_accessor(str(fn).format(r.label, r.zi)).data
# zmask[sl] = True
df_test.apply(lambda x: _label_obj(x), axis=1)
roiset_test = RoiSet.from_df_and_patch_masks(self.stack, df_test, patch_masks)
print('')
# check that patches are correct size
where_patch_masks = output_path / 'roiset_from_3d' / 'patch_masks'
for pmf in where_patch_masks.iterdir():
self.assertTrue(pmf.suffix.upper() == '.PNG')
la = int(re.search(r'la([\d]+)', str(pmf)).group(1))
roi_q = df_test.loc[df_test.label == la, :]
self.assertEqual(len(roi_q), 1)
roi = roi_q.iloc[0]
h = int(roi.y1 - roi.y0)
w = int(roi.x1 - roi.x0)
m_acc = generate_file_accessor(pmf)
self.assertEqual((h, w), m_acc.hw)
# df_test = pd.read_csv(where_df)
#
# # zmask = np.zeros((*self.stack.hw, 1, self.stack.nz), dtype=bool)
# print('hi')
#
# fn = output_path / 'roiset_from_3d' / 'patch_masks' / 'ref-la{:04d}-zi{:04d}.png'
# patch_masks = {}
#
# def _label_obj(r):
# sl = np.s_[r.ebb_y0:r.ebb_y1, r.ebb_x0:r.ebb_x1, :, r.zi:r.zi + 1]
# self.assertEqual(str(sl), r.slice)
# patch_masks[r.label] = generate_file_accessor(str(fn).format(r.label, r.zi)).data
# # zmask[sl] = True
#
# df_test.apply(lambda x: _label_obj(x), axis=1)
#
# roiset_test = RoiSet.from_df_and_patch_masks(self.stack, df_test, patch_masks)
# print('')
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment