diff --git a/model_server/base/roiset.py b/model_server/base/roiset.py index 6ee5d2171de439bfc7ce7eedfbe34f821784b1cc..f3c0c9479c35d0cea4bb0fc660a9e692ca101686 100644 --- a/model_server/base/roiset.py +++ b/model_server/base/roiset.py @@ -59,6 +59,7 @@ class RoiSetExportParams(BaseModel): patches_2d: Union[PatchParams, None] = None annotated_zstacks: Union[AnnotatedZStackParams, None] = None object_classes: bool = False + derived_channels: bool = False @@ -134,6 +135,7 @@ class RoiSet(object): assert acc_obj_ids.chroma == 1 self.acc_obj_ids = acc_obj_ids self.acc_raw = acc_raw + self.accs_derived = [] self.params = params self._df = self.filter_df( @@ -331,10 +333,20 @@ class RoiSet(object): mono_data = [raw_acc.get_one_channel_data(c).data for c in range(0, raw_acc.chroma)] for fcn in derived_channel_functions: der = fcn(raw_acc) # returns patch stack - assert der.shape == mono_data[0].shape - mono_data.append(der.data) + if der.shape != mono_data[0].shape or der.dtype not in ['uint8', 'uint16']: + raise DerivedChannelError( + f'Error processing derived channel {der} with shape {der.shape_dict} and dtype {der.dtype}' + ) + self.accs_derived.append(der) + # combine channels - input_acc = PatchStack(np.concatenate(mono_data, axis=raw_acc._ga('C'))) + data_derived = [acc.data for acc in self.accs_derived] + input_acc = PatchStack( + np.concatenate( + [*mono_data, *data_derived], + axis=raw_acc._ga('C') + ) + ) else: input_acc = raw_acc @@ -595,6 +607,13 @@ class RoiSet(object): fp = subdir / kc / (pr + '.tif') write_accessor_data_to_file(fp, acc) record[f'{k}_{kc}'] = str(fp) + if k == 'derived_channels': + record[k] = [] + for di, dacc in enumerate(self.accs_derived): + fp = subdir / f'dc{di:01d}.tif' + fp.parent.mkdir(exist_ok=True, parents=True) + dacc.export_pyxcz(fp) + record[k].append(str(fp)) # export dataframe and patch masks record = {**record, **self.serialize(where, prefix=prefix)} @@ -696,4 +715,7 @@ class Error(Exception): pass class DeserializeRoiSet(Error): + pass + +class DerivedChannelError(Error): pass \ No newline at end of file diff --git a/tests/test_roiset.py b/tests/test_roiset.py index 91471d6312c1909adc84a05003337e98c7437b52..4785688835c71d1f26ae1b95362b1d3403d83262 100644 --- a/tests/test_roiset.py +++ b/tests/test_roiset.py @@ -219,7 +219,7 @@ class TestRoiSetMonoProducts(BaseTestRoiSetMonoProducts, unittest.TestCase): ModelWithDerivedInputs(), derived_channel_functions=[ lambda acc: PatchStack(2 * acc.data), - lambda acc: PatchStack(0.5 * acc.data) + lambda acc: PatchStack((0.5 * acc.data).astype('uint8')) ] ) self.assertTrue(all(roiset.get_df()['classify_by_multiple_input_model'].unique() == [3])) @@ -228,6 +228,10 @@ class TestRoiSetMonoProducts(BaseTestRoiSetMonoProducts, unittest.TestCase): self.assertEqual(len(roiset.accs_derived), 2) for di in roiset.accs_derived: self.assertEqual(roiset.get_patches_acc().shape, di.shape) + + dpas = roiset.run_exports(output_path / 'derived_channels', 0, 'der', RoiSetExportParams(derived_channels=True)) + for fp in dpas['derived_channels']: + assert Path(fp).exists() return roiset def test_export_object_classes(self):