Skip to content

Commit

Permalink
Merge pull request #81 from Pogchamp-company/bug/80/mypy-errors
Browse files Browse the repository at this point in the history
Ignore mypy errors for migrations
  • Loading branch information
RustyGuard authored Nov 19, 2024
2 parents 4ecc17f + 915551c commit 05f2fd1
Show file tree
Hide file tree
Showing 10 changed files with 232 additions and 114 deletions.
97 changes: 73 additions & 24 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from alembic_postgresql_enum.configuration import Config

# alembic-postgresql-enum
[<img src="https://img.shields.io/pypi/pyversions/alembic-postgresql-enum">](https://pypi.org/project/alembic-postgresql-enum/)
[<img src="https://img.shields.io/pypi/v/alembic-postgresql-enum">](https://pypi.org/project/alembic-postgresql-enum/)
Expand Down Expand Up @@ -30,6 +32,21 @@ import alembic_postgresql_enum

To the top of your migrations/env.py file.

## Configuration

You can configure this extension to disable parts of it, or to enable some feature flags

To do so you need to call set_configuration function after the import:

```python
import alembic_postgresql_enum

alembic_postgresql_enum.set_configuration(
alembic_postgresql_enum.Config(
add_type_ignore=True,
)
)
```
## Features

* [Creation of enums](#creation-of-enum)
Expand Down Expand Up @@ -147,17 +164,25 @@ class MyEnum(enum.Enum):
```python
def upgrade():
# ### commands auto generated by Alembic - please adjust! ###
op.sync_enum_values('public', 'myenum', ['one', 'two', 'three', 'four'],
[('example_table', 'enum_field')],
enum_values_to_rename=[])
op.sync_enum_values(
enum_schema='public',
enum_name='myenum',
new_values=['one', 'two', 'three', 'four'],
affected_columns=[TableReference(table_schema='public', table_name='example_table', column_name='enum_field')],
enum_values_to_rename=[],
)
# ### end Alembic commands ###


def downgrade():
# ### commands auto generated by Alembic - please adjust! ###
op.sync_enum_values('public', 'myenum', ['one', 'two', 'three'],
[('example_table', 'enum_field')],
enum_values_to_rename=[])
op.sync_enum_values(
enum_schema='public',
enum_name='myenum',
new_values=['one', 'two', 'three'],
affected_columns=[TableReference(table_schema='public', table_name='example_table', column_name='enum_field')],
enum_values_to_rename=[],
)
# ### end Alembic commands ###
```

Expand All @@ -175,17 +200,25 @@ class MyEnum(enum.Enum):
```python
def upgrade():
# ### commands auto generated by Alembic - please adjust! ###
op.sync_enum_values('public', 'myenum', ['one', 'two'],
[('example_table', 'enum_field')],
enum_values_to_rename=[])
op.sync_enum_values(
enum_schema='public',
enum_name='myenum',
new_values=['one', 'two'],
affected_columns=[TableReference(table_schema='public', table_name='example_table', column_name='enum_field')],
enum_values_to_rename=[],
)
# ### end Alembic commands ###


def downgrade():
# ### commands auto generated by Alembic - please adjust! ###
op.sync_enum_values('public', 'myenum', ['one', 'two', 'three'],
[('example_table', 'enum_field')],
enum_values_to_rename=[])
op.sync_enum_values(
enum_schema='public',
enum_name='myenum',
new_values=['one', 'two', 'three'],
affected_columns=[TableReference(table_schema='public', table_name='example_table', column_name='enum_field')],
enum_values_to_rename=[],
)
# ### end Alembic commands ###
```

Expand All @@ -203,17 +236,25 @@ This code will generate this migration:
```python
def upgrade():
# ### commands auto generated by Alembic - please adjust! ###
op.sync_enum_values('public', 'myenum', ['one', 'two', 'three'],
[('example_table', 'enum_field')],
enum_values_to_rename=[])
op.sync_enum_values(
enum_schema='public',
enum_name='myenum',
new_values=['one', 'two', 'three'],
affected_columns=[TableReference(table_schema='public', table_name='example_table', column_name='enum_field')],
enum_values_to_rename=[],
)
# ### end Alembic commands ###


def downgrade():
# ### commands auto generated by Alembic - please adjust! ###
op.sync_enum_values('public', 'myenum', ['one', 'two', 'tree'],
[('example_table', 'enum_field')],
enum_values_to_rename=[])
op.sync_enum_values(
enum_schema='public',
enum_name='myenum',
new_values=['one', 'two', 'tree'],
affected_columns=[TableReference(table_schema='public', table_name='example_table', column_name='enum_field')],
enum_values_to_rename=[],
)
# ### end Alembic commands ###
```

Expand All @@ -223,15 +264,23 @@ So adjust migration like that

```python
def upgrade():
op.sync_enum_values('public', 'myenum', ['one', 'two', 'three'],
[('example_table', 'enum_field')],
enum_values_to_rename=[('tree', 'three')])
op.sync_enum_values(
enum_schema='public',
enum_name='myenum',
new_values=['one', 'two', 'three'],
affected_columns=[TableReference(table_schema='public', table_name='example_table', column_name='enum_field')],
enum_values_to_rename=[('tree', 'three')],
)


def downgrade():
op.sync_enum_values('public', 'myenum', ['one', 'two', 'tree'],
[('example_table', 'enum_field')],
enum_values_to_rename=[('three', 'tree')])
op.sync_enum_values(
enum_schema='public',
enum_name='myenum',
new_values=['one', 'two', 'tree'],
affected_columns=[TableReference(table_schema='public', table_name='example_table', column_name='enum_field')],
enum_values_to_rename=[('three', 'tree')],
)
```

Do not forget to switch places old and new values for downgrade
Expand Down
1 change: 1 addition & 0 deletions alembic_postgresql_enum/__init__.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,3 @@
from .compare_dispatch import compare_enums
from .get_enum_data import ColumnType, TableReference
from .configuration import set_configuration, Config
18 changes: 18 additions & 0 deletions alembic_postgresql_enum/configuration.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
from dataclasses import dataclass


@dataclass
class Config:
add_type_ignore: bool = False


_config = Config()


def set_configuration(config: Config):
global _config
_config = config


def get_configuration() -> Config:
return _config
12 changes: 9 additions & 3 deletions alembic_postgresql_enum/operations/sync_enum_values.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from alembic.autogenerate.api import AutogenContext
from sqlalchemy.exc import DataError

from alembic_postgresql_enum.configuration import get_configuration
from alembic_postgresql_enum.get_enum_data.types import Unspecified
from alembic_postgresql_enum.sql_commands.column_default import (
get_column_default,
Expand Down Expand Up @@ -203,12 +204,17 @@ def is_column_type_import_needed(self) -> bool:

@alembic.autogenerate.render.renderers.dispatch_for(SyncEnumValuesOp)
def render_sync_enum_value_op(autogen_context: AutogenContext, op: SyncEnumValuesOp):
config = get_configuration()
if op.is_column_type_import_needed:
autogen_context.imports.add("from alembic_postgresql_enum import ColumnType")
autogen_context.imports.add("from alembic_postgresql_enum import TableReference")

return (
f"op.sync_enum_values({op.schema!r}, {op.name!r}, {op.new_values!r},\n"
f" {op.affected_columns!r},\n"
f" enum_values_to_rename=[])"
f"op.sync_enum_values({' # type: ignore[attr-defined]' if config.add_type_ignore else ''}\n"
f" enum_schema={op.schema!r},\n"
f" enum_name={op.name!r},\n"
f" new_values={op.new_values!r},\n"
f" affected_columns={op.affected_columns!r},\n"
f" enum_values_to_rename=[],\n"
f")"
)
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[tool.poetry]
name = "alembic-postgresql-enum"
version = "1.3.0"
version = "1.4.0"
description = "Alembic autogenerate support for creation, alteration and deletion of enums"
authors = ["RustyGuard"]
license = "MIT"
Expand Down
5 changes: 4 additions & 1 deletion tests/base/render_and_run.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import textwrap
from typing import TYPE_CHECKING, Union, List
from typing import TYPE_CHECKING, Union, List, Optional

import sqlalchemy
from alembic import autogenerate
Expand All @@ -20,6 +20,7 @@ def compare_and_run(
*,
expected_upgrade: str,
expected_downgrade: str,
expected_imports: Optional[str],
disable_running: bool = False,
):
"""Compares generated migration script is equal to expected_upgrade and expected_downgrade, then runs it"""
Expand All @@ -37,6 +38,8 @@ def compare_and_run(
expected_upgrade = textwrap.dedent(expected_upgrade).strip("\n ")
expected_downgrade = textwrap.dedent(expected_downgrade).strip("\n ")

if expected_imports is not None:
assert template_args["imports"] == expected_imports
assert upgrade_code == expected_upgrade, f"Got:\n{upgrade_code!r}\nExpected:\n{expected_upgrade!r}"
assert downgrade_code == expected_downgrade, f"Got:\n{downgrade_code!r}\nExpected:\n{expected_downgrade!r}"

Expand Down
12 changes: 11 additions & 1 deletion tests/base/run_migration_test_abc.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
from abc import ABC, abstractmethod
from typing import TYPE_CHECKING
from typing import TYPE_CHECKING, Optional

import alembic_postgresql_enum
from alembic_postgresql_enum.configuration import Config, get_configuration
from tests.base.render_and_run import compare_and_run

if TYPE_CHECKING:
Expand All @@ -14,6 +16,7 @@ class CompareAndRunTestCase(ABC):
"""

disable_running = False
config = Config()

@abstractmethod
def get_database_schema(self) -> MetaData: ...
Expand All @@ -30,7 +33,12 @@ def get_expected_upgrade(self) -> str: ...
@abstractmethod
def get_expected_downgrade(self) -> str: ...

def get_expected_imports(self) -> Optional[str]:
return None

def test_run(self, connection: "Connection"):
old_config = get_configuration()
alembic_postgresql_enum.set_configuration(self.config)
database_schema = self.get_database_schema()
target_schema = self.get_target_schema()

Expand All @@ -42,5 +50,7 @@ def test_run(self, connection: "Connection"):
target_schema,
expected_upgrade=self.get_expected_upgrade(),
expected_downgrade=self.get_expected_downgrade(),
expected_imports=self.get_expected_imports(),
disable_running=self.disable_running,
)
alembic_postgresql_enum.set_configuration(old_config)
77 changes: 42 additions & 35 deletions tests/sync_enum_values/test_array_column.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import TYPE_CHECKING
from typing import TYPE_CHECKING, Optional

from alembic import autogenerate
from alembic.autogenerate import api
Expand All @@ -7,9 +7,11 @@
from alembic_postgresql_enum import ColumnType
from alembic_postgresql_enum.get_enum_data import TableReference
from alembic_postgresql_enum.operations import SyncEnumValuesOp
from tests.base.run_migration_test_abc import CompareAndRunTestCase

if TYPE_CHECKING:
from sqlalchemy import Connection
from sqlalchemy import MetaData

from tests.schemas import (
get_schema_with_enum_in_array_variants,
Expand All @@ -21,43 +23,48 @@
from tests.utils.migration_context import create_migration_context


def test_add_new_enum_value_render_with_array(connection: "Connection"):
class TestAddNewEnumValueRenderWithArray(CompareAndRunTestCase):
"""Check that enum variants are updated when new variant is added"""
old_enum_variants = ["black", "white", "red", "green", "blue", "other"]

database_schema = get_schema_with_enum_in_array_variants(old_enum_variants)
database_schema.create_all(connection)

new_enum_variants = old_enum_variants.copy()
new_enum_variants.append("violet")

target_schema = get_schema_with_enum_in_array_variants(new_enum_variants)

context = create_migration_context(connection, target_schema)

template_args = {}
autogenerate._render_migration_diffs(context, template_args)

assert template_args["imports"] == (
"from alembic_postgresql_enum import ColumnType" "\nfrom alembic_postgresql_enum import TableReference"
)
old_enum_variants = ["black", "white", "red", "green", "blue", "other"]
new_enum_variants = old_enum_variants + ["violet"]

def get_database_schema(self) -> MetaData:
schema = get_schema_with_enum_in_array_variants(self.old_enum_variants)
return schema

def get_target_schema(self) -> MetaData:
schema = get_schema_with_enum_in_array_variants(self.new_enum_variants)
return schema

def get_expected_upgrade(self) -> str:
return f"""
# ### commands auto generated by Alembic - please adjust! ###
op.sync_enum_values(
enum_schema='{DEFAULT_SCHEMA}',
enum_name='{CAR_COLORS_ENUM_NAME}',
new_values=[{', '.join(map(repr, self.new_enum_variants))}],
affected_columns=[TableReference(table_schema='{DEFAULT_SCHEMA}', table_name='{CAR_TABLE_NAME}', column_name='{CAR_COLORS_COLUMN_NAME}', column_type=ColumnType.ARRAY)],
enum_values_to_rename=[],
)
# ### end Alembic commands ###
"""

def get_expected_downgrade(self) -> str:
return f"""
# ### commands auto generated by Alembic - please adjust! ###
op.sync_enum_values(
enum_schema='{DEFAULT_SCHEMA}',
enum_name='{CAR_COLORS_ENUM_NAME}',
new_values=[{', '.join(map(repr, self.old_enum_variants))}],
affected_columns=[TableReference(table_schema='{DEFAULT_SCHEMA}', table_name='{CAR_TABLE_NAME}', column_name='{CAR_COLORS_COLUMN_NAME}', column_type=ColumnType.ARRAY)],
enum_values_to_rename=[],
)
# ### end Alembic commands ###
"""

assert (
template_args["upgrades"]
== f"""# ### commands auto generated by Alembic - please adjust! ###
op.sync_enum_values('{DEFAULT_SCHEMA}', '{CAR_COLORS_ENUM_NAME}', [{', '.join(map(repr, new_enum_variants))}],
[TableReference(table_schema='{DEFAULT_SCHEMA}', table_name='{CAR_TABLE_NAME}', column_name='{CAR_COLORS_COLUMN_NAME}', column_type=ColumnType.ARRAY)],
enum_values_to_rename=[])
# ### end Alembic commands ###"""
)
assert (
template_args["downgrades"]
== f"""# ### commands auto generated by Alembic - please adjust! ###
op.sync_enum_values('{DEFAULT_SCHEMA}', '{CAR_COLORS_ENUM_NAME}', [{', '.join(map(repr, old_enum_variants))}],
[TableReference(table_schema='{DEFAULT_SCHEMA}', table_name='{CAR_TABLE_NAME}', column_name='{CAR_COLORS_COLUMN_NAME}', column_type=ColumnType.ARRAY)],
enum_values_to_rename=[])
# ### end Alembic commands ###"""
)
def get_expected_imports(self) -> Optional[str]:
return "from alembic_postgresql_enum import ColumnType" "\nfrom alembic_postgresql_enum import TableReference"


def test_add_new_enum_value_diff_tuple_with_array(connection: "Connection"):
Expand Down
Loading

0 comments on commit 05f2fd1

Please sign in to comment.