diff --git a/.travis.yml b/.travis.yml index a3dfb1d..82338b0 100644 --- a/.travis.yml +++ b/.travis.yml @@ -14,7 +14,7 @@ install: before_script: - mysql -e 'CREATE DATABASE IF NOT EXISTS test;' - psql -c 'create database test;' -U postgres -- touch test.db +- touch test.db test1.db script: python tests/sql.py after_script: rm -f test.db jobs: diff --git a/setup.py b/setup.py index 9f11ae5..1a9c080 100644 --- a/setup.py +++ b/setup.py @@ -16,5 +16,5 @@ package_dir={"": "src"}, packages=["cs50"], url="https://github.com/cs50/python-cs50", - version="2.4.0" + version="2.4.1" ) diff --git a/src/cs50/sql.py b/src/cs50/sql.py index eef4af9..2ede1ca 100644 --- a/src/cs50/sql.py +++ b/src/cs50/sql.py @@ -5,6 +5,7 @@ import os import re import sqlalchemy +import sqlite3 import sqlparse import sys import termcolor @@ -32,12 +33,25 @@ def __init__(self, url, **kwargs): if not os.path.isfile(matches.group(1)): raise RuntimeError("not a file: {}".format(matches.group(1))) - # Create engine, raising exception if back end's module not installed - self.engine = sqlalchemy.create_engine(url, **kwargs) + # Remember foreign_keys and remove it from kwargs + foreign_keys = kwargs.pop("foreign_keys", False) + + # Create engine, raising exception if back end's module not installed + self.engine = sqlalchemy.create_engine(url, **kwargs) + + # Enable foreign key constraints + if foreign_keys: + sqlalchemy.event.listen(self.engine, "connect", _connect) + else: + + # Create engine, raising exception if back end's module not installed + self.engine = sqlalchemy.create_engine(url, **kwargs) + # Log statements to standard error logging.basicConfig(level=logging.DEBUG) self.logger = logging.getLogger("cs50") + disabled = self.logger.disabled # Test database try: @@ -48,7 +62,7 @@ def __init__(self, url, **kwargs): e.__cause__ = None raise e else: - self.logger.disabled = False + self.logger.disabled = disabled def _parse(self, e): """Parses an exception, returns its message.""" @@ -133,6 +147,8 @@ def process(value): return process(value) # Allow only one statement at a time + # SQLite does not support executing many statements + # https://docs.python.org/3/library/sqlite3.html#sqlite3.Cursor.execute if len(sqlparse.split(text)) > 1: raise RuntimeError("too many statements at once") @@ -211,3 +227,16 @@ def process(value): else: self.logger.debug(termcolor.colored(log, "green")) return ret + + +# http://docs.sqlalchemy.org/en/latest/dialects/sqlite.html#foreign-key-support +def _connect(dbapi_connection, connection_record): + """Enables foreign key support.""" + + # Ensure backend is sqlite + if type(dbapi_connection) is sqlite3.Connection: + cursor = dbapi_connection.cursor() + + # Respect foreign key constraints by default + cursor.execute("PRAGMA foreign_keys=ON") + cursor.close() diff --git a/tests/sql.py b/tests/sql.py index c2af662..67b4e94 100644 --- a/tests/sql.py +++ b/tests/sql.py @@ -107,12 +107,19 @@ class SQLiteTests(SQLTests): @classmethod def setUpClass(self): self.db = SQL("sqlite:///test.db") + self.db1 = SQL("sqlite:///test1.db", foreign_keys=True) def setUp(self): self.db.execute("CREATE TABLE cs50(id INTEGER PRIMARY KEY, val TEXT)") - def multi_inserts_enabled(self): - return False + def test_foreign_key_support(self): + self.db.execute("CREATE TABLE foo(id INTEGER PRIMARY KEY)") + self.db.execute("CREATE TABLE bar(foo_id INTEGER, FOREIGN KEY (foo_id) REFERENCES foo(id))") + self.assertEqual(self.db.execute("INSERT INTO bar VALUES(50)"), 1) + + self.db1.execute("CREATE TABLE foo(id INTEGER PRIMARY KEY)") + self.db1.execute("CREATE TABLE bar(foo_id INTEGER, FOREIGN KEY (foo_id) REFERENCES foo(id))") + self.assertEqual(self.db1.execute("INSERT INTO bar VALUES(50)"), None) if __name__ == "__main__": suite = unittest.TestSuite([ @@ -120,5 +127,6 @@ def multi_inserts_enabled(self): unittest.TestLoader().loadTestsFromTestCase(MySQLTests), unittest.TestLoader().loadTestsFromTestCase(PostgresTests) ]) - logging.getLogger("cs50.sql").disabled = True + + logging.getLogger("cs50").disabled = True sys.exit(not unittest.TextTestRunner(verbosity=2).run(suite).wasSuccessful())