From 651ce510ca6a0eca9a7ece22f17675201a5329bc Mon Sep 17 00:00:00 2001 From: ZENOTME <43447882+ZENOTME@users.noreply.github.com> Date: Fri, 21 Jul 2023 14:46:54 +0800 Subject: [PATCH] fix: fix the serialize way of struct value (#110) --- src/types/in_memory.rs | 104 ++++++++++++++++++++++++++++++++++++- src/types/on_disk/types.rs | 17 ++++-- 2 files changed, 115 insertions(+), 6 deletions(-) diff --git a/src/types/in_memory.rs b/src/types/in_memory.rs index e3d9551..6565dae 100644 --- a/src/types/in_memory.rs +++ b/src/types/in_memory.rs @@ -9,6 +9,7 @@ use chrono::NaiveTime; use chrono::Utc; use rust_decimal::Decimal; use serde::ser::SerializeMap; +use serde::ser::SerializeStruct; use serde::Serialize; use uuid::Uuid; @@ -42,7 +43,7 @@ pub enum AnyValue { /// default value. /// /// struct value stores as a map from field id to field value. - Struct(HashMap), + Struct(StructValue), /// A list type with a list of typed values. List(Vec), /// A map is a collection of key-value pairs with a key type and a value type. @@ -222,6 +223,17 @@ pub struct Struct { pub fields: Vec, } +impl Struct { + /// Generate map from field id to field name map for this struct. + pub fn generate_field_name_map(&self) -> HashMap { + let mut map = HashMap::with_capacity(self.fields.len()); + for field in &self.fields { + map.insert(field.id, field.name.clone()); + } + map + } +} + /// A Field is the field of a struct. #[derive(Debug, PartialEq, Clone)] pub struct Field { @@ -279,6 +291,30 @@ impl Field { } } +/// A Struct type is a tuple of typed values. +#[derive(Debug, PartialEq, Clone)] +pub struct StructValue { + /// fields is a map from field id to field value. + pub fields: HashMap, + /// field_names is a map from field id to field name. + pub field_names: HashMap, +} + +impl Serialize for StructValue { + fn serialize(&self, serializer: S) -> std::result::Result + where + S: serde::Serializer, + { + let mut record = serializer.serialize_struct("", self.fields.len())?; + for (id, value) in self.fields.iter() { + let key = self.field_names.get(id).unwrap(); + // NOTE: Here we use `Box::leak` to convert `&str` to `&'static str`. The safe is guaranteed by serializer. + record.serialize_field(Box::leak(key.clone().into_boxed_str()), value)?; + } + record.end() + } +} + /// A list is a collection of values with some element type. /// /// - The element field has an integer id that is unique in the table schema. @@ -1698,3 +1734,69 @@ impl ToString for TableFormatVersion { } } } + +#[cfg(test)] +mod test { + use std::collections::HashMap; + + use apache_avro::{schema, types::Value}; + + use super::{Any, AnyValue, Field, Primitive, Struct}; + + #[test] + fn test_struct_to_avro() { + let value = { + let struct_types = Struct { + fields: vec![ + Field::required(1, "a", Any::Primitive(Primitive::Int)), + Field::required(2, "b", Any::Primitive(Primitive::String)), + ], + }; + + let struct_value = { + let mut fields = HashMap::new(); + fields.insert(1, AnyValue::Primitive(super::PrimitiveValue::Int(1))); + fields.insert( + 2, + AnyValue::Primitive(super::PrimitiveValue::String("hello".to_string())), + ); + AnyValue::Struct(super::StructValue { + fields, + field_names: struct_types.generate_field_name_map(), + }) + }; + + let mut res = apache_avro::to_value(struct_value).unwrap(); + + // Guarantee the order of fields order of field names for later compare. + if let Value::Record(ref mut record) = res { + record.sort_by(|a, b| a.0.cmp(&b.0)); + } + + res + }; + + let expect_value = { + let raw_schema = r#" + { + "type": "record", + "name": "test", + "fields": [ + {"name": "a", "type": "int"}, + {"name": "b", "type": "string"} + ] + } + "#; + + let schema = schema::Schema::parse_str(raw_schema).unwrap(); + + let mut record = apache_avro::types::Record::new(&schema).unwrap(); + record.put("a", 1); + record.put("b", "hello"); + + record.into() + }; + + assert_eq!(value, expect_value); + } +} diff --git a/src/types/on_disk/types.rs b/src/types/on_disk/types.rs index 0b65c69..cda83a9 100644 --- a/src/types/on_disk/types.rs +++ b/src/types/on_disk/types.rs @@ -19,6 +19,8 @@ use uuid::Uuid; use crate::types; use crate::types::Any; +use crate::types::AnyValue; +use crate::types::StructValue; use crate::Error; use crate::ErrorKind; use crate::Result; @@ -773,13 +775,15 @@ fn parse_json_value_to_struct( expect_struct: &types::Struct, value: serde_json::Value, ) -> Result { - let fields = expect_struct + let field_types = expect_struct .fields .iter() .map(|v| (v.id, &v.field_type)) .collect::>(); - let mut values = HashMap::with_capacity(fields.len()); + let field_names = expect_struct.generate_field_name_map(); + + let mut fields: HashMap = HashMap::with_capacity(field_types.len()); match value { serde_json::Value::Object(o) => { @@ -792,7 +796,7 @@ fn parse_json_value_to_struct( .set_source(err) })?; - let expect_type = fields.get(&key).ok_or_else(|| { + let expect_type = field_types.get(&key).ok_or_else(|| { Error::new( ErrorKind::IcebergDataInvalid, format!("expect filed id {:?} but not exist", key), @@ -801,7 +805,7 @@ fn parse_json_value_to_struct( let value = parse_json_value(expect_type, value)?; - values.insert(key, value); + fields.insert(key, value); } } _ => { @@ -812,7 +816,10 @@ fn parse_json_value_to_struct( } } - Ok(types::AnyValue::Struct(values)) + Ok(types::AnyValue::Struct(StructValue { + fields, + field_names, + })) } /// JSON single-value serialization requires List been