diff --git a/pyproject.toml b/pyproject.toml index 10a3734..4a01eb3 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -59,10 +59,6 @@ requires = ["setuptools>=61.0"] build-backend = "setuptools.build_meta" [tool.pyright] -exclude = [ - # TODO - "src/earwigbot/wiki/copyvios" -] pythonVersion = "3.11" venvPath = "." venv = "venv" diff --git a/src/earwigbot/wiki/copyvios/__init__.py b/src/earwigbot/wiki/copyvios/__init__.py index 4602e28..b1715e4 100644 --- a/src/earwigbot/wiki/copyvios/__init__.py +++ b/src/earwigbot/wiki/copyvios/__init__.py @@ -18,208 +18,142 @@ # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE # SOFTWARE. +__all__ = [ + "DEFAULT_DEGREE", + "CopyvioChecker", + "CopyvioCheckResult", + "globalize", + "localize", +] + +import functools +import logging import time -from urllib.request import build_opener +from collections.abc import Callable -from earwigbot import exceptions -from earwigbot.wiki.copyvios.markov import MarkovChain -from earwigbot.wiki.copyvios.parsers import ArticleTextParser -from earwigbot.wiki.copyvios.search import SEARCH_ENGINES +from earwigbot.wiki.copyvios.exclusions import ExclusionsDB +from earwigbot.wiki.copyvios.markov import DEFAULT_DEGREE, MarkovChain +from earwigbot.wiki.copyvios.parsers import ArticleParser, ParserArgs +from earwigbot.wiki.copyvios.result import CopyvioCheckResult +from earwigbot.wiki.copyvios.search import SearchEngine, get_search_engine from earwigbot.wiki.copyvios.workers import CopyvioWorkspace, globalize, localize +from earwigbot.wiki.page import Page -__all__ = ["CopyvioMixIn", "globalize", "localize"] - -class CopyvioMixIn: +class CopyvioChecker: """ - **EarwigBot: Wiki Toolset: Copyright Violation MixIn** + Manages the lifecycle of a copyvio check or comparison. - This is a mixin that provides two public methods, :py:meth:`copyvio_check` - and :py:meth:`copyvio_compare`. The former checks the page for copyright - violations using a search engine API, and the latter compares the page - against a given URL. Credentials for the search engine API are stored in - the :py:class:`~earwigbot.wiki.site.Site`'s config. + Created by :py:class:`~earwigbot.wiki.page.Page` and handles the implementation + details of running a check. """ - def __init__(self, site): - self._search_config = site._search_config - self._exclusions_db = self._search_config.get("exclusions_db") - self._addheaders = [ - ("User-Agent", site.user_agent), + def __init__( + self, + page: Page, + *, + min_confidence: float = 0.75, + max_time: float = 30, + degree: int = DEFAULT_DEGREE, + logger: logging.Logger | None = None, + ) -> None: + self._page = page + self._site = page.site + self._config = page.site._search_config + self._min_confidence = min_confidence + self._max_time = max_time + self._degree = degree + self._logger = logger or logging.getLogger("earwigbot.wiki") + + self._headers = [ + ("User-Agent", page.site.user_agent), ("Accept-Encoding", "gzip"), ] - def _get_search_engine(self): - """Return a function that can be called to do web searches. - - The function takes one argument, a search query, and returns a list of - URLs, ranked by importance. The underlying logic depends on the - *engine* argument within our config; for example, if *engine* is - "Yahoo! BOSS", we'll use YahooBOSSSearchEngine for querying. - - Raises UnknownSearchEngineError if the 'engine' listed in our config is - unknown to us, and UnsupportedSearchEngineError if we are missing a - required package or module, like oauth2 for "Yahoo! BOSS". - """ - engine = self._search_config["engine"] - if engine not in SEARCH_ENGINES: - raise exceptions.UnknownSearchEngineError(engine) - - klass = SEARCH_ENGINES[engine] - credentials = self._search_config["credentials"] - opener = build_opener() - opener.addheaders = self._addheaders - - for dep in klass.requirements(): - try: - __import__(dep).__name__ - except (ModuleNotFoundError, AttributeError): - e = "Missing a required dependency ({}) for the {} engine" - e = e.format(dep, engine) - raise exceptions.UnsupportedSearchEngineError(e) - - return klass(credentials, opener) - - def copyvio_check( - self, - min_confidence=0.75, - max_queries=15, - max_time=-1, - no_searches=False, - no_links=False, - short_circuit=True, - degree=5, - ): - """Check the page for copyright violations. - - Returns a :class:`.CopyvioCheckResult` object with information on the - results of the check. - - *min_confidence* is the minimum amount of confidence we must have in - the similarity between a source text and the article in order for us to - consider it a suspected violation. This is a number between 0 and 1. - - *max_queries* is self-explanatory; we will never make more than this - number of queries in a given check. - - *max_time* can be set to prevent copyvio checks from taking longer than - a set amount of time (generally around a minute), which can be useful - if checks are called through a web server with timeouts. We will stop - checking new URLs as soon as this limit is reached. - - Setting *no_searches* to ``True`` will cause only URLs in the wikitext - of the page to be checked; no search engine queries will be made. - Setting *no_links* to ``True`` will cause the opposite to happen: URLs - in the wikitext will be ignored; search engine queries will be made - only. Setting both of these to ``True`` is pointless. - - Normally, the checker will short-circuit if it finds a URL that meets - *min_confidence*. This behavior normally causes it to skip any - remaining URLs and web queries, but setting *short_circuit* to - ``False`` will prevent this. - - Raises :exc:`.CopyvioCheckError` or subclasses - (:exc:`.UnknownSearchEngineError`, :exc:`.SearchQueryError`, ...) on - errors. - """ - log = "Starting copyvio check for [[{0}]]" - self._logger.info(log.format(self.title)) - searcher = self._get_search_engine() - parser = ArticleTextParser( - self.get(), - args={"nltk_dir": self._search_config["nltk_dir"], "lang": self._site.lang}, + self._parser = ArticleParser( + self._page.get(), + lang=self._site.lang, + nltk_dir=self._config["nltk_dir"], ) - article = MarkovChain(parser.strip(), degree=degree) - parser_args = {} + self._article = MarkovChain(self._parser.strip(), degree=self._degree) - if self._exclusions_db: - self._exclusions_db.sync(self.site.name) + @functools.cached_property + def _searcher(self) -> SearchEngine: + return get_search_engine(self._config, self._headers) - def exclude(u): - return self._exclusions_db.check(self.site.name, u) + @property + def _exclusions_db(self) -> ExclusionsDB | None: + return self._config.get("exclusions_db") - parser_args["mirror_hints"] = self._exclusions_db.get_mirror_hints(self) - else: - exclude = None + def _get_exclusion_callback(self) -> Callable[[str], bool] | None: + if not self._exclusions_db: + return None + return functools.partial(self._exclusions_db.check, self._site.name) + + def run_check( + self, + *, + max_queries: int = 15, + no_searches: bool = False, + no_links: bool = False, + short_circuit: bool = True, + ) -> CopyvioCheckResult: + parser_args: ParserArgs = {} + if self._exclusions_db: + self._exclusions_db.sync(self._site.name) + mirror_hints = self._exclusions_db.get_mirror_hints(self._page) + parser_args["mirror_hints"] = mirror_hints workspace = CopyvioWorkspace( - article, - min_confidence, - max_time, - self._logger, - self._addheaders, + self._article, + min_confidence=self._min_confidence, + max_time=self._max_time, + logger=self._logger, + headers=self._headers, short_circuit=short_circuit, parser_args=parser_args, - exclude_check=exclude, - config=self._search_config, - degree=degree, + exclusion_callback=self._get_exclusion_callback(), + config=self._config, + degree=self._degree, ) - if article.size < 20: # Auto-fail very small articles - result = workspace.get_result() - self._logger.info(result.get_log_message(self.title)) - return result + if self._article.size < 20: # Auto-fail very small articles + return workspace.get_result() if not no_links: - workspace.enqueue(parser.get_links()) + workspace.enqueue(self._parser.get_links()) num_queries = 0 if not no_searches: - chunks = parser.chunk(max_queries) + chunks = self._parser.chunk(max_queries) for chunk in chunks: if short_circuit and workspace.finished: workspace.possible_miss = True break - log = "[[{0}]] -> querying {1} for {2!r}" - self._logger.debug(log.format(self.title, searcher.name, chunk)) - workspace.enqueue(searcher.search(chunk)) + self._logger.debug( + f"[[{self._page.title}]] -> querying {self._searcher.name} " + f"for {chunk!r}" + ) + workspace.enqueue(self._searcher.search(chunk)) num_queries += 1 - time.sleep(1) + time.sleep(1) # TODO: Check whether this is needed workspace.wait() - result = workspace.get_result(num_queries) - self._logger.info(result.get_log_message(self.title)) - return result - - def copyvio_compare(self, urls, min_confidence=0.75, max_time=30, degree=5): - """Check the page like :py:meth:`copyvio_check` against specific URLs. - - This is essentially a reduced version of :meth:`copyvio_check` - a - copyivo comparison is made using Markov chains and the result is - returned in a :class:`.CopyvioCheckResult` object - but without using a - search engine, since the suspected "violated" URL is supplied from the - start. - - Its primary use is to generate a result when the URL is retrieved from - a cache, like the one used in EarwigBot's Tool Labs site. After a - search is done, the resulting URL is stored in a cache for 72 hours so - future checks against that page will not require another set of - time-and-money-consuming search engine queries. However, the comparison - itself (which includes the article's and the source's content) cannot - be stored for data retention reasons, so a fresh comparison is made - using this function. - - Since no searching is done, neither :exc:`.UnknownSearchEngineError` - nor :exc:`.SearchQueryError` will be raised. - """ - if not isinstance(urls, list): - urls = [urls] - log = "Starting copyvio compare for [[{0}]] against {1}" - self._logger.info(log.format(self.title, ", ".join(urls))) - article = MarkovChain(ArticleTextParser(self.get()).strip(), degree=degree) + return workspace.get_result(num_queries) + + def run_compare(self, urls: list[str]) -> CopyvioCheckResult: workspace = CopyvioWorkspace( - article, - min_confidence, - max_time, - self._logger, - self._addheaders, - max_time, + self._article, + min_confidence=self._min_confidence, + max_time=self._max_time, + logger=self._logger, + headers=self._headers, + url_timeout=self._max_time, num_workers=min(len(urls), 8), short_circuit=False, - config=self._search_config, - degree=degree, + config=self._config, + degree=self._degree, ) + workspace.enqueue(urls) workspace.wait() - result = workspace.get_result() - self._logger.info(result.get_log_message(self.title)) - return result + return workspace.get_result() diff --git a/src/earwigbot/wiki/copyvios/exclusions.py b/src/earwigbot/wiki/copyvios/exclusions.py index 6634cf0..f576620 100644 --- a/src/earwigbot/wiki/copyvios/exclusions.py +++ b/src/earwigbot/wiki/copyvios/exclusions.py @@ -18,15 +18,24 @@ # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE # SOFTWARE. +from __future__ import annotations + +__all__ = ["ExclusionsDB"] + +import logging import re import sqlite3 import threading import time +import typing import urllib.parse from earwigbot import exceptions -__all__ = ["ExclusionsDB"] +if typing.TYPE_CHECKING: + from earwigbot.wiki.page import Page + from earwigbot.wiki.site import Site + from earwigbot.wiki.sitesdb import SitesDB DEFAULT_SOURCES = { "all": [ # Applies to all, but located on enwiki @@ -52,26 +61,28 @@ class ExclusionsDB: """ **EarwigBot: Wiki Toolset: Exclusions Database Manager** - Controls the :file:`exclusions.db` file, which stores URLs excluded from - copyright violation checks on account of being known mirrors, for example. + Controls the :file:`exclusions.db` file, which stores URLs excluded from copyright + violation checks on account of being known mirrors, for example. """ - def __init__(self, sitesdb, dbfile, logger): + def __init__(self, sitesdb: SitesDB, dbfile: str, logger: logging.Logger) -> None: self._sitesdb = sitesdb self._dbfile = dbfile self._logger = logger self._db_access_lock = threading.Lock() - def __repr__(self): + def __repr__(self) -> str: """Return the canonical string representation of the ExclusionsDB.""" - res = "ExclusionsDB(sitesdb={0!r}, dbfile={1!r}, logger={2!r})" - return res.format(self._sitesdb, self._dbfile, self._logger) + return ( + f"ExclusionsDB(sitesdb={self._sitesdb!r}, dbfile={self._dbfile!r}, " + f"logger={self._logger!r})" + ) - def __str__(self): + def __str__(self) -> str: """Return a nice string representation of the ExclusionsDB.""" return f"" - def _create(self): + def _create(self) -> None: """Initialize the exclusions database with its necessary tables.""" script = """ CREATE TABLE sources (source_sitename, source_page); @@ -79,7 +90,7 @@ class ExclusionsDB: CREATE TABLE exclusions (exclusion_sitename, exclusion_url); """ query = "INSERT INTO sources VALUES (?, ?);" - sources = [] + sources: list[tuple[str, str]] = [] for sitename, pages in DEFAULT_SOURCES.items(): for page in pages: sources.append((sitename, page)) @@ -88,9 +99,9 @@ class ExclusionsDB: conn.executescript(script) conn.executemany(query, sources) - def _load_source(self, site, source): + def _load_source(self, site: Site, source: str) -> set[str]: """Load from a specific source and return a set of URLs.""" - urls = set() + urls: set[str] = set() try: data = site.get_page(source, follow_redirects=True).get() except exceptions.PageNotFoundError: @@ -123,7 +134,7 @@ class ExclusionsDB: urls.add(url) return urls - def _update(self, sitename): + def _update(self, sitename: str) -> None: """Update the database from listed sources in the index.""" query1 = "SELECT source_page FROM sources WHERE source_sitename = ?" query2 = "SELECT exclusion_url FROM exclusions WHERE exclusion_sitename = ?" @@ -140,7 +151,7 @@ class ExclusionsDB: else: site = self._sitesdb.get_site(sitename) with self._db_access_lock, sqlite3.connect(self._dbfile) as conn: - urls = set() + urls: set[str] = set() for (source,) in conn.execute(query1, (sitename,)): urls |= self._load_source(site, source) for (url,) in conn.execute(query2, (sitename,)): @@ -154,7 +165,7 @@ class ExclusionsDB: else: conn.execute(query7, (sitename, int(time.time()))) - def _get_last_update(self, sitename): + def _get_last_update(self, sitename: str) -> int: """Return the UNIX timestamp of the last time the db was updated.""" query = "SELECT update_time FROM updates WHERE update_sitename = ?" with self._db_access_lock, sqlite3.connect(self._dbfile) as conn: @@ -165,28 +176,34 @@ class ExclusionsDB: return 0 return result[0] if result else 0 - def sync(self, sitename, force=False): - """Update the database if it hasn't been updated recently. + def sync(self, sitename: str, force: bool = False) -> None: + """ + Update the database if it hasn't been updated recently. This updates the exclusions database for the site *sitename* and "all". - Site-specific lists are considered stale after 48 hours; global lists - after 12 hours. + Site-specific lists are considered stale after 48 hours; global lists after + 12 hours. """ max_staleness = 60 * 60 * (12 if sitename == "all" else 48) time_since_update = int(time.time() - self._get_last_update(sitename)) if force or time_since_update > max_staleness: - log = "Updating stale database: {0} (last updated {1} seconds ago)" - self._logger.info(log.format(sitename, time_since_update)) + self._logger.info( + f"Updating stale database: {sitename} (last updated " + f"{time_since_update} seconds ago)" + ) self._update(sitename) else: - log = "Database for {0} is still fresh (last updated {1} seconds ago)" - self._logger.debug(log.format(sitename, time_since_update)) + self._logger.debug( + f"Database for {sitename} is still fresh (last updated " + f"{time_since_update} seconds ago)" + ) if sitename != "all": self.sync("all", force=force) - def check(self, sitename, url): - """Check whether a given URL is in the exclusions database. + def check(self, sitename: str, url: str) -> bool: + """ + Check whether a given URL is in the exclusions database. Return ``True`` if the URL is in the database, or ``False`` otherwise. """ @@ -216,19 +233,18 @@ class ExclusionsDB: else: matches = normalized.startswith(excl) if matches: - log = "Exclusion detected in {0} for {1}" - self._logger.debug(log.format(sitename, url)) + self._logger.debug(f"Exclusion detected in {sitename} for {url}") return True - log = f"No exclusions in {sitename} for {url}" - self._logger.debug(log) + self._logger.debug(f"No exclusions in {sitename} for {url}") return False - def get_mirror_hints(self, page, try_mobile=True): - """Return a list of strings that indicate the existence of a mirror. + def get_mirror_hints(self, page: Page, try_mobile: bool = True) -> list[str]: + """ + Return a list of strings that indicate the existence of a mirror. - The source parser checks for the presence of these strings inside of - certain HTML tag attributes (``"href"`` and ``"src"``). + The source parser checks for the presence of these strings inside of certain + HTML tag attributes (``"href"`` and ``"src"``). """ site = page.site path = urllib.parse.urlparse(page.url).path @@ -238,10 +254,10 @@ class ExclusionsDB: if try_mobile: fragments = re.search(r"^([\w]+)\.([\w]+).([\w]+)$", site.domain) if fragments: - roots.append("{}.m.{}.{}".format(*fragments.groups())) + roots.append(f"{fragments[1]}.m.{fragments[2]}.{fragments[3]}") general = [ - root + site._script_path + "/" + script + root + site.script_path + "/" + script for root in roots for script in scripts ] diff --git a/src/earwigbot/wiki/copyvios/markov.py b/src/earwigbot/wiki/copyvios/markov.py index 5cf7a7f..f08195c 100644 --- a/src/earwigbot/wiki/copyvios/markov.py +++ b/src/earwigbot/wiki/copyvios/markov.py @@ -18,29 +18,44 @@ # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE # SOFTWARE. +__all__ = [ + "DEFAULT_DEGREE", + "EMPTY", + "EMPTY_INTERSECTION", + "MarkovChain", + "MarkovChainIntersection", +] + import re +from collections.abc import Iterable +from enum import Enum -__all__ = ["EMPTY", "EMPTY_INTERSECTION", "MarkovChain", "MarkovChainIntersection"] +DEFAULT_DEGREE = 5 -class MarkovChain: - """Implements a basic ngram Markov chain of words.""" - +class Sentinel(Enum): START = -1 END = -2 - def __init__(self, text, degree=5): + +RawChain = dict[tuple[str | Sentinel, ...], int] + + +class MarkovChain: + """Implements a basic ngram Markov chain of words.""" + + def __init__(self, text: str, degree: int = DEFAULT_DEGREE) -> None: self.text = text self.degree = degree # 2 for bigrams, 3 for trigrams, etc. self.chain = self._build() self.size = self._get_size() - def _build(self): + def _build(self) -> RawChain: """Build and return the Markov chain from the input text.""" padding = self.degree - 1 - words = re.sub(r"[^\w\s-]", "", self.text.lower(), flags=re.UNICODE).split() - words = ([self.START] * padding) + words + ([self.END] * padding) - chain = {} + words = re.sub(r"[^\w\s-]", "", self.text.lower()).split() + words = ([Sentinel.START] * padding) + words + ([Sentinel.END] * padding) + chain: RawChain = {} for i in range(len(words) - self.degree + 1): phrase = tuple(words[i : i + self.degree]) @@ -50,15 +65,15 @@ class MarkovChain: chain[phrase] = 1 return chain - def _get_size(self): + def _get_size(self) -> int: """Return the size of the Markov chain: the total number of nodes.""" return sum(self.chain.values()) - def __repr__(self): + def __repr__(self) -> str: """Return the canonical string representation of the MarkovChain.""" return f"MarkovChain(text={self.text!r})" - def __str__(self): + def __str__(self) -> str: """Return a nice string representation of the MarkovChain.""" return f"" @@ -66,61 +81,60 @@ class MarkovChain: class MarkovChainIntersection(MarkovChain): """Implements the intersection of two chains (i.e., their shared nodes).""" - def __init__(self, mc1, mc2): + def __init__(self, mc1: MarkovChain, mc2: MarkovChain) -> None: self.mc1, self.mc2 = mc1, mc2 self.chain = self._build() self.size = self._get_size() - def _build(self): + def _build(self) -> RawChain: """Build and return the Markov chain from the input chains.""" c1 = self.mc1.chain c2 = self.mc2.chain - chain = {} + chain: RawChain = {} for phrase in c1: if phrase in c2: chain[phrase] = min(c1[phrase], c2[phrase]) return chain - def __repr__(self): + def __repr__(self) -> str: """Return the canonical string representation of the intersection.""" - res = "MarkovChainIntersection(mc1={0!r}, mc2={1!r})" - return res.format(self.mc1, self.mc2) + return f"MarkovChainIntersection(mc1={self.mc1!r}, mc2={self.mc2!r})" - def __str__(self): + def __str__(self) -> str: """Return a nice string representation of the intersection.""" - res = "" - return res.format(self.size, self.mc1, self.mc2) + return ( + f"" + ) class MarkovChainUnion(MarkovChain): """Implemented the union of multiple chains.""" - def __init__(self, chains): + def __init__(self, chains: Iterable[MarkovChain]) -> None: self.chains = list(chains) self.chain = self._build() self.size = self._get_size() - def _build(self): + def _build(self) -> RawChain: """Build and return the Markov chain from the input chains.""" - union = {} + union: RawChain = {} for chain in self.chains: - for phrase, count in chain.chain.iteritems(): + for phrase, count in chain.chain.items(): if phrase in union: union[phrase] += count else: union[phrase] = count return union - def __repr__(self): + def __repr__(self) -> str: """Return the canonical string representation of the union.""" - res = "MarkovChainUnion(chains={!r})" - return res.format(self.chains) + return f"MarkovChainUnion(chains={self.chains!r})" - def __str__(self): + def __str__(self) -> str: """Return a nice string representation of the union.""" - res = "" - return res.format(self.size, "| ".join(str(chain) for chain in self.chains)) + chains = " | ".join(str(chain) for chain in self.chains) + return f"" EMPTY = MarkovChain("") diff --git a/src/earwigbot/wiki/copyvios/parsers.py b/src/earwigbot/wiki/copyvios/parsers.py index 09553e6..dc8fcad 100644 --- a/src/earwigbot/wiki/copyvios/parsers.py +++ b/src/earwigbot/wiki/copyvios/parsers.py @@ -18,44 +18,34 @@ # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE # SOFTWARE. +from __future__ import annotations + +__all__ = ["ArticleParser", "get_parser"] + import io import json import os.path import re +import typing import urllib.parse import urllib.request +from abc import ABC, abstractmethod +from collections.abc import Callable +from typing import Any, ClassVar, Literal, TypedDict import mwparserfromhell from earwigbot.exceptions import ParserExclusionError, ParserRedirectError -__all__ = ["ArticleTextParser", "get_parser"] - - -class _BaseTextParser: - """Base class for a parser that handles text.""" - - TYPE = None - - def __init__(self, text, url=None, args=None): - self.text = text - self.url = url - self._args = args or {} - - def __repr__(self): - """Return the canonical string representation of the text parser.""" - return f"{self.__class__.__name__}(text={self.text!r})" +if typing.TYPE_CHECKING: + import bs4 - def __str__(self): - """Return a nice string representation of the text parser.""" - name = self.__class__.__name__ - return f"<{name} of text with size {len(self.text)}>" + from earwigbot.wiki.copyvios.workers import OpenedURL -class ArticleTextParser(_BaseTextParser): +class ArticleParser: """A parser that can strip and chunk wikicode article text.""" - TYPE = "Article" TEMPLATE_MERGE_THRESHOLD = 35 NLTK_DEFAULT = "english" NLTK_LANGS = { @@ -78,7 +68,18 @@ class ArticleTextParser(_BaseTextParser): "tr": "turkish", } - def _merge_templates(self, code): + def __init__(self, text: str, lang: str, nltk_dir: str) -> None: + self.text = text + self._lang = lang + self._nltk_dir = nltk_dir + + def __repr__(self) -> str: + return f"{self.__class__.__name__}(text={self.text!r})" + + def __str__(self) -> str: + return f"<{self.__class__.__name__} of text with size {len(self.text)}>" + + def _merge_templates(self, code: mwparserfromhell.wikicode.Wikicode) -> None: """Merge template contents in to wikicode when the values are long.""" for template in code.filter_templates(recursive=code.RECURSE_OTHERS): chunks = [] @@ -92,23 +93,25 @@ class ArticleTextParser(_BaseTextParser): else: code.remove(template) - def _get_tokenizer(self): + def _get_tokenizer(self) -> Any: """Return a NLTK punctuation tokenizer for the article's language.""" import nltk - def datafile(lang): + def datafile(lang: str) -> str: return "file:" + os.path.join( - self._args["nltk_dir"], "tokenizers", "punkt", lang + ".pickle" + self._nltk_dir, "tokenizers", "punkt", lang + ".pickle" ) - lang = self.NLTK_LANGS.get(self._args.get("lang"), self.NLTK_DEFAULT) + lang = self.NLTK_LANGS.get(self._lang, self.NLTK_DEFAULT) try: nltk.data.load(datafile(self.NLTK_DEFAULT)) except LookupError: - nltk.download("punkt", self._args["nltk_dir"]) + nltk.download("punkt", self._nltk_dir) return nltk.data.load(datafile(lang)) - def _get_sentences(self, min_query, max_query, split_thresh): + def _get_sentences( + self, min_query: int, max_query: int, split_thresh: int + ) -> list[str]: """Split the article text into sentences of a certain length.""" def cut_sentence(words): @@ -138,24 +141,27 @@ class ArticleTextParser(_BaseTextParser): sentences.extend(cut_sentence(sentence.split())) return [sen for sen in sentences if len(sen) >= min_query] - def strip(self): - """Clean the page's raw text by removing templates and formatting. + def strip(self) -> str: + """ + Clean the page's raw text by removing templates and formatting. - Return the page's text with all HTML and wikicode formatting removed, - including templates, tables, and references. It retains punctuation - (spacing, paragraphs, periods, commas, (semi)-colons, parentheses, - quotes), original capitalization, and so forth. HTML entities are - replaced by their unicode equivalents. + Return the page's text with all HTML and wikicode formatting removed, including + templates, tables, and references. It retains punctuation (spacing, paragraphs, + periods, commas, (semi)-colons, parentheses, quotes), original capitalization, + and so forth. HTML entities are replaced by their unicode equivalents. The actual stripping is handled by :py:mod:`mwparserfromhell`. """ - def remove(code, node): - """Remove a node from a code object, ignoring ValueError. + def remove( + code: mwparserfromhell.wikicode.Wikicode, node: mwparserfromhell.nodes.Node + ) -> None: + """ + Remove a node from a code object, ignoring ValueError. - Sometimes we will remove a node that contains another node we wish - to remove, and we fail when we try to remove the inner one. Easiest - solution is to just ignore the exception. + Sometimes we will remove a node that contains another node we wish to + remove, and we fail when we try to remove the inner one. Easiest solution + is to just ignore the exception. """ try: code.remove(node) @@ -181,26 +187,32 @@ class ArticleTextParser(_BaseTextParser): self.clean = re.sub(r"\n\n+", "\n", clean).strip() return self.clean - def chunk(self, max_chunks, min_query=8, max_query=128, split_thresh=32): - """Convert the clean article text into a list of web-searchable chunks. - - No greater than *max_chunks* will be returned. Each chunk will only be - a sentence or two long at most (no more than *max_query*). The idea is - to return a sample of the article text rather than the whole, so we'll - pick and choose from parts of it, especially if the article is large - and *max_chunks* is low, so we don't end up just searching for just the - first paragraph. - - This is implemented using :py:mod:`nltk` (https://nltk.org/). A base - directory (*nltk_dir*) is required to store nltk's punctuation - database, and should be passed as an argument to the constructor. It is - typically located in the bot's working directory. + def chunk( + self, + max_chunks: int, + min_query: int = 8, + max_query: int = 128, + split_thresh: int = 32, + ) -> list[str]: + """ + Convert the clean article text into a list of web-searchable chunks. + + No greater than *max_chunks* will be returned. Each chunk will only be a + sentence or two long at most (no more than *max_query*). The idea is to return + a sample of the article text rather than the whole, so we'll pick and choose + from parts of it, especially if the article is large and *max_chunks* is low, + so we don't end up just searching for just the first paragraph. + + This is implemented using :py:mod:`nltk` (https://nltk.org/). A base directory + (*nltk_dir*) is required to store nltk's punctuation database, and should be + passed as an argument to the constructor. It is typically located in the bot's + working directory. """ sentences = self._get_sentences(min_query, max_query, split_thresh) if len(sentences) <= max_chunks: return sentences - chunks = [] + chunks: list[str] = [] while len(chunks) < max_chunks: if len(chunks) % 5 == 0: chunk = sentences.pop(0) # Pop from beginning @@ -216,7 +228,8 @@ class ArticleTextParser(_BaseTextParser): return chunks def get_links(self): - """Return a list of all external links in the article. + """ + Return a list of all external links in the article. The list is restricted to things that we suspect we can parse: i.e., those with schemes of ``http`` and ``https``. @@ -226,14 +239,42 @@ class ArticleTextParser(_BaseTextParser): return [str(link.url) for link in links if link.url.startswith(schemes)] -class _HTMLParser(_BaseTextParser): +class ParserArgs(TypedDict, total=False): + mirror_hints: list[str] + open_url: Callable[[str], OpenedURL | None] + + +class SourceParser(ABC): + """Base class for a parser that handles text.""" + + TYPE: ClassVar[str] + + def __init__(self, text: bytes, url: str, args: ParserArgs | None = None) -> None: + self.text = text + self.url = url + self._args = args or {} + + def __repr__(self) -> str: + """Return the canonical string representation of the text parser.""" + return f"{self.__class__.__name__}(text={self.text!r})" + + def __str__(self) -> str: + """Return a nice string representation of the text parser.""" + return f"<{self.__class__.__name__} of text with size {len(self.text)}>" + + @abstractmethod + def parse(self) -> str: ... + + +class HTMLParser(SourceParser): """A parser that can extract the text from an HTML document.""" TYPE = "HTML" hidden_tags = ["script", "style"] - def _fail_if_mirror(self, soup): - """Look for obvious signs that the given soup is a wiki mirror. + def _fail_if_mirror(self, soup: bs4.BeautifulSoup) -> None: + """ + Look for obvious signs that the given soup is a wiki mirror. If so, raise ParserExclusionError, which is caught in the workers and causes this source to excluded. @@ -242,13 +283,14 @@ class _HTMLParser(_BaseTextParser): return def func(attr): + assert "mirror_hints" in self._args return attr and any(hint in attr for hint in self._args["mirror_hints"]) if soup.find_all(href=func) or soup.find_all(src=func): raise ParserExclusionError() @staticmethod - def _get_soup(text): + def _get_soup(text: bytes) -> bs4.BeautifulSoup: """Parse some text using BeautifulSoup.""" import bs4 @@ -257,11 +299,11 @@ class _HTMLParser(_BaseTextParser): except ValueError: return bs4.BeautifulSoup(text) - def _clean_soup(self, soup): + def _clean_soup(self, soup: bs4.element.Tag) -> str: """Clean a BeautifulSoup tree of invisible tags.""" import bs4 - def is_comment(text): + def is_comment(text: bs4.element.Tag) -> bool: return isinstance(text, bs4.element.Comment) for comment in soup.find_all(text=is_comment): @@ -272,7 +314,7 @@ class _HTMLParser(_BaseTextParser): return "\n".join(s.replace("\n", " ") for s in soup.stripped_strings) - def _open(self, url, **kwargs): + def _open(self, url: str, **kwargs: Any) -> bytes | None: """Try to read a URL. Return None if it couldn't be read.""" opener = self._args.get("open_url") if not opener: @@ -280,13 +322,13 @@ class _HTMLParser(_BaseTextParser): result = opener(url, **kwargs) return result.content if result else None - def _load_from_blogspot(self, url): + def _load_from_blogspot(self, url: urllib.parse.ParseResult) -> str: """Load dynamic content from Blogger Dynamic Views.""" - match = re.search(r"'postId': '(\d+)'", self.text) + match = re.search(rb"'postId': '(\d+)'", self.text) if not match: return "" post_id = match.group(1) - url = f"https://{url.netloc}/feeds/posts/default/{post_id}?" + feed_url = f"https://{url.netloc}/feeds/posts/default/{post_id}?" params = { "alt": "json", "v": "2", @@ -294,7 +336,7 @@ class _HTMLParser(_BaseTextParser): "rewriteforssl": "true", } raw = self._open( - url + urllib.parse.urlencode(params), + feed_url + urllib.parse.urlencode(params), allow_content_types=["application/json"], ) if raw is None: @@ -308,19 +350,24 @@ class _HTMLParser(_BaseTextParser): except KeyError: return "" soup = self._get_soup(text) + if not soup.body: + return "" return self._clean_soup(soup.body) - def parse(self): - """Return the actual text contained within an HTML document. + def parse(self) -> str: + """ + Return the actual text contained within an HTML document. Implemented using :py:mod:`BeautifulSoup ` - (https://www.crummy.com/software/BeautifulSoup/). + (https://pypi.org/project/beautifulsoup4/). """ + import bs4 + url = urllib.parse.urlparse(self.url) if self.url else None soup = self._get_soup(self.text) if not soup.body: - # No tag present in HTML -> - # no scrapable content (possibly JS or