Skip to content

Commit

Permalink
Merge pull request #279 from ncilfone/maps
Browse files Browse the repository at this point in the history
All the check are passed. Adds __map__ functionality is ready to be merged
  • Loading branch information
mmalouane authored Nov 3, 2023
2 parents 837e485 + 9a1a6b4 commit 2e40d04
Show file tree
Hide file tree
Showing 10 changed files with 379 additions and 134 deletions.
7 changes: 6 additions & 1 deletion spock/backend/builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -238,9 +238,14 @@ def _cast_all_maps(cls, cls_fields: Dict, changed_vars: Set) -> None:
"""
for val in changed_vars:
# Make sure we cast directory and file types back to strings
if getattr(cls.__attrs_attrs__, val).type.__name__ in ("directory", "file"):
value_type = str
else:
value_type = getattr(cls.__attrs_attrs__, val).type
cls_fields[val] = VarResolver._attempt_cast(
maybe_env=cls_fields[val],
value_type=getattr(cls.__attrs_attrs__, val).type,
value_type=value_type,
ref_value=val,
)

Expand Down
92 changes: 85 additions & 7 deletions spock/backend/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,13 @@
"""Creates the spock config interface that wraps attr"""

import sys
from typing import Dict

import attr

from spock.backend.typed import katra
from spock.exceptions import _SpockInstantiationError, _SpockUndecoratedClass
from spock.utils import _is_spock_instance, vars_dict_non_dunder
from spock.utils import _is_spock_instance, contains_return, vars_dict_non_dunder


def _base_attr(cls, kw_only, make_init, dynamic):
Expand All @@ -27,7 +28,7 @@ def _base_attr(cls, kw_only, make_init, dynamic):
dynamic: allows inherited classes to not be @spock decorated
Returns:
cls: base spock classes derived from the MRO
bases: all the base classes
attrs_dict: the current dictionary of attr.attribute values
merged_annotations: dictionary of type annotations
Expand Down Expand Up @@ -83,6 +84,7 @@ def _base_attr(cls, kw_only, make_init, dynamic):
merged_annotations = {**base_annotation, **new_annotations}

cls_attrs = set()
hooks = set()
# Iterate through the bases first
for val in bases:
# Get the underlying attribute defs
Expand Down Expand Up @@ -135,6 +137,79 @@ def _base_attr(cls, kw_only, make_init, dynamic):
return bases, attrs_dict, merged_annotations


def _handle_hooks(
cls,
bases,
):
"""Handles creating a single function for all hooks from the given class and
all its parents
Args:
cls: basic class definition
Returns:
function that contains all necessary hooks
"""

# Check if the base classes have any hook functions
hooks = [
val.__attrs_post_init__ for val in bases if hasattr(val, "__attrs_post_init__")
]
# maps = [val.__maps__ for val in bases if hasattr(val, "__maps__")]
# Copy over the post init function -- borrow a bit from attrs library to add the
# __post__hook__ method and/or the __maps__ method (via a shim method) to the init
# call via `"__attrs_post_init__"`
if hasattr(cls, "__post_hook__") or hasattr(cls, "__maps__") or (len(hooks) > 0):
# Force the post_hook function to have no explict return
if hasattr(cls, "__post_hook__") and contains_return(cls.__post_hook__):
raise _SpockInstantiationError(
f"__post_hook__ function contains an explict return. This function "
f"cannot return any values (i.e. requires an implicit None return)"
)
if hasattr(cls, "__maps__") and not contains_return(cls.__maps__):
raise _SpockInstantiationError(
f"__maps__ function is missing an explict return. This function "
f"needs to explicitly return any type of values"
)
# if there are parent hooks we need to map them into a function
if len(hooks) > 0:
# Create a shim function to combine __post_hook__ and __maps__
# in addition to the parental hooks
def __shim__(self):
if hasattr(cls, "__post_hook__"):
cls.__post_hook__(self)
# Call the parents hooks
all_hooks = [val(self) for val in hooks]
# Pop any None values
all_hooks = [val for val in all_hooks if val is not None]
# Add in the given hook
if hasattr(cls, "__maps__"):
all_hooks = [cls.__maps__(self)] + all_hooks
if len(all_hooks) == 1:
all_hooks = all_hooks[0]
# Set maps to the mapped values
object.__setattr__(self, "_maps", all_hooks)

else:
# Create a shim function to combine __post_hook__ and __maps__
def __shim__(self):
if hasattr(cls, "__post_hook__"):
cls.__post_hook__(self)
if hasattr(cls, "__maps__"):
object.__setattr__(self, "_maps", cls.__maps__(self))
return cls.__maps__(self)
else:
return None

else:

def __shim__(self):
...

return __shim__


def _process_class(cls, kw_only: bool, make_init: bool, dynamic: bool):
"""Process a given class
Expand All @@ -150,10 +225,12 @@ def _process_class(cls, kw_only: bool, make_init: bool, dynamic: bool):
"""
# Handles the MRO and gets old annotations
bases, attrs_dict, merged_annotations = _base_attr(cls, kw_only, make_init, dynamic)
# Copy over the post init function -- borrow a bit from attrs library to add the __post__hook__ method to the
# init call via `"__attrs_post_init__"`
if hasattr(cls, "__post_hook__"):
attrs_dict.update({"__attrs_post_init__": cls.__post_hook__})
# if hasattr(cls, "__post_hook__"):
# attrs_dict.update({"__post_hook__": cls.__post_hook__})
# if hasattr(cls, "__maps__"):
# attrs_dict.update({"__maps__": cls.__maps__})
# Map the __shim__ function into __attrs_post_init__
attrs_dict.update({"__attrs_post_init__": _handle_hooks(cls, bases)})
# Dynamically make an attr class
obj = attr.make_class(
name=cls.__name__,
Expand All @@ -164,7 +241,8 @@ def _process_class(cls, kw_only: bool, make_init: bool, dynamic: bool):
auto_attribs=True,
init=make_init,
)
# For each class we dynamically create we need to register it within the system modules for pickle to work
# For each class we dynamically create we need to register it within the system
# modules for pickle to work
setattr(sys.modules["spock"].backend.config, obj.__name__, obj)
# Swap the __doc__ string from cls to obj
obj.__doc__ = cls.__doc__
Expand Down
11 changes: 10 additions & 1 deletion spock/backend/saver.py
Original file line number Diff line number Diff line change
Expand Up @@ -227,12 +227,21 @@ def _clean_up_values(self, payload: Spockspace, remove_crypto: bool = True) -> D
clean_dict = self._clean_output(out_dict)
# Clip any empty dictionaries
clean_dict = {k: v for k, v in clean_dict.items() if len(v) > 0}
# Clean up annotations
if remove_crypto:
if "__salt__" in clean_dict:
_ = clean_dict.pop("__salt__")
if "__key__" in clean_dict:
_ = clean_dict.pop("__key__")
return clean_dict
# Clean up protected attributes
out_dict = {}
for k, v in clean_dict.items():
cls_dict = {}
for ik, iv in v.items():
if not ik.startswith("_"):
cls_dict.update({ik: iv})
out_dict.update({k: cls_dict})
return out_dict

def _clean_tuner_values(self, payload: Spockspace) -> Dict:
# Just a double nested dict comprehension to unroll to dicts
Expand Down
4 changes: 2 additions & 2 deletions spock/backend/validators.py
Original file line number Diff line number Diff line change
Expand Up @@ -261,7 +261,7 @@ def __call__(self, inst: _C, attr: attr.Attribute, value: Any) -> None:
and self.type[0].__name__ == "directory"
):
return _is_directory(
self.type, create=True, check_access=True, attr=attr, value=value
self.type, create=True, check_access=False, attr=attr, value=value
)
# Catch the file type -- tuples suck, so we need to handle them with their own
# condition here -- basically if the tuple is of type directory then we need
Expand All @@ -271,7 +271,7 @@ def __call__(self, inst: _C, attr: attr.Attribute, value: Any) -> None:
and hasattr(self.type[0], "__name__")
and self.type[0].__name__ == "file"
):
return _is_file(type=self.type, check_access=True, attr=attr, value=value)
return _is_file(type=self.type, check_access=False, attr=attr, value=value)
# Fallback on base attr
else:
return _check_instance(value=value, name=attr.name, type=self.type)
Expand Down
14 changes: 11 additions & 3 deletions spock/backend/wrappers.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,10 +22,18 @@ def __init__(self, **kwargs):

@property
def __repr_dict__(self):
"""Handles making a clean dict to hind the salt and key on print"""
return {
k: v for k, v in self.__dict__.items() if k not in {"__key__", "__salt__"}
"""Handles making a clean dict to hide the salt and key on print"""
clean_dict = {
k: v
for k, v in self.__dict__.items()
if k not in {"__key__", "__salt__", "__maps__"}
}
repr_dict = {}
for k, v in clean_dict.items():
repr_dict.update(
{k: {ik: iv for ik, iv in vars(v).items() if not ik.startswith("_")}}
)
return repr_dict

def __repr__(self):
"""Overloaded repr to pretty print the spock object"""
Expand Down
23 changes: 22 additions & 1 deletion spock/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,17 +6,19 @@
"""Utility functions for Spock"""

import ast
import inspect
import os
import random
import socket
import subprocess
import sys
import textwrap
from argparse import _ArgumentGroup
from enum import EnumMeta
from math import isclose
from pathlib import Path
from time import localtime, strftime
from typing import Any, Dict, List, Tuple, Type, TypeVar, Union
from typing import Any, Callable, Dict, List, Tuple, Type, TypeVar, Union
from warnings import warn

import attr
Expand All @@ -28,6 +30,25 @@
minor = sys.version_info.minor


def contains_return(func: Callable):
"""Checks if a function/callable has an explict return def
Args:
func: function to check for direct return
Returns:
boolean if defined return is found
References:
https://stackoverflow.com/questions/48232810/python-check-if-function-has-return-statement
"""
return any(
isinstance(node, ast.Return)
for node in ast.walk(ast.parse(textwrap.dedent(inspect.getsource(func))))
)


def vars_dict_non_dunder(__obj: object):
"""Gets the user defined attributes from a base object class
Expand Down
61 changes: 61 additions & 0 deletions tests/base/test_maps.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
# -*- coding: utf-8 -*-
import sys

from typing import List, Tuple, Optional

import pytest

from spock import spock
from spock import SpockBuilder
from spock.exceptions import _SpockInstantiationError


class DummyClass:
def __init__(self, value):
self.value = value


class TestMaps:
def test_return_raise(self, monkeypatch, tmp_path):
with monkeypatch.context() as m:
m.setattr(
sys,
"argv",
[""],
)
with pytest.raises(_SpockInstantiationError):

@spock
class FailReturnConfig:
val_1: float = 0.5

def __maps__(self):
print(self.val_1)

config = SpockBuilder(
FailReturnConfig,
desc="Test Builder",
)
config.generate()

def test_map_return(self, monkeypatch, tmp_path):
with monkeypatch.context() as m:
m.setattr(
sys,
"argv",
[""],
)

@spock
class ReturnConfig:
val_1: float = 0.5

def __maps__(self):
return DummyClass(value=self.val_1)

config = SpockBuilder(
ReturnConfig,
desc="Test Builder",
)
configs = config.generate()
assert configs.ReturnConfig._maps.value == configs.ReturnConfig.val_1
Loading

0 comments on commit 2e40d04

Please sign in to comment.