diff --git a/bitshift/query/nodes.py b/bitshift/query/nodes.py index 342f8ec..905f65b 100644 --- a/bitshift/query/nodes.py +++ b/bitshift/query/nodes.py @@ -213,16 +213,15 @@ class Symbol(_Node): def parameterize(self, tables): tables |= {"code", "symbols"} if isinstance(self.name, Regex): - cond_base = "(symbol_type = ? AND symbol_name REGEXP ?)" - name = self.name.regex + cond, name = "symbol_name REGEXP ?", self.name.regex else: - cond_base = "(symbol_type = ? AND symbol_name = ?)" - name = self.name.string + cond, name = "symbol_name = ?", self.name.string + if self.type == self.ALL: + types = ", ".join(str(type_) for type_ in self.TYPES) + cond += " AND symbol_type IN (%s)" % types if self.type != self.ALL: - return cond_base, [], [self.type, name] - cond = "(" + " OR ".join([cond_base] * len(self.TYPES)) + ")" - args = zip(self.TYPES.keys(), [name] * len(self.TYPES)) - return cond, [], [arg for tup in args for arg in tup] + cond += " AND symbol_type = %d" % self.type + return "(" + cond + ")", [], [name] class BinaryOp(_Node): diff --git a/bitshift/query/tree.py b/bitshift/query/tree.py index 86392be..8989c31 100644 --- a/bitshift/query/tree.py +++ b/bitshift/query/tree.py @@ -46,20 +46,24 @@ class Tree(object): :return: SQL query data. :rtype: 2-tuple of (SQL statement string, query parameter tuple) """ - def get_table_join(table): - tables = { - "code": ("codelet_code_id", "code_id"), - "authors": ("author_codelet", "codelet_id"), - "symbols": ("symbol_code", "code_id") - } + def get_table_joins(tables): + data = [ + ("code", "codelet_code_id", "code_id"), + ("authors", "author_codelet", "codelet_id"), + ("symbols", "symbol_code", "code_id") + ] tmpl = "INNER JOIN %s ON %s = %s" - return tmpl % (table, tables[table][0], tables[table][1]) + for args in data: + if table in tables: + yield tmpl % args tables = set() cond, ranks, arglist = self._root.parameterize(tables) ranks = ranks or [cond] + # TODO: if the only rank is a single thing and it's a boolean value + # (i.e. not a match statement), get rid of it. score = "((%s) / %d)" % (" + ".join(ranks), len(ranks)) - joins = " ".join(get_table_join(table) for table in tables) + joins = " ".join(get_table_joins(tables)) offset = (page - 1) * page_size ## TODO: handle pretty