# -*- coding: utf-8 -*- from functools import wraps from hashlib import md5 from os import path from traceback import format_exc from flask import flash, g, request, url_for from flask_mako import render_template, TemplateError from werkzeug.exceptions import HTTPException from .exceptions import AccessDeniedError, EVEAPIError from .messages import Messages __all__ = [ "try_func", "make_error_catcher", "make_route_restricter", "set_up_asset_versioning"] def try_func(inner): """Evaluate inner(), catching subclasses of CalefactionError. If nothing was caught, return (inner(), False). Otherwise, flash an appropriate error message and return (False, True). """ try: result = inner() except EVEAPIError: flash(Messages.EVE_API_ERROR, "error") return (False, True) except AccessDeniedError: flash(Messages.ACCESS_DENIED, "error") return (False, True) if getattr(g, "_session_expired", False): flash(Messages.SESSION_EXPIRED, "error") return (result, True) return (result, False) def make_error_catcher(app, error_template): """Wrap a route to display and log any uncaught exceptions.""" def callback(func): @wraps(func) def inner(*args, **kwargs): try: return func(*args, **kwargs) except HTTPException: raise except TemplateError as exc: app.logger.error("Caught exception:\n{0}".format(exc.text)) trace = exc.text except Exception: app.logger.exception("Caught exception:") trace = format_exc() return render_template(error_template, traceback=trace), 500 return inner return callback def make_route_restricter(auth, on_failure): """Wrap a route to ensure the user is authenticated.""" def callback(func): @wraps(func) def inner(*args, **kwargs): success, caught = try_func(auth.is_authenticated) if success: return func(*args, **kwargs) if not caught: flash(Messages.LOG_IN_FIRST, "error") return on_failure() return inner return callback def set_up_asset_versioning(app): """Add a staticv endpoint that adds hash versioning to static assets.""" def callback(app, error, endpoint, values): if endpoint == "staticv": filename = values["filename"] fpath = path.join(app.static_folder, filename) try: mtime = path.getmtime(fpath) except OSError: return url_for("static", filename=filename) cache = app._hash_cache.get(fpath) if cache and cache[0] == mtime: hashstr = cache[1] else: with open(fpath, "rb") as fp: hashstr = md5(fp.read()).hexdigest() app._hash_cache[fpath] = (mtime, hashstr) return url_for("static", filename=filename, v=hashstr) raise error old_get_max_age = app.get_send_file_max_age def extend_max_age(filename): if "v" in request.args: return 60 * 60 * 24 * 365 # 1 year return old_get_max_age(filename) app._hash_cache = {} app.url_build_error_handlers.append(lambda a, b, c: callback(app, a, b, c)) app.get_send_file_max_age = extend_max_age