From d10b689e8779a52e7a352ede5c3c32ab535a7b71 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Quentin=20G=C3=A9r=C3=B4me?= Date: Tue, 9 May 2023 17:09:40 +0200 Subject: [PATCH] fix(connection): Create a dataclass on the fly for custom connection --- openhexa/sdk/workspaces/connection.py | 5 ----- openhexa/sdk/workspaces/workspace.py | 16 +++++++++------- tests/test_workspace.py | 7 +++---- 3 files changed, 12 insertions(+), 16 deletions(-) diff --git a/openhexa/sdk/workspaces/connection.py b/openhexa/sdk/workspaces/connection.py index fe59b7b..6d29665 100644 --- a/openhexa/sdk/workspaces/connection.py +++ b/openhexa/sdk/workspaces/connection.py @@ -35,8 +35,3 @@ class S3Connection: class GCSConnection: service_account_key: str bucket_name: str - - -@dataclasses.dataclass -class CustomConnection: - fields: dict diff --git a/openhexa/sdk/workspaces/workspace.py b/openhexa/sdk/workspaces/workspace.py index bafd0ee..2a27402 100644 --- a/openhexa/sdk/workspaces/workspace.py +++ b/openhexa/sdk/workspaces/workspace.py @@ -1,12 +1,11 @@ import os -import re +from dataclasses import make_dataclass import stringcase from openhexa.sdk.utils import Environments, get_environment from .connection import ( - CustomConnection, DHIS2Connection, GCSConnection, PostgreSQLConnection, @@ -154,13 +153,16 @@ def gcs_connection(self, slug: str) -> GCSConnection: bucket_name=bucket_name, ) - def custom_connection(slef, slug: str) -> CustomConnection: - env_variable_prefix = stringcase.constcase(slug.lower()) + def custom_connection(self, slug: str): + slug = slug.lower() + env_variable_prefix = stringcase.constcase(slug) fields = {} for key, value in os.environ.items(): - if re.match(rf"^{env_variable_prefix}_", key): - fields[key] = os.environ[key] - return CustomConnection(fields=fields) + if key.startswith(env_variable_prefix): + field_key = key[len(f"{env_variable_prefix}_") :].lower() + fields[field_key] = value + CustomConnection = make_dataclass(slug, fields.keys()) + return CustomConnection(**fields) workspace = CurrentWorkspace() diff --git a/tests/test_workspace.py b/tests/test_workspace.py index 45ee658..4ef7a3a 100644 --- a/tests/test_workspace.py +++ b/tests/test_workspace.py @@ -119,7 +119,7 @@ def test_workspace_gcs_connection(): def test_workspace_custom_connection(): - slug = "polio-ff3a0d" + slug = "my_connection" env_variable_prefix = stringcase.constcase(slug) username = "kaggle_username" password = "root" @@ -132,6 +132,5 @@ def test_workspace_custom_connection(): }, ): custom_connection = workspace.custom_connection(slug=slug) - assert len(custom_connection.fields) == 2 - assert f"{env_variable_prefix}_USERNAME" in custom_connection.fields - assert f"{env_variable_prefix}_PASSWORD" in custom_connection.fields + assert custom_connection.username == username + assert custom_connection.password == password