# Copyright (c) 2007-2008 PediaPress GmbH
# See README.txt for additional licensing information.

"""
The parse tree generated by the parser is a 1:1 representation of the mw-markup.
Unfortunally these trees have some flaws if used to geenerate derived documents.

This module seeks to rebuild the parstree
to be:
 * more logical markup
 * clean up the parse tree
 * make it more accessible
 * allow for validity checks
 * implement rebuilding strategies

Usefull Documentation:
http://en.wikipedia.org/wiki/Wikipedia:Don%27t_use_line_breaks
http://meta.wikimedia.org/wiki/Help:Advanced_editing
"""

import weakref
from mwlib.parser import Magic, Math,  _VListNode, Ref, Link, URL, NamedURL # not used but imported
from mwlib.parser import CategoryLink, SpecialLink, Caption, LangLink # not used but imported
from mwlib.parser import Item, ItemList,  Node, Table, Row, Cell, Paragraph, PreFormatted
from mwlib.parser import Section, Style, TagNode, Text, Timeline
from mwlib.parser import  ImageLink, Article, Book, Chapter
import copy
from mwlib.log import Log

log = Log("advtree")


def _idIndex(lst, el):
    # return first appeareance of element in list
    for i, e in enumerate(lst):
        if e is el:
            return i
    return -1

class AdvancedNode:
    """
    MixIn Class that extends Nodes so they become easier accessible

    allows to traverse the tree in any direction and 
    build derived convinience functions
   """
    _parentref = None # weak referece to parent element
    isblocknode = False

    def copy(self):
        "return a copy of this node and all its children"
        n = copy.copy(self)
        n.children = []
        n._parentref = None
        for c in self:
            n.appendChild(c.copy())
        return n


    def moveto(self, targetnode, prefix=False):
        """
        moves this node after target node
        if prefix is true, move in front of target node
        """
        if self.parent:
            self.parent.removeChild(self)
        tp = targetnode.parent
        idx = _idIndex(tp.children, targetnode)
        if not prefix:
            idx+=1
        tp.children = tp.children[:idx] + [self] + tp.children[idx:]
        self._parentref = weakref.ref(tp)
        
    def appendChild(self, c):
        self.children.append(c)
        c._parentref = weakref.ref(self)

    def remove(self):
        if self.parent:
            for (idx, n) in enumerate(self.parent.children):
                if n is self:
                    self.parent.children = self.parent.children[:idx] + self.parent.children[idx+1:]
                    return 0
        else:
            return 1
            
    def removeChild(self, c):
        self.replaceChild(c, [])

    def replaceChild(self, c, newchildren = []):
        idx = _idIndex(self.children, c)
        self.children.remove(c)           
        c._parentref = None
        if newchildren:
            self.children = self.children[:idx] + newchildren + self.children[idx:]
            for nc in newchildren:
                nc._parentref = weakref.ref(self)

    def getParents(self):
        if self.parent:
            return self.parent.getParents() + [self.parent]
        else:
            return []

    def getParent(self):
        if not self._parentref:
            return None
        x = self._parentref()
        if not x:
            raise weakref.ReferenceError
        return x

    def getLevel(self):
        "returns the number of nodes of same class in parents"
        return [p.__class__ for p in self.getParents()].count(self.__class__)

   
    def getParentNodesByClass(self, klass):
        "returns parents w/ klass"
        return [p for p in self.parents if p.__class__ == klass]

    def getChildNodesByClass(self, klass):
        "returns all children  w/ klass"
        return [p for p in self.getAllChildren() if p.__class__ == klass]

    def getAllChildren(self):
        "don't confuse w/ Node.allchildren() which returns allchildren + self"
        for c in self.children:
            yield c
            for x in c.getAllChildren():
                yield x        
        
    def getSiblings(self):
        return [c for c in self.getAllSiblings() if c is not self]

    def getAllSiblings(self):
        "all siblings plus me my self and i"
        if self.parent:
            return self.parent.children
        return []

    def getPrevious(self):
        "return previous sibling"
        s = self.getAllSiblings()
        try:
            idx = _idIndex(s,self)
        except ValueError:
            return None
        if idx -1 <0:
            return None
        else:
            return s[idx-1]

    def getNext(self):
        "return next sibling"
        s = self.getAllSiblings()
        try:
            idx = _idIndex(s,self)
        except ValueError:
            return None
        if idx+1 >= len(s):
            return None
        else:
            return s[idx+1]

    def getLast(self):
        "return last sibling"
        s = self.getAllSiblings()
        if s:
            return s[-1]

    def getFirst(self):
        "return first sibling"
        s = self.getAllSiblings()
        if s:
            return s[0]

    def getLastChild(self):
        "return last child of this node"
        if self.children:
            return self.children[-1]

    def getFirstChild(self):
        "return first child of this node"
        if self.children:
            return self.children[0]

    def getAllDisplayText(self, amap = None):
        "return all text that is intended for display"
        text = []
        if not amap:
            amap = {Text:"caption", Link:"target", URL:"caption", Math:"caption", ImageLink:"caption" }
        for n in self.allchildren():
            access = amap.get(n.__class__, "")
            if access:
                text.append( getattr(n, access) )
        alltext = [t for t in text if t]
        if alltext:
            return u''.join(alltext)
        else:
            return ''

    parent = property(getParent)
    parents = property(getParents)
    next = property(getNext)
    previous = property(getPrevious)
    siblings = property(getSiblings)
    last = property(getLast)
    first = property(getFirst)
    lastchild = property(getLastChild)
    firstchild = property(getFirstChild)
    


# --------------------------------------------------------------------------
# MixinClasses w/ special behaviour
# -------------------------------------------------------------------------

class AdvancedTable(AdvancedNode):    
    @property 
    def rows(self):
        return [r for r in self if r.__class__ == Row]

    @property 
    def numcols(self):
        cols = [[n.__class__ for n in row].count(Cell) for row in self.rows]
        if cols:
            return max(cols)
        else:
            return 0

class AdvancedRow(AdvancedNode):    
    @property 
    def cells(self):
        return [c for c in self if c.__class__ == Cell]
    

class AdvancedSection(AdvancedNode):
    h_level = 0 # this is set if it originates from an H1, H2, ... TagNode
    def getSectionLevel(self):
        return 1 + self.getLevel()

class AdvancedImageLink(AdvancedNode):
    isblocknode = property ( lambda s: not s.isInline() )
    
class AdvancedMath(AdvancedNode):
    @property
    def isblocknode(self):
        if self.caption.strip().startswith("\\begin{align}")  or \
                self.caption.strip().startswith("\\begin{alignat}"):
            return True
        return False

       

# --------------------------------------------------------------------------
# Missing as Classes derived from parser.Style
# -------------------------------------------------------------------------

    
class Emphasized(Style, AdvancedNode):
    "EM"
    pass

class Strong(Style, AdvancedNode):
    pass

class DefinitionList(Style, AdvancedNode):
    "DL"
    pass

class DefinitionTerm(Style, AdvancedNode):
    "DT"
    pass

class DefinitionDescription(Style, AdvancedNode):
    "DD"
    pass

class Blockquote(Style, AdvancedNode):
    "margins to left &  right"
    pass

class Indented(Style, AdvancedNode):
    "margin to the left"

class Overline(Style, AdvancedNode):
    _style = "overline"

class Underline(Style, AdvancedNode):
    _style = "u"

class Sub(Style, AdvancedNode):
    _style = "sub"

class Sup(Style, AdvancedNode):
    _style = "sup"

class Small(Style, AdvancedNode):
    _style = "small"

class Big(Style, AdvancedNode):
    _style = "big"

class Cite(Style, AdvancedNode):
    _style = "cite"


_styleNodeMap = dict( (k._style,k) for k in [Overline, Underline, Sub, Sup, Small, Big, Cite] )

# --------------------------------------------------------------------------
# Missing as Classes derived from parser.TagNode
# -------------------------------------------------------------------------


class Code(TagNode, AdvancedNode):
    _tag = "code"

class BreakingReturn(TagNode, AdvancedNode):
    _tag = "br"

class HorizontalRule(TagNode, AdvancedNode):
    _tag = "hr"

class Index(TagNode, AdvancedNode):
    _tag = "index"

class Teletyped(TagNode, AdvancedNode):
    _tag = "tt"

class Reference(TagNode, AdvancedNode):
    _tag = "ref"

class ReferenceList(TagNode, AdvancedNode):
    _tag = "references"

class Gallery(TagNode, AdvancedNode):
    _tag = "gallery"

class Center(TagNode, AdvancedNode):
    _tag = "center"

class Div(TagNode, AdvancedNode):
    _tag = "div"

class Span(TagNode, AdvancedNode): # span is defined as inline node which is in theory correct. 
    _tag = "span"

class Strike(TagNode,AdvancedNode):
    _tag = "strike"

class ImageMap(TagNode, AdvancedNode): # defined as block node, maybe incorrect
    _tag = "imagemap"
    
_tagNodeMap = dict( (k._tag,k) for k in [Code, BreakingReturn, HorizontalRule, Index, Teletyped, Reference, ReferenceList, Gallery, Center, Div, Span, Strike, ImageMap] )
_styleNodeMap["s"] = Strike # Special Handling for deprecated s style


# --------------------------------------------------------------------------
# BlockNode separation for AdvancedNode.isblocknode
# -------------------------------------------------------------------------

"""
For writers it is usefull to know whether elements are inline (within a paragraph) or not.
We define list for blocknodes, which are used in AdvancedNode as:

AdvancedNode.isblocknode

Image depends on result of Image.isInline() see above

Open Issues: Math, Magic, (unknown) TagNode 

"""
_blockNodesMap = (Book, Chapter, Article, Section, Paragraph, Div,
                  PreFormatted, Cell, Row, Table, Item, BreakingReturn,
                  ItemList, Timeline, Cite, HorizontalRule, Gallery, Indented, 
                  DefinitionList, DefinitionTerm, DefinitionDescription, ReferenceList)

for k in _blockNodesMap:  
  k.isblocknode = True



# --------------------------------------------------------------------------
# funcs for extending the nodes
# -------------------------------------------------------------------------

def MixIn(pyClass, mixInClass, makeFirst=False):
  if mixInClass not in pyClass.__bases__:
    if makeFirst:
      pyClass.__bases__ = (mixInClass,) + pyClass.__bases__
    else:
      pyClass.__bases__ += (mixInClass,)

def extendClasses(node):
    for c in node.children[:]:
        extendClasses(c)
        c._parentref = weakref.ref(node)            

# Nodes we defined above and that are separetly handled in extendClasses
_advancedNodesMap = {Section: AdvancedSection, ImageLink:AdvancedImageLink, 
                     Math:AdvancedMath, Row:AdvancedRow, Table:AdvancedTable}
MixIn(Node, AdvancedNode)
for k, v in _advancedNodesMap.items():
    MixIn(k,v)
    
# --------------------------------------------------------------------------
# funcs for repairing the tree
# -------------------------------------------------------------------------


def fixTagNodes(node):
    """
    detect known TagNode(s) and associate appropriate Nodes
    """
    for c in node.children:
        if c.__class__ == TagNode:
            if c.caption in _tagNodeMap:
                c.__class__ = _tagNodeMap[c.caption]
            elif c.caption in ("h1", "h2", "h3", "h4", "h5", "h6"): # FIXME
                # NEED TO MOVE NODE IF IT REALLY STARTS A SECTION
                c.__class__ = Section 
                MixIn(c.__class__, AdvancedSection)
                c._h_level = int(c.caption[1])
                c.caption = ""
            else:
                log.warn("fixTagNodes, unknowntagnode %r" % c)
                #raise Exception, "unknown tag %s" % c.caption # FIXME
        fixTagNodes(c)


def fixStyle(node):
    """
    parser.Style Nodes are mapped to logical markup
    detection of DefinitionList depends on removeNodes
    and removeNewlines
    """
    if not node.__class__ == Style:
        return
    # replace this node by a more apporiate
    if node.caption == "''": 
        node.__class__ = Emphasized
        node.caption = ""
    elif node.caption=="'''''":
        node.__class__ = Strong
        node.caption = ""
        em = Emphasized("''")
        for c in node.children:
            em.appendChild(c)
        node.children = []
        node.appendChild(em)
    elif node.caption == "'''":
        node.__class__ = Strong
        node.caption = ""
    elif node.caption == ";": 
        # this starts a definition list ? DL [DT->DD, ...]
        # check if previous node is DefinitionList, if not create one
        if node.previous.__class__ == DefinitionList:
            node.__class__ = DefinitionTerm
            node.moveto(node.previous.lastchild)
        else:
            node.__class__ = DefinitionList
            dt = DefinitionTerm()
            for c in node.children:
                dt.appendChild(c)
            node.children = []
            node.appendChild(dt)
    elif node.caption.startswith(":"): 
        if node.previous.__class__ == DefinitionList:
            node.__class__ = DefinitionDescription
            node.moveto(node.previous.lastchild)
            node.caption = ""
        else:
            node.__class__ = Indented
    elif node.caption in _styleNodeMap:
        node.__class__ = _styleNodeMap[node.caption]
        node.caption = ""
    else:
        log.warn("fixStyle, unknownstyle %r" % node)
        #raise Exception, "unknown style %s" % node.caption # FIXME
        pass
    return node

def fixStyles(node):
    if node.__class__ == Style:
        fixStyle(node)
    for c in node.children[:]:
        fixStyles(c)


def removeNodes(node):
    """
    the parser generates empty Node elements that do 
    nothing but group other nodes. we remove them here
    """
    if node.__class__ == Node:
        # first child of section groups heading text - grouping Node must not be removed
        if not (node.previous == None and node.parent.__class__ == Section): 
            node.parent.replaceChild(node, node.children)
    for c in node.children[:]:
        removeNodes(c)

def removeNewlines(node):
    """
    remove newlines, tabs, spaces if we are next to a blockNode
    """
    if node.__class__ == Text and not node.getParentNodesByClass(PreFormatted):
        if node.caption.strip() == u"":
            prev = node.previous or node.parent # previous sibling node or parentnode 
            next = node.next or node.parent.next
            if not next or next.isblocknode or not prev or prev.isblocknode: 
                node.parent.removeChild(node)    
        node.caption = node.caption.replace("\n", " ")
      
    for c in node.children[:]:
        removeNewlines(c)

        


def buildAdvancedTree(root): # USE WITH CARE
    """
    extends and cleans parse trees
    do not use this funcs without knowing whether these 
    Node modifications fit your problem
    """
    extendClasses(root) 
    fixTagNodes(root)
    removeNodes(root)
    removeNewlines(root)
    fixStyles(root) 

def getAdvTree(fn):
    from mwlib.dummydb import DummyDB
    from mwlib.uparser import parseString
    db = DummyDB()
    input = unicode(open(fn).read(), 'utf8')
    r = parseString(title=fn, raw=input, wikidb=db)
    buildAdvancedTree(r)
    return r



