A semantic search engine for source code https://bitshift.benkurtovic.com/
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 
 
 
 
 

231 lines
9.6 KiB

  1. """
  2. Subpackage with classes and functions to handle communication with the MySQL
  3. database backend, which manages the search index.
  4. """
  5. import os
  6. import mmh3
  7. import oursql
  8. from .migration import VERSION, MIGRATIONS
  9. from ..codelet import Codelet
  10. from ..query.nodes import (String, Regex, Text, Language, Author, Date, Symbol,
  11. BinaryOp, UnaryOp)
  12. __all__ = ["Database"]
  13. class Database(object):
  14. """Represents the MySQL database."""
  15. def __init__(self, migrate=False):
  16. self._conn = self._connect()
  17. self._check_version(migrate)
  18. def _connect(self):
  19. """Establish a connection to the database."""
  20. root = os.path.dirname(os.path.abspath(__file__))
  21. default_file = os.path.join(root, ".my.cnf")
  22. return oursql.connect(db="bitshift", read_default_file=default_file,
  23. autoping=True, autoreconnect=True)
  24. def _migrate(self, cursor, current):
  25. """Migrate the database to the latest schema version."""
  26. for version in xrange(current, VERSION):
  27. print "Migrating to %d..." % (version + 1)
  28. for query in MIGRATIONS[version - 1]:
  29. cursor.execute(query)
  30. cursor.execute("UPDATE version SET version = ?", (version + 1,))
  31. def _check_version(self, migrate):
  32. """Check the database schema version and respond accordingly.
  33. If the schema is out of date, migrate if *migrate* is True, else raise
  34. an exception.
  35. """
  36. with self._conn.cursor() as cursor:
  37. cursor.execute("SELECT version FROM version")
  38. version = cursor.fetchone()[0]
  39. if version < VERSION:
  40. if migrate:
  41. self._migrate(cursor, version)
  42. else:
  43. err = "Database schema out of date. " \
  44. "Run `python -m bitshift.database.migration`."
  45. raise RuntimeError(err)
  46. def _search_with_query(self, cursor, tree, page):
  47. """Execute an SQL query based on a query tree, and return results.
  48. The returned data is a 2-tuple of (list of codelet IDs, estimated
  49. number of total results).
  50. """
  51. query, args = tree.build_query(page)
  52. cursor.execute(query, args)
  53. ids = [cid for cid, _ in cursor.fetchall()]
  54. num_results = len(ids) # TODO: This is not entirely correct
  55. return ids, num_results
  56. def _get_authors_for_codelet(self, cursor, codelet_id):
  57. """Return a list of authors for a given codelet."""
  58. query = """SELECT author_name, author_url
  59. FROM authors
  60. WHERE author_codelet = ?"""
  61. cursor.execute(query, (codelet_id,))
  62. return cursor.fetchall()
  63. def _get_symbols_for_code(self, cursor, code_id, tree):
  64. """Return a list of symbols for a given codelet."""
  65. query = """SELECT symbol_type, symbol_name, sloc_type, sloc_row,
  66. sloc_col, sloc_end_row, sloc_end_col
  67. FROM symbols
  68. INNER JOIN symbol_locations ON sloc_symbol = symbol_id
  69. WHERE symbol_code = ? AND (%s)"""
  70. conds, args = [], [code_id]
  71. for node in tree.walk(Symbol):
  72. node_cond, node_args, _, _ = node.parameterize(set())
  73. conds.append(node_cond)
  74. args += node_args
  75. if not conds:
  76. return {}
  77. cond = " OR ".join(conds)
  78. symbols = {type_: {} for type_ in Symbol.TYPES}
  79. cursor.execute(query % cond, tuple(args))
  80. for type_, name, loc_type, row, col, erow, ecol in cursor.fetchall():
  81. sdict = symbols[Symbol.TYPES[type_]]
  82. if name not in sdict:
  83. sdict[name] = ([], [])
  84. sdict[name][loc_type].append((row, col, erow, ecol))
  85. for type_, sdict in symbols.items():
  86. symbols[type_] = [(n, d, u) for n, (d, u) in sdict.iteritems()]
  87. return symbols
  88. def _get_codelets_from_ids(self, cursor, ids, tree):
  89. """Return a list of Codelet objects given a list of codelet IDs."""
  90. query = """SELECT *
  91. FROM codelets
  92. INNER JOIN code ON codelet_code_id = code_id
  93. INNER JOIN origins ON codelet_origin = origin_id
  94. WHERE codelet_id = ?"""
  95. with self._conn.cursor(oursql.DictCursor) as dict_cursor:
  96. for codelet_id in ids:
  97. dict_cursor.execute(query, (codelet_id,))
  98. row = dict_cursor.fetchall()[0]
  99. code_id = row["code_id"]
  100. if row["origin_url_base"]:
  101. url = row["origin_url_base"] + row["codelet_url"]
  102. else:
  103. url = row["codelet_url"]
  104. origin = (row["origin_name"], row["origin_url"])
  105. authors = self._get_authors_for_codelet(cursor, codelet_id)
  106. symbols = self._get_symbols_for_code(cursor, code_id, tree)
  107. yield Codelet(
  108. row["codelet_name"], row["code_code"], None,
  109. row["code_lang"], authors, url,
  110. row["codelet_date_created"], row["codelet_date_modified"],
  111. row["codelet_rank"], symbols, origin)
  112. def _decompose_url(self, cursor, url):
  113. """Break up a URL into an origin (with a URL base) and a suffix."""
  114. query = """SELECT origin_id, SUBSTR(?, LENGTH(origin_url_base) + 1)
  115. FROM origins
  116. WHERE origin_url_base IS NOT NULL
  117. AND ? LIKE CONCAT(origin_url_base, "%")"""
  118. cursor.execute(query, (url, url))
  119. result = cursor.fetchone()
  120. return result if result else (1, url)
  121. def _insert_symbols(self, cursor, code_id, sym_type, symbols):
  122. """Insert a list of symbols of a given type into the database."""
  123. query1 = "INSERT INTO symbols VALUES (DEFAULT, ?, ?, ?)"
  124. query2 = """INSERT INTO symbol_locations VALUES
  125. (DEFAULT, ?, ?, ?, ?, ?, ?)"""
  126. type_id = Symbol.TYPES.index(sym_type)
  127. for (name, assigns, uses) in symbols:
  128. cursor.execute(query1, (code_id, type_id, name))
  129. sym_id = cursor.lastrowid
  130. params = ([tuple([sym_id, 0] + list(loc)) for loc in assigns] +
  131. [tuple([sym_id, 1] + list(loc)) for loc in uses])
  132. cursor.executemany(query2, params)
  133. def close(self):
  134. """Disconnect from the database."""
  135. self._conn.close()
  136. def search(self, tree, page=1):
  137. """
  138. Search the database for a query and return the *n*\ th page of results.
  139. :param tree: The query to search for.
  140. :type tree: :py:class:`~.query.tree.Tree`
  141. :param page: The result page to display.
  142. :type page: int
  143. :return: The total number of results, and the *n*\ th page of results.
  144. :rtype: 2-tuple of (long, list of :py:class:`.Codelet`\ s)
  145. """
  146. query1 = "SELECT 1 FROM cache WHERE cache_id = ?"
  147. query2 = """SELECT cdata_codelet, cache_count_mnt, cache_count_exp
  148. FROM cache
  149. INNER JOIN cache_data ON cache_id = cdata_cache
  150. WHERE cache_id = ?
  151. ORDER BY cdata_index ASC"""
  152. query3 = "INSERT INTO cache VALUES (?, ?, ?, DEFAULT)"
  153. query4 = "INSERT INTO cache_data VALUES (?, ?, ?)"
  154. cache_id = mmh3.hash64(str(page) + ":" + tree.serialize())[0]
  155. with self._conn.cursor() as cursor:
  156. cursor.execute(query1, (cache_id,))
  157. cache_hit = cursor.fetchall()
  158. if cache_hit:
  159. cursor.execute(query2, (cache_id,))
  160. rows = cursor.fetchall()
  161. num_results = rows[0][1] * (10 ** rows[0][2]) if rows else 0
  162. ids = [row[0] for row in rows]
  163. else:
  164. ids, num_results = self._search_with_query(cursor, tree, page)
  165. num_exp = max(len(str(num_results)) - 3, 0)
  166. num_results = int(round(num_results, -num_exp))
  167. num_mnt = num_results / (10 ** num_exp)
  168. cursor.execute(query3, (cache_id, num_mnt, num_exp))
  169. cdata = [(cache_id, c_id, i) for i, c_id in enumerate(ids)]
  170. cursor.executemany(query4, cdata)
  171. codelet_gen = self._get_codelets_from_ids(cursor, ids, tree)
  172. return (num_results, list(codelet_gen))
  173. def insert(self, codelet):
  174. """
  175. Insert a codelet into the database.
  176. :param codelet: The codelet to insert.
  177. :type codelet: :py:class:`.Codelet`
  178. """
  179. query1 = """INSERT INTO code VALUES (?, ?, ?)
  180. ON DUPLICATE KEY UPDATE code_id=code_id"""
  181. query2 = """INSERT INTO codelets VALUES
  182. (DEFAULT, ?, ?, ?, ?, ?, ?, ?)"""
  183. query3 = "INSERT INTO authors VALUES (DEFAULT, ?, ?, ?)"
  184. hash_key = str(codelet.language) + ":" + codelet.code.encode("utf8")
  185. code_id = mmh3.hash64(hash_key)[0]
  186. with self._conn.cursor() as cursor:
  187. cursor.execute(query1, (code_id, codelet.language, codelet.code))
  188. if cursor.rowcount == 1:
  189. for sym_type, symbols in codelet.symbols.iteritems():
  190. self._insert_symbols(cursor, code_id, sym_type, symbols)
  191. origin, url = self._decompose_url(cursor, codelet.url)
  192. cursor.execute(query2, (codelet.name, code_id, origin, url,
  193. codelet.rank, codelet.date_created,
  194. codelet.date_modified))
  195. codelet_id = cursor.lastrowid
  196. authors = [(codelet_id, a[0], a[1]) for a in codelet.authors]
  197. cursor.executemany(query3, authors)