From 2464806cd18ab4be5187d019d7568cb690b951d2 Mon Sep 17 00:00:00 2001 From: Christopher Rhodes <christopher.rhodes@embl.de> Date: Thu, 8 Aug 2024 11:44:22 +0200 Subject: [PATCH] Implemented and tested method to write intermediate pipeline products to file --- model_server/base/util.py | 49 +++++++++++++++++++++++++++++++++--- tests/base/test_pipelines.py | 8 +++++- 2 files changed, 52 insertions(+), 5 deletions(-) diff --git a/model_server/base/util.py b/model_server/base/util.py index 97e01ab3..f4299979 100644 --- a/model_server/base/util.py +++ b/model_server/base/util.py @@ -171,6 +171,12 @@ class PipelineTrace(OrderedDict): tfunc = perf_counter def __init__(self, enforce_accessors=True, allow_overwrite=False): + """ + A container and timer for data at each stage of a pipeline. + + :param enforce_accessors: if True, only allow accessors to be appended as items + :param allow_overwrite: if True, allow an item to be overwritten + """ self.enforce_accessors = enforce_accessors self.allow_overwrite = allow_overwrite self.last_time = self.tfunc() @@ -187,18 +193,53 @@ class PipelineTrace(OrderedDict): self.last_time = self.tfunc() return super().__setitem__(key, value) - @property - def accessors(self): - return self.items() - @property def times(self): + """ + Return an ordered dictionary of incremental times for each item that is appended + """ return {k: self.timer[k] for k in self.keys()} @property def last(self): + """ + Return most recently appended item + :return: + """ return list(self.values())[-1] + def write_interm( + self, + where: Path, + prefix: str = 'interm', + skip_first=True, + skip_last=True, + debug=False + ) -> List[Path]: + """ + Write accessor data to TIF files under specified path + :param where: directory in which to write image files + :param prefix: (optional) file prefix + :param skip_first: if True, do not write first item in trace + :param skip_last: if False, do not write last item in trace + :param debug: if True, report destination filepaths but do not write files + :return: list of destination filepaths + """ + paths = [] + for i, item in enumerate(self.items()): + k, v = item + if not isinstance(v, GenericImageDataAccessor): + continue + if skip_first and k == list(self.keys())[0]: + continue + if skip_last and k == list(self.keys())[-1]: + continue + fp = where / f'{prefix}_{i:02d}_{k}.tif' + paths.append(fp) + if not debug: + v.write(fp) + return paths + class Error(Exception): pass diff --git a/tests/base/test_pipelines.py b/tests/base/test_pipelines.py index 774b789f..580c5e3d 100644 --- a/tests/base/test_pipelines.py +++ b/tests/base/test_pipelines.py @@ -39,4 +39,10 @@ class TestGetSessionObject(unittest.TestCase): img[0, 0], 0, 'First pixel is not black as expected' - ) \ No newline at end of file + ) + + interm_fps = trace.write_interm( + output_path / 'pipelines' / 'segment_interm', + prefix=czifile['name'] + ) + self.assertTrue([ofp.stem.split('_')[-1] for ofp in interm_fps] == ['mono', 'inference', 'smooth']) \ No newline at end of file -- GitLab