diff --git a/mwparserfromhell/wikicode.py b/mwparserfromhell/wikicode.py index b814ee5..5c4d66a 100644 --- a/mwparserfromhell/wikicode.py +++ b/mwparserfromhell/wikicode.py @@ -60,19 +60,6 @@ class Wikicode(StringMixIn): for context, child in node.__iternodes__(self._get_all_nodes): yield child - def _get_context(self, node, obj): - """Return a ``Wikicode`` that contains *obj* in its descendants. - - The closest (shortest distance from *node*) suitable ``Wikicode`` will - be returned, or ``None`` if the *obj* is the *node* itself. - - Raises ``ValueError`` if *obj* is not within *node*. - """ - for context, child in node.__iternodes__(self._get_all_nodes): - if self._is_equivalent(obj, child): - return context - raise ValueError(obj) - def _get_all_nodes(self, code): """Iterate over all of our descendant nodes. @@ -105,26 +92,54 @@ class Wikicode(StringMixIn): return False return obj in nodes - def _do_search(self, obj, recursive, callback, context, *args, **kwargs): - """Look within *context* for *obj*, executing *callback* if found. + def _do_search(self, obj, recursive, context=None, literal=None): + """Return some info about the location of *obj* within *context*. - If *recursive* is ``True``, we'll look within *context* and its - descendants, otherwise we'll just execute *callback*. We raise - :py:exc:`ValueError` if *obj* isn't in our node list or context. If - found, *callback* is passed the context, the index of the node within - the context, and whatever were passed as ``*args`` and ``**kwargs``. + If *recursive* is ``True``, we'll look within *context* (``self`` by + default) and its descendants, otherwise just *context*. We raise + :py:exc:`ValueError` if *obj* isn't found. The return data is a list of + 3-tuples (*type*, *context*, *data*) where *type* is *obj*\ 's best + type resolution (either ``Node``, ``Wikicode``, or ``str``), *context* + is the closest ``Wikicode`` encompassing it, and *data* is either a + ``Node``, a list of ``Node``\ s, or ``None`` depending on *type*. """ - if recursive: - for i, node in enumerate(context.nodes): - if self._is_equivalent(obj, node): - return callback(context, i, *args, **kwargs) - if self._contains(self._get_children(node), obj): - context = self._get_context(node, obj) - return self._do_search(obj, recursive, callback, context, - *args, **kwargs) - raise ValueError(obj) + if not context: + context = self + literal = isinstance(obj, (Node, Wikicode)) + obj = parse_anything(obj) + if not obj or obj not in self: + raise ValueError(obj) + if len(obj.nodes) == 1: + obj = obj.get(0) + + compare = lambda a, b: (a is b) if literal else (a == b) + results = [] + i = 0 + while i < len(context.nodes): + node = context.get(i) + if isinstance(obj, Node) and compare(obj, node): + results.append((Node, context, node)) + elif isinstance(obj, Wikicode) and compare(obj.get(0), node): + for j in range(1, len(obj.nodes)): + if not compare(obj.get(j), context.get(i + j)): + break + else: + nodes = list(context.nodes[i:i + len(obj.nodes)]) + results.append((Wikicode, context, nodes)) + i += len(obj.nodes) - 1 + elif recursive: + contexts = node.__iternodes__(self._get_all_nodes) + for code in {ctx for ctx, child in contexts}: + if code and obj in code: + search = self._do_search(obj, recursive, code, literal) + results.extend(search) + i += 1 - callback(context, self.index(obj, recursive=False), *args, **kwargs) + if not results and not literal and recursive: + results.append((str, context, None)) + if not results and context is self: + raise ValueError(obj) + return results def _get_tree(self, code, lines, marker, indent): """Build a tree to illustrate the way the Wikicode object was parsed. @@ -253,41 +268,64 @@ class Wikicode(StringMixIn): def insert_before(self, obj, value, recursive=True): """Insert *value* immediately before *obj* in the list of nodes. - *obj* can be either a string or a :py:class:`~.Node`. *value* can be - anything parasable by :py:func:`.parse_anything`. If *recursive* is - ``True``, we will try to find *obj* within our child nodes even if it - is not a direct descendant of this :py:class:`~.Wikicode` object. If - *obj* is not in the node list, :py:exc:`ValueError` is raised. + *obj* can be either a string, a :py:class:`~.Node`, or other + :py:class:`~.Wikicode` object (as created by :py:meth:`get_sections`, + for example). *value* can be anything parasable by + :py:func:`.parse_anything`. If *recursive* is ``True``, we will try to + find *obj* within our child nodes even if it is not a direct descendant + of this :py:class:`~.Wikicode` object. If *obj* is not found, + :py:exc:`ValueError` is raised. """ - callback = lambda self, i, value: self.insert(i, value) - self._do_search(obj, recursive, callback, self, value) + for restype, context, data in self._do_search(obj, recursive): + if restype in (Node, Wikicode): + i = context.index(data if restype is Node else data[0], False) + context.insert(i, value) + else: + obj = str(obj) + context.nodes = str(context).replace(obj, str(value) + obj) def insert_after(self, obj, value, recursive=True): """Insert *value* immediately after *obj* in the list of nodes. - *obj* can be either a string or a :py:class:`~.Node`. *value* can be - anything parasable by :py:func:`.parse_anything`. If *recursive* is - ``True``, we will try to find *obj* within our child nodes even if it - is not a direct descendant of this :py:class:`~.Wikicode` object. If - *obj* is not in the node list, :py:exc:`ValueError` is raised. + *obj* can be either a string, a :py:class:`~.Node`, or other + :py:class:`~.Wikicode` object (as created by :py:meth:`get_sections`, + for example). *value* can be anything parasable by + :py:func:`.parse_anything`. If *recursive* is ``True``, we will try to + find *obj* within our child nodes even if it is not a direct descendant + of this :py:class:`~.Wikicode` object. If *obj* is not found, + :py:exc:`ValueError` is raised. """ - callback = lambda self, i, value: self.insert(i + 1, value) - self._do_search(obj, recursive, callback, self, value) + for restype, context, data in self._do_search(obj, recursive): + if restype in (Node, Wikicode): + i = context.index(data if restype is Node else data[-1], False) + context.insert(i + 1, value) + else: + obj = str(obj) + context.nodes = str(context).replace(obj, obj + str(value)) def replace(self, obj, value, recursive=True): """Replace *obj* with *value* in the list of nodes. - *obj* can be either a string or a :py:class:`~.Node`. *value* can be - anything parasable by :py:func:`.parse_anything`. If *recursive* is - ``True``, we will try to find *obj* within our child nodes even if it - is not a direct descendant of this :py:class:`~.Wikicode` object. If - *obj* is not in the node list, :py:exc:`ValueError` is raised. + *obj* can be either a string, a :py:class:`~.Node`, or other + :py:class:`~.Wikicode` object (as created by :py:meth:`get_sections`, + for example). *value* can be anything parasable by + :py:func:`.parse_anything`. If *recursive* is ``True``, we will try to + find *obj* within our child nodes even if it is not a direct descendant + of this :py:class:`~.Wikicode` object. If *obj* is not found, + :py:exc:`ValueError` is raised. """ - def callback(self, i, value): - self.nodes.pop(i) - self.insert(i, value) - - self._do_search(obj, recursive, callback, self, value) + for restype, context, data in self._do_search(obj, recursive): + if restype is Node: + i = context.index(data, False) + context.nodes.pop(i) + context.insert(i, value) + elif restype is Wikicode: + i = context.index(data[0], False) + for _ in data: + context.nodes.pop(i) + context.insert(i, value) + else: + context.nodes = str(context).replace(str(obj), str(value)) def append(self, value): """Insert *value* at the end of the list of nodes. @@ -301,13 +339,22 @@ class Wikicode(StringMixIn): def remove(self, obj, recursive=True): """Remove *obj* from the list of nodes. - *obj* can be either a string or a :py:class:`~.Node`. If *recursive* is - ``True``, we will try to find *obj* within our child nodes even if it - is not a direct descendant of this :py:class:`~.Wikicode` object. If - *obj* is not in the node list, :py:exc:`ValueError` is raised. + *obj* can be either a string, a :py:class:`~.Node`, or other + :py:class:`~.Wikicode` object (as created by :py:meth:`get_sections`, + for example). If *recursive* is ``True``, we will try to find *obj* + within our child nodes even if it is not a direct descendant of this + :py:class:`~.Wikicode` object. If *obj* is not found, + :py:exc:`ValueError` is raised. """ - callback = lambda self, i: self.nodes.pop(i) - self._do_search(obj, recursive, callback, self) + for restype, context, data in self._do_search(obj, recursive): + if restype is Node: + context.nodes.pop(context.index(data, False)) + elif restype is Wikicode: + i = context.index(data[0], False) + for _ in data: + context.nodes.pop(i) + else: + context.nodes = str(context).replace(str(obj), "") def matches(self, other): """Do a loose equivalency test suitable for comparing page names.