Browse Source

Support new CSRF token API.

tags/v0.2
Ben Kurtovic 10 years ago
parent
commit
459c252fc7
2 changed files with 59 additions and 50 deletions
  1. +19
    -33
      earwigbot/wiki/page.py
  2. +40
    -17
      earwigbot/wiki/site.py

+ 19
- 33
earwigbot/wiki/page.py View File

@@ -116,7 +116,6 @@ class Page(CopyvioMixIn):
self._creator = None self._creator = None


# Attributes used for editing/deleting/protecting/etc: # Attributes used for editing/deleting/protecting/etc:
self._token = None
self._basetimestamp = None self._basetimestamp = None
self._starttimestamp = None self._starttimestamp = None


@@ -199,18 +198,18 @@ class Page(CopyvioMixIn):
"""Load various data from the API in a single query. """Load various data from the API in a single query.


Loads self._title, ._exists, ._is_redirect, ._pageid, ._fullurl, Loads self._title, ._exists, ._is_redirect, ._pageid, ._fullurl,
._protection, ._namespace, ._is_talkpage, ._creator, ._lastrevid,
._token, and ._starttimestamp using the API. It will do a query of
its own unless *result* is provided, in which case we'll pretend
*result* is what the query returned.
._protection, ._namespace, ._is_talkpage, ._creator, ._lastrevid, and
._starttimestamp using the API. It will do a query of its own unless
*result* is provided, in which case we'll pretend *result* is what the
query returned.


Assuming the API is sound, this should not raise any exceptions. Assuming the API is sound, this should not raise any exceptions.
""" """
if not result: if not result:
query = self.site.api_query query = self.site.api_query
result = query(action="query", rvprop="user", intoken="edit",
prop="info|revisions", rvlimit=1, rvdir="newer",
titles=self._title, inprop="protection|url")
result = query(action="query", prop="info|revisions",
inprop="protection|url", rvprop="user", rvlimit=1,
rvdir="newer", titles=self._title)


res = result["query"]["pages"].values()[0] res = result["query"]["pages"].values()[0]


@@ -233,13 +232,7 @@ class Page(CopyvioMixIn):


self._fullurl = res["fullurl"] self._fullurl = res["fullurl"]
self._protection = res["protection"] self._protection = res["protection"]

try:
self._token = res["edittoken"]
except KeyError:
pass
else:
self._starttimestamp = strftime("%Y-%m-%dT%H:%M:%SZ", gmtime())
self._starttimestamp = strftime("%Y-%m-%dT%H:%M:%SZ", gmtime())


# We've determined the namespace and talkpage status in __init__() # We've determined the namespace and talkpage status in __init__()
# based on the title, but now we can be sure: # based on the title, but now we can be sure:
@@ -291,13 +284,6 @@ class Page(CopyvioMixIn):
in _handle_edit_errors(). We'll then throw these back as subclasses of in _handle_edit_errors(). We'll then throw these back as subclasses of
EditError. EditError.
""" """
# Try to get our edit token, and die if we can't:
if not self._token:
self._load_attributes()
if not self._token:
e = "You don't have permission to edit this page."
raise exceptions.PermissionsError(e)

# Weed out invalid pages before we get too far: # Weed out invalid pages before we get too far:
self._assert_validity() self._assert_validity()


@@ -306,8 +292,7 @@ class Page(CopyvioMixIn):
params = self._build_edit_params(text, summary, minor, bot, force, params = self._build_edit_params(text, summary, minor, bot, force,
section, captcha_id, captcha_word) section, captcha_id, captcha_word)
else: # Make sure we have the right token: else: # Make sure we have the right token:
params["token"] = self._token
self._token = None # Token now invalid
params["token"] = self.site.get_token()


# Try the API query, catching most errors with our handler: # Try the API query, catching most errors with our handler:
try: try:
@@ -332,8 +317,9 @@ class Page(CopyvioMixIn):
"""Given some keyword arguments, build an API edit query string.""" """Given some keyword arguments, build an API edit query string."""
unitxt = text.encode("utf8") if isinstance(text, unicode) else text unitxt = text.encode("utf8") if isinstance(text, unicode) else text
hashed = md5(unitxt).hexdigest() # Checksum to ensure text is correct hashed = md5(unitxt).hexdigest() # Checksum to ensure text is correct
params = {"action": "edit", "title": self._title, "text": text,
"token": self._token, "summary": summary, "md5": hashed}
params = {
"action": "edit", "title": self._title, "text": text,
"token": self.site.get_token(), "summary": summary, "md5": hashed}


if section: if section:
params["section"] = section params["section"] = section
@@ -378,13 +364,13 @@ class Page(CopyvioMixIn):
self._exists = self.PAGE_UNKNOWN self._exists = self.PAGE_UNKNOWN
raise exceptions.EditConflictError(error.info) raise exceptions.EditConflictError(error.info)
elif error.code == "badtoken" and retry: elif error.code == "badtoken" and retry:
params["token"] = self.site.get_token("edit")
params["token"] = self.site.get_token(force=True)
try: try:
return self.site.api_query(**params) return self.site.api_query(**params)
except exceptions.APIError as error:
if not hasattr(error, "code"):
except exceptions.APIError as err:
if not hasattr(err, "code"):
raise # We can only handle errors with a code attribute raise # We can only handle errors with a code attribute
return self._handle_edit_errors(error, params, retry=False)
return self._handle_edit_errors(err, params, retry=False)
elif error.code in ["emptypage", "emptynewsection"]: elif error.code in ["emptypage", "emptynewsection"]:
raise exceptions.NoContentError(error.info) raise exceptions.NoContentError(error.info)
elif error.code == "contenttoobig": elif error.code == "contenttoobig":
@@ -577,7 +563,7 @@ class Page(CopyvioMixIn):
query = self.site.api_query query = self.site.api_query
result = query(action="query", rvlimit=1, titles=self._title, result = query(action="query", rvlimit=1, titles=self._title,
prop="info|revisions", inprop="protection|url", prop="info|revisions", inprop="protection|url",
intoken="edit", rvprop="content|timestamp")
rvprop="content|timestamp")
self._load_attributes(result=result) self._load_attributes(result=result)
self._assert_existence() self._assert_existence()
self._load_content(result=result) self._load_content(result=result)
@@ -610,7 +596,7 @@ class Page(CopyvioMixIn):
:py:exc:`~earwigbot.exceptions.RedirectError` if the page is not a :py:exc:`~earwigbot.exceptions.RedirectError` if the page is not a
redirect. redirect.
""" """
re_redirect = "^\s*\#\s*redirect\s*\[\[(.*?)\]\]"
re_redirect = r"^\s*\#\s*redirect\s*\[\[(.*?)\]\]"
content = self.get() content = self.get()
try: try:
return re.findall(re_redirect, content, flags=re.I)[0] return re.findall(re_redirect, content, flags=re.I)[0]
@@ -709,7 +695,7 @@ class Page(CopyvioMixIn):
username = username.lower() username = username.lower()
optouts = [optout.lower() for optout in optouts] if optouts else [] optouts = [optout.lower() for optout in optouts] if optouts else []


r_bots = "\{\{\s*(no)?bots\s*(\||\}\})"
r_bots = r"\{\{\s*(no)?bots\s*(\||\}\})"
filter = self.parse().ifilter_templates(recursive=True, matches=r_bots) filter = self.parse().ifilter_templates(recursive=True, matches=r_bots)
for template in filter: for template in filter:
if template.has_param("deny"): if template.has_param("deny"):


+ 40
- 17
earwigbot/wiki/site.py View File

@@ -83,6 +83,8 @@ class Site(object):
""" """
SERVICE_API = 1 SERVICE_API = 1
SERVICE_SQL = 2 SERVICE_SQL = 2
SPECIAL_TOKENS = ["deleteglobalaccount", "patrol", "rollback",
"setglobalaccountstatus", "userrights", "watch"]


def __init__(self, name=None, project=None, lang=None, base_url=None, def __init__(self, name=None, project=None, lang=None, base_url=None,
article_path=None, script_path=None, sql=None, article_path=None, script_path=None, sql=None,
@@ -124,6 +126,7 @@ class Site(object):
self._wait_between_queries = wait_between_queries self._wait_between_queries = wait_between_queries
self._max_retries = 6 self._max_retries = 6
self._last_query_time = 0 self._last_query_time = 0
self._tokens = {}
self._api_lock = RLock() self._api_lock = RLock()
self._api_info_cache = {"maxlag": 0, "lastcheck": 0} self._api_info_cache = {"maxlag": 0, "lastcheck": 0}


@@ -252,13 +255,25 @@ class Site(object):


return self._handle_api_result(result, params, tries, wait, ae_retry) return self._handle_api_result(result, params, tries, wait, ae_retry)


def _request_csrf_token(self, params):
"""If possible, add a request for a CSRF token to an API query."""
if params.get("action") == "query":
if params.get("meta"):
if "tokens" not in params["meta"].split("|"):
params["meta"] += "|tokens"
else:
params["meta"] = "tokens"
if params.get("type"):
if "csrf" not in params["type"].split("|"):
params["type"] += "|csrf"

def _build_api_query(self, params, ignore_maxlag, no_assert): def _build_api_query(self, params, ignore_maxlag, no_assert):
"""Given API query params, return the URL to query and POST data.""" """Given API query params, return the URL to query and POST data."""
if not self._base_url or self._script_path is None: if not self._base_url or self._script_path is None:
e = "Tried to do an API query, but no API URL is known." e = "Tried to do an API query, but no API URL is known."
raise exceptions.APIError(e) raise exceptions.APIError(e)


url = ''.join((self.url, self._script_path, "/api.php"))
url = self.url + self._script_path + "/api.php"
params["format"] = "json" # This is the only format we understand params["format"] = "json" # This is the only format we understand
if self._assert_edit and not no_assert: if self._assert_edit and not no_assert:
# If requested, ensure that we're logged in # If requested, ensure that we're logged in
@@ -266,6 +281,9 @@ class Site(object):
if self._maxlag and not ignore_maxlag: if self._maxlag and not ignore_maxlag:
# If requested, don't overload the servers: # If requested, don't overload the servers:
params["maxlag"] = self._maxlag params["maxlag"] = self._maxlag
if "csrf" not in self._tokens:
# If we don't have a CSRF token, try to fetch one:
self._request_csrf_token(params)


data = self._urlencode_utf8(params) data = self._urlencode_utf8(params)
return url, data return url, data
@@ -282,6 +300,9 @@ class Site(object):
code = res["error"]["code"] code = res["error"]["code"]
info = res["error"]["info"] info = res["error"]["info"]
except (TypeError, KeyError): # If there's no error code/info, return except (TypeError, KeyError): # If there's no error code/info, return
if "query" in res and "tokens" in res["query"]:
for name, token in res["query"]["tokens"].iteritems():
self._tokens[name.split("token")[0]] = token
return res return res


if code == "maxlag": # We've been throttled by the server if code == "maxlag": # We've been throttled by the server
@@ -326,7 +347,7 @@ class Site(object):
# All attributes to be loaded, except _namespaces, which is a special # All attributes to be loaded, except _namespaces, which is a special
# case because it requires additional params in the API query: # case because it requires additional params in the API query:
attrs = [self._name, self._project, self._lang, self._base_url, attrs = [self._name, self._project, self._lang, self._base_url,
self._article_path, self._script_path]
self._article_path, self._script_path]


params = {"action": "query", "meta": "siteinfo", "siprop": "general"} params = {"action": "query", "meta": "siteinfo", "siprop": "general"}


@@ -485,6 +506,7 @@ class Site(object):
from our first request, and *attempt* is to prevent getting stuck in a from our first request, and *attempt* is to prevent getting stuck in a
loop if MediaWiki isn't acting right. loop if MediaWiki isn't acting right.
""" """
self._tokens.clear()
name, password = login name, password = login


params = {"action": "login", "lgname": name, "lgpassword": password} params = {"action": "login", "lgname": name, "lgpassword": password}
@@ -764,25 +786,26 @@ class Site(object):
result = list(self.sql_query(query)) result = list(self.sql_query(query))
return int(result[0][0]) return int(result[0][0])


def get_token(self, action):
def get_token(self, action=None, force=False):
"""Return a token for a data-modifying API action. """Return a token for a data-modifying API action.


*action* must be one of the types listed on
<https://www.mediawiki.org/wiki/API:Tokens>. If it's given as a union
of types separated by |, then the function will return a dictionary
of tokens instead of a single one.
In general, this will be a CSRF token, unless *action* is in a special
list of non-CSRF tokens. Tokens are cached for the session (until
:meth:`_login` is called again); set *force* to ``True`` to force a new
token to be fetched.


Raises :py:exc:`~earwigbot.exceptions.PermissionsError` if we don't
have permissions for the requested action(s), or they are invalid.
Raises :py:exc:`~earwigbot.exceptions.APIError` if there was some other
API issue.
Raises :exc:`.APIError` if there was an API issue.
""" """
res = self.api_query(action="tokens", type=action)
if "warnings" in res and "tokens" in res["warnings"]:
raise exceptions.PermissionsError(res["warnings"]["tokens"]["*"])
if "|" in action:
return res["tokens"]
return res["tokens"].values()[0]
if action not in self.SPECIAL_TOKENS:
action = "csrf"
if action in self._tokens and not force:
return self._tokens[action]

res = self.api_query(action="query", meta="tokens", type=action)
if action not in self._tokens:
err = "Tried to fetch a {0} token, but API returned: {1}"
raise exceptions.APIError(err.format(action, res))
return self._tokens[action]


def namespace_id_to_name(self, ns_id, all=False): def namespace_id_to_name(self, ns_id, all=False):
"""Given a namespace ID, returns associated namespace names. """Given a namespace ID, returns associated namespace names.


Loading…
Cancel
Save