diff --git a/model_server/base/models.py b/model_server/base/models.py index 8413f68d4420fe31add5e25818b4198bdc1a76eb..4fa0849031bcd318a66f49c8843763b0a99c2279 100644 --- a/model_server/base/models.py +++ b/model_server/base/models.py @@ -3,7 +3,7 @@ from math import floor import numpy as np -from model_server.base.accessors import GenericImageDataAccessor, InMemoryDataAccessor +from model_server.base.accessors import GenericImageDataAccessor, InMemoryDataAccessor, PatchStack class Model(ABC): @@ -88,6 +88,16 @@ class InstanceSegmentationModel(ImageToImageModel): if not img.shape == mask.shape: raise InvalidInputImageError('Expect input image and mask to be the same shape') + def label_patch_stack(self, img: PatchStack, mask: PatchStack, **kwargs): + """ + Iterative over a patch stack, call inference on each patch + :return: PatchStack of same shape in input + """ + res_data = np.zeros(img.shape, dtype='uint16') + for i in range(0, img.count): # interpret as PYXCZ + res_data[i, :, :, :, :] = self.label_instance_class(img.iat(i), mask.iat(i), **kwargs).data + return PatchStack(res_data) + class DummySemanticSegmentationModel(SemanticSegmentationModel): diff --git a/model_server/base/roiset.py b/model_server/base/roiset.py index c44023db6622237b5fe40dd82b90d2f161187517..6ef46ed0779e306d0b6b1e4137f065f394d21e57 100644 --- a/model_server/base/roiset.py +++ b/model_server/base/roiset.py @@ -279,7 +279,7 @@ class RoiSet(object): def classify_by(self, name: str, channel: int, object_classification_model: InstanceSegmentationModel, ): # do this on a patch basis, i.e. only one object per frame - obmap_patches = object_classification_model.label_instance_class( + obmap_patches = object_classification_model.label_patch_stack( self.get_raw_patches(channel=channel), self.get_patch_masks() )