diff --git a/mwparserfromhell/parser/demo.py b/mwparserfromhell/parser/demo.py index c98205a..95e4e76 100644 --- a/mwparserfromhell/parser/demo.py +++ b/mwparserfromhell/parser/demo.py @@ -31,6 +31,7 @@ class DemoParser(object): return text def parse(self, text): + # Ensure text is unicode! text = u"This is a {{test}} message with a {{template|with|foo={{params}}}}." node1 = Text(u"This is a ") diff --git a/mwparserfromhell/string_mixin.py b/mwparserfromhell/string_mixin.py index 2b3af46..74ae42f 100644 --- a/mwparserfromhell/string_mixin.py +++ b/mwparserfromhell/string_mixin.py @@ -22,7 +22,7 @@ __all__ = ["StringMixIn"] -class StringMixIn(unicode): +class StringMixIn(object): def __str__(self): return unicode(self).encode("utf8") diff --git a/mwparserfromhell/wikicode.py b/mwparserfromhell/wikicode.py index 9cc9533..9243b9a 100644 --- a/mwparserfromhell/wikicode.py +++ b/mwparserfromhell/wikicode.py @@ -55,17 +55,27 @@ class Wikicode(StringMixIn): raise ValueError(error.format(type(value), value)) return nodelist - def _get_children(self, node): - yield node + def _iterate_over_children(self, node): + yield (None, node) if isinstance(node, Template): for child in self._get_all_nodes(node.name): - yield child + yield (node.name, child) for param in node.params: if param.showkey: for child in self._get_all_nodes(param.name): - yield child + yield (param.name, child) for child in self._get_all_nodes(param.value): - yield child + yield (param.value, child) + + def _get_children(self, node): + for context, child in self._iterate_over_children(node): + yield child + + def _get_context(self, node, obj): + for context, child in self._iterate_over_children(node): + if child is obj: + return context + raise ValueError(obj) def _get_all_nodes(self, code): for node in code.nodes: @@ -91,18 +101,18 @@ class Wikicode(StringMixIn): return True return False - def _do_search(self, obj, value, recursive, callback, context=None): + def _do_search(self, obj, recursive, callback, context, *args, **kwargs): if recursive: - nodes = context.nodes if context else self.nodes - for i, node in enumerate(nodes): + for i, node in enumerate(context.nodes): if self._is_equivalent(obj, node): - return callback(self, value, i) + return callback(context, i, *args, **kwargs) if self._contains(self._get_children(node), obj): - return self._do_search(obj, value, recursive, callback, - context=node) + context = self._get_context(node, obj) + return self._do_search(obj, recursive, callback, context, + *args, **kwargs) raise ValueError(obj) - callback(self, value, self.index(obj, recursive=False)) + callback(context, self.index(obj, recursive=False), *args, **kwargs) def _get_tree(self, code, lines, marker=None, indent=0): def write(*args): @@ -164,19 +174,19 @@ class Wikicode(StringMixIn): self.nodes.insert(index, node) def insert_before(self, obj, value, recursive=True): - callback = lambda self, value, i: self.insert(i, value) - self._do_search(obj, value, recursive, callback) + callback = lambda self, i, value: self.insert(i, value) + self._do_search(obj, recursive, callback, self, value) def insert_after(self, obj, value, recursive=True): - callback = lambda self, value, i: self.insert(i + 1, value) - self._do_search(obj, value, recursive, callback) + callback = lambda self, i, value: self.insert(i + 1, value) + self._do_search(obj, recursive, callback, self, value) def replace(self, obj, value, recursive=True): - def callback(self, value, i): + def callback(self, i, value): self.nodes.pop(i) self.insert(i, value) - self._do_search(obj, value, recursive, callback) + self._do_search(obj, recursive, callback, self, value) def append(self, value): nodes = self._nodify(value) @@ -184,15 +194,8 @@ class Wikicode(StringMixIn): self.nodes.append(node) def remove(self, obj, recursive=True): - if recursive: - for i, node in enumerate(self.nodes): - if self._is_equivalent(obj, node): - return self.nodes.pop(i) - if self._contains(self._get_children(node), obj): - return node.remove(obj, recursive=True) - raise ValueError(obj) - - return self.nodes.pop(self.index(obj)) + callback = lambda self, i: self.nodes.pop(i) + self._do_search(obj, recursive, callback, self) def ifilter(self, recursive=False, matches=None, flags=FLAGS, forcetype=None):