Browse Source

Fixes, plus don't inherit from unicode.

tags/v0.1
Ben Kurtovic 12 years ago
parent
commit
b77497e12a
3 changed files with 32 additions and 28 deletions
  1. +1
    -0
      mwparserfromhell/parser/demo.py
  2. +1
    -1
      mwparserfromhell/string_mixin.py
  3. +30
    -27
      mwparserfromhell/wikicode.py

+ 1
- 0
mwparserfromhell/parser/demo.py View File

@@ -31,6 +31,7 @@ class DemoParser(object):
return text return text


def parse(self, text): def parse(self, text):
# Ensure text is unicode!
text = u"This is a {{test}} message with a {{template|with|foo={{params}}}}." text = u"This is a {{test}} message with a {{template|with|foo={{params}}}}."


node1 = Text(u"This is a ") node1 = Text(u"This is a ")


+ 1
- 1
mwparserfromhell/string_mixin.py View File

@@ -22,7 +22,7 @@


__all__ = ["StringMixIn"] __all__ = ["StringMixIn"]


class StringMixIn(unicode):
class StringMixIn(object):
def __str__(self): def __str__(self):
return unicode(self).encode("utf8") return unicode(self).encode("utf8")




+ 30
- 27
mwparserfromhell/wikicode.py View File

@@ -55,17 +55,27 @@ class Wikicode(StringMixIn):
raise ValueError(error.format(type(value), value)) raise ValueError(error.format(type(value), value))
return nodelist return nodelist


def _get_children(self, node):
yield node
def _iterate_over_children(self, node):
yield (None, node)
if isinstance(node, Template): if isinstance(node, Template):
for child in self._get_all_nodes(node.name): for child in self._get_all_nodes(node.name):
yield child
yield (node.name, child)
for param in node.params: for param in node.params:
if param.showkey: if param.showkey:
for child in self._get_all_nodes(param.name): for child in self._get_all_nodes(param.name):
yield child
yield (param.name, child)
for child in self._get_all_nodes(param.value): for child in self._get_all_nodes(param.value):
yield child
yield (param.value, child)

def _get_children(self, node):
for context, child in self._iterate_over_children(node):
yield child

def _get_context(self, node, obj):
for context, child in self._iterate_over_children(node):
if child is obj:
return context
raise ValueError(obj)


def _get_all_nodes(self, code): def _get_all_nodes(self, code):
for node in code.nodes: for node in code.nodes:
@@ -91,18 +101,18 @@ class Wikicode(StringMixIn):
return True return True
return False return False


def _do_search(self, obj, value, recursive, callback, context=None):
def _do_search(self, obj, recursive, callback, context, *args, **kwargs):
if recursive: if recursive:
nodes = context.nodes if context else self.nodes
for i, node in enumerate(nodes):
for i, node in enumerate(context.nodes):
if self._is_equivalent(obj, node): if self._is_equivalent(obj, node):
return callback(self, value, i)
return callback(context, i, *args, **kwargs)
if self._contains(self._get_children(node), obj): if self._contains(self._get_children(node), obj):
return self._do_search(obj, value, recursive, callback,
context=node)
context = self._get_context(node, obj)
return self._do_search(obj, recursive, callback, context,
*args, **kwargs)
raise ValueError(obj) raise ValueError(obj)


callback(self, value, self.index(obj, recursive=False))
callback(context, self.index(obj, recursive=False), *args, **kwargs)


def _get_tree(self, code, lines, marker=None, indent=0): def _get_tree(self, code, lines, marker=None, indent=0):
def write(*args): def write(*args):
@@ -164,19 +174,19 @@ class Wikicode(StringMixIn):
self.nodes.insert(index, node) self.nodes.insert(index, node)


def insert_before(self, obj, value, recursive=True): def insert_before(self, obj, value, recursive=True):
callback = lambda self, value, i: self.insert(i, value)
self._do_search(obj, value, recursive, callback)
callback = lambda self, i, value: self.insert(i, value)
self._do_search(obj, recursive, callback, self, value)


def insert_after(self, obj, value, recursive=True): def insert_after(self, obj, value, recursive=True):
callback = lambda self, value, i: self.insert(i + 1, value)
self._do_search(obj, value, recursive, callback)
callback = lambda self, i, value: self.insert(i + 1, value)
self._do_search(obj, recursive, callback, self, value)


def replace(self, obj, value, recursive=True): def replace(self, obj, value, recursive=True):
def callback(self, value, i):
def callback(self, i, value):
self.nodes.pop(i) self.nodes.pop(i)
self.insert(i, value) self.insert(i, value)


self._do_search(obj, value, recursive, callback)
self._do_search(obj, recursive, callback, self, value)


def append(self, value): def append(self, value):
nodes = self._nodify(value) nodes = self._nodify(value)
@@ -184,15 +194,8 @@ class Wikicode(StringMixIn):
self.nodes.append(node) self.nodes.append(node)


def remove(self, obj, recursive=True): def remove(self, obj, recursive=True):
if recursive:
for i, node in enumerate(self.nodes):
if self._is_equivalent(obj, node):
return self.nodes.pop(i)
if self._contains(self._get_children(node), obj):
return node.remove(obj, recursive=True)
raise ValueError(obj)

return self.nodes.pop(self.index(obj))
callback = lambda self, i: self.nodes.pop(i)
self._do_search(obj, recursive, callback, self)


def ifilter(self, recursive=False, matches=None, flags=FLAGS, def ifilter(self, recursive=False, matches=None, flags=FLAGS,
forcetype=None): forcetype=None):


Loading…
Cancel
Save