From d6ccdbd16d1db369801ebd7a12ba1bf90df5225a Mon Sep 17 00:00:00 2001 From: Ben Kurtovic Date: Fri, 2 May 2014 22:43:16 -0400 Subject: [PATCH] Fix a couble Database bugs. --- bitshift/database/__init__.py | 15 ++++++++------- 1 file changed, 8 insertions(+), 7 deletions(-) diff --git a/bitshift/database/__init__.py b/bitshift/database/__init__.py index bc4b451..1a2b373 100644 --- a/bitshift/database/__init__.py +++ b/bitshift/database/__init__.py @@ -16,15 +16,15 @@ class Database(object): """Represents the MySQL database.""" def __init__(self, migrate=False): - self._connect() + self._conn = self._connect() self._check_version(migrate) def _connect(self): """Establish a connection to the database.""" root = os.path.dirname(os.path.abspath(__file__)) default_file = os.path.join(root, ".my.cnf") - self._conn = oursql.connect(read_default_file=default_file, - autoping=True, autoreconnect=True) + return oursql.connect(db="bitshift", read_default_file=default_file, + autoping=True, autoreconnect=True) def _migrate(self, cursor, current): """Migrate the database to the latest schema version.""" @@ -58,8 +58,9 @@ class Database(object): def _insert_symbols(self, cursor, code_id, sym_type, symbols): """Insert a list of symbols of a given type into the database.""" sym_types = ["functions", "classes", "variables"] - query1 = "INSERT INTO symbols VALUES (?, ?, ?)" - query2 = "INSERT INTO symbol_locations VALUES (?, ?, ?, ?, ?, ?)" + query1 = "INSERT INTO symbols VALUES (DEFAULT, ?, ?, ?)" + query2 = """INSERT INTO symbol_locations VALUES + (DEFAULT, ?, ?, ?, ?, ?, ?)""" for (name, decls, uses) in symbols: cursor.execute(query1, (code_id, sym_types.index(sym_type), name)) @@ -105,8 +106,8 @@ class Database(object): query1 = """INSERT INTO code VALUES (?, ?) ON DUPLICATE KEY UPDATE code_id=code_id""" query2 = """INSERT INTO codelets VALUES - (?, ?, ?, ?, ?, ?, ?, ?)""" - query3 = "INSERT INTO authors VALUES (?, ?, ?)" + (DEFAULT, ?, ?, ?, ?, ?, ?, ?, ?)""" + query3 = "INSERT INTO authors VALUES (DEFAULT, ?, ?, ?)" code_id = mmh3.hash64(codelet.code.encode("utf8"))[0] origin, url = self._decompose_url(codelet.url)