Skip to content
Snippets Groups Projects
Commit 5cdfcf06 authored by Christopher Randolph Rhodes's avatar Christopher Randolph Rhodes
Browse files

Method to export derived channels as patch stacks

parent dd1cdf8a
No related branches found
No related tags found
3 merge requests!37Release 2024.04.19,!34Revert "Temporary error-handling for debug...",!30Accessor changes to support object classification
......@@ -59,6 +59,7 @@ class RoiSetExportParams(BaseModel):
patches_2d: Union[PatchParams, None] = None
annotated_zstacks: Union[AnnotatedZStackParams, None] = None
object_classes: bool = False
derived_channels: bool = False
......@@ -134,6 +135,7 @@ class RoiSet(object):
assert acc_obj_ids.chroma == 1
self.acc_obj_ids = acc_obj_ids
self.acc_raw = acc_raw
self.accs_derived = []
self.params = params
self._df = self.filter_df(
......@@ -331,10 +333,20 @@ class RoiSet(object):
mono_data = [raw_acc.get_one_channel_data(c).data for c in range(0, raw_acc.chroma)]
for fcn in derived_channel_functions:
der = fcn(raw_acc) # returns patch stack
assert der.shape == mono_data[0].shape
mono_data.append(der.data)
if der.shape != mono_data[0].shape or der.dtype not in ['uint8', 'uint16']:
raise DerivedChannelError(
f'Error processing derived channel {der} with shape {der.shape_dict} and dtype {der.dtype}'
)
self.accs_derived.append(der)
# combine channels
input_acc = PatchStack(np.concatenate(mono_data, axis=raw_acc._ga('C')))
data_derived = [acc.data for acc in self.accs_derived]
input_acc = PatchStack(
np.concatenate(
[*mono_data, *data_derived],
axis=raw_acc._ga('C')
)
)
else:
input_acc = raw_acc
......@@ -595,6 +607,13 @@ class RoiSet(object):
fp = subdir / kc / (pr + '.tif')
write_accessor_data_to_file(fp, acc)
record[f'{k}_{kc}'] = str(fp)
if k == 'derived_channels':
record[k] = []
for di, dacc in enumerate(self.accs_derived):
fp = subdir / f'dc{di:01d}.tif'
fp.parent.mkdir(exist_ok=True, parents=True)
dacc.export_pyxcz(fp)
record[k].append(str(fp))
# export dataframe and patch masks
record = {**record, **self.serialize(where, prefix=prefix)}
......@@ -696,4 +715,7 @@ class Error(Exception):
pass
class DeserializeRoiSet(Error):
pass
class DerivedChannelError(Error):
pass
\ No newline at end of file
......@@ -219,7 +219,7 @@ class TestRoiSetMonoProducts(BaseTestRoiSetMonoProducts, unittest.TestCase):
ModelWithDerivedInputs(),
derived_channel_functions=[
lambda acc: PatchStack(2 * acc.data),
lambda acc: PatchStack(0.5 * acc.data)
lambda acc: PatchStack((0.5 * acc.data).astype('uint8'))
]
)
self.assertTrue(all(roiset.get_df()['classify_by_multiple_input_model'].unique() == [3]))
......@@ -228,6 +228,10 @@ class TestRoiSetMonoProducts(BaseTestRoiSetMonoProducts, unittest.TestCase):
self.assertEqual(len(roiset.accs_derived), 2)
for di in roiset.accs_derived:
self.assertEqual(roiset.get_patches_acc().shape, di.shape)
dpas = roiset.run_exports(output_path / 'derived_channels', 0, 'der', RoiSetExportParams(derived_channels=True))
for fp in dpas['derived_channels']:
assert Path(fp).exists()
return roiset
def test_export_object_classes(self):
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment