From 1aac54b4732e76908e167a720290d45be269d5f9 Mon Sep 17 00:00:00 2001
From: Christopher Rhodes <christopher.rhodes@embl.de>
Date: Sun, 10 Nov 2024 08:03:07 +0100
Subject: [PATCH] File-backed accessors no longer keep file handle open; lazy
 loading passes in principle

---
 model_server/base/accessors.py | 54 ++++++++++++++++++++--------------
 tests/base/test_accessors.py   | 11 +++++--
 2 files changed, 41 insertions(+), 24 deletions(-)

diff --git a/model_server/base/accessors.py b/model_server/base/accessors.py
index 43d400f9..04daa7a2 100644
--- a/model_server/base/accessors.py
+++ b/model_server/base/accessors.py
@@ -1,5 +1,6 @@
 from abc import ABC, abstractmethod
 import os
+from importlib.metadata import metadata
 from pathlib import Path
 
 import numpy as np
@@ -213,15 +214,29 @@ class GenericImageFileAccessor(GenericImageDataAccessor): # image data is loaded
             raise FileAccessorError(f'Could not find file at {fpath}')
         self.fpath = fpath
 
+        self._data = None
+        self._metadata = None
+
         if not lazy:
-            self._data = self.load()
-        else:
-            self._data = None
+            self.load()
+
 
     @abstractmethod
-    def load(self):
+    def load(self) -> (np.ndarray, dict):
         pass
 
+    @property
+    def data(self):
+        if self._data is None:
+            self.load()
+        return self._data
+
+    @property
+    def metadata(self):
+        if self._metadata is None:
+            self.load()
+        return self._metadata
+
     @staticmethod
     def read(fp: Path):
         return generate_file_accessor(fp)
@@ -238,7 +253,6 @@ class TifSingleSeriesFileAccessor(GenericImageFileAccessor):
 
         try:
             tf = tifffile.TiffFile(fpath)
-            self.tf = tf # TODO: close file connection
         except Exception:
             raise FileAccessorError(f'Unable to access data in {fpath}')
 
@@ -264,12 +278,9 @@ class TifSingleSeriesFileAccessor(GenericImageFileAccessor):
             [axs.index(k) for k in order],
             [0, 1, 2, 3]
         )
-
-        return self.conform_data(yxcz.reshape(yxcz.shape[0:4]))
-
-    # TODO: remove
-    def __del__(self):
-        self.tf.close()
+        tf.close()
+        self._data = self.conform_data(yxcz.reshape(yxcz.shape[0:4]))
+        self._metadata = {}
 
 class PngFileAccessor(GenericImageFileAccessor):
     def load(self):
@@ -281,9 +292,10 @@ class PngFileAccessor(GenericImageFileAccessor):
             FileAccessorError(f'Unable to access data in {fpath}')
 
         if len(arr.shape) == 3: # rgb
-            return np.expand_dims(arr, 3)
+            self._data = np.expand_dims(arr, 3)
         else: # mono
-            return np.expand_dims(arr, (2, 3))
+            self._data = np.expand_dims(arr, (2, 3))
+        self._metadata = {}
 
 class CziImageFileAccessor(GenericImageFileAccessor):
     """
@@ -293,15 +305,14 @@ class CziImageFileAccessor(GenericImageFileAccessor):
     def load(self):
         fpath = self.fpath
         try:
-            # TODO: persist metadata then remove file connection
             cf = czifile.CziFile(fpath)
-            self.czifile = cf
+            metadata = cf.metadata(raw=False)
         except Exception:
             raise FileAccessorError(f'Unable to access CZI data in {fpath}')
 
+        # check for incompatible compression type
         try:
-            md = cf.metadata(raw=False)
-            compmet = md['ImageDocument']['Metadata']['Information']['Image']['OriginalCompressionMethod']
+            compmet = metadata['ImageDocument']['Metadata']['Information']['Image']['OriginalCompressionMethod']
         except KeyError:
             raise InvalidCziCompression('Could not find metadata key OriginalCompressionMethod')
         if compmet.upper() != 'UNCOMPRESSED':
@@ -317,14 +328,13 @@ class CziImageFileAccessor(GenericImageFileAccessor):
             [cf.axes.index(ch) for ch in idx],
             [0, 1, 2, 3]
         )
-        return self.conform_data(yxcz.reshape(yxcz.shape[0:4]))
-
-    def __del__(self):
-        self.czifile.close()
+        cf.close()
+        self._data = self.conform_data(yxcz.reshape(yxcz.shape[0:4]))
+        self._metadata = metadata
 
     @property
     def pixel_scale_in_micrometers(self):
-        scale_meta = self.czifile.metadata(raw=False)['ImageDocument']['Metadata']['Scaling']['Items']['Distance']
+        scale_meta = self.metadata['ImageDocument']['Metadata']['Scaling']['Items']['Distance']
         sc = {}
         for m in scale_meta:
             if m['DefaultUnitFormat'].encode() == b'\xc2\xb5m' and m['Id'] in self.shape_dict.keys():  # literal mu-m
diff --git a/tests/base/test_accessors.py b/tests/base/test_accessors.py
index f1705398..c70e2d67 100644
--- a/tests/base/test_accessors.py
+++ b/tests/base/test_accessors.py
@@ -229,8 +229,15 @@ class TestCziImageFileAccess(unittest.TestCase):
         )
 
     def test_lazy_load(self):
-        cf = generate_file_accessor(data['czifile']['path'])
-        self.assertEqual(1, 0)
+        acc_cf = generate_file_accessor(data['czifile']['path'], lazy=True)
+        self.assertEqual(acc_cf._data, None)
+        self.assertEqual(acc_cf._metadata, None)
+        acc_cf.load()
+        self.assertIsNotNone(acc_cf._data)
+        self.assertIsNotNone(acc_cf._metadata)
+        self.assertIsInstance(acc_cf._data, np.ndarray)
+        self.assertIsInstance(acc_cf._metadata, dict)
+
 
 class TestPatchStackAccessor(unittest.TestCase):
     def setUp(self) -> None:
-- 
GitLab