from __future__ import (unicode_literals, division, absolute_import, print_function)

import decimal
import hashlib
import struct

from .ion import (ion_type, IonAnnotation, IonBLOB, IonBool, IonCLOB, IonDecimal,
            IonFloat, IonInt, IonList, IonNop, IonNull, IonSExp, IonString, IonStruct, IonSymbol, IonTimestamp,
            IonTimestampTZ, ION_TIMESTAMP_Y, ION_TIMESTAMP_YM, ION_TIMESTAMP_YMD, ION_TIMESTAMP_YMDHM,
            ION_TIMESTAMP_YMDHMS, ION_TIMESTAMP_YMDHMSF)
from .ion_text import IonSerial
from .misc import (gunzip, hex_string)

__license__   = "GPL v3"
__copyright__ = "2018, John Howell <jhowell@acm.org>"

DEBUG = False

class IonBinary(IonSerial):

    major_version = 1
    minor_version = 0

    version_marker = 0xe0

    signature = chr(version_marker) + chr(major_version) + chr(minor_version) + chr(0xea)

    gzip_signature = b"\x1f\x8b\x08"
    drmion_signature = b"\xeaDRMION\xee"

    def deserialize_multiple_values(self, data, import_symbols=False, unzip=False):
        if unzip:
            data = gunzip(data)

        values = self.deserialize_multiple_values_(data, import_symbols)

        return values

    SORTED_STRUCT_FLAG = 1
    VARIABLE_LEN_FLAG = 14
    NULL_FLAG = 15

    def serialize_multiple_values_(self, values):
        serial = Serializer()
        serial.append(IonBinary.signature)

        for value in values:
            serial.append(self.serialize_value(value))

        return serial.serialize()

    def deserialize_multiple_values_(self, data, import_symbols):
        if DEBUG: self.log.debug("decoding: %s" % hex_string(data[:1000]))

        self.import_symbols = import_symbols

        ion_signature = data[:4]
        if ion_signature != IonBinary.signature:
            raise Exception("Ion signature is incorrect (%s)" % hex_string(ion_signature))

        serial = Deserializer(data[4:])
        result = []
        while len(serial):
            if serial.extract(1, advance=False) == IonBinary.version_marker:

                ion_signature = serial.unpack("4s")
                if ion_signature != IonBinary.signature:
                    raise Exception("Embedded Ion signature is incorrect (%s)" % hex_string(ion_signature))
            else:
                value = self.deserialize_value(serial)

                if self.import_symbols and isinstance(value, IonAnnotation):
                    if value.annotations[0] == "$ion_symbol_table":
                        self.symtab.create(value.value)
                    elif value.annotations[0] == "$ion_shared_symbol_table":
                        self.symtab.catalog.create_shared_symbol_table(value.value)

                if not isinstance(value, IonNop):
                    result.append(value)

        return result

    def serialize_value(self, value):
        handler = IonBinary.ION_TYPE_HANDLERS[ion_type(value)]
        signature,data = handler(self, value)

        if signature is None:
            return data

        length = len(data)

        if length < IonBinary.VARIABLE_LEN_FLAG:
            return descriptor(signature, length) + data

        return descriptor(signature, IonBinary.VARIABLE_LEN_FLAG) + serialize_vluint(length) + data

    def deserialize_value(self, serial):

        descriptor = serial.unpack("B")
        if descriptor == IonBinary.version_marker:
            raise Exception("Unexpected Ion version marker within data stream")

        signature = descriptor >> 4
        flag = descriptor & 0x0f
        if DEBUG: self.log.debug("IonBinary 0x%02x: signature=%d flag=%d data=%s" % (
                    descriptor, signature, flag, hex_string(serial.extract(advance=False)[:16])))

        extract_data, deserializer, name = IonBinary.VALUE_DESERIALIZERS[signature]

        if flag == IonBinary.NULL_FLAG and signature != IonBinary.null_value_signature:
            self.log.error("IonBinary: Deserialized null of type %s" % name)
            extract_data, deserializer, name = IonBinary.VALUE_DESERIALIZERS[IonBinary.null_value_signature]

        if extract_data:
            length = deserialize_vluint(serial) if flag == IonBinary.VARIABLE_LEN_FLAG else flag
            return deserializer(self, serial.extract(length))

        return deserializer(self, flag, serial)

    null_value_signature = 0

    def serialize_null_value(self, value):
        return (None, descriptor(IonBinary.null_value_signature, IonBinary.NULL_FLAG))

    def deserialize_null_value(self, flag, serial):
        if flag == IonBinary.NULL_FLAG:
            return None

        length = deserialize_vluint(serial) if flag == IonBinary.VARIABLE_LEN_FLAG else flag
        serial.extract(length)
        return IonNop()

    bool_value_signature = 1

    def serialize_bool_value(self, value):
        return (None, descriptor(IonBinary.bool_value_signature, 1 if value else 0))

    def deserialize_bool_value(self, flag, serial):
        if flag > 1:
            raise Exception("BinaryIonBool: Unknown IonBool flag value: %d" % flag)

        return flag != 0

    def serialize_int_value(self, value):
        return (IonBinary.posint_value_signature, serialize_unsignedint(value)) if value >= 0 else (IonBinary.negint_value_signature,
                serialize_unsignedint(-value))

    posint_value_signature = 2

    def deserialize_posint_value(self, data):
        return deserialize_unsignedint(data)

    negint_value_signature = 3

    def deserialize_negint_value(self, data):
        if len(data) == 0:
            self.log.error("BinaryIonNegInt has no data")

        if ord(data[0]) == 0:
            self.log.error("BinaryIonNegInt data starts with 0x00: %s" % hex_string(data))

        return -deserialize_unsignedint(data)

    float_value_signature = 4

    def serialize_float_value(self, value):
        return (IonBinary.float_value_signature, b"" if value == 0.0 else struct.pack(b">d", value))

    def deserialize_float_value(self, data):
        if len(data) == 0:
            return float(0.0)

        if len(data) == 4:
            return struct.unpack_from(b">f", data)[0]

        if len(data) == 8:
            return struct.unpack_from(b">d", data)[0]

        raise Exception("IonFloat unexpected data length: %s" % hex_string(data))

    decimal_value_signature = 5

    def serialize_decimal_value(self, value):
        if value.is_zero():
            return (IonBinary.decimal_value_signature, b"")

        vt = value.as_tuple()
        return (IonBinary.decimal_value_signature, serialize_vlsint(vt.exponent) +
                serialize_signedint(combine_decimal_digits(vt.digits, vt.sign)))

    def deserialize_decimal_value(self, data):
        if len(data) == 0:
            return decimal.Decimal(0)

        serial = Deserializer(data)
        exponent = deserialize_vlsint(serial)
        magnitude = deserialize_signedint(serial.extract())
        return decimal.Decimal(magnitude) * (decimal.Decimal(10) ** exponent)

    timestamp_value_signature = 6

    def serialize_timestamp_value(self, value):
        serial = Serializer()

        if isinstance(value.tzinfo, IonTimestampTZ):
            offset_minutes = value.tzinfo.offset_minutes()
            format_len = len(value.tzinfo.format())
            fraction_exponent = -value.tzinfo.fraction_len()
        else:
            offset_minutes = int(value.utcoffset().total_seconds()) // 60 if value.utcoffset() is not None else None
            format_len = len(ION_TIMESTAMP_YMDHMSF)
            fraction_exponent = -3

        serial.append(serialize_vlsint(offset_minutes))
        serial.append(serialize_vluint(value.year))

        if format_len >= len(ION_TIMESTAMP_YM):
            serial.append(serialize_vluint(value.month))

            if format_len >= len(ION_TIMESTAMP_YMD):
                serial.append(serialize_vluint(value.day))

                if format_len >= len(ION_TIMESTAMP_YMDHM):
                    serial.append(serialize_vluint(value.hour))
                    serial.append(serialize_vluint(value.minute))

                    if format_len >= len(ION_TIMESTAMP_YMDHMS):
                        serial.append(serialize_vluint(value.second))

                        if format_len >= len(ION_TIMESTAMP_YMDHMSF):
                            serial.append(serialize_vlsint(fraction_exponent))
                            serial.append(serialize_signedint(
                                    (value.microsecond * int(10 ** -fraction_exponent)) // 1000000))

        return (IonBinary.timestamp_value_signature, serial.serialize())

    def deserialize_timestamp_value(self, data):
        serial = Deserializer(data)

        offset_minutes = deserialize_vlsint(serial, allow_minus_zero=True)
        year = deserialize_vluint(serial)
        month = deserialize_vluint(serial) if len(serial) > 0 else None
        day = deserialize_vluint(serial) if len(serial) > 0 else None
        hour = deserialize_vluint(serial) if len(serial) > 0 else None
        minute = deserialize_vluint(serial) if len(serial) > 0 else None
        second = deserialize_vluint(serial) if len(serial) > 0 else None

        if len(serial) > 0:
            fraction_exponent = deserialize_vlsint(serial)

            fraction_coefficient = deserialize_signedint(serial.extract()) if len(serial) > 0 else 0

            if fraction_coefficient == 0 and fraction_exponent > -1:
                microsecond = None
            else:
                if fraction_exponent < -6 or fraction_exponent > -1:
                    self.log.error("Unexpected IonTimestamp fraction exponent %d coefficient %d: %s" % (
                            fraction_exponent, fraction_coefficient, hex_string(data)))

                microsecond = (fraction_coefficient * 1000000) // int(10 ** -fraction_exponent)

                if microsecond < 0 or microsecond > 999999:
                    self.log.error("Incorrect IonTimestamp fraction %d usec: %s" % (microsecond, hex_string(data)))
                    microsecond = None
                    fraction_exponent = 0
        else:
            microsecond = None
            fraction_exponent = 0

        if month is None:
            format = ION_TIMESTAMP_Y
            offset_minutes = None
        elif day is None:
            format = ION_TIMESTAMP_YM
            offset_minutes = None
        elif hour is None:
            format = ION_TIMESTAMP_YMD
            offset_minutes = None
        elif second is None:
            format = ION_TIMESTAMP_YMDHM
        elif microsecond is None:
            format = ION_TIMESTAMP_YMDHMS
        else:
            format = ION_TIMESTAMP_YMDHMSF

        return IonTimestamp(year,
                    month if month is not None else 1,
                    day if day is not None else 1,
                    hour if hour is not None else 0,
                    minute if hour is not None else 0,
                    second if second is not None else 0,
                    microsecond if microsecond is not None else 0,
                    IonTimestampTZ(offset_minutes, format, -fraction_exponent))

    symbol_value_signature = 7

    def serialize_symbol_value(self, value):
        symbol_id = self.symtab.get_id(value)
        if not symbol_id:
            raise Exception("attempt to serialize undefined symbol %s" % repr(value))

        return (IonBinary.symbol_value_signature, serialize_unsignedint(symbol_id))

    def deserialize_symbol_value(self, data):
        return self.symtab.get_symbol(deserialize_unsignedint(data))

    string_value_signature = 8

    def serialize_string_value(self, value):

        return (IonBinary.string_value_signature, value.encode("utf-8"))

    def deserialize_string_value(self, data):
        return data.decode("utf-8")

    clob_value_signature = 9

    def serialize_clob_value(self, value):
        self.log.error("Serialize CLOB")
        return (IonBinary.clob_value_signature, str(value))

    def deserialize_clob_value(self, data):
        self.log.error("Deserialize CLOB")
        return IonCLOB(data)

    blob_value_signature = 10

    def serialize_blob_value(self, value):
        return (IonBinary.blob_value_signature, str(value))

    def deserialize_blob_value(self, data):
        return IonBLOB(data)

    list_value_signature = 11

    def serialize_list_value(self, value):
        serial = Serializer()
        for val in value:
            serial.append(self.serialize_value(val))

        return (IonBinary.list_value_signature, serial.serialize())

    def deserialize_list_value(self, data, top_level=False):
        serial = Deserializer(data)
        result = []
        while len(serial):
            value = self.deserialize_value(serial)

            if not isinstance(value, IonNop):
                result.append(value)

        return result

    sexp_value_signature = 12

    def serialize_sexp_value(self, value):
        return (IonBinary.sexp_value_signature, self.serialize_list_value(list(value))[1])

    def deserialize_sexp_value(self, data):
        return IonSExp(self.deserialize_list_value(data))

    struct_value_signature = 13

    def serialize_struct_value(self, value):
        serial = Serializer()

        for key,val in value.items():
            serial.append(serialize_vluint(self.symtab.get_id(key)))
            serial.append(self.serialize_value(val))

        return (IonBinary.struct_value_signature, serial.serialize())

    def deserialize_struct_value(self, flag, serial):
        if flag == IonBinary.SORTED_STRUCT_FLAG:

            self.log.error("BinaryIonStruct: Sorted IonStruct encountered")
            flag = IonBinary.VARIABLE_LEN_FLAG

        serial2 = Deserializer(serial.extract(deserialize_vluint(serial) if flag == IonBinary.VARIABLE_LEN_FLAG else flag))
        result = IonStruct()

        while len(serial2):
            id_symbol = self.symtab.get_symbol(deserialize_vluint(serial2))

            value = self.deserialize_value(serial2)
            if DEBUG: self.log.debug("IonStruct: %s = %s" % (repr(id_symbol), repr(value)))

            if not isinstance(value, IonNop):
                if id_symbol in result:

                    self.log.error("BinaryIonStruct: Duplicate field name %s" % id_symbol)

                result[id_symbol] = value

        return result

    annotation_value_signature = 14

    def serialize_annotation_value(self, value):
        if not value.annotations: raise Exception("Serializing IonAnnotation without annotations")

        serial = Serializer()

        annotation_data = Serializer()
        for annotation in value.annotations:
            annotation_data.append(serialize_vluint(self.symtab.get_id(annotation)))

        serial.append(serialize_vluint(len(annotation_data)))
        serial.append(annotation_data.serialize())

        serial.append(self.serialize_value(value.value))

        return (IonBinary.annotation_value_signature, serial.serialize())

    def deserialize_annotation_value(self, data):
        serial = Deserializer(data)

        annotation_length = deserialize_vluint(serial)
        annotation_data = Deserializer(serial.extract(annotation_length))

        ion_value = self.deserialize_value(serial)
        if len(serial): raise Exception("IonAnnotation has excess data: %s" % hex_string(serial.extract()))

        annotations = []
        while len(annotation_data):
            annotations.append(self.symtab.get_symbol(deserialize_vluint(annotation_data)))

        if len(annotations) == 0:
            raise Exception("IonAnnotation has no annotations")

        if len(annotations) != 1:
            self.log.error("IonAnnotation has %d annotations" % len(annotations))

        return IonAnnotation(annotations, ion_value)

    reserved_value_signature = 15

    def deserialize_reserved_value(self, data):
        raise Exception("Deserialize reserved ion value signature %d" % self.value_signature)

    VALUE_DESERIALIZERS = {
        null_value_signature: (False, deserialize_null_value, "null"),
        bool_value_signature: (False, deserialize_bool_value, "bool"),
        posint_value_signature: (True, deserialize_posint_value, "int"),
        negint_value_signature: (True, deserialize_negint_value, "int"),
        float_value_signature: (True, deserialize_float_value, "float"),
        decimal_value_signature: (True, deserialize_decimal_value, "decimal"),
        timestamp_value_signature: (True, deserialize_timestamp_value, "timestamp"),
        symbol_value_signature: (True, deserialize_symbol_value, "symbol"),
        string_value_signature: (True, deserialize_string_value, "string"),
        clob_value_signature: (True, deserialize_clob_value, "clob"),
        blob_value_signature: (True, deserialize_blob_value, "blob"),
        list_value_signature: (True, deserialize_list_value, "list"),
        sexp_value_signature: (True, deserialize_sexp_value, "sexp"),
        struct_value_signature: (False, deserialize_struct_value, "struct"),
        annotation_value_signature: (True, deserialize_annotation_value, "annotation"),
        reserved_value_signature: (True, deserialize_reserved_value, "reserved"),
        }

    ION_TYPE_HANDLERS = {
        IonAnnotation: serialize_annotation_value,
        IonBLOB: serialize_blob_value,
        IonBool: serialize_bool_value,
        IonCLOB: serialize_clob_value,
        IonDecimal: serialize_decimal_value,
        IonFloat: serialize_float_value,
        IonInt: serialize_int_value,
        IonList: serialize_list_value,
        IonNull: serialize_null_value,
        IonSExp: serialize_sexp_value,
        IonString: serialize_string_value,
        IonStruct: serialize_struct_value,
        IonSymbol: serialize_symbol_value,
        IonTimestamp: serialize_timestamp_value,
        }

def descriptor(signature, flag):
    if flag < 0 or flag > 0x0f:
        raise Exception("Serialize bad descriptor flag: %d" % flag)

    return chr((signature << 4) + flag)

def serialize_unsignedint(value):
    return ltrim0(struct.pack(b">Q", value))

def deserialize_unsignedint(data):
    if len(data) > 0 and ord(data[0]) == 0:
        raise Exception("BinaryIonInt data padded with 0x00")

    return struct.unpack_from(b">Q", lpad0(data, 8))[0]

def serialize_signedint(value):
    data = ltrim0x(struct.pack(b">Q", abs(value)))

    if value < 0:
        data = or_first_byte(data, 0x80)

    return data

def deserialize_signedint(data):
    if len(data) == 0: return 0

    if (ord(data[0]) & 0x80) != 0:
        return -(struct.unpack_from(b">Q", lpad0(and_first_byte(data, 0x7f), 8))[0])

    return struct.unpack_from(b">Q", lpad0(data, 8))[0]

def serialize_vluint(value):
    if value < 0:
        raise Exception("Cannot serialize negative value as IonVLUInt: %d" % value)

    data = chr((value & 0x7f) + 0x80)
    while True:
        value = value >> 7
        if value == 0:
            return data

        data = chr(value & 0x7f) + data

def deserialize_vluint(serial):
    value = 0
    while True:
        i = serial.unpack("B")
        value = (value << 7) | (i & 0x7f)

        if i >= 0x80:
            return value

        if value == 0:
            raise Exception("IonVLUInt padded with 0x00")

        if value > 0x7fffffffffffff:
            raise Exception("IonVLUInt data value is too large, missing terminator")

def serialize_vlsint(value):
    if value is None:
        return chr(0xc0)

    data = serialize_vluint(abs(value))

    if ord(data[0]) & 0x40:
        data = chr(0) + data

    if value < 0:
        data = or_first_byte(data, 0x40)

    return data

def deserialize_vlsint(serial, allow_minus_zero=False):
    first = serial.unpack("B")
    byte = first & 0xbf
    data = chr(byte) if byte != 0 else b""

    while (byte & 0x80) == 0:
        byte = serial.unpack("B")
        data += chr(byte)

    value = deserialize_vluint(Deserializer(data))

    if first & 0x40:
        if value:
            value = -value
        elif allow_minus_zero:
            value = None
        else:
            raise Exception("deserialize_vlsint unexpected -0 value")

    return value

def lpad0(data, size):

    if len(data) > size:

        extra = len(data) - size
        if data[:size] != chr(0) * extra:
            raise Exception("lpad0, length (%d) > max (%d): %s" % (len(data), size, hex_string(data)))

        return data[:size]

    return (chr(0) * (size - len(data)) + data)

def ltrim0(data):

    while len(data) and ord(data[0]) == 0:
        data = data[1:]

    return data

def ltrim0x(data):

    while len(data) and ord(data[0]) == 0:
        if len(data) > 1 and (ord(data[1]) & 0x80):
            break

        data = data[1:]

    return data

def combine_decimal_digits(digits, sign_negative):
    val = 0

    for digit in digits:
        val = (val * 10) + digit

    if sign_negative:
        val = -val

    return val

def and_first_byte(data, mask):
    return chr(ord(data[0]) & mask) + data[1:]

def or_first_byte(data, mask):
    return chr(ord(data[0]) | mask) + data[1:]

class Serializer(object):
    def __init__(self):
        self.buffers = []
        self.length = 0

    def pack(self, fmt, *values):
        fmt = fmt.encode("ascii")
        fmt_pos = (fmt, len(self.buffers))
        self.append(struct.pack(fmt, *values))
        return fmt_pos

    def repack(self, fmt_pos, *values):
        fmt, position = fmt_pos
        self.buffers[position] = struct.pack(fmt, *values)

    def append(self, buf):
        if buf:
            self.buffers.append(buf)
            self.length += len(buf)

    def extend(self, serializer):
        self.buffers.extend(serializer.buffers)
        self.length += serializer.length

    def __len__(self):
        return self.length

    def serialize(self):
        return b"".join(self.buffers)

    def sha1(self):
        sha1 = hashlib.sha1()
        for buf in self.buffers: sha1.update(buf)
        return sha1.digest()

class Deserializer(object):
    def __init__(self, data):
        self.buffer = data
        self.offset = 0

    def unpack(self, fmt, advance=True):
        fmt = fmt.encode("ascii")
        result = struct.unpack_from(fmt, self.buffer, self.offset)[0]

        if advance: self.offset += struct.calcsize(fmt)
        return result

    def extract(self, size=None, upto=None, advance=True):
        if size is None:
            size = len(self) if upto is None else (upto - self.offset)

        data = self.buffer[self.offset:self.offset + size]

        if len(data) < size or size < 0:
            raise Exception("Deserializer: Insufficient data (need %d bytes, have %d bytes)" % (size, len(data)))

        if advance: self.offset += size
        return data

    def __len__(self):
        return len(self.buffer) - self.offset

