From 1b644dd0c4a4fd252306933ed2a1247ffda8de83 Mon Sep 17 00:00:00 2001 From: "David J. Malan" Date: Sun, 17 Dec 2023 14:20:44 -0500 Subject: [PATCH] updated IO wrapper, style, version --- setup.py | 2 +- src/cs50/__init__.py | 1 + src/cs50/cs50.py | 49 +++++++---- src/cs50/flask.py | 10 ++- src/cs50/sql.py | 199 ++++++++++++++++++++++++++++++------------- 5 files changed, 184 insertions(+), 77 deletions(-) diff --git a/setup.py b/setup.py index 23f6b01..10ceb30 100644 --- a/setup.py +++ b/setup.py @@ -18,5 +18,5 @@ package_dir={"": "src"}, packages=["cs50"], url="https://github.com/cs50/python-cs50", - version="9.3.0" + version="9.3.1" ) diff --git a/src/cs50/__init__.py b/src/cs50/__init__.py index aaec161..7dd4e17 100644 --- a/src/cs50/__init__.py +++ b/src/cs50/__init__.py @@ -8,6 +8,7 @@ # Import cs50_* from .cs50 import get_char, get_float, get_int, get_string + try: from .cs50 import get_long except ImportError: diff --git a/src/cs50/cs50.py b/src/cs50/cs50.py index 07f13e9..f331a88 100644 --- a/src/cs50/cs50.py +++ b/src/cs50/cs50.py @@ -17,7 +17,9 @@ try: # Patch formatException - logging.root.handlers[0].formatter.formatException = lambda exc_info: _formatException(*exc_info) + logging.root.handlers[ + 0 + ].formatter.formatException = lambda exc_info: _formatException(*exc_info) except IndexError: pass @@ -37,26 +39,31 @@ _logger.addHandler(handler) -class _flushfile(): +class _Unbuffered: """ Disable buffering for standard output and standard error. - http://stackoverflow.com/a/231216 + https://stackoverflow.com/a/107717 + https://docs.python.org/3/library/io.html """ - def __init__(self, f): - self.f = f + def __init__(self, stream): + self.stream = stream - def __getattr__(self, name): - return getattr(self.f, name) + def __getattr__(self, attr): + return getattr(self.stream, attr) - def write(self, x): - self.f.write(x) - self.f.flush() + def write(self, b): + self.stream.write(b) + self.stream.flush() + def writelines(self, lines): + self.stream.writelines(lines) + self.stream.flush() -sys.stderr = _flushfile(sys.stderr) -sys.stdout = _flushfile(sys.stdout) + +sys.stderr = _Unbuffered(sys.stderr) +sys.stdout = _Unbuffered(sys.stdout) def _formatException(type, value, tb): @@ -78,19 +85,29 @@ def _formatException(type, value, tb): lines += line else: matches = re.search(r"^(\s*)(.*?)(\s*)$", line, re.DOTALL) - lines.append(matches.group(1) + colored(matches.group(2), "yellow") + matches.group(3)) + lines.append( + matches.group(1) + + colored(matches.group(2), "yellow") + + matches.group(3) + ) return "".join(lines).rstrip() -sys.excepthook = lambda type, value, tb: print(_formatException(type, value, tb), file=sys.stderr) +sys.excepthook = lambda type, value, tb: print( + _formatException(type, value, tb), file=sys.stderr +) def eprint(*args, **kwargs): - raise RuntimeError("The CS50 Library for Python no longer supports eprint, but you can use print instead!") + raise RuntimeError( + "The CS50 Library for Python no longer supports eprint, but you can use print instead!" + ) def get_char(prompt): - raise RuntimeError("The CS50 Library for Python no longer supports get_char, but you can use get_string instead!") + raise RuntimeError( + "The CS50 Library for Python no longer supports get_char, but you can use get_string instead!" + ) def get_float(prompt): diff --git a/src/cs50/flask.py b/src/cs50/flask.py index 3668007..6e38971 100644 --- a/src/cs50/flask.py +++ b/src/cs50/flask.py @@ -2,6 +2,7 @@ import pkgutil import sys + def _wrap_flask(f): if f is None: return @@ -17,10 +18,15 @@ def _wrap_flask(f): if os.getenv("CS50_IDE_TYPE") == "online": from werkzeug.middleware.proxy_fix import ProxyFix + _flask_init_before = f.Flask.__init__ + def _flask_init_after(self, *args, **kwargs): _flask_init_before(self, *args, **kwargs) - self.wsgi_app = ProxyFix(self.wsgi_app, x_proto=1) # For HTTPS-to-HTTP proxy + self.wsgi_app = ProxyFix( + self.wsgi_app, x_proto=1 + ) # For HTTPS-to-HTTP proxy + f.Flask.__init__ = _flask_init_after @@ -30,7 +36,7 @@ def _flask_init_after(self, *args, **kwargs): # If Flask wasn't imported else: - flask_loader = pkgutil.get_loader('flask') + flask_loader = pkgutil.get_loader("flask") if flask_loader: _exec_module_before = flask_loader.exec_module diff --git a/src/cs50/sql.py b/src/cs50/sql.py index de3ad56..a0b93eb 100644 --- a/src/cs50/sql.py +++ b/src/cs50/sql.py @@ -14,7 +14,6 @@ def _enable_logging(f): @functools.wraps(f) def decorator(*args, **kwargs): - # Infer whether Flask is installed try: import flask @@ -71,17 +70,20 @@ def __init__(self, url, **kwargs): # Create engine, disabling SQLAlchemy's own autocommit mode raising exception if back end's module not installed; # without isolation_level, PostgreSQL warns with "there is already a transaction in progress" for our own BEGIN and # "there is no transaction in progress" for our own COMMIT - self._engine = sqlalchemy.create_engine(url, **kwargs).execution_options(autocommit=False, isolation_level="AUTOCOMMIT") + self._engine = sqlalchemy.create_engine(url, **kwargs).execution_options( + autocommit=False, isolation_level="AUTOCOMMIT" + ) # Get logger self._logger = logging.getLogger("cs50") # Listener for connections def connect(dbapi_connection, connection_record): - # Enable foreign key constraints try: - if isinstance(dbapi_connection, sqlite3.Connection): # If back end is sqlite + if isinstance( + dbapi_connection, sqlite3.Connection + ): # If back end is sqlite cursor = dbapi_connection.cursor() cursor.execute("PRAGMA foreign_keys=ON") cursor.close() @@ -150,14 +152,33 @@ def execute(self, sql, *args, **kwargs): raise RuntimeError("cannot pass both positional and named parameters") # Infer command from flattened statement to a single string separated by spaces - full_statement = ' '.join(str(token) for token in statements[0].tokens if token.ttype in [sqlparse.tokens.Keyword, sqlparse.tokens.Keyword.DDL, sqlparse.tokens.Keyword.DML]) + full_statement = " ".join( + str(token) + for token in statements[0].tokens + if token.ttype + in [ + sqlparse.tokens.Keyword, + sqlparse.tokens.Keyword.DDL, + sqlparse.tokens.Keyword.DML, + ] + ) full_statement = full_statement.upper() # Set of possible commands - commands = {"BEGIN", "CREATE VIEW", "DELETE", "INSERT", "SELECT", "START", "UPDATE"} + commands = { + "BEGIN", + "CREATE VIEW", + "DELETE", + "INSERT", + "SELECT", + "START", + "UPDATE", + } # Check if the full_statement starts with any command - command = next((cmd for cmd in commands if full_statement.startswith(cmd)), None) + command = next( + (cmd for cmd in commands if full_statement.startswith(cmd)), None + ) # Flatten statement tokens = list(statements[0].flatten()) @@ -166,10 +187,8 @@ def execute(self, sql, *args, **kwargs): placeholders = {} paramstyle = None for index, token in enumerate(tokens): - # If token is a placeholder if token.ttype == sqlparse.tokens.Name.Placeholder: - # Determine paramstyle, name _paramstyle, name = _parse_placeholder(token) @@ -186,7 +205,6 @@ def execute(self, sql, *args, **kwargs): # If no placeholders if not paramstyle: - # Error-check like qmark if args if args: paramstyle = "qmark" @@ -201,13 +219,20 @@ def execute(self, sql, *args, **kwargs): # qmark if paramstyle == "qmark": - # Validate number of placeholders if len(placeholders) != len(args): if len(placeholders) < len(args): - raise RuntimeError("fewer placeholders ({}) than values ({})".format(_placeholders, _args)) + raise RuntimeError( + "fewer placeholders ({}) than values ({})".format( + _placeholders, _args + ) + ) else: - raise RuntimeError("more placeholders ({}) than values ({})".format(_placeholders, _args)) + raise RuntimeError( + "more placeholders ({}) than values ({})".format( + _placeholders, _args + ) + ) # Escape values for i, index in enumerate(placeholders.keys()): @@ -215,27 +240,34 @@ def execute(self, sql, *args, **kwargs): # numeric elif paramstyle == "numeric": - # Escape values for index, i in placeholders.items(): if i >= len(args): - raise RuntimeError("missing value for placeholder (:{})".format(i + 1, len(args))) + raise RuntimeError( + "missing value for placeholder (:{})".format(i + 1, len(args)) + ) tokens[index] = self._escape(args[i]) # Check if any values unused indices = set(range(len(args))) - set(placeholders.values()) if indices: - raise RuntimeError("unused {} ({})".format( - "value" if len(indices) == 1 else "values", - ", ".join([str(self._escape(args[index])) for index in indices]))) + raise RuntimeError( + "unused {} ({})".format( + "value" if len(indices) == 1 else "values", + ", ".join( + [str(self._escape(args[index])) for index in indices] + ), + ) + ) # named elif paramstyle == "named": - # Escape values for index, name in placeholders.items(): if name not in kwargs: - raise RuntimeError("missing value for placeholder (:{})".format(name)) + raise RuntimeError( + "missing value for placeholder (:{})".format(name) + ) tokens[index] = self._escape(kwargs[name]) # Check if any keys unused @@ -245,13 +277,20 @@ def execute(self, sql, *args, **kwargs): # format elif paramstyle == "format": - # Validate number of placeholders if len(placeholders) != len(args): if len(placeholders) < len(args): - raise RuntimeError("fewer placeholders ({}) than values ({})".format(_placeholders, _args)) + raise RuntimeError( + "fewer placeholders ({}) than values ({})".format( + _placeholders, _args + ) + ) else: - raise RuntimeError("more placeholders ({}) than values ({})".format(_placeholders, _args)) + raise RuntimeError( + "more placeholders ({}) than values ({})".format( + _placeholders, _args + ) + ) # Escape values for i, index in enumerate(placeholders.keys()): @@ -259,40 +298,44 @@ def execute(self, sql, *args, **kwargs): # pyformat elif paramstyle == "pyformat": - # Escape values for index, name in placeholders.items(): if name not in kwargs: - raise RuntimeError("missing value for placeholder (%{}s)".format(name)) + raise RuntimeError( + "missing value for placeholder (%{}s)".format(name) + ) tokens[index] = self._escape(kwargs[name]) # Check if any keys unused keys = kwargs.keys() - placeholders.values() if keys: - raise RuntimeError("unused {} ({})".format( - "value" if len(keys) == 1 else "values", - ", ".join(keys))) + raise RuntimeError( + "unused {} ({})".format( + "value" if len(keys) == 1 else "values", ", ".join(keys) + ) + ) # For SQL statements where a colon is required verbatim, as within an inline string, use a backslash to escape # https://docs.sqlalchemy.org/en/13/core/sqlelement.html?highlight=text#sqlalchemy.sql.expression.text for index, token in enumerate(tokens): - # In string literal # https://www.sqlite.org/lang_keywords.html - if token.ttype in [sqlparse.tokens.Literal.String, sqlparse.tokens.Literal.String.Single]: + if token.ttype in [ + sqlparse.tokens.Literal.String, + sqlparse.tokens.Literal.String.Single, + ]: token.value = re.sub("(^'|\s+):", r"\1\:", token.value) # In identifier # https://www.sqlite.org/lang_keywords.html elif token.ttype == sqlparse.tokens.Literal.String.Symbol: - token.value = re.sub("(^\"|\s+):", r"\1\:", token.value) + token.value = re.sub('(^"|\s+):', r"\1\:", token.value) # Join tokens into statement statement = "".join([str(token) for token in tokens]) # If no connection yet if not hasattr(_data, self._name()): - # Connect to database setattr(_data, self._name(), self._engine.connect()) @@ -302,9 +345,12 @@ def execute(self, sql, *args, **kwargs): # Disconnect if/when a Flask app is torn down try: import flask + assert flask.current_app + def teardown_appcontext(exception): self._disconnect() + if teardown_appcontext not in flask.current_app.teardown_appcontext_funcs: flask.current_app.teardown_appcontext(teardown_appcontext) except (ModuleNotFoundError, AssertionError): @@ -312,15 +358,20 @@ def teardown_appcontext(exception): # Catch SQLAlchemy warnings with warnings.catch_warnings(): - # Raise exceptions for warnings warnings.simplefilter("error") # Prepare, execute statement try: - # Join tokens into statement, abbreviating binary data as - _statement = "".join([str(bytes) if token.ttype == sqlparse.tokens.Other else str(token) for token in tokens]) + _statement = "".join( + [ + str(bytes) + if token.ttype == sqlparse.tokens.Other + else str(token) + for token in tokens + ] + ) # Check for start of transaction if command in ["BEGIN", "START"]: @@ -342,12 +393,10 @@ def teardown_appcontext(exception): # If SELECT, return result set as list of dict objects if command == "SELECT": - # Coerce types rows = [dict(row) for row in result.mappings().all()] for row in rows: for column in row: - # Coerce decimal.Decimal objects to float objects # https://groups.google.com/d/msg/sqlalchemy/0qXMYJvq8SA/oqtvMD9Uw-kJ if isinstance(row[column], decimal.Decimal): @@ -362,15 +411,15 @@ def teardown_appcontext(exception): # If INSERT, return primary key value for a newly inserted row (or None if none) elif command == "INSERT": - # If PostgreSQL if self._engine.url.get_backend_name() == "postgresql": - # Return LASTVAL() or NULL, avoiding # "(psycopg2.errors.ObjectNotInPrerequisiteState) lastval is not yet defined in this session", # a la https://stackoverflow.com/a/24186770/5156190; # cf. https://www.psycopg.org/docs/errors.html re 55000 - result = connection.execute(sqlalchemy.text(""" + result = connection.execute( + sqlalchemy.text( + """ CREATE OR REPLACE FUNCTION _LASTVAL() RETURNS integer LANGUAGE plpgsql AS $$ @@ -382,7 +431,9 @@ def teardown_appcontext(exception): END; END $$; SELECT _LASTVAL(); - """)) + """ + ) + ) ret = result.first()[0] # If not PostgreSQL @@ -405,7 +456,10 @@ def teardown_appcontext(exception): raise e # If user error - except (sqlalchemy.exc.OperationalError, sqlalchemy.exc.ProgrammingError) as e: + except ( + sqlalchemy.exc.OperationalError, + sqlalchemy.exc.ProgrammingError, + ) as e: self._disconnect() self._logger.error(termcolor.colored(_statement, "red")) e = RuntimeError(e.orig) @@ -430,7 +484,6 @@ def _escape(self, value): import sqlparse def __escape(value): - # Lazily import import datetime import sqlalchemy @@ -439,14 +492,21 @@ def __escape(value): if isinstance(value, bool): return sqlparse.sql.Token( sqlparse.tokens.Number, - sqlalchemy.types.Boolean().literal_processor(self._engine.dialect)(value)) + sqlalchemy.types.Boolean().literal_processor(self._engine.dialect)( + value + ), + ) # bytes elif isinstance(value, bytes): if self._engine.url.get_backend_name() in ["mysql", "sqlite"]: - return sqlparse.sql.Token(sqlparse.tokens.Other, f"x'{value.hex()}'") # https://dev.mysql.com/doc/refman/8.0/en/hexadecimal-literals.html + return sqlparse.sql.Token( + sqlparse.tokens.Other, f"x'{value.hex()}'" + ) # https://dev.mysql.com/doc/refman/8.0/en/hexadecimal-literals.html elif self._engine.url.get_backend_name() == "postgresql": - return sqlparse.sql.Token(sqlparse.tokens.Other, f"'\\x{value.hex()}'") # https://dba.stackexchange.com/a/203359 + return sqlparse.sql.Token( + sqlparse.tokens.Other, f"'\\x{value.hex()}'" + ) # https://dba.stackexchange.com/a/203359 else: raise RuntimeError("unsupported value: {}".format(value)) @@ -454,43 +514,59 @@ def __escape(value): elif isinstance(value, datetime.datetime): return sqlparse.sql.Token( sqlparse.tokens.String, - sqlalchemy.types.String().literal_processor(self._engine.dialect)(value.strftime("%Y-%m-%d %H:%M:%S"))) + sqlalchemy.types.String().literal_processor(self._engine.dialect)( + value.strftime("%Y-%m-%d %H:%M:%S") + ), + ) # datetime.date elif isinstance(value, datetime.date): return sqlparse.sql.Token( sqlparse.tokens.String, - sqlalchemy.types.String().literal_processor(self._engine.dialect)(value.strftime("%Y-%m-%d"))) + sqlalchemy.types.String().literal_processor(self._engine.dialect)( + value.strftime("%Y-%m-%d") + ), + ) # datetime.time elif isinstance(value, datetime.time): return sqlparse.sql.Token( sqlparse.tokens.String, - sqlalchemy.types.String().literal_processor(self._engine.dialect)(value.strftime("%H:%M:%S"))) + sqlalchemy.types.String().literal_processor(self._engine.dialect)( + value.strftime("%H:%M:%S") + ), + ) # float elif isinstance(value, float): return sqlparse.sql.Token( sqlparse.tokens.Number, - sqlalchemy.types.Float().literal_processor(self._engine.dialect)(value)) + sqlalchemy.types.Float().literal_processor(self._engine.dialect)( + value + ), + ) # int elif isinstance(value, int): return sqlparse.sql.Token( sqlparse.tokens.Number, - sqlalchemy.types.Integer().literal_processor(self._engine.dialect)(value)) + sqlalchemy.types.Integer().literal_processor(self._engine.dialect)( + value + ), + ) # str elif isinstance(value, str): return sqlparse.sql.Token( sqlparse.tokens.String, - sqlalchemy.types.String().literal_processor(self._engine.dialect)(value)) + sqlalchemy.types.String().literal_processor(self._engine.dialect)( + value + ), + ) # None elif value is None: - return sqlparse.sql.Token( - sqlparse.tokens.Keyword, - sqlalchemy.null()) + return sqlparse.sql.Token(sqlparse.tokens.Keyword, sqlalchemy.null()) # Unsupported value else: @@ -498,7 +574,9 @@ def __escape(value): # Escape value(s), separating with commas as needed if isinstance(value, (list, tuple)): - return sqlparse.sql.TokenList(sqlparse.parse(", ".join([str(__escape(v)) for v in value]))) + return sqlparse.sql.TokenList( + sqlparse.parse(", ".join([str(__escape(v)) for v in value])) + ) else: return __escape(value) @@ -510,7 +588,9 @@ def _parse_exception(e): import re # MySQL - matches = re.search(r"^\(_mysql_exceptions\.OperationalError\) \(\d+, \"(.+)\"\)$", str(e)) + matches = re.search( + r"^\(_mysql_exceptions\.OperationalError\) \(\d+, \"(.+)\"\)$", str(e) + ) if matches: return matches.group(1) @@ -536,7 +616,10 @@ def _parse_placeholder(token): import sqlparse # Validate token - if not isinstance(token, sqlparse.sql.Token) or token.ttype != sqlparse.tokens.Name.Placeholder: + if ( + not isinstance(token, sqlparse.sql.Token) + or token.ttype != sqlparse.tokens.Name.Placeholder + ): raise TypeError() # qmark