import requests
import unittest

import numpy as np

import conf.testing
from model_server.image import CziImageFileAccessor, InMemoryDataAccessor, write_accessor_data_to_file
from model_server.ilastik import IlastikObjectClassifierModel, IlastikPixelClassifierModel
from model_server.workflow import infer_image_to_image
from tests.test_api import TestServerBaseClass

class TestIlastikPixelClassification(unittest.TestCase):
    def setUp(self) -> None:
        self.cf = CziImageFileAccessor(conf.testing.czifile['path'])


    def test_faulthandler(self): # recreate error that is messing up ilastik
        import io
        import sys
        import faulthandler

        with self.assertRaises(io.UnsupportedOperation):
            faulthandler.enable(file=sys.stdout)


    def test_raise_error_if_autoload_disabled(self):
        model = IlastikPixelClassifierModel(
            {'project_file': conf.testing.ilastik['pixel_classifier']},
            autoload=False
        )
        w = 512
        h = 256

        input_img = InMemoryDataAccessor(data=np.random.rand(w, h, 1, 1))

        with self.assertRaises(AttributeError):
            pxmap, _ = model.infer(input_img)


    def test_run_pixel_classifier_on_random_data(self):
        model = IlastikPixelClassifierModel(
            {'project_file': conf.testing.ilastik['pixel_classifier']},
        )
        w = 512
        h = 256

        input_img = InMemoryDataAccessor(data=np.random.rand(h, w, 1, 1))

        pxmap, _ = model.infer(input_img)
        self.assertEqual(pxmap.shape, (h, w, 2, 1))


    def test_run_pixel_classifier(self):
        channel = 0
        model = IlastikPixelClassifierModel(
            {'project_file': conf.testing.ilastik['pixel_classifier']}
        )
        cf = CziImageFileAccessor(
            conf.testing.czifile['path']
        )
        mono_image = cf.get_one_channel_data(channel)

        self.assertEqual(mono_image.shape_dict['X'], conf.testing.czifile['w'])
        self.assertEqual(mono_image.shape_dict['Y'], conf.testing.czifile['h'])
        self.assertEqual(mono_image.shape_dict['C'], 1)
        self.assertEqual(mono_image.shape_dict['Z'], 1)

        pxmap, _ = model.infer(mono_image)

        self.assertEqual(pxmap.shape[0:2], cf.shape[0:2])
        self.assertEqual(pxmap.shape_dict['C'], 2)
        self.assertEqual(pxmap.shape_dict['Z'], 1)
        self.assertTrue(
            write_accessor_data_to_file(
                conf.testing.output_path / f'pxmap_{cf.fpath.stem}_ch{channel}.tif',
                pxmap
            )
        )

        self.mono_image = mono_image
        self.pxmap = pxmap

    def test_run_object_classifier(self):
        self.test_run_pixel_classifier()
        fp = conf.testing.czifile['path']
        model = IlastikObjectClassifierModel(
            {'project_file': conf.testing.ilastik['object_classifier']}
        )
        objmap, _ = model.infer(self.mono_image, self.pxmap)

        self.assertTrue(
            write_accessor_data_to_file(
                conf.testing.output_path / f'obmap_{fp.stem}.tif',
                objmap,
            )
        )
        self.assertEqual(objmap.data.max(), 3)

    def test_ilastik_pixel_classification_as_workflow(self):
        result = infer_image_to_image(
            conf.testing.czifile['path'],
            IlastikPixelClassifierModel(
                {'project_file': conf.testing.ilastik['pixel_classifier']}
            ),
            conf.testing.output_path,
            channel=0,
        )
        self.assertTrue(result.success)
        self.assertGreater(result.timer_results['inference'], 1.0)

class TestIlastikOverApi(TestServerBaseClass):

    def test_httpexception_if_incorrect_project_file_loaded(self):
        resp_load = requests.put(
            self.uri + 'models/ilastik/pixel_classification/load/',
            params={'project_file': 'improper.ilp'},
        )
        self.assertEqual(resp_load.status_code, 404)


    def test_load_ilastik_pixel_model(self):
        resp_load = requests.put(
            self.uri + 'models/ilastik/pixel_classification/load/',
            params={'project_file': str(conf.testing.ilastik['pixel_classifier'])},
        )
        model_id = resp_load.json()['model_id']

        self.assertEqual(resp_load.status_code, 200, resp_load.json())
        resp_list = requests.get(self.uri + 'models')
        self.assertEqual(resp_list.status_code, 200)
        rj = resp_list.json()
        self.assertEqual(rj[model_id]['class'], 'IlastikPixelClassifierModel')
        return model_id

    def test_load_another_ilastik_pixel_model(self):
        model_id = self.test_load_ilastik_pixel_model()
        resp_list_1st = requests.get(self.uri + 'models').json()
        self.assertEqual(len(resp_list_1st), 1, resp_list_1st)
        resp_load_2nd = requests.put(
            self.uri + 'models/ilastik/pixel_classification/load/',
            params={
                'project_file': str(conf.testing.ilastik['pixel_classifier']),
                'duplicate': True,
            },
        )
        resp_list_2nd = requests.get(self.uri + 'models').json()
        self.assertEqual(len(resp_list_2nd), 2, resp_list_2nd)
        resp_load_3rd = requests.put(
            self.uri + 'models/ilastik/pixel_classification/load/',
            params={
                'project_file': str(conf.testing.ilastik['pixel_classifier']),
                'duplicate': False,
            },
        )
        resp_list_3rd = requests.get(self.uri + 'models').json()
        self.assertEqual(len(resp_list_3rd), 2, resp_list_3rd)


    def test_load_ilastik_object_model(self):
        resp_load = requests.put(
            self.uri + 'models/ilastik/object_classification/load/',
            params={'project_file': str(conf.testing.ilastik['object_classifier'])},
        )
        model_id = resp_load.json()['model_id']

        self.assertEqual(resp_load.status_code, 200, resp_load.json())
        resp_list = requests.get(self.uri + 'models')
        self.assertEqual(resp_list.status_code, 200)
        rj = resp_list.json()
        self.assertEqual(rj[model_id]['class'], 'IlastikObjectClassifierModel')
        return model_id

    def test_ilastik_infer_pixel_probability(self):
        self.copy_input_file_to_server()
        model_id = self.test_load_ilastik_pixel_model()

        resp_infer = requests.put(
            self.uri + f'infer/from_image_file',
            params={
                'model_id': model_id,
                'input_filename': conf.testing.czifile['filename'],
                'channel': 0,
            },
        )
        self.assertEqual(resp_infer.status_code, 200, resp_infer.content.decode())

    def test_ilastik_infer_px_then_ob(self):
        self.copy_input_file_to_server()
        px_model_id = self.test_load_ilastik_pixel_model()
        ob_model_id = self.test_load_ilastik_object_model()

        resp_infer = requests.put(
            self.uri + f'models/ilastik/pixel_then_object_classification/infer/',
            params={
                'px_model_id': px_model_id,
                'ob_model_id': ob_model_id,
                'input_filename': conf.testing.czifile['filename'],
                'channel': 0,
            }
        )
        self.assertEqual(resp_infer.status_code, 200, resp_infer.content.decode())