Skip to content
Snippets Groups Projects
Commit 6debd9c3 authored by Christopher Randolph Rhodes's avatar Christopher Randolph Rhodes
Browse files

Use patch stack accessor in a batch script

parent 4d16baa5
No related branches found
No related tags found
No related merge requests found
import numpy as np
from model_server.accessors import InMemoryDataAccessor
from model_server.accessors import generate_file_accessor, InMemoryDataAccessor
class MonoPatchStack(InMemoryDataAccessor):
......@@ -44,6 +44,11 @@ class MonoPatchStack(InMemoryDataAccessor):
return [self.data[:, :, 0, zi] for zi in range(0, n)]
class MonoPatchStackFromFile(MonoPatchStack):
def __init__(self, fpath):
super().__init__(generate_file_accessor(fpath).data[:, :, 0, :])
class Error(Exception):
pass
......
......@@ -8,6 +8,7 @@ import skimage
import uuid
import vigra
from extensions.chaeo.accessors import MonoPatchStack, MonoPatchStackFromFile
from extensions.ilastik.models import IlastikObjectClassifierFromSegmentationModel
from model_server.accessors import generate_file_accessor, GenericImageDataAccessor, InMemoryDataAccessor, write_accessor_data_to_file
......@@ -17,22 +18,13 @@ class PatchStackObjectClassifier(IlastikObjectClassifierFromSegmentationModel):
as time-series images where each frame contains only one object.
"""
@staticmethod
def make_tczyx(acc: GenericImageDataAccessor):
assert acc.chroma == 1
tyx = np.moveaxis(
acc.data[:, :, 0, :], # YX(C)Z
[2, 0, 1],
[0, 1, 2]
)
return np.expand_dims(tyx, (1, 2))
def infer(self, input_img: GenericImageDataAccessor, segmentation_img: GenericImageDataAccessor) -> (np.ndarray, dict):
assert segmentation_img.is_mask()
assert input_img.chroma == 1
def infer(self, input_acc: MonoPatchStack, segmentation_acc: MonoPatchStack) -> (np.ndarray, dict):
assert segmentation_acc.is_mask()
assert input_acc.chroma == 1
tagged_input_data = vigra.taggedView(self.make_tczyx(input_img), 'tczyx')
tagged_seg_data = vigra.taggedView(self.make_tczyx(segmentation_img), 'tczyx')
tagged_input_data = vigra.taggedView(input_acc.make_tczyx(), 'tczyx')
tagged_seg_data = vigra.taggedView(segmentation_acc.make_tczyx(), 'tczyx')
dsi = [
{
......@@ -46,14 +38,14 @@ class PatchStackObjectClassifier(IlastikObjectClassifierFromSegmentationModel):
assert len(obmaps) == 1, 'ilastik generated more than one object map'
# for some reason ilastik scrambles these axes to Z(1)YX(1)
assert obmaps[0].shape == (input_img.nz, 1, input_img.hw[0], input_img.hw[1], 1)
assert obmaps[0].shape == (input_acc.nz, 1, input_acc.hw[0], input_acc.hw[1], 1)
yxcz = np.moveaxis(
obmaps[0][:, :, :, :, 0],
[2, 3, 1, 0],
[0, 1, 2, 3]
)
assert yxcz.shape == input_img.shape
assert yxcz.shape == input_acc.shape
return InMemoryDataAccessor(data=yxcz), {'success': True}
def get_dataset_info(h5: h5py.File, lane : int = 0):
......@@ -236,16 +228,16 @@ if __name__ == '__main__':
def infer_and_compare_training_set(ilp, suffix):
# infer object labels from the same data used to train the classifier
train_zstack_raw = generate_file_accessor(where_patch_stack / 'zstack_train_raw.tif')
train_zstack_mask = generate_file_accessor(where_patch_stack / 'zstack_train_mask.tif')
train_truth_labels = generate_file_accessor(where_patch_stack / f'zstack_train_label.tif')
train_zstack_raw = MonoPatchStackFromFile(where_patch_stack / 'zstack_train_raw.tif')
train_zstack_mask = MonoPatchStackFromFile(where_patch_stack / 'zstack_train_mask.tif')
train_truth_labels = MonoPatchStackFromFile(where_patch_stack / f'zstack_train_label.tif')
infer_and_compare(ilp, 'train', suffix, train_zstack_raw, train_zstack_mask, train_truth_labels)
def infer_and_compare_test_set(ilp, suffix):
# infer object labels from test dataset
test_zstack_raw = generate_file_accessor(where_patch_stack / 'zstack_test_raw.tif')
test_zstack_mask = generate_file_accessor(where_patch_stack / 'zstack_test_mask.tif')
test_truth_labels = generate_file_accessor(where_patch_stack / f'zstack_test_label.tif')
test_zstack_raw = MonoPatchStackFromFile(where_patch_stack / 'zstack_test_raw.tif')
test_zstack_mask = MonoPatchStackFromFile(where_patch_stack / 'zstack_test_mask.tif')
test_truth_labels = MonoPatchStackFromFile(where_patch_stack / f'zstack_test_label.tif')
infer_and_compare(ilp, 'test', suffix, test_zstack_raw, test_zstack_mask, test_truth_labels)
def infer_and_compare(ilp, prefix, suffix, raw, mask, labels):
......
......@@ -15,13 +15,15 @@ class TestCziImageFileAccess(unittest.TestCase):
h = 512
n = 4
acc = MonoPatchStack(np.random.rand(h, w, n))
assert acc.count == n
assert acc.hw == (h, w)
self.assertEqual(acc.count, n)
self.assertEqual(acc.hw, (h, w))
self.assertEqual(acc.make_tczyx().shape, (n, 1, 1, h, w))
def test_make_patch_stack_from_3d_array(self):
w = 256
h = 512
n = 4
acc = MonoPatchStack([np.random.rand(h, w) for _ in range(0, 4)])
assert acc.count == n
assert acc.hw == (h, w)
\ No newline at end of file
self.assertEqual(acc.count, n)
self.assertEqual(acc.hw, (h, w))
self.assertEqual(acc.make_tczyx().shape, (n, 1, 1, h, w))
\ No newline at end of file
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment