diff --git a/tests/_test_tree_equality.py b/tests/_test_tree_equality.py index a12bd68..6d9b26a 100644 --- a/tests/_test_tree_equality.py +++ b/tests/_test_tree_equality.py @@ -32,6 +32,15 @@ from mwparserfromhell.wikicode import Wikicode wrap = lambda L: Wikicode(SmartList(L)) wraptext = lambda t: wrap([Text(t)]) +def getnodes(code): + """Iterate over all child nodes of a given parent node. + + Imitates Wikicode._get_all_nodes(). + """ + for node in code.nodes: + for context, child in node.__iternodes__(getnodes): + yield child + class TreeEqualityTestCase(TestCase): """A base test case with support for comparing the equality of node trees. diff --git a/tests/test_argument.py b/tests/test_argument.py index 3a959b6..a9469d4 100644 --- a/tests/test_argument.py +++ b/tests/test_argument.py @@ -26,7 +26,7 @@ import unittest from mwparserfromhell.compat import str from mwparserfromhell.nodes import Argument, Text -from ._test_tree_equality import TreeEqualityTestCase, wrap +from ._test_tree_equality import TreeEqualityTestCase, getnodes, wrap class TestArgument(TreeEqualityTestCase): """Test cases for the Argument node.""" @@ -38,6 +38,23 @@ class TestArgument(TreeEqualityTestCase): node2 = Argument(wrap([Text("foo")]), wrap([Text("bar")])) self.assertEqual("{{{foo|bar}}}", str(node2)) + def test_iternodes(self): + """test Argument.__iternodes__()""" + node1n1 = Text("foobar") + node2n1, node2n2, node2n3 = Text("foo"), Text("bar"), Text("baz") + node1 = Argument(wrap([node1n1])) + node2 = Argument(wrap([node2n1]), wrap([node2n2, node2n3])) + gen1 = node1.__iternodes__(getnodes) + gen2 = node2.__iternodes__(getnodes) + self.assertEqual((None, node1), next(gen1)) + self.assertEqual((None, node2), next(gen2)) + self.assertEqual((node1.name, node1n1), next(gen1)) + self.assertEqual((node2.name, node2n1), next(gen2)) + self.assertEqual((node2.default, node2n2), next(gen2)) + self.assertEqual((node2.default, node2n3), next(gen2)) + self.assertRaises(StopIteration, next, gen1) + self.assertRaises(StopIteration, next, gen2) + def test_strip(self): """test Argument.__strip__()""" node = Argument(wrap([Text("foobar")])) diff --git a/tests/test_comment.py b/tests/test_comment.py index a7a3c4d..44225a2 100644 --- a/tests/test_comment.py +++ b/tests/test_comment.py @@ -36,6 +36,13 @@ class TestComment(TreeEqualityTestCase): node = Comment("foobar") self.assertEqual("", str(node)) + def test_iternodes(self): + """test Comment.__iternodes__()""" + node = Comment("foobar") + gen = node.__iternodes__(None) + self.assertEqual((None, node), next(gen)) + self.assertRaises(StopIteration, next, gen) + def test_strip(self): """test Comment.__strip__()""" node = Comment("foobar") diff --git a/tests/test_heading.py b/tests/test_heading.py index 79b0ebf..38f6545 100644 --- a/tests/test_heading.py +++ b/tests/test_heading.py @@ -26,7 +26,7 @@ import unittest from mwparserfromhell.compat import str from mwparserfromhell.nodes import Heading, Text -from ._test_tree_equality import TreeEqualityTestCase, wrap +from ._test_tree_equality import TreeEqualityTestCase, getnodes, wrap class TestHeading(TreeEqualityTestCase): """Test cases for the Heading node.""" @@ -38,6 +38,16 @@ class TestHeading(TreeEqualityTestCase): node2 = Heading(wrap([Text(" zzz ")]), 5) self.assertEqual("===== zzz =====", str(node2)) + def test_iternodes(self): + """test Heading.__iternodes__()""" + text1, text2 = Text("foo"), Text("bar") + node = Heading(wrap([text1, text2]), 3) + gen = node.__iternodes__(getnodes) + self.assertEqual((None, node), next(gen)) + self.assertEqual((node.title, text1), next(gen)) + self.assertEqual((node.title, text2), next(gen)) + self.assertRaises(StopIteration, next, gen) + def test_strip(self): """test Heading.__strip__()""" node = Heading(wrap([Text("foobar")]), 3) diff --git a/tests/test_html_entity.py b/tests/test_html_entity.py index d3d23bf..d38e5ec 100644 --- a/tests/test_html_entity.py +++ b/tests/test_html_entity.py @@ -42,6 +42,13 @@ class TestHTMLEntity(TreeEqualityTestCase): self.assertEqual("k", str(node3)) self.assertEqual("l", str(node4)) + def test_iternodes(self): + """test HTMLEntity.__iternodes__()""" + node = HTMLEntity("nbsp", named=True, hexadecimal=False) + gen = node.__iternodes__(None) + self.assertEqual((None, node), next(gen)) + self.assertRaises(StopIteration, next, gen) + def test_strip(self): """test HTMLEntity.__strip__()""" node1 = HTMLEntity("nbsp", named=True, hexadecimal=False) diff --git a/tests/test_template.py b/tests/test_template.py index 81b7382..28592df 100644 --- a/tests/test_template.py +++ b/tests/test_template.py @@ -26,7 +26,7 @@ import unittest from mwparserfromhell.compat import str from mwparserfromhell.nodes import HTMLEntity, Template, Text from mwparserfromhell.nodes.extras import Parameter -from ._test_tree_equality import TreeEqualityTestCase, wrap, wraptext +from ._test_tree_equality import TreeEqualityTestCase, getnodes, wrap, wraptext pgens = lambda k, v: Parameter(wraptext(k), wraptext(v), showkey=True) pgenh = lambda k, v: Parameter(wraptext(k), wraptext(v), showkey=False) @@ -42,6 +42,30 @@ class TestTemplate(TreeEqualityTestCase): [pgenh("1", "bar"), pgens("abc", "def")]) self.assertEqual("{{foo|bar|abc=def}}", str(node2)) + def test_iternodes(self): + """test Template.__iternodes__()""" + node1n1 = Text("foobar") + node2n1, node2n2, node2n3 = Text("foo"), Text("bar"), Text("abc") + node2n4, node2n5 = Text("def"), Text("ghi") + node2p1 = Parameter(wraptext("1"), wrap([node2n2]), showkey=False) + node2p2 = Parameter(wrap([node2n3]), wrap([node2n4, node2n5]), + showkey=True) + node1 = Template(wrap([node1n1])) + node2 = Template(wrap([node2n1]), [node2p1, node2p2]) + + gen1 = node1.__iternodes__(getnodes) + gen2 = node2.__iternodes__(getnodes) + self.assertEqual((None, node1), next(gen1)) + self.assertEqual((None, node2), next(gen2)) + self.assertEqual((node1.name, node1n1), next(gen1)) + self.assertEqual((node2.name, node2n1), next(gen2)) + self.assertEqual((node2.params[0].value, node2n2), next(gen2)) + self.assertEqual((node2.params[1].name, node2n3), next(gen2)) + self.assertEqual((node2.params[1].value, node2n4), next(gen2)) + self.assertEqual((node2.params[1].value, node2n5), next(gen2)) + self.assertRaises(StopIteration, next, gen1) + self.assertRaises(StopIteration, next, gen2) + def test_strip(self): """test Template.__strip__()""" node1 = Template(wraptext("foobar")) diff --git a/tests/test_text.py b/tests/test_text.py index f3649dd..35ac340 100644 --- a/tests/test_text.py +++ b/tests/test_text.py @@ -36,6 +36,13 @@ class TestText(unittest.TestCase): node2 = Text("fóóbar") self.assertEqual("fóóbar", str(node2)) + def test_iternodes(self): + """test Text.__iternodes__()""" + node = Text("foobar") + gen = node.__iternodes__(None) + self.assertEqual((None, node), next(gen)) + self.assertRaises(StopIteration, next, gen) + def test_strip(self): """test Text.__strip__()""" node = Text("foobar") diff --git a/tests/test_wikilink.py b/tests/test_wikilink.py index 09ca5b3..d4319c1 100644 --- a/tests/test_wikilink.py +++ b/tests/test_wikilink.py @@ -26,7 +26,7 @@ import unittest from mwparserfromhell.compat import str from mwparserfromhell.nodes import Text, Wikilink -from ._test_tree_equality import TreeEqualityTestCase, wrap +from ._test_tree_equality import TreeEqualityTestCase, getnodes, wrap class TestWikilink(TreeEqualityTestCase): """Test cases for the Wikilink node.""" @@ -38,6 +38,23 @@ class TestWikilink(TreeEqualityTestCase): node2 = Wikilink(wrap([Text("foo")]), wrap([Text("bar")])) self.assertEqual("[[foo|bar]]", str(node2)) + def test_iternodes(self): + """test Wikilink.__iternodes__()""" + node1n1 = Text("foobar") + node2n1, node2n2, node2n3 = Text("foo"), Text("bar"), Text("baz") + node1 = Wikilink(wrap([node1n1])) + node2 = Wikilink(wrap([node2n1]), wrap([node2n2, node2n3])) + gen1 = node1.__iternodes__(getnodes) + gen2 = node2.__iternodes__(getnodes) + self.assertEqual((None, node1), next(gen1)) + self.assertEqual((None, node2), next(gen2)) + self.assertEqual((node1.title, node1n1), next(gen1)) + self.assertEqual((node2.title, node2n1), next(gen2)) + self.assertEqual((node2.text, node2n2), next(gen2)) + self.assertEqual((node2.text, node2n3), next(gen2)) + self.assertRaises(StopIteration, next, gen1) + self.assertRaises(StopIteration, next, gen2) + def test_strip(self): """test Wikilink.__strip__()""" node = Wikilink(wrap([Text("foobar")]))