diff --git a/earwigbot/wiki/sitesdb.py b/earwigbot/wiki/sitesdb.py index c6e3f57..a436609 100644 --- a/earwigbot/wiki/sitesdb.py +++ b/earwigbot/wiki/sitesdb.py @@ -55,6 +55,7 @@ class SitesDB(object): def __init__(self, config): """Set up the manager with an attribute for the BotConfig object.""" self.config = config + self._sites = {} # Internal site cache self._sitesdb = path.join(config.root_dir, "sites.db") self._cookie_file = path.join(config.root_dir, ".cookies") self._cookiejar = None @@ -103,6 +104,19 @@ class SitesDB(object): with sqlite.connect(self._sitesdb) as conn: conn.executescript(script) + def _get_site_object(self, name): + """Return the site from our cache, or create it if it doesn't exist. + + This is essentially just a wrapper around _make_site_object that + returns the same object each time a specific site is asked for. + """ + try: + return self._sites[name] + except KeyError: + site = self._make_site_object(name) + self._sites[name] = site + return site + def _load_site_from_sitesdb(self, name): """Return all information stored in the sitesdb relating to given site. @@ -221,7 +235,12 @@ class SitesDB(object): conn.executemany("INSERT INTO namespaces VALUES (?, ?, ?, ?)", ns_data) def _remove_site_from_sitesdb(self, name): - """Remove a site by name from the sitesdb.""" + """Remove a site by name from the sitesdb and the internal cache.""" + try: + del self._sites[name] + except KeyError: + pass + with sqlite.connect(self._sitesdb) as conn: cursor = conn.execute("DELETE FROM sites WHERE site_name = ?", (name,)) if cursor.rowcount == 0: @@ -267,23 +286,23 @@ class SitesDB(object): except KeyError: e = "Default site is not specified in config." raise SiteNotFoundError(e) - return self._make_site_object(default) + return self._get_site_object(default) # Name arg given, but don't look at others unless `name` isn't found: if name: try: - return self._make_site_object(name) + return self._get_site_object(name) except SiteNotFoundError: if project and lang: name = self._get_site_name_from_sitesdb(project, lang) if name: - return self._make_site_object(name) + return self._get_site_object(name) raise # If we end up here, then project and lang are the only args given: name = self._get_site_name_from_sitesdb(project, lang) if name: - return self._make_site_object(name) + return self._get_site_object(name) e = "Site '{0}:{1}' not found in the sitesdb.".format(project, lang) raise SiteNotFoundError(e) @@ -333,6 +352,7 @@ class SitesDB(object): search_config=search_config) self._add_site_to_sitesdb(site) + self._sites[site.name()] = site return site def remove_site(self, name=None, project=None, lang=None):