diff --git a/mwparserfromhell/wikicode.py b/mwparserfromhell/wikicode.py index a5437be..ae7f4c6 100644 --- a/mwparserfromhell/wikicode.py +++ b/mwparserfromhell/wikicode.py @@ -84,7 +84,18 @@ class Wikicode(StringMixIn): return True return False - def _show_tree(self, code, lines, marker=None, indent=0): + def _do_search(self, obj, value, recursive, callback): + if recursive: + for i, node in enumerate(self.nodes): + if self._is_equivalent(obj, node): + return callback(value, i) + if self._contains(self._get_children(node), obj): + return self._do_search(obj, value, recursive, callback) + raise ValueError(obj) + + callback(value, self.index(obj, recursive=False)) + + def _get_tree(self, code, lines, marker=None, indent=0): def write(*args): if lines and lines[-1] is marker: # Continue from the last line lines.pop() # Remove the marker @@ -96,14 +107,14 @@ class Wikicode(StringMixIn): for node in code.nodes: if isinstance(node, Template): write("{{", ) - self._show_tree(node.name, lines, marker, indent + 1) + self._get_tree(node.name, lines, marker, indent + 1) for param in node.params: write(" | ") lines.append(marker) # Continue from this line - self._show_tree(param.name, lines, marker, indent + 1) + self._get_tree(param.name, lines, marker, indent + 1) write(" = ") lines.append(marker) # Continue from this line - self._show_tree(param.value, lines, marker, indent + 1) + self._get_tree(param.value, lines, marker, indent + 1) write("}}") elif isinstance(node, Text): write(unicode(node)) @@ -146,14 +157,19 @@ class Wikicode(StringMixIn): self.nodes.insert(index, node) def insert_before(self, obj, value, recursive=True): - if obj not in self.nodes: - raise KeyError(obj) - self.insert(self.index(obj), value) + callback = lambda value, i: self.insert(i, value) + self._do_search(obj, value, recursive, callback) def insert_after(self, obj, value, recursive=True): - if obj not in self.nodes: - raise KeyError(obj) - self.insert(self.index(obj) + 1, value) + callback = lambda value, i: self.insert(i + 1, value) + self._do_search(obj, value, recursive, callback) + + def replace(self, obj, value, recursive=True): + def callback(value, i): + self.nodes.pop(i) + self.insert(i, value) + + self._do_search(obj, value, recursive, callback) def append(self, value): nodes = self._nodify(value) @@ -165,15 +181,11 @@ class Wikicode(StringMixIn): for i, node in enumerate(self.nodes): if self._is_equivalent(obj, node): return self.nodes.pop(i) - children = self._get_children(node) if self._contains(self._get_children(node), obj): return node.remove(obj, recursive=True) raise ValueError(obj) - try: - return self.nodes.pop(self.index(obj)) - except IndexError: - raise ValueError(obj) + return self.nodes.pop(self.index(obj)) def ifilter(self, recursive=False, matches=None, flags=FLAGS, forcetype=None): @@ -208,4 +220,4 @@ class Wikicode(StringMixIn): def get_tree(self): marker = object() # Random object we can find with certainty in a list - return "\n".join(self._show_tree(self, [], marker)) + return "\n".join(self._get_tree(self, [], marker))