Skip to content
Snippets Groups Projects
Commit f8dd4aa2 authored by Christopher Randolph Rhodes's avatar Christopher Randolph Rhodes
Browse files

Abstractified Model base class

parent ca9709e6
No related branches found
No related tags found
No related merge requests found
......@@ -47,29 +47,4 @@ def infer_img(model_id: str, imgf: str, channel: int = None) -> dict:
channel=channel
)
session.record_workflow_run(record)
return record
# model = session.models[model_id]
#
# # read image file into memory
# # maybe this isn't accurate if e.g. czifile loads lazily
# t0 = time()
# img = generate_file_accessor(session.inbound / imgf)
# dt_fi = time() - t0
#
# # run model inference
# outdata, record = model.infer(img.data, channel=channel)
# dt_inf = time() - t0
#
# # write output to file
# outpath = session.outbound / img.fpath.stem / '.tif'
# WriteableTiffFileAccessor(outpath).write(outdata)
# dt_fo = time() - t0
#
# record['output_file'] = outpath
# record['times'] = {
# 'file_input': dt_fi,
# 'inference': dt_inf - dt_fi,
# 'file_output': dt_fo - dt_fi - dt_inf
# }
# session.register_inference(record)
\ No newline at end of file
return record
\ No newline at end of file
from abc import ABC, abstractmethod
from math import floor
import numpy as np
from model_server.image import GenericImageFileAccessor
# TODO: properly abstractify base class
class Model(object):
class Model(ABC):
def __init__(self, autoload=True):
"""
......@@ -15,10 +15,11 @@ class Model(object):
"""
self.autoload = autoload
# abstract
@abstractmethod
def load(self):
pass
@abstractmethod
def infer(self,
img: GenericImageFileAccessor,
channel: int = None
......@@ -33,7 +34,16 @@ class Model(object):
def reload(self):
self.load()
class ImageToImageModel(Model): # receives an image and returns an image of the same size
class ImageToImageModel(Model):
def __init__(self, **kwargs):
"""
Abstract class for models that receive an image and return an image of the same size
:param kwargs: variable length keyword arguments
"""
return super(**kwargs)
@abstractmethod
def infer(self, img, channel=None) -> (np.ndarray, dict):
super().infer(img, channel)
......
......@@ -7,6 +7,10 @@ class TestCziImageFileAccess(unittest.TestCase):
def setUp(self) -> None:
self.cf = CziImageFileAccessor(czifile['path'])
def test_instantiate_model_with_nondefault_kwarg(self):
model = DummyImageToImageModel(autoload=False)
self.assertFalse(model.autoload, 'Could not override autoload flag in subclass of Model.')
def test_czifile_is_correct_shape(self):
model = DummyImageToImageModel()
img, _ = model.infer(self.cf, channel=1)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment