diff --git a/model_server/api.py b/model_server/api.py index 2289fea41fde7f44c2aa42e3e594b740d06a892f..47471e3216d9f7cdd8553c827b8d881533ecb225 100644 --- a/model_server/api.py +++ b/model_server/api.py @@ -47,29 +47,4 @@ def infer_img(model_id: str, imgf: str, channel: int = None) -> dict: channel=channel ) session.record_workflow_run(record) - return record - - # model = session.models[model_id] - # - # # read image file into memory - # # maybe this isn't accurate if e.g. czifile loads lazily - # t0 = time() - # img = generate_file_accessor(session.inbound / imgf) - # dt_fi = time() - t0 - # - # # run model inference - # outdata, record = model.infer(img.data, channel=channel) - # dt_inf = time() - t0 - # - # # write output to file - # outpath = session.outbound / img.fpath.stem / '.tif' - # WriteableTiffFileAccessor(outpath).write(outdata) - # dt_fo = time() - t0 - # - # record['output_file'] = outpath - # record['times'] = { - # 'file_input': dt_fi, - # 'inference': dt_inf - dt_fi, - # 'file_output': dt_fo - dt_fi - dt_inf - # } - # session.register_inference(record) \ No newline at end of file + return record \ No newline at end of file diff --git a/model_server/model.py b/model_server/model.py index 416c0055930fc7b75cd3e4f19097744ef4238769..c9c3140b697bd4543e603b760408b718e8c45914 100644 --- a/model_server/model.py +++ b/model_server/model.py @@ -1,11 +1,11 @@ +from abc import ABC, abstractmethod from math import floor import numpy as np from model_server.image import GenericImageFileAccessor -# TODO: properly abstractify base class -class Model(object): +class Model(ABC): def __init__(self, autoload=True): """ @@ -15,10 +15,11 @@ class Model(object): """ self.autoload = autoload - # abstract + @abstractmethod def load(self): pass + @abstractmethod def infer(self, img: GenericImageFileAccessor, channel: int = None @@ -33,7 +34,16 @@ class Model(object): def reload(self): self.load() -class ImageToImageModel(Model): # receives an image and returns an image of the same size +class ImageToImageModel(Model): + + def __init__(self, **kwargs): + """ + Abstract class for models that receive an image and return an image of the same size + :param kwargs: variable length keyword arguments + """ + return super(**kwargs) + + @abstractmethod def infer(self, img, channel=None) -> (np.ndarray, dict): super().infer(img, channel) diff --git a/tests/test_model.py b/tests/test_model.py index 435689979bb5cc96804b36aa0cbcd445139d3dab..1920ea3b4d128a86c4909068404276276d9a8590 100644 --- a/tests/test_model.py +++ b/tests/test_model.py @@ -7,6 +7,10 @@ class TestCziImageFileAccess(unittest.TestCase): def setUp(self) -> None: self.cf = CziImageFileAccessor(czifile['path']) + def test_instantiate_model_with_nondefault_kwarg(self): + model = DummyImageToImageModel(autoload=False) + self.assertFalse(model.autoload, 'Could not override autoload flag in subclass of Model.') + def test_czifile_is_correct_shape(self): model = DummyImageToImageModel() img, _ = model.infer(self.cf, channel=1)