A semantic search engine for source code https://bitshift.benkurtovic.com/
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 
 
 
 
 

240 lines
9.9 KiB

  1. """
  2. Subpackage with classes and functions to handle communication with the MySQL
  3. database backend, which manages the search index.
  4. """
  5. import codecs
  6. import os
  7. import mmh3
  8. import oursql
  9. from .migration import VERSION, MIGRATIONS
  10. from ..codelet import Codelet
  11. from ..query.nodes import (String, Regex, Text, Language, Author, Date, Symbol,
  12. BinaryOp, UnaryOp)
  13. __all__ = ["Database"]
  14. class Database(object):
  15. """Represents the MySQL database."""
  16. def __init__(self, migrate=False):
  17. self._conn = self._connect()
  18. self._check_version(migrate)
  19. def _connect(self):
  20. """Establish a connection to the database."""
  21. try:
  22. codecs.lookup("utf8mb4")
  23. except LookupError:
  24. utf8 = codecs.lookup("utf8")
  25. codecs.register(lambda name: utf8 if name == "utf8mb4" else None)
  26. root = os.path.dirname(os.path.abspath(__file__))
  27. default_file = os.path.join(root, ".my.cnf")
  28. return oursql.connect(
  29. db="bitshift", read_default_file=default_file, autoping=True,
  30. autoreconnect=True, charset="utf8mb4")
  31. def _migrate(self, cursor, current):
  32. """Migrate the database to the latest schema version."""
  33. for version in xrange(current, VERSION):
  34. print "Migrating to %d..." % (version + 1)
  35. for query in MIGRATIONS[version - 1]:
  36. cursor.execute(query)
  37. cursor.execute("UPDATE version SET version = ?", (version + 1,))
  38. def _check_version(self, migrate):
  39. """Check the database schema version and respond accordingly.
  40. If the schema is out of date, migrate if *migrate* is True, else raise
  41. an exception.
  42. """
  43. with self._conn.cursor() as cursor:
  44. cursor.execute("SELECT version FROM version")
  45. version = cursor.fetchone()[0]
  46. if version < VERSION:
  47. if migrate:
  48. self._migrate(cursor, version)
  49. else:
  50. err = "Database schema out of date. " \
  51. "Run `python -m bitshift.database.migration`."
  52. raise RuntimeError(err)
  53. def _search_with_query(self, cursor, tree, page):
  54. """Execute an SQL query based on a query tree, and return results.
  55. The returned data is a 2-tuple of (list of codelet IDs, estimated
  56. number of total results).
  57. """
  58. query, args = tree.build_query(page)
  59. cursor.execute(query, args)
  60. ids = [cid for cid, _ in cursor.fetchall()]
  61. num_results = len(ids) # TODO: This is not entirely correct
  62. return ids, num_results
  63. def _get_authors_for_codelet(self, cursor, codelet_id):
  64. """Return a list of authors for a given codelet."""
  65. query = """SELECT author_name, author_url
  66. FROM authors
  67. WHERE author_codelet = ?"""
  68. cursor.execute(query, (codelet_id,))
  69. return cursor.fetchall()
  70. def _get_symbols_for_code(self, cursor, code_id, tree):
  71. """Return a list of symbols for a given codelet."""
  72. query = """SELECT symbol_type, symbol_name, sloc_type, sloc_row,
  73. sloc_col, sloc_end_row, sloc_end_col
  74. FROM symbols
  75. INNER JOIN symbol_locations ON sloc_symbol = symbol_id
  76. WHERE symbol_code = ? AND (%s)"""
  77. conds, args = [], [code_id]
  78. for node in tree.walk(Symbol):
  79. node_cond, node_args, _, _ = node.parameterize(set())
  80. conds.append(node_cond)
  81. args += node_args
  82. if not conds:
  83. return {}
  84. cond = " OR ".join(conds)
  85. symbols = {type_: {} for type_ in Symbol.TYPES}
  86. cursor.execute(query % cond, tuple(args))
  87. for type_, name, loc_type, row, col, erow, ecol in cursor.fetchall():
  88. sdict = symbols[Symbol.TYPES[type_]]
  89. if name not in sdict:
  90. sdict[name] = ([], [])
  91. sdict[name][loc_type].append((row, col, erow, ecol))
  92. for type_, sdict in symbols.items():
  93. symbols[type_] = [(n, d, u) for n, (d, u) in sdict.iteritems()]
  94. return symbols
  95. def _get_codelets_from_ids(self, cursor, ids, tree):
  96. """Return a list of Codelet objects given a list of codelet IDs."""
  97. query = """SELECT *
  98. FROM codelets
  99. INNER JOIN code ON codelet_code_id = code_id
  100. INNER JOIN origins ON codelet_origin = origin_id
  101. WHERE codelet_id = ?"""
  102. with self._conn.cursor(oursql.DictCursor) as dict_cursor:
  103. for codelet_id in ids:
  104. dict_cursor.execute(query, (codelet_id,))
  105. row = dict_cursor.fetchall()[0]
  106. code_id = row["code_id"]
  107. if row["origin_url_base"]:
  108. url = row["origin_url_base"] + row["codelet_url"]
  109. else:
  110. url = row["codelet_url"]
  111. origin = (row["origin_name"], row["origin_url"])
  112. authors = self._get_authors_for_codelet(cursor, codelet_id)
  113. symbols = self._get_symbols_for_code(cursor, code_id, tree)
  114. yield Codelet(
  115. row["codelet_name"], row["code_code"], None,
  116. row["code_lang"], authors, url,
  117. row["codelet_date_created"], row["codelet_date_modified"],
  118. row["codelet_rank"], symbols, origin)
  119. def _decompose_url(self, cursor, url):
  120. """Break up a URL into an origin (with a URL base) and a suffix."""
  121. query = """SELECT origin_id, SUBSTR(?, LENGTH(origin_url_base) + 1)
  122. FROM origins
  123. WHERE origin_url_base IS NOT NULL
  124. AND ? LIKE CONCAT(origin_url_base, "%")"""
  125. cursor.execute(query, (url, url))
  126. result = cursor.fetchone()
  127. return result if result else (1, url)
  128. def _insert_symbols(self, cursor, code_id, sym_type, symbols):
  129. """Insert a list of symbols of a given type into the database."""
  130. query1 = "INSERT INTO symbols VALUES (DEFAULT, ?, ?, ?)"
  131. query2 = """INSERT INTO symbol_locations VALUES
  132. (DEFAULT, ?, ?, ?, ?, ?, ?)"""
  133. build = lambda id, L, typ: [tuple([id, typ] + list(loc)) for loc in L]
  134. type_id = Symbol.TYPES.index(sym_type)
  135. for (name, defs, uses) in symbols:
  136. cursor.execute(query1, (code_id, type_id, name))
  137. sym_id = cursor.lastrowid
  138. params = (build(sym_id, defs, Symbol.DEFINE) +
  139. build(sym_id, uses, Symbol.USE))
  140. cursor.executemany(query2, params)
  141. def close(self):
  142. """Disconnect from the database."""
  143. self._conn.close()
  144. def search(self, tree, page=1):
  145. """
  146. Search the database for a query and return the *n*\ th page of results.
  147. :param tree: The query to search for.
  148. :type tree: :py:class:`~.query.tree.Tree`
  149. :param page: The result page to display.
  150. :type page: int
  151. :return: The total number of results, and the *n*\ th page of results.
  152. :rtype: 2-tuple of (long, list of :py:class:`.Codelet`\ s)
  153. """
  154. query1 = "SELECT 1 FROM cache WHERE cache_id = ?"
  155. query2 = """SELECT cdata_codelet, cache_count_mnt, cache_count_exp
  156. FROM cache
  157. INNER JOIN cache_data ON cache_id = cdata_cache
  158. WHERE cache_id = ?
  159. ORDER BY cdata_index ASC"""
  160. query3 = "INSERT INTO cache VALUES (?, ?, ?, DEFAULT)"
  161. query4 = "INSERT INTO cache_data VALUES (?, ?, ?)"
  162. cache_id = mmh3.hash64(str(page) + ":" + tree.serialize())[0]
  163. with self._conn.cursor() as cursor:
  164. cursor.execute(query1, (cache_id,))
  165. cache_hit = cursor.fetchall()
  166. if cache_hit:
  167. cursor.execute(query2, (cache_id,))
  168. rows = cursor.fetchall()
  169. num_results = rows[0][1] * (10 ** rows[0][2]) if rows else 0
  170. ids = [row[0] for row in rows]
  171. else:
  172. ids, num_results = self._search_with_query(cursor, tree, page)
  173. num_exp = max(len(str(num_results)) - 3, 0)
  174. num_results = int(round(num_results, -num_exp))
  175. num_mnt = num_results / (10 ** num_exp)
  176. cursor.execute(query3, (cache_id, num_mnt, num_exp))
  177. cdata = [(cache_id, c_id, i) for i, c_id in enumerate(ids)]
  178. cursor.executemany(query4, cdata)
  179. codelet_gen = self._get_codelets_from_ids(cursor, ids, tree)
  180. return (num_results, list(codelet_gen))
  181. def insert(self, codelet):
  182. """
  183. Insert a codelet into the database.
  184. :param codelet: The codelet to insert.
  185. :type codelet: :py:class:`.Codelet`
  186. """
  187. query1 = """INSERT INTO code VALUES (?, ?, ?)
  188. ON DUPLICATE KEY UPDATE code_id=code_id"""
  189. query2 = """INSERT INTO codelets VALUES
  190. (DEFAULT, ?, ?, ?, ?, ?, ?, ?)"""
  191. query3 = "INSERT INTO authors VALUES (DEFAULT, ?, ?, ?)"
  192. hash_key = str(codelet.language) + ":" + codelet.code.encode("utf8")
  193. code_id = mmh3.hash64(hash_key)[0]
  194. with self._conn.cursor() as cursor:
  195. cursor.execute(query1, (code_id, codelet.language, codelet.code))
  196. if cursor.rowcount == 1:
  197. for sym_type, symbols in codelet.symbols.iteritems():
  198. self._insert_symbols(cursor, code_id, sym_type, symbols)
  199. origin, url = self._decompose_url(cursor, codelet.url)
  200. cursor.execute(query2, (codelet.name, code_id, origin, url,
  201. codelet.rank, codelet.date_created,
  202. codelet.date_modified))
  203. codelet_id = cursor.lastrowid
  204. authors = [(codelet_id, a[0], a[1]) for a in codelet.authors]
  205. cursor.executemany(query3, authors)