diff --git a/mwparserfromhell/wikicode.py b/mwparserfromhell/wikicode.py index ae7f4c6..ed30660 100644 --- a/mwparserfromhell/wikicode.py +++ b/mwparserfromhell/wikicode.py @@ -84,16 +84,18 @@ class Wikicode(StringMixIn): return True return False - def _do_search(self, obj, value, recursive, callback): + def _do_search(self, obj, value, recursive, callback, context=None): if recursive: - for i, node in enumerate(self.nodes): + nodes = context.nodes if context else self.nodes + for i, node in enumerate(nodes): if self._is_equivalent(obj, node): - return callback(value, i) + return callback(self, value, i) if self._contains(self._get_children(node), obj): - return self._do_search(obj, value, recursive, callback) + return self._do_search(obj, value, recursive, callback, + context=obj) raise ValueError(obj) - callback(value, self.index(obj, recursive=False)) + callback(self, value, self.index(obj, recursive=False)) def _get_tree(self, code, lines, marker=None, indent=0): def write(*args): @@ -157,15 +159,15 @@ class Wikicode(StringMixIn): self.nodes.insert(index, node) def insert_before(self, obj, value, recursive=True): - callback = lambda value, i: self.insert(i, value) + callback = lambda self, value, i: self.insert(i, value) self._do_search(obj, value, recursive, callback) def insert_after(self, obj, value, recursive=True): - callback = lambda value, i: self.insert(i + 1, value) + callback = lambda self, value, i: self.insert(i + 1, value) self._do_search(obj, value, recursive, callback) def replace(self, obj, value, recursive=True): - def callback(value, i): + def callback(self, value, i): self.nodes.pop(i) self.insert(i, value)