From e147a54f3e84a0a721f17dce733f0b43a6ed2be6 Mon Sep 17 00:00:00 2001
From: Christopher Rhodes <christopher.rhodes@embl.de>
Date: Thu, 31 Aug 2023 15:50:18 +0200
Subject: [PATCH] Session returns JSON-compliant model list

---
 api.py                  | 4 ++--
 model_server/session.py | 9 +++++++--
 tests/test_api.py       | 8 +++-----
 tests/test_session.py   | 7 ++++---
 4 files changed, 16 insertions(+), 12 deletions(-)

diff --git a/api.py b/api.py
index 9518066f..ad40f5ef 100644
--- a/api.py
+++ b/api.py
@@ -23,8 +23,8 @@ def list_active_models():
     return session.describe_loaded_models()
 
 @app.put('/models/dummy/load/')
-def load_dummy_model(params: str = None) -> dict:
-    return session.load_model(DummyImageToImageModel, params)
+def load_dummy_model() -> dict:
+    return session.load_model(DummyImageToImageModel)
 
 @app.put('/models/ilastik/pixel_classification/load/')
 def load_ilastik_pixel_classification_model(params: str) -> dict:
diff --git a/model_server/session.py b/model_server/session.py
index 0ce71e13..6db22003 100644
--- a/model_server/session.py
+++ b/model_server/session.py
@@ -85,8 +85,13 @@ class Session(object):
         return self.describe_loaded_models()
 
     def describe_loaded_models(self) -> dict:
-        # TODO: explictly make this JSON-compatible
-        return self.models
+        return {
+            k: {
+                'class': self.models[k]['object'].__class__.__name__,
+                'params': self.models[k]['params'],
+            }
+            for k in self.models.keys()
+        }
 
     def restart(self):
         self.__init__()
diff --git a/tests/test_api.py b/tests/test_api.py
index 4e7c4819..d4f7d266 100644
--- a/tests/test_api.py
+++ b/tests/test_api.py
@@ -47,17 +47,15 @@ class TestApiFromAutomatedClient(TestServerBaseClass):
         self.assertEqual(resp.content, b'{}')
 
     def test_load_dummy_model(self):
-        model_id = DummyImageToImageModel.model_id
+        model_key = DummyImageToImageModel.__name__ + '_00'
         resp_load = requests.put(
             self.uri + f'models/dummy/load',
-            # params={'misc': {'d': 'e'}}
         )
         self.assertEqual(resp_load.status_code, 200, resp_load.json())
         resp_list = requests.get(self.uri + 'models')
         self.assertEqual(resp_list.status_code, 200)
         rj = resp_list.json()
-        self.assertEqual(rj[model_id]['class'], 'DummyImageToImageModel')
-        return model_id
+        self.assertEqual(rj[model_key]['class'], 'DummyImageToImageModel')
 
     def test_respond_with_error_when_invalid_model_loaded(self):
         model_id = 'not_a_real_model'
@@ -69,7 +67,7 @@ class TestApiFromAutomatedClient(TestServerBaseClass):
         print(resp.content)
 
     def test_respond_with_error_when_invalid_filepath_requested(self):
-        model_id = self.test_load_dummy_model()
+        # model_id = self.test_load_dummy_model()
         resp = requests.put(
             self.uri + f'i2i/infer/',
             params={
diff --git a/tests/test_session.py b/tests/test_session.py
index 2cfead8e..6a15c68b 100644
--- a/tests/test_session.py
+++ b/tests/test_session.py
@@ -44,12 +44,13 @@ class TestGetSessionObject(unittest.TestCase):
         success = sesh.load_model(MC)
         self.assertTrue(success)
         loaded_models = sesh.describe_loaded_models()
+        print(loaded_models)
         self.assertTrue(
             (MC.__name__ + '_00') in loaded_models.keys()
         )
         self.assertEqual(
-            loaded_models[MC.__name__ + '_00']['object'].__class__,
-            MC
+            loaded_models[MC.__name__ + '_00']['class'],
+            MC.__name__
         )
 
     def test_session_loads_second_instance_of_same_model(self):
@@ -69,4 +70,4 @@ class TestGetSessionObject(unittest.TestCase):
         success = sesh.load_model(MC, params=p)
         self.assertTrue(success)
         loaded_models = sesh.describe_loaded_models()
-        self.assertEqual(loaded_models[MC.__name__ + '_00']['params'], p)
\ No newline at end of file
+        self.assertEqual(loaded_models[MC.__name__ + '_00']['params'], p)
-- 
GitLab