Skip to content

Commit

Permalink
fix: python pydantic unwrap properties (#2111)
Browse files Browse the repository at this point in the history
  • Loading branch information
memdal authored Oct 30, 2024
1 parent f4976c9 commit 9f2b80e
Show file tree
Hide file tree
Showing 3 changed files with 31 additions and 24 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,11 @@
exports[`Should be able to render python models and should log expected output to console: class-model 1`] = `
Array [
"class Root(BaseModel):
optional_field: Optional[str] = Field(description='''this field is optional''', default=None, alias='''optionalField''')
required_field: str = Field(description='''this field is required''', alias='''requiredField''')
no_description: Optional[str] = Field(default=None, alias='''noDescription''')
optional_field: Optional[str] = Field(description='''this field is optional''', default=None)
required_field: str = Field(description='''this field is required''')
no_description: Optional[str] = Field(default=None)
options: Optional[Options.Options] = Field(default=None)
content_type: Optional[str] = Field(default=None, alias='''content-type''')
content_type: Optional[str] = Field(default=None)
",
]
`;
Expand Down
9 changes: 4 additions & 5 deletions src/generators/python/presets/Pydantic.ts
Original file line number Diff line number Diff line change
Expand Up @@ -57,9 +57,6 @@ const PYTHON_PYDANTIC_CLASS_PRESET: ClassPresetType<PythonOptions> = {
if (!property.required) {
decoratorArgs.push('default=None');
}
if (property.propertyName !== property.unconstrainedPropertyName) {
decoratorArgs.push(`alias='''${property.unconstrainedPropertyName}'''`);
}

return `${propertyName}: ${type} = Field(${decoratorArgs.join(', ')})`;
},
Expand Down Expand Up @@ -101,6 +98,8 @@ def custom_serializer(self, handler):
@model_validator(mode='before')
@classmethod
def unwrap_${dictionaryModel?.propertyName}(cls, data):
if not isinstance(data, dict):
data = data.model_dump()
json_properties = list(data.keys())
known_object_properties = [${allProperties
.map((value) => `'${value}'`)
Expand All @@ -113,11 +112,11 @@ def unwrap_${dictionaryModel?.propertyName}(cls, data):
known_json_properties = [${Object.values(model.properties)
.map((value) => `'${value.unconstrainedPropertyName}'`)
.join(', ')}]
${dictionaryModel?.propertyName} = {}
${dictionaryModel?.propertyName} = data.get('${dictionaryModel?.propertyName}', {})
for obj_key in list(data.keys()):
if not known_json_properties.__contains__(obj_key):
${dictionaryModel?.propertyName}[obj_key] = data.pop(obj_key, None)
data['${dictionaryModel?.unconstrainedPropertyName}'] = ${dictionaryModel?.propertyName}
data['${dictionaryModel?.propertyName}'] = ${dictionaryModel?.propertyName}
return data
${content}`;
}
Expand Down
38 changes: 23 additions & 15 deletions test/generators/python/presets/__snapshots__/Pydantic.spec.ts.snap
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ exports[`PYTHON_PYDANTIC_PRESET should render pydantic for class 1`] = `
multi
line
description''', default=None)
additional_properties: Optional[dict[str, Any]] = Field(exclude=True, default=None, alias='''additionalProperties''')
additional_properties: Optional[dict[str, Any]] = Field(exclude=True, default=None)
@model_serializer(mode='wrap')
def custom_serializer(self, handler):
Expand All @@ -23,6 +23,8 @@ exports[`PYTHON_PYDANTIC_PRESET should render pydantic for class 1`] = `
@model_validator(mode='before')
@classmethod
def unwrap_additional_properties(cls, data):
if not isinstance(data, dict):
data = data.model_dump()
json_properties = list(data.keys())
known_object_properties = ['prop', 'additional_properties']
unknown_object_properties = [element for element in json_properties if element not in known_object_properties]
Expand All @@ -31,11 +33,11 @@ exports[`PYTHON_PYDANTIC_PRESET should render pydantic for class 1`] = `
return data
known_json_properties = ['prop', 'additionalProperties']
additional_properties = {}
additional_properties = data.get('additional_properties', {})
for obj_key in list(data.keys()):
if not known_json_properties.__contains__(obj_key):
additional_properties[obj_key] = data.pop(obj_key, None)
data['additionalProperties'] = additional_properties
data['additional_properties'] = additional_properties
return data
"
Expand All @@ -44,8 +46,8 @@ exports[`PYTHON_PYDANTIC_PRESET should render pydantic for class 1`] = `
exports[`PYTHON_PYDANTIC_PRESET should render union to support Python < 3.10 1`] = `
Array [
"class UnionTest(BaseModel):
union_test: Optional[Union[Union1.Union1, Union2.Union2]] = Field(default=None, alias='''unionTest''')
additional_properties: Optional[dict[str, Any]] = Field(exclude=True, default=None, alias='''additionalProperties''')
union_test: Optional[Union[Union1.Union1, Union2.Union2]] = Field(default=None)
additional_properties: Optional[dict[str, Any]] = Field(exclude=True, default=None)
@model_serializer(mode='wrap')
def custom_serializer(self, handler):
Expand All @@ -62,6 +64,8 @@ Array [
@model_validator(mode='before')
@classmethod
def unwrap_additional_properties(cls, data):
if not isinstance(data, dict):
data = data.model_dump()
json_properties = list(data.keys())
known_object_properties = ['union_test', 'additional_properties']
unknown_object_properties = [element for element in json_properties if element not in known_object_properties]
Expand All @@ -70,17 +74,17 @@ Array [
return data
known_json_properties = ['unionTest', 'additionalProperties']
additional_properties = {}
additional_properties = data.get('additional_properties', {})
for obj_key in list(data.keys()):
if not known_json_properties.__contains__(obj_key):
additional_properties[obj_key] = data.pop(obj_key, None)
data['additionalProperties'] = additional_properties
data['additional_properties'] = additional_properties
return data
",
"class Union1(BaseModel):
test_prop1: Optional[str] = Field(default=None, alias='''testProp1''')
additional_properties: Optional[dict[str, Any]] = Field(exclude=True, default=None, alias='''additionalProperties''')
test_prop1: Optional[str] = Field(default=None)
additional_properties: Optional[dict[str, Any]] = Field(exclude=True, default=None)
@model_serializer(mode='wrap')
def custom_serializer(self, handler):
Expand All @@ -97,6 +101,8 @@ Array [
@model_validator(mode='before')
@classmethod
def unwrap_additional_properties(cls, data):
if not isinstance(data, dict):
data = data.model_dump()
json_properties = list(data.keys())
known_object_properties = ['test_prop1', 'additional_properties']
unknown_object_properties = [element for element in json_properties if element not in known_object_properties]
Expand All @@ -105,17 +111,17 @@ Array [
return data
known_json_properties = ['testProp1', 'additionalProperties']
additional_properties = {}
additional_properties = data.get('additional_properties', {})
for obj_key in list(data.keys()):
if not known_json_properties.__contains__(obj_key):
additional_properties[obj_key] = data.pop(obj_key, None)
data['additionalProperties'] = additional_properties
data['additional_properties'] = additional_properties
return data
",
"class Union2(BaseModel):
test_prop2: Optional[str] = Field(default=None, alias='''testProp2''')
additional_properties: Optional[dict[str, Any]] = Field(exclude=True, default=None, alias='''additionalProperties''')
test_prop2: Optional[str] = Field(default=None)
additional_properties: Optional[dict[str, Any]] = Field(exclude=True, default=None)
@model_serializer(mode='wrap')
def custom_serializer(self, handler):
Expand All @@ -132,6 +138,8 @@ Array [
@model_validator(mode='before')
@classmethod
def unwrap_additional_properties(cls, data):
if not isinstance(data, dict):
data = data.model_dump()
json_properties = list(data.keys())
known_object_properties = ['test_prop2', 'additional_properties']
unknown_object_properties = [element for element in json_properties if element not in known_object_properties]
Expand All @@ -140,11 +148,11 @@ Array [
return data
known_json_properties = ['testProp2', 'additionalProperties']
additional_properties = {}
additional_properties = data.get('additional_properties', {})
for obj_key in list(data.keys()):
if not known_json_properties.__contains__(obj_key):
additional_properties[obj_key] = data.pop(obj_key, None)
data['additionalProperties'] = additional_properties
data['additional_properties'] = additional_properties
return data
",
Expand Down

0 comments on commit 9f2b80e

Please sign in to comment.