Skip to content

Commit

Permalink
Add CSV plugin (#15)
Browse files Browse the repository at this point in the history
  • Loading branch information
faph authored Dec 18, 2023
2 parents 97e7941 + aeed85d commit 328ed8c
Show file tree
Hide file tree
Showing 6 changed files with 152 additions and 7 deletions.
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ dependencies = [
"avro~=1.11",
"fastavro~=1.8", # TODO: consider moving Avro-related dependencies to optional dependencies
"memoization~=0.4",
"more-itertools~=10.0",
"orjson~=3.0",
"pluggy~=1.3",
"py-avro-schema~=3.0",
Expand Down
8 changes: 3 additions & 5 deletions src/py_adapter/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@
import importlib.metadata
import inspect
import io
import itertools
import logging
import uuid
from collections.abc import Iterable, Iterator
Expand All @@ -42,6 +41,7 @@
import avro.schema
import dateutil.parser
import memoization
import more_itertools
import orjson
import py_avro_schema as pas

Expand Down Expand Up @@ -151,12 +151,10 @@ def serialize_many_to_stream(objs: Iterable[Any], stream: BinaryIO, *, format: s
:param writer_schema: Data schema to serialize the data with, as JSON bytes.
"""
serialize_fn = py_adapter.plugin.plugin_hook(format, "serialize_many")
objs_iter = iter(objs)
# Use the first object to find the class, assuming all objects share the same type
first_obj = next(objs_iter)
(first_obj,), objs = more_itertools.spy(objs) # This will fail if the iterable is empty
py_type = type(first_obj)
# Then iterate over all objects again to convert to basic types
basic_objs = (to_basic_type(obj) for obj in itertools.chain([first_obj], objs_iter))
basic_objs = (to_basic_type(obj) for obj in objs)
serialize_fn(objs=basic_objs, stream=stream, py_type=py_type, writer_schema=writer_schema)


Expand Down
3 changes: 2 additions & 1 deletion src/py_adapter/plugin/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,10 +53,11 @@ def manager() -> pluggy.PluginManager:

def _load_default_plugins(manager_: pluggy.PluginManager) -> None:
"""Load plugins that are packaged with py-adapter"""
from py_adapter.plugin import _avro, _json
from py_adapter.plugin import _avro, _csv, _json

default_plugins = {
"Avro": _avro,
"CSV": _csv,
"JSON": _json,
}
for name, plugin in default_plugins.items():
Expand Down
82 changes: 82 additions & 0 deletions src/py_adapter/plugin/_csv.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
# Copyright 2023 J.P. Morgan Chase & Co.
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
# the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
# specific language governing permissions and limitations under the License.

"""
CSV serializer/deserializer **py-adapter** plugin
"""

import io
from typing import BinaryIO, Iterable, Iterator

import more_itertools

import py_adapter


@py_adapter.plugin.hook
def serialize(obj: py_adapter.Basic, stream: BinaryIO) -> BinaryIO:
"""
Serialize an object of basic Python types as CSV data
:param obj: Python object to serialize
:param stream: File-like object to serialize data to
"""
serialize_many([obj], stream)
return stream


@py_adapter.plugin.hook
def serialize_many(objs: Iterable[py_adapter.Basic], stream: BinaryIO) -> BinaryIO:
"""
Serialize multiple Python objects of basic types as CSV data
:param objs: Python objects to serialize
:param stream: File-like object to serialize data to
"""
import csv

text_stream = io.StringIO(newline="") # csv modules writes as text
(first_obj,), objs = more_itertools.spy(objs) # this fails if the iterable is empty
assert isinstance(first_obj, dict), "CSV serializer supports 'record' types only."
csv_writer = csv.DictWriter(text_stream, fieldnames=first_obj.keys())
csv_writer.writeheader()
csv_writer.writerows(objs) # type:ignore[arg-type] # We know it's a dict
text_stream.flush()
text_stream.seek(0)

stream.write(text_stream.read().encode("utf-8"))
stream.flush()
return stream


@py_adapter.plugin.hook
def deserialize(stream: BinaryIO) -> py_adapter.Basic:
"""
Deserialize CSV data as an object of basic Python types
:param stream: File-like object to deserialize
"""
obj = next(deserialize_many(stream))
return obj


@py_adapter.plugin.hook
def deserialize_many(stream: BinaryIO) -> Iterator[py_adapter.Basic]:
"""
Deserialize CSV data as an iterator over objects of basic Python types
:param stream: File-like object to deserialize
"""
import csv

text_stream = io.StringIO(stream.read().decode("utf-8"))
csv_reader = csv.DictReader(text_stream)
return csv_reader
63 changes: 63 additions & 0 deletions tests/test_csv.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
# Copyright 2023 J.P. Morgan Chase & Co.
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
# the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
# specific language governing permissions and limitations under the License.

"""
Test script for CSV serialization/deserialization
"""

import dataclasses
import datetime
from typing import Optional

import pytest

import py_adapter


@dataclasses.dataclass
class SimpleShip:
name: str
build_on: Optional[datetime.date] = None


@pytest.fixture
def simple_ship():
return SimpleShip(
name="Elvira",
build_on=datetime.date(1970, 12, 31),
)


def test_serialize_1_record(simple_ship):
data = py_adapter.serialize(simple_ship, format="CSV")
expected_lines = [b"name,build_on", b"Elvira,1970-12-31"]
assert data.splitlines() == expected_lines
obj_out = py_adapter.deserialize(data, SimpleShip, format="CSV")
assert obj_out == simple_ship


def test_serialize_many_records(simple_ship):
objs_in = [simple_ship, simple_ship]
data = py_adapter.serialize_many(objs_in, format="CSV")
expected_lines = [b"name,build_on", b"Elvira,1970-12-31", b"Elvira,1970-12-31"]
assert data.splitlines() == expected_lines
objs_out = list(py_adapter.deserialize_many(data, SimpleShip, format="CSV"))
assert objs_out == objs_in


@pytest.mark.xfail(reason="Not supported")
def test_serialize_many_no_records():
objs_in = []
data = py_adapter.serialize_many(objs_in, format="CSV")
expected_lines = [b"name,build_on"]
assert data.splitlines() == expected_lines
objs_out = list(py_adapter.deserialize_many(data, SimpleShip, format="CSV"))
assert objs_out == objs_in
2 changes: 1 addition & 1 deletion tests/test_serialize.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ def test_invalid_format(ship_obj):
py_adapter.plugin.InvalidFormat,
match=re.escape(
"A plugin for serialization format 'does not exist' is not available. Installed plugins/formats are: "
"['Avro', 'JSON']."
"['Avro', 'CSV', 'JSON']."
),
):
py_adapter.serialize(ship_obj, format="does not exist")
Expand Down

0 comments on commit 328ed8c

Please sign in to comment.