Skip to content
Snippets Groups Projects
Commit 2095889d authored by Christopher Randolph Rhodes's avatar Christopher Randolph Rhodes
Browse files

(1) slicing to a single-channel image now returns its own accessor object; (2)...

(1) slicing to a single-channel image now returns its own accessor object; (2) channel slicing now happened at image accessor level, i.e. model is agnostic; (3) TIFF export is now a utility method, not a subclass of accessor
parent 3d5e1685
No related branches found
No related tags found
No related merge requests found
......@@ -8,7 +8,7 @@ from ilastik.workflows.objectClassification.objectClassificationWorkflow import
import numpy as np
from model_server.model import ImageToImageModel, ParameterExpectedError
from model_server.model import GenericImageDataAccessor, ImageToImageModel, ParameterExpectedError
class IlastikImageToImageModel(ImageToImageModel):
......@@ -61,10 +61,10 @@ class IlastikImageToImageModel(ImageToImageModel):
class IlastikPixelClassifierModel(IlastikImageToImageModel):
workflow = PixelClassificationWorkflow
def infer(self, input_img: np.ndarray, channel=None) -> (np.ndarray, dict):
def infer(self, input_img: GenericImageDataAccessor, channel=None) -> (np.ndarray, dict):
dsi = [
{
'Raw Data': PreloadedArrayDatasetInfo(preloaded_array=input_img),
'Raw Data': PreloadedArrayDatasetInfo(preloaded_array=input_img.data),
}
]
pxmap = self.shell.workflow.batchProcessingApplet.run_export(dsi, export_to_array=True)
......
......@@ -26,7 +26,7 @@ class GenericImageDataAccessor(ABC):
return True if self.shape_dict['Z'] > 1 else False
def get_one_channel_data (self, channel: int):
return self.data[:, :, channel, :]
return InMemoryDataAccessor(self.data[:, :, channel, :])
@property
def data(self): # XYCZ enforced
......@@ -36,6 +36,10 @@ class GenericImageDataAccessor(ABC):
def shape_dict(self):
return dict(zip(('X', 'Y', 'C', 'Z'), self.data.shape))
class InMemoryDataAccessor(GenericImageDataAccessor):
def __init__(self, data):
self._data = data
class GenericImageFileAccessor(GenericImageDataAccessor): # image data is loaded from a file
def __init__(self, fpath: Path):
"""
......@@ -80,15 +84,13 @@ class CziImageFileAccessor(GenericImageFileAccessor):
def __del__(self):
self.czifile.close()
class WriteableTiffFileAccessor(GenericImageFileAccessor):
def __init__(self, fpath: Path):
self.fpath = fpath
def write(self, data):
try:
tifffile.imwrite(self.fpath, data)
except:
raise FileWriteError(f'Unable to write data to file')
def write_accessor_data_to_file(fpath: Path, accessor: GenericImageDataAccessor) -> bool:
try:
tifffile.imwrite(fpath, accessor.data)
except:
raise FileWriteError(f'Unable to write data to file')
return True
def generate_file_accessor(fpath):
......
......@@ -45,17 +45,10 @@ class Model(ABC):
pass
@abstractmethod
def infer(self,
img: GenericImageDataAccessor,
channel: int = None
) -> (np.ndarray, dict): # return json describing inference result
def infer(self, img: GenericImageDataAccessor) -> (np.ndarray, dict): # return json describing inference result
if self.autoload:
self.load()
if channel and channel >= img.chroma:
raise ChannelTooHighError(f'Requested channel {channel} but image contains only {img.chroma} channels')
def reload(self):
self.load()
......@@ -66,8 +59,8 @@ class ImageToImageModel(Model):
"""
@abstractmethod
def infer(self, img, channel=None) -> (np.ndarray, dict):
super().infer(img, channel)
def infer(self, img) -> (np.ndarray, dict):
super().infer(img)
class DummyImageToImageModel(ImageToImageModel):
......@@ -76,8 +69,8 @@ class DummyImageToImageModel(ImageToImageModel):
def load(self):
return True
def infer(self, img: GenericImageDataAccessor, channel=None) -> (np.ndarray, dict):
super().infer(img, channel)
def infer(self, img: GenericImageDataAccessor) -> (np.ndarray, dict):
super().infer(img)
w = img.shape_dict['X']
h = img.shape_dict['Y']
result = np.zeros([h, w], dtype='uint8')
......@@ -88,9 +81,6 @@ class DummyImageToImageModel(ImageToImageModel):
class Error(Exception):
pass
class ChannelTooHighError(Error):
pass
class CouldNotLoadModelError(Error):
pass
......
......@@ -5,7 +5,7 @@ Implementation of image analysis work behind API endpoints, without knowledge of
from time import time
from typing import Dict
from model_server.image import generate_file_accessor, WriteableTiffFileAccessor
from model_server.image import generate_file_accessor, write_accessor_data_to_file
from pydantic import BaseModel
......@@ -22,20 +22,20 @@ def infer_image_to_image(fpi, model, where_output, **kwargs) -> dict:
# read image file into memory
# maybe this isn't accurate if e.g. czifile loads lazily
t0 = time()
img = generate_file_accessor(fpi)
ch = kwargs.get('channel')
img = generate_file_accessor(fpi).get_one_channel_data(ch)
dt_fi = time() - t0
# run model inference
# TODO: call this async / await and report out infer status to optional callback
ch = kwargs.get('channel')
outdata, messages = model.infer(img, channel=ch)
outdata, messages = model.infer(img)
dt_inf = time() - t0
# TODO: assert outdata format
# write output to file
outpath = where_output / (img.fpath.stem + '.tif')
WriteableTiffFileAccessor(outpath).write(outdata)
outpath = where_output / (fpi.stem + '.tif')
write_accessor_data_to_file(outpath, outdata)
dt_fo = time() - t0
# TODO: smoother step-timing e.g. w/ decorate
......
......@@ -20,4 +20,6 @@ class TestIlastikPixelClassification(unittest.TestCase):
def test_instantiate_pixel_classifier(self):
model = IlastikPixelClassifierModel({'project_file': ilastik['pixel_classifier']})
cf = CziImageFileAccessor(czifile['path'])
model.infer(cf.get_one_channel_data(2))
import unittest
from conf.testing import czifile, output_path
from model_server.image import CziImageFileAccessor, WriteableTiffFileAccessor
from model_server.image import CziImageFileAccessor, write_accessor_data_to_file
class TestCziImageFileAccess(unittest.TestCase):
def setUp(self) -> None:
......@@ -17,10 +17,12 @@ class TestCziImageFileAccess(unittest.TestCase):
def test_write_single_channel_tif(self):
ch = 2
cf = CziImageFileAccessor(czifile['path'])
of = WriteableTiffFileAccessor(
output_path / f'{cf.fpath.stem}_ch{ch}.tif'
)
mono = cf.get_one_channel_data(2)
of.write(mono)
self.assertTrue(
write_accessor_data_to_file(
output_path / f'{cf.fpath.stem}_ch{ch}.tif',
mono
)
)
self.assertEqual(cf.data.shape[0:2], mono.data.shape[0:2])
self.assertEqual(cf.data.shape[3], mono.data.shape[2])
......@@ -20,14 +20,12 @@ class TestCziImageFileAccess(unittest.TestCase):
def load(self):
return False
self.assertRaises(
CouldNotLoadModelError,
mi=UnloadableDummyImageToImageModel,
)
with self.assertRaises(CouldNotLoadModelError):
mi = UnloadableDummyImageToImageModel()
def test_czifile_is_correct_shape(self):
model = DummyImageToImageModel()
img, _ = model.infer(self.cf, channel=1)
img, _ = model.infer(self.cf)
w = czifile['w']
h = czifile['h']
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment