import numpy as np
from PIL import Image, ImageDraw, ImageFont

from model_server.process import rescale

def draw_boxes_on_2d_image(yx_img, boxes, **kwargs):
    pilimg = Image.fromarray(np.copy(yx_img))  # drawing modifies array in-place
    draw = ImageDraw.Draw(pilimg)
    font_size = kwargs.get('font_size', 18)
    linewidth = kwargs.get('linewidth', 4)

    draw.font = ImageFont.truetype(font="arial.ttf", size=font_size)

    for box in boxes:
        y0 = box['info'].y0
        y1 = box['info'].y1
        x0 = box['info'].x0
        x1 = box['info'].x1
        xm = round((x0 + x1) / 2)

        la = box['info'].label
        zi = box['info'].zi

        draw.rectangle([(x0, y0), (x1, y1)], outline='white', width=linewidth)

        if kwargs.get('add_label') is True:
            draw.text((xm, y0), f'Z{zi:04d}-L{la:04d}', fill='white', anchor='mb')

    return pilimg


def draw_boxes_on_3d_image(yxcz_img, boxes, draw_full_depth=False, **kwargs):
    assert len(yxcz_img.shape) == 4
    nz = yxcz_img.shape[3]
    assert yxcz_img.shape[2] == 1

    annotated = np.zeros(yxcz_img.shape, dtype=yxcz_img.dtype)

    for zi in range(0, nz):
        if draw_full_depth:
            zi_boxes = boxes
        else:
            zi_boxes = [bb for bb in boxes if bb['info'].zi == zi]
        annotated[:, :, 0, zi] = draw_boxes_on_2d_image(yxcz_img[:, :, 0, zi], zi_boxes, **kwargs)

    if clip := kwargs.get('rescale_clip'):
        assert clip >= 0.0 and clip <= 1.0
        annotated = rescale(annotated, clip=clip)

    return annotated

def draw_box_on_patch(patch, bbox, linewidth=1):
    assert len(patch.shape) == 2
    ((x0, y0), (x1, y1)) = bbox
    pilimg = Image.fromarray(patch)  # drawing modifies array in-place
    draw = ImageDraw.Draw(pilimg)
    draw.rectangle([(x0, y0), (x1, y1)], outline='white', width=linewidth)
    return np.array(pilimg)

def draw_contours_on_patch(patch, contours, linewidth=1):
    assert len(patch.shape) == 2
    pilimg = Image.fromarray(patch)  # drawing modifies array in-place
    draw = ImageDraw.Draw(pilimg)
    for co in contours:
        draw.line([(p[1], p[0]) for p in co], width=linewidth, joint='curve')
    return np.array(pilimg)