diff --git a/calefaction/database.py b/calefaction/database.py index 012b2f9..4297016 100644 --- a/calefaction/database.py +++ b/calefaction/database.py @@ -4,7 +4,7 @@ from datetime import datetime import random import sqlite3 -from flask import g +from flask import current_app, g from werkzeug.local import LocalProxy __all__ = ["Database"] @@ -62,6 +62,9 @@ class Database: period of time, or if their absolute age exceeds some number. We don't actually remove them until a bit after this time. """ + if current_app.debug: + return # Sessions don't expire in debug mode + query = """DELETE FROM session WHERE strftime("%s", "now") - strftime("%s", session_created) >= {} OR strftime("%s", "now") - strftime("%s", session_touched) >= {}""" @@ -71,6 +74,22 @@ class Database: with self._conn as conn: conn.execute(query.format(create_thresh, touch_thresh)) + def _build_expiry_check(self): + """Build and return a snippet of SQL to check for valid sessions. + + The SQL should be inserted in a WHERE clause. If debug mode is active, + we just return an empty string. + """ + if current_app.debug: + return "" + + check = """ AND + strftime("%s", "now") - + strftime("%s", session_created) < {} AND + strftime("%s", "now") - + strftime("%s", session_touched) < {}""" + return check.format(self.MAX_SESSION_AGE, self.MAX_SESSION_STALENESS) + def new_session(self): """Allocate a new session in the database. @@ -94,11 +113,7 @@ class Database: self._clear_old_sessions() query = """SELECT session_created FROM session - WHERE session_id = ? AND - strftime("%s", "now") - strftime("%s", session_created) < {} AND - strftime("%s", "now") - strftime("%s", session_touched) < {}""" - query = query.format(self.MAX_SESSION_AGE, self.MAX_SESSION_STALENESS) - + WHERE session_id = ?""" + self._build_expiry_check() res = self._conn.execute(query, (sid,)).fetchall() if not res: return None @@ -107,11 +122,7 @@ class Database: def read_session(self, sid): """Return the character associated with the given session, or None.""" query = """SELECT session_character FROM session - WHERE session_id = ? AND - strftime("%s", "now") - strftime("%s", session_created) < {} AND - strftime("%s", "now") - strftime("%s", session_touched) < {}""" - query = query.format(self.MAX_SESSION_AGE, self.MAX_SESSION_STALENESS) - + WHERE session_id = ?""" + self._build_expiry_check() res = self._conn.execute(query, (sid,)).fetchall() return res[0][0] if res else None