From 2a32126fd7cdd6328f6dd54c4926565d6cf34856 Mon Sep 17 00:00:00 2001 From: Kirill Podoprigora Date: Fri, 8 Nov 2024 05:30:16 +0200 Subject: [PATCH 1/3] Initial test; Need to be refactored! --- tests/unit/test_readme_example.py | 45 +++++++++++++++++++++++++++++++ 1 file changed, 45 insertions(+) create mode 100644 tests/unit/test_readme_example.py diff --git a/tests/unit/test_readme_example.py b/tests/unit/test_readme_example.py new file mode 100644 index 0000000..6e9b788 --- /dev/null +++ b/tests/unit/test_readme_example.py @@ -0,0 +1,45 @@ +from typing import Protocol, Iterable +from sqlite3 import Connection, connect +from dishka import Provider, Scope, make_container, provide + + +class DAO(Protocol): + ... + + +class Service: + def __init__(self, dao: DAO): + ... + + +class DAOImpl(DAO): + def __init__(self, connection: Connection): + ... + + +class SomeClient: + ... + + +provider = Provider() + +service_provider = Provider(scope=Scope.REQUEST) +service_provider.provide(Service) +service_provider.provide(DAOImpl, provides=DAO) +service_provider.provide(SomeClient, scope=Scope.APP) + +class ConnectionProvider(Provider): + @provide(scope=Scope.REQUEST) + def new_connection(self) -> Iterable[Connection]: + conn = sqlite3.connect(":memory:") + yield conn + conn.close() + + +def test_get_client(): + container = make_container(service_provider, ConnectionProvider()) + client_1 = container.get(SomeClient) + client_2 = container.get(SomeClient) + + assert isinstance(client_1, SomeClient) + assert client_1 is client_2 From db1746f4e6e378151347a90c6273d9aafcd921b7 Mon Sep 17 00:00:00 2001 From: Kirill Podoprigora Date: Fri, 8 Nov 2024 06:22:38 +0200 Subject: [PATCH 2/3] Refactoring of test_readme_example.py --- tests/unit/test_readme_example.py | 43 ++++++++++++++++++++++++------- 1 file changed, 34 insertions(+), 9 deletions(-) diff --git a/tests/unit/test_readme_example.py b/tests/unit/test_readme_example.py index 6e9b788..db34445 100644 --- a/tests/unit/test_readme_example.py +++ b/tests/unit/test_readme_example.py @@ -1,5 +1,10 @@ -from typing import Protocol, Iterable -from sqlite3 import Connection, connect +import sqlite3 +from collections.abc import Iterable +from sqlite3 import Connection +from typing import Protocol + +import pytest + from dishka import Provider, Scope, make_container, provide @@ -21,12 +26,19 @@ class SomeClient: ... -provider = Provider() +@pytest.fixture +def service_provider(): + provider = Provider(scope=Scope.REQUEST) + provider.provide(Service) + provider.provide(DAOImpl, provides=DAO) + provider.provide(SomeClient, scope=Scope.APP) + return provider + + +@pytest.fixture +def container(service_provider): + return make_container(service_provider, ConnectionProvider()) -service_provider = Provider(scope=Scope.REQUEST) -service_provider.provide(Service) -service_provider.provide(DAOImpl, provides=DAO) -service_provider.provide(SomeClient, scope=Scope.APP) class ConnectionProvider(Provider): @provide(scope=Scope.REQUEST) @@ -36,10 +48,23 @@ def new_connection(self) -> Iterable[Connection]: conn.close() -def test_get_client(): - container = make_container(service_provider, ConnectionProvider()) +def test_get_client(container): client_1 = container.get(SomeClient) client_2 = container.get(SomeClient) assert isinstance(client_1, SomeClient) assert client_1 is client_2 + + +def test_subcontainers(container): + with container() as request_container: + service_1 = request_container.get(Service) + service_2 = request_container.get(Service) + + assert service_1 is service_2 + assert isinstance(service_1, Service) + + with container() as new_request_container: + service_3 = new_request_container.get(Service) + + assert service_1 is not service_3 From 749881391263a26ac029b490f4c93911c0a80f03 Mon Sep 17 00:00:00 2001 From: Kirill Podoprigora Date: Fri, 8 Nov 2024 06:25:35 +0200 Subject: [PATCH 3/3] sqlite3.connect() -> sqlite3.connect(':memory:') --- README.md | 2 +- docs/quickstart.rst | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/README.md b/README.md index 88aa9fe..d5bb788 100644 --- a/README.md +++ b/README.md @@ -89,7 +89,7 @@ from dishka import Provider, provide, Scope class ConnectionProvider(Provider): @provide(scope=Scope.REQUEST) def new_connection(self) -> Iterable[Connection]: - conn = sqlite3.connect() + conn = sqlite3.connect(":memory:") yield conn conn.close() ``` diff --git a/docs/quickstart.rst b/docs/quickstart.rst index bce8ed5..896b072 100644 --- a/docs/quickstart.rst +++ b/docs/quickstart.rst @@ -58,7 +58,7 @@ To provide connection we might need to write some custom code: class ConnectionProvider(Provider): @provide(scope=Scope.REQUEST) def new_connection(self) -> Iterable[Connection]: - conn = sqlite3.connect() + conn = sqlite3.connect(":memory:") yield conn conn.close()