@@ -1,3 +1,25 @@ | |||
v0.3 (released March 24, 2019): | |||
- Added various new features to the WikiProjectTagger task. | |||
- Copyvio detector: improved sentence splitting algorithm; many performance | |||
improvements. | |||
- Improved config file command/task exclusion logic. | |||
- Wiki: Added logging for warnings. | |||
- Wiki: Added OAuth support. | |||
- Wiki: Switched to requests from urllib2. | |||
- Wiki: Updated some deprecated API calls. | |||
- Wiki: Fixed Page.toggle_talk() behavior on mainspace titles with colons. | |||
- IRC > !cidr: Added; new command for calculating range blocks. | |||
- IRC > !notes: Improved help and added aliases. | |||
- IRC > !remind: Added !remind all. Fixed multithreading efficiency issues. | |||
Improved time detection and argument parsing. Newly expired reminders are now | |||
triggered on bot startup. | |||
- IRC > !stalk: Allow regular expressions as page titles or usernames. | |||
- IRC: Added a per-channel quiet config setting. | |||
- IRC: Try not to join channels before NickServ auth has completed. | |||
- IRC: Improved detection of maximum IRC message length. | |||
- IRC: Improved some help commands. | |||
v0.2 (released November 8, 2015): | |||
- Added a new command syntax allowing the caller to redirect replies to another | |||
@@ -1,4 +1,4 @@ | |||
Copyright (C) 2009-2015 Ben Kurtovic <ben.kurtovic@gmail.com> | |||
Copyright (C) 2009-2017 Ben Kurtovic <ben.kurtovic@gmail.com> | |||
Permission is hereby granted, free of charge, to any person obtaining a copy | |||
of this software and associated documentation files (the "Software"), to deal | |||
@@ -10,7 +10,7 @@ History | |||
------- | |||
Development began, based on the `Pywikipedia framework`_, in early 2009. | |||
Approval for its fist task, a `copyright violation detector`_, was carried out | |||
Approval for its first task, a `copyright violation detector`_, was carried out | |||
in May, and the bot has been running consistently ever since (with the | |||
exception of Jan/Feb 2011). It currently handles `several ongoing tasks`_ | |||
ranging from statistics generation to category cleanup, and on-demand tasks | |||
@@ -36,7 +36,7 @@ setup.py test`` from the project's root directory. Note that some | |||
tests require an internet connection, and others may take a while to run. | |||
Coverage is currently rather incomplete. | |||
Latest release (v0.2) | |||
Latest release (v0.3) | |||
~~~~~~~~~~~~~~~~~~~~~ | |||
EarwigBot is available from the `Python Package Index`_, so you can install the | |||
@@ -47,7 +47,7 @@ some header files. For example, on Ubuntu, see `this StackOverflow post`_. | |||
You can also install it from source [1]_ directly:: | |||
curl -Lo earwigbot.tgz https://github.com/earwig/earwigbot/tarball/v0.2 | |||
curl -Lo earwigbot.tgz https://github.com/earwig/earwigbot/tarball/v0.3 | |||
tar -xf earwigbot.tgz | |||
cd earwig-earwigbot-* | |||
python setup.py install | |||
@@ -41,16 +41,16 @@ master_doc = 'index' | |||
# General information about the project. | |||
project = u'EarwigBot' | |||
copyright = u'2009-2015 Ben Kurtovic' | |||
copyright = u'2009-2016 Ben Kurtovic' | |||
# The version info for the project you're documenting, acts as replacement for | |||
# |version| and |release|, also used in various other places throughout the | |||
# built documents. | |||
# | |||
# The short X.Y version. | |||
version = '0.2' | |||
version = '0.3' | |||
# The full version, including alpha/beta/rc tags. | |||
release = '0.2' | |||
release = '0.3' | |||
# The language for content autogenerated by Sphinx. Refer to documentation | |||
# for a list of supported languages. | |||
@@ -174,6 +174,13 @@ The bot has a wide selection of built-in commands and plugins to act as sample | |||
code and/or to give ideas. Start with test_, and then check out chanops_ and | |||
afc_status_ for some more complicated scripts. | |||
By default, the bot loads every built-in and custom command available. You can | |||
disable *all* built-in commands with the config entry | |||
:py:attr:`config.commands["disable"]` set to ``True``, or a subset of commands | |||
by setting it to a list of command class names or module names. If using the | |||
former method, you can specifically enable certain built-in commands with | |||
:py:attr:`config.commands["enable"]` set to a list of command module names. | |||
Custom bot tasks | |||
---------------- | |||
@@ -1,4 +1,4 @@ | |||
EarwigBot v0.2 Documentation | |||
EarwigBot v0.3 Documentation | |||
============================ | |||
EarwigBot_ is a Python_ robot that edits Wikipedia_ and interacts with people | |||
@@ -13,7 +13,7 @@ It's recommended to run the bot's unit tests before installing. Run | |||
some tests require an internet connection, and others may take a while to run. | |||
Coverage is currently rather incomplete. | |||
Latest release (v0.2) | |||
Latest release (v0.3) | |||
--------------------- | |||
EarwigBot is available from the `Python Package Index`_, so you can install the | |||
@@ -24,7 +24,7 @@ some header files. For example, on Ubuntu, see `this StackOverflow post`_. | |||
You can also install it from source [1]_ directly:: | |||
curl -Lo earwigbot.tgz https://github.com/earwig/earwigbot/tarball/v0.2 | |||
curl -Lo earwigbot.tgz https://github.com/earwig/earwigbot/tarball/v0.3 | |||
tar -xf earwigbot.tgz | |||
cd earwig-earwigbot-* | |||
python setup.py install | |||
@@ -1,6 +1,6 @@ | |||
# -*- coding: utf-8 -*- | |||
# | |||
# Copyright (C) 2009-2015 Ben Kurtovic <ben.kurtovic@gmail.com> | |||
# Copyright (C) 2009-2019 Ben Kurtovic <ben.kurtovic@gmail.com> | |||
# | |||
# Permission is hereby granted, free of charge, to any person obtaining a copy | |||
# of this software and associated documentation files (the "Software"), to deal | |||
@@ -30,9 +30,9 @@ details. This documentation is also available `online | |||
""" | |||
__author__ = "Ben Kurtovic" | |||
__copyright__ = "Copyright (C) 2009-2015 Ben Kurtovic" | |||
__copyright__ = "Copyright (C) 2009-2019 Ben Kurtovic" | |||
__license__ = "MIT License" | |||
__version__ = "0.2" | |||
__version__ = "0.3" | |||
__email__ = "ben.kurtovic@gmail.com" | |||
__release__ = False | |||
@@ -45,7 +45,7 @@ if not __release__: | |||
commit_id = Repo(path).head.object.hexsha | |||
return commit_id[:8] | |||
try: | |||
__version__ += "+git-" + _get_git_commit_id() | |||
__version__ += "+" + _get_git_commit_id() | |||
except Exception: | |||
pass | |||
finally: | |||
@@ -150,7 +150,8 @@ class Bot(object): | |||
component_names = self.config.components.keys() | |||
skips = component_names + ["MainThread", "reminder", "irc:quit"] | |||
for thread in enumerate_threads(): | |||
if thread.name not in skips and thread.is_alive(): | |||
if thread.is_alive() and not any( | |||
thread.name.startswith(skip) for skip in skips): | |||
tasks.append(thread.name) | |||
if tasks: | |||
log = "The following commands or tasks will be killed: {0}" | |||
@@ -0,0 +1,179 @@ | |||
# -*- coding: utf-8 -*- | |||
# | |||
# Copyright (C) 2009-2016 Ben Kurtovic <ben.kurtovic@gmail.com> | |||
# | |||
# Permission is hereby granted, free of charge, to any person obtaining a copy | |||
# of this software and associated documentation files (the "Software"), to deal | |||
# in the Software without restriction, including without limitation the rights | |||
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell | |||
# copies of the Software, and to permit persons to whom the Software is | |||
# furnished to do so, subject to the following conditions: | |||
# | |||
# The above copyright notice and this permission notice shall be included in | |||
# all copies or substantial portions of the Software. | |||
# | |||
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR | |||
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, | |||
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE | |||
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER | |||
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, | |||
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE | |||
# SOFTWARE. | |||
from collections import namedtuple | |||
import re | |||
import socket | |||
from socket import AF_INET, AF_INET6 | |||
from earwigbot.commands import Command | |||
_IP = namedtuple("_IP", ["family", "ip", "size"]) | |||
_Range = namedtuple("_Range", [ | |||
"family", "range", "low", "high", "size", "addresses"]) | |||
class CIDR(Command): | |||
"""Calculates the smallest CIDR range that encompasses a list of IP | |||
addresses. Used to make range blocks.""" | |||
name = "cidr" | |||
commands = ["cidr", "range", "rangeblock", "rangecalc", "blockcalc", | |||
"iprange", "cdir"] | |||
# https://www.mediawiki.org/wiki/Manual:$wgBlockCIDRLimit | |||
LIMIT_IPv4 = 16 | |||
LIMIT_IPv6 = 19 | |||
def process(self, data): | |||
if not data.args: | |||
msg = ("Specify a list of IP addresses to calculate a CIDR range " | |||
"for. For example, \x0306!{0} 192.168.0.3 192.168.0.15 " | |||
"192.168.1.4\x0F or \x0306!{0} 2500:1:2:3:: " | |||
"2500:1:2:3:dead:beef::\x0F.") | |||
self.reply(data, msg.format(data.command)) | |||
return | |||
try: | |||
ips = [self._parse_ip(arg) for arg in data.args] | |||
except ValueError as exc: | |||
msg = "Can't parse IP address \x0302{0}\x0F." | |||
self.reply(data, msg.format(exc.message)) | |||
return | |||
if any(ip.family == AF_INET for ip in ips) and any( | |||
ip.family == AF_INET6 for ip in ips): | |||
msg = "Can't calculate a range for both IPv4 and IPv6 addresses." | |||
self.reply(data, msg) | |||
return | |||
cidr = self._calculate_range(ips[0].family, ips) | |||
descr = self._describe(cidr.family, cidr.size) | |||
msg = ("Smallest CIDR range is \x02{0}\x0F, covering {1} from " | |||
"\x0305{2}\x0F to \x0305{3}\x0F{4}.") | |||
self.reply(data, msg.format( | |||
cidr.range, cidr.addresses, cidr.low, cidr.high, | |||
" (\x0304{0}\x0F)".format(descr) if descr else "")) | |||
def _parse_ip(self, arg): | |||
"""Converts an argument into an IP address object.""" | |||
arg = self._parse_arg(arg) | |||
oldarg = arg | |||
size = None | |||
if "/" in arg: | |||
arg, size = arg.split("/", 1) | |||
try: | |||
size = int(size, 10) | |||
except ValueError: | |||
raise ValueError(oldarg) | |||
if size < 0 or size > 128: | |||
raise ValueError(oldarg) | |||
try: | |||
ip = _IP(AF_INET, socket.inet_pton(AF_INET, arg), size) | |||
except socket.error: | |||
try: | |||
return _IP(AF_INET6, socket.inet_pton(AF_INET6, arg), size) | |||
except socket.error: | |||
raise ValueError(oldarg) | |||
if size > 32: | |||
raise ValueError(oldarg) | |||
return ip | |||
def _parse_arg(self, arg): | |||
"""Converts an argument into an IP address string.""" | |||
if "[[" in arg and "]]" in arg: | |||
regex = r"\[\[\s*(?:User(?:\stalk)?:)?(.*?)(?:\|.*?)?\s*\]\]" | |||
match = re.search(regex, arg, re.I) | |||
if not match: | |||
raise ValueError(arg) | |||
arg = match.group(1) | |||
if re.match(r"https?://", arg): | |||
if "target=" in arg: | |||
regex = r"target=(.*?)(?:&|$)" | |||
elif "page=" in arg: | |||
regex = r"page=(?:User(?:(?:\s|_)talk)?(?::|%3A))?(.*?)(?:&|$)" | |||
elif re.search(r"Special(:|%3A)Contributions/", arg, re.I): | |||
regex = r"Special(?:\:|%3A)Contributions/(.*?)(?:\&|\?|$)" | |||
elif re.search(r"User((\s|_)talk)?(:|%3A)", arg, re.I): | |||
regex = r"User(?:(?:\s|_)talk)?(?:\:|%3A)(.*?)(?:\&|\?|$)" | |||
else: | |||
raise ValueError(arg) | |||
match = re.search(regex, arg, re.I) | |||
if not match: | |||
raise ValueError(arg) | |||
arg = match.group(1) | |||
return arg | |||
def _calculate_range(self, family, ips): | |||
"""Calculate the smallest CIDR range encompassing a list of IPs.""" | |||
bin_ips = ["".join( | |||
bin(ord(octet))[2:].zfill(8) for octet in ip.ip) for ip in ips] | |||
for i, ip in enumerate(ips): | |||
if ip.size is not None: | |||
suffix = "X" * (len(bin_ips[i]) - ip.size) | |||
bin_ips[i] = bin_ips[i][:ip.size] + suffix | |||
size = len(bin_ips[0]) | |||
for i in xrange(len(bin_ips[0])): | |||
if any(ip[i] == "X" for ip in bin_ips) or ( | |||
any(ip[i] == "0" for ip in bin_ips) and | |||
any(ip[i] == "1" for ip in bin_ips)): | |||
size = i | |||
break | |||
bin_low = bin_ips[0][:size].ljust(len(bin_ips[0]), "0") | |||
bin_high = bin_ips[0][:size].ljust(len(bin_ips[0]), "1") | |||
low = self._format_bin(family, bin_low) | |||
high = self._format_bin(family, bin_high) | |||
return _Range( | |||
family, low + "/" + str(size), low, high, size, | |||
self._format_count(2 ** (len(bin_ips[0]) - size))) | |||
@staticmethod | |||
def _format_bin(family, binary): | |||
"""Convert an IP's binary representation to presentation format.""" | |||
return socket.inet_ntop(family, "".join( | |||
chr(int(binary[i:i + 8], 2)) for i in xrange(0, len(binary), 8))) | |||
@staticmethod | |||
def _format_count(count): | |||
"""Nicely format a number of addresses affected by a range block.""" | |||
if count == 1: | |||
return "1 address" | |||
if count > 2 ** 32: | |||
base = "{0:.2E} addresses".format(count) | |||
if count == 2 ** 64: | |||
return base + " (1 /64 subnet)" | |||
if count > 2 ** 96: | |||
return base + " ({0:.2E} /64 subnets)".format(count >> 64) | |||
if count > 2 ** 63: | |||
return base + " ({0:,} /64 subnets)".format(count >> 64) | |||
return base | |||
return "{0:,} addresses".format(count) | |||
def _describe(self, family, size): | |||
"""Return an optional English description of a range.""" | |||
if (family == AF_INET and size < self.LIMIT_IPv4) or ( | |||
family == AF_INET6 and size < self.LIMIT_IPv6): | |||
return "too large to block" |
@@ -37,7 +37,7 @@ class Lag(Command): | |||
msg = base.format(site.name, self.get_replag(site)) | |||
elif data.command == "maxlag": | |||
base = "\x0302{0}\x0F: {1}." | |||
msg = base.format(site.name, self.get_maxlag(site).capitalize()) | |||
msg = base.format(site.name, self.get_maxlag(site)) | |||
else: | |||
base = "\x0302{0}\x0F: {1}; {2}." | |||
msg = base.format(site.name, self.get_replag(site), | |||
@@ -45,10 +45,10 @@ class Lag(Command): | |||
self.reply(data, msg) | |||
def get_replag(self, site): | |||
return "replag is {0}".format(self.time(site.get_replag())) | |||
return "SQL replag is {0}".format(self.time(site.get_replag())) | |||
def get_maxlag(self, site): | |||
return "database maxlag is {0}".format(self.time(site.get_maxlag())) | |||
return "API maxlag is {0}".format(self.time(site.get_maxlag())) | |||
def get_site(self, data): | |||
if data.kwargs and "project" in data.kwargs and "lang" in data.kwargs: | |||
@@ -32,7 +32,19 @@ class Notes(Command): | |||
"""A mini IRC-based wiki for storing notes, tips, and reminders.""" | |||
name = "notes" | |||
commands = ["notes", "note", "about"] | |||
version = 2 | |||
version = "2.1" | |||
aliases = { | |||
"all": "list", | |||
"show": "read", | |||
"get": "read", | |||
"add": "edit", | |||
"write": "edit", | |||
"change": "edit", | |||
"modify": "edit", | |||
"move": "rename", | |||
"remove": "delete" | |||
} | |||
def setup(self): | |||
self._dbfile = path.join(self.config.root_dir, "notes.db") | |||
@@ -50,14 +62,13 @@ class Notes(Command): | |||
} | |||
if not data.args: | |||
msg = ("\x0302The Earwig Mini-Wiki\x0F: running v{0}. Subcommands " | |||
"are: {1}. You can get help on any with '!{2} help subcommand'.") | |||
cmnds = ", ".join((commands)) | |||
self.reply(data, msg.format(self.version, cmnds, data.command)) | |||
self.do_help(data) | |||
return | |||
command = data.args[0].lower() | |||
if command in commands: | |||
commands[command](data) | |||
elif command in self.aliases: | |||
commands[self.aliases[command]](data) | |||
else: | |||
msg = "Unknown subcommand: \x0303{0}\x0F.".format(command) | |||
self.reply(data, msg) | |||
@@ -83,8 +94,13 @@ class Notes(Command): | |||
try: | |||
command = data.args[1] | |||
except IndexError: | |||
self.reply(data, "Please specify a subcommand to get help on.") | |||
msg = ("\x0302The Earwig Mini-Wiki\x0F: running v{0}. Subcommands " | |||
"are: {1}. You can get help on any with '!{2} help subcommand'.") | |||
cmnds = ", ".join((info.keys())) | |||
self.reply(data, msg.format(self.version, cmnds, data.command)) | |||
return | |||
if command in self.aliases: | |||
command = self.aliases[command] | |||
try: | |||
help_ = re.sub(r"\s\s+", " ", info[command].replace("\n", "")) | |||
self.reply(data, "\x0303{0}\x0F: ".format(command) + help_) | |||
@@ -113,7 +129,7 @@ class Notes(Command): | |||
INNER JOIN revisions ON entry_revision = rev_id | |||
WHERE entry_slug = ?""" | |||
try: | |||
slug = self.slugify(data.args[1]) | |||
slug = self._slugify(data.args[1]) | |||
except IndexError: | |||
self.reply(data, "Please specify an entry to read from.") | |||
return | |||
@@ -141,7 +157,7 @@ class Notes(Command): | |||
query3 = "INSERT INTO entries VALUES (?, ?, ?, ?)" | |||
query4 = "UPDATE entries SET entry_revision = ? WHERE entry_id = ?" | |||
try: | |||
slug = self.slugify(data.args[1]) | |||
slug = self._slugify(data.args[1]) | |||
except IndexError: | |||
self.reply(data, "Please specify an entry to edit.") | |||
return | |||
@@ -157,17 +173,17 @@ class Notes(Command): | |||
create = False | |||
except sqlite.OperationalError: | |||
id_, title, author = 1, data.args[1].decode("utf8"), data.host | |||
self.create_db(conn) | |||
self._create_db(conn) | |||
except TypeError: | |||
id_ = self.get_next_entry(conn) | |||
id_ = self._get_next_entry(conn) | |||
title, author = data.args[1].decode("utf8"), data.host | |||
permdb = self.config.irc["permissions"] | |||
if author != data.host and not permdb.is_admin(data): | |||
msg = "You must be an author or a bot admin to edit this entry." | |||
self.reply(data, msg) | |||
return | |||
revid = self.get_next_revision(conn) | |||
userid = self.get_user(conn, data.host) | |||
revid = self._get_next_revision(conn) | |||
userid = self._get_user(conn, data.host) | |||
now = datetime.utcnow().strftime("%b %d, %Y %H:%M:%S") | |||
conn.execute(query2, (revid, id_, userid, now, content)) | |||
if create: | |||
@@ -185,7 +201,7 @@ class Notes(Command): | |||
INNER JOIN users ON rev_user = user_id | |||
WHERE entry_slug = ?""" | |||
try: | |||
slug = self.slugify(data.args[1]) | |||
slug = self._slugify(data.args[1]) | |||
except IndexError: | |||
self.reply(data, "Please specify an entry to get info on.") | |||
return | |||
@@ -221,7 +237,7 @@ class Notes(Command): | |||
query2 = """UPDATE entries SET entry_slug = ?, entry_title = ? | |||
WHERE entry_id = ?""" | |||
try: | |||
slug = self.slugify(data.args[1]) | |||
slug = self._slugify(data.args[1]) | |||
except IndexError: | |||
self.reply(data, "Please specify an entry to rename.") | |||
return | |||
@@ -246,7 +262,7 @@ class Notes(Command): | |||
msg = "You must be an author or a bot admin to rename this entry." | |||
self.reply(data, msg) | |||
return | |||
args = (self.slugify(newtitle), newtitle.decode("utf8"), id_) | |||
args = (self._slugify(newtitle), newtitle.decode("utf8"), id_) | |||
conn.execute(query2, args) | |||
msg = "Entry \x0302{0}\x0F renamed to \x0302{1}\x0F." | |||
@@ -261,7 +277,7 @@ class Notes(Command): | |||
query2 = "DELETE FROM entries WHERE entry_id = ?" | |||
query3 = "DELETE FROM revisions WHERE rev_entry = ?" | |||
try: | |||
slug = self.slugify(data.args[1]) | |||
slug = self._slugify(data.args[1]) | |||
except IndexError: | |||
self.reply(data, "Please specify an entry to delete.") | |||
return | |||
@@ -283,11 +299,11 @@ class Notes(Command): | |||
self.reply(data, "Entry \x0302{0}\x0F deleted.".format(data.args[1])) | |||
def slugify(self, name): | |||
def _slugify(self, name): | |||
"""Convert *name* into an identifier for storing in the database.""" | |||
return name.lower().replace("_", "").replace("-", "").decode("utf8") | |||
def create_db(self, conn): | |||
def _create_db(self, conn): | |||
"""Initialize the notes database with its necessary tables.""" | |||
script = """ | |||
CREATE TABLE entries (entry_id, entry_slug, entry_title, | |||
@@ -298,19 +314,19 @@ class Notes(Command): | |||
""" | |||
conn.executescript(script) | |||
def get_next_entry(self, conn): | |||
def _get_next_entry(self, conn): | |||
"""Get the next entry ID.""" | |||
query = "SELECT MAX(entry_id) FROM entries" | |||
later = conn.execute(query).fetchone()[0] | |||
return later + 1 if later else 1 | |||
def get_next_revision(self, conn): | |||
def _get_next_revision(self, conn): | |||
"""Get the next revision ID.""" | |||
query = "SELECT MAX(rev_id) FROM revisions" | |||
later = conn.execute(query).fetchone()[0] | |||
return later + 1 if later else 1 | |||
def get_user(self, conn, host): | |||
def _get_user(self, conn, host): | |||
"""Get the user ID corresponding to a hostname, or make one.""" | |||
query1 = "SELECT user_id FROM users WHERE user_host = ?" | |||
query2 = "SELECT MAX(user_id) FROM users" | |||
@@ -1,6 +1,6 @@ | |||
# -*- coding: utf-8 -*- | |||
# | |||
# Copyright (C) 2009-2015 Ben Kurtovic <ben.kurtovic@gmail.com> | |||
# Copyright (C) 2009-2016 Ben Kurtovic <ben.kurtovic@gmail.com> | |||
# | |||
# Permission is hereby granted, free of charge, to any person obtaining a copy | |||
# of this software and associated documentation files (the "Software"), to deal | |||
@@ -21,7 +21,6 @@ | |||
# SOFTWARE. | |||
import ast | |||
from contextlib import contextmanager | |||
from itertools import chain | |||
import operator | |||
import random | |||
@@ -31,13 +30,24 @@ import time | |||
from earwigbot.commands import Command | |||
from earwigbot.irc import Data | |||
DISPLAY = ["display", "show", "list", "info", "details"] | |||
DISPLAY = ["display", "show", "info", "details"] | |||
CANCEL = ["cancel", "stop", "delete", "del", "stop", "unremind", "forget", | |||
"disregard"] | |||
SNOOZE = ["snooze", "delay", "reset", "adjust", "modify", "change"] | |||
SNOOZE_ONLY = ["snooze", "delay", "reset"] | |||
def _format_time(epoch): | |||
"""Format a UNIX timestamp nicely.""" | |||
lctime = time.localtime(epoch) | |||
if lctime.tm_year == time.localtime().tm_year: | |||
return time.strftime("%b %d %H:%M:%S %Z", lctime) | |||
else: | |||
return time.strftime("%b %d, %Y %H:%M:%S %Z", lctime) | |||
class Remind(Command): | |||
"""Set a message to be repeated to you in a certain amount of time.""" | |||
"""Set a message to be repeated to you in a certain amount of time. See | |||
usage with !remind help.""" | |||
name = "remind" | |||
commands = ["remind", "reminder", "reminders", "snooze", "cancel", | |||
"unremind", "forget"] | |||
@@ -49,8 +59,10 @@ class Remind(Command): | |||
return "display" | |||
if command in CANCEL: | |||
return "cancel" | |||
if command in SNOOZE: | |||
if command in SNOOZE_ONLY: | |||
return "snooze" | |||
if command in SNOOZE: # "adjust" == snoozing active reminders | |||
return "adjust" | |||
@staticmethod | |||
def _parse_time(arg): | |||
@@ -77,26 +89,18 @@ class Remind(Command): | |||
else: | |||
raise ValueError(node) | |||
if arg and arg[-1] in time_units: | |||
factor, arg = time_units[arg[-1]], arg[:-1] | |||
else: | |||
factor = 1 | |||
for unit, factor in time_units.iteritems(): | |||
arg = arg.replace(unit, "*" + str(factor)) | |||
try: | |||
parsed = int(_evaluate(ast.parse(arg, mode="eval").body) * factor) | |||
parsed = int(_evaluate(ast.parse(arg, mode="eval").body)) | |||
except (SyntaxError, KeyError): | |||
raise ValueError(arg) | |||
if parsed <= 0: | |||
raise ValueError(parsed) | |||
return parsed | |||
@contextmanager | |||
def _db(self): | |||
"""Return a threadsafe context manager for the permissions database.""" | |||
with self._db_lock: | |||
yield self.config.irc["permissions"] | |||
def _really_get_reminder_by_id(self, user, rid): | |||
def _get_reminder_by_id(self, user, rid): | |||
"""Return the _Reminder object that corresponds to a particular ID. | |||
Raises IndexError on failure. | |||
@@ -106,17 +110,6 @@ class Remind(Command): | |||
raise IndexError(rid) | |||
return [robj for robj in self.reminders[user] if robj.id == rid][0] | |||
def _get_reminder_by_id(self, user, rid, data): | |||
"""Return the _Reminder object that corresponds to a particular ID. | |||
Sends an error message to the user on failure. | |||
""" | |||
try: | |||
return self._really_get_reminder_by_id(user, rid) | |||
except IndexError: | |||
msg = "Couldn't find a reminder for \x0302{0}\x0F with ID \x0303{1}\x0F." | |||
self.reply(data, msg.format(user, rid)) | |||
def _get_new_id(self): | |||
"""Get a free ID for a new reminder.""" | |||
taken = set(robj.id for robj in chain(*self.reminders.values())) | |||
@@ -125,13 +118,13 @@ class Remind(Command): | |||
def _start_reminder(self, reminder, user): | |||
"""Start the given reminder object for the given user.""" | |||
reminder.start() | |||
if user in self.reminders: | |||
self.reminders[user].append(reminder) | |||
else: | |||
self.reminders[user] = [reminder] | |||
self._thread.add(reminder) | |||
def _create_reminder(self, data, user): | |||
def _create_reminder(self, data): | |||
"""Create a new reminder for the given user.""" | |||
try: | |||
wait = self._parse_time(data.args[0]) | |||
@@ -144,7 +137,6 @@ class Remind(Command): | |||
msg = "Given time \x02{0}\x0F is too large. Keep it reasonable." | |||
return self.reply(data, msg.format(data.args[0])) | |||
end = time.time() + wait | |||
message = " ".join(data.args[1:]) | |||
try: | |||
rid = self._get_new_id() | |||
@@ -152,8 +144,8 @@ class Remind(Command): | |||
msg = "Couldn't set a new reminder: no free IDs available." | |||
return self.reply(data, msg) | |||
reminder = _Reminder(rid, user, wait, end, message, data, self) | |||
self._start_reminder(reminder, user) | |||
reminder = _Reminder(rid, data.host, wait, message, data, self) | |||
self._start_reminder(reminder, data.host) | |||
msg = "Set reminder \x0303{0}\x0F ({1})." | |||
self.reply(data, msg.format(rid, reminder.end_time)) | |||
@@ -164,104 +156,85 @@ class Remind(Command): | |||
reminder.message) | |||
self.reply(data, msg) | |||
def _cancel_reminder(self, data, user, reminder): | |||
def _cancel_reminder(self, data, reminder): | |||
"""Cancel a pending reminder.""" | |||
reminder.stop() | |||
self.reminders[user].remove(reminder) | |||
if not self.reminders[user]: | |||
del self.reminders[user] | |||
self._thread.remove(reminder) | |||
self.unstore_reminder(reminder.id) | |||
self.reminders[data.host].remove(reminder) | |||
if not self.reminders[data.host]: | |||
del self.reminders[data.host] | |||
msg = "Reminder \x0303{0}\x0F canceled." | |||
self.reply(data, msg.format(reminder.id)) | |||
def _snooze_reminder(self, data, reminder, arg=None): | |||
"""Snooze a reminder to be re-triggered after a period of time.""" | |||
verb = "snoozed" if reminder.end < time.time() else "adjusted" | |||
if arg: | |||
try: | |||
duration = self._parse_time(data.args[arg]) | |||
reminder.wait = duration | |||
except (IndexError, ValueError): | |||
pass | |||
reminder.end = time.time() + reminder.wait | |||
reminder.start() | |||
end = time.strftime("%b %d %H:%M:%S %Z", time.localtime(reminder.end)) | |||
verb = "snoozed" if reminder.expired else "adjusted" | |||
try: | |||
duration = self._parse_time(arg) if arg else None | |||
except ValueError: | |||
duration = None | |||
reminder.reset(duration) | |||
end = _format_time(reminder.end) | |||
msg = "Reminder \x0303{0}\x0F {1} until {2}." | |||
self.reply(data, msg.format(reminder.id, verb, end)) | |||
def _load_reminders(self): | |||
"""Load previously made reminders from the database.""" | |||
with self._db() as permdb: | |||
try: | |||
database = permdb.get_attr("command:remind", "data") | |||
except KeyError: | |||
return | |||
permdb.set_attr("command:remind", "data", "[]") | |||
permdb = self.config.irc["permissions"] | |||
try: | |||
database = permdb.get_attr("command:remind", "data") | |||
except KeyError: | |||
return | |||
permdb.set_attr("command:remind", "data", "[]") | |||
connect_wait = 30 | |||
for item in ast.literal_eval(database): | |||
rid, user, wait, end, message, data = item | |||
if end < time.time(): | |||
continue | |||
if end < time.time() + connect_wait: | |||
# Make reminders that have expired while the bot was offline | |||
# trigger shortly after startup | |||
end = time.time() + connect_wait | |||
data = Data.unserialize(data) | |||
reminder = _Reminder(rid, user, wait, end, message, data, self) | |||
reminder = _Reminder(rid, user, wait, message, data, self, end) | |||
self._start_reminder(reminder, user) | |||
def _handle_command(self, command, data, user, reminder, arg=None): | |||
"""Handle a reminder-processing subcommand.""" | |||
if command in DISPLAY: | |||
self._display_reminder(data, reminder) | |||
elif command in CANCEL: | |||
self._cancel_reminder(data, user, reminder) | |||
elif command in SNOOZE: | |||
self._snooze_reminder(data, reminder, arg) | |||
else: | |||
msg = "Unknown action \x02{0}\x0F for reminder \x0303{1}\x0F." | |||
self.reply(data, msg.format(command, reminder.id)) | |||
def _show_reminders(self, data, user): | |||
def _show_reminders(self, data): | |||
"""Show all of a user's current reminders.""" | |||
shorten = lambda s: (s[:37] + "..." if len(s) > 40 else s) | |||
tmpl = '\x0303{0}\x0F ("{1}", {2})' | |||
fmt = lambda robj: tmpl.format(robj.id, shorten(robj.message), | |||
robj.end_time) | |||
if user in self.reminders: | |||
rlist = ", ".join(fmt(robj) for robj in self.reminders[user]) | |||
msg = "Your reminders: {0}.".format(rlist) | |||
else: | |||
msg = ("You have no reminders. Set one with \x0306!remind [time] " | |||
"[message]\x0F. See also: \x0306!remind help\x0F.") | |||
self.reply(data, msg) | |||
def _process_snooze_command(self, data, user): | |||
"""Process the !snooze command.""" | |||
if not data.args: | |||
if user not in self.reminders: | |||
self.reply(data, "You have no reminders to snooze.") | |||
elif len(self.reminders[user]) == 1: | |||
self._snooze_reminder(data, self.reminders[user][0]) | |||
else: | |||
msg = "You have {0} reminders. Snooze which one?" | |||
self.reply(data, msg.format(len(self.reminders[user]))) | |||
if data.host not in self.reminders: | |||
self.reply(data, "You have no reminders. Set one with " | |||
"\x0306!remind [time] [message]\x0F. See also: " | |||
"\x0306!remind help\x0F.") | |||
return | |||
reminder = self._get_reminder_by_id(user, data.args[0], data) | |||
if reminder: | |||
self._snooze_reminder(data, reminder, 1) | |||
def _process_cancel_command(self, data, user): | |||
"""Process the !cancel, !unremind, and !forget commands.""" | |||
if not data.args: | |||
if user not in self.reminders: | |||
self.reply(data, "You have no reminders to cancel.") | |||
elif len(self.reminders[user]) == 1: | |||
self._cancel_reminder(data, user, self.reminders[user][0]) | |||
else: | |||
msg = "You have {0} reminders. Cancel which one?" | |||
self.reply(data, msg.format(len(self.reminders[user]))) | |||
shorten = lambda s: (s[:37] + "..." if len(s) > 40 else s) | |||
dest = lambda data: ( | |||
"privately" if data.is_private else "in {0}".format(data.chan)) | |||
fmt = lambda robj: '\x0303{0}\x0F ("{1}" {2}, {3})'.format( | |||
robj.id, shorten(robj.message), dest(robj.data), robj.end_time) | |||
rlist = ", ".join(fmt(robj) for robj in self.reminders[data.host]) | |||
self.reply(data, "Your reminders: {0}.".format(rlist)) | |||
def _show_all_reminders(self, data): | |||
"""Show all reminders to bot admins.""" | |||
if not self.config.irc["permissions"].is_admin(data): | |||
self.reply(data, "You must be a bot admin to view other users' " | |||
"reminders. View your own with " | |||
"\x0306!reminders\x0F.") | |||
return | |||
reminder = self._get_reminder_by_id(user, data.args[0], data) | |||
if reminder: | |||
self._cancel_reminder(data, user, reminder) | |||
if not self.reminders: | |||
self.reply(data, "There are no active reminders.") | |||
return | |||
dest = lambda data: ( | |||
"privately" if data.is_private else "in {0}".format(data.chan)) | |||
fmt = lambda robj, user: '\x0303{0}\x0F (for {1} {2}, {3})'.format( | |||
robj.id, user, dest(robj.data), robj.end_time) | |||
rlist = (fmt(rem, user) for user, rems in self.reminders.iteritems() | |||
for rem in rems) | |||
self.reply(data, "All reminders: {0}.".format(", ".join(rlist))) | |||
def _show_help(self, data): | |||
"""Reply to the user with help for all major subcommands.""" | |||
@@ -271,122 +244,205 @@ class Remind(Command): | |||
("Get info", "!remind [id]"), | |||
("Cancel", "!remind cancel [id]"), | |||
("Adjust", "!remind adjust [id] [time]"), | |||
("Restart", "!snooze [id]") | |||
("Restart", "!snooze [id] [time]"), | |||
("Admin", "!remind all") | |||
] | |||
extra = "In most cases, \x0306[id]\x0F can be omitted if you have only one reminder." | |||
extra = "The \x0306[id]\x0F can be omitted if you have only one reminder." | |||
joined = " ".join("{0}: \x0306{1}\x0F.".format(k, v) for k, v in parts) | |||
self.reply(data, joined + " " + extra) | |||
def setup(self): | |||
self.reminders = {} | |||
self._db_lock = RLock() | |||
self._load_reminders() | |||
def process(self, data): | |||
if data.command == "snooze": | |||
return self._process_snooze_command(data, data.host) | |||
if data.command in ["cancel", "unremind", "forget"]: | |||
return self._process_cancel_command(data, data.host) | |||
if not data.args: | |||
return self._show_reminders(data, data.host) | |||
def _dispatch_command(self, data, command, args): | |||
"""Handle a reminder-processing subcommand.""" | |||
user = data.host | |||
if len(data.args) == 1: | |||
command = data.args[0] | |||
if command == "help": | |||
return self._show_help(data) | |||
if command in DISPLAY + CANCEL + SNOOZE: | |||
if user not in self.reminders: | |||
msg = "You have no reminders to {0}." | |||
self.reply(data, msg.format(self._normalize(command))) | |||
elif len(self.reminders[user]) == 1: | |||
reminder = self.reminders[user][0] | |||
self._handle_command(command, data, user, reminder) | |||
else: | |||
msg = "You have {0} reminders. {1} which one?" | |||
num = len(self.reminders[user]) | |||
command = self._normalize(command).capitalize() | |||
self.reply(data, msg.format(num, command)) | |||
reminder = None | |||
if args and args[0].upper().startswith("R"): | |||
try: | |||
reminder = self._get_reminder_by_id(user, args[0]) | |||
except IndexError: | |||
msg = "Couldn't find a reminder for \x0302{0}\x0F with ID \x0303{1}\x0F." | |||
self.reply(data, msg.format(user, args[0])) | |||
return | |||
reminder = self._get_reminder_by_id(user, data.args[0], data) | |||
if reminder: | |||
self._display_reminder(data, reminder) | |||
args.pop(0) | |||
elif user not in self.reminders: | |||
msg = "You have no reminders to {0}." | |||
self.reply(data, msg.format(self._normalize(command))) | |||
return | |||
elif len(self.reminders[user]) == 1: | |||
reminder = self.reminders[user][0] | |||
elif command in SNOOZE_ONLY: # Select most recent expired reminder | |||
rmds = [rmd for rmd in self.reminders[user] if rmd.expired] | |||
rmds.sort(key=lambda rmd: rmd.end) | |||
if len(rmds) > 0: | |||
reminder = rmds[-1] | |||
elif command in SNOOZE or command in CANCEL: # Select only active one | |||
rmds = [rmd for rmd in self.reminders[user] if not rmd.expired] | |||
if len(rmds) == 1: | |||
reminder = rmds[0] | |||
if not reminder: | |||
msg = "You have {0} reminders. {1} which one?" | |||
num = len(self.reminders[user]) | |||
command = self._normalize(command).capitalize() | |||
self.reply(data, msg.format(num, command)) | |||
return | |||
if command in DISPLAY: | |||
self._display_reminder(data, reminder) | |||
elif command in CANCEL: | |||
self._cancel_reminder(data, reminder) | |||
elif command in SNOOZE: | |||
self._snooze_reminder(data, reminder, args[0] if args else None) | |||
else: | |||
msg = "Unknown action \x02{0}\x0F for reminder \x0303{1}\x0F." | |||
self.reply(data, msg.format(command, reminder.id)) | |||
def _process(self, data): | |||
"""Main entry point.""" | |||
if data.command in SNOOZE + CANCEL: | |||
return self._dispatch_command(data, data.command, data.args) | |||
if not data.args: | |||
return self._show_reminders(data) | |||
if data.args[0] == "help": | |||
return self._show_help(data) | |||
if data.args[0] == "list": | |||
return self._show_reminders(data) | |||
if data.args[0] == "all": | |||
return self._show_all_reminders(data) | |||
if data.args[0] in DISPLAY + CANCEL + SNOOZE: | |||
reminder = self._get_reminder_by_id(user, data.args[1], data) | |||
if reminder: | |||
self._handle_command(data.args[0], data, user, reminder, 2) | |||
return | |||
return self._dispatch_command(data, data.args[0], data.args[1:]) | |||
try: | |||
reminder = self._really_get_reminder_by_id(user, data.args[0]) | |||
self._get_reminder_by_id(data.host, data.args[0]) | |||
except IndexError: | |||
return self._create_reminder(data, user) | |||
return self._create_reminder(data) | |||
if len(data.args) == 1: | |||
return self._dispatch_command(data, "display", data.args) | |||
self._dispatch_command( | |||
data, data.args[1], [data.args[0]] + data.args[2:]) | |||
self._handle_command(data.args[1], data, user, reminder, 2) | |||
@property | |||
def lock(self): | |||
"""Return the reminder modification/access lock.""" | |||
return self._lock | |||
def setup(self): | |||
self.reminders = {} | |||
self._lock = RLock() | |||
self._thread = _ReminderThread(self._lock) | |||
self._load_reminders() | |||
def process(self, data): | |||
with self.lock: | |||
self._process(data) | |||
def unload(self): | |||
for reminder in chain(*self.reminders.values()): | |||
reminder.stop(delete=False) | |||
self._thread.stop() | |||
def store_reminder(self, reminder): | |||
"""Store a serialized reminder into the database.""" | |||
with self._db() as permdb: | |||
try: | |||
dump = permdb.get_attr("command:remind", "data") | |||
except KeyError: | |||
dump = "[]" | |||
permdb = self.config.irc["permissions"] | |||
try: | |||
dump = permdb.get_attr("command:remind", "data") | |||
except KeyError: | |||
dump = "[]" | |||
database = ast.literal_eval(dump) | |||
database.append(reminder) | |||
permdb.set_attr("command:remind", "data", str(database)) | |||
database = ast.literal_eval(dump) | |||
database.append(reminder) | |||
permdb.set_attr("command:remind", "data", str(database)) | |||
def unstore_reminder(self, rid): | |||
"""Remove a reminder from the database by ID.""" | |||
with self._db() as permdb: | |||
try: | |||
dump = permdb.get_attr("command:remind", "data") | |||
except KeyError: | |||
dump = "[]" | |||
permdb = self.config.irc["permissions"] | |||
try: | |||
dump = permdb.get_attr("command:remind", "data") | |||
except KeyError: | |||
dump = "[]" | |||
database = ast.literal_eval(dump) | |||
database = [item for item in database if item[0] != rid] | |||
permdb.set_attr("command:remind", "data", str(database)) | |||
class _ReminderThread(object): | |||
"""A single thread that handles reminders.""" | |||
def __init__(self, lock): | |||
self._thread = None | |||
self._abort = False | |||
self._active = {} | |||
self._lock = lock | |||
def _running(self): | |||
"""Return if the thread should still be running.""" | |||
return self._active and not self._abort | |||
def _get_soonest(self): | |||
"""Get the soonest reminder to trigger.""" | |||
return min(self._active.values(), key=lambda robj: robj.end) | |||
def _get_ready_reminder(self): | |||
"""Block until a reminder is ready to be triggered.""" | |||
while self._running(): | |||
if self._get_soonest().end <= time.time(): | |||
return self._get_soonest() | |||
self._lock.release() | |||
time.sleep(0.25) | |||
self._lock.acquire() | |||
def _callback(self): | |||
"""Internal callback function to be executed by the reminder thread.""" | |||
with self._lock: | |||
while True: | |||
reminder = self._get_ready_reminder() | |||
if not reminder: | |||
break | |||
if reminder.trigger(): | |||
del self._active[reminder.id] | |||
self._thread = None | |||
def _start(self): | |||
"""Start the thread.""" | |||
self._thread = Thread(target=self._callback, name="reminder") | |||
self._thread.daemon = True | |||
self._thread.start() | |||
self._abort = False | |||
def add(self, reminder): | |||
"""Add a reminder to the table of active reminders.""" | |||
self._active[reminder.id] = reminder | |||
if not self._thread: | |||
self._start() | |||
def remove(self, reminder): | |||
"""Remove a reminder from the table of active reminders.""" | |||
if reminder.id in self._active: | |||
del self._active[reminder.id] | |||
if not self._active: | |||
self.stop() | |||
def stop(self): | |||
"""Stop the thread.""" | |||
if not self._thread: | |||
return | |||
self._abort = True | |||
self._thread = None | |||
database = ast.literal_eval(dump) | |||
database = [item for item in database if item[0] != rid] | |||
permdb.set_attr("command:remind", "data", str(database)) | |||
class _Reminder(object): | |||
"""Represents a single reminder.""" | |||
def __init__(self, rid, user, wait, end, message, data, cmdobj): | |||
def __init__(self, rid, user, wait, message, data, cmdobj, end=None): | |||
self.id = rid | |||
self.wait = wait | |||
self.end = end | |||
self.end = time.time() + wait if end is None else end | |||
self.message = message | |||
self._user = user | |||
self._data = data | |||
self._cmdobj = cmdobj | |||
self._thread = None | |||
self._expired = False | |||
def _callback(self): | |||
"""Internal callback function to be executed by the reminder thread.""" | |||
thread = self._thread | |||
while time.time() < thread.end: | |||
time.sleep(1) | |||
if thread.abort: | |||
return | |||
self._cmdobj.reply(self._data, self.message) | |||
self._delete() | |||
for i in xrange(60): | |||
time.sleep(1) | |||
if thread.abort: | |||
return | |||
try: | |||
self._cmdobj.reminders[self._user].remove(self) | |||
if not self._cmdobj.reminders[self._user]: | |||
del self._cmdobj.reminders[self._user] | |||
except (KeyError, ValueError): # Already canceled by the user | |||
pass | |||
self._save() | |||
def _save(self): | |||
"""Save this reminder to the database.""" | |||
@@ -394,37 +450,54 @@ class _Reminder(object): | |||
item = (self.id, self._user, self.wait, self.end, self.message, data) | |||
self._cmdobj.store_reminder(item) | |||
def _delete(self): | |||
"""Remove this reminder from the database.""" | |||
def _fire(self): | |||
"""Activate the reminder for the user.""" | |||
self._cmdobj.reply(self._data, self.message) | |||
self._cmdobj.unstore_reminder(self.id) | |||
self.end = time.time() + (60 * 60 * 24) | |||
self._expired = True | |||
def _finalize(self): | |||
"""Clean up after a reminder has been expired for too long.""" | |||
try: | |||
self._cmdobj.reminders[self._user].remove(self) | |||
if not self._cmdobj.reminders[self._user]: | |||
del self._cmdobj.reminders[self._user] | |||
except (KeyError, ValueError): # Already canceled by the user | |||
pass | |||
@property | |||
def data(self): | |||
"""Return the IRC data object associated with this reminder.""" | |||
return self._data | |||
@property | |||
def end_time(self): | |||
"""Return a string representing the end time of a reminder.""" | |||
if self.end >= time.time(): | |||
lctime = time.localtime(self.end) | |||
if lctime.tm_year == time.localtime().tm_year: | |||
ends = time.strftime("%b %d %H:%M:%S %Z", lctime) | |||
else: | |||
ends = time.strftime("%b %d, %Y %H:%M:%S %Z", lctime) | |||
return "ends {0}".format(ends) | |||
return "expired" | |||
def start(self): | |||
"""Start the reminder timer thread. Stops it if already running.""" | |||
self.stop() | |||
self._thread = Thread(target=self._callback, name="remind-" + self.id) | |||
self._thread.end = self.end | |||
self._thread.daemon = True | |||
self._thread.abort = False | |||
self._thread.start() | |||
if self._expired or self.end < time.time(): | |||
return "expired" | |||
return "ends {0}".format(_format_time(self.end)) | |||
@property | |||
def expired(self): | |||
"""Return whether the reminder is expired.""" | |||
return self._expired | |||
def reset(self, wait=None): | |||
"""Reactivate a reminder.""" | |||
if wait is not None: | |||
self.wait = wait | |||
self.end = self.wait + time.time() | |||
self._expired = False | |||
self._cmdobj.unstore_reminder(self.id) | |||
self._save() | |||
def stop(self, delete=True): | |||
"""Stop a currently running reminder.""" | |||
if not self._thread: | |||
return | |||
if delete: | |||
self._delete() | |||
self._thread.abort = True | |||
self._thread = None | |||
def trigger(self): | |||
"""Hook run by the reminder thread.""" | |||
if not self._expired: | |||
self._fire() | |||
return False | |||
else: | |||
self._finalize() | |||
return True |
@@ -21,13 +21,14 @@ | |||
# SOFTWARE. | |||
from ast import literal_eval | |||
import re | |||
from earwigbot.commands import Command | |||
from earwigbot.irc import RC | |||
class Stalk(Command): | |||
"""Stalk a particular user (!stalk/!unstalk) or page (!watch/!unwatch) for | |||
edits. Applies to the current bot session only.""" | |||
edits. Prefix regular expressions with "re:" (uses re.match).""" | |||
name = "stalk" | |||
commands = ["stalk", "watch", "unstalk", "unwatch", "stalks", "watches", | |||
"allstalks", "allwatches", "unstalkall", "unwatchall"] | |||
@@ -79,9 +80,12 @@ class Stalk(Command): | |||
target = " ".join(data.args).replace("_", " ") | |||
if target.startswith("[[") and target.endswith("]]"): | |||
target = target[2:-2] | |||
if target.startswith("User:") and "stalk" in data.command: | |||
target = target[5:] | |||
target = target[0].upper() + target[1:] | |||
if target.startswith("re:"): | |||
target = "re:" + target[3:].lstrip() | |||
else: | |||
if target.startswith("User:") and "stalk" in data.command: | |||
target = target[5:] | |||
target = target[0].upper() + target[1:] | |||
if data.command in ["stalk", "watch"]: | |||
if data.is_private: | |||
@@ -119,12 +123,12 @@ class Stalk(Command): | |||
else: | |||
chans[item[0]] = None | |||
def _wildcard_match(target, tag): | |||
return target[-1] == "*" and tag.startswith(target[:-1]) | |||
def _regex_match(target, tag): | |||
return target.startswith("re:") and re.match(target[3:], tag) | |||
def _process(table, tag): | |||
for target, stalks in table.iteritems(): | |||
if target == tag or _wildcard_match(target, tag): | |||
if target == tag or _regex_match(target, tag): | |||
_update_chans(stalks) | |||
chans = {} | |||
@@ -71,14 +71,11 @@ class Threads(Command): | |||
tname = thread.name | |||
ident = thread.ident % 10000 | |||
if tname == "MainThread": | |||
t = "\x0302MainThread\x0F (id {0})" | |||
t = "\x0302main\x0F (id {0})" | |||
normal_threads.append(t.format(ident)) | |||
elif tname in self.config.components: | |||
t = "\x0302{0}\x0F (id {1})" | |||
normal_threads.append(t.format(tname, ident)) | |||
elif tname.startswith("remind-"): | |||
t = "\x0302reminder\x0F (id {0})" | |||
daemon_threads.append(t.format(tname[len("remind-"):])) | |||
elif tname.startswith("cvworker-"): | |||
t = "\x0302copyvio worker\x0F (site {0})" | |||
daemon_threads.append(t.format(tname[len("cvworker-"):])) | |||
@@ -145,6 +142,9 @@ class Threads(Command): | |||
return | |||
data.kwargs["fromIRC"] = True | |||
data.kwargs["_IRCCallback"] = lambda: self.reply( | |||
data, "Task \x0302{0}\x0F finished.".format(task_name)) | |||
self.bot.tasks.start(task_name, **data.kwargs) | |||
msg = "Task \x0302{0}\x0F started.".format(task_name) | |||
self.reply(data, msg) |
@@ -1,6 +1,6 @@ | |||
# -*- coding: utf-8 -*- | |||
# | |||
# Copyright (C) 2009-2015 Ben Kurtovic <ben.kurtovic@gmail.com> | |||
# Copyright (C) 2009-2016 Ben Kurtovic <ben.kurtovic@gmail.com> | |||
# | |||
# Permission is hereby granted, free of charge, to any person obtaining a copy | |||
# of this software and associated documentation files (the "Software"), to deal | |||
@@ -143,9 +143,9 @@ class ConfigScript(object): | |||
self._print("""I can encrypt passwords stored in your config file in | |||
addition to preventing other users on your system from | |||
reading the file. Encryption is recommended if the bot | |||
is to run on a public computer like the Toolserver, but | |||
otherwise the need to enter a key everytime you start | |||
the bot may be annoying.""") | |||
is to run on a public server like Wikimedia Labs, but | |||
otherwise the need to enter a key every time you start | |||
the bot may be an inconvenience.""") | |||
self.data["metadata"]["encryptPasswords"] = False | |||
if self._ask_bool("Encrypt stored passwords?"): | |||
key = getpass(self.PROMPT + "Enter an encryption key: ") | |||
@@ -270,7 +270,7 @@ class ConfigScript(object): | |||
password = self._ask_pass("Bot password:", encrypt=False) | |||
self.data["wiki"]["password"] = password | |||
self.data["wiki"]["userAgent"] = "EarwigBot/$1 (Python/$2; https://github.com/earwig/earwigbot)" | |||
self.data["wiki"]["summary"] = "([[WP:BOT|Bot]]): $2" | |||
self.data["wiki"]["summary"] = "([[WP:BOT|Bot]]) $2" | |||
self.data["wiki"]["useHTTPS"] = True | |||
self.data["wiki"]["assert"] = "user" | |||
self.data["wiki"]["maxlag"] = 10 | |||
@@ -442,6 +442,10 @@ class ConfigScript(object): | |||
"""Make a new config file based on the user's input.""" | |||
try: | |||
makedirs(path.dirname(self.config.path)) | |||
except OSError as exc: | |||
if exc.errno != 17: | |||
raise | |||
try: | |||
open(self.config.path, "w").close() | |||
chmod(self.config.path, stat.S_IRUSR|stat.S_IWUSR) | |||
except IOError: | |||
@@ -45,6 +45,7 @@ class IRCConnection(object): | |||
self._last_recv = time() | |||
self._last_send = 0 | |||
self._last_ping = 0 | |||
self._myhost = "." * 63 # default: longest possible hostname | |||
def __repr__(self): | |||
"""Return the canonical string representation of the IRCConnection.""" | |||
@@ -100,8 +101,19 @@ class IRCConnection(object): | |||
self.logger.debug(msg) | |||
self._last_send = time() | |||
def _split(self, msgs, maxlen, maxsplits=3): | |||
"""Split a large message into multiple messages smaller than maxlen.""" | |||
def _get_maxlen(self, extra): | |||
"""Return our best guess of the maximum length of a standard message. | |||
This applies mainly to PRIVMSGs and NOTICEs. | |||
""" | |||
base_max = 512 | |||
userhost = len(self.nick) + len(self.ident) + len(self._myhost) + 2 | |||
padding = 4 # "\r\n" at end, ":" at beginning, and " " after userhost | |||
return base_max - userhost - padding - extra | |||
def _split(self, msgs, extralen, maxsplits=3): | |||
"""Split a large message into multiple messages.""" | |||
maxlen = self._get_maxlen(extralen) | |||
words = msgs.split(" ") | |||
splits = 0 | |||
while words and splits < maxsplits: | |||
@@ -128,6 +140,19 @@ class IRCConnection(object): | |||
self._last_recv = time() | |||
if line[0] == "PING": # If we are pinged, pong back | |||
self.pong(line[1][1:]) | |||
elif line[1] == "001": # Update nickname on startup | |||
if line[2] != self.nick: | |||
self.logger.warn("Nickname changed from {0} to {1}".format( | |||
self.nick, line[2])) | |||
self._nick = line[2] | |||
elif line[1] == "376": # After sign-on, get our userhost | |||
self._send("WHOIS {0}".format(self.nick)) | |||
elif line[1] == "311": # Receiving WHOIS result | |||
if line[2] == self.nick: | |||
self._ident = line[4] | |||
self._myhost = line[5] | |||
elif line[1] == "396": # Hostname change | |||
self._myhost = line[3] | |||
def _process_message(self, line): | |||
"""To be overridden in subclasses.""" | |||
@@ -163,7 +188,7 @@ class IRCConnection(object): | |||
def say(self, target, msg, hidelog=False): | |||
"""Send a private message to a target on the server.""" | |||
for msg in self._split(msg, 400): | |||
for msg in self._split(msg, len(target) + 10): | |||
msg = "PRIVMSG {0} :{1}".format(target, msg) | |||
self._send(msg, hidelog) | |||
@@ -182,7 +207,7 @@ class IRCConnection(object): | |||
def notice(self, target, msg, hidelog=False): | |||
"""Send a notice to a target on the server.""" | |||
for msg in self._split(msg, 400): | |||
for msg in self._split(msg, len(target) + 9): | |||
msg = "NOTICE {0} :{1}".format(target, msg) | |||
self._send(msg, hidelog) | |||
@@ -1,6 +1,6 @@ | |||
# -*- coding: utf-8 -*- | |||
# | |||
# Copyright (C) 2009-2015 Ben Kurtovic <ben.kurtovic@gmail.com> | |||
# Copyright (C) 2009-2016 Ben Kurtovic <ben.kurtovic@gmail.com> | |||
# | |||
# Permission is hereby granted, free of charge, to any person obtaining a copy | |||
# of this software and associated documentation files (the "Software"), to deal | |||
@@ -50,20 +50,26 @@ class Data(object): | |||
def _parse(self): | |||
"""Parse a line from IRC into its components as instance attributes.""" | |||
sender = re.findall(r":(.*?)!(.*?)@(.*?)\Z", self.line[0])[0] | |||
self._chan = self.line[2] | |||
try: | |||
sender = re.findall(r":(.*?)!(.*?)@(.*?)\Z", self.line[0])[0] | |||
except IndexError: | |||
self._host = self.line[0][1:] | |||
self._nick = self._ident = self._reply_nick = "*" | |||
return | |||
self._nick, self._ident, self._host = sender | |||
self._reply_nick = self._nick | |||
self._chan = self.line[2] | |||
if self._msgtype == "PRIVMSG": | |||
if self._msgtype in ["PRIVMSG", "NOTICE"]: | |||
if self.chan.lower() == self.my_nick: | |||
# This is a privmsg to us, so set 'chan' as the nick of the | |||
# sender instead of the 'channel', which is ourselves: | |||
self._chan = self._nick | |||
self._is_private = True | |||
self._msg = " ".join(self.line[3:])[1:] | |||
self._parse_args() | |||
self._parse_kwargs() | |||
if self._msgtype == "PRIVMSG": | |||
self._parse_args() | |||
self._parse_kwargs() | |||
def _parse_args(self): | |||
"""Parse command arguments from the message. | |||
@@ -1,6 +1,6 @@ | |||
# -*- coding: utf-8 -*- | |||
# | |||
# Copyright (C) 2009-2015 Ben Kurtovic <ben.kurtovic@gmail.com> | |||
# Copyright (C) 2009-2016 Ben Kurtovic <ben.kurtovic@gmail.com> | |||
# | |||
# Permission is hereby granted, free of charge, to any person obtaining a copy | |||
# of this software and associated documentation files (the "Software"), to deal | |||
@@ -20,6 +20,8 @@ | |||
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE | |||
# SOFTWARE. | |||
from time import sleep | |||
from earwigbot.irc import IRCConnection, Data | |||
__all__ = ["Frontend"] | |||
@@ -36,6 +38,7 @@ class Frontend(IRCConnection): | |||
:py:mod:`earwigbot.commands` or the bot's custom command directory | |||
(explained in the :doc:`documentation </customizing>`). | |||
""" | |||
NICK_SERVICES = "NickServ" | |||
def __init__(self, bot): | |||
self.bot = bot | |||
@@ -43,6 +46,8 @@ class Frontend(IRCConnection): | |||
base = super(Frontend, self) | |||
base.__init__(cf["host"], cf["port"], cf["nick"], cf["ident"], | |||
cf["realname"], bot.logger.getChild("frontend")) | |||
self._auth_wait = False | |||
self._connect() | |||
def __repr__(self): | |||
@@ -56,6 +61,11 @@ class Frontend(IRCConnection): | |||
res = "<Frontend {0}!{1} at {2}:{3}>" | |||
return res.format(self.nick, self.ident, self.host, self.port) | |||
def _join_channels(self): | |||
"""Join all startup channels as specified by the config file.""" | |||
for chan in self.bot.config.irc["frontend"]["channels"]: | |||
self.join(chan) | |||
def _process_message(self, line): | |||
"""Process a single message from IRC.""" | |||
if line[1] == "JOIN": | |||
@@ -74,17 +84,30 @@ class Frontend(IRCConnection): | |||
self.bot.commands.call("msg_public", data) | |||
self.bot.commands.call("msg", data) | |||
elif line[1] == "NOTICE": | |||
data = Data(self.nick, line, msgtype="NOTICE") | |||
if self._auth_wait and data.nick == self.NICK_SERVICES: | |||
if data.msg.startswith("This nickname is registered."): | |||
return | |||
self._auth_wait = False | |||
sleep(2) # Wait for hostname change to propagate | |||
self._join_channels() | |||
elif line[1] == "376": # On successful connection to the server | |||
# If we're supposed to auth to NickServ, do that: | |||
try: | |||
username = self.bot.config.irc["frontend"]["nickservUsername"] | |||
password = self.bot.config.irc["frontend"]["nickservPassword"] | |||
except KeyError: | |||
pass | |||
self._join_channels() | |||
else: | |||
self.logger.debug("Identifying with services") | |||
msg = "IDENTIFY {0} {1}".format(username, password) | |||
self.say("NickServ", msg, hidelog=True) | |||
self.say(self.NICK_SERVICES, msg, hidelog=True) | |||
self._auth_wait = True | |||
# Join all of our startup channels: | |||
for chan in self.bot.config.irc["frontend"]["channels"]: | |||
self.join(chan) | |||
elif line[1] == "401": # No such nickname | |||
if self._auth_wait and line[3] == self.NICK_SERVICES: | |||
# Services is down, or something...? | |||
self._auth_wait = False | |||
self._join_channels() |
@@ -72,14 +72,21 @@ class _ResourceManager(object): | |||
for resource in self._resources.itervalues(): | |||
yield resource | |||
def _is_disabled(self, name): | |||
"""Check whether a resource should be disabled.""" | |||
conf = getattr(self.bot.config, self._resource_name) | |||
disabled = conf.get("disable", []) | |||
enabled = conf.get("enable", []) | |||
return name not in enabled and (disabled is True or name in disabled) | |||
def _load_resource(self, name, path, klass): | |||
"""Instantiate a resource class and add it to the dictionary.""" | |||
res_type = self._resource_name[:-1] # e.g. "command" or "task" | |||
if hasattr(klass, "name"): | |||
res_config = getattr(self.bot.config, self._resource_name) | |||
if getattr(klass, "name") in res_config.get("disable", []): | |||
classname = getattr(klass, "name") | |||
if self._is_disabled(name) and self._is_disabled(classname): | |||
log = "Skipping disabled {0} {1}" | |||
self.logger.debug(log.format(res_type, getattr(klass, "name"))) | |||
self.logger.debug(log.format(res_type, classname)) | |||
return | |||
try: | |||
resource = klass(self.bot) # Create instance of resource | |||
@@ -119,8 +126,6 @@ class _ResourceManager(object): | |||
def _load_directory(self, dir): | |||
"""Load all valid resources in a given directory.""" | |||
self.logger.debug("Loading directory {0}".format(dir)) | |||
res_config = getattr(self.bot.config, self._resource_name) | |||
disabled = res_config.get("disable", []) | |||
processed = [] | |||
for name in listdir(dir): | |||
if not name.endswith(".py") and not name.endswith(".pyc"): | |||
@@ -128,14 +133,14 @@ class _ResourceManager(object): | |||
if name.startswith("_") or name.startswith("."): | |||
continue | |||
modname = sub("\.pyc?$", "", name) # Remove extension | |||
if modname in disabled: | |||
if modname in processed: | |||
continue | |||
processed.append(modname) | |||
if self._is_disabled(modname): | |||
log = "Skipping disabled module {0}".format(modname) | |||
self.logger.debug(log) | |||
processed.append(modname) | |||
continue | |||
if modname not in processed: | |||
self._load_module(modname, dir) | |||
processed.append(modname) | |||
self._load_module(modname, dir) | |||
def _unload_resources(self): | |||
"""Unload all resources, calling their unload hooks in the process.""" | |||
@@ -162,7 +167,8 @@ class _ResourceManager(object): | |||
self._unload_resources() | |||
builtin_dir = path.join(path.dirname(__file__), name) | |||
plugins_dir = path.join(self.bot.config.root_dir, name) | |||
if getattr(self.bot.config, name).get("disable") is True: | |||
conf = getattr(self.bot.config, name) | |||
if conf.get("disable") is True and not conf.get("enable"): | |||
log = "Skipping disabled builtins directory: {0}" | |||
self.logger.debug(log.format(builtin_dir)) | |||
else: | |||
@@ -219,6 +225,13 @@ class CommandManager(_ResourceManager): | |||
.. note:: | |||
The special ``rc`` hook actually passes a :class:`~.RC` object. | |||
""" | |||
try: | |||
quiet = self.bot.config.irc["frontend"]["quiet"][data.chan] | |||
if quiet is True or hook in quiet: | |||
return | |||
except KeyError: | |||
pass | |||
for command in self: | |||
if hook in command.hooks and self._wrap_check(command, data): | |||
thread = Thread(target=self._wrap_process, | |||
@@ -247,6 +260,8 @@ class TaskManager(_ResourceManager): | |||
else: | |||
msg = "Task '{0}' finished successfully" | |||
self.logger.info(msg.format(task.name)) | |||
if kwargs.get("fromIRC"): | |||
kwargs.get("_IRCCallback")() | |||
def start(self, task_name, **kwargs): | |||
"""Start a given task in a new daemon thread, and return the thread. | |||
@@ -54,6 +54,10 @@ class Task(object): | |||
self.bot = bot | |||
self.config = bot.config | |||
self.logger = bot.tasks.logger.getChild(self.name) | |||
number = self.config.tasks.get(self.name, {}).get("number") | |||
if number is not None: | |||
self.number = number | |||
self.setup() | |||
def __repr__(self): | |||
@@ -1,6 +1,6 @@ | |||
# -*- coding: utf-8 -*- | |||
# | |||
# Copyright (C) 2009-2015 Ben Kurtovic <ben.kurtovic@gmail.com> | |||
# Copyright (C) 2009-2017 Ben Kurtovic <ben.kurtovic@gmail.com> | |||
# | |||
# Permission is hereby granted, free of charge, to any person obtaining a copy | |||
# of this software and associated documentation files (the "Software"), to deal | |||
@@ -30,16 +30,16 @@ class WikiProjectTagger(Task): | |||
"""A task to tag talk pages with WikiProject banners. | |||
Usage: :command:`earwigbot -t wikiproject_tagger PATH | |||
--banner BANNER (--category CAT | --file FILE) [--summary SUM] | |||
[--append TEXT] [--autoassess] [--nocreate] [--recursive NUM] | |||
[--site SITE]` | |||
--banner BANNER (--category CAT | --file FILE) [--summary SUM] [--update] | |||
[--append PARAMS] [--autoassess [CLASSES]] [--only-with BANNER] | |||
[--nocreate] [--recursive [NUM]] [--site SITE] [--dry-run]` | |||
.. glossary:: | |||
``--banner BANNER`` | |||
the page name of the banner to add, without a namespace (unless the | |||
namespace is something other than ``Template``) so | |||
``--banner WikiProject Biography`` for ``{{WikiProject Biography}}`` | |||
``--banner "WikiProject Biography"`` for ``{{WikiProject Biography}}`` | |||
``--category CAT`` or ``--file FILE`` | |||
determines which pages to tag; either all pages in a category (to | |||
include subcategories as well, see ``--recursive``) or all | |||
@@ -47,21 +47,33 @@ class WikiProjectTagger(Task): | |||
current directory) | |||
``--summary SUM`` | |||
an optional edit summary to use; defaults to | |||
``"Adding WikiProject banner {{BANNER}}."`` | |||
``--append TEXT`` | |||
optional text to append to the banner (after an autoassessment, if | |||
any), like ``|importance=low`` | |||
``--autoassess`` | |||
``"Tagging with WikiProject banner {{BANNER}}."`` | |||
``--update`` | |||
updates existing banners with new fields; should include at least one | |||
of ``--append`` or ``--autoassess`` to be useful | |||
``--append PARAMS`` | |||
optional comma-separated parameters to append to the banner (after an | |||
auto-assessment, if any); use syntax ``importance=low,taskforce=yes`` | |||
to add ``|importance=low|taskforce=yes`` | |||
``--autoassess [CLASSES]`` | |||
try to assess each article's class automatically based on the class of | |||
other banners on the same page | |||
other banners on the same page; if CLASSES is given as a | |||
comma-separated list, only those classes will be auto-assessed | |||
``--only-with BANNER`` | |||
only tag pages that already have the given banner | |||
``--nocreate`` | |||
don't create new talk pages with just a banner if the page doesn't | |||
already exist | |||
``--recursive NUM`` | |||
recursively go through subcategories up to a maximum depth of ``NUM``, | |||
or if ``NUM`` isn't provided, go infinitely (this can be dangerous) | |||
``--tag-categories`` | |||
also tag category pages | |||
``--site SITE`` | |||
the ID of the site to tag pages on, defaulting to the... default site | |||
the ID of the site to tag pages on, defaulting to the default site | |||
``--dry-run`` | |||
don't actually make any edits, just log the pages that would have been | |||
edited | |||
""" | |||
name = "wikiproject_tagger" | |||
@@ -90,11 +102,10 @@ class WikiProjectTagger(Task): | |||
r"failed ?ga$", | |||
r"old ?prod( ?full)?$", | |||
r"(old|previous) ?afd$", | |||
r"((wikiproject|wp) ?)?bio(graph(y|ies))?$", | |||
] | |||
def _upperfirst(self, text): | |||
@staticmethod | |||
def _upperfirst(text): | |||
"""Try to uppercase the first letter of a string.""" | |||
try: | |||
return text[0].upper() + text[1:] | |||
@@ -114,15 +125,29 @@ class WikiProjectTagger(Task): | |||
site = self.bot.wiki.get_site(name=kwargs.get("site")) | |||
banner = kwargs["banner"] | |||
summary = kwargs.get("summary", "Adding WikiProject banner $3.") | |||
summary = kwargs.get("summary", "Tagging with WikiProject banner $3.") | |||
update = kwargs.get("update", False) | |||
append = kwargs.get("append") | |||
autoassess = kwargs.get("autoassess", False) | |||
ow_banner = kwargs.get("only-with") | |||
nocreate = kwargs.get("nocreate", False) | |||
recursive = kwargs.get("recursive", 0) | |||
tag_categories = kwargs.get("tag-categories", False) | |||
dry_run = kwargs.get("dry-run", False) | |||
banner, names = self.get_names(site, banner) | |||
if not names: | |||
return | |||
job = _Job(banner, names, summary, append, autoassess, nocreate) | |||
if ow_banner: | |||
_, only_with = self.get_names(site, ow_banner) | |||
if not only_with: | |||
return | |||
else: | |||
only_with = None | |||
job = _Job(banner=banner, names=names, summary=summary, update=update, | |||
append=append, autoassess=autoassess, only_with=only_with, | |||
nocreate=nocreate, tag_categories=tag_categories, | |||
dry_run=dry_run) | |||
try: | |||
self.run_job(kwargs, site, job, recursive) | |||
@@ -172,139 +197,237 @@ class WikiProjectTagger(Task): | |||
banner = banner.split(":", 1)[1] | |||
page = site.get_page(title) | |||
if page.exists != page.PAGE_EXISTS: | |||
self.logger.error(u"Banner [[{0}]] does not exist".format(title)) | |||
self.logger.error(u"Banner [[%s]] does not exist", title) | |||
return banner, None | |||
if banner == title: | |||
names = [self._upperfirst(banner)] | |||
else: | |||
names = [self._upperfirst(banner), self._upperfirst(title)] | |||
names = {banner, title} | |||
result = site.api_query(action="query", list="backlinks", bllimit=500, | |||
blfilterredir="redirects", bltitle=title) | |||
for backlink in result["query"]["backlinks"]: | |||
names.append(backlink["title"]) | |||
names.add(backlink["title"]) | |||
if backlink["ns"] == constants.NS_TEMPLATE: | |||
names.append(backlink["title"].split(":", 1)[1]) | |||
names.add(backlink["title"].split(":", 1)[1]) | |||
log = u"Found {0} aliases for banner [[{1}]]".format(len(names), title) | |||
self.logger.debug(log) | |||
log = u"Found %s aliases for banner [[%s]]" | |||
self.logger.debug(log, len(names), title) | |||
return banner, names | |||
def process_category(self, page, job, recursive): | |||
"""Try to tag all pages in the given category.""" | |||
self.logger.info(u"Processing category: [[{0]]".format(page.title)) | |||
if page.title in job.processed_cats: | |||
self.logger.debug(u"Skipping category, already processed: [[%s]]", | |||
page.title) | |||
return | |||
self.logger.info(u"Processing category: [[%s]]", page.title) | |||
job.processed_cats.add(page.title) | |||
if job.tag_categories: | |||
self.process_page(page, job) | |||
for member in page.get_members(): | |||
if member.namespace == constants.NS_CATEGORY: | |||
nspace = member.namespace | |||
if nspace == constants.NS_CATEGORY: | |||
if recursive is True: | |||
self.process_category(member, job, True) | |||
elif recursive: | |||
elif recursive > 0: | |||
self.process_category(member, job, recursive - 1) | |||
elif job.tag_categories: | |||
self.process_page(member, job) | |||
elif nspace in (constants.NS_USER, constants.NS_USER_TALK): | |||
continue | |||
else: | |||
self.process_page(member, job) | |||
def process_page(self, page, job): | |||
"""Try to tag a specific *page* using the *job* description.""" | |||
if not page.is_talkpage: | |||
page = page.toggle_talk() | |||
if page.title in job.processed_pages: | |||
self.logger.debug(u"Skipping page, already processed: [[%s]]", | |||
page.title) | |||
return | |||
job.processed_pages.add(page.title) | |||
if job.counter % 10 == 0: # Do a shutoff check every ten pages | |||
if self.shutoff_enabled(page.site): | |||
raise _ShutoffEnabled() | |||
job.counter += 1 | |||
if not page.is_talkpage: | |||
page = page.toggle_talk() | |||
try: | |||
code = page.parse() | |||
except exceptions.PageNotFoundError: | |||
if job.nocreate: | |||
log = u"Skipping nonexistent page: [[{0}]]".format(page.title) | |||
self.logger.info(log) | |||
else: | |||
log = u"Tagging new page: [[{0}]]".format(page.title) | |||
self.logger.info(log) | |||
banner = "{{" + job.banner + job.append + "}}" | |||
summary = job.summary.replace("$3", banner) | |||
page.edit(banner, self.make_summary(summary)) | |||
self.process_new_page(page, job) | |||
return | |||
except exceptions.InvalidPageError: | |||
log = u"Skipping invalid page: [[{0}]]".format(page.title) | |||
self.logger.error(log) | |||
self.logger.error(u"Skipping invalid page: [[%s]]", page.title) | |||
return | |||
is_update = False | |||
for template in code.ifilter_templates(recursive=True): | |||
name = self._upperfirst(template.name.strip()) | |||
if name in job.names: | |||
log = u"Skipping page: [[{0}]]; already tagged with '{1}'" | |||
self.logger.info(log.format(page.title, name)) | |||
if template.name.matches(job.names): | |||
if job.update: | |||
banner = template | |||
is_update = True | |||
break | |||
else: | |||
log = u"Skipping page: [[%s]]; already tagged with '%s'" | |||
self.logger.info(log, page.title, template.name) | |||
return | |||
if job.only_with: | |||
if not any(template.name.matches(job.only_with) | |||
for template in code.ifilter_templates(recursive=True)): | |||
log = u"Skipping page: [[%s]]; fails only-with condition" | |||
self.logger.info(log, page.title) | |||
return | |||
banner = self.make_banner(job, code) | |||
shell = self.get_banner_shell(code) | |||
if shell: | |||
if shell.has_param(1): | |||
shell.get(1).value.insert(0, banner + "\n") | |||
else: | |||
shell.add(1, banner) | |||
if is_update: | |||
old_banner = unicode(banner) | |||
self.update_banner(banner, job, code) | |||
if banner == old_banner: | |||
log = u"Skipping page: [[%s]]; already tagged and no updates" | |||
self.logger.info(log, page.title) | |||
return | |||
self.logger.info(u"Updating banner on page: [[%s]]", page.title) | |||
banner = banner.encode("utf8") | |||
else: | |||
self.add_banner(code, banner) | |||
self.apply_genfixes(code) | |||
self.logger.info(u"Tagging page: [[%s]]", page.title) | |||
banner = self.make_banner(job, code) | |||
shell = self.get_banner_shell(code) | |||
if shell: | |||
self.add_banner_to_shell(shell, banner) | |||
else: | |||
self.add_banner(code, banner) | |||
self.save_page(page, job, unicode(code), banner) | |||
self.logger.info(u"Tagging page: [[{0}]]".format(page.title)) | |||
summary = job.summary.replace("$3", banner) | |||
page.edit(unicode(code), self.make_summary(summary)) | |||
def process_new_page(self, page, job): | |||
"""Try to tag a *page* that doesn't exist yet using the *job*.""" | |||
if job.nocreate or job.only_with: | |||
log = u"Skipping nonexistent page: [[%s]]" | |||
self.logger.info(log, page.title) | |||
else: | |||
self.logger.info(u"Tagging new page: [[%s]]", page.title) | |||
banner = self.make_banner(job) | |||
self.save_page(page, job, banner, banner) | |||
def save_page(self, page, job, text, banner): | |||
"""Save a page with an updated banner.""" | |||
if job.dry_run: | |||
self.logger.debug(u"[DRY RUN] Banner: %s", banner) | |||
else: | |||
summary = job.summary.replace("$3", banner) | |||
page.edit(text, self.make_summary(summary), minor=True) | |||
def make_banner(self, job, code): | |||
def make_banner(self, job, code=None): | |||
"""Return banner text to add based on a *job* and a page's *code*.""" | |||
banner = "{{" + job.banner | |||
if job.autoassess: | |||
classes = {"fa": 0, "fl": 0, "ga": 0, "a": 0, "b": 0, "start": 0, | |||
"stub": 0, "list": 0, "dab": 0, "c": 0, "redirect": 0, | |||
"book": 0, "template": 0, "category": 0} | |||
for template in code.ifilter_templates(recursive=True): | |||
if template.has_param("class"): | |||
value = unicode(template.get("class").value).lower() | |||
if value in classes: | |||
classes[value] += 1 | |||
values = tuple(classes.values()) | |||
best = max(values) | |||
banner = job.banner | |||
if code is not None and job.autoassess is not False: | |||
assess, reason = self.get_autoassessment(code, job.autoassess) | |||
if assess: | |||
banner += "|class=" + assess | |||
if reason: | |||
banner += "|auto=" + reason | |||
if job.append: | |||
banner += "|" + "|".join(job.append.split(",")) | |||
return "{{" + banner + "}}" | |||
def update_banner(self, banner, job, code): | |||
"""Update an existing *banner* based on a *job* and a page's *code*.""" | |||
has = lambda key: (banner.has(key) and | |||
banner.get(key).value.strip() not in ("", "?")) | |||
if job.autoassess is not False: | |||
if not has("class"): | |||
assess, reason = self.get_autoassessment(code, job.autoassess) | |||
if assess: | |||
banner.add("class", assess) | |||
if reason: | |||
banner.add("auto", reason) | |||
if job.append: | |||
for param in job.append.split(","): | |||
key, value = param.split("=", 1) | |||
if not has(key): | |||
banner.add(key, value) | |||
def get_autoassessment(self, code, only_classes=None): | |||
"""Get an autoassessment for a page. | |||
Return (assessed class as a string or None, assessment reason or None). | |||
""" | |||
if only_classes is None or only_classes is True: | |||
classnames = ["a", "b", "book", "c", "dab", "fa", "fl", "ga", | |||
"list", "redirect", "start", "stub"] | |||
else: | |||
classnames = [klass.strip().lower() | |||
for klass in only_classes.split(",")] | |||
classes = {klass: 0 for klass in classnames} | |||
for template in code.ifilter_templates(recursive=True): | |||
if template.has("class"): | |||
value = unicode(template.get("class").value).lower() | |||
if value in classes: | |||
classes[value] += 1 | |||
values = tuple(classes.values()) | |||
best = max(values) | |||
if best: | |||
confidence = float(best) / sum(values) | |||
if confidence > 0.75: | |||
rank = tuple(classes.keys())[values.index(best)] | |||
if rank in ("fa", "fl", "ga"): | |||
banner += "|class=" + rank.upper() | |||
return rank.upper(), "inherit" | |||
else: | |||
banner += "|class=" + self._upperfirst(rank) | |||
return banner + job.append + "}}" | |||
return self._upperfirst(rank), "inherit" | |||
return None, None | |||
def get_banner_shell(self, code): | |||
"""Return the banner shell template within *code*, else ``None``.""" | |||
regex = r"^\{\{\s*((WikiProject|WP)[ _]?Banner[ _]?S(hell)?|W(BPS|PBS|PB)|Shell)" | |||
regex = r"^\{\{\s*((WikiProject|WP)[ _]?Banner[ _]?S(hell)?|W(BPS|PBS|PB)|Shell)\s*(\||\}\})" | |||
shells = code.filter_templates(matches=regex) | |||
if not shells: | |||
shells = code.filter_templates(matches=regex, recursive=True) | |||
if shells: | |||
log = u"Inserting banner into shell: {0}" | |||
self.logger.debug(log.format(shells[0].name)) | |||
log = u"Inserting banner into shell: %s" | |||
self.logger.debug(log, shells[0].name) | |||
return shells[0] | |||
def add_banner_to_shell(self, shell, banner): | |||
"""Add *banner* to *shell*.""" | |||
if shell.has_param(1): | |||
if unicode(shell.get(1).value).endswith("\n"): | |||
banner += "\n" | |||
else: | |||
banner = "\n" + banner | |||
shell.get(1).value.append(banner) | |||
else: | |||
shell.add(1, banner) | |||
def add_banner(self, code, banner): | |||
"""Add *banner* to *code*, following template order conventions.""" | |||
index = 0 | |||
for i, template in enumerate(code.ifilter_templates()): | |||
predecessor = None | |||
for template in code.ifilter_templates(recursive=False): | |||
name = template.name.lower().replace("_", " ") | |||
for regex in self.TOP_TEMPS: | |||
if re.match(regex, name): | |||
self.logger.info("Skipping top template: {0}".format(name)) | |||
index = i + 1 | |||
self.logger.debug(u"Inserting banner at index {0}".format(index)) | |||
code.insert(index, banner) | |||
def apply_genfixes(self, code): | |||
"""Apply general fixes to *code*, such as template substitution.""" | |||
regex = r"^\{\{\s*((un|no)?s(i((gn|ng)(ed3?)?|g))?|usu|tilde|forgot to sign|without signature)" | |||
for template in code.ifilter_templates(matches=regex): | |||
self.logger.debug("Applying genfix: substitute {{unsigned}}") | |||
template.name = "subst:unsigned" | |||
self.logger.debug(u"Skipping past top template: %s", name) | |||
predecessor = template | |||
break | |||
if "wikiproject" in name or name.startswith("wp"): | |||
self.logger.debug(u"Skipping past banner template: %s", name) | |||
predecessor = template | |||
if predecessor: | |||
self.logger.debug("Inserting banner after template") | |||
if not unicode(predecessor).endswith("\n"): | |||
banner = "\n" + banner | |||
post = code.index(predecessor) + 1 | |||
if len(code.nodes) > post and not code.get(post).startswith("\n"): | |||
banner += "\n" | |||
code.insert_after(predecessor, banner) | |||
else: | |||
self.logger.debug("Inserting banner at beginning") | |||
code.insert(0, banner + "\n") | |||
class _Job(object): | |||
"""Represents a single wikiproject-tagging task. | |||
@@ -313,14 +436,21 @@ class _Job(object): | |||
or not to autoassess and create new pages from scratch, and a counter of | |||
the number of pages edited. | |||
""" | |||
def __init__(self, banner, names, summary, append, autoassess, nocreate): | |||
self.banner = banner | |||
self.names = names | |||
self.summary = summary | |||
self.append = append | |||
self.autoassess = autoassess | |||
self.nocreate = nocreate | |||
def __init__(self, **kwargs): | |||
self.banner = kwargs["banner"] | |||
self.names = kwargs["names"] | |||
self.summary = kwargs["summary"] | |||
self.update = kwargs["update"] | |||
self.append = kwargs["append"] | |||
self.autoassess = kwargs["autoassess"] | |||
self.only_with = kwargs["only_with"] | |||
self.nocreate = kwargs["nocreate"] | |||
self.tag_categories = kwargs["tag_categories"] | |||
self.dry_run = kwargs["dry_run"] | |||
self.counter = 0 | |||
self.processed_cats = set() | |||
self.processed_pages = set() | |||
class _ShutoffEnabled(Exception): | |||
@@ -1,6 +1,6 @@ | |||
# -*- coding: utf-8 -*- | |||
# | |||
# Copyright (C) 2009-2015 Ben Kurtovic <ben.kurtovic@gmail.com> | |||
# Copyright (C) 2009-2016 Ben Kurtovic <ben.kurtovic@gmail.com> | |||
# | |||
# Permission is hereby granted, free of charge, to any person obtaining a copy | |||
# of this software and associated documentation files (the "Software"), to deal | |||
@@ -23,15 +23,13 @@ | |||
from time import sleep, time | |||
from urllib2 import build_opener | |||
from earwigbot import exceptions, importer | |||
from earwigbot import exceptions | |||
from earwigbot.wiki.copyvios.markov import MarkovChain | |||
from earwigbot.wiki.copyvios.parsers import ArticleTextParser | |||
from earwigbot.wiki.copyvios.search import YahooBOSSSearchEngine | |||
from earwigbot.wiki.copyvios.search import SEARCH_ENGINES | |||
from earwigbot.wiki.copyvios.workers import ( | |||
globalize, localize, CopyvioWorkspace) | |||
oauth = importer.new("oauth2") | |||
__all__ = ["CopyvioMixIn", "globalize", "localize"] | |||
class CopyvioMixIn(object): | |||
@@ -48,7 +46,8 @@ class CopyvioMixIn(object): | |||
def __init__(self, site): | |||
self._search_config = site._search_config | |||
self._exclusions_db = self._search_config.get("exclusions_db") | |||
self._addheaders = site._opener.addheaders | |||
self._addheaders = [("User-Agent", site.user_agent), | |||
("Accept-Encoding", "gzip")] | |||
def _get_search_engine(self): | |||
"""Return a function that can be called to do web searches. | |||
@@ -63,19 +62,23 @@ class CopyvioMixIn(object): | |||
required package or module, like oauth2 for "Yahoo! BOSS". | |||
""" | |||
engine = self._search_config["engine"] | |||
if engine not in SEARCH_ENGINES: | |||
raise exceptions.UnknownSearchEngineError(engine) | |||
klass = SEARCH_ENGINES[engine] | |||
credentials = self._search_config["credentials"] | |||
opener = build_opener() | |||
opener.addheaders = self._addheaders | |||
if engine == "Yahoo! BOSS": | |||
for dep in klass.requirements(): | |||
try: | |||
oauth.__version__ # Force-load the lazy module | |||
except ImportError: | |||
e = "Yahoo! BOSS requires the 'oauth2' package: https://github.com/simplegeo/python-oauth2" | |||
__import__(dep).__name__ | |||
except (ImportError, AttributeError): | |||
e = "Missing a required dependency ({}) for the {} engine" | |||
e = e.format(dep, engine) | |||
raise exceptions.UnsupportedSearchEngineError(e) | |||
opener = build_opener() | |||
opener.addheaders = self._addheaders | |||
return YahooBOSSSearchEngine(credentials, opener) | |||
raise exceptions.UnknownSearchEngineError(engine) | |||
return klass(credentials, opener) | |||
def copyvio_check(self, min_confidence=0.75, max_queries=15, max_time=-1, | |||
no_searches=False, no_links=False, short_circuit=True): | |||
@@ -114,15 +117,18 @@ class CopyvioMixIn(object): | |||
log = u"Starting copyvio check for [[{0}]]" | |||
self._logger.info(log.format(self.title)) | |||
searcher = self._get_search_engine() | |||
parser = ArticleTextParser(self.get()) | |||
parser = ArticleTextParser(self.get(), { | |||
"nltk_dir": self._search_config["nltk_dir"], | |||
"lang": self._site.lang | |||
}) | |||
article = MarkovChain(parser.strip()) | |||
parser_args = {} | |||
if self._exclusions_db: | |||
self._exclusions_db.sync(self.site.name) | |||
exclude = lambda u: self._exclusions_db.check(self.site.name, u) | |||
parser_args["mirror_hints"] = self._exclusions_db.get_mirror_hints( | |||
self.site.name) | |||
parser_args["mirror_hints"] = \ | |||
self._exclusions_db.get_mirror_hints(self) | |||
else: | |||
exclude = None | |||
@@ -139,7 +145,7 @@ class CopyvioMixIn(object): | |||
workspace.enqueue(parser.get_links(), exclude) | |||
num_queries = 0 | |||
if not no_searches: | |||
chunks = parser.chunk(self._search_config["nltk_dir"], max_queries) | |||
chunks = parser.chunk(max_queries) | |||
for chunk in chunks: | |||
if short_circuit and workspace.finished: | |||
workspace.possible_miss = True | |||
@@ -1,6 +1,6 @@ | |||
# -*- coding: utf-8 -*- | |||
# | |||
# Copyright (C) 2009-2015 Ben Kurtovic <ben.kurtovic@gmail.com> | |||
# Copyright (C) 2009-2016 Ben Kurtovic <ben.kurtovic@gmail.com> | |||
# | |||
# Permission is hereby granted, free of charge, to any person obtaining a copy | |||
# of this software and associated documentation files (the "Software"), to deal | |||
@@ -122,7 +122,7 @@ class ExclusionsDB(object): | |||
site = self._sitesdb.get_site("enwiki") | |||
else: | |||
site = self._sitesdb.get_site(sitename) | |||
with sqlite.connect(self._dbfile) as conn, self._db_access_lock: | |||
with self._db_access_lock, sqlite.connect(self._dbfile) as conn: | |||
urls = set() | |||
for (source,) in conn.execute(query1, (sitename,)): | |||
urls |= self._load_source(site, source) | |||
@@ -140,7 +140,7 @@ class ExclusionsDB(object): | |||
def _get_last_update(self, sitename): | |||
"""Return the UNIX timestamp of the last time the db was updated.""" | |||
query = "SELECT update_time FROM updates WHERE update_sitename = ?" | |||
with sqlite.connect(self._dbfile) as conn, self._db_access_lock: | |||
with self._db_access_lock, sqlite.connect(self._dbfile) as conn: | |||
try: | |||
result = conn.execute(query, (sitename,)).fetchone() | |||
except sqlite.OperationalError: | |||
@@ -176,7 +176,7 @@ class ExclusionsDB(object): | |||
normalized = re.sub(r"^https?://(www\.)?", "", url.lower()) | |||
query = """SELECT exclusion_url FROM exclusions | |||
WHERE exclusion_sitename = ? OR exclusion_sitename = ?""" | |||
with sqlite.connect(self._dbfile) as conn, self._db_access_lock: | |||
with self._db_access_lock, sqlite.connect(self._dbfile) as conn: | |||
for (excl,) in conn.execute(query, (sitename, "all")): | |||
if excl.startswith("*."): | |||
parsed = urlparse(url.lower()) | |||
@@ -200,21 +200,23 @@ class ExclusionsDB(object): | |||
self._logger.debug(log) | |||
return False | |||
def get_mirror_hints(self, sitename, try_mobile=True): | |||
def get_mirror_hints(self, page, try_mobile=True): | |||
"""Return a list of strings that indicate the existence of a mirror. | |||
The source parser checks for the presence of these strings inside of | |||
certain HTML tag attributes (``"href"`` and ``"src"``). | |||
""" | |||
site = self._sitesdb.get_site(sitename) | |||
base = site.domain + site._script_path | |||
roots = [base] | |||
site = page.site | |||
path = urlparse(page.url).path | |||
roots = [site.domain] | |||
scripts = ["index.php", "load.php", "api.php"] | |||
if try_mobile: | |||
fragments = re.search(r"^([\w]+)\.([\w]+).([\w]+)$", site.domain) | |||
if fragments: | |||
mobile = "{0}.m.{1}.{2}".format(*fragments.groups()) | |||
roots.append(mobile + site._script_path) | |||
roots.append("{0}.m.{1}.{2}".format(*fragments.groups())) | |||
return [root + "/" + script for root in roots for script in scripts] | |||
general = [root + site._script_path + "/" + script | |||
for root in roots for script in scripts] | |||
specific = [root + path for root in roots] | |||
return general + specific |
@@ -20,7 +20,6 @@ | |||
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE | |||
# SOFTWARE. | |||
from collections import defaultdict | |||
from re import sub, UNICODE | |||
__all__ = ["EMPTY", "EMPTY_INTERSECTION", "MarkovChain", | |||
@@ -34,23 +33,27 @@ class MarkovChain(object): | |||
def __init__(self, text): | |||
self.text = text | |||
self.chain = defaultdict(lambda: defaultdict(lambda: 0)) | |||
words = sub(r"[^\w\s-]", "", text.lower(), flags=UNICODE).split() | |||
self.chain = self._build() | |||
self.size = self._get_size() | |||
def _build(self): | |||
"""Build and return the Markov chain from the input text.""" | |||
padding = self.degree - 1 | |||
words = sub(r"[^\w\s-]", "", self.text.lower(), flags=UNICODE).split() | |||
words = ([self.START] * padding) + words + ([self.END] * padding) | |||
for i in range(len(words) - self.degree + 1): | |||
last = i + self.degree - 1 | |||
self.chain[tuple(words[i:last])][words[last]] += 1 | |||
self.size = self._get_size() | |||
chain = {} | |||
for i in xrange(len(words) - self.degree + 1): | |||
phrase = tuple(words[i:i+self.degree]) | |||
if phrase in chain: | |||
chain[phrase] += 1 | |||
else: | |||
chain[phrase] = 1 | |||
return chain | |||
def _get_size(self): | |||
"""Return the size of the Markov chain: the total number of nodes.""" | |||
size = 0 | |||
for node in self.chain.itervalues(): | |||
for hits in node.itervalues(): | |||
size += hits | |||
return size | |||
return sum(self.chain.itervalues()) | |||
def __repr__(self): | |||
"""Return the canonical string representation of the MarkovChain.""" | |||
@@ -65,20 +68,21 @@ class MarkovChainIntersection(MarkovChain): | |||
"""Implements the intersection of two chains (i.e., their shared nodes).""" | |||
def __init__(self, mc1, mc2): | |||
self.chain = defaultdict(lambda: defaultdict(lambda: 0)) | |||
self.mc1, self.mc2 = mc1, mc2 | |||
c1 = mc1.chain | |||
c2 = mc2.chain | |||
for word, nodes1 in c1.iteritems(): | |||
if word in c2: | |||
nodes2 = c2[word] | |||
for node, count1 in nodes1.iteritems(): | |||
if node in nodes2: | |||
count2 = nodes2[node] | |||
self.chain[word][node] = min(count1, count2) | |||
self.chain = self._build() | |||
self.size = self._get_size() | |||
def _build(self): | |||
"""Build and return the Markov chain from the input chains.""" | |||
c1 = self.mc1.chain | |||
c2 = self.mc2.chain | |||
chain = {} | |||
for phrase in c1: | |||
if phrase in c2: | |||
chain[phrase] = min(c1[phrase], c2[phrase]) | |||
return chain | |||
def __repr__(self): | |||
"""Return the canonical string representation of the intersection.""" | |||
res = "MarkovChainIntersection(mc1={0!r}, mc2={1!r})" | |||
@@ -1,6 +1,6 @@ | |||
# -*- coding: utf-8 -*- | |||
# | |||
# Copyright (C) 2009-2015 Ben Kurtovic <ben.kurtovic@gmail.com> | |||
# Copyright (C) 2009-2019 Ben Kurtovic <ben.kurtovic@gmail.com> | |||
# | |||
# Permission is hereby granted, free of charge, to any person obtaining a copy | |||
# of this software and associated documentation files (the "Software"), to deal | |||
@@ -34,8 +34,6 @@ nltk = importer.new("nltk") | |||
converter = importer.new("pdfminer.converter") | |||
pdfinterp = importer.new("pdfminer.pdfinterp") | |||
pdfpage = importer.new("pdfminer.pdfpage") | |||
pdftypes = importer.new("pdfminer.pdftypes") | |||
psparser = importer.new("pdfminer.psparser") | |||
__all__ = ["ArticleTextParser", "get_parser"] | |||
@@ -61,6 +59,26 @@ class ArticleTextParser(_BaseTextParser): | |||
"""A parser that can strip and chunk wikicode article text.""" | |||
TYPE = "Article" | |||
TEMPLATE_MERGE_THRESHOLD = 35 | |||
NLTK_DEFAULT = "english" | |||
NLTK_LANGS = { | |||
"cs": "czech", | |||
"da": "danish", | |||
"de": "german", | |||
"el": "greek", | |||
"en": "english", | |||
"es": "spanish", | |||
"et": "estonian", | |||
"fi": "finnish", | |||
"fr": "french", | |||
"it": "italian", | |||
"nl": "dutch", | |||
"no": "norwegian", | |||
"pl": "polish", | |||
"pt": "portuguese", | |||
"sl": "slovene", | |||
"sv": "swedish", | |||
"tr": "turkish" | |||
} | |||
def _merge_templates(self, code): | |||
"""Merge template contents in to wikicode when the values are long.""" | |||
@@ -76,6 +94,47 @@ class ArticleTextParser(_BaseTextParser): | |||
else: | |||
code.remove(template) | |||
def _get_tokenizer(self): | |||
"""Return a NLTK punctuation tokenizer for the article's language.""" | |||
datafile = lambda lang: "file:" + path.join( | |||
self._args["nltk_dir"], "tokenizers", "punkt", lang + ".pickle") | |||
lang = self.NLTK_LANGS.get(self._args.get("lang"), self.NLTK_DEFAULT) | |||
try: | |||
nltk.data.load(datafile(self.NLTK_DEFAULT)) | |||
except LookupError: | |||
nltk.download("punkt", self._args["nltk_dir"]) | |||
return nltk.data.load(datafile(lang)) | |||
def _get_sentences(self, min_query, max_query, split_thresh): | |||
"""Split the article text into sentences of a certain length.""" | |||
def cut_sentence(words): | |||
div = len(words) | |||
if div == 0: | |||
return [] | |||
length = len(" ".join(words)) | |||
while length > max_query: | |||
div -= 1 | |||
length -= len(words[div]) + 1 | |||
result = [] | |||
if length >= split_thresh: | |||
result.append(" ".join(words[:div])) | |||
return result + cut_sentence(words[div + 1:]) | |||
tokenizer = self._get_tokenizer() | |||
sentences = [] | |||
if not hasattr(self, "clean"): | |||
self.strip() | |||
for sentence in tokenizer.tokenize(self.clean): | |||
if len(sentence) <= max_query: | |||
sentences.append(sentence) | |||
else: | |||
sentences.extend(cut_sentence(sentence.split())) | |||
return [sen for sen in sentences if len(sen) >= min_query] | |||
def strip(self): | |||
"""Clean the page's raw text by removing templates and formatting. | |||
@@ -118,7 +177,7 @@ class ArticleTextParser(_BaseTextParser): | |||
self.clean = re.sub("\n\n+", "\n", clean).strip() | |||
return self.clean | |||
def chunk(self, nltk_dir, max_chunks, min_query=8, max_query=128): | |||
def chunk(self, max_chunks, min_query=8, max_query=128, split_thresh=32): | |||
"""Convert the clean article text into a list of web-searchable chunks. | |||
No greater than *max_chunks* will be returned. Each chunk will only be | |||
@@ -130,27 +189,11 @@ class ArticleTextParser(_BaseTextParser): | |||
This is implemented using :py:mod:`nltk` (http://nltk.org/). A base | |||
directory (*nltk_dir*) is required to store nltk's punctuation | |||
database. This is typically located in the bot's working directory. | |||
database, and should be passed as an argument to the constructor. It is | |||
typically located in the bot's working directory. | |||
""" | |||
datafile = path.join(nltk_dir, "tokenizers", "punkt", "english.pickle") | |||
try: | |||
tokenizer = nltk.data.load("file:" + datafile) | |||
except LookupError: | |||
nltk.download("punkt", nltk_dir) | |||
tokenizer = nltk.data.load("file:" + datafile) | |||
sentences = [] | |||
for sentence in tokenizer.tokenize(self.clean): | |||
if len(sentence) > max_query: | |||
words = sentence.split() | |||
while len(" ".join(words)) > max_query: | |||
words.pop() | |||
sentence = " ".join(words) | |||
if len(sentence) < min_query: | |||
continue | |||
sentences.append(sentence) | |||
if max_chunks >= len(sentences): | |||
sentences = self._get_sentences(min_query, max_query, split_thresh) | |||
if len(sentences) <= max_chunks: | |||
return sentences | |||
chunks = [] | |||
@@ -187,6 +230,20 @@ class _HTMLParser(_BaseTextParser): | |||
"script", "style" | |||
] | |||
def _fail_if_mirror(self, soup): | |||
"""Look for obvious signs that the given soup is a wiki mirror. | |||
If so, raise ParserExclusionError, which is caught in the workers and | |||
causes this source to excluded. | |||
""" | |||
if "mirror_hints" not in self._args: | |||
return | |||
func = lambda attr: attr and any( | |||
hint in attr for hint in self._args["mirror_hints"]) | |||
if soup.find_all(href=func) or soup.find_all(src=func): | |||
raise ParserExclusionError() | |||
def parse(self): | |||
"""Return the actual text contained within an HTML document. | |||
@@ -203,12 +260,7 @@ class _HTMLParser(_BaseTextParser): | |||
# no scrapable content (possibly JS or <frame> magic): | |||
return "" | |||
if "mirror_hints" in self._args: | |||
# Look for obvious signs that this is a mirror: | |||
func = lambda attr: attr and any( | |||
hint in attr for hint in self._args["mirror_hints"]) | |||
if soup.find_all(href=func) or soup.find_all(src=func): | |||
raise ParserExclusionError() | |||
self._fail_if_mirror(soup) | |||
soup = soup.body | |||
is_comment = lambda text: isinstance(text, bs4.element.Comment) | |||
@@ -240,7 +292,7 @@ class _PDFParser(_BaseTextParser): | |||
pages = pdfpage.PDFPage.get_pages(StringIO(self.text)) | |||
for page in pages: | |||
interp.process_page(page) | |||
except (pdftypes.PDFException, psparser.PSException, AssertionError): | |||
except Exception: # pylint: disable=broad-except | |||
return output.getvalue().decode("utf8") | |||
finally: | |||
conv.close() | |||
@@ -1,6 +1,6 @@ | |||
# -*- coding: utf-8 -*- | |||
# | |||
# Copyright (C) 2009-2015 Ben Kurtovic <ben.kurtovic@gmail.com> | |||
# Copyright (C) 2009-2016 Ben Kurtovic <ben.kurtovic@gmail.com> | |||
# | |||
# Permission is hereby granted, free of charge, to any person obtaining a copy | |||
# of this software and associated documentation files (the "Software"), to deal | |||
@@ -22,19 +22,22 @@ | |||
from gzip import GzipFile | |||
from json import loads | |||
from re import sub as re_sub | |||
from socket import error | |||
from StringIO import StringIO | |||
from urllib import quote | |||
from urllib import quote, urlencode | |||
from urllib2 import URLError | |||
from earwigbot import importer | |||
from earwigbot.exceptions import SearchQueryError | |||
lxml = importer.new("lxml") | |||
oauth = importer.new("oauth2") | |||
__all__ = ["BaseSearchEngine", "YahooBOSSSearchEngine"] | |||
__all__ = ["BingSearchEngine", "GoogleSearchEngine", "YahooBOSSSearchEngine", | |||
"YandexSearchEngine", "SEARCH_ENGINES"] | |||
class BaseSearchEngine(object): | |||
class _BaseSearchEngine(object): | |||
"""Base class for a simple search engine interface.""" | |||
name = "Base" | |||
@@ -42,6 +45,7 @@ class BaseSearchEngine(object): | |||
"""Store credentials (*cred*) and *opener* for searching later on.""" | |||
self.cred = cred | |||
self.opener = opener | |||
self.count = 5 | |||
def __repr__(self): | |||
"""Return the canonical string representation of the search engine.""" | |||
@@ -51,6 +55,31 @@ class BaseSearchEngine(object): | |||
"""Return a nice string representation of the search engine.""" | |||
return "<{0}>".format(self.__class__.__name__) | |||
def _open(self, *args): | |||
"""Open a URL (like urlopen) and try to return its contents.""" | |||
try: | |||
response = self.opener.open(*args) | |||
result = response.read() | |||
except (URLError, error) as exc: | |||
raise SearchQueryError("{0} Error: {1}".format(self.name, exc)) | |||
if response.headers.get("Content-Encoding") == "gzip": | |||
stream = StringIO(result) | |||
gzipper = GzipFile(fileobj=stream) | |||
result = gzipper.read() | |||
code = response.getcode() | |||
if code != 200: | |||
err = "{0} Error: got response code '{1}':\n{2}'" | |||
raise SearchQueryError(err.format(self.name, code, result)) | |||
return result | |||
@staticmethod | |||
def requirements(): | |||
"""Return a list of packages required by this search engine.""" | |||
return [] | |||
def search(self, query): | |||
"""Use this engine to search for *query*. | |||
@@ -59,7 +88,87 @@ class BaseSearchEngine(object): | |||
raise NotImplementedError() | |||
class YahooBOSSSearchEngine(BaseSearchEngine): | |||
class BingSearchEngine(_BaseSearchEngine): | |||
"""A search engine interface with Bing Search (via Azure Marketplace).""" | |||
name = "Bing" | |||
def __init__(self, cred, opener): | |||
super(BingSearchEngine, self).__init__(cred, opener) | |||
key = self.cred["key"] | |||
auth = (key + ":" + key).encode("base64").replace("\n", "") | |||
self.opener.addheaders.append(("Authorization", "Basic " + auth)) | |||
def search(self, query): | |||
"""Do a Bing web search for *query*. | |||
Returns a list of URLs ranked by relevance (as determined by Bing). | |||
Raises :py:exc:`~earwigbot.exceptions.SearchQueryError` on errors. | |||
""" | |||
service = "SearchWeb" if self.cred["type"] == "searchweb" else "Search" | |||
url = "https://api.datamarket.azure.com/Bing/{0}/Web?".format(service) | |||
params = { | |||
"$format": "json", | |||
"$top": str(self.count), | |||
"Query": "'\"" + query.replace('"', "").encode("utf8") + "\"'", | |||
"Market": "'en-US'", | |||
"Adult": "'Off'", | |||
"Options": "'DisableLocationDetection'", | |||
"WebSearchOptions": "'DisableHostCollapsing+DisableQueryAlterations'" | |||
} | |||
result = self._open(url + urlencode(params)) | |||
try: | |||
res = loads(result) | |||
except ValueError: | |||
err = "Bing Error: JSON could not be decoded" | |||
raise SearchQueryError(err) | |||
try: | |||
results = res["d"]["results"] | |||
except KeyError: | |||
return [] | |||
return [result["Url"] for result in results] | |||
class GoogleSearchEngine(_BaseSearchEngine): | |||
"""A search engine interface with Google Search.""" | |||
name = "Google" | |||
def search(self, query): | |||
"""Do a Google web search for *query*. | |||
Returns a list of URLs ranked by relevance (as determined by Google). | |||
Raises :py:exc:`~earwigbot.exceptions.SearchQueryError` on errors. | |||
""" | |||
domain = self.cred.get("proxy", "www.googleapis.com") | |||
url = "https://{0}/customsearch/v1?".format(domain) | |||
params = { | |||
"cx": self.cred["id"], | |||
"key": self.cred["key"], | |||
"q": '"' + query.replace('"', "").encode("utf8") + '"', | |||
"alt": "json", | |||
"num": str(self.count), | |||
"safe": "off", | |||
"fields": "items(link)" | |||
} | |||
result = self._open(url + urlencode(params)) | |||
try: | |||
res = loads(result) | |||
except ValueError: | |||
err = "Google Error: JSON could not be decoded" | |||
raise SearchQueryError(err) | |||
try: | |||
return [item["link"] for item in res["items"]] | |||
except KeyError: | |||
return [] | |||
class YahooBOSSSearchEngine(_BaseSearchEngine): | |||
"""A search engine interface with Yahoo! BOSS.""" | |||
name = "Yahoo! BOSS" | |||
@@ -70,11 +179,14 @@ class YahooBOSSSearchEngine(BaseSearchEngine): | |||
args = ["=".join((enc(k), enc(v))) for k, v in params.iteritems()] | |||
return base + "?" + "&".join(args) | |||
@staticmethod | |||
def requirements(): | |||
return ["oauth2"] | |||
def search(self, query): | |||
"""Do a Yahoo! BOSS web search for *query*. | |||
Returns a list of URLs, no more than five, ranked by relevance | |||
(as determined by Yahoo). | |||
Returns a list of URLs ranked by relevance (as determined by Yahoo). | |||
Raises :py:exc:`~earwigbot.exceptions.SearchQueryError` on errors. | |||
""" | |||
key, secret = self.cred["key"], self.cred["secret"] | |||
@@ -86,34 +198,69 @@ class YahooBOSSSearchEngine(BaseSearchEngine): | |||
"oauth_nonce": oauth.generate_nonce(), | |||
"oauth_timestamp": oauth.Request.make_timestamp(), | |||
"oauth_consumer_key": consumer.key, | |||
"q": '"' + query.encode("utf8") + '"', "count": "5", | |||
"type": "html,text,pdf", "format": "json", | |||
"q": '"' + query.encode("utf8") + '"', | |||
"count": str(self.count), | |||
"type": "html,text,pdf", | |||
"format": "json", | |||
} | |||
req = oauth.Request(method="GET", url=url, parameters=params) | |||
req.sign_request(oauth.SignatureMethod_HMAC_SHA1(), consumer, None) | |||
try: | |||
response = self.opener.open(self._build_url(url, req)) | |||
result = response.read() | |||
except (URLError, error) as exc: | |||
raise SearchQueryError("Yahoo! BOSS Error: " + str(exc)) | |||
if response.headers.get("Content-Encoding") == "gzip": | |||
stream = StringIO(result) | |||
gzipper = GzipFile(fileobj=stream) | |||
result = gzipper.read() | |||
result = self._open(self._build_url(url, req)) | |||
if response.getcode() != 200: | |||
e = "Yahoo! BOSS Error: got response code '{0}':\n{1}'" | |||
raise SearchQueryError(e.format(response.getcode(), result)) | |||
try: | |||
res = loads(result) | |||
except ValueError: | |||
e = "Yahoo! BOSS Error: JSON could not be decoded" | |||
raise SearchQueryError(e) | |||
err = "Yahoo! BOSS Error: JSON could not be decoded" | |||
raise SearchQueryError(err) | |||
try: | |||
results = res["bossresponse"]["web"]["results"] | |||
except KeyError: | |||
return [] | |||
return [result["url"] for result in results] | |||
class YandexSearchEngine(_BaseSearchEngine): | |||
"""A search engine interface with Yandex Search.""" | |||
name = "Yandex" | |||
@staticmethod | |||
def requirements(): | |||
return ["lxml.etree"] | |||
def search(self, query): | |||
"""Do a Yandex web search for *query*. | |||
Returns a list of URLs ranked by relevance (as determined by Yandex). | |||
Raises :py:exc:`~earwigbot.exceptions.SearchQueryError` on errors. | |||
""" | |||
domain = self.cred.get("proxy", "yandex.com") | |||
url = "https://{0}/search/xml?".format(domain) | |||
query = re_sub(r"[^a-zA-Z0-9 ]", "", query).encode("utf8") | |||
params = { | |||
"user": self.cred["user"], | |||
"key": self.cred["key"], | |||
"query": '"' + query + '"', | |||
"l10n": "en", | |||
"filter": "none", | |||
"maxpassages": "1", | |||
"groupby": "mode=flat.groups-on-page={0}".format(self.count) | |||
} | |||
result = self._open(url + urlencode(params)) | |||
try: | |||
data = lxml.etree.fromstring(result) | |||
return [elem.text for elem in data.xpath(".//url")] | |||
except lxml.etree.Error as exc: | |||
raise SearchQueryError("Yandex XML parse error: " + str(exc)) | |||
SEARCH_ENGINES = { | |||
"Bing": BingSearchEngine, | |||
"Google": GoogleSearchEngine, | |||
"Yahoo! BOSS": YahooBOSSSearchEngine, | |||
"Yandex": YandexSearchEngine | |||
} |
@@ -1,6 +1,6 @@ | |||
# -*- coding: utf-8 -*- | |||
# | |||
# Copyright (C) 2009-2015 Ben Kurtovic <ben.kurtovic@gmail.com> | |||
# Copyright (C) 2009-2019 Ben Kurtovic <ben.kurtovic@gmail.com> | |||
# | |||
# Permission is hereby granted, free of charge, to any person obtaining a copy | |||
# of this software and associated documentation files (the "Software"), to deal | |||
@@ -128,7 +128,7 @@ class _CopyvioWorker(object): | |||
url = source.url.encode("utf8") | |||
try: | |||
response = self._opener.open(url, timeout=source.timeout) | |||
except (URLError, HTTPException, socket_error): | |||
except (URLError, HTTPException, socket_error, ValueError): | |||
return None | |||
try: | |||
@@ -203,6 +203,28 @@ class _CopyvioWorker(object): | |||
self._queues.lock.release() | |||
return source | |||
def _handle_once(self): | |||
"""Handle a single source from one of the queues.""" | |||
try: | |||
source = self._dequeue() | |||
except Empty: | |||
self._logger.debug("Exiting: queue timed out") | |||
return False | |||
except StopIteration: | |||
self._logger.debug("Exiting: got stop signal") | |||
return False | |||
try: | |||
text = self._open_url(source) | |||
except ParserExclusionError: | |||
self._logger.debug("Source excluded by content parser") | |||
source.skipped = source.excluded = True | |||
source.finish_work() | |||
else: | |||
chain = MarkovChain(text) if text else None | |||
source.workspace.compare(source, chain) | |||
return True | |||
def _run(self): | |||
"""Main entry point for the worker thread. | |||
@@ -211,24 +233,8 @@ class _CopyvioWorker(object): | |||
now empty. | |||
""" | |||
while True: | |||
try: | |||
source = self._dequeue() | |||
except Empty: | |||
self._logger.debug("Exiting: queue timed out") | |||
return | |||
except StopIteration: | |||
self._logger.debug("Exiting: got stop signal") | |||
return | |||
try: | |||
text = self._open_url(source) | |||
except ParserExclusionError: | |||
self._logger.debug("Source excluded by content parser") | |||
source.skipped = source.excluded = True | |||
source.finish_work() | |||
else: | |||
chain = MarkovChain(text) if text else None | |||
source.workspace.compare(source, chain) | |||
if not self._handle_once(): | |||
break | |||
def start(self): | |||
"""Start the copyvio worker in a new thread.""" | |||
@@ -1,6 +1,6 @@ | |||
# -*- coding: utf-8 -*- | |||
# | |||
# Copyright (C) 2009-2015 Ben Kurtovic <ben.kurtovic@gmail.com> | |||
# Copyright (C) 2009-2019 Ben Kurtovic <ben.kurtovic@gmail.com> | |||
# | |||
# Permission is hereby granted, free of charge, to any person obtaining a copy | |||
# of this software and associated documentation files (the "Software"), to deal | |||
@@ -264,13 +264,15 @@ class Page(CopyvioMixIn): | |||
if not result: | |||
query = self.site.api_query | |||
result = query(action="query", prop="revisions", rvlimit=1, | |||
rvprop="content|timestamp", titles=self._title) | |||
rvprop="content|timestamp", rvslots="main", | |||
titles=self._title) | |||
res = result["query"]["pages"].values()[0] | |||
try: | |||
self._content = res["revisions"][0]["*"] | |||
self._basetimestamp = res["revisions"][0]["timestamp"] | |||
except KeyError: | |||
revision = res["revisions"][0] | |||
self._content = revision["slots"]["main"]["*"] | |||
self._basetimestamp = revision["timestamp"] | |||
except (KeyError, IndexError): | |||
# This can only happen if the page was deleted since we last called | |||
# self._load_attributes(). In that case, some of our attributes are | |||
# outdated, so force another self._load_attributes(): | |||
@@ -552,9 +554,9 @@ class Page(CopyvioMixIn): | |||
else: | |||
new_ns = self._namespace + 1 | |||
try: | |||
if self._namespace != 0: | |||
body = self._title.split(":", 1)[1] | |||
except IndexError: | |||
else: | |||
body = self._title | |||
new_prefix = self.site.namespace_id_to_name(new_ns) | |||
@@ -582,7 +584,7 @@ class Page(CopyvioMixIn): | |||
query = self.site.api_query | |||
result = query(action="query", rvlimit=1, titles=self._title, | |||
prop="info|revisions", inprop="protection|url", | |||
rvprop="content|timestamp") | |||
rvprop="content|timestamp", rvslots="main") | |||
self._load_attributes(result=result) | |||
self._assert_existence() | |||
self._load_content(result=result) | |||
@@ -1,6 +1,6 @@ | |||
# -*- coding: utf-8 -*- | |||
# | |||
# Copyright (C) 2009-2015 Ben Kurtovic <ben.kurtovic@gmail.com> | |||
# Copyright (C) 2009-2019 Ben Kurtovic <ben.kurtovic@gmail.com> | |||
# | |||
# Permission is hereby granted, free of charge, to any person obtaining a copy | |||
# of this software and associated documentation files (the "Software"), to deal | |||
@@ -21,17 +21,16 @@ | |||
# SOFTWARE. | |||
from cookielib import CookieJar | |||
from gzip import GzipFile | |||
from json import loads | |||
from logging import getLogger, NullHandler | |||
from os.path import expanduser | |||
from StringIO import StringIO | |||
from threading import RLock | |||
from time import sleep, time | |||
from urllib import quote_plus, unquote_plus | |||
from urllib2 import build_opener, HTTPCookieProcessor, URLError | |||
from urlparse import urlparse | |||
import requests | |||
from requests_oauthlib import OAuth1 | |||
from earwigbot import exceptions, importer | |||
from earwigbot.wiki import constants | |||
from earwigbot.wiki.category import Category | |||
@@ -83,15 +82,16 @@ class Site(object): | |||
""" | |||
SERVICE_API = 1 | |||
SERVICE_SQL = 2 | |||
SPECIAL_TOKENS = ["deleteglobalaccount", "patrol", "rollback", | |||
"setglobalaccountstatus", "userrights", "watch"] | |||
SPECIAL_TOKENS = ["createaccount", "deleteglobalaccount", "login", | |||
"patrol", "rollback", "setglobalaccountstatus", | |||
"userrights", "watch"] | |||
def __init__(self, name=None, project=None, lang=None, base_url=None, | |||
article_path=None, script_path=None, sql=None, | |||
namespaces=None, login=(None, None), cookiejar=None, | |||
user_agent=None, use_https=True, assert_edit=None, | |||
maxlag=None, wait_between_queries=2, logger=None, | |||
search_config=None): | |||
namespaces=None, login=(None, None), oauth=None, | |||
cookiejar=None, user_agent=None, use_https=True, | |||
assert_edit=None, maxlag=None, wait_between_queries=2, | |||
logger=None, search_config=None): | |||
"""Constructor for new Site instances. | |||
This probably isn't necessary to call yourself unless you're building a | |||
@@ -100,14 +100,15 @@ class Site(object): | |||
based on your config file and the sites database. We accept a bunch of | |||
kwargs, but the only ones you really "need" are *base_url* and | |||
*script_path*; this is enough to figure out an API url. *login*, a | |||
tuple of (username, password), is highly recommended. *cookiejar* will | |||
be used to store cookies, and we'll use a normal CookieJar if none is | |||
given. | |||
tuple of (username, password), can be used to log in using the legacy | |||
BotPasswords system; otherwise, a dict of OAuth info should be provided | |||
to *oauth*. *cookiejar* will be used to store cookies, and we'll use a | |||
normal CookieJar if none is given. | |||
First, we'll store the given arguments as attributes, then set up our | |||
URL opener. We'll load any of the attributes that weren't given from | |||
the API, and then log in if a username/pass was given and we aren't | |||
already logged in. | |||
requests session. We'll load any of the attributes that weren't given | |||
from the API, and then log in if a username/pass was given and we | |||
aren't already logged in. | |||
""" | |||
# Attributes referring to site information, filled in by an API query | |||
# if they are missing (and an API url can be determined): | |||
@@ -145,16 +146,22 @@ class Site(object): | |||
else: | |||
self._search_config = {} | |||
# Set up cookiejar and URL opener for making API queries: | |||
# Set up cookiejar and requests session for making API queries: | |||
if cookiejar is not None: | |||
self._cookiejar = cookiejar | |||
else: | |||
self._cookiejar = CookieJar() | |||
self._last_cookiejar_save = None | |||
if not user_agent: | |||
user_agent = constants.USER_AGENT # Set default UA | |||
self._opener = build_opener(HTTPCookieProcessor(self._cookiejar)) | |||
self._opener.addheaders = [("User-Agent", user_agent), | |||
("Accept-Encoding", "gzip")] | |||
self._oauth = oauth | |||
self._session = requests.Session() | |||
self._session.cookies = self._cookiejar | |||
self._session.headers["User-Agent"] = user_agent | |||
if oauth: | |||
self._session.auth = OAuth1( | |||
oauth["consumer_token"], oauth["consumer_secret"], | |||
oauth["access_token"], oauth["access_secret"]) | |||
# Set up our internal logger: | |||
if logger: | |||
@@ -168,7 +175,7 @@ class Site(object): | |||
# If we have a name/pass and the API says we're not logged in, log in: | |||
self._login_info = name, password = login | |||
if name and password: | |||
if not self._oauth and name and password: | |||
logged_in_as = self._get_username_from_cookies() | |||
if not logged_in_as or name.replace("_", " ") != logged_in_as: | |||
self._login(login) | |||
@@ -180,17 +187,18 @@ class Site(object): | |||
"base_url={_base_url!r}", "article_path={_article_path!r}", | |||
"script_path={_script_path!r}", "use_https={_use_https!r}", | |||
"assert_edit={_assert_edit!r}", "maxlag={_maxlag!r}", | |||
"sql={_sql_data!r}", "login={0}", "user_agent={2!r}", | |||
"cookiejar={1})")) | |||
"sql={_sql_data!r}", "login={0}", "oauth={1}", "user_agent={3!r}", | |||
"cookiejar={2})")) | |||
name, password = self._login_info | |||
login = "({0}, {1})".format(repr(name), "hidden" if password else None) | |||
oauth = "hidden" if self._oauth else None | |||
cookies = self._cookiejar.__class__.__name__ | |||
if hasattr(self._cookiejar, "filename"): | |||
cookies += "({0!r})".format(getattr(self._cookiejar, "filename")) | |||
else: | |||
cookies += "()" | |||
agent = self._opener.addheaders[0][1] | |||
return res.format(login, cookies, agent, **self.__dict__) | |||
agent = self.user_agent | |||
return res.format(login, oauth, cookies, agent, **self.__dict__) | |||
def __str__(self): | |||
"""Return a nice string representation of the Site.""" | |||
@@ -232,28 +240,18 @@ class Site(object): | |||
url, data = self._build_api_query(params, ignore_maxlag, no_assert) | |||
if "lgpassword" in params: | |||
self._logger.debug("{0} -> <hidden>".format(url)) | |||
elif len(data) > 1000: | |||
self._logger.debug("{0} -> {1}...".format(url, data[:997])) | |||
else: | |||
self._logger.debug("{0} -> {1}".format(url, data)) | |||
try: | |||
response = self._opener.open(url, data) | |||
except URLError as error: | |||
if hasattr(error, "reason"): | |||
e = "API query failed: {0}.".format(error.reason) | |||
elif hasattr(error, "code"): | |||
e = "API query failed: got an error code of {0}." | |||
e = e.format(error.code) | |||
else: | |||
e = "API query failed." | |||
raise exceptions.APIError(e) | |||
response = self._session.post(url, data=data) | |||
response.raise_for_status() | |||
except requests.RequestException as exc: | |||
raise exceptions.APIError("API query failed: {0}".format(exc)) | |||
result = response.read() | |||
if response.headers.get("Content-Encoding") == "gzip": | |||
stream = StringIO(result) | |||
gzipper = GzipFile(fileobj=stream) | |||
result = gzipper.read() | |||
return self._handle_api_result(result, params, tries, wait, ae_retry) | |||
return self._handle_api_result(response, params, tries, wait, ae_retry) | |||
def _request_csrf_token(self, params): | |||
"""If possible, add a request for a CSRF token to an API query.""" | |||
@@ -288,14 +286,28 @@ class Site(object): | |||
data = self._urlencode_utf8(params) | |||
return url, data | |||
def _handle_api_result(self, result, params, tries, wait, ae_retry): | |||
"""Given the result of an API query, attempt to return useful data.""" | |||
def _handle_api_result(self, response, params, tries, wait, ae_retry): | |||
"""Given an API query response, attempt to return useful data.""" | |||
try: | |||
res = loads(result) # Try to parse as a JSON object | |||
res = response.json() | |||
except ValueError: | |||
e = "API query failed: JSON could not be decoded." | |||
raise exceptions.APIError(e) | |||
if "warnings" in res: | |||
for name, value in res["warnings"].items(): | |||
try: | |||
warning = value["warnings"] | |||
except KeyError: | |||
try: | |||
warning = value["*"] | |||
except KeyError: | |||
warning = value | |||
self._logger.warning("API warning: %s: %s", name, warning) | |||
if self._should_save_cookiejar(): | |||
self._save_cookiejar() | |||
try: | |||
code = res["error"]["code"] | |||
info = res["error"]["info"] | |||
@@ -315,18 +327,18 @@ class Site(object): | |||
sleep(wait) | |||
return self._api_query(params, tries, wait * 2, ae_retry=ae_retry) | |||
elif code in ["assertuserfailed", "assertbotfailed"]: # AssertEdit | |||
if ae_retry and all(self._login_info): | |||
if ae_retry and all(self._login_info) and not self._oauth: | |||
# Try to log in if we got logged out: | |||
self._login(self._login_info) | |||
if "token" in params: # Fetch a new one; this is invalid now | |||
params["token"] = self.get_token(params["action"]) | |||
return self._api_query(params, tries, wait, ae_retry=False) | |||
if not all(self._login_info): | |||
if not all(self._login_info) and not self._oauth: | |||
e = "Assertion failed, and no login info was provided." | |||
elif code == "assertbotfailed": | |||
e = "Bot assertion failed: we don't have a bot flag!" | |||
else: | |||
e = "User assertion failed due to an unknown issue. Cookie problem?" | |||
e = "User assertion failed due to an unknown issue. Cookie or OAuth problem?" | |||
raise exceptions.PermissionsError("AssertEdit: " + e) | |||
else: # Some unknown error occurred | |||
e = 'API query failed: got error "{0}"; server says: "{1}".' | |||
@@ -463,15 +475,30 @@ class Site(object): | |||
unnecessary API query. For the cookie-detection method, see | |||
_get_username_from_cookies()'s docs. | |||
If our username isn't in cookies, then we're probably not logged in, or | |||
something fishy is going on (like forced logout). In this case, do a | |||
single API query for our username (or IP address) and return that. | |||
If our username isn't in cookies, then we're either using OAuth or | |||
we're probably not logged in, or something fishy is going on (like | |||
forced logout). If we're using OAuth and a username was configured, | |||
assume it is accurate and use it. Otherwise, do a single API query for | |||
our username (or IP address) and return that. | |||
""" | |||
name = self._get_username_from_cookies() | |||
if name: | |||
return name | |||
if self._oauth and self._login_info[0]: | |||
return self._login_info[0] | |||
return self._get_username_from_api() | |||
def _should_save_cookiejar(self): | |||
"""Return a bool indicating whether we should save the cookiejar. | |||
This is True if we haven't saved the cookiejar yet this session, or if | |||
our last save was over a day ago. | |||
""" | |||
max_staleness = 60 * 60 * 24 # 1 day | |||
if not self._last_cookiejar_save: | |||
return True | |||
return time() - self._last_cookiejar_save > max_staleness | |||
def _save_cookiejar(self): | |||
"""Try to save our cookiejar after doing a (normal) login or logout. | |||
@@ -485,8 +512,9 @@ class Site(object): | |||
getattr(self._cookiejar, "save")() | |||
except (NotImplementedError, ValueError): | |||
pass | |||
self._last_cookiejar_save = time() | |||
def _login(self, login, token=None, attempt=0): | |||
def _login(self, login): | |||
"""Safely login through the API. | |||
Normally, this is called by __init__() if a username and password have | |||
@@ -494,45 +522,43 @@ class Site(object): | |||
time it needs to be called is when those cookies expire, which is done | |||
automatically by api_query() if a query fails. | |||
Recent versions of MediaWiki's API have fixed a CSRF vulnerability, | |||
requiring login to be done in two separate requests. If the response | |||
from from our initial request is "NeedToken", we'll do another one with | |||
the token. If login is successful, we'll try to save our cookiejar. | |||
*login* is a (username, password) tuple. | |||
Raises LoginError on login errors (duh), like bad passwords and | |||
nonexistent usernames. | |||
*login* is a (username, password) tuple. *token* is the token returned | |||
from our first request, and *attempt* is to prevent getting stuck in a | |||
loop if MediaWiki isn't acting right. | |||
""" | |||
self._tokens.clear() | |||
name, password = login | |||
params = {"action": "login", "lgname": name, "lgpassword": password} | |||
if token: | |||
params["lgtoken"] = token | |||
params = {"action": "query", "meta": "tokens", "type": "login"} | |||
with self._api_lock: | |||
result = self._api_query(params, no_assert=True) | |||
try: | |||
token = result["query"]["tokens"]["logintoken"] | |||
except KeyError: | |||
raise exceptions.LoginError("Couldn't get login token") | |||
params = {"action": "login", "lgname": name, "lgpassword": password, | |||
"lgtoken": token} | |||
with self._api_lock: | |||
result = self._api_query(params, no_assert=True) | |||
res = result["login"]["result"] | |||
if res == "Success": | |||
self._tokens.clear() | |||
self._save_cookiejar() | |||
elif res == "NeedToken" and attempt == 0: | |||
token = result["login"]["token"] | |||
return self._login(login, token, attempt=1) | |||
return | |||
if res == "Illegal": | |||
e = "The provided username is illegal." | |||
elif res == "NotExists": | |||
e = "The provided username does not exist." | |||
elif res == "EmptyPass": | |||
e = "No password was given." | |||
elif res == "WrongPass" or res == "WrongPluginPass": | |||
e = "The given password is incorrect." | |||
else: | |||
if res == "Illegal": | |||
e = "The provided username is illegal." | |||
elif res == "NotExists": | |||
e = "The provided username does not exist." | |||
elif res == "EmptyPass": | |||
e = "No password was given." | |||
elif res == "WrongPass" or res == "WrongPluginPass": | |||
e = "The given password is incorrect." | |||
else: | |||
e = "Couldn't login; server says '{0}'.".format(res) | |||
raise exceptions.LoginError(e) | |||
e = "Couldn't login; server says '{0}'.".format(res) | |||
raise exceptions.LoginError(e) | |||
def _logout(self): | |||
"""Safely logout through the API. | |||
@@ -650,6 +676,11 @@ class Site(object): | |||
url = "http:" + url | |||
return url | |||
@property | |||
def user_agent(self): | |||
"""The User-Agent header sent to the API by the requests session.""" | |||
return self._session.headers["User-Agent"] | |||
def api_query(self, **kwargs): | |||
"""Do an API query with `kwargs` as the parameters. | |||
@@ -666,10 +697,9 @@ class Site(object): | |||
:py:attr:`self._assert_edit` and :py:attr:`_maxlag` respectively. | |||
Additionally, we'll sleep a bit if the last query was made fewer than | |||
:py:attr:`self._wait_between_queries` seconds ago. The request is made | |||
through :py:attr:`self._opener`, which has cookie support | |||
(:py:attr:`self._cookiejar`), a ``User-Agent`` | |||
(:py:const:`earwigbot.wiki.constants.USER_AGENT`), and | |||
``Accept-Encoding`` set to ``"gzip"``. | |||
through :py:attr:`self._session`, which has cookie support | |||
(:py:attr:`self._cookiejar`) and a ``User-Agent`` | |||
(:py:const:`earwigbot.wiki.constants.USER_AGENT`). | |||
Assuming everything went well, we'll gunzip the data (if compressed), | |||
load it as a JSON object, and return it. | |||
@@ -1,6 +1,6 @@ | |||
# -*- coding: utf-8 -*- | |||
# | |||
# Copyright (C) 2009-2015 Ben Kurtovic <ben.kurtovic@gmail.com> | |||
# Copyright (C) 2009-2019 Ben Kurtovic <ben.kurtovic@gmail.com> | |||
# | |||
# Permission is hereby granted, free of charge, to any person obtaining a copy | |||
# of this software and associated documentation files (the "Software"), to deal | |||
@@ -187,6 +187,7 @@ class SitesDB(object): | |||
config = self.config | |||
login = (config.wiki.get("username"), config.wiki.get("password")) | |||
oauth = config.wiki.get("oauth") | |||
user_agent = config.wiki.get("userAgent") | |||
use_https = config.wiki.get("useHTTPS", True) | |||
assert_edit = config.wiki.get("assert") | |||
@@ -212,7 +213,7 @@ class SitesDB(object): | |||
return Site(name=name, project=project, lang=lang, base_url=base_url, | |||
article_path=article_path, script_path=script_path, | |||
sql=sql, namespaces=namespaces, login=login, | |||
sql=sql, namespaces=namespaces, login=login, oauth=oauth, | |||
cookiejar=cookiejar, user_agent=user_agent, | |||
use_https=use_https, assert_edit=assert_edit, | |||
maxlag=maxlag, wait_between_queries=wait_between_queries, | |||
@@ -239,7 +240,7 @@ class SitesDB(object): | |||
if site: | |||
return site[0] | |||
else: | |||
url = "%{0}.{1}%".format(lang, project) | |||
url = "//{0}.{1}.%".format(lang, project) | |||
site = conn.execute(query2, (url,)).fetchone() | |||
return site[0] if site else None | |||
except sqlite.OperationalError: | |||
@@ -386,6 +387,7 @@ class SitesDB(object): | |||
config = self.config | |||
login = (config.wiki.get("username"), config.wiki.get("password")) | |||
oauth = config.wiki.get("oauth") | |||
user_agent = config.wiki.get("userAgent") | |||
use_https = config.wiki.get("useHTTPS", True) | |||
assert_edit = config.wiki.get("assert") | |||
@@ -398,9 +400,10 @@ class SitesDB(object): | |||
# Create a Site object to log in and load the other attributes: | |||
site = Site(base_url=base_url, script_path=script_path, sql=sql, | |||
login=login, cookiejar=cookiejar, user_agent=user_agent, | |||
use_https=use_https, assert_edit=assert_edit, | |||
maxlag=maxlag, wait_between_queries=wait_between_queries) | |||
login=login, oauth=oauth, cookiejar=cookiejar, | |||
user_agent=user_agent, use_https=use_https, | |||
assert_edit=assert_edit, maxlag=maxlag, | |||
wait_between_queries=wait_between_queries) | |||
self._logger.info("Added site '{0}'".format(site.name)) | |||
self._add_site_to_sitesdb(site) | |||
@@ -1,7 +1,7 @@ | |||
#! /usr/bin/env python | |||
# -*- coding: utf-8 -*- | |||
# | |||
# Copyright (C) 2009-2015 Ben Kurtovic <ben.kurtovic@gmail.com> | |||
# Copyright (C) 2009-2019 Ben Kurtovic <ben.kurtovic@gmail.com> | |||
# | |||
# Permission is hereby granted, free of charge, to any person obtaining a copy | |||
# of this software and associated documentation files (the "Software"), to deal | |||
@@ -26,8 +26,10 @@ from setuptools import setup, find_packages | |||
from earwigbot import __version__ | |||
required_deps = [ | |||
"PyYAML >= 3.11", # Parsing config files | |||
"mwparserfromhell >= 0.4.3", # Parsing wikicode for manipulation | |||
"PyYAML >= 3.12", # Parsing config files | |||
"mwparserfromhell >= 0.5", # Parsing wikicode for manipulation | |||
"requests >= 2.21.0", # Wiki API requests | |||
"requests_oauthlib >= 1.2.0", # API authentication via OAuth | |||
] | |||
extra_deps = { | |||
@@ -36,19 +38,19 @@ extra_deps = { | |||
"pycrypto >= 2.6.1", # Storing bot passwords + keys in the config file | |||
], | |||
"sql": [ | |||
"oursql >= 0.9.3.1", # Interfacing with MediaWiki databases | |||
"oursql >= 0.9.3.2", # Interfacing with MediaWiki databases | |||
], | |||
"copyvios": [ | |||
"beautifulsoup4 >= 4.4.1", # Parsing/scraping HTML | |||
"cchardet >= 1.0.0", # Encoding detection for BeautifulSoup | |||
"lxml >= 3.4.4", # Faster parser for BeautifulSoup | |||
"nltk >= 3.1", # Parsing sentences to split article content | |||
"beautifulsoup4 >= 4.6.0", # Parsing/scraping HTML | |||
"cchardet >= 2.1.1", # Encoding detection for BeautifulSoup | |||
"lxml >= 3.8.0", # Faster parser for BeautifulSoup | |||
"nltk >= 3.2.4", # Parsing sentences to split article content | |||
"oauth2 >= 1.9.0", # Interfacing with Yahoo! BOSS Search | |||
"pdfminer >= 20140328", # Extracting text from PDF files | |||
"tldextract >= 1.7.1", # Getting domains for the multithreaded workers | |||
"tldextract >= 2.1.0", # Getting domains for the multithreaded workers | |||
], | |||
"time": [ | |||
"pytz >= 2015.7", # Handling timezones for the !time IRC command | |||
"pytz >= 2017.2", # Handling timezones for the !time IRC command | |||
], | |||
} | |||