Skip to content
Snippets Groups Projects
Commit 6942c044 authored by Christopher Randolph Rhodes's avatar Christopher Randolph Rhodes
Browse files

Needed to modify model wrapper to handle patch stacks

parent 1a60e61d
No related branches found
No related tags found
No related merge requests found
......@@ -5,10 +5,51 @@ import json
import numpy as np
import pandas as pd
import uuid
import vigra
from extensions.chaeo.util import autonumber_new_file
from extensions.ilastik.models import IlastikObjectClassifierFromSegmentationModel
from model_server.accessors import generate_file_accessor, write_accessor_data_to_file
from model_server.accessors import generate_file_accessor, GenericImageDataAccessor, InMemoryDataAccessor, write_accessor_data_to_file
class PatchStackObjectClassifier(IlastikObjectClassifierFromSegmentationModel):
@staticmethod
def make_tczyx(acc):
assert acc.chroma == 1
tyx = np.moveaxis(
acc.data[:, :, 0, :], # YX(C)Z
[2, 0, 1],
[0, 1, 2]
)
return np.expand_dims(tyx, (1, 2))
# return tyx
def infer(self, input_img: GenericImageDataAccessor, segmentation_img: GenericImageDataAccessor) -> (np.ndarray, dict):
assert segmentation_img.is_mask()
assert input_img.chroma == 1
tagged_input_data = vigra.taggedView(self.make_tczyx(input_img), 'tczyx')
tagged_seg_data = vigra.taggedView(self.make_tczyx(segmentation_img), 'tczyx')
dsi = [
{
'Raw Data': self.PreloadedArrayDatasetInfo(preloaded_array=tagged_input_data),
'Segmentation Image': self.PreloadedArrayDatasetInfo(preloaded_array=tagged_seg_data),
}
]
obmaps = self.shell.workflow.batchProcessingApplet.run_export(dsi, export_to_array=True) # [z x h x w x n]
assert len(obmaps) == 1, 'ilastik generated more than one object map'
assert obmaps[0].shape == (input_img.nz, 1, input_img.hw[0], input_img.hw[1], 1) # z(1)yx(1)
yxcz = np.moveaxis(
obmaps[0][:, :, :, :, 0],
[2, 3, 1, 0],
[0, 1, 2, 3]
)
assert yxcz.shape == input_img.shape
return InMemoryDataAccessor(data=yxcz), {'success': True}
def get_dataset_info(h5, lane=0):
lns = f'{lane:04d}'
......@@ -154,8 +195,8 @@ if __name__ == '__main__':
train_zstack_mask = generate_file_accessor(where_patch_stack / 'zstack_train_mask.tif')
new_ilp = root / 'exp0014/test_obj_from_seg.ilp'
mod = IlastikObjectClassifierFromSegmentationModel({'project_file': new_ilp})
mod = PatchStackObjectClassifier({'project_file': new_ilp})
result = mod.infer(train_zstack_raw, train_zstack_mask)
write_accessor_data_to_file(where_patch_stack / 'result.tif', result)
print(mod.project_file_abspath)
\ No newline at end of file
result_acc, _ = mod.infer(train_zstack_raw, train_zstack_mask)
write_accessor_data_to_file(where_patch_stack / 'result.tif', result_acc)
print(where_patch_stack / 'result.tif')
\ No newline at end of file
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment