diff --git a/mwparserfromhell/wikicode.py b/mwparserfromhell/wikicode.py index 28b3524..68af8da 100644 --- a/mwparserfromhell/wikicode.py +++ b/mwparserfromhell/wikicode.py @@ -24,7 +24,7 @@ import re import mwparserfromhell from mwparserfromhell.node import Node -from mwparserfromhell.string_mixin import StringMixin +from mwparserfromhell.string_mixin import StringMixIn from mwparserfromhell.template import Template from mwparserfromhell.text import Text @@ -32,7 +32,7 @@ __all__ = ["Wikicode"] FLAGS = re.I | re.S | re.U -class Wikicode(StringMixin): +class Wikicode(StringMixIn): def __init__(self, nodes): self._nodes = nodes @@ -49,18 +49,34 @@ class Wikicode(StringMixin): error = "Needs string, Node, or Wikicode object, but got {0}: {1}" raise ValueError(error.format(type(value), value)) + def _get_children(self, node): + yield node + if isinstance(node, Template): + for child in self._get_all_nodes(node.name): + yield child + for param in node.params: + if param.showkey: + for child in self._get_all_nodes(param.name): + yield child + for child in self._get_all_nodes(param.value): + yield child + def _get_all_nodes(self, code): for node in code.nodes: - yield node - if isinstance(node, Template): - for child in self._get_all_nodes(node.name): - yield child - for param in node.params: - if param.showkey: - for child in self._get_all_nodes(param.name): - yield child - for child in self._get_all_nodes(param.value): - yield child + for child in self._get_children(node): + yield child + + def _do_recursive_index(self, obj): + for i, node in enumerate(self.nodes): + children = self._get_children(node) + if isinstance(obj, Node): + for child in children: + if child is obj: + return i + else: + if obj in children: + return i + raise ValueError(obj) def _show_tree(self, code, lines, marker=None, indent=0): def write(*args): @@ -106,9 +122,9 @@ class Wikicode(StringMixin): if nodes: self.nodes[index] = nodes[0] - def index(self, obj): - if obj not in self.nodes: - raise ValueError(obj) + def index(self, obj, recursive=False): + if recursive: + return self._do_recursive_index() if isinstance(obj, Node): for i, node in enumerate(self.nodes): if node is obj: @@ -116,17 +132,17 @@ class Wikicode(StringMixin): raise ValueError(obj) return self.nodes.index(obj) - def insert(self, index, value, recursive=True): + def insert(self, index, value): nodes = self._nodify(value) for node in reversed(nodes): self.nodes.insert(index, node) - def insert_before(self, obj, value): + def insert_before(self, obj, value, recursive=True): if obj not in self.nodes: raise KeyError(obj) self.insert(self.index(obj), value) - def insert_after(self, obj, value): + def insert_after(self, obj, value, recursive=True): if obj not in self.nodes: raise KeyError(obj) self.insert(self.index(obj) + 1, value) @@ -136,7 +152,7 @@ class Wikicode(StringMixin): for node in nodes: self.nodes.append(node) - def remove(self, node): + def remove(self, node, recursive=True): self.nodes.pop(self.index(node)) def ifilter(self, recursive=False, matches=None, flags=FLAGS,