Skip to content

Commit

Permalink
Merge branch 'main' into ninhu/disable_tracing
Browse files Browse the repository at this point in the history
  • Loading branch information
ninghu authored Jan 7, 2025
2 parents 04fbef1 + a981d6a commit 037424e
Showing 1 changed file with 4 additions and 3 deletions.
7 changes: 4 additions & 3 deletions src/promptflow-devkit/promptflow/_sdk/schemas/_fields.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

from marshmallow import fields
from marshmallow.exceptions import FieldInstanceResolutionError, ValidationError
from marshmallow.fields import _T, Field, Nested
from marshmallow.fields import Field, Nested
from marshmallow.utils import RAISE, resolve_field_instance

from promptflow._sdk._constants import BASE_PATH_CONTEXT_KEY
Expand All @@ -17,6 +17,7 @@
# pylint: disable=unused-argument,no-self-use,protected-access

module_logger = LoggerFactory.get_logger(__name__)
T = typing.TypeVar("T")


class StringTransformedEnum(Field):
Expand Down Expand Up @@ -215,7 +216,7 @@ def __init__(self, *args, **kwargs):
class DumpableIntegerField(fields.Integer):
"""An int field that cannot serialize other type of values to int if self.strict."""

def _serialize(self, value, attr, obj, **kwargs) -> typing.Optional[typing.Union[str, _T]]:
def _serialize(self, value, attr, obj, **kwargs) -> typing.Optional[typing.Union[str, T]]:
if self.strict and not isinstance(value, int):
# this implementation can serialize bool to bool
raise self.make_error("invalid", input=value)
Expand All @@ -241,7 +242,7 @@ def _validated(self, value):
raise self.make_error("invalid", input=value)
return super()._validated(value)

def _serialize(self, value, attr, obj, **kwargs) -> typing.Optional[typing.Union[str, _T]]:
def _serialize(self, value, attr, obj, **kwargs) -> typing.Optional[typing.Union[str, T]]:
return super()._serialize(self._validated(value), attr, obj, **kwargs)


Expand Down

0 comments on commit 037424e

Please sign in to comment.