diff --git a/src/alembic_utils/depends.py b/src/alembic_utils/depends.py index 741ea84..e824289 100644 --- a/src/alembic_utils/depends.py +++ b/src/alembic_utils/depends.py @@ -80,6 +80,7 @@ def upgrade() -> None: """ from alembic_utils.pg_function import PGFunction from alembic_utils.pg_materialized_view import PGMaterializedView + from alembic_utils.pg_procedure import PGProcedure from alembic_utils.pg_trigger import PGTrigger from alembic_utils.pg_view import PGView from alembic_utils.replaceable_entity import ReplaceableEntity @@ -91,6 +92,7 @@ def collect_all_db_entities(sess: Session) -> List[ReplaceableEntity]: return [ *PGFunction.from_database(sess, "%"), + *PGProcedure.from_database(sess, "%"), *PGTrigger.from_database(sess, "%"), *PGView.from_database(sess, "%"), *PGMaterializedView.from_database(sess, "%"), diff --git a/src/alembic_utils/pg_procedure.py b/src/alembic_utils/pg_procedure.py new file mode 100644 index 0000000..159ad77 --- /dev/null +++ b/src/alembic_utils/pg_procedure.py @@ -0,0 +1,154 @@ +# pylint: disable=unused-argument,invalid-name,line-too-long +from typing import List + +from parse import parse +from sqlalchemy import text as sql_text + +from alembic_utils.exceptions import SQLParseFailure +from alembic_utils.replaceable_entity import ReplaceableEntity +from alembic_utils.statement import ( + escape_colon_for_plpgsql, + escape_colon_for_sql, + normalize_whitespace, + strip_terminating_semicolon, +) + + +class PGProcedure(ReplaceableEntity): + """A PostgreSQL Procedure compatible with `alembic revision --autogenerate` + + **Parameters:** + + * **schema** - *str*: A SQL schema name + * **signature** - *str*: A SQL procedure's call signature + * **definition** - *str*: The remainig procedure body and identifiers + """ + + type_ = "procedure" + + def __init__(self, schema: str, signature: str, definition: str): + super().__init__(schema, signature, definition) + # Detect if procedure uses plpgsql and update escaping rules to not escape ":=" + is_plpgsql: bool = "language plpgsql" in normalize_whitespace(definition).lower().replace( + "'", "" + ) + escaping_callable = escape_colon_for_plpgsql if is_plpgsql else escape_colon_for_sql + # Override definition with correct escaping rules + self.definition: str = escaping_callable(strip_terminating_semicolon(definition)) + + @classmethod + def from_sql(cls, sql: str) -> "PGProcedure": + """Create an instance instance from a SQL string""" + template = "create{}procedure{:s}{schema}.{signature_name}({signature_arg}){:s}{definition}" + result = parse(template, sql.strip(), case_sensitive=False) + if result is not None: + raw_signature = f'{result["signature_name"]}({result["signature_arg"]})' + # remove possible quotes from signature + signature = ( + "".join(raw_signature.split('"', 2)) + if raw_signature.startswith('"') + else raw_signature + ) + return cls( + schema=result["schema"], + signature=signature, + definition=result["definition"], + ) + raise SQLParseFailure(f'Failed to parse SQL into PGProcedure """{sql}"""') + + @property + def literal_signature(self) -> str: + """Adds quoting around the procedure name when emitting SQL statements + + e.g. + 'toUpper(text) returns text' -> '"toUpper"(text) text' + """ + # May already be quoted if loading from database or SQL file + name, remainder = self.signature.split("(", 1) + return '"' + name.strip() + '"(' + remainder + + def to_sql_statement_create(self): + """Generates a SQL "create procedure" statement for PGProcedure""" + return sql_text( + f"CREATE PROCEDURE {self.literal_schema}.{self.literal_signature} {self.definition}", + ) + + def to_sql_statement_drop(self, cascade=False): + """Generates a SQL "drop procedure" statement for PGProcedure""" + cascade = "cascade" if cascade else "" + template = "{procedure_name}({parameters})" + result = parse(template, self.signature, case_sensitive=False) + try: + procedure_name = result["procedure_name"].strip() + parameters_str = result["parameters"].strip() + except TypeError: + # Did not match, NoneType is not scriptable + result = parse("{procedure_name}()", self.signature, case_sensitive=False) + procedure_name = result["procedure_name"].strip() + parameters_str = "" + + # NOTE: Will fail if a text field has a default and that deafult contains a comma... + parameters: list[str] = parameters_str.split(",") + parameters = [x[: len(x.lower().split("default")[0])] for x in parameters] + parameters = [x.strip() for x in parameters] + drop_params = ", ".join(parameters) + return sql_text( + f'DROP PROCEDURE {self.literal_schema}."{procedure_name}"({drop_params}) {cascade}', + ) + + def to_sql_statement_create_or_replace(self): + """Generates a SQL "create or replace procedure" statement for PGProcedure""" + yield sql_text( + f"CREATE OR REPLACE PROCEDURE {self.literal_schema}.{self.literal_signature} {self.definition}", + ) + + @classmethod + def from_database(cls, sess, schema): + """Get a list of all procedures defined in the db""" + + sql = sql_text( + """ + with extension_functions as ( + select + objid as extension_function_oid + from + pg_depend + where + -- depends on an extension + deptype='e' + -- is a proc/function + and classid = 'pg_proc'::regclass + ) + + select + n.nspname as function_schema, + p.proname as procedure_name, + pg_get_function_arguments(p.oid) as function_arguments, + case + when l.lanname = 'internal' then p.prosrc + else pg_get_functiondef(p.oid) + end as create_statement, + t.typname as return_type, + l.lanname as function_language + from + pg_proc p + left join pg_namespace n on p.pronamespace = n.oid + left join pg_language l on p.prolang = l.oid + left join pg_type t on t.oid = p.prorettype + left join extension_functions ef on p.oid = ef.extension_function_oid + where + n.nspname not in ('pg_catalog', 'information_schema') + -- Filter out functions from extensions + and ef.extension_function_oid is null + and n.nspname = :schema + and p.prokind = 'p' + """ + ) + + rows = sess.execute(sql, {"schema": schema}).fetchall() + db_functions = [cls.from_sql(x[3]) for x in rows] + + for func in db_functions: + assert func is not None + + return db_functions diff --git a/src/test/test_pg_procedure.py b/src/test/test_pg_procedure.py new file mode 100644 index 0000000..2752efb --- /dev/null +++ b/src/test/test_pg_procedure.py @@ -0,0 +1,256 @@ +from typing import List + +from sqlalchemy import text + +from alembic_utils.pg_procedure import PGProcedure +from alembic_utils.replaceable_entity import register_entities +from alembic_utils.testbase import TEST_VERSIONS_ROOT, run_alembic_command + +TO_UPPER = PGProcedure( + schema="public", + signature="toUpper (some_text text default 'my text!')", + definition=""" + language PLPGSQL as $$ + declare result text; + begin result = upper(some_text) || 'abc'; end; $$; + """, +) + + +def test_trailing_whitespace_stripped(): + sql_statements: List[str] = [ + str(TO_UPPER.to_sql_statement_create()), + str(next(iter(TO_UPPER.to_sql_statement_create_or_replace()))), + str(TO_UPPER.to_sql_statement_drop()), + ] + + for statement in sql_statements: + print(statement) + assert '"toUpper"' in statement + assert not '"toUpper "' in statement + + +def test_create_revision(engine) -> None: + register_entities([TO_UPPER], entity_types=[PGProcedure]) + + run_alembic_command( + engine=engine, + command="revision", + command_kwargs={"autogenerate": True, "rev_id": "1", "message": "create"}, + ) + + migration_create_path = TEST_VERSIONS_ROOT / "1_create.py" + + with migration_create_path.open() as migration_file: + migration_contents = migration_file.read() + + assert "op.create_entity" in migration_contents + assert "op.drop_entity" in migration_contents + assert "op.replace_entity" not in migration_contents + assert "from alembic_utils.pg_procedure import PGProcedure" in migration_contents + + # Execute upgrade + run_alembic_command(engine=engine, command="upgrade", command_kwargs={"revision": "head"}) + # Execute Downgrade + run_alembic_command(engine=engine, command="downgrade", command_kwargs={"revision": "base"}) + + +def test_update_revision(engine) -> None: + with engine.begin() as connection: + connection.execute(TO_UPPER.to_sql_statement_create()) + + # Update definition of TO_UPPER + UPDATED_TO_UPPER = PGProcedure( + TO_UPPER.schema, + TO_UPPER.signature, + r''' + language SQL as $$ + select upper(some_text) || 'def' -- """ \n \\ + $$''', + ) + + register_entities([UPDATED_TO_UPPER], entity_types=[PGProcedure]) + + # Autogenerate a new migration + # It should detect the change we made and produce a "replace_procedure" statement + run_alembic_command( + engine=engine, + command="revision", + command_kwargs={"autogenerate": True, "rev_id": "2", "message": "replace"}, + ) + + migration_replace_path = TEST_VERSIONS_ROOT / "2_replace.py" + + with migration_replace_path.open() as migration_file: + migration_contents = migration_file.read() + + assert "op.replace_entity" in migration_contents + assert "op.create_entity" not in migration_contents + assert "op.drop_entity" not in migration_contents + assert "from alembic_utils.pg_procedure import PGProcedure" in migration_contents + + # Execute upgrade + run_alembic_command(engine=engine, command="upgrade", command_kwargs={"revision": "head"}) + + # Execute Downgrade + run_alembic_command(engine=engine, command="downgrade", command_kwargs={"revision": "base"}) + + +def test_noop_revision(engine) -> None: + with engine.begin() as connection: + connection.execute(TO_UPPER.to_sql_statement_create()) + + register_entities([TO_UPPER], entity_types=[PGProcedure]) + + output = run_alembic_command( + engine=engine, + command="revision", + command_kwargs={"autogenerate": True, "rev_id": "3", "message": "do_nothing"}, + ) + migration_do_nothing_path = TEST_VERSIONS_ROOT / "3_do_nothing.py" + + with migration_do_nothing_path.open() as migration_file: + migration_contents = migration_file.read() + + assert "op.create_entity" not in migration_contents + assert "op.drop_entity" not in migration_contents + assert "op.replace_entity" not in migration_contents + assert "from alembic_utils" not in migration_contents + + # Execute upgrade + run_alembic_command(engine=engine, command="upgrade", command_kwargs={"revision": "head"}) + # Execute Downgrade + run_alembic_command(engine=engine, command="downgrade", command_kwargs={"revision": "base"}) + + +def test_drop(engine) -> None: + # Manually create a SQL procedure + with engine.begin() as connection: + connection.execute(TO_UPPER.to_sql_statement_create()) + + # Register no procedure locally + register_entities([], schemas=["public"], entity_types=[PGProcedure]) + + run_alembic_command( + engine=engine, + command="revision", + command_kwargs={"autogenerate": True, "rev_id": "1", "message": "drop"}, + ) + + migration_create_path = TEST_VERSIONS_ROOT / "1_drop.py" + + with migration_create_path.open() as migration_file: + migration_contents = migration_file.read() + + assert "op.drop_entity" in migration_contents + assert "op.create_entity" in migration_contents + assert "from alembic_utils" in migration_contents + assert migration_contents.index("op.drop_entity") < migration_contents.index("op.create_entity") + + # Execute upgrade + run_alembic_command(engine=engine, command="upgrade", command_kwargs={"revision": "head"}) + # Execute Downgrade + run_alembic_command(engine=engine, command="downgrade", command_kwargs={"revision": "base"}) + + +def test_has_no_parameters(engine) -> None: + # Error was occuring in drop statement when procedure had no parameters + # related to parameter parsing to drop default statements + + SIDE_EFFECT = PGProcedure( + schema="public", + signature="side_effect()", + definition=""" + language SQL as $$ + select 1; $$; + """, + ) + + # Register no procedures locally + register_entities([SIDE_EFFECT], schemas=["public"], entity_types=[PGProcedure]) + + run_alembic_command( + engine=engine, + command="revision", + command_kwargs={"autogenerate": True, "rev_id": "1", "message": "no_arguments"}, + ) + + migration_create_path = TEST_VERSIONS_ROOT / "1_no_arguments.py" + + with migration_create_path.open() as migration_file: + migration_contents = migration_file.read() + + assert "op.drop_entity" in migration_contents + + # Execute upgrade + run_alembic_command(engine=engine, command="upgrade", command_kwargs={"revision": "head"}) + # Execute Downgrade + run_alembic_command(engine=engine, command="downgrade", command_kwargs={"revision": "base"}) + + +def test_ignores_extension_procedures(engine) -> None: + # Extensions contain procedures and don't have local representations + # Unless they are excluded, every autogenerate migration will produce + # drop statements for those procedures + try: + with engine.begin() as connection: + connection.execute(text("create extension if not exists unaccent;")) + register_entities([], schemas=["public"], entity_types=[PGProcedure]) + run_alembic_command( + engine=engine, + command="revision", + command_kwargs={"autogenerate": True, "rev_id": "1", "message": "no_drops"}, + ) + + migration_create_path = TEST_VERSIONS_ROOT / "1_no_drops.py" + + with migration_create_path.open() as migration_file: + migration_contents = migration_file.read() + + assert "op.drop_entity" not in migration_contents + finally: + with engine.begin() as connection: + connection.execute(text("drop extension if exists unaccent;")) + + +def test_plpgsql_colon_esacpe(engine) -> None: + # PGProcedure.__init__ overrides colon escapes for plpgsql + # because := should not be escaped for sqlalchemy.text + # if := is escaped, an exception would be raised + + PLPGSQL_FUNC = PGProcedure( + schema="public", + signature="some_proc(some_text text)", + definition=""" + language plpgsql as $$ + declare + copy_o_text text; + begin + copy_o_text := some_text; + end; + $$; + """, + ) + + register_entities([PLPGSQL_FUNC], entity_types=[PGProcedure]) + + run_alembic_command( + engine=engine, + command="revision", + command_kwargs={"autogenerate": True, "rev_id": "1", "message": "create"}, + ) + + migration_create_path = TEST_VERSIONS_ROOT / "1_create.py" + + with migration_create_path.open() as migration_file: + migration_contents = migration_file.read() + + assert "op.create_entity" in migration_contents + assert "op.drop_entity" in migration_contents + assert "op.replace_entity" not in migration_contents + assert "from alembic_utils.pg_procedure import PGProcedure" in migration_contents + + # Execute upgrade + run_alembic_command(engine=engine, command="upgrade", command_kwargs={"revision": "head"}) + # Execute Downgrade + run_alembic_command(engine=engine, command="downgrade", command_kwargs={"revision": "base"})