diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index e4bb1f5..88898ca 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -1,11 +1,11 @@ repos: - repo: https://github.com/astral-sh/ruff-pre-commit - rev: v0.6.2 + rev: v0.6.8 hooks: - id: ruff args: [--fix] - id: ruff-format - repo: https://github.com/RobertCraigie/pyright-python - rev: v1.1.377 + rev: v1.1.383 hooks: - id: pyright diff --git a/Makefile b/Makefile new file mode 100644 index 0000000..68ec720 --- /dev/null +++ b/Makefile @@ -0,0 +1,24 @@ +MAKEJS := uglifyjs --compress +MAKECSS := postcss -u cssnano --no-map + +.PHONY: all js css + +.INTERMEDIATE: static/style.tmp.css + +all: js css + +js: static/script.min.js + +css: static/style.min.css static/api.min.css + +static/script.min.js: static/script.js + $(MAKEJS) -o $@ -- $^ + +static/style.tmp.css: static/css/*.css + cat $^ > $@ + +static/style.min.css: static/style.tmp.css + $(MAKECSS) -o $@ $^ + +static/api.min.css: static/api.css + $(MAKECSS) -o $@ $^ diff --git a/README.md b/README.md index 0496500..a6158ce 100644 --- a/README.md +++ b/README.md @@ -18,13 +18,13 @@ Installation this should be in `~/www/python/venv`, otherwise it can be in a subdirectory of the git project named `venv`: - python3 -m venv venv - . venv/bin/activate - pip install -e . + python3 -m venv venv + . venv/bin/activate + pip install -e . - If you intend to modify CSS or JS, install the frontend dependencies: - npm install -g uglify-js cssnano postcss postcss-cli + npm install -g uglify-js cssnano postcss postcss-cli - Create an SQL database with the tables defined by `schema.sql`. @@ -40,7 +40,7 @@ Installation Running ======= -- Run `./build.py` to minify JS and CSS files after making any frontend - changes. +- Run `make` to minify JS and CSS files after making any frontend changes. -- Start your WSGI server pointing to app:app. +- Start your WSGI server pointing to app:app. For production, uWSGI or + Gunicorn are likely good options. For development, use `flask run`. diff --git a/app.py b/app.py index b45c145..de91a0a 100755 --- a/app.py +++ b/app.py @@ -1,23 +1,26 @@ #! /usr/bin/env python +import functools +import hashlib +import json import logging -from functools import wraps -from hashlib import md5 -from json import dumps +import os +import time +import traceback +from collections.abc import Callable from logging.handlers import TimedRotatingFileHandler -from os import path -from time import asctime -from traceback import format_exc +from typing import Any, ParamSpec -from earwigbot.bot import Bot from earwigbot.wiki.copyvios import globalize -from flask import Flask, g, make_response, request +from flask import Flask, Response, make_response, request from flask_mako import MakoTemplates, TemplateError, render_template from copyvios.api import format_api_error, handle_api_request -from copyvios.checker import do_check -from copyvios.cookies import parse_cookies -from copyvios.misc import cache, get_notice +from copyvios.cache import cache +from copyvios.checker import CopyvioCheckError, do_check +from copyvios.cookies import get_new_cookies +from copyvios.misc import get_notice +from copyvios.query import CheckQuery from copyvios.settings import process_settings from copyvios.sites import update_sites @@ -27,13 +30,17 @@ MakoTemplates(app) hand = TimedRotatingFileHandler("logs/app.log", when="midnight", backupCount=7) hand.setLevel(logging.DEBUG) app.logger.addHandler(hand) -app.logger.info("Flask server started " + asctime()) -app._hash_cache = {} +app.logger.info(f"Flask server started {time.asctime()}") +globalize(num_workers=8) -def catch_errors(func): - @wraps(func) - def inner(*args, **kwargs): +AnyResponse = Response | str | bytes +P = ParamSpec("P") + + +def catch_errors(func: Callable[P, AnyResponse]) -> Callable[P, AnyResponse]: + @functools.wraps(func) + def inner(*args: P.args, **kwargs: P.kwargs) -> AnyResponse: try: return func(*args, **kwargs) except TemplateError as exc: @@ -41,69 +48,42 @@ def catch_errors(func): return render_template("error.mako", traceback=exc.text) except Exception: app.logger.exception("Caught exception:") - return render_template("error.mako", traceback=format_exc()) + return render_template("error.mako", traceback=traceback.format_exc()) return inner -@app.before_first_request -def setup_app(): - cache.bot = Bot(".earwigbot", 100) - cache.langs, cache.projects = [], [] - cache.last_sites_update = 0 - cache.background_data = {} - cache.last_background_updates = {} - - globalize(num_workers=8) - - -@app.before_request -def prepare_request(): - g._db = None - g.cookies = parse_cookies( - request.script_root or "/", request.environ.get("HTTP_COOKIE") - ) - g.new_cookies = [] - - @app.after_request -def add_new_cookies(response): - for cookie in g.new_cookies: +def add_new_cookies(response: Response) -> Response: + for cookie in get_new_cookies(): response.headers.add("Set-Cookie", cookie) return response @app.after_request -def write_access_log(response): - msg = "%s %s %s %s -> %s" +def write_access_log(response: Response) -> Response: app.logger.debug( - msg, - asctime(), - request.method, - request.path, - request.values.to_dict(), - response.status_code, + f"{time.asctime()} {request.method} {request.path} " + f"{request.values.to_dict()} -> {response.status_code}" ) return response -@app.teardown_appcontext -def close_databases(error): - if g._db: - g._db.close() +@functools.lru_cache +def _get_hash(path: str, mtime: float) -> str: + # mtime is used as part of the cache key + with open(path, "rb") as fp: + return hashlib.sha1(fp.read()).hexdigest() -def external_url_handler(error, endpoint, values): +def external_url_handler( + error: Exception, endpoint: str, values: dict[str, Any] +) -> str: if endpoint == "static" and "file" in values: - fpath = path.join(app.static_folder, values["file"]) - mtime = path.getmtime(fpath) - cache = app._hash_cache.get(fpath) - if cache and cache[0] == mtime: - hashstr = cache[1] - else: - with open(fpath, "rb") as f: - hashstr = md5(f.read()).hexdigest() - app._hash_cache[fpath] = (mtime, hashstr) + assert app.static_folder is not None + path = os.path.join(app.static_folder, values["file"]) + mtime = os.path.getmtime(path) + hashstr = _get_hash(path, mtime) return f"/static/{values['file']}?v={hashstr}" raise error @@ -113,22 +93,28 @@ app.url_build_error_handlers.append(external_url_handler) @app.route("/") @catch_errors -def index(): +def index() -> AnyResponse: notice = get_notice() update_sites() - query = do_check() + query = CheckQuery.from_get_args() + try: + result = do_check(query) + error = None + except CopyvioCheckError as exc: + result = None + error = exc return render_template( "index.mako", notice=notice, query=query, - result=query.result, - turnitin_result=query.turnitin_result, + result=result, + error=error, ) @app.route("/settings", methods=["GET", "POST"]) @catch_errors -def settings(): +def settings() -> AnyResponse: status = process_settings() if request.method == "POST" else None update_sites() default = cache.bot.wiki.get_site() @@ -142,13 +128,13 @@ def settings(): @app.route("/api") @catch_errors -def api(): +def api() -> AnyResponse: return render_template("api.mako", help=True) @app.route("/api.json") @catch_errors -def api_json(): +def api_json() -> AnyResponse: if not request.args: return render_template("api.mako", help=True) @@ -160,12 +146,12 @@ def api_json(): except Exception as exc: result = format_api_error("unhandled_exception", exc) else: - errmsg = f"Unknown format: '{format}'" + errmsg = f"Unknown format: {format!r}" result = format_api_error("unknown_format", errmsg) if format == "jsonfm": return render_template("api.mako", help=False, result=result) - resp = make_response(dumps(result)) + resp = make_response(json.dumps(result)) resp.mimetype = "application/json" resp.headers["Access-Control-Allow-Origin"] = "*" return resp diff --git a/build.py b/build.py deleted file mode 100755 index 5d3e1c2..0000000 --- a/build.py +++ /dev/null @@ -1,39 +0,0 @@ -#! /usr/bin/env python - -import os -import subprocess - - -def process(*args): - print(*args) - subprocess.run(args, check=True) - - -def main(): - root = os.path.join(os.path.dirname(__file__), "static") - for dirpath, dirnames, filenames in os.walk(root): - for filename in filenames: - name = os.path.relpath(os.path.join(dirpath, filename)) - if filename.endswith(".js") and ".min." not in filename: - process( - "uglifyjs", - "--compress", - "-o", - name.replace(".js", ".min.js"), - "--", - name, - ) - if filename.endswith(".css") and ".min." not in filename: - process( - "postcss", - "-u", - "cssnano", - "--no-map", - name, - "-o", - name.replace(".css", ".min.css"), - ) - - -if __name__ == "__main__": - main() diff --git a/pyproject.toml b/pyproject.toml index b40dd06..bb5617e 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -14,8 +14,9 @@ dependencies = [ "flask-mako >= 0.4", "mako >= 1.3.5", "requests >= 2.32.3", + "pydantic >= 2.9.2", "SQLAlchemy >= 2.0.32", - "apsw >= 3.46.1", + "mwoauth >= 0.4.0", ] [project.urls] @@ -28,11 +29,6 @@ build-backend = "setuptools.build_meta" [tool.pyright] pythonVersion = "3.11" -exclude = [ - # TODO - "src/copyvios/*", - "app.py", -] venvPath = "." venv = "venv" @@ -41,4 +37,3 @@ target-version = "py311" [tool.ruff.lint] select = ["E4", "E7", "E9", "F", "I", "UP"] -ignore = ["F403"] diff --git a/src/copyvios/api.py b/src/copyvios/api.py index 9373d9c..be8ca26 100644 --- a/src/copyvios/api.py +++ b/src/copyvios/api.py @@ -1,135 +1,142 @@ -from collections import OrderedDict +__all__ = ["format_api_error", "handle_api_request"] -from .checker import T_POSSIBLE, T_SUSPECT, do_check -from .highlighter import highlight_delta -from .misc import Query, cache -from .sites import update_sites +from typing import Any -__all__ = ["format_api_error", "handle_api_request"] +from earwigbot.wiki import Page +from earwigbot.wiki.copyvios.result import CopyvioCheckResult, CopyvioSource +from flask import g + +from .cache import cache +from .checker import T_POSSIBLE, T_SUSPECT, CopyvioCheckError, ErrorCode, do_check +from .highlighter import highlight_delta +from .query import APIQuery +from .sites import get_site, update_sites _CHECK_ERRORS = { - "no search method": "Either 'use_engine' or 'use_links' must be true", - "bad oldid": "The revision ID is invalid", - "no URL": "The parameter 'url' is required for URL comparisons", - "bad URI": "The given URI scheme is unsupported", - "no data": "No text could be found in the given URL (note that only HTML " - "and plain text pages are supported, and content generated by " - "JavaScript or found inside iframes is ignored)", - "timeout": "The given URL timed out before any data could be retrieved", - "search error": "An error occurred while using the search engine; try " - "reloading or setting 'use_engine' to 0", + ErrorCode.NO_SEARCH_METHOD: "Either 'use_engine' or 'use_links' must be true", + ErrorCode.BAD_OLDID: "The revision ID is invalid", + ErrorCode.NO_URL: "The parameter 'url' is required for URL comparisons", + ErrorCode.BAD_URI: "The given URI scheme is unsupported", + ErrorCode.NO_DATA: ( + "No text could be found in the given URL (note that only HTML and plain text " + "pages are supported, and content generated by JavaScript or found inside " + "iframes is ignored)" + ), + ErrorCode.TIMEOUT: "The given URL timed out before any data could be retrieved", + ErrorCode.SEARCH_ERROR: ( + "An error occurred while using the search engine; try reloading or setting " + "'use_engine' to 0" + ), } -def _serialize_page(page): - return OrderedDict((("title", page.title), ("url", page.url))) +def _serialize_page(page: Page) -> dict[str, Any]: + return {"title": page.title, "url": page.url} -def _serialize_source(source, show_skip=True): +def _serialize_source( + source: CopyvioSource | None, show_skip: bool = True +) -> dict[str, Any]: if not source: - return OrderedDict((("url", None), ("confidence", 0.0), ("violation", "none"))) - - conf = source.confidence - data = OrderedDict( - ( - ("url", source.url), - ("confidence", conf), - ( - "violation", - ( - "suspected" - if conf >= T_SUSPECT - else "possible" - if conf >= T_POSSIBLE - else "none" - ), - ), - ) - ) + return {"url": None, "confidence": 0.0, "violation": "none"} + + if source.confidence >= T_SUSPECT: + violation = "suspected" + elif source.confidence >= T_POSSIBLE: + violation = "possible" + else: + violation = "none" + + data = { + "url": source.url, + "confidence": source.confidence, + "violation": violation, + } if show_skip: data["skipped"] = source.skipped data["excluded"] = source.excluded return data -def _serialize_detail(result): +def _serialize_detail(result: CopyvioCheckResult) -> dict[str, Any] | None: + if not result.best: + return None source_chain, delta = result.best.chains article = highlight_delta(None, result.article_chain, delta) source = highlight_delta(None, source_chain, delta) - return OrderedDict((("article", article), ("source", source))) + return {"article": article, "source": source} -def format_api_error(code, info): - if isinstance(info, BaseException): - info = type(info).__name__ + ": " + str(info) - error_inner = OrderedDict((("code", code), ("info", info))) - return OrderedDict((("status", "error"), ("error", error_inner))) +def format_api_error(code: str, info: Exception | str) -> dict[str, Any]: + if isinstance(info, Exception): + info = f"{type(info).__name__}: {info}" + return {"status": "error", "error": {"code": code, "info": info}} -def _hook_default(query): - info = f"Unknown action: '{query.action.lower()}'" - return format_api_error("unknown_action", info) +def _hook_default(query: APIQuery) -> dict[str, Any]: + if query.action: + return format_api_error( + "unknown_action", f"Unknown action: {query.action.lower()!r}" + ) + else: + return format_api_error("missing_action", "Missing 'action' query parameter") -def _hook_check(query): - do_check(query) +def _hook_check(query: APIQuery) -> dict[str, Any]: + try: + result = do_check(query) + except CopyvioCheckError as exc: + info = _CHECK_ERRORS.get(exc.code, "An unknown error occurred") + return format_api_error(exc.code.name.lower(), info) + if not query.submitted: info = ( - "The query parameters 'project', 'lang', and either 'title' " - "or 'oldid' are required for checks" + "The query parameters 'project', 'lang', and either 'title' or 'oldid' " + "are required for checks" ) return format_api_error("missing_params", info) - if query.error: - info = _CHECK_ERRORS.get(query.error, "An unknown error occurred") - return format_api_error(query.error.replace(" ", "_"), info) - elif not query.site: + if not get_site(): info = ( - f"The given site (project={query.project}, lang={query.lang}) either doesn't exist," - " is closed, or is private" + f"The given site (project={query.project}, lang={query.lang}) either " + "doesn't exist, is closed, or is private" ) return format_api_error("bad_site", info) - elif not query.result: + if not result: if query.oldid: - info = "The revision ID couldn't be found: {0}" - return format_api_error("bad_oldid", info.format(query.oldid)) + return format_api_error( + "bad_oldid", f"The revision ID couldn't be found: {query.oldid}" + ) else: - info = "The page couldn't be found: {0}" - return format_api_error("bad_title", info.format(query.page.title)) - - result = query.result - data = OrderedDict( - ( - ("status", "ok"), - ( - "meta", - OrderedDict( - ( - ("time", result.time), - ("queries", result.queries), - ("cached", result.cached), - ("redirected", bool(query.redirected_from)), - ) - ), - ), - ("page", _serialize_page(query.page)), - ) - ) - if result.cached: - data["meta"]["cache_time"] = result.cache_time - if query.redirected_from: - data["original_page"] = _serialize_page(query.redirected_from) + assert isinstance(g.page, Page), g.page + return format_api_error( + "bad_title", f"The page couldn't be found: {g.page.title}" + ) + + assert isinstance(g.page, Page), g.page + data = { + "status": "ok", + "meta": { + "time": result.time, + "queries": result.queries, + "cached": result.metadata.cached, + "redirected": hasattr(result.metadata, "redirected_from"), + }, + "page": _serialize_page(g.page), + } + if result.metadata.cached: + data["meta"]["cache_time"] = result.metadata.cache_time + if result.metadata.redirected_from: + data["original_page"] = _serialize_page(result.metadata.redirected_from) data["best"] = _serialize_source(result.best, show_skip=False) data["sources"] = [_serialize_source(source) for source in result.sources] - if query.detail in ("1", "true"): + if query.detail: data["detail"] = _serialize_detail(result) return data -def _hook_sites(query): +def _hook_sites(query: APIQuery) -> dict[str, Any]: update_sites() - return OrderedDict( - (("status", "ok"), ("langs", cache.langs), ("projects", cache.projects)) - ) + return {"status": "ok", "langs": cache.langs, "projects": cache.projects} _HOOKS = { @@ -140,19 +147,12 @@ _HOOKS = { def handle_api_request(): - query = Query() - if query.version: - try: - query.version = int(query.version) - except ValueError: - info = f"The version string is invalid: {query.version}" - return format_api_error("invalid_version", info) - else: - query.version = 1 + query = APIQuery.from_get_args() if query.version == 1: action = query.action.lower() if query.action else "" return _HOOKS.get(action, _hook_default)(query) - - info = f"The API version is unsupported: {query.version}" - return format_api_error("unsupported_version", info) + else: + return format_api_error( + "unsupported_version", f"The API version is unsupported: {query.version}" + ) diff --git a/src/copyvios/attribution.py b/src/copyvios/attribution.py index 39ec265..f547e66 100644 --- a/src/copyvios/attribution.py +++ b/src/copyvios/attribution.py @@ -1,7 +1,7 @@ -from earwigbot.wiki import NS_TEMPLATE - __all__ = ["get_attribution_info"] +from earwigbot.wiki import NS_TEMPLATE, Page, Site + ATTRIB_TEMPLATES = { "enwiki": { "CC-notice", @@ -14,11 +14,11 @@ ATTRIB_TEMPLATES = { } -def get_attribution_info(site, page): - """Check to see if the given page has some kind of attribution info. +def get_attribution_info(site: Site, page: Page) -> tuple[str, str] | None: + """ + Check to see if the given page has some kind of attribution info. - If yes, return a tuple of (attribution template name, template URL). - If no, return None. + Return a tuple of (attribution template name, template URL) or None if no template. """ if site.name not in ATTRIB_TEMPLATES: return None @@ -32,4 +32,5 @@ def get_attribution_info(site, page): name = str(template.name).strip() title = name if ":" in name else prefix + ":" + name return name, site.get_page(title).url + return None diff --git a/src/copyvios/background.py b/src/copyvios/background.py index 6a8ec77..f9806c1 100644 --- a/src/copyvios/background.py +++ b/src/copyvios/background.py @@ -1,100 +1,162 @@ +__all__ = ["get_background"] + +import json +import logging import random import re import urllib.error import urllib.parse import urllib.request -from datetime import datetime, timedelta -from json import loads +from dataclasses import dataclass +from datetime import UTC, date, datetime, timedelta +from typing import Self from earwigbot import exceptions +from earwigbot.wiki import Site from flask import g -from .misc import cache +from .cache import cache +from .cookies import get_cookies + +logger = logging.getLogger(__name__) + + +@dataclass(frozen=True) +class BackgroundInfo: + filename: str + url: str + descurl: str + width: int + height: int + -__all__ = ["set_background"] +@dataclass(frozen=True) +class ScreenInfo: + width: int = 1024 + height: int = 768 + + @classmethod + def from_cookie(cls, value: str) -> Self: + try: + screen = json.loads(value) + screen = cls(width=int(screen["width"]), height=int(screen["height"])) + if screen.width <= 0 or screen.height <= 0: + raise ValueError() + except (ValueError, KeyError): + screen = cls() + return screen -def _get_commons_site(): +def _get_commons_site() -> Site: try: return cache.bot.wiki.get_site("commonswiki") except exceptions.SiteNotFoundError: return cache.bot.wiki.add_site(project="wikimedia", lang="commons") -def _load_file(site, filename): - data = site.api_query( - action="query", - prop="imageinfo", - iiprop="url|size|canonicaltitle", - titles="File:" + filename, +def _load_file(site: Site, filename: str) -> BackgroundInfo | None: + prefix = "File:" + try: + data = site.api_query( + action="query", + prop="imageinfo", + iiprop="url|size|canonicaltitle", + titles=prefix + filename, + ) + res = list(data["query"]["pages"].values())[0]["imageinfo"][0] + name = res["canonicaltitle"] + assert isinstance(name, str), name + except Exception: + logger.exception(f"Failed to get info for file {prefix + filename!r}") + return None + name = name.removeprefix(prefix).replace(" ", "_") + return BackgroundInfo( + name, res["url"], res["descriptionurl"], res["width"], res["height"] ) - res = list(data["query"]["pages"].values())[0]["imageinfo"][0] - name = res["canonicaltitle"][len("File:") :].replace(" ", "_") - return name, res["url"], res["descriptionurl"], res["width"], res["height"] -def _get_fresh_potd(): +def _get_fresh_from_potd() -> BackgroundInfo | None: site = _get_commons_site() - date = datetime.utcnow().strftime("%Y-%m-%d") - page = site.get_page("Template:Potd/" + date) + date = datetime.now(UTC).strftime("%Y-%m-%d") + page = site.get_page(f"Template:Potd/{date}") regex = r"\{\{Potd filename\|(?:1=)?(.*?)\|.*?\}\}" - filename = re.search(regex, page.get()).group(1) + try: + match = re.search(regex, page.get()) + except exceptions.EarwigBotError: + logger.exception(f"Failed to load today's POTD from {page.title!r}") + return None + if not match: + logger.exception(f"Failed to extract POTD from {page.title!r}") + return None + filename = match.group(1) return _load_file(site, filename) -def _get_fresh_list(): +def _get_fresh_from_list() -> BackgroundInfo | None: site = _get_commons_site() page = site.get_page("User:The Earwig/POTD") regex = r"\*\*?\s*\[\[:File:(.*?)\]\]" - filenames = re.findall(regex, page.get()) - - # Ensure all workers share the same background each day: - random.seed(datetime.utcnow().strftime("%Y%m%d")) - filename = random.choice(filenames) + try: + filenames = re.findall(regex, page.get()) + except exceptions.EarwigBotError: + logger.exception(f"Failed to load images from {page.title!r}") + return None + + # Ensure all workers share the same background each day + rand = random.Random() + rand.seed(datetime.now(UTC).strftime("%Y%m%d")) + try: + filename = rand.choice(filenames) + except IndexError: + logger.exception(f"Failed to find any images on {page.title!r}") + return None return _load_file(site, filename) -def _build_url(screen, filename, url, imgwidth, imgheight): - width = screen["width"] - if float(imgwidth) / imgheight > float(screen["width"]) / screen["height"]: - width = int(float(imgwidth) / imgheight * screen["height"]) - if width >= imgwidth: - return url - url = url.replace("/commons/", "/commons/thumb/") - return "%s/%dpx-%s" % (url, width, urllib.parse.quote(filename.encode("utf8"))) +def _build_url(screen: ScreenInfo, background: BackgroundInfo) -> str: + width = screen.width + if background.width / background.height > screen.width / screen.height: + width = int(background.width / background.height * screen.height) + if width >= background.width: + return background.url + url = background.url.replace("/commons/", "/commons/thumb/") + return f"{url}/{width}px-{urllib.parse.quote(background.filename)}" -_BACKGROUNDS = {"potd": _get_fresh_potd, "list": _get_fresh_list} +_BACKGROUNDS = { + "potd": _get_fresh_from_potd, + "list": _get_fresh_from_list, +} +_BACKGROUND_CACHE: dict[str, BackgroundInfo | None] = {} +_LAST_BACKGROUND_UPDATES: dict[str, date] = { + key: datetime.min.date() for key in _BACKGROUNDS +} -def _get_background(selected): - if not cache.last_background_updates: - for key in _BACKGROUNDS: - cache.last_background_updates[key] = datetime.min - plus_one = cache.last_background_updates[selected] + timedelta(days=1) - max_age = datetime(plus_one.year, plus_one.month, plus_one.day) - if datetime.utcnow() > max_age: - update_func = _BACKGROUNDS.get(selected, _get_fresh_list) - cache.background_data[selected] = update_func() - cache.last_background_updates[selected] = datetime.utcnow().date() - return cache.background_data[selected] +def _get_background(selected: str) -> BackgroundInfo | None: + next_day = _LAST_BACKGROUND_UPDATES[selected] + timedelta(days=1) + max_age = datetime(next_day.year, next_day.month, next_day.day, tzinfo=UTC) + if datetime.now(UTC) > max_age: + update_func = _BACKGROUNDS.get(selected, _get_fresh_from_list) + _BACKGROUND_CACHE[selected] = update_func() + _LAST_BACKGROUND_UPDATES[selected] = datetime.now(UTC).date() + return _BACKGROUND_CACHE[selected] -def set_background(selected): - if "CopyviosScreenCache" in g.cookies: - screen_cache = g.cookies["CopyviosScreenCache"].value - try: - screen = loads(screen_cache) - screen = {"width": int(screen["width"]), "height": int(screen["height"])} - if screen["width"] <= 0 or screen["height"] <= 0: - raise ValueError() - except (ValueError, KeyError): - screen = {"width": 1024, "height": 768} +def get_background(selected: str) -> str: + cookies = get_cookies() + if "CopyviosScreenCache" in cookies: + cookie = cookies["CopyviosScreenCache"].value + screen = ScreenInfo.from_cookie(cookie) else: - screen = {"width": 1024, "height": 768} + screen = ScreenInfo() - filename, url, descurl, width, height = _get_background(selected) - bg_url = _build_url(screen, filename, url, width, height) - g.descurl = descurl + background = _get_background(selected) + if background: + bg_url = _build_url(screen, background) + g.descurl = background.descurl + else: + bg_url = "" + g.descurl = None return bg_url diff --git a/src/copyvios/cache.py b/src/copyvios/cache.py new file mode 100644 index 0000000..ee1419a --- /dev/null +++ b/src/copyvios/cache.py @@ -0,0 +1,70 @@ +__all__ = ["cache"] + +import os.path +import sqlite3 +from dataclasses import dataclass, field +from typing import Any + +import sqlalchemy +from earwigbot.bot import Bot + + +@dataclass(frozen=True, order=True) +class Lang: + code: str + name: str + + +@dataclass(frozen=True, order=True) +class Project: + code: str + name: str + + +@dataclass +class AppCache: + bot: Bot + engine: sqlalchemy.Engine + langs: list[Lang] = field(default_factory=list) + projects: list[Project] = field(default_factory=list) + + +@sqlalchemy.event.listens_for(sqlalchemy.Engine, "connect") +def setup_connection(dbapi_connection: Any, connection_record: Any) -> None: + if isinstance(dbapi_connection, sqlite3.Connection): + cursor = dbapi_connection.cursor() + cursor.execute("PRAGMA foreign_keys = ON") + cursor.close() + + +def _get_engine(bot: Bot) -> sqlalchemy.Engine: + args = bot.config.wiki["copyvios"].copy() + engine_name = args.pop("engine", "mysql").lower() + + if engine_name == "mysql": + url_object = sqlalchemy.URL.create( + "mysql+pymysql", + host=args["host"], + database=args["db"], + query={ + "charset": "utf8mb4", + "read_default_file": os.path.expanduser("~/.my.cnf"), + }, + ) + return sqlalchemy.create_engine(url_object, pool_pre_ping=True) + + if engine_name == "sqlite": + dbpath = os.path.join(bot.config.root_dir, "copyvios.db") + return sqlalchemy.create_engine("sqlite:///" + dbpath) + + raise ValueError(f"Unknown engine: {engine_name}") + + +def _make_cache() -> AppCache: + bot = Bot(".earwigbot", 100) + engine = _get_engine(bot) + return AppCache(bot=bot, engine=engine) + + +# Singleton +cache = _make_cache() diff --git a/src/copyvios/checker.py b/src/copyvios/checker.py index e3856c6..6c4b165 100644 --- a/src/copyvios/checker.py +++ b/src/copyvios/checker.py @@ -1,119 +1,136 @@ +__all__ = ["T_POSSIBLE", "T_SUSPECT", "do_check"] + +import hashlib +import logging import re -from datetime import datetime, timedelta -from hashlib import sha256 -from logging import getLogger -from urllib.parse import urlparse +import typing +import urllib.parse +from datetime import UTC, datetime, timedelta +from enum import Enum from earwigbot import exceptions -from earwigbot.wiki.copyvios.markov import EMPTY, MarkovChain -from earwigbot.wiki.copyvios.parsers import ArticleTextParser +from earwigbot.wiki import Page, Site +from earwigbot.wiki.copyvios import CopyvioChecker +from earwigbot.wiki.copyvios.markov import DEFAULT_DEGREE, EMPTY from earwigbot.wiki.copyvios.result import CopyvioCheckResult, CopyvioSource +from earwigbot.wiki.copyvios.workers import CopyvioWorkspace +from flask import g +from sqlalchemy import PoolProxiedConnection -from .misc import Query, get_cursor, get_db, get_sql_error, sql_dialect +from .cache import cache +from .misc import get_sql_error, sql_dialect +from .query import CheckQuery from .sites import get_site from .turnitin import search_turnitin -__all__ = ["do_check", "T_POSSIBLE", "T_SUSPECT"] - T_POSSIBLE = 0.4 T_SUSPECT = 0.75 -_LOGGER = getLogger("copyvios.checker") +_LOGGER = logging.getLogger("copyvios.checker") -def _coerce_bool(val): - return val and val not in ("0", "false") +class ErrorCode(Enum): + BAD_ACTION = "bad action" + BAD_OLDID = "bad oldid" + BAD_URI = "bad URI" + NO_DATA = "no data" + NO_SEARCH_METHOD = "no search method" + NO_URL = "no URL" + SEARCH_ERROR = "search error" + TIMEOUT = "timeout" -def do_check(query=None): - if not query: - query = Query() - if query.lang: - query.lang = query.orig_lang = query.lang.strip().lower() - if "::" in query.lang: - query.lang, query.name = query.lang.split("::", 1) - if query.project: - query.project = query.project.strip().lower() - if query.oldid: - query.oldid = query.oldid.strip().lstrip("0") +class CopyvioCheckError(Exception): + def __init__(self, code: ErrorCode): + super().__init__(code.value) + self.code = code + - query.submitted = query.project and query.lang and (query.title or query.oldid) +def do_check(query: CheckQuery) -> CopyvioCheckResult | None: if query.submitted: - query.site = get_site(query) - if query.site: - _get_results(query, follow=not _coerce_bool(query.noredirect)) - return query + site = get_site(query) + if site: + return _get_results(query, site, follow=not query.noredirect) + return None -def _get_results(query, follow=True): +def _get_results( + query: CheckQuery, site: Site, follow: bool = True +) -> CopyvioCheckResult | None: if query.oldid: if not re.match(r"^\d+$", query.oldid): - query.error = "bad oldid" - return - page = query.page = _get_page_by_revid(query.site, query.oldid) + raise CopyvioCheckError(ErrorCode.BAD_OLDID) + page = _get_page_by_revid(site, query.oldid) if not page: - return + return None + g.page = page else: - page = query.page = query.site.get_page(query.title) + assert query.title + g.page = page = site.get_page(query.title) try: - page.get() # Make sure that the page exists before we check it! + page.get() # Make sure that the page exists before we check it except (exceptions.PageNotFoundError, exceptions.InvalidPageError): - return + return None if page.is_redirect and follow: try: query.title = page.get_redirect_target() except exceptions.RedirectError: - pass # Something's wrong. Continue checking the original page. + pass # Something's wrong; continue checking the original page else: - query.redirected_from = page - _get_results(query, follow=False) - return + result = _get_results(query, site, follow=False) + if result: + result.metadata.redirected_from = page + return result if not query.action: query.action = "compare" if query.url else "search" + if query.action == "search": - use_engine = 0 if query.use_engine in ("0", "false") else 1 - use_links = 0 if query.use_links in ("0", "false") else 1 - use_turnitin = 1 if query.turnitin in ("1", "true") else 0 - if not use_engine and not use_links and not use_turnitin: - query.error = "no search method" - return + if not query.use_engine and not query.use_links and not query.turnitin: + raise CopyvioCheckError(ErrorCode.NO_SEARCH_METHOD) - # Handle the turnitin check - if use_turnitin: - query.turnitin_result = search_turnitin(page.title, query.lang) + # Handle the Turnitin check + turnitin_result = None + if query.turnitin: + assert query.lang + turnitin_result = search_turnitin(page.title, query.lang) # Handle the copyvio check - _perform_check(query, page, use_engine, use_links) + conn = cache.engine.raw_connection() + try: + result = _perform_check(query, page, conn) + finally: + conn.close() + if turnitin_result: + result.metadata.turnitin_result = turnitin_result + elif query.action == "compare": if not query.url: - query.error = "no URL" - return - scheme = urlparse(query.url).scheme + raise CopyvioCheckError(ErrorCode.NO_URL) + scheme = urllib.parse.urlparse(query.url).scheme if not scheme and query.url[0] not in ":/": query.url = "http://" + query.url elif scheme not in ["http", "https"]: - query.error = "bad URI" - return - degree = 5 - if query.degree: - try: - degree = int(query.degree) - except ValueError: - pass + raise CopyvioCheckError(ErrorCode.BAD_URI) + + degree = query.degree or DEFAULT_DEGREE result = page.copyvio_compare( query.url, min_confidence=T_SUSPECT, max_time=10, degree=degree ) - if result.best.chains[0] is EMPTY: - query.error = "timeout" if result.time > 10 else "no data" - return - query.result = result - query.result.cached = False + result.metadata.cached = False + + if not result.best or result.best.chains[0] is EMPTY: + if result.time > 10: + raise CopyvioCheckError(ErrorCode.TIMEOUT) + else: + raise CopyvioCheckError(ErrorCode.NO_DATA) + return result + else: - query.error = "bad action" + raise CopyvioCheckError(ErrorCode.BAD_ACTION) -def _get_page_by_revid(site, revid): +def _get_page_by_revid(site: Site, revid: str) -> Page | None: try: res = site.api_query( action="query", @@ -140,104 +157,118 @@ def _get_page_by_revid(site, revid): return page -def _perform_check(query, page, use_engine, use_links): - conn = get_db() +def _perform_check( + query: CheckQuery, page: Page, conn: PoolProxiedConnection +) -> CopyvioCheckResult: sql_error = get_sql_error() - mode = f"{use_engine}:{use_links}:" + mode = f"{query.use_engine}:{query.use_links}:" + result: CopyvioCheckResult | None = None - if not _coerce_bool(query.nocache): + if not query.nocache: try: - query.result = _get_cached_results( - page, conn, mode, _coerce_bool(query.noskip) - ) + result = _get_cached_results(page, conn, mode, query.noskip) except sql_error: _LOGGER.exception("Failed to retrieve cached results") - if not query.result: + if not result: try: - query.result = page.copyvio_check( + result = page.copyvio_check( min_confidence=T_SUSPECT, max_queries=8, max_time=30, - no_searches=not use_engine, - no_links=not use_links, + no_searches=not query.use_engine, + no_links=not query.use_links, short_circuit=not query.noskip, ) except exceptions.SearchQueryError as exc: - query.error = "search error" - query.exception = exc - return - query.result.cached = False + raise CopyvioCheckError(ErrorCode.SEARCH_ERROR) from exc + result.metadata.cached = False try: - _cache_result(page, query.result, conn, mode) + _cache_result(page, result, conn, mode) except sql_error: _LOGGER.exception("Failed to cache results") + return result + + +def _get_cache_id(page: Page, mode: str) -> bytes: + return hashlib.sha256((mode + page.get()).encode("utf8")).digest() -def _get_cached_results(page, conn, mode, noskip): - query1 = """SELECT cache_time, cache_queries, cache_process_time, - cache_possible_miss - FROM cache - WHERE cache_id = ?""" - query2 = """SELECT cdata_url, cdata_confidence, cdata_skipped, cdata_excluded - FROM cache_data - WHERE cdata_cache_id = ?""" - cache_id = sha256(mode + page.get().encode("utf8")).digest() +def _get_cached_results( + page: Page, conn: PoolProxiedConnection, mode: str, noskip: bool +) -> CopyvioCheckResult | None: + cache_id = _get_cache_id(page, mode) cursor = conn.cursor() - cursor.execute(query1, (cache_id,)) + cursor.execute( + """SELECT cache_time, cache_queries, cache_process_time, cache_possible_miss + FROM cache + WHERE cache_id = ?""", + (cache_id,), + ) results = cursor.fetchall() + if not results: return None cache_time, queries, check_time, possible_miss = results[0] if possible_miss and noskip: return None + if not isinstance(cache_time, datetime): - cache_time = datetime.utcfromtimestamp(cache_time) - if datetime.utcnow() - cache_time > timedelta(days=3): + cache_time = datetime.fromtimestamp(cache_time, tz=UTC) + elif cache_time.tzinfo is None: + cache_time = cache_time.replace(tzinfo=UTC) + if datetime.now(UTC) - cache_time > timedelta(days=3): return None - cursor.execute(query2, (cache_id,)) + + cursor.execute( + """SELECT cdata_url, cdata_confidence, cdata_skipped, cdata_excluded + FROM cache_data + WHERE cdata_cache_id = ?""", + (cache_id,), + ) data = cursor.fetchall() if not data: # TODO: do something less hacky for this edge case - article_chain = MarkovChain(ArticleTextParser(page.get()).strip()) + article_chain = CopyvioChecker(page).article_chain result = CopyvioCheckResult( False, [], queries, check_time, article_chain, possible_miss ) - result.cached = True - result.cache_time = cache_time.strftime("%b %d, %Y %H:%M:%S UTC") - result.cache_age = _format_date(cache_time) + result.metadata.cached = True + result.metadata.cache_time = cache_time.strftime("%b %d, %Y %H:%M:%S UTC") + result.metadata.cache_age = _format_date(cache_time) return result - url, confidence, skipped, excluded = data.pop(0) + url, confidence, skipped, excluded = data[0] if skipped: # Should be impossible: data must be bad; run a new check return None result = page.copyvio_compare(url, min_confidence=T_SUSPECT, max_time=10) if abs(result.confidence - confidence) >= 0.0001: return None - for url, confidence, skipped, excluded in data: + for url, confidence, skipped, excluded in data[1:]: if noskip and skipped: return None - source = CopyvioSource(None, url) + source = CopyvioSource(typing.cast(CopyvioWorkspace, None), url) source.confidence = confidence source.skipped = bool(skipped) source.excluded = bool(excluded) result.sources.append(source) + result.queries = queries result.time = check_time result.possible_miss = possible_miss - result.cached = True - result.cache_time = cache_time.strftime("%b %d, %Y %H:%M:%S UTC") - result.cache_age = _format_date(cache_time) + result.metadata.cached = True + result.metadata.cache_time = cache_time.strftime("%b %d, %Y %H:%M:%S UTC") + result.metadata.cache_age = _format_date(cache_time) return result -def _format_date(cache_time): - def formatter(n, w): - return "{} {}{}".format(n, w, "" if n == 1 else "s") +def _format_date(cache_time: datetime) -> str: + def formatter(val: float, unit: str): + return f"{int(val)} {unit}{'' if val == 1 else 's'}" - diff = datetime.utcnow() - cache_time + diff = datetime.now(UTC) - cache_time total_seconds = diff.days * 86400 + diff.seconds if total_seconds > 3600: return formatter(total_seconds / 3600, "hour") @@ -246,19 +277,14 @@ def _format_date(cache_time): return formatter(total_seconds, "second") -def _cache_result(page, result, conn, mode): +def _cache_result( + page: Page, result: CopyvioCheckResult, conn: PoolProxiedConnection, mode: str +) -> None: expiry = sql_dialect( mysql="DATE_SUB(CURRENT_TIMESTAMP, INTERVAL 3 DAY)", sqlite="STRFTIME('%s', 'now', '-3 days')", ) - query1 = "DELETE FROM cache WHERE cache_id = ?" - query2 = f"DELETE FROM cache WHERE cache_time < {expiry}" - query3 = """INSERT INTO cache (cache_id, cache_queries, cache_process_time, - cache_possible_miss) VALUES (?, ?, ?, ?)""" - query4 = """INSERT INTO cache_data (cdata_cache_id, cdata_url, - cdata_confidence, cdata_skipped, - cdata_excluded) VALUES (?, ?, ?, ?, ?)""" - cache_id = sha256(mode + page.get().encode("utf8")).digest() + cache_id = _get_cache_id(page, mode) data = [ ( cache_id, @@ -269,10 +295,29 @@ def _cache_result(page, result, conn, mode): ) for source in result.sources ] - with get_cursor(conn) as cursor: - cursor.execute(query1, (cache_id,)) - cursor.execute(query2) - cursor.execute( - query3, (cache_id, result.queries, result.time, result.possible_miss) + + # TODO: Switch to proper SQLAlchemy + cur = conn.cursor() + try: + cur.execute("DELETE FROM cache WHERE cache_id = ?", (cache_id,)) + cur.execute(f"DELETE FROM cache WHERE cache_time < {expiry}") + cur.execute( + """INSERT INTO cache ( + cache_id, cache_queries, cache_process_time, cache_possible_miss + ) VALUES (?, ?, ?, ?)""", + (cache_id, result.queries, result.time, result.possible_miss), + ) + cur.executemany( + """INSERT INTO cache_data ( + cdata_cache_id, cdata_url, cdata_confidence, cdata_skipped, + cdata_excluded + ) VALUES (?, ?, ?, ?, ?)""", + data, ) - cursor.executemany(query4, data) + except Exception: + conn.rollback() + raise + else: + conn.commit() + finally: + cur.close() diff --git a/src/copyvios/cookies.py b/src/copyvios/cookies.py index 5daf798..359ec90 100644 --- a/src/copyvios/cookies.py +++ b/src/copyvios/cookies.py @@ -1,59 +1,85 @@ +__all__ = [ + "delete_cookie", + "get_cookies", + "get_new_cookies", + "parse_cookies", + "set_cookie", +] + import base64 -from datetime import datetime, timedelta +from datetime import UTC, datetime, timedelta from http.cookies import CookieError, SimpleCookie -from flask import g - -__all__ = ["parse_cookies", "set_cookie", "delete_cookie"] +from flask import g, request -class _CookieManager(SimpleCookie): +class CookieManager(SimpleCookie): MAGIC = "--cpv2" - def __init__(self, path, cookies): + def __init__(self, path: str, cookies: str | None) -> None: self._path = path try: super().__init__(cookies) except CookieError: super().__init__() for cookie in list(self.keys()): - if self[cookie].value is False: + if not self[cookie].value: del self[cookie] - def value_decode(self, value): - unquoted = super().value_decode(value)[0] + def value_decode(self, val: str) -> tuple[str, str]: + unquoted = super().value_decode(val)[0] try: - decoded = base64.b64decode(unquoted).decode("utf8") - except (TypeError, UnicodeDecodeError): - return False, "False" + decoded = base64.b64decode(unquoted).decode() + except (TypeError, ValueError): + return "", "" if decoded.startswith(self.MAGIC): - return decoded[len(self.MAGIC) :], value - return False, "False" + return decoded[len(self.MAGIC) :], val + return "", "" - def value_encode(self, value): - encoded = base64.b64encode(self.MAGIC + value.encode("utf8")) + def value_encode(self, val: str) -> tuple[str, str]: + encoded = base64.b64encode((self.MAGIC + val).encode()).decode() quoted = super().value_encode(encoded)[1] - return value, quoted + return val, quoted @property - def path(self): + def path(self) -> str: return self._path -def parse_cookies(path, cookies): - return _CookieManager(path, cookies) +def parse_cookies(path: str, cookies: str | None) -> CookieManager: + return CookieManager(path, cookies) + + +def get_cookies() -> CookieManager: + if "cookies" not in g: + g.cookies = parse_cookies( + request.script_root or "/", request.environ.get("HTTP_COOKIE") + ) + assert isinstance(g.cookies, CookieManager), g.cookies + return g.cookies -def set_cookie(key, value, days=0): - g.cookies[key] = value +def get_new_cookies() -> list[str]: + if "new_cookies" not in g: + g.new_cookies = [] + assert isinstance(g.new_cookies, list), g.new_cookies + return g.new_cookies + + +def set_cookie(key: str, value: str, days: float = 0) -> None: + cookies = get_cookies() + cookies[key] = value if days: - expire_dt = datetime.utcnow() + timedelta(days=days) + expire_dt = datetime.now(UTC) + timedelta(days=days) expires = expire_dt.strftime("%a, %d %b %Y %H:%M:%S GMT") - g.cookies[key]["expires"] = expires - g.cookies[key]["path"] = g.cookies.path - g.new_cookies.append(g.cookies[key].OutputString()) + cookies[key]["expires"] = expires + cookies[key]["path"] = cookies.path + + new_cookies = get_new_cookies() + new_cookies.append(cookies[key].OutputString()) -def delete_cookie(key): +def delete_cookie(key: str) -> None: + cookies = get_cookies() set_cookie(key, "", days=-1) - del g.cookies[key] + del cookies[key] diff --git a/src/copyvios/highlighter.py b/src/copyvios/highlighter.py index 009d93b..c953401 100644 --- a/src/copyvios/highlighter.py +++ b/src/copyvios/highlighter.py @@ -1,20 +1,28 @@ -from collections import deque -from re import UNICODE, sub +__all__ = ["highlight_delta"] -from earwigbot.wiki.copyvios.markov import EMPTY_INTERSECTION -from markupsafe import escape +import re +from collections import deque +from typing import Literal -__all__ = ["highlight_delta"] +import markupsafe +from earwigbot.wiki.copyvios.markov import ( + EMPTY_INTERSECTION, + MarkovChain, + MarkovChainIntersection, + Sentinel, +) -def highlight_delta(context, chain, delta): +def highlight_delta( + context, chain: MarkovChain, delta: MarkovChainIntersection | None +) -> str: degree = chain.degree - 1 highlights = [False] * degree - block = deque([chain.START] * degree) + block: deque[str | Sentinel] = deque([Sentinel.START] * degree) if not delta: delta = EMPTY_INTERSECTION - for word in chain.text.split() + ([chain.END] * degree): - word = _strip_word(chain, word) + for word in chain.text.split() + ([Sentinel.END] * degree): + word = _strip_word(word) block.append(word) if tuple(block) in delta.chain: highlights[-1 * degree :] = [True] * degree @@ -25,7 +33,7 @@ def highlight_delta(context, chain, delta): i = degree numwords = len(chain.text.split()) - result = [] + result: list[str] = [] paragraphs = deque(chain.text.split("\n")) while paragraphs: words = [] @@ -37,15 +45,15 @@ def highlight_delta(context, chain, delta): last = i - degree + 1 == numwords words.append(_highlight_word(word, before, after, first, last)) else: - words.append(str(escape(word))) + words.append(str(markupsafe.escape(word))) result.append(" ".join(words)) i += 1 return "

".join(result) -def _get_next(paragraphs): - body = [] +def _get_next(paragraphs: deque[str]) -> list[str]: + body: list[str] = [] while paragraphs and not body: body = paragraphs.popleft().split() if body and len(body) <= 3: @@ -59,44 +67,46 @@ def _get_next(paragraphs): return body -def _highlight_word(word, before, after, first, last): +def _highlight_word( + word: str, before: bool, after: bool, first: bool, last: bool +) -> str: if before and after: - # Word is in the middle of a highlighted block: - res = str(escape(word)) + # Word is in the middle of a highlighted block + res = str(markupsafe.escape(word)) if first: res = '' + res if last: res += "" elif after: - # Word is the first in a highlighted block: + # Word is the first in a highlighted block res = '' + _fade_word(word, "in") if last: res += "" elif before: - # Word is the last in a highlighted block: + # Word is the last in a highlighted block res = _fade_word(word, "out") + "" if first: res = '' + res else: - res = str(escape(word)) + res = str(markupsafe.escape(word)) return res -def _fade_word(word, dir): +def _fade_word(word: str, dir: Literal["in", "out"]) -> str: if len(word) <= 4: - word = str(escape(word)) + word = str(markupsafe.escape(word)) return f'{word}' if dir == "out": - before, after = str(escape(word[:-4])), str(escape(word[-4:])) - base = '{0}{1}' - return base.format(before, after) + before = str(markupsafe.escape(word[:-4])) + after = str(markupsafe.escape(word[-4:])) + return f'{before}{after}' else: - before, after = str(escape(word[:4])), str(escape(word[4:])) - base = '{0}{1}' - return base.format(before, after) + before = str(markupsafe.escape(word[:4])) + after = str(markupsafe.escape(word[4:])) + return f'{before}{after}' -def _strip_word(chain, word): - if word == chain.START or word == chain.END: +def _strip_word(word: str | Sentinel) -> str | Sentinel: + if word == Sentinel.START or word == Sentinel.END: return word - return sub("[^\w\s-]", "", word.lower(), flags=UNICODE) + return re.sub(r"[^\w\s-]", "", word.lower()) diff --git a/src/copyvios/misc.py b/src/copyvios/misc.py index c924fa6..bd962ae 100644 --- a/src/copyvios/misc.py +++ b/src/copyvios/misc.py @@ -1,120 +1,66 @@ -import datetime -from contextlib import contextmanager -from os.path import expanduser, join - -import apsw -import oursql -from flask import g, request -from sqlalchemy.pool import manage - -oursql = manage(oursql) - -__all__ = ["Query", "cache", "get_db", "get_notice", "httpsfix", "urlstrip"] - - -class Query: - def __init__(self, method="GET"): - self.query = {} - data = request.form if method == "POST" else request.args - for key in data: - self.query[key] = data.getlist(key)[-1] - - def __getattr__(self, key): - return self.query.get(key) - - def __setattr__(self, key, value): - if key == "query": - super().__setattr__(key, value) - else: - self.query[key] = value - - -class _AppCache: - def __init__(self): - super().__setattr__("_data", {}) - - def __getattr__(self, key): - return self._data[key] - - def __setattr__(self, key, value): - self._data[key] = value - - -cache = _AppCache() - - -def _connect_to_db(engine, args): - if engine == "mysql": - args["read_default_file"] = expanduser("~/.my.cnf") - args["autoping"] = True - args["autoreconnect"] = True - return oursql.connect(**args) - if engine == "sqlite": - dbpath = join(cache.bot.config.root_dir, "copyvios.db") - conn = apsw.Connection(dbpath) - conn.cursor().execute("PRAGMA foreign_keys = ON") - return conn - raise ValueError(f"Unknown engine: {engine}") +__all__ = [ + "get_notice", + "get_sql_error", + "httpsfix", + "parse_wiki_timestamp", + "sql_dialect", + "urlstrip", +] +import datetime +import os +import sqlite3 +from typing import TypeVar -def get_db(): - if not g._db: - args = cache.bot.config.wiki["copyvios"].copy() - g._engine = engine = args.pop("engine", "mysql").lower() - g._db = _connect_to_db(engine, args) - return g._db +import pymysql +from .cache import cache -@contextmanager -def get_cursor(conn): - if g._engine == "mysql": - with conn.cursor() as cursor: - yield cursor - elif g._engine == "sqlite": - with conn: - yield conn.cursor() - else: - raise ValueError(f"Unknown engine: {g._engine}") +T = TypeVar("T") -def get_sql_error(): - if g._engine == "mysql": - return oursql.Error - if g._engine == "sqlite": - return apsw.Error - raise ValueError(f"Unknown engine: {g._engine}") +def get_sql_error() -> type[Exception]: + match cache.engine.dialect.name: + case "mysql": + return pymysql.Error + case "sqlite": + return sqlite3.Error + case dialect: + raise ValueError(f"Unknown engine: {dialect}") -def sql_dialect(mysql, sqlite): - if g._engine == "mysql": - return mysql - if g._engine == "sqlite": - return sqlite - raise ValueError(f"Unknown engine: {g._engine}") +def sql_dialect(mysql: T, sqlite: T) -> T: + match cache.engine.dialect.name: + case "mysql": + return mysql + case "sqlite": + return sqlite + case dialect: + raise ValueError(f"Unknown engine: {dialect}") -def get_notice(): +def get_notice() -> str | None: try: - with open(expanduser("~/copyvios_notice.html")) as fp: - lines = fp.read().decode("utf8").strip().splitlines() - if lines[0] == "": + with open(os.path.expanduser("~/copyvios_notice.html")) as fp: + lines = fp.read().strip().splitlines() + if lines and lines[0] == "": return "\n".join(lines[1:]) return None except OSError: return None -def httpsfix(context, url): +def httpsfix(context, url: str) -> str: if url.startswith("http://"): url = url[len("http:") :] return url -def parse_wiki_timestamp(timestamp): +def parse_wiki_timestamp(timestamp: str) -> datetime.datetime: return datetime.datetime.strptime(timestamp, "%Y%m%d%H%M%S") -def urlstrip(context, url): +def urlstrip(context, url: str) -> str: if url.startswith("http://"): url = url[7:] if url.startswith("https://"): diff --git a/src/copyvios/query.py b/src/copyvios/query.py new file mode 100644 index 0000000..e9221fb --- /dev/null +++ b/src/copyvios/query.py @@ -0,0 +1,87 @@ +__all__ = ["APIQuery", "CheckQuery", "SettingsQuery"] + +from typing import Any, Literal, Self + +from flask import request +from pydantic import BaseModel, field_validator, model_validator +from werkzeug.datastructures import MultiDict + + +class BaseQuery(BaseModel): + @classmethod + def from_multidict(cls, args: MultiDict[str, str]) -> Self: + query = {key: args.getlist(key)[-1] for key in args} + return cls.model_validate(query) + + @classmethod + def from_get_args(cls) -> Self: + return cls.from_multidict(request.args) + + @classmethod + def from_post_data(cls) -> Self: + return cls.from_multidict(request.form) + + +class CheckQuery(BaseQuery): + action: str | None = None + lang: str | None = None + project: str | None = None + title: str | None = None + oldid: str | None = None + url: str | None = None + use_engine: bool = True + use_links: bool = True + turnitin: bool = False + nocache: bool = False + noredirect: bool = False + noskip: bool = False + degree: int | None = None + + # Derived parameters + orig_lang: str | None = None + name: str | None = None + + @field_validator("project") + @classmethod + def validate_project(cls, project: Any) -> str | None: + if not isinstance(project, str): + return project + return project.strip().lower() + + @field_validator("oldid") + @classmethod + def validate_oldid(cls, oldid: Any) -> str | None: + if not isinstance(oldid, str): + return oldid + return oldid.strip().lstrip("0") + + @model_validator(mode="after") + def validate_lang(self) -> Self: + self.orig_lang = self.name = None + if self.lang: + self.lang = self.orig_lang = self.lang.strip().lower() + if "::" in self.lang: + self.lang, self.name = self.lang.split("::", 1) + return self + + @property + def submitted(self) -> bool: + return bool(self.project and self.lang and (self.title or self.oldid)) + + +class APIQuery(CheckQuery): + version: int = 1 + detail: bool = False + + +class SettingsQuery(BaseQuery): + action: Literal["set", "delete"] | None = None + + # With action=set: + lang: str | None = None + project: str | None = None + background: Literal["list", "potd", "plain"] | None = None + + # With action=delete: + cookie: str | None = None + all: bool | None = None diff --git a/src/copyvios/settings.py b/src/copyvios/settings.py index 00217d1..c5548fa 100644 --- a/src/copyvios/settings.py +++ b/src/copyvios/settings.py @@ -1,54 +1,58 @@ -from flask import g -from markupsafe import escape +__all__ = ["process_settings"] -from .cookies import delete_cookie, set_cookie -from .misc import Query +import typing -__all__ = ["process_settings"] +import markupsafe + +from .cookies import delete_cookie, get_cookies, set_cookie +from .query import SettingsQuery + +COOKIE_EXPIRY = 3 * 365 # Days -def process_settings(): - query = Query(method="POST") - if query.action == "set": - status = _do_set(query) - elif query.action == "delete": - status = _do_delete(query) - else: - status = None - return status +def process_settings() -> str | None: + query = SettingsQuery.from_post_data() + match query.action: + case "set": + return _do_set(query) + case "delete": + return _do_delete(query) + case None: + return None + case _: + typing.assert_never(query.action) -def _do_set(query): - cookies = g.cookies - changes = set() +def _do_set(query: SettingsQuery) -> str | None: + cookies = get_cookies() + changes: set[str] = set() if query.lang: key = "CopyviosDefaultLang" if key not in cookies or cookies[key].value != query.lang: - set_cookie(key, query.lang, 1095) + set_cookie(key, query.lang, COOKIE_EXPIRY) changes.add("site") if query.project: key = "CopyviosDefaultProject" if key not in cookies or cookies[key].value != query.project: - set_cookie(key, query.project, 1095) + set_cookie(key, query.project, COOKIE_EXPIRY) changes.add("site") if query.background: key = "CopyviosBackground" if key not in cookies or cookies[key].value != query.background: - set_cookie(key, query.background, 1095) - delete_cookie("EarwigBackgroundCache") + set_cookie(key, query.background, COOKIE_EXPIRY) + delete_cookie("EarwigBackgroundCache") # Old name changes.add("background") if changes: - changes = ", ".join(sorted(list(changes))) - return f"Updated {changes}." + return f"Updated {', '.join(sorted(changes))}." return None -def _do_delete(query): - cookies = g.cookies - if query.cookie in cookies: - delete_cookie(query.cookie.encode("utf8")) - template = 'Deleted cookie {0}.' - return template.format(escape(query.cookie)) +def _do_delete(query: SettingsQuery) -> str | None: + cookies = get_cookies() + cookie = query.cookie + if cookie and cookie in cookies: + delete_cookie(cookie) + return f'Deleted cookie {markupsafe.escape(cookie)}.' elif query.all: number = len(cookies) for cookie in list(cookies.values()): diff --git a/src/copyvios/sites.py b/src/copyvios/sites.py index 3dc5706..ede1e70 100644 --- a/src/copyvios/sites.py +++ b/src/copyvios/sites.py @@ -1,40 +1,53 @@ -from time import time -from urllib.parse import urlparse +__all__ = ["get_site", "update_sites"] + +import urllib.parse +from datetime import UTC, datetime, timedelta from earwigbot import exceptions +from earwigbot.wiki import Site +from flask import g -from .misc import cache +from .cache import Lang, Project, cache +from .query import CheckQuery -__all__ = ["get_site", "update_sites"] +_LAST_SITES_UPDATE = datetime.min.replace(tzinfo=UTC) -def get_site(query): - lang, project, name = query.lang, query.project, query.name - wiki = cache.bot.wiki - if project not in [proj[0] for proj in cache.projects]: +def _get_site(query: CheckQuery) -> Site | None: + if not any(proj.code == query.project for proj in cache.projects): return None - if project == "wikimedia" and name: # Special sites: - try: - return wiki.get_site(name=name) - except exceptions.SiteNotFoundError: - return _add_site(lang, project) try: - return wiki.get_site(lang=lang, project=project) + if query.project == "wikimedia" and query.name: # Special sites + return cache.bot.wiki.get_site(name=query.name) + else: + return cache.bot.wiki.get_site(lang=query.lang, project=query.project) except exceptions.SiteNotFoundError: - return _add_site(lang, project) + assert query.lang and query.project, (query.lang, query.project) + return _add_site(query.lang, query.project) + +def get_site(query: CheckQuery | None = None) -> Site | None: + if "site" not in g: + assert query is not None, "get_site() called with no cached site nor query" + g.site = _get_site(query) + assert g.site is None or isinstance(g.site, Site), g.site + return g.site -def update_sites(): - if time() - cache.last_sites_update > 60 * 60 * 24 * 7: + +def update_sites() -> None: + global _LAST_SITES_UPDATE + + now = datetime.now(UTC) + if now - _LAST_SITES_UPDATE > timedelta(days=1): cache.langs, cache.projects = _load_sites() - cache.last_sites_update = time() + _LAST_SITES_UPDATE = now -def _add_site(lang, project): +def _add_site(lang: str, project: str) -> Site | None: update_sites() - if not any(project == item[0] for item in cache.projects): + if not any(project == proj.code for proj in cache.projects): return None - if lang != "www" and not any(lang == item[0] for item in cache.langs): + if lang != "www" and not any(lang == item.code for item in cache.langs): return None try: return cache.bot.wiki.add_site(lang=lang, project=project) @@ -42,34 +55,38 @@ def _add_site(lang, project): return None -def _load_sites(): +def _load_sites() -> tuple[list[Lang], list[Project]]: site = cache.bot.wiki.get_site() matrix = site.api_query(action="sitematrix")["sitematrix"] del matrix["count"] - langs, projects = set(), set() + langs: set[Lang] = set() + projects: set[Project] = set() + for site in matrix.values(): if isinstance(site, list): # Special sites bad_sites = ["closed", "private", "fishbowl"] for special in site: - if all([key not in special for key in bad_sites]): - full = urlparse(special["url"]).netloc - if full.count(".") == 1: # No subdomain, so use "www" - lang, project = "www", full.split(".")[0] - else: - lang, project = full.rsplit(".", 2)[:2] - code = "{}::{}".format(lang, special["dbname"]) - name = special["code"].capitalize() - langs.add((code, f"{lang} ({name})")) - projects.add((project, project.capitalize())) + if any(key in special for key in bad_sites): + continue + full = urllib.parse.urlparse(special["url"]).netloc + if full.count(".") == 1: # No subdomain, so use "www" + lang, project = "www", full.split(".")[0] + else: + lang, project = full.rsplit(".", 2)[:2] + langcode = f"{lang}::{special['dbname']}" + langname = special["code"].capitalize() + langs.add(Lang(langcode, f"{lang} ({langname})")) + projects.add(Project(project, project.capitalize())) else: - this = set() + this: set[Project] = set() for web in site["site"]: if "closed" in web: continue proj = "wikipedia" if web["code"] == "wiki" else web["code"] - this.add((proj, proj.capitalize())) + this.add(Project(proj, proj.capitalize())) if this: code = site["code"] - langs.add((code, "{} ({})".format(code, site["name"]))) + langs.add(Lang(code, f"{code} ({site['name']})")) projects |= this - return list(sorted(langs)), list(sorted(projects)) + + return sorted(langs), sorted(projects) diff --git a/src/copyvios/turnitin.py b/src/copyvios/turnitin.py index 6026c72..c12d4cb 100644 --- a/src/copyvios/turnitin.py +++ b/src/copyvios/turnitin.py @@ -1,29 +1,30 @@ +from __future__ import annotations + +__all__ = ["search_turnitin", "TURNITIN_API_ENDPOINT"] + +import ast import re -from ast import literal_eval +from dataclasses import dataclass +from datetime import datetime import requests from .misc import parse_wiki_timestamp -__all__ = ["search_turnitin", "TURNITIN_API_ENDPOINT"] - TURNITIN_API_ENDPOINT = "https://eranbot.toolforge.org/plagiabot/api.py" -def search_turnitin(page_title, lang): - """Search the Plagiabot database for Turnitin reports for a page. - - Keyword arguments: - page_title -- string containing the page title - lang -- string containing the page's project language code - - Return a TurnitinResult (contains a list of TurnitinReports). +def search_turnitin(page_title: str, lang: str) -> TurnitinResult: + """ + Search the Plagiabot database for Turnitin reports for a page. """ return TurnitinResult(_make_api_request(page_title, lang)) -def _make_api_request(page_title, lang): - """Query the plagiabot API for Turnitin reports for a given page.""" +def _make_api_request(page_title: str, lang: str) -> list[dict]: + """ + Query the plagiabot API for Turnitin reports for a given page. + """ stripped_page_title = page_title.replace(" ", "_") api_parameters = { "action": "suspected_diffs", @@ -35,40 +36,40 @@ def _make_api_request(page_title, lang): result = requests.get(TURNITIN_API_ENDPOINT, params=api_parameters, verify=False) # use literal_eval to *safely* parse the resulting dict-containing string try: - parsed_api_result = literal_eval(result.text) + parsed_api_result = ast.literal_eval(result.text) except (SyntaxError, ValueError): parsed_api_result = [] return parsed_api_result +@dataclass class TurnitinResult: - """Container class for TurnitinReports. Each page may have zero or - more reports of plagiarism. The list will have multiple - TurnitinReports if plagiarism has been detected for more than one - revision. + """ + Container class for TurnitinReports. - TurnitinResult.reports -- list containing >= 0 TurnitinReport items + Each page may have zero or more reports of plagiarism. The list will have multiple + TurnitinReports if plagiarism has been detected for more than one revision. """ - def __init__(self, turnitin_data): + reports: list[TurnitinReport] + + def __init__(self, turnitin_data: list[dict]) -> None: """ Keyword argument: turnitin_data -- plagiabot API result """ - self.reports = [] - for item in turnitin_data: - report = TurnitinReport( - item["diff_timestamp"], item["diff"], item["report"] - ) - self.reports.append(report) - - def __repr__(self): - return str(self.__dict__) + self.reports = [ + TurnitinReport(item["diff_timestamp"], item["diff"], item["report"]) + for item in turnitin_data + ] +@dataclass class TurnitinReport: - """Contains data for each Turnitin report (one on each potentially - plagiarized revision). + """ + Contains data for each Turnitin report. + + There is one report for each potentially plagiarized revision. TurnitinReport.reportid -- Turnitin report ID, taken from plagiabot TurnitinReport.diffid -- diff ID from Wikipedia database @@ -79,30 +80,33 @@ class TurnitinReport: url -- url for the possibly-plagiarized source """ - def __init__(self, timestamp, diffid, report): + reportid: str + diffid: str + time_posted: datetime + sources: list[dict] + + def __init__(self, timestamp: str, diffid: str, report: str) -> None: """ Keyword argument: timestamp -- diff timestamp from Wikipedia database diffid -- diff ID from Wikipedia database report -- Turnitin report from the plagiabot database """ - self.report_data = self._parse_report(report) - self.reportid = self.report_data[0] + self.reportid, results = self._parse_report(report) self.diffid = diffid self.time_posted = parse_wiki_timestamp(timestamp) self.sources = [] - for item in self.report_data[1]: + for item in results: source = {"percent": item[0], "words": item[1], "url": item[2]} self.sources.append(source) - def __repr__(self): - return str(self.__dict__) - - def _parse_report(self, report_text): + def _parse_report(self, report_text: str) -> tuple[str, list[str]]: # extract report ID report_id_pattern = re.compile(r"\?rid=(\d*)") - report_id = report_id_pattern.search(report_text).groups()[0] + report_id_match = report_id_pattern.search(report_text) + assert report_id_match, report_text + report_id = report_id_match.group(1) # extract percent match, words, and URL for each source in the report extract_info_pattern = re.compile(r"\n\* \w\s+(\d*)\% (\d*) words at \[(.*?) ") diff --git a/static/style.css b/static/css/style.css similarity index 100% rename from static/style.css rename to static/css/style.css diff --git a/templates/index.mako b/templates/index.mako index 4210730..22e3d0b 100644 --- a/templates/index.mako +++ b/templates/index.mako @@ -1,7 +1,8 @@ <%! - from flask import g, request + from flask import request from copyvios.attribution import get_attribution_info from copyvios.checker import T_POSSIBLE, T_SUSPECT + from copyvios.cookies import get_cookies from copyvios.misc import cache %>\ <% @@ -10,6 +11,7 @@ titleparts.append(query.page.title) titleparts.append("Earwig's Copyvio Detector") title = " | ".join(titleparts) + cookies = get_cookies() %>\ <%include file="/support/header.mako" args="title=title, splash=not result"/> <%namespace module="copyvios.highlighter" import="highlight_delta"/>\ @@ -37,7 +39,7 @@ % elif query.error == "timeout": The URL ${query.url | h} timed out before any data could be retrieved. % elif query.error == "search error": - An error occurred while using the search engine (${query.exception}). Note: there is a daily limit on the number of search queries the tool is allowed to make. You may repeat the check without using the search engine. + An error occurred while using the search engine (${query.error.__cause__}). Note: there is a daily limit on the number of search queries the tool is allowed to make. You may repeat the check without using the search engine. % else: An unknown error occurred. % endif @@ -64,7 +66,7 @@
- <% selected_project = query.project if query.project else g.cookies["CopyviosDefaultProject"].value if "CopyviosDefaultProject" in g.cookies else cache.bot.wiki.get_site().project %>\ + <% selected_project = query.project if query.project else cookies["CopyviosDefaultProject"].value if "CopyviosDefaultProject" in cookies else cache.bot.wiki.get_site().project %>\ % for code, name in cache.projects: % if code == selected_project: diff --git a/templates/settings.mako b/templates/settings.mako index 0654007..c0998b2 100644 --- a/templates/settings.mako +++ b/templates/settings.mako @@ -1,7 +1,11 @@ <%! from json import dumps, loads - from flask import g, request - from copyvios.misc import cache + from flask import request + from copyvios.cookies import get_cookies + from copyvios.cache import cache +%>\ +<% + cookies = get_cookies() %>\ <%include file="/support/header.mako" args="title='Settings | Earwig\'s Copyvio Detector', splash=True"/> % if status: @@ -20,7 +24,7 @@
- <% selected_project = g.cookies["CopyviosDefaultProject"].value if "CopyviosDefaultProject" in g.cookies else default_project %>\ + <% selected_project = cookies["CopyviosDefaultProject"].value if "CopyviosDefaultProject" in cookies else default_project %>\ % for code, name in cache.projects: % if code == selected_project: @@ -55,7 +59,7 @@ ("potd", 'Use the current Commons Picture of the Day, unfiltered. Certain POTDs may be unsuitable as backgrounds due to their aspect ratio or subject matter.'), ("plain", "Use a plain background."), ] - selected = g.cookies["CopyviosBackground"].value if "CopyviosBackground" in g.cookies else "list" + selected = cookies["CopyviosBackground"].value if "CopyviosBackground" in cookies else "list" %>\
diff --git a/templates/support/footer.mako b/templates/support/footer.mako index 153f07c..e318b83 100644 --- a/templates/support/footer.mako +++ b/templates/support/footer.mako @@ -11,7 +11,7 @@
  • Maintained by Ben Kurtovic
  • API
  • Source code
  • - % if ("CopyviosBackground" in g.cookies and g.cookies["CopyviosBackground"].value in ["potd", "list"]) or "CopyviosBackground" not in g.cookies: + % if g.descurl:
  • Background image
  • % endif diff --git a/templates/support/header.mako b/templates/support/header.mako index d6e7a76..ccab414 100644 --- a/templates/support/header.mako +++ b/templates/support/header.mako @@ -1,7 +1,11 @@ <%page args="title, splash=False"/>\ <%! - from flask import g, request, url_for - from copyvios.background import set_background + from flask import request, url_for + from copyvios.background import get_background + from copyvios.cookies import get_cookies +%>\ +<% + cookies = get_cookies() %>\ @@ -15,11 +19,11 @@ -<% selected = g.cookies["CopyviosBackground"].value if "CopyviosBackground" in g.cookies else "list" %>\ +<% selected = cookies["CopyviosBackground"].value if "CopyviosBackground" in cookies else "list" %>\ % if selected == "plain": % else: - + % endif