Skip to content

Commit

Permalink
feat: check for connection before using
Browse files Browse the repository at this point in the history
  • Loading branch information
cofin committed Aug 30, 2024
1 parent 8889740 commit b94a318
Show file tree
Hide file tree
Showing 5 changed files with 36 additions and 27 deletions.
22 changes: 14 additions & 8 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -41,25 +41,31 @@ from __future__ import annotations

from typing import TYPE_CHECKING

from litestar import Controller, Litestar, get
from litestar import Controller, Litestar, Request, get

from litestar_oracledb import AsyncDatabaseConfig, AsyncPoolConfig, OracleDatabasePlugin

if TYPE_CHECKING:
from oracledb import AsyncConnection


class SampleController(Controller):
@get(path="/sample")
async def sample_route(self, db_connection: AsyncConnection) -> dict[str, str]:
@get(path="/")
async def sample_route(self, request: Request, db_connection: AsyncConnection) -> dict[str, str]:
"""Check database available and returns app config info."""
cursor = db_connection.cursor()
await cursor.execute("select 1 from dual")
result = await cursor.fetchone()
return {"select_1": str(result)}
with db_connection.cursor() as cursor:
await cursor.execute("select 'a database value' a_column from dual")
result = await cursor.fetchone()
request.logger.info(result[0])
if result:
return {"a_column": result[0]}
return {"a_column": "dunno"}


oracledb = OracleDatabasePlugin(
config=AsyncDatabaseConfig(pool_config=AsyncPoolConfig(user="app", password="super-secret", dsn="localhost:1521/freepdb1"))
config=AsyncDatabaseConfig(
pool_config=AsyncPoolConfig(user="system", password="super-secret", dsn="localhost:1513/FREEPDB1") # noqa: S106
)
)
app = Litestar(plugins=[oracledb], route_handlers=[SampleController])

Expand Down
17 changes: 6 additions & 11 deletions examples/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@

from typing import TYPE_CHECKING

import msgspec
from litestar import Controller, Litestar, Request, get

from litestar_oracledb import AsyncDatabaseConfig, AsyncPoolConfig, OracleDatabasePlugin
Expand All @@ -11,26 +10,22 @@
from oracledb import AsyncConnection


class HealthCheck(msgspec.Struct):
status: str


class SampleController(Controller):
@get(path="/sample")
async def sample_route(self, request: Request, db_connection: AsyncConnection) -> HealthCheck:
@get(path="/")
async def sample_route(self, request: Request, db_connection: AsyncConnection) -> dict[str, str]:
"""Check database available and returns app config info."""
with db_connection.cursor() as cursor:
await cursor.execute("select 'a database value' a_column from dual")
result = await cursor.fetchone()
request.logger.info(result)
request.logger.info(result[0])
if result:
return HealthCheck(status="online")
return HealthCheck(status="offline")
return {"a_column": result[0]}
return {"a_column": "dunno"}


oracledb = OracleDatabasePlugin(
config=AsyncDatabaseConfig(
pool_config=AsyncPoolConfig(user="app", password="super-secret", dsn="localhost:1521/freepdb1") # noqa: S106
pool_config=AsyncPoolConfig(user="system", password="super-secret", dsn="localhost:1513/FREEPDB1") # noqa: S106
)
)
app = Litestar(plugins=[oracledb], route_handlers=[SampleController])
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ license = { text = "MIT" }
name = "litestar-oracledb"
readme = "README.md"
requires-python = ">=3.8"
version = "0.1.1"
version = "0.1.2"

[project.urls]
Changelog = "https://litestar-org.github.io/litesatr-oracledb/latest/changelog"
Expand Down
15 changes: 11 additions & 4 deletions src/litestar_oracledb/config/_asyncio.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,8 +50,10 @@ async def handler(message: Message, scope: Scope) -> None:
None
"""
connection = cast("AsyncConnection | None", get_scope_state(scope, connection_scope_key))
if connection and message["type"] in SESSION_TERMINUS_ASGI_EVENTS:
await connection.close()
if connection is not None and message["type"] in SESSION_TERMINUS_ASGI_EVENTS and connection._impl is not None: # noqa: SLF001
# checks to to see if connected without raising an exception: https://github.com/oracle/python-oracledb/blob/main/src/oracledb/connection.py#L80
if connection._impl is not None: # noqa: SLF001
await connection.close()
delete_scope_state(scope, connection_scope_key)

return handler
Expand Down Expand Up @@ -97,15 +99,20 @@ async def handler(message: Message, scope: Scope) -> None:
"""
connection = cast("AsyncConnection | None", get_scope_state(scope, connection_scope_key))
try:
if connection is not None and message["type"] == HTTP_RESPONSE_START:
if connection is not None and message["type"] == HTTP_RESPONSE_START and connection._impl is not None: # noqa: SLF001
if (message["status"] in commit_range or message["status"] in extra_commit_statuses) and message[
"status"
] not in extra_rollback_statuses:
await connection.commit()
else:
await connection.rollback()
finally:
if connection and message["type"] in SESSION_TERMINUS_ASGI_EVENTS:
# checks to to see if connected without raising an exception: https://github.com/oracle/python-oracledb/blob/main/src/oracledb/connection.py#L80
if (
connection is not None
and message["type"] in SESSION_TERMINUS_ASGI_EVENTS
and connection._impl is not None # noqa: SLF001
):
await connection.close()
delete_scope_state(scope, connection_scope_key)

Expand Down
7 changes: 4 additions & 3 deletions src/litestar_oracledb/config/_sync.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,8 @@ async def handler(message: Message, scope: Scope) -> None:
None
"""
connection = cast("Connection | None", get_scope_state(scope, connection_scope_key))
if connection and message["type"] in SESSION_TERMINUS_ASGI_EVENTS:
# checks to to see if connected without raising an exception: https://github.com/oracle/python-oracledb/blob/main/src/oracledb/connection.py#L80
if connection and message["type"] in SESSION_TERMINUS_ASGI_EVENTS and connection._impl is not None: # noqa: SLF001
connection.close()
delete_scope_state(scope, connection_scope_key)

Expand Down Expand Up @@ -98,15 +99,15 @@ def handler(message: Message, scope: Scope) -> None:
"""
connection = cast("Connection | None", get_scope_state(scope, connection_scope_key))
try:
if connection is not None and message["type"] == HTTP_RESPONSE_START:
if connection is not None and message["type"] == HTTP_RESPONSE_START and connection._impl is not None: # noqa: SLF001
if (message["status"] in commit_range or message["status"] in extra_commit_statuses) and message[
"status"
] not in extra_rollback_statuses:
connection.commit()
else:
connection.rollback()
finally:
if connection and message["type"] in SESSION_TERMINUS_ASGI_EVENTS:
if connection and message["type"] in SESSION_TERMINUS_ASGI_EVENTS and connection._impl is not None: # noqa: SLF001
connection.close()
delete_scope_state(scope, connection_scope_key)

Expand Down

0 comments on commit b94a318

Please sign in to comment.