Skip to content

Commit

Permalink
feat: add validation
Browse files Browse the repository at this point in the history
  • Loading branch information
PaulKalho committed Nov 11, 2024
1 parent 771b5b1 commit dde0e50
Show file tree
Hide file tree
Showing 3 changed files with 163 additions and 37 deletions.
2 changes: 1 addition & 1 deletion scystream/sdk/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ def wrapper(*args, **kwargs):
if settings_class is not None:
# Load settings
try:
settings = settings_class.load_settings()
settings = settings_class.from_env()
except ValidationError as e:
raise ValueError(f"Invalid environment configuration: {e}")

Expand Down
66 changes: 53 additions & 13 deletions scystream/sdk/env/settings.py
Original file line number Diff line number Diff line change
@@ -1,30 +1,70 @@
from pathlib import Path
from pydantic_settings import BaseSettings, SettingsConfigDict
from typing import Type
from typing import Union, List, get_type_hints
from pydantic import Field

ENV_FILE_ENCODING = "utf-8"


class BaseENVSettings(BaseSettings):
"""
This class acts as the BaseClass which can be used to define custom
ENV-Variables which can be used across the ComputeBlock & for entrypoints
This definition, and pydantic, will then take care of validating the envs
"""
Allow kwargs to propagate to any fields whose default factory extends
BaseSettings,
This is mostly to allow _env_file to be passed through.
"""
model_config = SettingsConfigDict(
env_file_encoding=ENV_FILE_ENCODING,
case_sensitive=True,
extra="ignore"
)

@classmethod
def load_settings(
cls: Type["BaseENVSettings"],
env_file: str = ".env"
) -> "BaseENVSettings":
def from_env(
cls,
env_file: Union[str, Path, List[Union[str, Path]]] = None,
*args,
**kwargs
):
return cls(propagate_kwargs={"_env_file": env_file}, *args, **kwargs)

@classmethod
def _basesettings_fields(cls):
"""
load_settings loads the env file. The name of the env_file can be
passed as an argument.
Returns the parsed ENVs
:return a dict of field_name: default_factory for any fields that
extend BaseSettings
"""
return cls(_env_file=env_file, _env_file_encoding=ENV_FILE_ENCODING)
type_hints = get_type_hints(cls)
return {
name: typ for name, typ in type_hints.items()
if isinstance(typ, type) and issubclass(typ, BaseSettings)

}

@classmethod
def _propagate_kwargs(cls, kwargs):
"""
Any settings that extend BaseSettings be passed the kwargs.
"""
sub_settings = cls._basesettings_fields()
for name, field_type in sub_settings.items():
kwargs[name] = field_type(**kwargs)
return kwargs

def __init_subclass__(cls, **kwargs):
"""
Automatically set up nested settings fields with default_factory.
"""
super().__init_subclass__(**kwargs)
type_hints = get_type_hints(cls)
for field_name, field_type in type_hints.items():
if isinstance(field_type, type) and issubclass(
field_type, BaseSettings):
# Set a default factory for nested BaseSettings fields
default_field = Field(default_factory=field_type)
setattr(cls, field_name, default_field)

def __init__(self, propagate_kwargs=None, *args, **kwargs):
if propagate_kwargs:
kwargs = self._propagate_kwargs(propagate_kwargs)
super().__init__(*args, **kwargs)
132 changes: 109 additions & 23 deletions tests/test_settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,48 +2,134 @@
import os
from scystream.sdk.core import entrypoint
from scystream.sdk.env.settings import BaseENVSettings
from scystream.sdk.scheduler import Scheduler


class DummyInputSettings(BaseENVSettings):
DUMMY_INPUT: str = "test"


class WithDefaultSettings(BaseENVSettings):
DUMMY_SETTING: str = "this is a dummy setting"
DUMMY_GLOBAL: str = "dummy global var"

dummy_input_settings: DummyInputSettings


class DummyInputSettingsNoDef(BaseENVSettings):
DUMMY_INPUT: str


class WithoutDefaultSettings(BaseENVSettings):
DUMMY_GLOBAL: str

dummy_input_settings_no_def: DummyInputSettingsNoDef


class WithoutDefaultNoNesting(BaseENVSettings):
TEST: str = "teststr"
MUST_SET: str


class SubOne(BaseENVSettings):
ONE: str
TWO: str


class SubTwo(BaseENVSettings):
TEST: str
NO_DEF: str


class NoDefaultSetting(BaseENVSettings):
DUMMY_SETTING: str
class TwoSubclasses(BaseENVSettings):
GLOBAL: str

input_one: SubOne
input_two: SubTwo


class TestSettings(unittest.TestCase):
def test_entrypoint_with_setting_default(self):
@entrypoint(WithDefaultSettings)
def with_default_settings(settings):
return settings.DUMMY_SETTING
return settings.dummy_input_settings.DUMMY_INPUT

result = with_default_settings()
self.assertEqual(result, "this is a dummy setting")
self.assertEqual(result, "test")

"""
environment is set
"""
os.environ["DUMMY_SETTING"] = "overridden setting"
# set environ
os.environ["DUMMY_INPUT"] = "overridden setting"
result = with_default_settings()
# check if overriding works
self.assertEqual(result, "overridden setting")
del os.environ["DUMMY_SETTING"]

def test_entrypoint_with_no_setting_default(self):
@entrypoint(NoDefaultSetting)
def with_no_default_settings(settings):
return settings.DUMMY_SETTING
del os.environ["DUMMY_INPUT"]

def test_entrypoint_no_setting_default_one(self):
@entrypoint(WithoutDefaultSettings)
def without_def_settings(settings):
print("test...")

# do we fail if environments not set
with self.assertRaises(ValueError):
with_no_default_settings()

"""
environemnt is set
"""
os.environ["DUMMY_SETTING"] = "required setting"
result = with_no_default_settings()
self.assertEqual(result, "required setting")
del os.environ["DUMMY_SETTING"]
Scheduler.execute_function("without_def_settings")

def test_entrypoint_no_setting_default_two(self):
@entrypoint(WithoutDefaultSettings)
def without_def_settings(settings):
return (
settings.DUMMY_GLOBAL,
settings.dummy_input_settings_no_def.DUMMY_INPUT
)

# set environments
os.environ["DUMMY_GLOBAL"] = "dummy global"
os.environ["DUMMY_INPUT"] = "dummy input"

# check if environments have been set
result = without_def_settings()
self.assertEqual(result[0], "dummy global")
self.assertEqual(result[1], "dummy input")

del os.environ["DUMMY_GLOBAL"]
del os.environ["DUMMY_INPUT"]

def test_entrypoint_no_setting_defautl_three(self):
@entrypoint(WithoutDefaultNoNesting)
def no_nesting(settings):
print("testing...")

with self.assertRaises(ValueError):
Scheduler.execute_function("no_nesting")

def test_two_subs(self):
@entrypoint(TwoSubclasses)
def two_subs(settings):
return (
settings.GLOBAL,
settings.input_one.ONE,
settings.input_one.TWO,
settings.input_two.TEST,
settings.input_two.NO_DEF
)

os.environ["GLOBAL"] = "global"
os.environ["ONE"] = "one"
os.environ["TWO"] = "two"
os.environ["TEST"] = "test"
os.environ["NO_DEF"] = "no_def"

result = two_subs()
self.assertEqual(result[0], "global")
self.assertEqual(result[1], "one")
self.assertEqual(result[2], "two")
self.assertEqual(result[3], "test")
self.assertEqual(result[4], "no_def")

del os.environ["GLOBAL"]
del os.environ["ONE"]
del os.environ["TWO"]
del os.environ["TEST"]
del os.environ["NO_DEF"]


if __name__ == "__main__":
Expand Down

0 comments on commit dde0e50

Please sign in to comment.