Skip to content
Snippets Groups Projects
Commit ca9709e6 authored by Christopher Randolph Rhodes's avatar Christopher Randolph Rhodes
Browse files

Extended pydantic data validation, evaluated output image in several tests

parent 4414ae82
No related branches found
No related tags found
No related merge requests found
from pathlib import Path
from time import time
from fastapi import FastAPI, HTTPException
from image import generate_file_accessor, WriteableTiffFileAccessor
from session import Session
from workflow import infer_image_to_image
......@@ -41,12 +39,15 @@ def infer_img(model_id: str, imgf: str, channel: int = None) -> dict:
detail=f'Model {model_id} has not been loaded'
)
return infer_image_to_image(
# TODO: try block workflow, catch and redirect HTTP
record = infer_image_to_image(
session.inbound / imgf,
session.models[model_id],
session.outbound,
channel=channel
)
session.record_workflow_run(record)
return record
# model = session.models[model_id]
#
......
import os
from pathlib import Path
import numpy as np
import czifile
import tifffile
......@@ -15,7 +17,7 @@ class CziImageFileAccessor(GenericImageFileAccessor):
self.fpath = fpath
try:
cf = czifile.CziFile(fpath)
self.data = cf.asarray()
self.czifile = cf
except:
raise FileAccessorError(f'Unable to access data in {fpath}')
......@@ -26,6 +28,15 @@ class CziImageFileAccessor(GenericImageFileAccessor):
self.shape_dict = sd
self.chroma = sd['C']
self.data = np.moveaxis(
cf.asarray(),
[cf.axes.index(ch) for ch in ['X', 'Y', 'C', 'Z']],
[0, 1, 2, 3]
)
def __del__(self):
self.czifile.close()
class WriteableTiffFileAccessor(GenericImageFileAccessor):
def __init__(self, fpath: Path):
self.fpath = fpath
......
......@@ -4,15 +4,34 @@ import numpy as np
from model_server.image import GenericImageFileAccessor
# TODO: properly abstractify base class
class Model(object):
def __init__(self, autoload=True):
"""
Abstract base class for an inference model that uses image data as an input.
:param autoload: automatically load model and dependencies into memory if True
"""
self.autoload = autoload
# abstract
def load(self):
pass
def infer(self, img: GenericImageFileAccessor, channel:int=None) -> (np.ndarray, dict): # return json describing inference result
def infer(self,
img: GenericImageFileAccessor,
channel: int = None
) -> (np.ndarray, dict): # return json describing inference result
if self.autoload:
self.load()
if channel and channel >= img.chroma:
raise ChannelTooHighError(f'Requested channel {channel} but image contains only {img.chroma} channels')
def reload(self):
pass
self.load()
class ImageToImageModel(Model): # receives an image and returns an image of the same size
def infer(self, img, channel=None) -> (np.ndarray, dict):
......@@ -22,6 +41,9 @@ class IlastikImageToImageModel(ImageToImageModel):
pass
class DummyImageToImageModel(Model):
model_id = 'dummy_make_white_square'
def load(self):
self.loaded = True
......
......@@ -6,13 +6,14 @@ from time import strftime, localtime
from conf.server import paths
from model_server.share import SharedImageDirectory
from model_server.workflow import WorkflowRunRecord
def create_manifest_json():
pass
class Session(object):
"""
Singleton class for persisting data between API calls as a server session
Singleton class for a server session that persists data between API calls
"""
inbound = SharedImageDirectory(paths['images']['inbound'])
outbound = SharedImageDirectory(paths['images']['outbound'])
......@@ -31,7 +32,8 @@ class Session(object):
self.session_log = self.where_records / f'{self.session_id}.log'
self.log_event('Initialized session')
self.manifest_json = self.where_records / f'{self.session_id}-manifest.json'
self.record_inference(None) # instantiate empty file
open(self.manifest_json, 'w').close() # instantiate empty json file
@staticmethod
def create_session_id(look_where: Path) -> str:
......@@ -52,12 +54,12 @@ class Session(object):
with open(self.session_log, 'w+') as fh:
fh.write(f'{timestamp} -- {event}')
def record_inference(self, record: dict):
def record_workflow_run(self, record: WorkflowRunRecord or None):
"""
Append a JSON describing inference data to this session's manifest
"""
with open(self.manifest_json, 'w+') as fh:
json.dump(record, fh)
json.dump(record.dict(), fh)
def restart(self):
self.__init__()
......
......@@ -3,9 +3,21 @@ Implementation of image analysis work behind API endpoints, without knowledge of
"""
from time import time
from typing import Dict
from model_server.image import generate_file_accessor, WriteableTiffFileAccessor
from pydantic import BaseModel
# TODO: timer decorator
class WorkflowRunRecord(BaseModel):
model_id: str
input_filepath: str
output_filepath: str
success: bool
timer_results: Dict[str, str]
def infer_image_to_image(fpi, model, where_output, **kwargs) -> dict:
# read image file into memory
# maybe this isn't accurate if e.g. czifile loads lazily
......@@ -13,22 +25,35 @@ def infer_image_to_image(fpi, model, where_output, **kwargs) -> dict:
img = generate_file_accessor(fpi)
dt_fi = time() - t0
# TODO: assert indata format
assert(img.shape_dict['T'] == 1)
assert (img.shape_dict['T'] == 1)
# run model inference
ch = kwargs.get('channel')
outdata, record = model.infer(img, channel=ch)
outdata, messages = model.infer(img, channel=ch)
dt_inf = time() - t0
# TODO: assert outdata format
# write output to file
outpath = where_output / (img.fpath.stem + '.tif')
WriteableTiffFileAccessor(outpath).write(outdata)
dt_fo = time() - t0
record['output_file'] = outpath
# TODO: smoother step-timing e.g. w/ decorate
record['times'] = {
timer_results = {
'file_input': dt_fi,
'inference': dt_inf - dt_fi,
'file_output': dt_fo - dt_fi - dt_inf
}
record = WorkflowRunRecord(
model_id=model.model_id,
input_filepath=str(fpi),
output_filepath=str(outpath),
success=messages['success'],
timer_results=timer_results
)
return record
\ No newline at end of file
......@@ -57,6 +57,7 @@ dependencies:
- greenlet=1.1.3=py37hf2a7229_0
- grpcio=1.41.1=py37h04d2302_1
- grpcio-tools=1.41.1=py37hf2a7229_1
- h11=0.14.0=pyhd8ed1ab_0
- h5py=3.7.0=nompi_py37h24adfc3_101
- hdf5=1.12.2=nompi_h2a0e4a3_101
- hytra=1.1.5=py_0_gfd7342d
......@@ -210,6 +211,7 @@ dependencies:
- ucrt=10.0.22621.0=h57928b3_0
- unicodedata2=14.0.0=py37hcc03f2d_1
- urllib3=2.0.2=pyhd8ed1ab_0
- uvicorn=0.19.0=py37h03978a9_0
- vc=14.3=hb25d44b_16
- vc14_runtime=14.34.31931=h5081d32_16
- vigra=1.11.1=py37hc3ed208_1033
......
import unittest
from conf.testing import czifile_attr, paths
from conf.testing import czifile
from model_server.image import CziImageFileAccessor
class TestCziImageFileAccess(unittest.TestCase):
......@@ -7,8 +7,8 @@ class TestCziImageFileAccess(unittest.TestCase):
pass
def test_czifile_is_correct_shape(self):
cf = CziImageFileAccessor(paths['czifile'])
self.assertEqual(cf.shape_dict['Y'], czifile_attr['h'])
self.assertEqual(cf.shape_dict['X'], czifile_attr['w'])
self.assertEqual(cf.chroma, czifile_attr['c'])
cf = CziImageFileAccessor(czifile['path'])
self.assertEqual(cf.shape_dict['Y'], czifile['h'])
self.assertEqual(cf.shape_dict['X'], czifile['w'])
self.assertEqual(cf.chroma, czifile['c'])
self.assertFalse(cf.is_3d)
\ No newline at end of file
......@@ -9,6 +9,25 @@ class TestCziImageFileAccess(unittest.TestCase):
def test_czifile_is_correct_shape(self):
model = DummyImageToImageModel()
model.infer(self.cf, channel=1)
img, _ = model.infer(self.cf, channel=1)
# TODO: check that result is a white rectangle
\ No newline at end of file
w = czifile['w']
h = czifile['h']
self.assertEqual(
img.shape,
(h, w),
'Inferred image is not the expected shape'
)
self.assertEqual(
img[int(w/2), int(h/2)],
255,
'Middle pixel is not white as expected'
)
self.assertEqual(
img[0, 0],
0,
'First pixel is not black as expected'
)
\ No newline at end of file
......@@ -9,6 +9,10 @@ class TestGetSessionObject(unittest.TestCase):
sesh = Session()
self.assertIs(sesh, Session(), 'Re-initializing Session class returned a new object')
from os.path import exists
self.assertTrue(exists(sesh.session_log), 'Session did not create a log file in the correct place')
self.assertTrue(exists(sesh.manifest_json), 'Session did not create a manifest JSON file in the correct place')
def test_restart_session(self):
sesh = Session()
logfile1 = sesh.session_log
......@@ -16,12 +20,15 @@ class TestGetSessionObject(unittest.TestCase):
logfile2 = sesh.session_log
self.assertIsNot(logfile1, logfile2, 'Restarting session does not generate new logfile')
def test_session_records_inference(self):
def test_session_records_workflow(self):
import json
from model_server.workflow import WorkflowRunRecord
sesh = Session()
di = {'model_id': 'test_model'}
sesh.record_inference(di)
di = WorkflowRunRecord(
model_id='test_model',
)
sesh.record_workflow_run(di)
with open(sesh.manifest_json, 'r') as fh:
do = json.load(fh)
self.assertEqual(di, do)
self.assertEqual(di.dict(), do, 'Manifest record is not correct')
......@@ -11,4 +11,27 @@ class TestGetSessionObject(unittest.TestCase):
def test_single_session_instance(self):
result = infer_image_to_image(czifile['path'], self.model, output_path)
self.assertEqual(result['success'], True)
self.assertTrue(result.success)
import tifffile
img = tifffile.imread(result.output_filepath)
w = czifile['w']
h = czifile['h']
self.assertEqual(
img.shape,
(h, w),
'Inferred image is not the expected shape'
)
self.assertEqual(
img[int(w/2), int(h/2)],
255,
'Middle pixel is not white as expected'
)
self.assertEqual(
img[0, 0],
0,
'First pixel is not black as expected'
)
\ No newline at end of file
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment