From 1387df21aea92a4421d126d7350bc9cefc491b97 Mon Sep 17 00:00:00 2001
From: Christopher Rhodes <christopher.rhodes@embl.de>
Date: Fri, 5 Jul 2024 13:27:23 +0200
Subject: [PATCH] base.util implements test data loading method

---
 model_server/base/util.py          | 40 ++++++++++++++++++++++++++++++
 tests/base/conf.py                 | 31 +++--------------------
 tests/base/test_process.py         |  1 -
 tests/base/test_roiset.py          |  1 -
 tests/base/test_session.py         |  1 -
 tests/test_ilastik/test_ilastik.py |  9 +++++--
 6 files changed, 51 insertions(+), 32 deletions(-)

diff --git a/model_server/base/util.py b/model_server/base/util.py
index 7f335dfc..6185ae21 100644
--- a/model_server/base/util.py
+++ b/model_server/base/util.py
@@ -3,6 +3,7 @@ import unittest
 from math import ceil
 from multiprocessing import Process
 from pathlib import Path
+import os
 import re
 from time import localtime, strftime
 from typing import List
@@ -172,6 +173,10 @@ def loop_workflow(
 
 
 class TestServerBaseClass(unittest.TestCase):
+    """
+    Base class for unittests of API functionality.  Implements both server and clients for testing.
+    """
+
     app_name = 'model_server.base.api:app'
 
     def setUp(self) -> None:
@@ -210,3 +215,38 @@ class TestServerBaseClass(unittest.TestCase):
     def tearDown(self) -> None:
         self.server_process.terminate()
         self.server_process.join()
+
+def setup_test_data():
+    """
+    Look for test data, create test output directory, parse and return meta information
+    :return:
+        meta (dict) of test data and paths
+    """
+    # places to look for test data
+    data_paths = [
+        os.environ.get('UNITTEST_DATA_ROOT'),
+        Path.home() / 'model_server' / 'testing',
+        os.getcwd(),
+    ]
+    root = None
+
+    # look for first instance of summary.json
+    for dp in data_paths:
+        if dp is None:
+            continue
+        sf = (Path(dp) / 'summary.json')
+        if sf.exists():
+            with open(sf, 'r') as fh:
+                meta = json.load(fh)
+                root = Path(dp)
+                break
+
+    if root is None:
+        raise Exception('Could not find test data, try setting environmental variable UNITTEST_DATA_ROOT.')
+
+    op_ev = os.environ.get('UNITTEST_OUTPUT', (root / 'test_output').__str__())
+    output_path = Path(op_ev)
+    output_path.mkdir(parents=True, exist_ok=True)
+    meta['root'] = root.__str__()
+    meta['output_path'] = output_path.__str__()
+    return meta
\ No newline at end of file
diff --git a/tests/base/conf.py b/tests/base/conf.py
index 44bc1ceb..18d49871 100644
--- a/tests/base/conf.py
+++ b/tests/base/conf.py
@@ -1,33 +1,10 @@
-import json
-import os
-from os import environ
 from pathlib import Path
+from model_server.base.util import setup_test_data
 
-# places to look for test data
-data_paths = [
-    environ.get('UNITTEST_DATA_ROOT'),
-    Path.home() / 'model_server' / 'testing',
-    os.getcwd(),
-]
-root = None
+meta = setup_test_data()
 
-# look for first instance of summary.json
-for dp in data_paths:
-    if dp is None:
-        continue
-    sf = (Path(dp) / 'summary.json')
-    if sf.exists():
-        with open(sf, 'r') as fh:
-            meta = json.load(fh)
-            root = Path(dp)
-            break
-
-if root is None:
-    raise Exception('Could not find test data, try setting environmental variable UNITTEST_DATA_ROOT.')
-
-op_ev = environ.get('UNITTEST_OUTPUT', (root / 'test_output').__str__())
-output_path = Path(op_ev)
-output_path.mkdir(parents=True, exist_ok=True)
+root = Path(meta['root'])
+output_path = Path(meta['output_path'])
 
 # resolve paths in test image files
 imgs = meta['image_files']
diff --git a/tests/base/test_process.py b/tests/base/test_process.py
index 9a5b8380..abd5777d 100644
--- a/tests/base/test_process.py
+++ b/tests/base/test_process.py
@@ -52,7 +52,6 @@ class TestMaskLargestObject(unittest.TestCase):
         arr[0:3, 0:3] = 255
         arr[4, 2:5] = 255
         masked = mask_largest_object(arr)
-        print(np.unique(masked))
         self.assertTrue(np.all(np.unique(masked) == [0, 255]))
         self.assertTrue(np.all(masked[:, 3:5] == 0))
         self.assertTrue(np.all(masked[3:5, :] == 0))
diff --git a/tests/base/test_roiset.py b/tests/base/test_roiset.py
index 532b0753..c7bf2752 100644
--- a/tests/base/test_roiset.py
+++ b/tests/base/test_roiset.py
@@ -519,7 +519,6 @@ class TestRoiSetMultichannelProducts(BaseTestRoiSetMonoProducts, unittest.TestCa
             self.assertTrue(pa.exists())
             pacc = generate_file_accessor(pa)
             self.assertEqual(pacc.chroma, 1)
-        print('res')
 
 
 from model_server.base.roiset import _get_label_ids
diff --git a/tests/base/test_session.py b/tests/base/test_session.py
index a7c13c07..d8b3f270 100644
--- a/tests/base/test_session.py
+++ b/tests/base/test_session.py
@@ -26,7 +26,6 @@ class TestGetSessionObject(unittest.TestCase):
 
     def test_change_session_subdirectory(self):
         old_paths = self.sesh.get_paths()
-        print(old_paths)
         self.sesh.set_data_directory('outbound_images', old_paths['inbound_images'])
         self.assertEqual(self.sesh.paths['outbound_images'], self.sesh.paths['inbound_images'])
 
diff --git a/tests/test_ilastik/test_ilastik.py b/tests/test_ilastik/test_ilastik.py
index c01bfb31..5153e854 100644
--- a/tests/test_ilastik/test_ilastik.py
+++ b/tests/test_ilastik/test_ilastik.py
@@ -2,8 +2,6 @@ import unittest
 
 import numpy as np
 
-from tests.base._conf import czifile, output_path, roiset_test_data
-from tests.test_ilastik._conf import ilastik_classifiers
 from model_server.base.accessors import CziImageFileAccessor, generate_file_accessor, InMemoryDataAccessor, PatchStack, write_accessor_data_to_file
 from model_server.extensions.ilastik import models as ilm
 from model_server.extensions.ilastik.workflows import infer_px_then_ob_model
@@ -11,6 +9,13 @@ from model_server.base.roiset import _get_label_ids, RoiSet, RoiSetMetaParams
 from model_server.base.workflows import classify_pixels
 from model_server.base.util import TestServerBaseClass
 
+from tests.test_ilastik import conf
+
+data = conf.meta['image_files']
+output_path = conf.output_path
+params = conf.meta['roiset']
+czifile = conf.meta['image_files']['czifile']
+ilastik_classifiers = conf.meta['ilastik_classifiers']
 
 def _random_int(*args):
     return np.random.randint(0, 2 ** 8, size=args, dtype='uint8')
-- 
GitLab