diff --git a/model_server/extensions/yolov8/__init__.py b/model_server/extensions/yolov8/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/model_server/extensions/yolov8/models.py b/model_server/extensions/yolov8/models.py new file mode 100644 index 0000000000000000000000000000000000000000..1b5240545b2a5f795661296e31e8a9465c0f62ef --- /dev/null +++ b/model_server/extensions/yolov8/models.py @@ -0,0 +1,43 @@ +from typing import List + +from pydantic import BaseModel +from ultralytics import YOLO + +from ...base.accessors import GenericImageDataAccessor +from ...base.models import InstanceSegmentationModel +from ...base.roiset import RoiSet + +class YoloV8Params(BaseModel): + pt_file: str + duplicate: bool = True + +class YoloModel(InstanceSegmentationModel): + + def __init__(self, params: YoloV8Params, autoload=True): + # initialize from pretrained model + super().__init__(autoload, params) + + def load(self): + self.yolo = YOLO(self.params.pt_file) + + def _infer_yolo_seg(self, img:GenericImageDataAccessor) -> RoiSet: + # only populates bounding box info + pass + + def _infer_yolo_det(self, img:GenericImageDataAccessor) -> RoiSet: + # also populates segmentation masks + pass + + def label_instance_class( + self, img: GenericImageDataAccessor, mask: GenericImageDataAccessor, **kwargs + ) -> GenericImageDataAccessor: + roiset = self._infer_yolo_seg(img) + return roiset.acc_obj_ids + + def export(self): + # export pretrained model + pass + + def train(self, roisets: List[RoiSet]): + coco_list = [r.serialize_coco for r in roisets] + self.yolo.train(coco_list) \ No newline at end of file