Browse Source

Only return the right codelet symbols from the database (closes #46).

tags/v1.0^2
Ben Kurtovic 10 years ago
parent
commit
00058d3879
2 changed files with 36 additions and 12 deletions
  1. +22
    -12
      bitshift/database/__init__.py
  2. +14
    -0
      bitshift/query/tree.py

+ 22
- 12
bitshift/database/__init__.py View File

@@ -62,7 +62,7 @@ class Database(object):
""" """
query, args = tree.build_query(page) query, args = tree.build_query(page)
cursor.execute(query, args) cursor.execute(query, args)
ids = [id for id, _ in cursor.fetchall()]
ids = [cid for cid, _ in cursor.fetchall()]
num_results = len(ids) # TODO: This is not entirely correct num_results = len(ids) # TODO: This is not entirely correct
return ids, num_results return ids, num_results


@@ -75,16 +75,25 @@ class Database(object):
cursor.execute(query, (codelet_id,)) cursor.execute(query, (codelet_id,))
return cursor.fetchall() return cursor.fetchall()


def _get_symbols_for_code(self, cursor, code_id):
def _get_symbols_for_code(self, cursor, code_id, tree):
"""Return a list of symbols for a given codelet.""" """Return a list of symbols for a given codelet."""
query = """SELECT symbol_type, symbol_name, sloc_type, sloc_row, query = """SELECT symbol_type, symbol_name, sloc_type, sloc_row,
sloc_col, sloc_end_row, sloc_end_col sloc_col, sloc_end_row, sloc_end_col
FROM symbols FROM symbols
INNER JOIN symbol_locations ON sloc_symbol = symbol_id INNER JOIN symbol_locations ON sloc_symbol = symbol_id
WHERE symbol_code = ?"""
WHERE symbol_code = ? AND (%s)"""

conds, args = [], [code_id]
for node in tree.walk(Symbol):
node_cond, node_args, _, _ = node.parameterize(set())
conds.append(node_cond)
args += node_args
if not conds:
return {}
cond = " OR ".join(conds)


symbols = {type_: {} for type_ in Symbol.TYPES_INV} symbols = {type_: {} for type_ in Symbol.TYPES_INV}
cursor.execute(query, (code_id,))
cursor.execute(query % cond, tuple(args))
for type_, name, loc_type, row, col, erow, ecol in cursor.fetchall(): for type_, name, loc_type, row, col, erow, ecol in cursor.fetchall():
sdict = symbols[Symbol.TYPES_INV[type_]] sdict = symbols[Symbol.TYPES_INV[type_]]
if name not in sdict: if name not in sdict:
@@ -94,7 +103,7 @@ class Database(object):
symbols[type_] = [(n, d, u) for n, (d, u) in sdict.iteritems()] symbols[type_] = [(n, d, u) for n, (d, u) in sdict.iteritems()]
return symbols return symbols


def _get_codelets_from_ids(self, cursor, ids):
def _get_codelets_from_ids(self, cursor, ids, tree):
"""Return a list of Codelet objects given a list of codelet IDs.""" """Return a list of Codelet objects given a list of codelet IDs."""
query = """SELECT * query = """SELECT *
FROM codelets FROM codelets
@@ -106,6 +115,7 @@ class Database(object):
for codelet_id in ids: for codelet_id in ids:
dict_cursor.execute(query, (codelet_id,)) dict_cursor.execute(query, (codelet_id,))
row = dict_cursor.fetchall()[0] row = dict_cursor.fetchall()[0]
code_id = row["code_id"]
if row["origin_url_base"]: if row["origin_url_base"]:
url = row["origin_url_base"] + row["codelet_url"] url = row["origin_url_base"] + row["codelet_url"]
else: else:
@@ -113,7 +123,7 @@ class Database(object):
origin = (row["origin_name"], row["origin_url"], origin = (row["origin_name"], row["origin_url"],
row["origin_image"]) row["origin_image"])
authors = self._get_authors_for_codelet(cursor, codelet_id) authors = self._get_authors_for_codelet(cursor, codelet_id)
symbols = self._get_symbols_for_code(cursor, row["code_id"])
symbols = self._get_symbols_for_code(cursor, code_id, tree)
yield Codelet( yield Codelet(
row["codelet_name"], row["code_code"], None, row["codelet_name"], row["code_code"], None,
row["code_lang"], authors, url, row["code_lang"], authors, url,
@@ -148,12 +158,12 @@ class Database(object):
"""Disconnect from the database.""" """Disconnect from the database."""
self._conn.close() self._conn.close()


def search(self, query, page=1):
def search(self, tree, page=1):
""" """
Search the database for a query and return the *n*\ th page of results. Search the database for a query and return the *n*\ th page of results.


:param query: The query to search for.
:type query: :py:class:`~.query.tree.Tree`
:param tree: The query to search for.
:type tree: :py:class:`~.query.tree.Tree`
:param page: The result page to display. :param page: The result page to display.
:type page: int :type page: int


@@ -169,7 +179,7 @@ class Database(object):
query3 = "INSERT INTO cache VALUES (?, ?, ?, DEFAULT)" query3 = "INSERT INTO cache VALUES (?, ?, ?, DEFAULT)"
query4 = "INSERT INTO cache_data VALUES (?, ?, ?)" query4 = "INSERT INTO cache_data VALUES (?, ?, ?)"


cache_id = mmh3.hash64(str(page) + ":" + query.serialize())[0]
cache_id = mmh3.hash64(str(page) + ":" + tree.serialize())[0]


with self._conn.cursor() as cursor: with self._conn.cursor() as cursor:
cursor.execute(query1, (cache_id,)) cursor.execute(query1, (cache_id,))
@@ -180,14 +190,14 @@ class Database(object):
num_results = rows[0][1] * (10 ** rows[0][2]) if rows else 0 num_results = rows[0][1] * (10 ** rows[0][2]) if rows else 0
ids = [row[0] for row in rows] ids = [row[0] for row in rows]
else: else:
ids, num_results = self._search_with_query(cursor, query, page)
ids, num_results = self._search_with_query(cursor, tree, page)
num_exp = max(len(str(num_results)) - 3, 0) num_exp = max(len(str(num_results)) - 3, 0)
num_results = int(round(num_results, -num_exp)) num_results = int(round(num_results, -num_exp))
num_mnt = num_results / (10 ** num_exp) num_mnt = num_results / (10 ** num_exp)
cursor.execute(query3, (cache_id, num_mnt, num_exp)) cursor.execute(query3, (cache_id, num_mnt, num_exp))
cdata = [(cache_id, c_id, i) for i, c_id in enumerate(ids)] cdata = [(cache_id, c_id, i) for i, c_id in enumerate(ids)]
cursor.executemany(query4, cdata) cursor.executemany(query4, cdata)
codelet_gen = self._get_codelets_from_ids(cursor, ids)
codelet_gen = self._get_codelets_from_ids(cursor, ids, tree)
return (num_results, list(codelet_gen)) return (num_results, list(codelet_gen))


def insert(self, codelet): def insert(self, codelet):


+ 14
- 0
bitshift/query/tree.py View File

@@ -1,3 +1,5 @@
from . import nodes

__all__ = ["Tree"] __all__ = ["Tree"]


QUERY_TEMPLATE = """SELECT codelet_id, (codelet_rank%s) AS score QUERY_TEMPLATE = """SELECT codelet_id, (codelet_rank%s) AS score
@@ -33,6 +35,18 @@ class Tree(object):
""" """
return repr(self) return repr(self)


def walk(self, node_type=None):
"""Walk through the query tree, returning nodes of a specific type."""
pending = [self._root]
while pending:
node = pending.pop()
if not node_type or isinstance(node, node_type):
yield node
if isinstance(node, nodes.UnaryOp):
pending.append(node.node)
elif isinstance(node, nodes.BinaryOp):
pending.extend([node.left, node.right])

def build_query(self, page=1, page_size=10): def build_query(self, page=1, page_size=10):
"""Convert the query tree into a parameterized SQL SELECT statement. """Convert the query tree into a parameterized SQL SELECT statement.




Loading…
Cancel
Save