diff --git a/model_server/extensions/chaeo/params.py b/model_server/extensions/chaeo/params.py index 71b9497dd76ef8f2a116e1573d37d55a22012520..9e06b286498a21bf1115b15f26f99b10547d9808 100644 --- a/model_server/extensions/chaeo/params.py +++ b/model_server/extensions/chaeo/params.py @@ -34,8 +34,8 @@ class RoiSetMetaParams(BaseModel): class RoiSetExportParams(BaseModel): pixel_probabilities: bool = False patches_3d: Union[PatchParams, None] = None - patches_2d_for_annotation: Union[PatchParams, None] = None - patches_2d_for_training: Union[PatchParams, None] = None + annotated_patches_2d: Union[PatchParams, None] = None + patches_2d: Union[PatchParams, None] = None patch_masks: Union[PatchParams, None] = None annotated_zstacks: Union[AnnotatedZStackParams, None] = None object_classes: bool = False diff --git a/model_server/extensions/chaeo/products.py b/model_server/extensions/chaeo/products.py index a99e8c63b2bf9702146d8e3c375a1e998b3129be..bd643df7775d91c556885b51f43f49647ec73b10 100644 --- a/model_server/extensions/chaeo/products.py +++ b/model_server/extensions/chaeo/products.py @@ -227,21 +227,13 @@ def export_patches_from_zstack( make_3d = kwargs.get('make_3d', False) patches_df = get_patches_from_zmask_meta(roiset, **kwargs) - pc = roiset.acc_raw.chroma # TODO: this should follow from generated patches, not roiset - # patches = list(patches_df['patch']) - # if not make_3d and pc == 1: - # patches_acc = MonoPatchStack(patches) - # else: - # patches_acc = Multichannel3dPatchStack(patches) - patches_acc = roiset.get_raw_patches(make_3d=make_3d) - def _export_patch(roi): - patch = patches_acc.iat_yxcz(i) - ext = 'tif' if make_3d or patches_acc.chroma > 3 else 'png' + patch = InMemoryDataAccessor(roi.patch) + ext = 'tif' if make_3d or patch.chroma > 3 else 'png' fname = f'{prefix}-la{roi.label:04d}-zi{roi.zi:04d}.{ext}' if patch.dtype is np.dtype('uint16'): - write_patch_to_file(where, fname, resample_to_8bit(patch)) + write_patch_to_file(where, fname, resample_to_8bit(patch.data)) else: write_patch_to_file(where, fname, patch) @@ -252,7 +244,7 @@ def export_patches_from_zstack( }) exported = [] - for i, roi in enumerate(patches_df.itertuples()): # just used for label info + for roi in patches_df.itertuples(): # just used for label info _export_patch(roi) return exported diff --git a/model_server/extensions/chaeo/tests/test_zstack.py b/model_server/extensions/chaeo/tests/test_zstack.py index 61565ec4006fadf3d5e8007f5b70908922015529..c07deb807d44afc0c13a0b3fa811eaf0d36725bd 100644 --- a/model_server/extensions/chaeo/tests/test_zstack.py +++ b/model_server/extensions/chaeo/tests/test_zstack.py @@ -47,13 +47,14 @@ class TestRoiSetMonoProducts(BaseTestRoiSetMonoProducts, unittest.TestCase): id_map, self.stack_ch_pa, params=RoiSetMetaParams( - mask_type=mask_type, filters=kwargs.get('filters') + mask_type=mask_type, + filters=kwargs.get('filters', {'area': {'min': 1e3, 'max': 1e4}}) ) ) return roiset def test_zmask_makes_correct_boxes(self, **kwargs): - roiset = self._make_roi_set(mask_type='boxes', **kwargs) + roiset = self._make_roi_set(**kwargs) zmask = roiset.get_zmask() meta = roiset.zmask_meta interm = roiset.interm @@ -96,7 +97,7 @@ class TestRoiSetMonoProducts(BaseTestRoiSetMonoProducts, unittest.TestCase): return self._make_roi_set(mask_type='contours') def test_zmask_makes_correct_boxes_with_filters(self): - return self._make_roi_set(filters={'area': {'min': 1e3, 'max': 1e4}}) + return self._make_roi_set() def test_zmask_makes_correct_expanded_boxes(self): return self._make_roi_set(expand_box_by=(64, 2)) @@ -110,18 +111,18 @@ class TestRoiSetMonoProducts(BaseTestRoiSetMonoProducts, unittest.TestCase): def test_zmask_rel_slices_are_valid(self): roiset = self._make_roi_set() - for i, s in enumerate(roiset.get_slices()): - ebb = roiset.acc_raw.data[s] + # for i, s in enumerate(roiset.get_slices()): + for roi in roiset.get_df().itertuples(): + ebb = roiset.acc_raw.data[roi.slice] self.assertEqual(len(ebb.shape), 4) self.assertTrue(np.all([si >= 1 for si in ebb.shape])) - rel_slices = roiset.get_df()['relative_slice'] - rbb = ebb[rel_slices[i]] + rbb = ebb[roi.relative_slice] self.assertEqual(len(rbb.shape), 4) self.assertTrue(np.all([si >= 1 for si in rbb.shape])) def test_make_2d_patches_from_zmask(self): roiset = self._make_roi_set( - filters={'area': {'min': 1e3, 'max': 1e4}}, + # filters={'area': {'min': 1e3, 'max': 1e4}}, expand_box_by=(64, 2) ) files = export_patches_from_zstack( @@ -133,7 +134,7 @@ class TestRoiSetMonoProducts(BaseTestRoiSetMonoProducts, unittest.TestCase): def test_make_3d_patches_from_zmask(self): roiset = self._make_roi_set( - filters={'area': {'min': 1e3, 'max': 1e4}}, + # filters={'area': {'min': 1e3, 'max': 1e4}}, expand_box_by=(64, 2), ) files = export_patches_from_zstack( @@ -142,6 +143,17 @@ class TestRoiSetMonoProducts(BaseTestRoiSetMonoProducts, unittest.TestCase): make_3d=True) self.assertGreaterEqual(len(files), 1) + def test_export_annotated_zstack(self): + roiset = self._make_roi_set( + # filters={'area': {'min': 1e3, 'max': 1e4}}, + expand_box_by=(64, 2), + ) + file = roiset.export_annotated_zstack( + output_path / 'annotated_stack', + ) + result = generate_file_accessor(Path(file['location']) / file['filename']) + self.assertEqual(result.shape, roiset.acc_raw.shape) + def test_flatten_image(self): id_map = get_label_ids(self.seg_mask) @@ -167,7 +179,7 @@ class TestRoiSetMonoProducts(BaseTestRoiSetMonoProducts, unittest.TestCase): def test_make_binary_masks_from_zmask(self): roiset = self._make_roi_set( - filters={'area': {'min': 1e3, 'max': 1e4}}, + # filters={'area': {'min': 1e3, 'max': 1e4}}, expand_box_by=(128, 2) ) files = roiset.export_patch_masks(output_path / '2d_mask_patches', ) @@ -175,11 +187,7 @@ class TestRoiSetMonoProducts(BaseTestRoiSetMonoProducts, unittest.TestCase): def test_classify_by(self): roiset = self._make_roi_set() - roiset.classify_by( - 'dummy_class', - pipeline_params['patches_channel'], - DummyInstanceSegmentationModel() - ) + roiset.classify_by('dummy_class', 0, DummyInstanceSegmentationModel()) self.assertTrue(all(roiset.get_df()['classify_by_dummy_class'].unique() == [1.])) def test_object_map_workflow(self): @@ -214,13 +222,13 @@ class TestRoiSetMonoProducts(BaseTestRoiSetMonoProducts, unittest.TestCase): export_params = RoiSetExportParams(**{ 'pixel_probabilities': True, 'patches_3d': {}, - 'patches_2d_for_annotation': { + 'annotated_patches_2d': { 'draw_bounding_box': True, 'rgb_overlay_channels': [3, None, None], 'rgb_overlay_weights': [0.2, 1.0, 1.0], 'pad_to': 512, }, - 'patches_2d_for_training': { + 'patches_2d': { 'draw_bounding_box': False, 'draw_mask': False, }, diff --git a/model_server/extensions/chaeo/zmask.py b/model_server/extensions/chaeo/zmask.py index 02e5922adc1adadc9dfaef567b90d05cc1aada7c..31d5166b1ecc4ab0a6936ba1975f0b350306e7f8 100644 --- a/model_server/extensions/chaeo/zmask.py +++ b/model_server/extensions/chaeo/zmask.py @@ -167,12 +167,11 @@ class RoiSet(object): self._df[name] = se def get_multichannel_projection(self): # TODO: document and test - dff = self.df[self.df['keeper']] if self.count: projected = project_stack_from_focal_points( - dff['centroid-0'].to_numpy(), - dff['centroid-1'].to_numpy(), - dff['zi'].to_numpy(), + self._df['centroid-0'].to_numpy(), + self._df['centroid-1'].to_numpy(), + self._df['zi'].to_numpy(), self.acc_raw, degree=4, ) @@ -200,6 +199,13 @@ class RoiSet(object): else: return Multichannel3dPatchStack(patches) + def export_annotated_zstack(self, where, prefix='zstack', **kwargs): + annotated = InMemoryDataAccessor( + draw_boxes_on_3d_image(self.acc_raw.data, self.zmask_meta, **kwargs) # TODO remove zmask_meta ref + ) + success = write_accessor_data_to_file(where / (prefix + '.tif'), annotated) + return {'location': where.__str__(), 'filename': prefix + '.tif'} + def get_zmask(self, mask_type='boxes'): """ Return a mask of same dimensionality as raw data @@ -237,6 +243,7 @@ class RoiSet(object): return zi_st + # TODO: channel restriction as an argument def classify_by(self, name: str, channel: int, object_classification_model: InstanceSegmentationModel, ): @@ -248,12 +255,10 @@ class RoiSet(object): lamap = self.object_id_labels om = np.zeros(lamap.shape, dtype=lamap.dtype) - # self.df['instance_class'] = np.nan df = self.get_df() idx = df.index se = pd.Series(data=np.nan, index=idx) - # df['classify_by_' + f] # assign labels to object map: for i in range(0, len(idx)): @@ -281,31 +286,28 @@ class RoiSet(object): files = export_patches_from_zstack( subdir, self, white_channel=channel, prefix=pr, make_3d=True, **kp ) - if k == 'patches_2d_for_annotation': + if k == 'annotated_patches_2d': files = export_multichannel_patches_from_zstack( - subdir, self.acc_raw, self.zmask_meta, prefix=pr, make_3d=False, rgb_white_channel=channel, + subdir, self, prefix=pr, make_3d=False, white_channel=channel, bounding_box_channel=1, bounding_box_linewidth=2, **kp, ) - if k == 'patches_2d_for_training': + if k == 'patches_2d': files = export_multichannel_patches_from_zstack( - subdir, self.acc_raw, self.zmask_meta, rgb_white_channel=channel, prefix=pr, make_3d=False, **kp + subdir, self, white_channel=channel, prefix=pr, make_3d=False, **kp ) df_patches = pd.DataFrame(files) - self.df = pd.merge(self.df, df_patches, left_index=True, right_on='df_index').drop(columns='df_index') - self.df['patch_id'] = self.df.apply(lambda _: uuid4(), axis=1) + self._df = pd.merge(self._df, df_patches, left_index=True, right_on='df_index').drop(columns='df_index') + self._df['patch_id'] = self._df.apply(lambda _: uuid4(), axis=1) if k == 'patch_masks': - self.export_patch_masks(subdir, prefix=pr, **params.patch_masks) + self.export_patch_masks(subdir, prefix=pr, **kp) if k == 'annotated_zstacks': - annotated = InMemoryDataAccessor( - draw_boxes_on_3d_image(raw_ch.data, self.zmask_meta, **kp) # TODO remove zmask_meta ref - ) - write_accessor_data_to_file(subdir / (pr + '.tif'), annotated) + self.export_annotated_zstack(prefix=pr, **kp) if k == 'object_classes': write_accessor_data_to_file(subdir / (pr + '.tif'), self.object_class_map) if k == 'dataframe': dfpa = subdir / (pr + '.csv') dfpa.parent.mkdir(parents=True, exist_ok=True) - self.df.to_csv(dfpa, index=False) + self._df.to_csv(dfpa, index=False) def project_stack_from_focal_points(