diff --git a/mwparserfromhell/nodes/__init__.py b/mwparserfromhell/nodes/__init__.py index c03785b..2209cd6 100644 --- a/mwparserfromhell/nodes/__init__.py +++ b/mwparserfromhell/nodes/__init__.py @@ -25,7 +25,10 @@ from mwparserfromhell.string_mixin import StringMixIn __all__ = ["Node"] class Node(StringMixIn): - def __strip__(self, normalize=True, collapse=True): + def __iternodes__(self, getter): + yield None, self + + def __strip__(self, normalize, collapse): return None from mwparserfromhell.nodes import extras diff --git a/mwparserfromhell/nodes/heading.py b/mwparserfromhell/nodes/heading.py index 1f26df6..b213824 100644 --- a/mwparserfromhell/nodes/heading.py +++ b/mwparserfromhell/nodes/heading.py @@ -32,6 +32,11 @@ class Heading(Node): def __unicode__(self): return ("=" * self.level) + self.title + ("=" * self.level) + def __iternodes__(self, getter): + yield None, self + for child in getter(self.title): + yield self.title, child + def __strip__(self, normalize, collapse): return self.title diff --git a/mwparserfromhell/nodes/tag.py b/mwparserfromhell/nodes/tag.py index faaaa54..0f260f0 100644 --- a/mwparserfromhell/nodes/tag.py +++ b/mwparserfromhell/nodes/tag.py @@ -100,6 +100,20 @@ class Tag(Node): result += "" return result + def __iternodes__(self, getter): + yield None, self + if self.showtag: + for child in getter(self.tag): + yield self.tag, tag + for attr in self.attrs: + for child in getter(attr.name): + yield attr.name, child + if attr.value: + for child in getter(attr.value): + yield attr.value, child + for child in getter(self.contents): + yield self.contents, child + def __strip__(self, normalize, collapse): if self.type in self.TAGS_VISIBLE: return self.contents.strip_code(normalize, collapse) diff --git a/mwparserfromhell/nodes/template.py b/mwparserfromhell/nodes/template.py index 5f635b0..bdef285 100644 --- a/mwparserfromhell/nodes/template.py +++ b/mwparserfromhell/nodes/template.py @@ -46,6 +46,17 @@ class Template(Node): else: return "{{" + unicode(self.name) + "}}" + def __iternodes__(self, getter): + yield None, self + for child in getter(self.name): + yield self.name, child + for param in self.params: + if param.showkey: + for child in getter(param.name): + yield param.name, child + for child in getter(param.value): + yield param.value, child + def _surface_escape(self, code, char): replacement = HTMLEntity(value=ord(char)) for node in code.filter_text(recursive=False): diff --git a/mwparserfromhell/wikicode.py b/mwparserfromhell/wikicode.py index 72195be..9c54927 100644 --- a/mwparserfromhell/wikicode.py +++ b/mwparserfromhell/wikicode.py @@ -39,39 +39,12 @@ class Wikicode(StringMixIn): def __unicode__(self): return "".join([unicode(node) for node in self.nodes]) - def _iterate_over_children(self, node): - yield (None, node) - if isinstance(node, Heading): - for child in self._get_all_nodes(node.title): - yield (node.title, child) - elif isinstance(node, Tag): - if node.showtag: - for child in self._get_all_nodes(node.tag): - yield (node.tag, tag) - for attr in node.attrs: - for child in self._get_all_nodes(attr.name): - yield (attr.name, child) - if attr.value: - for child in self._get_all_nodes(attr.value): - yield (attr.value, child) - for child in self._get_all_nodes(node.contents): - yield (node.contents, child) - elif isinstance(node, Template): - for child in self._get_all_nodes(node.name): - yield (node.name, child) - for param in node.params: - if param.showkey: - for child in self._get_all_nodes(param.name): - yield (param.name, child) - for child in self._get_all_nodes(param.value): - yield (param.value, child) - def _get_children(self, node): - for context, child in self._iterate_over_children(node): + for context, child in node.__iternodes__(self._get_all_nodes): yield child def _get_context(self, node, obj): - for context, child in self._iterate_over_children(node): + for context, child in node.__iternodes__(self._get_all_nodes): if child is obj: return context raise ValueError(obj)