From e9da5e6abb6e91251531da68e617b3030d1de3ba Mon Sep 17 00:00:00 2001 From: Christopher Rhodes <christopher.rhodes@embl.de> Date: Thu, 8 Aug 2024 15:42:44 +0200 Subject: [PATCH] Session can add and delete accessors to its scope --- model_server/base/accessors.py | 15 ++++++++++++- model_server/base/session.py | 41 +++++++++++++++++++++++++++++++++- tests/base/test_session.py | 17 +++++++++++++- 3 files changed, 70 insertions(+), 3 deletions(-) diff --git a/model_server/base/accessors.py b/model_server/base/accessors.py index 5be57a4b..271b8d4a 100644 --- a/model_server/base/accessors.py +++ b/model_server/base/accessors.py @@ -153,6 +153,14 @@ class GenericImageDataAccessor(ABC): func(self.data) ) + @property + def info(self): + return { + 'shape_dict': self.shape_dict, + 'dtype': self.dtype, + 'loaded': True, + } + class InMemoryDataAccessor(GenericImageDataAccessor): def __init__(self, data): self._data = self.conform_data(data) @@ -172,6 +180,12 @@ class GenericImageFileAccessor(GenericImageDataAccessor): # image data is loaded def read(fp: Path): return generate_file_accessor(fp) + @property + def info(self): + d = super().info() + d['filepath'] = self.fpath.__str__() + return d + class TifSingleSeriesFileAccessor(GenericImageFileAccessor): def __init__(self, fpath: Path): super().__init__(fpath) @@ -470,7 +484,6 @@ def make_patch_stack_from_file(fpath): # interpret t-dimension as patch positio return PatchStack(pyxcz) - class Error(Exception): pass diff --git a/model_server/base/session.py b/model_server/base/session.py index 24c5b31c..3876a769 100644 --- a/model_server/base/session.py +++ b/model_server/base/session.py @@ -1,3 +1,4 @@ +from collections import OrderedDict import logging import os @@ -9,11 +10,11 @@ from typing import Union import pandas as pd from ..conf import defaults +from .accessors import GenericImageDataAccessor from .models import Model logger = logging.getLogger(__name__) - class CsvTable(object): def __init__(self, fpath: Path): self.path = fpath @@ -40,6 +41,7 @@ class _Session(object): def __init__(self, root: str = None): self.models = {} # model_id : model object self.paths = self.make_paths(root) + self.accessors = OrderedDict() self.logfile = self.paths['logs'] / f'session.log' logging.basicConfig(filename=self.logfile, level=logging.INFO, force=True, format=self.log_format) @@ -80,6 +82,40 @@ class _Session(object): raise InvalidPathError(f'Could not find {path}') self.paths[key] = Path(path) + def add_accessor(self, acc: GenericImageDataAccessor, accessor_id: str = None): + """ + Add an accessor to session context + :param acc: accessor to add + :param accessor_id: unique ID, or autogenerate if None + :return: ID of accessor + """ + if accessor_id in self.accessors.keys(): + raise AccessorIdError(f'Access with ID {accessor_id} already exists') + if accessor_id is None: + idx = len(self.accessors) + accessor_id = f'auto_{idx:06d}' + self.accessors[accessor_id] = acc + return accessor_id + + def del_accessor(self, accessor_id: str): + """ + Replace accessor object with its info dictionary + :param accessor_id: accessor's ID + :return: ID of accessor + """ + if accessor_id not in self.accessors.keys(): + raise AccessorIdError(f'No accessor with ID {accessor_id} is registered') + v = self.accessors[accessor_id] + if isinstance(v, dict) and v['loaded'] is False: + logger.warning(f'Accessor {accessor_id} is already deleted') + info = v + else: + assert isinstance(v, GenericImageDataAccessor) + info = v.info + info['loaded'] = False + self.accessors[accessor_id] = info + return accessor_id + @staticmethod def make_paths(root: str = None) -> dict: """ @@ -208,6 +244,9 @@ class InferenceRecordError(Error): class CouldNotInstantiateModelError(Error): pass +class AccessorIdError(Error): + pass + class CouldNotCreateDirectory(Error): pass diff --git a/tests/base/test_session.py b/tests/base/test_session.py index 56ce4632..1e333844 100644 --- a/tests/base/test_session.py +++ b/tests/base/test_session.py @@ -1,8 +1,9 @@ from os.path import exists import pathlib -from pydantic import BaseModel import unittest +import numpy as np +from model_server.base.accessors import InMemoryDataAccessor from model_server.base.session import session class TestGetSessionObject(unittest.TestCase): @@ -85,3 +86,17 @@ class TestGetSessionObject(unittest.TestCase): self.assertEqual(len(dfv), len(data)) self.assertEqual(dfv.columns[0], 'X') self.assertEqual(dfv.columns[1], 'Y') + + + def test_add_and_remove_accessor(self): + w = 256 + h = 512 + nc = 4 + nz = 11 + sh = (h, w, nc, nz) + acc = InMemoryDataAccessor(np.random.randint(0, 2 ** 8, size=sh, dtype='uint8')) + acc_id = session.add_accessor(acc) + shd = acc.shape_dict + self.assertEqual(session.accessors[acc_id].shape, sh) + session.del_accessor(acc_id) + self.assertEqual(session.accessors[acc_id]['shape_dict'], shd) -- GitLab