diff --git a/model_server/base/roiset.py b/model_server/base/roiset.py index e5ec6708425c76ff00856d2221ed2087e9bf2e41..ffb8418f9429be3160ab9ef3c701b216fd1e82d5 100644 --- a/model_server/base/roiset.py +++ b/model_server/base/roiset.py @@ -475,6 +475,7 @@ class RoiSet(object): return dfe def run_exports(self, where, channel, prefix, params: RoiSetExportParams): + record = {} if not self.count: return raw_ch = self.acc_raw.get_one_channel_data(channel) @@ -505,12 +506,17 @@ class RoiSet(object): if k == 'annotated_zstacks': self.export_annotated_zstack(subdir, prefix=pr, **kp) if k == 'object_classes': - for k, acc in self.object_class_maps.items(): - write_accessor_data_to_file(subdir / k / (pr + '.tif'), acc) + record[k] = {} + for kc, acc in self.object_class_maps.items(): + fp = subdir / kc / (pr + '.tif') + write_accessor_data_to_file(fp, acc) + record[k][kc] = fp.__str__() if k == 'dataframe': dfpa = subdir / (pr + '.csv') dfpa.parent.mkdir(parents=True, exist_ok=True) self._df.to_csv(dfpa, index=False) + record[k] = dfpa + return record def project_stack_from_focal_points( diff --git a/tests/test_roiset.py b/tests/test_roiset.py index bbe0ee0cd0a073ec43b9eab8b8624f82b77a190c..4c95552848bcc334696e96ff8584c3bfbc03398b 100644 --- a/tests/test_roiset.py +++ b/tests/test_roiset.py @@ -5,7 +5,7 @@ from pathlib import Path from model_server.conf.testing import output_path, roiset_test_data -from model_server.base.roiset import RoiSetMetaParams +from model_server.base.roiset import RoiSetExportParams, RoiSetMetaParams from model_server.base.roiset import _get_label_ids, RoiSet from model_server.base.accessors import generate_file_accessor, InMemoryDataAccessor, write_accessor_data_to_file from model_server.base.models import DummyInstanceSegmentationModel @@ -139,6 +139,20 @@ class TestRoiSetMonoProducts(BaseTestRoiSetMonoProducts, unittest.TestCase): roiset.classify_by('dummy_class', 0, DummyInstanceSegmentationModel()) self.assertTrue(all(roiset.get_df()['classify_by_dummy_class'].unique() == [1])) self.assertTrue(all(np.unique(roiset.object_class_maps['dummy_class'].data) == [0, 1])) + return roiset + + def test_export_object_classes(self): + record = self.test_classify_by().run_exports( + output_path / 'object_class_maps', + 0, + 'obmap', + RoiSetExportParams(object_classes=True) + ) + opa = record['object_classes']['dummy_class'] + acc = generate_file_accessor(opa) + self.assertTrue(Path(opa).exists()) + self.assertTrue(all(np.unique(acc.data) == [0, 1])) + def test_raw_patches_are_correct_shape(self): roiset = self._make_roi_set()