From 0bf3572546900997a4e223874c0935201192c2e9 Mon Sep 17 00:00:00 2001 From: Christopher Rhodes <christopher.rhodes@embl.de> Date: Thu, 15 Aug 2024 16:14:22 +0200 Subject: [PATCH] RoiSet can also return ordered dictionary of its export products as accessors --- model_server/base/api.py | 2 + model_server/base/pipelines/roiset_obmap.py | 3 -- model_server/base/roiset.py | 34 +++++++++++++++++ tests/base/test_roiset.py | 41 +++++++++++++++++++++ 4 files changed, 77 insertions(+), 3 deletions(-) diff --git a/model_server/base/api.py b/model_server/base/api.py index db484d0f..1715d9ab 100644 --- a/model_server/base/api.py +++ b/model_server/base/api.py @@ -107,3 +107,5 @@ def write_accessor_to_file(accessor_id: str, filename: Union[str, None] = None) raise HTTPException(404, f'Did not find accessor with ID {accessor_id}') except WriteAccessorError as e: raise HTTPException(409, str(e)) + +# TODO: endpoint to unload all accessors \ No newline at end of file diff --git a/model_server/base/pipelines/roiset_obmap.py b/model_server/base/pipelines/roiset_obmap.py index c4991aa3..376dbe3c 100644 --- a/model_server/base/pipelines/roiset_obmap.py +++ b/model_server/base/pipelines/roiset_obmap.py @@ -88,11 +88,8 @@ def roiset_object_map(p: RoiSetObjectMapParams) -> RoiSetToObjectMapRecord: """ record, rois = call_pipeline(roiset_object_map_pipeline, p) - # TODO: try labeling this correctly as input accessor_id instead of filename - assert isinstance(rois, RoiSet) table = rois.get_serializable_dataframe() - # TODO: instead, explicitly let trace include RoiSets session.write_to_table('RoiSet', {'input_filename': p.accessor_id}, table) ret = RoiSetToObjectMapRecord( roiset_table=table.to_dict(), diff --git a/model_server/base/roiset.py b/model_server/base/roiset.py index bccbcb05..d5a78f77 100644 --- a/model_server/base/roiset.py +++ b/model_server/base/roiset.py @@ -1,3 +1,4 @@ +from collections import OrderedDict from itertools import combinations from math import sqrt, floor from pathlib import Path @@ -772,6 +773,39 @@ class RoiSet(object): return record + def get_export_product_accessors(self, channel, params: RoiSetExportParams) -> dict: + """ + Return various representations of ROIs, e.g. patches, annotated stacks, and object maps, as accessors + :param channel: color channel of products to export + :param params: RoiSetExportParams object describing which products to export and with which parameters + :return: ordered dict of accessors containing the specified products + """ + interm = OrderedDict() + if not self.count: + return interm + + for k, kp in params.dict().items(): + if kp is None: + continue + if k == 'patches_3d': + interm[k] = self.get_patches_acc([channel], make_3d=True, **kp) + if k == 'annotated_patches_2d': + interm[k] = self.get_patches_acc( + make_3d=False, white_channel=channel, + bounding_box_channel=1, bounding_box_linewidth=2, **kp + ) + if k == 'patches_2d': + interm[k] = self.get_patches_acc(make_3d=False, white_channel=channel, **kp) + if k == 'annotated_zstacks': + interm[k] = InMemoryDataAccessor(draw_boxes_on_3d_image(self, **kp)) + if k == 'object_classes': + pr = 'classify_by_' + cnames = [c.split(pr)[1] for c in self._df.columns if c.startswith(pr)] + for n in cnames: + interm[f'{k}_{n}'] = self.get_object_class_map(n) + + return interm + def serialize(self, where: Path, prefix='') -> dict: """ Export the minimal information needed to recreate RoiSet object, i.e. CSV data file and tight patch masks diff --git a/tests/base/test_roiset.py b/tests/base/test_roiset.py index 44486110..32a2a4b5 100644 --- a/tests/base/test_roiset.py +++ b/tests/base/test_roiset.py @@ -367,6 +367,7 @@ class TestRoiSetMultichannelProducts(BaseTestRoiSetMonoProducts, unittest.TestCa self.assertEqual(result.nz, self.roiset.acc_raw.nz) self.assertEqual(result.chroma, 1) + def test_run_exports(self): p = RoiSetExportParams(**{ 'patches_3d': {}, @@ -413,6 +414,46 @@ class TestRoiSetMultichannelProducts(BaseTestRoiSetMonoProducts, unittest.TestCa for f in test_df[c]: self.assertTrue((where / f).exists(), where / f) + def test_get_interm_prods(self): + p = RoiSetExportParams(**{ + 'patches_3d': None, + 'annotated_patches_2d': { + 'draw_bounding_box': True, + 'rgb_overlay_channels': [3, None, None], + 'rgb_overlay_weights': [0.2, 1.0, 1.0], + 'pad_to': 512, + }, + 'patches_2d': { + 'draw_bounding_box': False, + 'draw_mask': False, + }, + 'annotated_zstacks': {}, + 'object_classes': True, + }) + self.roiset.classify_by('dummy_class', [0], DummyInstanceSegmentationModel()) + interm = self.roiset.get_export_product_accessors( + channel=3, + params=p + ) + self.assertNotIn('patches_3d', interm.keys()) + self.assertEqual( + interm['annotated_patches_2d'].hw, + (self.roiset.get_df().h.max(), self.roiset.get_df().w.max()) + ) + self.assertEqual( + interm['patches_2d'].hw, + (self.roiset.get_df().h.max(), self.roiset.get_df().w.max()) + ) + self.assertEqual( + interm['annotated_zstacks'].hw, + self.stack.hw + ) + self.assertEqual( + interm['object_classes_dummy_class'].hw, + self.stack.hw + ) + self.assertTrue(np.all(interm['object_classes_dummy_class'].unique()[0] == [0, 1])) + def test_run_export_expanded_2d_patch(self): p = RoiSetExportParams(**{ 'patches_2d': { -- GitLab