Skip to content

Commit

Permalink
Allow writing pa.Table that are either a subset of table schema or …
Browse files Browse the repository at this point in the history
…in arbitrary order, and support type promotion on write (#921)

* merge

* thanks @HonahX :)

Co-authored-by: Honah J. <[email protected]>

* support promote

* revert promote

* use a visitor

* support promotion on write

* fix

* Thank you @Fokko !

Co-authored-by: Fokko Driesprong <[email protected]>

* revert

* add-files promotiontest

* support promote for add_files

* add tests for uuid

* add_files subset schema test

---------

Co-authored-by: Honah J. <[email protected]>
Co-authored-by: Fokko Driesprong <[email protected]>
  • Loading branch information
3 people authored Jul 17, 2024
1 parent 0f2e19e commit 1ed3abd
Show file tree
Hide file tree
Showing 7 changed files with 545 additions and 79 deletions.
81 changes: 33 additions & 48 deletions pyiceberg/io/pyarrow.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,7 @@
Schema,
SchemaVisitorPerPrimitiveType,
SchemaWithPartnerVisitor,
_check_schema_compatible,
pre_order_visit,
promote,
prune_columns,
Expand Down Expand Up @@ -1407,7 +1408,7 @@ def list(self, list_type: ListType, list_array: Optional[pa.Array], value_array:
# This can be removed once this has been fixed:
# https://github.com/apache/arrow/issues/38809
list_array = pa.LargeListArray.from_arrays(list_array.offsets, value_array)

value_array = self._cast_if_needed(list_type.element_field, value_array)
arrow_field = pa.large_list(self._construct_field(list_type.element_field, value_array.type))
return list_array.cast(arrow_field)
else:
Expand All @@ -1417,6 +1418,8 @@ def map(
self, map_type: MapType, map_array: Optional[pa.Array], key_result: Optional[pa.Array], value_result: Optional[pa.Array]
) -> Optional[pa.Array]:
if isinstance(map_array, pa.MapArray) and key_result is not None and value_result is not None:
key_result = self._cast_if_needed(map_type.key_field, key_result)
value_result = self._cast_if_needed(map_type.value_field, value_result)
arrow_field = pa.map_(
self._construct_field(map_type.key_field, key_result.type),
self._construct_field(map_type.value_field, value_result.type),
Expand Down Expand Up @@ -1549,9 +1552,16 @@ def __init__(self, iceberg_type: PrimitiveType, physical_type_string: str, trunc

expected_physical_type = _primitive_to_physical(iceberg_type)
if expected_physical_type != physical_type_string:
raise ValueError(
f"Unexpected physical type {physical_type_string} for {iceberg_type}, expected {expected_physical_type}"
)
# Allow promotable physical types
# INT32 -> INT64 and FLOAT -> DOUBLE are safe type casts
if (physical_type_string == "INT32" and expected_physical_type == "INT64") or (
physical_type_string == "FLOAT" and expected_physical_type == "DOUBLE"
):
pass
else:
raise ValueError(
f"Unexpected physical type {physical_type_string} for {iceberg_type}, expected {expected_physical_type}"
)

self.primitive_type = iceberg_type

Expand Down Expand Up @@ -1896,16 +1906,6 @@ def data_file_statistics_from_parquet_metadata(
set the mode for column metrics collection
parquet_column_mapping (Dict[str, int]): The mapping of the parquet file name to the field ID
"""
if parquet_metadata.num_columns != len(stats_columns):
raise ValueError(
f"Number of columns in statistics configuration ({len(stats_columns)}) is different from the number of columns in pyarrow table ({parquet_metadata.num_columns})"
)

if parquet_metadata.num_columns != len(parquet_column_mapping):
raise ValueError(
f"Number of columns in column mapping ({len(parquet_column_mapping)}) is different from the number of columns in pyarrow table ({parquet_metadata.num_columns})"
)

column_sizes: Dict[int, int] = {}
value_counts: Dict[int, int] = {}
split_offsets: List[int] = []
Expand Down Expand Up @@ -1998,8 +1998,7 @@ def write_file(io: FileIO, table_metadata: TableMetadata, tasks: Iterator[WriteT
)

def write_parquet(task: WriteTask) -> DataFile:
table_schema = task.schema

table_schema = table_metadata.schema()
# if schema needs to be transformed, use the transformed schema and adjust the arrow table accordingly
# otherwise use the original schema
if (sanitized_schema := sanitize_column_names(table_schema)) != table_schema:
Expand All @@ -2011,7 +2010,7 @@ def write_parquet(task: WriteTask) -> DataFile:
batches = [
_to_requested_schema(
requested_schema=file_schema,
file_schema=table_schema,
file_schema=task.schema,
batch=batch,
downcast_ns_timestamp_to_us=downcast_ns_timestamp_to_us,
include_field_ids=True,
Expand Down Expand Up @@ -2070,47 +2069,30 @@ def bin_pack_arrow_table(tbl: pa.Table, target_file_size: int) -> Iterator[List[
return bin_packed_record_batches


def _check_schema_compatible(table_schema: Schema, other_schema: pa.Schema, downcast_ns_timestamp_to_us: bool = False) -> None:
def _check_pyarrow_schema_compatible(
requested_schema: Schema, provided_schema: pa.Schema, downcast_ns_timestamp_to_us: bool = False
) -> None:
"""
Check if the `table_schema` is compatible with `other_schema`.
Check if the `requested_schema` is compatible with `provided_schema`.
Two schemas are considered compatible when they are equal in terms of the Iceberg Schema type.
Raises:
ValueError: If the schemas are not compatible.
"""
name_mapping = table_schema.name_mapping
name_mapping = requested_schema.name_mapping
try:
task_schema = pyarrow_to_schema(
other_schema, name_mapping=name_mapping, downcast_ns_timestamp_to_us=downcast_ns_timestamp_to_us
provided_schema = pyarrow_to_schema(
provided_schema, name_mapping=name_mapping, downcast_ns_timestamp_to_us=downcast_ns_timestamp_to_us
)
except ValueError as e:
other_schema = _pyarrow_to_schema_without_ids(other_schema, downcast_ns_timestamp_to_us=downcast_ns_timestamp_to_us)
additional_names = set(other_schema.column_names) - set(table_schema.column_names)
provided_schema = _pyarrow_to_schema_without_ids(provided_schema, downcast_ns_timestamp_to_us=downcast_ns_timestamp_to_us)
additional_names = set(provided_schema._name_to_id.keys()) - set(requested_schema._name_to_id.keys())
raise ValueError(
f"PyArrow table contains more columns: {', '.join(sorted(additional_names))}. Update the schema first (hint, use union_by_name)."
) from e

if table_schema.as_struct() != task_schema.as_struct():
from rich.console import Console
from rich.table import Table as RichTable

console = Console(record=True)

rich_table = RichTable(show_header=True, header_style="bold")
rich_table.add_column("")
rich_table.add_column("Table field")
rich_table.add_column("Dataframe field")

for lhs in table_schema.fields:
try:
rhs = task_schema.find_field(lhs.field_id)
rich_table.add_row("✅" if lhs == rhs else "❌", str(lhs), str(rhs))
except ValueError:
rich_table.add_row("❌", str(lhs), "Missing")

console.print(rich_table)
raise ValueError(f"Mismatch in fields:\n{console.export_text()}")
_check_schema_compatible(requested_schema, provided_schema)


def parquet_files_to_data_files(io: FileIO, table_metadata: TableMetadata, file_paths: Iterator[str]) -> Iterator[DataFile]:
Expand All @@ -2124,7 +2106,7 @@ def parquet_files_to_data_files(io: FileIO, table_metadata: TableMetadata, file_
f"Cannot add file {file_path} because it has field IDs. `add_files` only supports addition of files without field_ids"
)
schema = table_metadata.schema()
_check_schema_compatible(schema, parquet_metadata.schema.to_arrow_schema())
_check_pyarrow_schema_compatible(schema, parquet_metadata.schema.to_arrow_schema())

statistics = data_file_statistics_from_parquet_metadata(
parquet_metadata=parquet_metadata,
Expand Down Expand Up @@ -2205,7 +2187,7 @@ def _dataframe_to_data_files(
Returns:
An iterable that supplies datafiles that represent the table.
"""
from pyiceberg.table import PropertyUtil, TableProperties, WriteTask
from pyiceberg.table import DOWNCAST_NS_TIMESTAMP_TO_US_ON_WRITE, PropertyUtil, TableProperties, WriteTask

counter = counter or itertools.count(0)
write_uuid = write_uuid or uuid.uuid4()
Expand All @@ -2214,13 +2196,16 @@ def _dataframe_to_data_files(
property_name=TableProperties.WRITE_TARGET_FILE_SIZE_BYTES,
default=TableProperties.WRITE_TARGET_FILE_SIZE_BYTES_DEFAULT,
)
name_mapping = table_metadata.schema().name_mapping
downcast_ns_timestamp_to_us = Config().get_bool(DOWNCAST_NS_TIMESTAMP_TO_US_ON_WRITE) or False
task_schema = pyarrow_to_schema(df.schema, name_mapping=name_mapping, downcast_ns_timestamp_to_us=downcast_ns_timestamp_to_us)

if table_metadata.spec().is_unpartitioned():
yield from write_file(
io=io,
table_metadata=table_metadata,
tasks=iter([
WriteTask(write_uuid=write_uuid, task_id=next(counter), record_batches=batches, schema=table_metadata.schema())
WriteTask(write_uuid=write_uuid, task_id=next(counter), record_batches=batches, schema=task_schema)
for batches in bin_pack_arrow_table(df, target_file_size)
]),
)
Expand All @@ -2235,7 +2220,7 @@ def _dataframe_to_data_files(
task_id=next(counter),
record_batches=batches,
partition_key=partition.partition_key,
schema=table_metadata.schema(),
schema=task_schema,
)
for partition in partitions
for batches in bin_pack_arrow_table(partition.arrow_table_partition, target_file_size)
Expand Down
100 changes: 100 additions & 0 deletions pyiceberg/schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -1616,3 +1616,103 @@ def _(file_type: FixedType, read_type: IcebergType) -> IcebergType:
return read_type
else:
raise ResolveError(f"Cannot promote {file_type} to {read_type}")


def _check_schema_compatible(requested_schema: Schema, provided_schema: Schema) -> None:
"""
Check if the `provided_schema` is compatible with `requested_schema`.
Both Schemas must have valid IDs and share the same ID for the same field names.
Two schemas are considered compatible when:
1. All `required` fields in `requested_schema` are present and are also `required` in the `provided_schema`
2. Field Types are consistent for fields that are present in both schemas. I.e. the field type
in the `provided_schema` can be promoted to the field type of the same field ID in `requested_schema`
Raises:
ValueError: If the schemas are not compatible.
"""
pre_order_visit(requested_schema, _SchemaCompatibilityVisitor(provided_schema))


class _SchemaCompatibilityVisitor(PreOrderSchemaVisitor[bool]):
provided_schema: Schema

def __init__(self, provided_schema: Schema):
from rich.console import Console
from rich.table import Table as RichTable

self.provided_schema = provided_schema
self.rich_table = RichTable(show_header=True, header_style="bold")
self.rich_table.add_column("")
self.rich_table.add_column("Table field")
self.rich_table.add_column("Dataframe field")
self.console = Console(record=True)

def _is_field_compatible(self, lhs: NestedField) -> bool:
# Validate nullability first.
# An optional field can be missing in the provided schema
# But a required field must exist as a required field
try:
rhs = self.provided_schema.find_field(lhs.field_id)
except ValueError:
if lhs.required:
self.rich_table.add_row("❌", str(lhs), "Missing")
return False
else:
self.rich_table.add_row("✅", str(lhs), "Missing")
return True

if lhs.required and not rhs.required:
self.rich_table.add_row("❌", str(lhs), str(rhs))
return False

# Check type compatibility
if lhs.field_type == rhs.field_type:
self.rich_table.add_row("✅", str(lhs), str(rhs))
return True
# We only check that the parent node is also of the same type.
# We check the type of the child nodes when we traverse them later.
elif any(
(isinstance(lhs.field_type, container_type) and isinstance(rhs.field_type, container_type))
for container_type in {StructType, MapType, ListType}
):
self.rich_table.add_row("✅", str(lhs), str(rhs))
return True
else:
try:
# If type can be promoted to the requested schema
# it is considered compatible
promote(rhs.field_type, lhs.field_type)
self.rich_table.add_row("✅", str(lhs), str(rhs))
return True
except ResolveError:
self.rich_table.add_row("❌", str(lhs), str(rhs))
return False

def schema(self, schema: Schema, struct_result: Callable[[], bool]) -> bool:
if not (result := struct_result()):
self.console.print(self.rich_table)
raise ValueError(f"Mismatch in fields:\n{self.console.export_text()}")
return result

def struct(self, struct: StructType, field_results: List[Callable[[], bool]]) -> bool:
results = [result() for result in field_results]
return all(results)

def field(self, field: NestedField, field_result: Callable[[], bool]) -> bool:
return self._is_field_compatible(field) and field_result()

def list(self, list_type: ListType, element_result: Callable[[], bool]) -> bool:
return self._is_field_compatible(list_type.element_field) and element_result()

def map(self, map_type: MapType, key_result: Callable[[], bool], value_result: Callable[[], bool]) -> bool:
return all([
self._is_field_compatible(map_type.key_field),
self._is_field_compatible(map_type.value_field),
key_result(),
value_result(),
])

def primitive(self, primitive: PrimitiveType) -> bool:
return True
15 changes: 10 additions & 5 deletions pyiceberg/table/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,6 @@
manifest_evaluator,
)
from pyiceberg.io import FileIO, OutputFile, load_file_io
from pyiceberg.io.pyarrow import _check_schema_compatible, _dataframe_to_data_files, expression_to_pyarrow, project_table
from pyiceberg.manifest import (
POSITIONAL_DELETE_SCHEMA,
DataFile,
Expand Down Expand Up @@ -471,6 +470,8 @@ def append(self, df: pa.Table, snapshot_properties: Dict[str, str] = EMPTY_DICT)
except ModuleNotFoundError as e:
raise ModuleNotFoundError("For writes PyArrow needs to be installed") from e

from pyiceberg.io.pyarrow import _check_pyarrow_schema_compatible, _dataframe_to_data_files

if not isinstance(df, pa.Table):
raise ValueError(f"Expected PyArrow table, got: {df}")

Expand All @@ -481,8 +482,8 @@ def append(self, df: pa.Table, snapshot_properties: Dict[str, str] = EMPTY_DICT)
f"Not all partition types are supported for writes. Following partitions cannot be written using pyarrow: {unsupported_partitions}."
)
downcast_ns_timestamp_to_us = Config().get_bool(DOWNCAST_NS_TIMESTAMP_TO_US_ON_WRITE) or False
_check_schema_compatible(
self._table.schema(), other_schema=df.schema, downcast_ns_timestamp_to_us=downcast_ns_timestamp_to_us
_check_pyarrow_schema_compatible(
self._table.schema(), provided_schema=df.schema, downcast_ns_timestamp_to_us=downcast_ns_timestamp_to_us
)

manifest_merge_enabled = PropertyUtil.property_as_bool(
Expand Down Expand Up @@ -528,6 +529,8 @@ def overwrite(
except ModuleNotFoundError as e:
raise ModuleNotFoundError("For writes PyArrow needs to be installed") from e

from pyiceberg.io.pyarrow import _check_pyarrow_schema_compatible, _dataframe_to_data_files

if not isinstance(df, pa.Table):
raise ValueError(f"Expected PyArrow table, got: {df}")

Expand All @@ -538,8 +541,8 @@ def overwrite(
f"Not all partition types are supported for writes. Following partitions cannot be written using pyarrow: {unsupported_partitions}."
)
downcast_ns_timestamp_to_us = Config().get_bool(DOWNCAST_NS_TIMESTAMP_TO_US_ON_WRITE) or False
_check_schema_compatible(
self._table.schema(), other_schema=df.schema, downcast_ns_timestamp_to_us=downcast_ns_timestamp_to_us
_check_pyarrow_schema_compatible(
self._table.schema(), provided_schema=df.schema, downcast_ns_timestamp_to_us=downcast_ns_timestamp_to_us
)

self.delete(delete_filter=overwrite_filter, snapshot_properties=snapshot_properties)
Expand All @@ -566,6 +569,8 @@ def delete(self, delete_filter: Union[str, BooleanExpression], snapshot_properti
delete_filter: A boolean expression to delete rows from a table
snapshot_properties: Custom properties to be added to the snapshot summary
"""
from pyiceberg.io.pyarrow import _dataframe_to_data_files, expression_to_pyarrow, project_table

if (
self.table_metadata.properties.get(TableProperties.DELETE_MODE, TableProperties.DELETE_MODE_DEFAULT)
== TableProperties.DELETE_MODE_MERGE_ON_READ
Expand Down
Loading

0 comments on commit 1ed3abd

Please sign in to comment.