diff --git a/mwparserfromhell/string_mixin.py b/mwparserfromhell/string_mixin.py index 8ddf699..74ae42f 100644 --- a/mwparserfromhell/string_mixin.py +++ b/mwparserfromhell/string_mixin.py @@ -22,13 +22,16 @@ __all__ = ["StringMixIn"] -class StringMixIn(object): # UnicodeMixIn? +class StringMixIn(object): def __str__(self): return unicode(self).encode("utf8") def __repr__(self): return repr(unicode(self)) + def __unicode__(self): + raise NotImplementedError() + def __lt__(self, other): if isinstance(other, StringMixin): return unicode(self) < unicode(other) @@ -73,4 +76,6 @@ class StringMixIn(object): # UnicodeMixIn? return unicode(self)[index] def __contains__(self, item): + if isinstance(item, StringMixIn): + return unicode(item) in unicode(self) return item in unicode(self) diff --git a/mwparserfromhell/wikicode.py b/mwparserfromhell/wikicode.py index ea97d5e..a5437be 100644 --- a/mwparserfromhell/wikicode.py +++ b/mwparserfromhell/wikicode.py @@ -65,17 +65,24 @@ class Wikicode(StringMixIn): for child in self._get_children(node): yield child - def _do_recursive_index(self, obj): - for i, node in enumerate(self.nodes): - children = self._get_children(node) - if isinstance(obj, Node): - for child in children: - if child is obj: - return i - else: - if obj in children: - return i - raise ValueError(obj) + def _is_equivalent(self, obj, node): + if isinstance(obj, Node): + if node is obj: + return True + else: + if node == obj: + return True + return False + + def _contains(self, nodes, obj): + if isinstance(obj, Node): + for node in nodes: + if node is obj: + return True + else: + if obj in nodes: + return True + return False def _show_tree(self, code, lines, marker=None, indent=0): def write(*args): @@ -123,13 +130,15 @@ class Wikicode(StringMixIn): def index(self, obj, recursive=False): if recursive: - return self._do_recursive_index() - if isinstance(obj, Node): for i, node in enumerate(self.nodes): - if node is obj: + if self._contains(self._get_children(node), obj): return i raise ValueError(obj) - return self.nodes.index(obj) + + for i, node in enumerate(self.nodes): + if self._is_equivalent(obj, node): + return i + raise ValueError(obj) def insert(self, index, value): nodes = self._nodify(value) @@ -151,8 +160,20 @@ class Wikicode(StringMixIn): for node in nodes: self.nodes.append(node) - def remove(self, node, recursive=True): - self.nodes.pop(self.index(node)) + def remove(self, obj, recursive=True): + if recursive: + for i, node in enumerate(self.nodes): + if self._is_equivalent(obj, node): + return self.nodes.pop(i) + children = self._get_children(node) + if self._contains(self._get_children(node), obj): + return node.remove(obj, recursive=True) + raise ValueError(obj) + + try: + return self.nodes.pop(self.index(obj)) + except IndexError: + raise ValueError(obj) def ifilter(self, recursive=False, matches=None, flags=FLAGS, forcetype=None):