Skip to content

Commit

Permalink
Add support for Zstd bytes->bytes Codec.
Browse files Browse the repository at this point in the history
  • Loading branch information
zoj613 committed Oct 3, 2024
1 parent d3cb229 commit 703afaa
Show file tree
Hide file tree
Showing 13 changed files with 123 additions and 18 deletions.
1 change: 1 addition & 0 deletions .github/workflows/build-and-test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ jobs:
- name: setup
run: |
opam install --deps-only --with-test --with-doc --yes zarr
opam install bytesrw conf-zlib conf-zstd --yes
opam install lwt --yes
opam exec -- dune build zarr zarr-sync zarr-lwt
Expand Down
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@ assert (Ndarray.equal x' y);;
```ocaml
let config =
{chunk_shape = [|5; 3; 5|]
;codecs = [`Transpose [|2; 0; 1|]; `Bytes LE; `Gzip L5]
;codecs = [`Transpose [|2; 0; 1|]; `Bytes LE; `Zstd (0, true)]
;index_codecs = [`Bytes BE; `Crc32c]
;index_location = Start};;
Expand Down
2 changes: 1 addition & 1 deletion zarr-sync/test/test_sync.ml
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ let test_storage
{chunk_shape = [|2; 5; 5|]
;index_location = End
;index_codecs = [`Bytes LE; `Crc32c]
;codecs = [`Transpose [|2; 0; 1|]; `Bytes BE; `Gzip L5]} in
;codecs = [`Transpose [|2; 0; 1|]; `Bytes BE; `Zstd (0, false)]} in
let cfg2 =
{chunk_shape = [|2; 5; 5|]
;index_location = Start
Expand Down
3 changes: 3 additions & 0 deletions zarr.opam
Original file line number Diff line number Diff line change
Expand Up @@ -40,3 +40,6 @@ build: [
]
]
dev-repo: "git+https://github.com/zoj613/zarr-ml.git"
pin-depends: [
["bytesrw.dev" "git+https://erratique.ch/repos/bytesrw.git"]
]
3 changes: 3 additions & 0 deletions zarr.opam.template
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
pin-depends: [
["bytesrw.dev" "git+https://erratique.ch/repos/bytesrw.git"]
]
2 changes: 1 addition & 1 deletion zarr/src/codecs/array_to_bytes.ml
Original file line number Diff line number Diff line change
Expand Up @@ -375,7 +375,7 @@ end = struct
let* l = acc in
match c with
| `Crc32c -> Ok (`Crc32c :: l)
| `Gzip _ -> Error msg) ic.b2b (Ok [])
| `Gzip _ | `Zstd _ -> Error msg) ic.b2b (Ok [])
in
let+ a2b = match ic.a2b with
| `Bytes e -> Ok (`Bytes e)
Expand Down
40 changes: 40 additions & 0 deletions zarr/src/codecs/bytes_to_bytes.ml
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
open Codecs_intf
open Bytesrw

(* https://zarr-specs.readthedocs.io/en/latest/v3/codecs/gzip/v1.0.html *)
module GzipCodec = struct
Expand Down Expand Up @@ -52,27 +53,66 @@ module Crc32cCodec = struct
Ok `Crc32c
end

(* https://github.com/zarr-developers/zarr-specs/pull/256 *)
module ZstdCodec = struct
let min_clevel = -131072 and max_clevel = 22

let parse_clevel l =
if l < min_clevel || max_clevel < l then (raise Invalid_zstd_level)

let encode clevel checksum x =
let params = Bytesrw_zstd.Cctx_params.make ~checksum ~clevel () in
Bytes.Reader.to_string @@
Bytesrw_zstd.compress_reads ~params () @@ Bytes.Reader.of_string x

let decode x =
let params = Bytesrw_zstd.Dctx_params.default in
Bytes.Reader.to_string @@
Bytesrw_zstd.decompress_reads ~params () @@ Bytes.Reader.of_string x

let to_yojson l c =
`Assoc
[("name", `String "zstd")
;("configuration", `Assoc [("level", `Int l); ("checksum", `Bool c)])]

let of_yojson x =
match Yojson.Safe.Util.(member "configuration" x |> to_assoc) with
| [("level", `Int l); ("checksum", `Bool c)] ->
begin match parse_clevel l with
| exception Invalid_zstd_level -> Error "Invalid_zstd_level"
| () -> Result.ok @@ `Zstd (l, c) end
| _ -> Error "Invalid Zstd configuration."
end

module BytesToBytes = struct
let encoded_size :
int -> fixed_bytestobytes -> int
= fun input_size -> function
| `Crc32c -> Crc32cCodec.encoded_size input_size

let parse = function
| `Gzip _ | `Crc32c -> ()
| `Zstd (l, _) -> ZstdCodec.parse_clevel l

let encode x = function
| `Gzip l -> GzipCodec.encode l x
| `Crc32c -> Crc32cCodec.encode x
| `Zstd (l, c) -> ZstdCodec.encode l c x

let decode t x = match t with
| `Gzip _ -> GzipCodec.decode x
| `Crc32c -> Crc32cCodec.decode x
| `Zstd _ -> ZstdCodec.decode x

let to_yojson = function
| `Gzip l -> GzipCodec.to_yojson l
| `Crc32c -> Crc32cCodec.to_yojson
| `Zstd (l, c) -> ZstdCodec.to_yojson l c

let of_yojson x =
match Util.get_name x with
| "gzip" -> GzipCodec.of_yojson x
| "crc32c" -> Crc32cCodec.of_yojson x
| "zstd" -> ZstdCodec.of_yojson x
| s -> Error (Printf.sprintf "codec %s is not supported." s)
end
1 change: 1 addition & 0 deletions zarr/src/codecs/bytes_to_bytes.mli
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
open Codecs_intf

module BytesToBytes : sig
val parse : bytestobytes -> unit
val encoded_size : int -> fixed_bytestobytes -> int
val encode : string -> bytestobytes -> string
val decode : bytestobytes -> string -> string
Expand Down
1 change: 1 addition & 0 deletions zarr/src/codecs/codecs.ml
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,7 @@ module Chain = struct
| x :: _ as xs ->
ArrayToArray.parse x shape;
List.fold_left ArrayToArray.encoded_repr shape xs);
List.iter BytesToBytes.parse b2b;
{a2a; a2b; b2b}

let encode t x =
Expand Down
2 changes: 2 additions & 0 deletions zarr/src/codecs/codecs.mli
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@ module Chain : sig
if [c] contains more than one bytes->bytes codec.
@raise Invalid_transpose_order
if [c] contains a transpose codec with invalid order array.
@raise Invalid_zstd_level
if [c] contains a Zstd codec whose compression level is invalid.
@raise Invalid_sharding_chunk_shape
if [c] contains a shardingindexed codec with an
incorrect inner chunk shape. *)
Expand Down
10 changes: 8 additions & 2 deletions zarr/src/codecs/codecs_intf.ml
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ exception Array_to_bytes_invariant
exception Invalid_transpose_order
exception Invalid_sharding_chunk_shape
exception Invalid_codec_ordering
exception Invalid_zstd_level

type arraytoarray =
[ `Transpose of int array ]
Expand All @@ -13,7 +14,8 @@ type fixed_bytestobytes =
[ `Crc32c ]

type variable_bytestobytes =
[ `Gzip of compression_level ]
[ `Gzip of compression_level
| `Zstd of int * bool ]

type bytestobytes =
[ fixed_bytestobytes | variable_bytestobytes ]
Expand Down Expand Up @@ -62,6 +64,9 @@ module type Interface = sig
(** raised when a codec chain has incorrect ordering of codecs. i.e if the
ordering is not [arraytoarray list -> 1 arraytobytes -> bytestobytes list]. *)

exception Invalid_zstd_level
(** raised when a codec chain contains a Zstd codec with an incorrect compression value.*)

(** The type of [array -> array] codecs. *)
type arraytoarray =
[ `Transpose of int array ]
Expand All @@ -78,7 +83,8 @@ module type Interface = sig
(** A type representing [bytes -> bytes] codecs that produce
variable sized encoded strings. *)
type variable_bytestobytes =
[ `Gzip of compression_level ]
[ `Gzip of compression_level
| `Zstd of int * bool ]

(** The type of [bytes -> bytes] codecs. *)
type bytestobytes =
Expand Down
1 change: 1 addition & 0 deletions zarr/src/dune
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
(libraries
yojson
ezgzip
bytesrw.zstd
stdint
checkseum)
(ocamlopt_flags
Expand Down
73 changes: 60 additions & 13 deletions zarr/test/test_codecs.ml
Original file line number Diff line number Diff line change
Expand Up @@ -238,19 +238,22 @@ let tests = [
[{"name": "bytes", "configuration": {"endian": "big"}}]}}]|}
~msg:"Must be exactly one array->bytes codec.";
(* test violation of index_codec invariant when it contains variable-sized codecs. *)
decode_chain
~shape:[|5; 5; 5|]
~str:{|[
{"name": "sharding_indexed",
"configuration":
{"index_location": "start",
"chunk_shape": [5, 5, 5],
"index_codecs":
[{"name": "bytes", "configuration": {"endian": "big"}},
{"name": "gzip", "configuration": {"level": 1}}],
"codecs":
[{"name": "bytes", "configuration": {"endian": "big"}}]}}]|}
~msg:"Must be exactly one array->bytes codec.";
List.iter
(fun c ->
decode_chain
~shape:[|5; 5; 5|]
~str:(Format.sprintf {|[
{"name": "sharding_indexed",
"configuration":
{"index_location": "start",
"chunk_shape": [5, 5, 5],
"index_codecs":
[{"name": "bytes", "configuration": {"endian": "big"}}, %s],
"codecs":
[{"name": "bytes", "configuration": {"endian": "big"}}]}}]|} c)
~msg:"Must be exactly one array->bytes codec.")
[{|{"name": "zstd", "configuration": {"level": 0, "checksum": false}}|}
;{|{"name": "gzip", "configuration": {"level": 1}}|}];

let shape = [|10; 15; 10|] in
let kind = Ndarray.Float64 in
Expand Down Expand Up @@ -365,6 +368,50 @@ let tests = [
assert_equal arr @@ Chain.decode c {shape; kind} encoded)
[L0; L1; L2; L3; L4; L5; L6; L7; L8; L9])
;

"test zstd codec" >:: (fun _ ->
(* test wrong compression level *)
List.iter
(fun l ->
decode_chain
~shape:[||]
~str:(Format.sprintf {|[{"name": "bytes", "configuration": {"endian": "little"}},
{"name": "zstd", "configuration": {"level": %d, "checksum": false}}]|} l)
~msg:"zstd codec is unsupported or has invalid configuration.")
[50; -500_000];
(* test incorrect configuration *)
decode_chain
~shape:[||]
~str:{|[{"name": "bytes", "configuration": {"endian": "little"}},
{"name": "zstd", "configuration": {"something": -1}}]|}
~msg:"zstd codec is unsupported or has invalid configuration.";

(* test correct deserialization of zstd compression level *)
let shape = [|10; 15; 10|] in
List.iter
(fun level ->
let str =
Format.sprintf
{|[{"name": "bytes", "configuration": {"endian": "little"}},
{"name": "zstd", "configuration": {"level": %d, "checksum": false}}]|} level
in
let r = Chain.of_yojson shape @@ Yojson.Safe.from_string str in
assert_bool "Encoding this chain should not fail" @@ Result.is_ok r)
[-131072; 0];

(* test encoding/decoding for various compression levels *)
let kind = Ndarray.Int in
let fill_value = Int.max_int in
let arr = Ndarray.create kind shape fill_value in
let chain = [`Bytes LE] in
List.iter
(fun (level, checksum) ->
let c = Chain.create shape @@ chain @ [`Zstd (level, checksum)] in
let encoded = Chain.encode c arr in
assert_equal arr @@ Chain.decode c {shape; kind} encoded)
[(-131072, false); (-131072, true); (0, false); (0, true)])
;

"test bytes codec" >:: (fun _ ->
let shape = [|2; 2; 2|] in
(* test decoding of chain with invalid endianness name *)
Expand Down

0 comments on commit 703afaa

Please sign in to comment.