From dae64e307583191026de750d38a9421f5d90dc71 Mon Sep 17 00:00:00 2001
From: Christopher Rhodes <christopher.rhodes@embl.de>
Date: Fri, 25 Oct 2024 11:50:18 +0200
Subject: [PATCH] Can load binary segmentation model

---
 model_server/base/api.py    | 12 +++++++++++-
 model_server/base/models.py | 25 ++++++++++++++++---------
 tests/base/test_api.py      |  8 +++++++-
 tests/base/test_model.py    |  5 ++---
 4 files changed, 36 insertions(+), 14 deletions(-)

diff --git a/model_server/base/api.py b/model_server/base/api.py
index 608738a1..ba047efe 100644
--- a/model_server/base/api.py
+++ b/model_server/base/api.py
@@ -1,10 +1,11 @@
+from pydantic import BaseModel, Field
 from typing import Union
 
 from fastapi import FastAPI, HTTPException
 from .accessors import generate_file_accessor
+from .models import BinaryThresholdSegmentationModel
 from .session import session, AccessorIdError, InvalidPathError, WriteAccessorError
 
-
 app = FastAPI(debug=True)
 
 from .pipelines.router import router
@@ -68,6 +69,15 @@ def list_session_log() -> list:
 def list_active_models():
     return session.describe_loaded_models()
 
+class BinaryThresholdSegmentationParams(BaseModel):
+    channel: int = Field(None, description='Channel to use for segmentation; use all channels if empty.')
+    tr: Union[int, float] = Field(0.5, description='Threshold for binary segmentation')
+
+@app.put('/models/seg/binary_threshold/load/')
+def load_binary_threshold_model(p: BinaryThresholdSegmentationParams, model_id=None) -> dict:
+    result = session.load_model(BinaryThresholdSegmentationModel, key=model_id, params=p)
+    session.log_info(f'Loaded binary threshold segmentation model {result}')
+    return {'model_id': result}
 
 @app.get('/accessors')
 def list_accessors():
diff --git a/model_server/base/models.py b/model_server/base/models.py
index eaded2c9..346bfaa4 100644
--- a/model_server/base/models.py
+++ b/model_server/base/models.py
@@ -1,15 +1,13 @@
 from abc import ABC, abstractmethod
-from math import floor
 
 import numpy as np
-from pydantic import BaseModel
 
-from .accessors import GenericImageDataAccessor, InMemoryDataAccessor, PatchStack
+from .accessors import GenericImageDataAccessor, PatchStack
 
 
 class Model(ABC):
 
-    def __init__(self, autoload=True, params: BaseModel = None):
+    def __init__(self, autoload=True, params: dict = None):
         """
         Abstract base class for an inference model that uses image data as an input.
 
@@ -18,7 +16,7 @@ class Model(ABC):
         """
         self.autoload = autoload
         if params:
-            self.params = params.dict()
+            self.params = params
         self.loaded = False
         if not autoload:
             return None
@@ -130,15 +128,24 @@ class InstanceSegmentationModel(ImageToImageModel):
 
 
 class BinaryThresholdSegmentationModel(SemanticSegmentationModel):
+    """
+    Trivial but functional model that labels all pixels above an intensity threshold as class 1
+    """
 
-    def __init__(self, tr: float = 0.5):
-        self.tr = tr
+    def __init__(self, params=None):
+        self.tr = params['tr']
+        self.channel = params['channel']
+        self.loaded = True
 
     def infer(self, img: GenericImageDataAccessor) -> (GenericImageDataAccessor, dict):
-        return img.apply(lambda x: x > self.tr), {'success': True}
+        if self.channel:
+            acc = img.get_mono(self.channel)
+        else:
+            acc = img
+        return acc.get_mono(self.channel).apply(lambda x: x > self.tr)
 
     def label_pixel_class(self, img: GenericImageDataAccessor, **kwargs) -> GenericImageDataAccessor:
-        return self.infer(img, **kwargs)[0]
+        return self.infer(img, **kwargs)
 
     def load(self):
         pass
diff --git a/tests/base/test_api.py b/tests/base/test_api.py
index 73a36188..9593b7f3 100644
--- a/tests/base/test_api.py
+++ b/tests/base/test_api.py
@@ -182,4 +182,10 @@ class TestApiFromAutomatedClient(TestServerBaseClass):
         sd = self.assertGetSuccess(f'accessors/{acc_id}')['shape_dict']
         self.assertEqual(self.assertGetSuccess(f'accessors/{acc_id}')['filepath'], '')
         acc_out = self.get_accessor(accessor_id=acc_id, filename='test_output.tif')
-        self.assertEqual(sd, acc_out.shape_dict)
\ No newline at end of file
+        self.assertEqual(sd, acc_out.shape_dict)
+
+    def test_load_binary_segmentation_model(self):
+        mid = self.assertPutSuccess(
+            '/models/seg/binary_threshold/load/', body={'channel': 0, 'tr': 10}
+        )['model_id']
+        return mid
\ No newline at end of file
diff --git a/tests/base/test_model.py b/tests/base/test_model.py
index d975f7cd..38e51c0b 100644
--- a/tests/base/test_model.py
+++ b/tests/base/test_model.py
@@ -56,9 +56,8 @@ class TestCziImageFileAccess(unittest.TestCase):
         return img, mask
 
     def test_binary_segmentation(self):
-        model = BinaryThresholdSegmentationModel(tr=3e4)
-        img = self.cf.get_mono(0)
-        res = model.label_pixel_class(img)
+        model = BinaryThresholdSegmentationModel({'tr': 3e4, 'channel': 0})
+        res = model.label_pixel_class(self.cf)
         self.assertTrue(res.is_mask())
 
     def test_dummy_instance_segmentation(self):
-- 
GitLab