@@ -53,61 +53,18 @@ class Database(object): | |||
"Run `python -m bitshift.database.migration`." | |||
raise RuntimeError(err) | |||
def _explode_query_tree(self, tree): | |||
"""Convert a query tree into components of an SQL SELECT statement.""" | |||
def _parse_node(node, tables): | |||
if isinstance(node, Text): | |||
tables |= {"code", "symbols"} | |||
# (FTS: codelet_name, =: symbol_name, FTS: code_code) vs. node.text (_Literal) | |||
pass | |||
elif isinstance(node, Language): | |||
tables |= {"code"} | |||
return "(code_lang = ?)", tables, [node.lang] | |||
elif isinstance(node, Author): | |||
tables |= {"authors"} | |||
if isinstance(node.name, Regex): | |||
return "(author_name REGEXP ?)", [node.name.regex] | |||
cond = "(MATCH(author_name) AGAINST (? IN BOOLEAN MODE))" | |||
return cond, tables, [node.name.string] | |||
elif isinstance(node, Date): | |||
column = {node.CREATE: "codelet_date_created", | |||
node.MODIFY: "codelet_date_modified"}[node.type] | |||
op = {node.BEFORE: "<=", node.AFTER: ">="}[node.relation] | |||
return "(" + column + " " + op + " ?)", tables, [node.date] | |||
elif isinstance(node, Symbol): | |||
tables |= {"symbols"} | |||
cond_base = "(symbol_type = ? AND symbol_name = ?)" | |||
if node.type != node.ALL: | |||
return cond_base, tables, [node.type, node.name] | |||
cond = "(" + " OR ".join([cond_base] * len(node.TYPES)) + ")" | |||
args = zip(node.TYPES.keys(), [node.name] * len(node.TYPES)) | |||
return cond, tables, [arg for tup in args for arg in tup] | |||
elif isinstance(node, BinaryOp): | |||
left_cond, tbls, left_args = _parse_node(node.left, tables) | |||
right_cond, tables, right_args = _parse_node(node.right, tbls) | |||
op = node.OPS[node.op] | |||
cond = "(" + left_cond + " " + op + " " + right_cond + ")" | |||
return cond, tables, left_args + right_args | |||
elif isinstance(node, UnaryOp): | |||
cond, tables, args = _parse_node(node.node, tables) | |||
return "(" + node.OPS[node.op] + " " + cond + ")", tables, args | |||
conditional, tables, arglist = _parse_node(tree.root, set()) | |||
# joins = " ".join(tables) | |||
return conditional, joins, tuple(arglist) | |||
def _search_with_query(self, cursor, query, page): | |||
"""Execute an SQL query based on a query tree, and return results. | |||
The returned data is a 2-tuple of (list of codelet IDs, estimated | |||
number of total results). | |||
""" | |||
conditional, joins, args = self._explode_query_tree(query) | |||
base = """SELECT codelet_id | |||
FROM codelets %s | |||
WHERE %s | |||
ORDER BY codelet_rank LIMIT 10""" | |||
conditional, tables, args = query.parameterize() | |||
joins = " ".join(tables) | |||
qstring = base % (joins, conditional) | |||
if page > 1: | |||
qstring += " OFFSET %d" % ((page - 1) * 10) | |||
@@ -139,7 +139,7 @@ class _QueryParser(object): | |||
Returns a 2-tuple of (first_marker_found, marker_index). | |||
""" | |||
def _is_escaped(query, index): | |||
def is_escaped(query, index): | |||
"""Return whether a query marker is backslash-escaped.""" | |||
return (index > 0 and query[index - 1] == "\\" and | |||
(index < 2 or query[index - 2] != "\\")) | |||
@@ -147,7 +147,7 @@ class _QueryParser(object): | |||
best_marker, best_index = None, maxsize | |||
for marker in markers: | |||
index = query.find(marker) | |||
if _is_escaped(query, index): | |||
if is_escaped(query, index): | |||
_, new_index = self._scan_query(query[index + 1:], marker) | |||
index += new_index + 1 | |||
if index >= 0 and index < best_index: | |||
@@ -15,6 +15,13 @@ class _Node(object): | |||
"""Return a string sort key for the node.""" | |||
return "" | |||
def parameterize(self, tables): | |||
"""Parameterize the node. | |||
Returns a 3-tuple of (query conditional string, table set, param list). | |||
""" | |||
return "", tables, [] | |||
class _Literal(object): | |||
"""Represents a literal component of a search query, present at the leaves. | |||
@@ -75,6 +82,11 @@ class Text(_Node): | |||
def sortkey(self): | |||
return self.text.sortkey() | |||
def parameterize(self, tables): | |||
tables |= {"code", "symbols"} | |||
# (FTS: codelet_name, =: symbol_name, FTS: code_code) vs. node.text (_Literal) | |||
pass | |||
class Language(_Node): | |||
"""Represents a language node. | |||
@@ -94,6 +106,10 @@ class Language(_Node): | |||
def sortkey(self): | |||
return LANGS[self.lang] | |||
def parameterize(self, tables): | |||
tables |= {"code"} | |||
return "(code_lang = ?)", tables, [self.lang] | |||
class Author(_Node): | |||
"""Represents a author node. | |||
@@ -113,6 +129,13 @@ class Author(_Node): | |||
def sortkey(self): | |||
return self.name.sortkey() | |||
def parameterize(self, tables): | |||
tables |= {"authors"} | |||
if isinstance(self.name, Regex): | |||
return "(author_name REGEXP ?)", [self.name.regex] | |||
cond = "(MATCH(author_name) AGAINST (? IN BOOLEAN MODE))" | |||
return cond, tables, [self.name.string] | |||
class Date(_Node): | |||
"""Represents a date node. | |||
@@ -144,6 +167,12 @@ class Date(_Node): | |||
def sortkey(self): | |||
return self.date.strftime("%Y%m%d%H%M%S") | |||
def parameterize(self, tables): | |||
column = {self.CREATE: "codelet_date_created", | |||
self.MODIFY: "codelet_date_modified"}[self.type] | |||
op = {self.BEFORE: "<=", self.AFTER: ">="}[self.relation] | |||
return "(" + column + " " + op + " ?)", tables, [self.date] | |||
class Symbol(_Node): | |||
"""Represents a symbol node. | |||
@@ -171,6 +200,15 @@ class Symbol(_Node): | |||
def sortkey(self): | |||
return self.name.sortkey() | |||
def parameterize(self, tables): | |||
tables |= {"symbols"} | |||
cond_base = "(symbol_type = ? AND symbol_name = ?)" | |||
if self.type != self.ALL: | |||
return cond_base, tables, [self.type, self.name] | |||
cond = "(" + " OR ".join([cond_base] * len(self.TYPES)) + ")" | |||
args = zip(self.TYPES.keys(), [self.name] * len(self.TYPES)) | |||
return cond, tables, [arg for tup in args for arg in tup] | |||
class BinaryOp(_Node): | |||
"""Represents a relationship between two nodes: ``and``, ``or``.""" | |||
@@ -190,6 +228,13 @@ class BinaryOp(_Node): | |||
def sortkey(self): | |||
return self.left.sortkey() + self.right.sortkey() | |||
def parameterize(self, tables): | |||
left_cond, tables, left_args = self.left.parameterize(tables) | |||
right_cond, tables, right_args = self.right.parameterize(tables) | |||
op = self.OPS[self.op] | |||
cond = "(" + left_cond + " " + op + " " + right_cond + ")" | |||
return cond, tables, left_args + right_args | |||
class UnaryOp(_Node): | |||
"""Represents a transformation applied to one node: ``not``.""" | |||
@@ -205,3 +250,7 @@ class UnaryOp(_Node): | |||
def sortkey(self): | |||
return self.node.sortkey() | |||
def parameterize(self, tables): | |||
cond, tables, args = self.node.parameterize(tables) | |||
return "(" + self.OPS[self.op] + " " + cond + ")", tables, args |
@@ -25,3 +25,12 @@ class Tree(object): | |||
:rtype: str | |||
""" | |||
return repr(self) | |||
def parameterize(self): | |||
"""Parameterize the query tree for an SQL SELECT statement. | |||
:return: SQL query data. | |||
:rtype: 3-tuple of (query conditional string, table set, param tuple) | |||
""" | |||
conditional, tables, arglist = self._root.parameterize(set()) | |||
return conditional, tables, tuple(arglist) |