Skip to content

Commit

Permalink
pydantic support
Browse files Browse the repository at this point in the history
  • Loading branch information
e3rd committed Sep 6, 2024
1 parent 9edf1b5 commit e20fa46
Show file tree
Hide file tree
Showing 10 changed files with 229 additions and 80 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,7 @@ pip install mininterface

# Docs

Use a common [dataclass](https://docs.python.org/3/library/dataclasses.html#dataclasses.dataclass) for the configuration. Wrap it to the [run](#run) method that returns an interface `m`. Access the configuration via `m.env` or use it to prompt the user `m.ask_yes("Is that alright?")`.
Use a common [dataclass](https://docs.python.org/3/library/dataclasses.html#dataclasses.dataclass) or a Pydantic [BaseModel](https://brentyi.github.io/tyro/examples/04_additional/08_pydantic/) to store the configuration. Wrap it to the [run](#run) method that returns an interface `m`. Access the configuration via `m.env` or use it to prompt the user `m.ask_yes("Is that alright?")`.

## Basic usage

Expand Down
34 changes: 18 additions & 16 deletions mininterface/FormDict.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
import logging
from typing import Any, Callable, Optional, TypeVar, Union, get_type_hints


from .FormField import FormField

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -36,8 +35,8 @@ def dict_to_formdict(data: dict) -> FormDict:
if isinstance(val, dict): # nested config hierarchy
fd[key] = dict_to_formdict(val)
else: # scalar value
# NOTE name=param is not set (yet?) in `config_to_formdict`, neither `src`
fd[key] = FormField(val, "", name=key, _src_dict=(data, key)) if not isinstance(val, FormField) else val
fd[key] = FormField(val, "", name=key, _src_dict=data, _src_key=key) \
if not isinstance(val, FormField) else val
return fd


Expand All @@ -55,29 +54,32 @@ def formdict_to_widgetdict(d: FormDict | Any, widgetize_callback: Callable, _key
def dataclass_to_formdict(env: EnvClass, descr: dict, _path="") -> FormDict:
""" Convert the dataclass produced by tyro into dict of dicts. """
main = ""
params = {main: {}} if not _path else {}
subdict = {main: {}} if not _path else {}
for param, val in vars(env).items():
annotation = None
annotation = get_type_hints(env.__class__).get(param)
if val is None:
wanted_type = get_type_hints(env.__class__).get(param)
if wanted_type in (Optional[int], Optional[str]):
if annotation in (Optional[int], Optional[str]):
# Since tkinter_form does not handle None yet, we have help it.
# We need it to be able to write a number and if empty, return None.
# This would fail: `severity: int | None = None`
# Here, we convert None to str(""), in normalize_types we convert it back.
annotation = wanted_type
val = ""
else:
# An unknown type annotation encountered-
# An unknown type annotation encountered.
# Since tkinter_form does not handle None yet, this will display as checkbox.
# Which is not probably wanted.
val = False
logger.warn(f"Annotation {wanted_type} of `{param}` not supported by Mininterface."
logger.warn(f"Annotation {annotation} of `{param}` not supported by Mininterface."
"None converted to False.")
if hasattr(val, "__dict__"): # nested config hierarchy
params[param] = dataclass_to_formdict(val, descr, _path=f"{_path}{param}.")
elif not _path: # scalar value in root
params[main][param] = FormField(val, descr.get(param), annotation, param, _src_obj=(env, param))
else: # scalar value in nested
params[param] = FormField(val, descr.get(f"{_path}{param}"), annotation, param, _src_obj=(env, param))
return params
subdict[param] = dataclass_to_formdict(val, descr, _path=f"{_path}{param}.")
else:
params = {"val": val,
"_src_key": param,
"_src_obj": env
}
if not _path: # scalar value in root
subdict[main][param] = FormField(description=descr.get(param), **params)
else: # scalar value in nested
subdict[param] = FormField(description=descr.get(f"{_path}{param}"), **params)
return subdict
144 changes: 105 additions & 39 deletions mininterface/FormField.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,23 @@
from ast import literal_eval
from dataclasses import dataclass, fields
from types import UnionType
from typing import TYPE_CHECKING, Any, Callable, Iterable, Optional, TypeVar, get_args
from typing import TYPE_CHECKING, Any, Callable, Iterable, Optional, TypeVar, get_args, get_type_hints

from .auxiliary import flatten

if TYPE_CHECKING:
from .FormDict import FormDict

# Pydantic is not a project dependency, that is just an optional integration
try: # Pydantic is not a dependency but integration
from pydantic import ValidationError as PydanticValidationError
from pydantic import create_model
pydantic = True
except:
pydantic = False
PydanticValidationError = None
create_model = None

FFValue = TypeVar("FFValue")
TD = TypeVar("TD")
""" dict """
Expand All @@ -19,6 +29,8 @@
""" Callback validation error message"""
ValidationResult = bool | ErrorMessage
""" Callback validation result is either boolean or an error message. """
PydanticFieldInfo = TypeVar("PydanticFieldInfo")


@dataclass
class FormField:
Expand Down Expand Up @@ -55,7 +67,7 @@ class FormField:
"""

validation: Callable[["FormField"], ValidationResult | tuple[ValidationResult,
FieldValue]] | None = None
FieldValue]] | None = None
""" When the user submits the form, the values are validated (and possibly transformed) with a callback function.
If the validation fails, user is prompted to edit the value.
Return True if validation succeeded or False or an error message when it failed.
Expand All @@ -82,13 +94,24 @@ def check(ff: FormField):
I am not sure whether to store the transformed value in the ui_value or fixed_value.
"""

_src_dict: tuple[TD, TK] | None = None
""" The original dict to be updated when UI ends.
"""
_src_obj: tuple[TD, TK] | None = None
# _pydantic_model: type = None
# """ NOTE Experimental

# Annotation of pydantic model is what?

# """

_src_dict: TD | None = None
""" The original dict to be updated when UI ends."""

_src_obj: TD | None = None
""" The original object to be updated when UI ends.
NOTE might be merged to `src`
If not set earlier, fetches name, annotation, _pydantic_field from this class.
"""
_src_key: str | None = None
""" Key in the src object / src dict """
_src_class: type | None = None
""" If not set earlier, fetch name, annotation and _pydantic_field from this class. """

#
# Following attributes are not meant to be set externally.
Expand All @@ -111,8 +134,20 @@ def check(ff.val):
# """ Distinguish _ui_val default None from the user UI input None """
# _ui_val = None
# """ Auxiliary variable. UI state → validation fails on a field, we need to restore """
_pydantic_field: PydanticFieldInfo = None

def __post_init__(self):
# Fetch information from the parent object
if self._src_obj and not self._src_class:
self._src_class = self._src_obj
if self._src_class:
if not self.annotation: # when we have _src_class, we must have _src_key too
self.annotation = get_type_hints(self._src_class).get(self._src_key)
if pydantic: # Pydantic integration
self._pydantic_field: dict | None = getattr(self._src_class, "model_fields", {}).get(self._src_key)
if not self.name and self._src_key:
self.name = self._src_key

if not self.annotation:
self.annotation = type(self.val)
self._original_desc = self.description
Expand All @@ -123,11 +158,16 @@ def __repr__(self):
field_strings = []
for field in fields(self):
field_value = getattr(self, field.name)
# clean-up protected members
if field.name.startswith("_"):
continue

# Display 'validation=not_empty' instead of 'validation=<function not_empty at...>'
if field.name == 'validation' and (func_name := getattr(field_value, "__name__", "")):
v = f"{field.name}={func_name}"
else:
v = f"{field.name}={field_value!r}"

field_strings.append(v)
return f"{self.__class__.__name__}({', '.join(field_strings)})"

Expand All @@ -150,6 +190,43 @@ def _repr_annotation(self):
else:
return self.annotation.__name__

def _validate(self, out_value) -> FieldValue:
""" Runs
* self.validation callback
* pydantic validation
* annotation type validation
If succeeded, return the (possibly transformed) value.
If failed, raises ValueError.
"""
if self.validation:
last = self.val
self.val = out_value
res = self.validation(self)
if isinstance(res, tuple):
passed, out_value = res
self.val = out_value
else:
passed = res
self.val = last
if passed is not True: # we did not pass, there might be an error message in passed
self.set_error_text(passed or f"Validation fail")
raise ValueError

# pydantic_check
if self._pydantic_field:
try:
create_model('ValidationModel', check=(self.annotation, self._pydantic_field))(check=out_value)
except PydanticValidationError as e:
self.set_error_text(e.errors()[0]["msg"])
raise ValueError

# Type check
if self.annotation and not isinstance(out_value, self.annotation):
self.set_error_text(f"Type must be {self._repr_annotation()}!")
raise ValueError
return out_value

def update(self, ui_value) -> bool:
""" UI value → FormField value → original value. (With type conversion and checks.)
Expand Down Expand Up @@ -190,8 +267,20 @@ def update(self, ui_value) -> bool:
self.set_error_text(f"Not a valid {self._repr_annotation()}")
return False

if not isinstance(out_value, self.annotation):
if isinstance(out_value, str):
try:
seems_bad = not isinstance(out_value, self.annotation) and isinstance(out_value, str)
except TypeError:
# Why checking TypeError? Due to Pydantic.
# class Inner(BaseModel):
# id: int
# class Model(BaseModel):
# items1: List[Item] = []
# 'TypeError: Subscripted generics cannot be used with class and instance checks'
# items2: list[Item] = []
# 'TypeError: cannot be a parameterized generic'
pass
else:
if seems_bad:
try:
# Textual ask_number -> user writes '123', this has to be converted to int 123
# NOTE: Unfortunately, type(list) looks awful here. @see TextualInterface.form comment.
Expand All @@ -200,40 +289,17 @@ def update(self, ui_value) -> bool:
# Automatic conversion failed
pass

# User validation check
if self.validation:
last = self.val
self.val = out_value
res = self.validation(self)
if isinstance(res, tuple):
passed, out_value = res
self.val = ui_value = out_value
else:
passed = res
self.val = last
if passed is not True: # we did not pass, there might be an error message in passed
self.set_error_text(passed or f"Validation fail")
# self.val = last
return False

# Type check
if self.annotation and not isinstance(out_value, self.annotation):
self.set_error_text(f"Type must be {self._repr_annotation()}!")
# self.val = last
return False # revision needed

# keep values if revision needed
# We merge new data to the origin. If form is re-submitted, the values will stay there.
self.val = out_value # checks succeeded, confirm the value

# User and type validation check
try:
self.val = self._validate(out_value) # checks succeeded, confirm the value
except ValueError:
return False

# Store to the source user data
if self._src_dict:
d, k = self._src_dict
d[k] = out_value
self._src_dict[self._src_key] = out_value
elif self._src_obj:
d, k = self._src_obj
setattr(d, k, out_value)
setattr(self._src_obj, self._src_key, out_value)
else:
# This might be user-created object. There is no need to update anything as the user reads directly from self.val.
pass
Expand Down
1 change: 0 additions & 1 deletion mininterface/TextualInterface.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,6 @@ def form(self, form: FormDictOrEnv | None = None, title: str = "") -> FormDictOr

# NOTE we should implement better, now the user does not know it needs an int
def ask_number(self, text: str):
# TODO suggestion fail
return self.form({text: FormField("", "", int, text)})[text].val

def is_yes(self, text: str):
Expand Down
2 changes: 1 addition & 1 deletion mininterface/auxiliary.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ def get_terminal_size():

def get_descriptions(parser: ArgumentParser) -> dict:
""" Load descriptions from the parser. Strip argparse info about the default value as it will be editable in the form. """
return {action.dest.replace("-", "_"): re.sub(r"\(default.*\)", "", action.help)
return {action.dest.replace("-", "_"): re.sub(r"\(default.*\)", "", action.help or "")
for action in parser._actions}

def recursive_set_focus(widget: Widget):
Expand Down
56 changes: 37 additions & 19 deletions mininterface/cli_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,15 @@
from .FormField import FormField
from .validators import not_empty

# Pydantic is not a project dependency, that is just an optional integration
try: # Pydantic is not a dependency but integration
from pydantic import BaseModel
pydantic = True
except:
pydantic = False
BaseModel = False


WrongFields = dict[str, FormField]

eavesdrop = ""
Expand Down Expand Up @@ -108,25 +117,26 @@ def run_tyro_parser(env_class: Type[EnvClass],
# (with a graceful message from tyro)
pass
else:
# NOTE: For a missing int, we put '' to the UI.
# The UI is then not able to use the number filtering capabilities.
# However, we insist on having '' value as it clearly states to the user
# that the value is missing.
type_ = env_class.__annotations__[argument.dest]


wf[argument.dest] = FormField("",
argument.help.replace("(required)", ""),
type_,
validation=not_empty
)
setattr(kwargs["default"], argument.dest, None)

# second attempt to parse CLI
# NOTE: We put '' to the UI to clearly state that the value is missing.
# However, the UI then is not able to use the number filtering capabilities.
ff = wf[argument.dest] = FormField("",
argument.help.replace("(required)", ""),
validation=not_empty,
_src_class=env_class,
_src_key=argument.dest
)
# Why `type_()`? We need to put a default value so that the parsing will not fail.
# A None would be enough because Mininterface will ask for the missing values
# promply, however, Pydantic model would fail.
setattr(kwargs["default"], argument.dest, ff.annotation())

# Second attempt to parse CLI
# Why catching warnings? All the meaningful warnings
# have been produces during the first attempt.
# Now, when we defaulted all the missing fields with None,
# tyro produces 'UserWarning: The field (...) but the default value has type <class 'str'>.'
# (This is not true anymore; to support pydantic we put a default value of the type,
# so there is probably no more warning to be caught.)
with warnings.catch_warnings():
warnings.simplefilter('ignore')
return cli(env_class, **kwargs), wf
Expand Down Expand Up @@ -156,10 +166,18 @@ def _parse_cli(env_class: Type[EnvClass],
# Nested dataclasses have to be properly initialized. YAML gave them as dicts only.
for key in (key for key, val in disk.items() if isinstance(val, dict)):
disk[key] = env_class.__annotations__[key](**disk[key])
# To ensure the configuration file does not need to contain all keys, we have to fill in the missing ones.
# Otherwise, tyro will spawn warnings about missing fields.
static = {key: getattr(env_class, key, MISSING)
for key in env_class.__annotations__ if not key.startswith("__") and not key in disk}

# Fill default fields
if pydantic and issubclass(env_class, BaseModel):
# Unfortunately, pydantic needs to fill the default with the actual values,
# the default value takes the precedence over the hard coded one, even if missing.
static = {key: env_class.model_fields.get(key).default
for key in env_class.__annotations__ if not key.startswith("__") and not key in disk}
else:
# To ensure the configuration file does not need to contain all keys, we have to fill in the missing ones.
# Otherwise, tyro will spawn warnings about missing fields.
static = {key: getattr(env_class, key, MISSING)
for key in env_class.__annotations__ if not key.startswith("__") and not key in disk}
kwargs["default"] = SimpleNamespace(**(disk | static))

# Load configuration from CLI
Expand Down
Loading

0 comments on commit e20fa46

Please sign in to comment.