From 3a4cd2808b77bf2deb3b4b2bb5814b5c08afee6a Mon Sep 17 00:00:00 2001
From: Christopher Rhodes <christopher.rhodes@embl.de>
Date: Mon, 12 Aug 2024 13:08:10 +0200
Subject: [PATCH] Removed YOLOv8 wrappers; delint

---
 model_server/base/accessors.py             |  4 +-
 model_server/base/api.py                   | 21 ++++++++++-
 model_server/extensions/ilastik/models.py  |  4 +-
 model_server/extensions/yolov8/__init__.py |  0
 model_server/extensions/yolov8/models.py   | 43 ----------------------
 5 files changed, 23 insertions(+), 49 deletions(-)
 delete mode 100644 model_server/extensions/yolov8/__init__.py
 delete mode 100644 model_server/extensions/yolov8/models.py

diff --git a/model_server/base/accessors.py b/model_server/base/accessors.py
index 9b5f6461..7fc28bb3 100644
--- a/model_server/base/accessors.py
+++ b/model_server/base/accessors.py
@@ -287,7 +287,7 @@ class CziImageFileAccessor(GenericImageFileAccessor):
 
 def write_accessor_data_to_file(fpath: Path, acc: GenericImageDataAccessor, mkdir=True) -> bool:
     """
-    Export an image accessor to file.
+    Export an image accessor to file
     :param fpath: complete path including filename and extension
     :param acc: image accessor to be written
     :param mkdir: create any needed subdirectories in fpath if True
@@ -334,7 +334,7 @@ def write_accessor_data_to_file(fpath: Path, acc: GenericImageDataAccessor, mkdi
 def generate_file_accessor(fpath):
     """
     Given an image file path, return an image accessor, assuming the file is a supported format and represents
-    a single position array, which may be single or multi-channel, single plane or z-stack.
+    a single position array, which may be single or multichannel, single plane or z-stack.
     """
     if str(fpath).upper().endswith('.TIF') or str(fpath).upper().endswith('.TIFF'):
         return TifSingleSeriesFileAccessor(fpath)
diff --git a/model_server/base/api.py b/model_server/base/api.py
index a54bbf6a..610515c2 100644
--- a/model_server/base/api.py
+++ b/model_server/base/api.py
@@ -1,4 +1,3 @@
-from pathlib import Path
 from typing import Union
 
 from fastapi import FastAPI, HTTPException
@@ -8,23 +7,29 @@ from .session import session, AccessorIdError, InvalidPathError, WriteAccessorEr
 app = FastAPI(debug=True)
 
 import model_server.extensions.ilastik.router
+
 app.include_router(model_server.extensions.ilastik.router.router)
 
 import model_server.base.pipelines.segment
+
 app.include_router(model_server.base.pipelines.segment.router)
 
+
 @app.on_event("startup")
 def startup():
     pass
 
+
 @app.get('/')
 def read_root():
     return {'success': True}
 
+
 @app.get('/paths')
 def list_session_paths():
     return session.get_paths()
 
+
 @app.get('/status')
 def show_session_status():
     return {
@@ -33,51 +38,62 @@ def show_session_status():
         'paths': session.get_paths(),
     }
 
+
 def _change_path(key, path):
     try:
         session.set_data_directory(key, path)
     except InvalidPathError as e:
         raise HTTPException(404, f'Did not find valid folder at: {path}')
 
+
 @app.put('/paths/watch_input')
 def watch_input_path(path: str):
     return _change_path('inbound_images', path)
 
+
 @app.put('/paths/watch_output')
 def watch_output_path(path: str):
     return _change_path('outbound_images', path)
 
+
 @app.get('/session/restart')
 def restart_session(root: str = None) -> dict:
     session.restart(root=root)
     return session.describe_loaded_models()
 
+
 @app.get('/session/logs')
 def list_session_log() -> list:
     return session.get_log_data()
 
+
 @app.get('/models')
 def list_active_models():
     return session.describe_loaded_models()
 
+
 @app.get('/accessors')
 def list_accessors():
     return session.list_accessors()
 
+
 def _session_accessor(func, acc_id):
     try:
         return func(acc_id)
     except AccessorIdError as e:
         raise HTTPException(404, f'Did not find accessor with ID {acc_id}')
 
+
 @app.get('/accessors/{accessor_id}')
 def get_accessor(accessor_id: str):
     return _session_accessor(session.get_accessor_info, accessor_id)
 
+
 @app.get('/accessors/delete/{accessor_id}')
 def delete_accessor(accessor_id: str):
     return _session_accessor(session.del_accessor, accessor_id)
 
+
 @app.put('/accessors/read_from_file/{filename}')
 def read_accessor_from_file(filename: str, accessor_id: Union[str, None] = None):
     fp = session.paths['inbound_images'] / filename
@@ -86,6 +102,7 @@ def read_accessor_from_file(filename: str, accessor_id: Union[str, None] = None)
     acc = generate_file_accessor(fp)
     return session.add_accessor(acc, accessor_id=accessor_id)
 
+
 @app.put('/accessors/write_to_file/{accessor_id}')
 def write_accessor_to_file(accessor_id: str, filename: Union[str, None] = None) -> str:
     try:
@@ -93,4 +110,4 @@ def write_accessor_to_file(accessor_id: str, filename: Union[str, None] = None)
     except AccessorIdError as e:
         raise HTTPException(404, f'Did not find accessor with ID {accessor_id}')
     except WriteAccessorError as e:
-        raise HTTPException(409, str(e))
\ No newline at end of file
+        raise HTTPException(409, str(e))
diff --git a/model_server/extensions/ilastik/models.py b/model_server/extensions/ilastik/models.py
index b547deeb..d098e039 100644
--- a/model_server/extensions/ilastik/models.py
+++ b/model_server/extensions/ilastik/models.py
@@ -305,7 +305,7 @@ class IlastikObjectClassifierFromPixelPredictionsModel(IlastikModel, ImageToImag
     def label_instance_class(self, img: GenericImageDataAccessor, pxmap: GenericImageDataAccessor, **kwargs):
         """
         Given an image and a map of pixel probabilities of the same shape, return a map where each connected object is
-        assigned a class.
+        assigned a class
         :param img: input image
         :param pxmap: map of pixel probabilities
         :param kwargs:
@@ -318,7 +318,7 @@ class IlastikObjectClassifierFromPixelPredictionsModel(IlastikModel, ImageToImag
         if not pxmap.data.min() >= 0.0 and pxmap.data.max() <= 1.0:
             raise InvalidInputImageError('Pixel probability values must be between 0.0 and 1.0')
         pxch = kwargs.get('pixel_classification_channel', 0)
-        pxtr = kwargs('pixel_classification_threshold', 0.5)
+        pxtr = kwargs.get('pixel_classification_threshold', 0.5)
         mask = InMemoryDataAccessor(pxmap.get_one_channel_data(pxch).data > pxtr)
         obmap, _ = self.infer(img, mask)
         return obmap
diff --git a/model_server/extensions/yolov8/__init__.py b/model_server/extensions/yolov8/__init__.py
deleted file mode 100644
index e69de29b..00000000
diff --git a/model_server/extensions/yolov8/models.py b/model_server/extensions/yolov8/models.py
deleted file mode 100644
index 1b524054..00000000
--- a/model_server/extensions/yolov8/models.py
+++ /dev/null
@@ -1,43 +0,0 @@
-from typing import List
-
-from pydantic import BaseModel
-from ultralytics import YOLO
-
-from ...base.accessors import GenericImageDataAccessor
-from ...base.models import InstanceSegmentationModel
-from ...base.roiset import RoiSet
-
-class YoloV8Params(BaseModel):
-    pt_file: str
-    duplicate: bool = True
-
-class YoloModel(InstanceSegmentationModel):
-
-    def __init__(self, params: YoloV8Params, autoload=True):
-        # initialize from pretrained model
-        super().__init__(autoload, params)
-
-    def load(self):
-        self.yolo = YOLO(self.params.pt_file)
-
-    def _infer_yolo_seg(self, img:GenericImageDataAccessor) -> RoiSet:
-        # only populates bounding box info
-        pass
-
-    def _infer_yolo_det(self, img:GenericImageDataAccessor) -> RoiSet:
-        # also populates segmentation masks
-        pass
-
-    def label_instance_class(
-            self, img: GenericImageDataAccessor, mask: GenericImageDataAccessor, **kwargs
-    ) -> GenericImageDataAccessor:
-        roiset = self._infer_yolo_seg(img)
-        return roiset.acc_obj_ids
-
-    def export(self):
-        # export pretrained model
-        pass
-
-    def train(self, roisets: List[RoiSet]):
-        coco_list = [r.serialize_coco for r in roisets]
-        self.yolo.train(coco_list)
\ No newline at end of file
-- 
GitLab