diff --git a/CHANGELOG b/CHANGELOG index 289c413..7da4968 100644 --- a/CHANGELOG +++ b/CHANGELOG @@ -4,6 +4,10 @@ v0.4 (unreleased): - Added a script to do releases in scripts/release.sh. - skip_style_tags can now be passed to mwparserfromhell.parse() (previously, only Parser().parse() allowed it). +- The 'recursive' argument to Wikicode's filter methods now accepts a third + option, RECURSE_OTHERS, which recurses over all children except instances of + 'forcetype' (for example, `code.filter_templates(code.RECURSE_OTHERS)` + returns all un-nested templates). - Fixed a parser bug involving nested tags. - Updated and fixed some documentation. diff --git a/docs/changelog.rst b/docs/changelog.rst index 21f0629..8416204 100644 --- a/docs/changelog.rst +++ b/docs/changelog.rst @@ -11,6 +11,11 @@ Unreleased - Added a script to do releases in :file:`scripts/release.sh`. - *skip_style_tags* can now be passed to :py:func:`mwparserfromhell.parse() <.parse_anything>` (previously, only :py:meth:`.Parser.parse` allowed it). +- The *recursive* argument to :py:class:`Wikicode's <.Wikicode>` + :py:meth:`.filter` methods now accepts a third option, ``RECURSE_OTHERS``, + which recurses over all children except instances of *forcetype* (for + example, ``code.filter_templates(code.RECURSE_OTHERS)`` returns all un-nested + templates). - Fixed a parser bug involving nested tags. - Updated and fixed some documentation. diff --git a/mwparserfromhell/wikicode.py b/mwparserfromhell/wikicode.py index f728248..d7736ff 100644 --- a/mwparserfromhell/wikicode.py +++ b/mwparserfromhell/wikicode.py @@ -44,6 +44,7 @@ class Wikicode(StringMixIn): ` series of functions is very useful for extracting and iterating over, for example, all of the templates in the object. """ + RECURSE_OTHERS = 2 def __init__(self, nodes): super(Wikicode, self).__init__() @@ -53,12 +54,15 @@ class Wikicode(StringMixIn): return "".join([str(node) for node in self.nodes]) @staticmethod - def _get_children(node, contexts=False, parent=None): + def _get_children(node, contexts=False, restrict=None, parent=None): """Iterate over all child :py:class:`.Node`\ s of a given *node*.""" yield (parent, node) if contexts else node + if restrict and isinstance(node, restrict): + return for code in node.__children__(): for child in code.nodes: - for result in Wikicode._get_children(child, contexts, code): + sub = Wikicode._get_children(child, contexts, restrict, code) + for result in sub: yield result @staticmethod @@ -79,7 +83,7 @@ class Wikicode(StringMixIn): if matches: if callable(matches): return matches - return lambda obj: re.search(matches, str(obj), flags) # r + return lambda obj: re.search(matches, str(obj), flags) return lambda obj: True def _indexed_ifilter(self, recursive=True, matches=None, flags=FLAGS, @@ -93,8 +97,9 @@ class Wikicode(StringMixIn): """ match = self._build_matcher(matches, flags) if recursive: + restrict = forcetype if recursive == self.RECURSE_OTHERS else None def getter(i, node): - for ch in self._get_children(node): + for ch in self._get_children(node, restrict=restrict): yield (i, ch) inodes = chain(*(getter(i, n) for i, n in enumerate(self.nodes))) else: @@ -222,10 +227,10 @@ class Wikicode(StringMixIn): This is equivalent to :py:meth:`{1}` with *forcetype* set to :py:class:`~{2.__module__}.{2.__name__}`. """ - make_ifilter = lambda ftype: (lambda self, **kw: - self.ifilter(forcetype=ftype, **kw)) - make_filter = lambda ftype: (lambda self, **kw: - self.filter(forcetype=ftype, **kw)) + make_ifilter = lambda ftype: (lambda self, *a, **kw: + self.ifilter(forcetype=ftype, *a, **kw)) + make_filter = lambda ftype: (lambda self, *a, **kw: + self.filter(forcetype=ftype, *a, **kw)) for name, ftype in (meths.items() if py3k else meths.iteritems()): ifilter = make_ifilter(ftype) filter = make_filter(ftype) @@ -435,27 +440,36 @@ class Wikicode(StringMixIn): forcetype=None): """Iterate over nodes in our list matching certain conditions. - If *recursive* is ``True``, we will iterate over our children and all - of their descendants, otherwise just our immediate children. If - *forcetype* is given, only nodes that are instances of this type are - yielded. *matches* can be used to further restrict the nodes, either as - a function (taking a single :py:class:`.Node` and returning a boolean) - or a regular expression (matched against the node's string - representation with :py:func:`re.search`). If *matches* is a regex, the - flags passed to :py:func:`re.search` are :py:const:`re.IGNORECASE`, + If *forcetype* is given, only nodes that are instances of this type (or + tuple of types) are yielded. Setting *recursive* to ``True`` will + iterate over all children and their descendants. ``RECURSE_OTHERS`` + will only iterate over children that are not the instances of + *forcetype*. ``False`` will only iterate over immediate children. + + ``RECURSE_OTHERS`` can be used to iterate over all un-nested templates, + even if they are inside of HTML tags, like so: + + >>> code = mwparserfromhell.parse("{{foo}}{{foo|{{bar}}}}") + >>> code.filter_templates(code.RECURSE_OTHERS) + ["{{foo}}", "{{foo|{{bar}}}}"] + + *matches* can be used to further restrict the nodes, either as a + function (taking a single :py:class:`.Node` and returning a boolean) or + a regular expression (matched against the node's string representation + with :py:func:`re.search`). If *matches* is a regex, the flags passed + to :py:func:`re.search` are :py:const:`re.IGNORECASE`, :py:const:`re.DOTALL`, and :py:const:`re.UNICODE`, but custom flags can be specified by passing *flags*. """ - return (node for i, node in - self._indexed_ifilter(recursive, matches, flags, forcetype)) + gen = self._indexed_ifilter(recursive, matches, flags, forcetype) + return (node for i, node in gen) - def filter(self, recursive=True, matches=None, flags=FLAGS, - forcetype=None): + def filter(self, *args, **kwargs): """Return a list of nodes within our list matching certain conditions. This is equivalent to calling :py:func:`list` on :py:meth:`ifilter`. """ - return list(self.ifilter(recursive, matches, flags, forcetype)) + return list(self.ifilter(*args, **kwargs)) def get_sections(self, levels=None, matches=None, flags=FLAGS, flat=False, include_lead=None, include_headings=True): diff --git a/tests/test_wikicode.py b/tests/test_wikicode.py index 9ff5949..a7c3eb3 100644 --- a/tests/test_wikicode.py +++ b/tests/test_wikicode.py @@ -319,11 +319,14 @@ class TestWikicode(TreeEqualityTestCase): self.assertEqual(["{{baz}}", "{{bz}}"], func(matches=r"^{{b.*?z")) self.assertEqual(["{{baz}}"], func(matches=r"^{{b.+?z}}")) - self.assertEqual(["{{a|{{b}}|{{c|d={{f}}{{h}}}}}}"], - code2.filter_templates(recursive=False)) - self.assertEqual(["{{a|{{b}}|{{c|d={{f}}{{h}}}}}}", "{{b}}", - "{{c|d={{f}}{{h}}}}", "{{f}}", "{{h}}"], - code2.filter_templates(recursive=True)) + exp_rec = ["{{a|{{b}}|{{c|d={{f}}{{h}}}}}}", "{{b}}", + "{{c|d={{f}}{{h}}}}", "{{f}}", "{{h}}"] + exp_unrec = ["{{a|{{b}}|{{c|d={{f}}{{h}}}}}}"] + self.assertEqual(exp_rec, code2.filter_templates()) + self.assertEqual(exp_unrec, code2.filter_templates(recursive=False)) + self.assertEqual(exp_rec, code2.filter_templates(recursive=True)) + self.assertEqual(exp_rec, code2.filter_templates(True)) + self.assertEqual(exp_unrec, code2.filter_templates(False)) self.assertEqual(["{{foobar}}"], code3.filter_templates( matches=lambda node: node.name.matches("Foobar"))) @@ -332,9 +335,15 @@ class TestWikicode(TreeEqualityTestCase): self.assertEqual([], code3.filter_tags(matches=r"^{{b.*?z")) self.assertEqual([], code3.filter_tags(matches=r"^{{b.*?z", flags=0)) - self.assertRaises(TypeError, code.filter_templates, 100) self.assertRaises(TypeError, code.filter_templates, a=42) self.assertRaises(TypeError, code.filter_templates, forcetype=Template) + self.assertRaises(TypeError, code.filter_templates, 1, 0, 0, Template) + + code4 = parse("{{foo}}{{foo|{{bar}}}}") + actual1 = code4.filter_templates(recursive=code4.RECURSE_OTHERS) + actual2 = code4.filter_templates(code4.RECURSE_OTHERS) + self.assertEqual(["{{foo}}", "{{foo|{{bar}}}}"], actual1) + self.assertEqual(["{{foo}}", "{{foo|{{bar}}}}"], actual2) def test_get_sections(self): """test Wikicode.get_sections()"""