Skip to content

Commit

Permalink
Merge pull request #50 from cs50/foreign-key-constraint
Browse files Browse the repository at this point in the history
added foreign key constraint support to SQLite
  • Loading branch information
Kareem Zidane authored Apr 26, 2018
2 parents f19934d + afb6cba commit bdb8128
Show file tree
Hide file tree
Showing 4 changed files with 45 additions and 8 deletions.
2 changes: 1 addition & 1 deletion .travis.yml
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ install:
before_script:
- mysql -e 'CREATE DATABASE IF NOT EXISTS test;'
- psql -c 'create database test;' -U postgres
- touch test.db
- touch test.db test1.db
script: python tests/sql.py
after_script: rm -f test.db
jobs:
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,5 +16,5 @@
package_dir={"": "src"},
packages=["cs50"],
url="https://github.com/cs50/python-cs50",
version="2.4.0"
version="2.4.1"
)
35 changes: 32 additions & 3 deletions src/cs50/sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import os
import re
import sqlalchemy
import sqlite3
import sqlparse
import sys
import termcolor
Expand Down Expand Up @@ -32,12 +33,25 @@ def __init__(self, url, **kwargs):
if not os.path.isfile(matches.group(1)):
raise RuntimeError("not a file: {}".format(matches.group(1)))

# Create engine, raising exception if back end's module not installed
self.engine = sqlalchemy.create_engine(url, **kwargs)
# Remember foreign_keys and remove it from kwargs
foreign_keys = kwargs.pop("foreign_keys", False)

# Create engine, raising exception if back end's module not installed
self.engine = sqlalchemy.create_engine(url, **kwargs)

# Enable foreign key constraints
if foreign_keys:
sqlalchemy.event.listen(self.engine, "connect", _connect)
else:

# Create engine, raising exception if back end's module not installed
self.engine = sqlalchemy.create_engine(url, **kwargs)


# Log statements to standard error
logging.basicConfig(level=logging.DEBUG)
self.logger = logging.getLogger("cs50")
disabled = self.logger.disabled

# Test database
try:
Expand All @@ -48,7 +62,7 @@ def __init__(self, url, **kwargs):
e.__cause__ = None
raise e
else:
self.logger.disabled = False
self.logger.disabled = disabled

def _parse(self, e):
"""Parses an exception, returns its message."""
Expand Down Expand Up @@ -133,6 +147,8 @@ def process(value):
return process(value)

# Allow only one statement at a time
# SQLite does not support executing many statements
# https://docs.python.org/3/library/sqlite3.html#sqlite3.Cursor.execute
if len(sqlparse.split(text)) > 1:
raise RuntimeError("too many statements at once")

Expand Down Expand Up @@ -211,3 +227,16 @@ def process(value):
else:
self.logger.debug(termcolor.colored(log, "green"))
return ret


# http://docs.sqlalchemy.org/en/latest/dialects/sqlite.html#foreign-key-support
def _connect(dbapi_connection, connection_record):
"""Enables foreign key support."""

# Ensure backend is sqlite
if type(dbapi_connection) is sqlite3.Connection:
cursor = dbapi_connection.cursor()

# Respect foreign key constraints by default
cursor.execute("PRAGMA foreign_keys=ON")
cursor.close()
14 changes: 11 additions & 3 deletions tests/sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,18 +107,26 @@ class SQLiteTests(SQLTests):
@classmethod
def setUpClass(self):
self.db = SQL("sqlite:///test.db")
self.db1 = SQL("sqlite:///test1.db", foreign_keys=True)

def setUp(self):
self.db.execute("CREATE TABLE cs50(id INTEGER PRIMARY KEY, val TEXT)")

def multi_inserts_enabled(self):
return False
def test_foreign_key_support(self):
self.db.execute("CREATE TABLE foo(id INTEGER PRIMARY KEY)")
self.db.execute("CREATE TABLE bar(foo_id INTEGER, FOREIGN KEY (foo_id) REFERENCES foo(id))")
self.assertEqual(self.db.execute("INSERT INTO bar VALUES(50)"), 1)

self.db1.execute("CREATE TABLE foo(id INTEGER PRIMARY KEY)")
self.db1.execute("CREATE TABLE bar(foo_id INTEGER, FOREIGN KEY (foo_id) REFERENCES foo(id))")
self.assertEqual(self.db1.execute("INSERT INTO bar VALUES(50)"), None)

if __name__ == "__main__":
suite = unittest.TestSuite([
unittest.TestLoader().loadTestsFromTestCase(SQLiteTests),
unittest.TestLoader().loadTestsFromTestCase(MySQLTests),
unittest.TestLoader().loadTestsFromTestCase(PostgresTests)
])
logging.getLogger("cs50.sql").disabled = True

logging.getLogger("cs50").disabled = True
sys.exit(not unittest.TextTestRunner(verbosity=2).run(suite).wasSuccessful())

0 comments on commit bdb8128

Please sign in to comment.