diff --git a/bitshift/database/__init__.py b/bitshift/database/__init__.py index 56f8b1c..b857ff8 100644 --- a/bitshift/database/__init__.py +++ b/bitshift/database/__init__.py @@ -62,7 +62,7 @@ class Database(object): """ query, args = tree.build_query(page) cursor.execute(query, args) - ids = [id for id, _ in cursor.fetchall()] + ids = [cid for cid, _ in cursor.fetchall()] num_results = len(ids) # TODO: This is not entirely correct return ids, num_results @@ -75,16 +75,25 @@ class Database(object): cursor.execute(query, (codelet_id,)) return cursor.fetchall() - def _get_symbols_for_code(self, cursor, code_id): + def _get_symbols_for_code(self, cursor, code_id, tree): """Return a list of symbols for a given codelet.""" query = """SELECT symbol_type, symbol_name, sloc_type, sloc_row, sloc_col, sloc_end_row, sloc_end_col FROM symbols INNER JOIN symbol_locations ON sloc_symbol = symbol_id - WHERE symbol_code = ?""" + WHERE symbol_code = ? AND (%s)""" + + conds, args = [], [code_id] + for node in tree.walk(Symbol): + node_cond, node_args, _, _ = node.parameterize(set()) + conds.append(node_cond) + args += node_args + if not conds: + return {} + cond = " OR ".join(conds) symbols = {type_: {} for type_ in Symbol.TYPES_INV} - cursor.execute(query, (code_id,)) + cursor.execute(query % cond, tuple(args)) for type_, name, loc_type, row, col, erow, ecol in cursor.fetchall(): sdict = symbols[Symbol.TYPES_INV[type_]] if name not in sdict: @@ -94,7 +103,7 @@ class Database(object): symbols[type_] = [(n, d, u) for n, (d, u) in sdict.iteritems()] return symbols - def _get_codelets_from_ids(self, cursor, ids): + def _get_codelets_from_ids(self, cursor, ids, tree): """Return a list of Codelet objects given a list of codelet IDs.""" query = """SELECT * FROM codelets @@ -106,6 +115,7 @@ class Database(object): for codelet_id in ids: dict_cursor.execute(query, (codelet_id,)) row = dict_cursor.fetchall()[0] + code_id = row["code_id"] if row["origin_url_base"]: url = row["origin_url_base"] + row["codelet_url"] else: @@ -113,7 +123,7 @@ class Database(object): origin = (row["origin_name"], row["origin_url"], row["origin_image"]) authors = self._get_authors_for_codelet(cursor, codelet_id) - symbols = self._get_symbols_for_code(cursor, row["code_id"]) + symbols = self._get_symbols_for_code(cursor, code_id, tree) yield Codelet( row["codelet_name"], row["code_code"], None, row["code_lang"], authors, url, @@ -148,12 +158,12 @@ class Database(object): """Disconnect from the database.""" self._conn.close() - def search(self, query, page=1): + def search(self, tree, page=1): """ Search the database for a query and return the *n*\ th page of results. - :param query: The query to search for. - :type query: :py:class:`~.query.tree.Tree` + :param tree: The query to search for. + :type tree: :py:class:`~.query.tree.Tree` :param page: The result page to display. :type page: int @@ -169,7 +179,7 @@ class Database(object): query3 = "INSERT INTO cache VALUES (?, ?, ?, DEFAULT)" query4 = "INSERT INTO cache_data VALUES (?, ?, ?)" - cache_id = mmh3.hash64(str(page) + ":" + query.serialize())[0] + cache_id = mmh3.hash64(str(page) + ":" + tree.serialize())[0] with self._conn.cursor() as cursor: cursor.execute(query1, (cache_id,)) @@ -180,14 +190,14 @@ class Database(object): num_results = rows[0][1] * (10 ** rows[0][2]) if rows else 0 ids = [row[0] for row in rows] else: - ids, num_results = self._search_with_query(cursor, query, page) + ids, num_results = self._search_with_query(cursor, tree, page) num_exp = max(len(str(num_results)) - 3, 0) num_results = int(round(num_results, -num_exp)) num_mnt = num_results / (10 ** num_exp) cursor.execute(query3, (cache_id, num_mnt, num_exp)) cdata = [(cache_id, c_id, i) for i, c_id in enumerate(ids)] cursor.executemany(query4, cdata) - codelet_gen = self._get_codelets_from_ids(cursor, ids) + codelet_gen = self._get_codelets_from_ids(cursor, ids, tree) return (num_results, list(codelet_gen)) def insert(self, codelet): diff --git a/bitshift/query/tree.py b/bitshift/query/tree.py index 5da3f02..54461bc 100644 --- a/bitshift/query/tree.py +++ b/bitshift/query/tree.py @@ -1,3 +1,5 @@ +from . import nodes + __all__ = ["Tree"] QUERY_TEMPLATE = """SELECT codelet_id, (codelet_rank%s) AS score @@ -33,6 +35,18 @@ class Tree(object): """ return repr(self) + def walk(self, node_type=None): + """Walk through the query tree, returning nodes of a specific type.""" + pending = [self._root] + while pending: + node = pending.pop() + if not node_type or isinstance(node, node_type): + yield node + if isinstance(node, nodes.UnaryOp): + pending.append(node.node) + elif isinstance(node, nodes.BinaryOp): + pending.extend([node.left, node.right]) + def build_query(self, page=1, page_size=10): """Convert the query tree into a parameterized SQL SELECT statement.