From e9da5e6abb6e91251531da68e617b3030d1de3ba Mon Sep 17 00:00:00 2001
From: Christopher Rhodes <christopher.rhodes@embl.de>
Date: Thu, 8 Aug 2024 15:42:44 +0200
Subject: [PATCH] Session can add and delete accessors to its scope

---
 model_server/base/accessors.py | 15 ++++++++++++-
 model_server/base/session.py   | 41 +++++++++++++++++++++++++++++++++-
 tests/base/test_session.py     | 17 +++++++++++++-
 3 files changed, 70 insertions(+), 3 deletions(-)

diff --git a/model_server/base/accessors.py b/model_server/base/accessors.py
index 5be57a4b..271b8d4a 100644
--- a/model_server/base/accessors.py
+++ b/model_server/base/accessors.py
@@ -153,6 +153,14 @@ class GenericImageDataAccessor(ABC):
             func(self.data)
         )
 
+    @property
+    def info(self):
+        return {
+            'shape_dict': self.shape_dict,
+            'dtype': self.dtype,
+            'loaded': True,
+        }
+
 class InMemoryDataAccessor(GenericImageDataAccessor):
     def __init__(self, data):
         self._data = self.conform_data(data)
@@ -172,6 +180,12 @@ class GenericImageFileAccessor(GenericImageDataAccessor): # image data is loaded
     def read(fp: Path):
         return generate_file_accessor(fp)
 
+    @property
+    def info(self):
+        d = super().info()
+        d['filepath'] = self.fpath.__str__()
+        return d
+
 class TifSingleSeriesFileAccessor(GenericImageFileAccessor):
     def __init__(self, fpath: Path):
         super().__init__(fpath)
@@ -470,7 +484,6 @@ def make_patch_stack_from_file(fpath):  # interpret t-dimension as patch positio
     return PatchStack(pyxcz)
 
 
-
 class Error(Exception):
     pass
 
diff --git a/model_server/base/session.py b/model_server/base/session.py
index 24c5b31c..3876a769 100644
--- a/model_server/base/session.py
+++ b/model_server/base/session.py
@@ -1,3 +1,4 @@
+from collections import OrderedDict
 import logging
 import os
 
@@ -9,11 +10,11 @@ from typing import Union
 import pandas as pd
 
 from ..conf import defaults
+from .accessors import GenericImageDataAccessor
 from .models import Model
 
 logger = logging.getLogger(__name__)
 
-
 class CsvTable(object):
     def __init__(self, fpath: Path):
         self.path = fpath
@@ -40,6 +41,7 @@ class _Session(object):
     def __init__(self, root: str = None):
         self.models = {} # model_id : model object
         self.paths = self.make_paths(root)
+        self.accessors = OrderedDict()
 
         self.logfile = self.paths['logs'] / f'session.log'
         logging.basicConfig(filename=self.logfile, level=logging.INFO, force=True, format=self.log_format)
@@ -80,6 +82,40 @@ class _Session(object):
             raise InvalidPathError(f'Could not find {path}')
         self.paths[key] = Path(path)
 
+    def add_accessor(self, acc: GenericImageDataAccessor, accessor_id: str = None):
+        """
+        Add an accessor to session context
+        :param acc: accessor to add
+        :param accessor_id: unique ID, or autogenerate if None
+        :return: ID of accessor
+        """
+        if accessor_id in self.accessors.keys():
+            raise AccessorIdError(f'Access with ID {accessor_id} already exists')
+        if accessor_id is None:
+            idx = len(self.accessors)
+            accessor_id = f'auto_{idx:06d}'
+        self.accessors[accessor_id] = acc
+        return accessor_id
+
+    def del_accessor(self, accessor_id: str):
+        """
+        Replace accessor object with its info dictionary
+        :param accessor_id: accessor's ID
+        :return: ID of accessor
+        """
+        if accessor_id not in self.accessors.keys():
+            raise AccessorIdError(f'No accessor with ID {accessor_id} is registered')
+        v = self.accessors[accessor_id]
+        if isinstance(v, dict) and v['loaded'] is False:
+            logger.warning(f'Accessor {accessor_id} is already deleted')
+            info = v
+        else:
+            assert isinstance(v, GenericImageDataAccessor)
+            info = v.info
+            info['loaded'] = False
+            self.accessors[accessor_id] = info
+        return accessor_id
+
     @staticmethod
     def make_paths(root: str = None) -> dict:
         """
@@ -208,6 +244,9 @@ class InferenceRecordError(Error):
 class CouldNotInstantiateModelError(Error):
     pass
 
+class AccessorIdError(Error):
+    pass
+
 class CouldNotCreateDirectory(Error):
     pass
 
diff --git a/tests/base/test_session.py b/tests/base/test_session.py
index 56ce4632..1e333844 100644
--- a/tests/base/test_session.py
+++ b/tests/base/test_session.py
@@ -1,8 +1,9 @@
 from os.path import exists
 import pathlib
-from pydantic import BaseModel
 import unittest
 
+import numpy as np
+from model_server.base.accessors import InMemoryDataAccessor
 from model_server.base.session import session
 
 class TestGetSessionObject(unittest.TestCase):
@@ -85,3 +86,17 @@ class TestGetSessionObject(unittest.TestCase):
         self.assertEqual(len(dfv), len(data))
         self.assertEqual(dfv.columns[0], 'X')
         self.assertEqual(dfv.columns[1], 'Y')
+
+
+    def test_add_and_remove_accessor(self):
+        w = 256
+        h = 512
+        nc = 4
+        nz = 11
+        sh = (h, w, nc, nz)
+        acc = InMemoryDataAccessor(np.random.randint(0, 2 ** 8, size=sh, dtype='uint8'))
+        acc_id = session.add_accessor(acc)
+        shd = acc.shape_dict
+        self.assertEqual(session.accessors[acc_id].shape, sh)
+        session.del_accessor(acc_id)
+        self.assertEqual(session.accessors[acc_id]['shape_dict'], shd)
-- 
GitLab