from __future__ import (unicode_literals, division, absolute_import, print_function)


from PIL import Image
import io

try:
    import PyPDF2
except ImportError:
    try:
        from . import PyPDF2
    except ImportError:
        PyPDF2 = None


from .ion import (ion_type, IonAnnotation, IonList, IonSExp, IonString, IonStruct, IonSymbol)
from .message_logging import log
from .utilities import (convert_jxr_to_tiff, disable_debug_log)
from .yj_container import YJFragmentKey
from .yj_structure import SYMBOL_FORMATS

__license__ = "GPL v3"
__copyright__ = "2016-2022, John Howell <jhowell@acm.org>"

USE_HIGHEST_RESOLUTION_IMAGE_VARIANT = True
MIN_JPEG_QUALITY = 80
MAX_JPEG_QUALITY = 100
COMBINED_TILE_SIZE_FACTOR = 1.2
TILE_SIZE_REPORT_PERCENTAGE = 20
OPTIMIZE_PNG = False
DEBUG_TILES = False
DEBUG_VARIANTS = False

IMAGE_COLOR_MODES = [
    "1",
    "L",
    "P",
    "RGB",
    ]

IMAGE_OPACITY_MODE = "A"


class ImageResource(object):
    def __init__(self, location, image_format, height, width, raw_media):
        self.location = location
        self.image_format = image_format
        self.height = height
        self.width = width
        self.raw_media = raw_media


class KFX_PDF(object):
    def __init__(self, book):
        self.book = book

    def extract_pdf_resources(self):
        ordered_pdfs = self.get_ordered_images(["$565"], remove_duplicates=True, include_unreferenced=True)

        if len(ordered_pdfs) == 0:
            pdf_data = None
        elif len(ordered_pdfs) == 1:
            pdf_data = ordered_pdfs[0].raw_media
        elif PyPDF2 is None:
            log.error("PyPDF2 package is missing. Unable to combine PDF resources")
            pdf_data = None
        else:
            try:
                merger = PyPDF2.PdfFileMerger()
                for single_pdf in ordered_pdfs:
                    merger.append(fileobj=io.BytesIO(single_pdf.raw_media))

                merged_file = io.BytesIO()
                merger.write(merged_file)
                pdf_data = merged_file.getvalue()
                merged_file.close()
            except Exception as e:
                log.error("PdfFileMerger error: %s" % repr(e))
                pdf_data = None

            if pdf_data is not None:
                log.info("Combined %d PDF resources into a single file" % len(ordered_pdfs))

        return pdf_data

    def convert_image_resources(self):
        ordered_images = self.get_ordered_images(["$286", "$285", "$548", "$284"])
        return convert_images_to_pdf_data(ordered_images)

    def get_ordered_images(self, formats, remove_duplicates=False, include_unreferenced=False):

        referenced_resources = self.collect_ordered_image_references(formats, False)

        if include_unreferenced:
            for fragment in self.book.fragments.get_all("$164"):
                resource = fragment.value
                resource_format = resource.get("$161")
                if resource_format in formats and fragment.fid not in referenced_resources:
                    log.error("Found unreferenced resource: %s" % fragment.fid)
                    referenced_resources.append(fragment.fid)

        ordered_images = []
        image_locations = set()

        for fid in referenced_resources:
            image = self.get_resource_image(fid)
            if image is not None:
                if not (remove_duplicates and image.location in image_locations):
                    ordered_images.append(image)
                    image_locations.add(image.location)

        return ordered_images

    def collect_ordered_image_references(self, formats, remove_duplicates):
        processed_story_names = set()
        ordered_image_resources = []

        def collect_section_info(section_name):
            pending_story_names = []
            section_image_resources = set()
            section_image_types = set()

            def add_section_resource(resource_name, image_type):
                if resource_name is not None and resource_name not in section_image_resources:
                    fragment = self.book.fragments.get(ftype="$164", fid=resource_name)
                    if fragment is not None:
                        resource = fragment.value
                        if resource.get("$161") in formats:
                            section_image_resources.add(resource_name)
                            section_image_types.add(image_type)

                            if not (remove_duplicates and resource_name in ordered_image_resources):
                                ordered_image_resources.append(resource_name)

            def walk_content(data, content_key):
                data_type = ion_type(data)

                if data_type is IonAnnotation:
                    walk_content(data.value, content_key)

                elif data_type is IonList:
                    for i, fc in enumerate(data):
                        if content_key in {"$146", "$274"} and self.book.is_kpf_prepub and ion_type(fc) is IonSymbol:
                            fc = self.book.fragments[YJFragmentKey(ftype="$608", fid=fc)]

                        walk_content(fc, content_key)

                elif data_type is IonSExp:
                    for fc in data:
                        walk_content(fc, content_key)

                elif data_type is IonStruct:
                    annot_type = data.get("$687")
                    typ = data.get("$159")

                    if typ == "$271":
                        add_section_resource(data.get("$175"), "foreground")

                    if "$479" in data:
                        add_section_resource(data["$479"], "background")

                    if "$141" in data:
                        for pt in data["$141"]:
                            if isinstance(pt, IonAnnotation):
                                pt = pt.value

                            walk_content(pt, "$141")

                    if "$683" in data:
                        walk_content(data["$683"], "$683")

                    if "$749" in data:
                        walk_content(self.book.fragments[YJFragmentKey(ftype="$259", fid=data["$749"])], "$259")

                    if "$146" in data:
                        walk_content(data["$146"], "$274" if typ == "$274" else "$146")

                    if "$145" in data and annot_type not in ["$584", "$690"]:
                        fv = data["$145"]
                        if ion_type(fv) is not IonStruct:
                            walk_content(fv, "$145")

                    if "$176" in data and content_key != "$259":
                        fv = data["$176"]

                        if self.book.has_illustrated_layout_conditional_page_template:
                            if fv not in pending_story_names:
                                pending_story_names.append(fv)
                        else:
                            if fv not in processed_story_names:
                                walk_content(self.book.fragments[YJFragmentKey(ftype="$259", fid=fv)], "$259")
                                processed_story_names.add(fv)

                    if "$157" in data:
                        walk_content(self.book.fragments[YJFragmentKey(ftype="$157", fid=data["$157"])], "$157")

                    for fk, fv in data.items():
                        if ion_type(fv) != IonString and fk not in {
                                "$749", "$584", "$683", "$145",
                                "$146", "$141", "$702", "$250", "$176",
                                "yj.dictionary.term", "yj.dictionary.unnormalized_term"}:
                            walk_content(fv, fk)

            walk_content(self.book.fragments[YJFragmentKey(ftype="$260", fid=section_name)], "$260")

            for story_name in pending_story_names:
                if story_name not in processed_story_names:
                    walk_content(self.book.fragments[YJFragmentKey(ftype="$259", fid=story_name)], "$259")
                    processed_story_names.add(story_name)

            if len(section_image_resources) > 2:
                log.error("Section %s contains more than two images", section_name)

            if len(section_image_types) > 1:
                log.error("Section %s contains both background and foreground images", section_name)

        for section_name in self.book.ordered_section_names():
            collect_section_info(section_name)

        return ordered_image_resources

    def get_resource_image(self, resource_name, ignore_variants=False):
        fragment = self.book.fragments.get(ftype="$164", fid=resource_name)
        if fragment is None:
            return None

        resource = fragment.value
        resource_format = resource.get("$161")
        resource_height = resource.get("$423", None) or resource.get("$67", None)
        resource_width = resource.get("$422", None) or resource.get("$66", None)

        if "$636" in resource:
            yj_tiles = resource.get("$636")
            tile_height = resource.get("$638")
            tile_width = resource.get("$637")
            tile_padding = resource.get("$797", 0)
            location = yj_tiles[0][0].partition("-tile")[0]

            tiles_raw_media = []
            for row in yj_tiles:
                for tile_location in row:
                    tile_raw_media_frag = self.book.fragments.get(ftype="$417", fid=tile_location)
                    tiles_raw_media.append(None if tile_raw_media_frag is None else tile_raw_media_frag.value)

            raw_media = combine_image_tiles(
                resource_name, resource_height, resource_width, resource_format, tile_height, tile_width, tile_padding,
                yj_tiles, tiles_raw_media)
        else:
            location = resource.get("$165")
            if location is not None:
                raw_media = self.book.fragments.get(ftype="$417", fid=location)
                if raw_media is not None:
                    raw_media = raw_media.value
            else:
                raw_media = None

        if not ignore_variants:
            for rr in resource.get("$635", []):
                variant = self.get_resource_image(rr, ignore_variants=True)

                if (USE_HIGHEST_RESOLUTION_IMAGE_VARIANT and variant is not None and
                        variant.width > resource_width and variant.height > resource_height):
                    if DEBUG_VARIANTS:
                        log.info("Replacing image %s (%dx%d) with variant %s (%dx%d)" % (
                                location, resource_width, resource_height, variant.location, variant.width, variant.height))

                    location, raw_media, resource_width, resource_height = variant.location, variant.raw_media, variant.width, variant.height

        if raw_media is None:
            return None

        return ImageResource(location, resource_format, resource_height, resource_width, raw_media)


def convert_images_to_pdf_data(ordered_images):
    if len(ordered_images) == 0:
        pdf_data = None
    else:
        image_list = []
        for image_resource in ordered_images:
            image_data = image_resource.raw_media

            if image_resource.image_format == "$548":
                try:
                    image_data = convert_jxr_to_tiff(image_data, image_resource.location)
                except Exception as e:
                    log.error("Exception during conversion of JPEG-XR '%s' to TIFF: %s" % (
                        image_resource.location, repr(e)))

            with disable_debug_log():
                image = Image.open(io.BytesIO(image_data))
                image = image.convert("RGB")
            image_list.append(image)

        first_image = image_list.pop(0)
        pdf_file = io.BytesIO()

        with disable_debug_log():
            first_image.save(pdf_file, "pdf", save_all=True, append_images=image_list)

            for image in image_list:
                image.close()

            first_image.close()

        pdf_data = pdf_file.getvalue()
        pdf_file.close()

    return pdf_data


def combine_image_tiles(
        resource_name, resource_height, resource_width, resource_format, tile_height, tile_width, tile_padding,
        yj_tiles, tiles_raw_media):

    if DEBUG_TILES:
        ncols = len(yj_tiles)
        nrows = len(yj_tiles[0])
        log.warning("tiled image %dx%d: %s" % (nrows, ncols, resource_name))

    with disable_debug_log():
        tile_images = []
        separate_tiles_size = tile_count = 0
        full_image_color_mode = IMAGE_COLOR_MODES[0]
        full_image_opacity_mode = ""

        tile_num = 0
        missing_tiles = []
        for y, row in enumerate(yj_tiles):
            for x, tile_location in enumerate(row):
                tile_raw_media = tiles_raw_media[tile_num]
                if tile_raw_media is not None:
                    tile_count += 1
                    separate_tiles_size += len(tile_raw_media)
                    tile = Image.open(io.BytesIO(tile_raw_media))

                    if tile.mode.endswith(IMAGE_OPACITY_MODE):
                        tile_color_mode = tile.mode[:-1]
                        full_image_opacity_mode = IMAGE_OPACITY_MODE
                    else:
                        tile_color_mode = tile.mode

                    if tile_color_mode not in IMAGE_COLOR_MODES:
                        log.error("Resource %s tile %s has unexpected image mode %s" % (resource_name, tile_location, tile.mode))
                    elif IMAGE_COLOR_MODES.index(tile_color_mode) > IMAGE_COLOR_MODES.index(full_image_color_mode):
                        full_image_color_mode = tile_color_mode
                else:
                    tile = None
                    missing_tiles.append((x, y))

                tile_images.append(tile)
                tile_num += 1

        if missing_tiles:
            log.error("Resource %s is missing tiles: %s" % (resource_name, repr(missing_tiles)))

        full_image = Image.new(full_image_color_mode + full_image_opacity_mode, (resource_width, resource_height))

        for y, row in enumerate(yj_tiles):
            top_padding = 0 if y == 0 else tile_padding
            bottom_padding = min(tile_padding, resource_height - tile_height * (y + 1))

            for x, tile_location in enumerate(row):
                left_padding = 0 if x == 0 else tile_padding
                right_padding = min(tile_padding, resource_width - tile_width * (x + 1))

                tile = tile_images.pop(0)
                if tile is not None:
                    twidth, theight = tile.size
                    if twidth != tile_width + left_padding + right_padding or theight != tile_height + top_padding + bottom_padding:
                        log.error("Resource %s tile %d, %d size (%d, %d) does not have padding %d of expected size (%d, %d)" % (
                            resource_name, x, y, twidth, theight, tile_padding, tile_width, tile_height))
                        log.info("tile padding ltrb: %d, %d, %d, %d" % (left_padding, top_padding, right_padding, bottom_padding))

                    crop = (left_padding, top_padding, tile_width + left_padding, tile_height + top_padding)
                    tile = tile.crop(crop)
                    full_image.paste(tile, (x * tile_width, y * tile_height))
                    tile.close()

        if full_image.size != (resource_width, resource_height):
            log.error("Resource %s combined tiled image size is (%d, %d) but should be (%d, %d)" % (
                    resource_name, full_image.size[0], full_image.size[1], resource_width, resource_height))

        fmt = SYMBOL_FORMATS[resource_format]

        if fmt == "jpg":
            desired_combined_size = max(int(separate_tiles_size * COMBINED_TILE_SIZE_FACTOR), 1024)
            min_quality = MIN_JPEG_QUALITY
            max_quality = MAX_JPEG_QUALITY
            best_size_diff = best_quality = raw_media = None

            while max_quality >= min_quality:
                quality = (max_quality + min_quality) // 2
                outfile = io.BytesIO()
                full_image.save(outfile, "jpeg", quality=quality)
                test_raw_media = outfile.getvalue()
                outfile.close()

                size_diff = len(test_raw_media) - desired_combined_size

                if best_size_diff is None or abs(size_diff) < abs(best_size_diff):
                    best_size_diff = size_diff
                    best_quality = quality
                    raw_media = test_raw_media

                if len(test_raw_media) < desired_combined_size:
                    min_quality = quality + 1
                else:
                    max_quality = quality - 1

            diff_percentage = (best_size_diff * 100) // desired_combined_size
            if DEBUG_TILES and abs(diff_percentage) > TILE_SIZE_REPORT_PERCENTAGE:
                log.warning("Image resource %s has %d tiles %d bytes combined into quality %d %s JPEG %d bytes (%+d%%)" % (
                    resource_name, tile_count, separate_tiles_size, best_quality, full_image.mode, len(raw_media), diff_percentage))
        else:
            outfile = io.BytesIO()
            full_image.save(outfile, fmt, optimize=OPTIMIZE_PNG)
            raw_media = outfile.getvalue()
            outfile.close()

    return raw_media
