A semantic search engine for source code https://bitshift.benkurtovic.com/
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 
 
 
 
 

128 rivejä
4.7 KiB

  1. """
  2. Subpackage with classes and functions to handle communication with the MySQL
  3. database backend, which manages the search index.
  4. """
  5. import os
  6. import mmh3
  7. import oursql
  8. from .migration import VERSION, MIGRATIONS
  9. __all__ = ["Database"]
  10. class Database(object):
  11. """Represents the MySQL database."""
  12. def __init__(self, migrate=False):
  13. self._conn = self._connect()
  14. self._check_version(migrate)
  15. def _connect(self):
  16. """Establish a connection to the database."""
  17. root = os.path.dirname(os.path.abspath(__file__))
  18. default_file = os.path.join(root, ".my.cnf")
  19. return oursql.connect(db="bitshift", read_default_file=default_file,
  20. autoping=True, autoreconnect=True)
  21. def _migrate(self, cursor, current):
  22. """Migrate the database to the latest schema version."""
  23. for version in xrange(current, VERSION):
  24. print "Migrating to %d..." % version + 1
  25. for query in MIGRATIONS[version - 1]:
  26. cursor.execute(query)
  27. cursor.execute("UPDATE version SET version = ?", (version + 1,))
  28. def _check_version(self, migrate):
  29. """Check the database schema version and respond accordingly.
  30. If the schema is out of date, migrate if *migrate* is True, else raise
  31. an exception.
  32. """
  33. with self._conn.cursor() as cursor:
  34. cursor.execute("SELECT version FROM version")
  35. version = cursor.fetchone()[0]
  36. if version < VERSION:
  37. if migrate:
  38. self._migrate(cursor, version)
  39. else:
  40. err = "Database schema out of date. " \
  41. "Run `python -m bitshift.database.migration`."
  42. raise RuntimeError(err)
  43. def _decompose_url(self, url):
  44. """Break up a URL into an origin (with a URL base) and a suffix."""
  45. pass ## TODO
  46. def _insert_symbols(self, cursor, code_id, sym_type, symbols):
  47. """Insert a list of symbols of a given type into the database."""
  48. sym_types = ["functions", "classes", "variables"]
  49. query1 = "INSERT INTO symbols VALUES (DEFAULT, ?, ?, ?)"
  50. query2 = """INSERT INTO symbol_locations VALUES
  51. (DEFAULT, ?, ?, ?, ?, ?, ?)"""
  52. for (name, decls, uses) in symbols:
  53. cursor.execute(query1, (code_id, sym_types.index(sym_type), name))
  54. sym_id = cursor.lastrowid
  55. params = ([tuple([sym_id, 0] + list(loc)) for loc in decls] +
  56. [tuple([sym_id, 1] + list(loc)) for loc in uses])
  57. cursor.executemany(query2, params)
  58. def close(self):
  59. """Disconnect from the database."""
  60. self._conn.close()
  61. def search(self, query, page=1):
  62. """
  63. Search the database for a query and return the *n*\ th page of results.
  64. :param query: The query to search for.
  65. :type query: :py:class:`~.query.tree.Tree`
  66. :param page: The result page to display.
  67. :type page: int
  68. :return: A list of search results.
  69. :rtype: list of :py:class:`.Codelet`\ s
  70. """
  71. # search for cache_hash = mmh3.hash(query.serialize() + str(page))
  72. # cache HIT:
  73. # update cache_last_used
  74. # return codelets
  75. # cache MISS:
  76. # build complex search query
  77. # fetch codelets
  78. # cache results
  79. # return codelets
  80. pass
  81. def insert(self, codelet):
  82. """
  83. Insert a codelet into the database.
  84. :param codelet: The codelet to insert.
  85. :type codelet: :py:class:`.Codelet`
  86. """
  87. query1 = """INSERT INTO code VALUES (?, ?)
  88. ON DUPLICATE KEY UPDATE code_id=code_id"""
  89. query2 = """INSERT INTO codelets VALUES
  90. (DEFAULT, ?, ?, ?, ?, ?, ?, ?, ?)"""
  91. query3 = "INSERT INTO authors VALUES (DEFAULT, ?, ?, ?)"
  92. code_id = mmh3.hash64(codelet.code.encode("utf8"))[0]
  93. origin, url = self._decompose_url(codelet.url)
  94. with self._conn.cursor() as cursor:
  95. cursor.execute(query1, (code_id, codelet.code))
  96. new_code = cursor.rowcount == 1
  97. cursor.execute(query2, (codelet.name, code_id, codelet.language,
  98. origin, url, codelet.rank,
  99. codelet.date_created,
  100. codelet.date_modified))
  101. codelet_id = cursor.lastrowid
  102. authors = [(codelet_id, a[0], a[1]) for a in codelet.authors]
  103. cursor.executemany(query3, authors)
  104. if new_code:
  105. for sym_type, symbols in codelet.symbols.iteritems():
  106. self._insert_symbols(cursor, code_id, sym_type, symbols)