#!/usr/bin/env python

import sys
import struct
import binascii

class Unbuffered:
        def __init__(self, stream):
                self.stream = stream
        def write(self, data):
                self.stream.write(data)
                self.stream.flush()
        def __getattr__(self, attr):
                return getattr(self.stream, attr)

# important pdb header offsets
id_seed = 68
num_records = 76
first_pdb_record = 78
book_length = 4
book_record_count = 8

# important rec0 offsets
mobi_header_base = 16
length_of_book = 4
mobi_type = 24
mobi_header_length = 20
first_non_text = 80
title_offset = 84
first_image_record = 108
srcs = 212
srcs_count = 216
flis_index = 208
fcis_index = 200
first_content_index = 192
last_content_index = 194
primary_index = 244



def getint(datain,ofs,len='L'):
        i, = struct.unpack_from('>'+len,datain,ofs)
        return i

def writeint(datain,ofs,n,len='L'):
        if len=='L':
                return datain[:ofs]+struct.pack('>L',n)+datain[ofs+4:]
        else:
                return datain[:ofs]+struct.pack('>H',n)+datain[ofs+2:]

def getsecaddr(datain,secno):
        print("getsecaddr %d" % secno)
        nsec = getint(datain,num_records,'H')
        assert secno>=0 & secno<nsec,'secno out of range'
        secstart = getint(datain,first_pdb_record+secno*8)
        if secno == nsec-1:
                secend = len(datain)
        else:
                secend = getint(datain,first_pdb_record+(secno+1)*8)
        print("getsecaddr %x %x" % (secstart, secend))
        return secstart,secend

def readsection(datain,secno):
        secstart, secend = getsecaddr(datain,secno)
        return datain[secstart:secend]

def writesection(datain,secno,secdata):
        secstart, secend = getsecaddr(datain,secno)
        print("write section (%d) offset=%x" % (secno,secstart))
        dataout = datain[:secstart]+secdata+datain[secend:]
        dif = len(secdata) - (secend-secstart)
        if dif == 0:
                return dataout
        nsec = getint(datain,num_records,'H')
        if secno == nsec-1:
                return dataout
        for i in range(secno+1,nsec):
                ofs, = struct.unpack_from('>L',dataout,first_pdb_record+i*8)
                ofs = ofs+dif
                dataout = dataout[:first_pdb_record+i*8]+struct.pack('>L',ofs)+dataout[first_pdb_record+i*8+4:]
        return dataout

def nullsection(datain,secno):
        secstart, secend = getsecaddr(datain,secno)
        print("null section (%d) offset=%x" % (secno,secstart))
        dataout = datain[:secstart]+datain[secend:]
        dif =  -(secend-secstart)
        if dif == 0:
                return dataout
        nsec = getint(datain,num_records,'H')
        if secno == nsec-1:
                return dataout
        for i in range(secno+1,nsec):
                ofs, = struct.unpack_from('>L',dataout,first_pdb_record+i*8)
                ofs = ofs+dif
                dataout = dataout[:first_pdb_record+i*8]+struct.pack('>L',ofs)+dataout[first_pdb_record+i*8+4:]
        return dataout

def deletesectionrange(datain,firstsec,lastsec):
        print("Deleting records %d - %d" % (firstsec,lastsec))
        firstsecstart,firstsecend = getsecaddr(datain,firstsec)
        lastsecstart,lastsecend = getsecaddr(datain,lastsec)
        dif = lastsecend - firstsecstart
        #print("dif = %d, original size = %d" % (dif,len(datain)))
        dataout = datain[:firstsecstart]+datain[lastsecend:]
        #print("len dataout = %d" % len(dataout))
        nsec = getint(datain,num_records,'H')
        #print("original nrec = %d" % nsec)
        dataout = writeint(dataout,num_records,nsec-(lastsec-firstsec+1),'H')
        #print("new nrec = %d" % getint(dataout,num_records,'H'))
        for i in range(0,nsec):
                ofs, = struct.unpack_from('>L',dataout,first_pdb_record+i*8)
                ofs = ofs-8*(lastsec-firstsec+1)
                dataout = dataout[:first_pdb_record+i*8]+struct.pack('>L',ofs)+dataout[first_pdb_record+i*8+4:] 
        for i in range(lastsec+1,nsec):
                ofs, = struct.unpack_from('>L',dataout,first_pdb_record+i*8)
                #print("update pdb record %d ofs %d -> %d" % (i,ofs,ofs-dif))
                ofs = ofs-dif
                it = 2*i
                #print("ofs,it=%d,%d" % (ofs,it))
                dataout = dataout[:first_pdb_record+i*8]+\
                          struct.pack('>L',ofs)+struct.pack('L',it)+\
                          dataout[first_pdb_record+i*8+8:]
        dataout = dataout[:first_pdb_record+firstsec*8]+dataout[first_pdb_record+(lastsec+1)*8:]
        return dataout

def insertsection(datain,secno,secdata):
        nsec = getint(datain,num_records,'H')
        print("Inserting before %d, nsec=%d" % (secno,nsec))
        if secno == nsec:
                newsecstart = len(datain)
                print("append section (%d) offset=%x" % (secno,newsecstart))
        else:
                insert_secstart, insert_secend = getsecaddr(datain,secno)
                newsecstart = insert_secstart
                print("insert section (%d) offset=%x" % (secno,newsecstart))
        dataout = datain[:id_seed]+struct.pack('>L',2*(nsec+1)+1)+datain[id_seed+4:first_pdb_record-2]+struct.pack('>H',nsec+1)
        
        for i in range(0,secno):
                ofs = getint(datain,first_pdb_record+i*8)+8
                dataout += struct.pack('>L',ofs)+struct.pack('>L',2*i)
        dataout += struct.pack('>L',newsecstart+8)+struct.pack('>L',2*secno)
        for i in range(secno,nsec):
                ofs = getint(datain,first_pdb_record+i*8)+len(secdata)+8
                dataout += struct.pack('>L',ofs)+struct.pack('>L',2*i)
        r0start,r0end=getsecaddr(dataout,0)
        #print("r0start = %x pdb-end = %x" % (r0start,(first_pdb_record+8*(nsec+1))))
        dataout += '\0' * (r0start-(first_pdb_record+8*(nsec+1)))
        dataout += datain[r0start-8:newsecstart]+secdata+datain[newsecstart:]
        return dataout

def insertsectionrange(sectionsource,firstsec,lastsec,sectiontarget,targetsec):
        print("insertsectionrange %d,%d at %d" % (firstsec,lastsec,targetsec))
        dataout = sectiontarget
        for idx in range(lastsec,firstsec-1,-1):
                print("inserting %d" % idx)
                dataout = insertsection(dataout,targetsec,readsection(sectionsource,idx))
        return dataout
        
        

def get_exth_params(rec0):
        ebase = mobi_header_base + getint(rec0,mobi_header_length)
        elen = getint(rec0,ebase+4)
        enum = getint(rec0,ebase+8)
        return ebase,elen,enum

def add_exth(rec0,exth_num,exth_bytes):
        ebase,elen,enum = get_exth_params(rec0)
        newrecsize = 8+len(exth_bytes)
        newrec0 = rec0[0:ebase+4]+struct.pack('>L',elen+newrecsize)+struct.pack('>L',enum+1)+\
                  struct.pack('>L',exth_num)+struct.pack('>L',newrecsize)+exth_bytes+rec0[ebase+12:]
        newrec0 = writeint(newrec0,title_offset,getint(newrec0,title_offset)+newrecsize)
        return newrec0
        
def read_exth(rec0,exth_num):
        ebase,elen,enum = get_exth_params(rec0)
        ebase = ebase+12
        while enum>0:
                exth_id = getint(rec0,ebase)
                if exth_id == exth_num:
                        return rec0[ebase+8:ebase+getint(rec0,ebase+4)]
                enum = enum-1
                ebase = ebase+getint(rec0,ebase+4)
        return ''

def write_exth(rec0,exth_num,exth_bytes):
        ebase,elen,enum = get_exth_params(rec0)
        ebase_idx = ebase+12
        enum_idx = enum
        while enum_idx>0:
                exth_id = getint(rec0,ebase_idx)
                if exth_id == exth_num:
                        dif = len(exth_bytes)+8-getint(rec0,ebase_idx+4)
                        newrec0 = rec0
                        if dif != 0:
                                newrec0 = writeint(newrec0,title_offset,getint(newrec0,title_offset)+dif)
                        return newrec0[:ebase+4]+struct.pack('>L',elen+len(exth_bytes)+8-getint(rec0,ebase_idx+4))+\
                                                          struct.pack('>L',enum)+rec0[ebase+12:ebase_idx+4]+\
                                                          struct.pack('>L',len(exth_bytes)+8)+exth_bytes+\
                                                          rec0[ebase_idx+getint(rec0,ebase_idx+4):]
                enum_idx = enum_idx-1
                ebase_idx = ebase_idx+getint(rec0,ebase_idx+4)
        return rec0

def del_exth(rec0,exth_num):
        ebase,elen,enum = get_exth_params(rec0)
        ebase_idx = ebase+12
        enum_idx = enum
        while enum_idx>0:
                exth_id = getint(rec0,ebase_idx)
                if exth_id == exth_num:
                        dif = getint(rec0,ebase_idx+4)
                        newrec0 = rec0
                        newrec0 = writeint(newrec0,title_offset,getint(newrec0,title_offset)-dif)
                        newrec0 = newrec0[:ebase_idx]+newrec0[ebase_idx+dif:]
                        newrec0 = newrec0[0:ebase+4]+struct.pack('>L',elen-dif)+struct.pack('>L',enum-1)+newrec0[ebase+12:]             
                        return newrec0
                enum_idx = enum_idx-1
                ebase_idx = ebase_idx+getint(rec0,ebase_idx+4)
        return rec0
        
                
class kfdel:

        def __init__(self, datain):
                datain_rec0 = readsection(datain,0)
                datain_kf8, = struct.unpack_from('>L',read_exth(datain_rec0,121),0)
                datain_kfrec0 =readsection(datain,datain_kf8)

                firstimage = getint(datain_rec0,first_image_record)
                lastimage = getint(datain_rec0,last_content_index,'H')

                self.result_file = deletesectionrange(datain,0,datain_kf8-1)
                target = getint(datain_kfrec0,first_image_record)
                self.result_file = insertsectionrange(datain,firstimage,lastimage,self.result_file,target)
                datain_kfrec0 =readsection(self.result_file,0)
                ofs_list = [(194,'H'),(200,'L'),(208,'L')]
                for ofs,sz in ofs_list:
                        n = getint(datain_kfrec0,ofs,sz)+lastimage-firstimage+1
                        datain_kfrec0 = writeint(datain_kfrec0,ofs,n,sz)
                self.result_file = writesection(self.result_file,0,datain_kfrec0)
                
        def getResult(self):
                return self.result_file


if __name__ == "__main__":

        if len(sys.argv)<3 or len(sys.argv)>4:
                print "Usage:"
                print "    %s <infile1> <outfile>" % sys.argv[0]
                sys.exit(1)
        else:
                infile1 = sys.argv[1]
                outfile = sys.argv[2]
                data_file1 = file(infile1, 'rb').read()

                mergedFile = kfdel(data_file1)
                file(outfile, 'wb').write(mergedFile.getResult())
        sys.exit(0)
