Skip to content

Commit

Permalink
fix(connection): Create a dataclass on the fly for custom connection
Browse files Browse the repository at this point in the history
  • Loading branch information
qgerome committed May 10, 2023
1 parent 51ea0ed commit d10b689
Show file tree
Hide file tree
Showing 3 changed files with 12 additions and 16 deletions.
5 changes: 0 additions & 5 deletions openhexa/sdk/workspaces/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,8 +35,3 @@ class S3Connection:
class GCSConnection:
service_account_key: str
bucket_name: str


@dataclasses.dataclass
class CustomConnection:
fields: dict
16 changes: 9 additions & 7 deletions openhexa/sdk/workspaces/workspace.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,11 @@
import os
import re
from dataclasses import make_dataclass

import stringcase

from openhexa.sdk.utils import Environments, get_environment

from .connection import (
CustomConnection,
DHIS2Connection,
GCSConnection,
PostgreSQLConnection,
Expand Down Expand Up @@ -154,13 +153,16 @@ def gcs_connection(self, slug: str) -> GCSConnection:
bucket_name=bucket_name,
)

def custom_connection(slef, slug: str) -> CustomConnection:
env_variable_prefix = stringcase.constcase(slug.lower())
def custom_connection(self, slug: str):
slug = slug.lower()
env_variable_prefix = stringcase.constcase(slug)
fields = {}
for key, value in os.environ.items():
if re.match(rf"^{env_variable_prefix}_", key):
fields[key] = os.environ[key]
return CustomConnection(fields=fields)
if key.startswith(env_variable_prefix):
field_key = key[len(f"{env_variable_prefix}_") :].lower()
fields[field_key] = value
CustomConnection = make_dataclass(slug, fields.keys())
return CustomConnection(**fields)


workspace = CurrentWorkspace()
7 changes: 3 additions & 4 deletions tests/test_workspace.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,7 @@ def test_workspace_gcs_connection():


def test_workspace_custom_connection():
slug = "polio-ff3a0d"
slug = "my_connection"
env_variable_prefix = stringcase.constcase(slug)
username = "kaggle_username"
password = "root"
Expand All @@ -132,6 +132,5 @@ def test_workspace_custom_connection():
},
):
custom_connection = workspace.custom_connection(slug=slug)
assert len(custom_connection.fields) == 2
assert f"{env_variable_prefix}_USERNAME" in custom_connection.fields
assert f"{env_variable_prefix}_PASSWORD" in custom_connection.fields
assert custom_connection.username == username
assert custom_connection.password == password

0 comments on commit d10b689

Please sign in to comment.