Skip to content
Snippets Groups Projects
Commit cce43209 authored by Christopher Randolph Rhodes's avatar Christopher Randolph Rhodes
Browse files

Merge branch 'patch_stack_one_channel' into 'staging'

Patch stack one channel

See merge request !13
parents d318df9a 36a0ec66
No related branches found
No related tags found
2 merge requests!16Completed (de)serialization of RoiSet,!13Patch stack one channel
...@@ -3,7 +3,7 @@ from math import floor ...@@ -3,7 +3,7 @@ from math import floor
import numpy as np import numpy as np
from model_server.base.accessors import GenericImageDataAccessor, InMemoryDataAccessor from model_server.base.accessors import GenericImageDataAccessor, InMemoryDataAccessor, PatchStack
class Model(ABC): class Model(ABC):
...@@ -88,6 +88,16 @@ class InstanceSegmentationModel(ImageToImageModel): ...@@ -88,6 +88,16 @@ class InstanceSegmentationModel(ImageToImageModel):
if not img.shape == mask.shape: if not img.shape == mask.shape:
raise InvalidInputImageError('Expect input image and mask to be the same shape') raise InvalidInputImageError('Expect input image and mask to be the same shape')
def label_patch_stack(self, img: PatchStack, mask: PatchStack, **kwargs):
"""
Iterative over a patch stack, call inference on each patch
:return: PatchStack of same shape in input
"""
res_data = np.zeros(img.shape, dtype='uint16')
for i in range(0, img.count): # interpret as PYXCZ
res_data[i, :, :, :, :] = self.label_instance_class(img.iat(i), mask.iat(i), **kwargs).data
return PatchStack(res_data)
class DummySemanticSegmentationModel(SemanticSegmentationModel): class DummySemanticSegmentationModel(SemanticSegmentationModel):
......
...@@ -279,7 +279,7 @@ class RoiSet(object): ...@@ -279,7 +279,7 @@ class RoiSet(object):
def classify_by(self, name: str, channel: int, object_classification_model: InstanceSegmentationModel, ): def classify_by(self, name: str, channel: int, object_classification_model: InstanceSegmentationModel, ):
# do this on a patch basis, i.e. only one object per frame # do this on a patch basis, i.e. only one object per frame
obmap_patches = object_classification_model.label_instance_class( obmap_patches = object_classification_model.label_patch_stack(
self.get_raw_patches(channel=channel), self.get_raw_patches(channel=channel),
self.get_patch_masks() self.get_patch_masks()
) )
......
...@@ -252,7 +252,12 @@ class IlastikObjectClassifierFromPixelPredictionsModel(IlastikModel, ImageToImag ...@@ -252,7 +252,12 @@ class IlastikObjectClassifierFromPixelPredictionsModel(IlastikModel, ImageToImag
def label_instance_class( def label_instance_class(
self, img: GenericImageDataAccessor, mask: GenericImageDataAccessor, **kwargs self, img: GenericImageDataAccessor, mask: GenericImageDataAccessor, **kwargs
) -> GenericImageDataAccessor: ) -> GenericImageDataAccessor:
return super().label_instance_class(img, mask, pixel_classification_channel=px_ch) if mask.dtype == 'bool':
norm_mask = 1.0 * mask.data
else:
norm_mask = mask.data / np.iinfo(mask.dtype).max
norm_mask_acc = InMemoryDataAccessor(norm_mask.astype('float32'))
return super().label_instance_class(img, norm_mask_acc, pixel_classification_channel=px_ch)
return _Mod(params={'project_file': self.project_file}) return _Mod(params={'project_file': self.project_file})
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment