Skip to content
Snippets Groups Projects
Commit 36a0ec66 authored by Christopher Randolph Rhodes's avatar Christopher Randolph Rhodes
Browse files

RoiSet now classifies objects on a patch basis via .label_patch_stack method

parent 98be6bb7
No related branches found
No related tags found
2 merge requests!16Completed (de)serialization of RoiSet,!13Patch stack one channel
......@@ -3,7 +3,7 @@ from math import floor
import numpy as np
from model_server.base.accessors import GenericImageDataAccessor, InMemoryDataAccessor
from model_server.base.accessors import GenericImageDataAccessor, InMemoryDataAccessor, PatchStack
class Model(ABC):
......@@ -88,6 +88,16 @@ class InstanceSegmentationModel(ImageToImageModel):
if not img.shape == mask.shape:
raise InvalidInputImageError('Expect input image and mask to be the same shape')
def label_patch_stack(self, img: PatchStack, mask: PatchStack, **kwargs):
"""
Iterative over a patch stack, call inference on each patch
:return: PatchStack of same shape in input
"""
res_data = np.zeros(img.shape, dtype='uint16')
for i in range(0, img.count): # interpret as PYXCZ
res_data[i, :, :, :, :] = self.label_instance_class(img.iat(i), mask.iat(i), **kwargs).data
return PatchStack(res_data)
class DummySemanticSegmentationModel(SemanticSegmentationModel):
......
......@@ -279,7 +279,7 @@ class RoiSet(object):
def classify_by(self, name: str, channel: int, object_classification_model: InstanceSegmentationModel, ):
# do this on a patch basis, i.e. only one object per frame
obmap_patches = object_classification_model.label_instance_class(
obmap_patches = object_classification_model.label_patch_stack(
self.get_raw_patches(channel=channel),
self.get_patch_masks()
)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment