Skip to content

Commit

Permalink
Seems to be working
Browse files Browse the repository at this point in the history
  • Loading branch information
jeromekelleher committed Jan 26, 2024
1 parent bab86d0 commit 7d0e398
Show file tree
Hide file tree
Showing 2 changed files with 119 additions and 41 deletions.
111 changes: 77 additions & 34 deletions sgkit/io/vcf/vcf_converter.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,12 +142,17 @@ class BufferedUnsizedField:
def swap_buffers(self):
self.buff = []

def sync_flush(self, zarr_path, partition_index, chunk_index):
dest_file = (
zarr_path / "tmp" / self.variable_name / f"{partition_index}.{chunk_index}"
)
with open(dest_file, "wb") as f:
pickle.dump(self.buff, f)

def sync_flush_unsized_buffer(buff, file_path):
with open(file_path, "wb") as f:
pickle.dump(buff, f)


def async_flush_unsized_buffer(
executor, buff, zarr_path, variable_name, partition_index, chunk_index
):
dest_file = zarr_path / "tmp" / variable_name / f"{partition_index}.{chunk_index}"
return [executor.submit(sync_flush_unsized_buffer, buff, dest_file)]


def flush_futures(futures):
Expand Down Expand Up @@ -198,24 +203,31 @@ def write_partition(
gt_phased,
gt_mask,
]

# The unbound fields. These are buffered in Python lists and stored
# in pickled chunks for later analysis
allele = BufferedUnsizedField("variant_allele")
buffered_unsized_fields = [allele]

unsized_info_fields = []
fixed_info_fields = []
for field in vcf_metadata.info_fields:
if field.dimension is not None:
ba = BufferedArray(root[field.variable_name])
buffered_arrays.append(ba)
fixed_info_fields.append((field, ba))
else:
buf = BufferedUnsizedField(field.variable_name)
buffered_unsized_fields.append(buf)
unsized_info_fields.append((field, buf))

fixed_format_fields = []
for field in vcf_metadata.format_fields:
if field.dimension is not None:
ba = BufferedArray(root[field.variable_name])
buffered_arrays.append(ba)
fixed_format_fields.append((field, ba))

# The unbound fields. These are buffered in Python lists and stored
# in pickled chunks for later analysis
allele = BufferedUnsizedField("variant_allele")
buffered_unsized_fields = [allele]

chunk_length = pos.buff.shape[0]

def flush_fixed_buffers(start=0, stop=chunk_length):
Expand Down Expand Up @@ -244,9 +256,20 @@ def flush_fixed_buffers(start=0, stop=chunk_length):
return futures

def flush_unsized_buffers(chunk_index):
futures = []
for buf in buffered_unsized_fields:
buf.sync_flush(zarr_path, partition.index, chunk_index)
futures.extend(
async_flush_unsized_buffer(
executor,
buf.buff,
zarr_path,
buf.variable_name,
partition.index,
chunk_index,
)
)
buf.swap_buffers()
return futures

contig_name_map = {name: j for j, name in enumerate(vcf_metadata.contig_names)}
filter_map = {filter_id: j for j, filter_id in enumerate(vcf_metadata.filters)}
Expand Down Expand Up @@ -299,6 +322,16 @@ def flush_unsized_buffers(chunk_index):
buffered_array.buff[j] = variant.INFO[field.name]
except KeyError:
pass
for field, buffered_unsized_field in unsized_info_fields:
val = tuple()
try:
val = variant.INFO[field.name]
except KeyError:
pass
if not isinstance(val, tuple):
val = (val,)
buffered_unsized_field.buff.append(val)

for field, buffered_array in fixed_format_fields:
# NOTE not sure the semantics is correct here
val = None
Expand All @@ -318,10 +351,7 @@ def flush_unsized_buffers(chunk_index):
if j == chunk_length:
flush_futures(futures)
futures = flush_fixed_buffers(start=chunk_start)
flush_unsized_buffers(chunk_index)
# flush_info_buffers(
# zarr_path, buffered_infos, partition.index, chunk_index
# )
futures.extend(flush_unsized_buffers(chunk_index))
j = 0
offset += chunk_length - chunk_start
chunk_start = 0
Expand All @@ -336,7 +366,7 @@ def flush_unsized_buffers(chunk_index):
# Flush the last chunk
flush_futures(futures)
futures = flush_fixed_buffers(start=chunk_start, stop=j)
flush_unsized_buffers(chunk_index)
futures.extend(flush_unsized_buffers(chunk_index))

# Wait for the last batch of futures to complete
flush_futures(futures)
Expand Down Expand Up @@ -645,7 +675,7 @@ def create_zarr(

else:
# print(field)
field_dir = tmp_dir / f"INFO_{field.name}"
field_dir = tmp_dir / field.variable_name
field_dir.mkdir()

for field in vcf_metadata.format_fields:
Expand Down Expand Up @@ -711,22 +741,35 @@ def finalise_zarr(path, vcf_metadata, partitions, chunk_length):
)
a.attrs["_ARRAY_DIMENSIONS"] = ["variants", "alleles"]

# tmp_dir = path / "tmp"
# for info_field in vcf_metadata.info_fields:
# field_dir = tmp_dir / f"INFO_{info_field.name}"
# data = []
# for partition in partitions:
# for chunk in range(partition.num_chunks):
# filename = field_dir / f"{partition.index}.{chunk}"
# with open(filename, "rb") as f:
# data.extend(pickle.load(f))

# print(info_field, ":", data[:10])
# try:
# np_array = np.array(data)
# print("\t", np_array)
# except ValueError as e:
# print("\terror", e)
for field in vcf_metadata.info_fields:
if field.dimension is None:
print("Write", field.variable_name)
py_array = join_partitioned_lists(tmp_dir / field.variable_name, partitions)
# if field.vcf_number == ".":
max_len = 0
for row in py_array:
max_len = max(max_len, len(row))
shape = (m, max_len)
a = root.empty(
field.variable_name,
shape=shape,
chunks=(chunk_length),
dtype=field.dtype,
compressor=compressor,
)
a.attrs["_ARRAY_DIMENSIONS"] = ["variants", field.name]
# print(a)
np_array = np.full(
(m, max_len), _dtype_to_fill[a.dtype.str], dtype=field.dtype
)
for j, row in enumerate(py_array):
np_array[j, : len(row)] = row
a[:] = np_array

# print(field)
# print(np_array)
# np_array = np.array(py_array, dtype=field.dtype)
# print(np_array)

zarr.consolidate_metadata(path)

Expand Down
49 changes: 42 additions & 7 deletions sgkit/tests/io/vcf/test_vcf_reader.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,10 @@
import xarray as xr
import zarr
from numcodecs import Blosc, Delta, FixedScaleOffset, PackBits, VLenUTF8
from numpy.testing import assert_allclose, assert_array_equal
from numpy.testing import assert_allclose, assert_array_equal, assert_array_almost_equal

from sgkit import load_dataset, save_dataset
from sgkit.io.utils import FLOAT32_FILL, INT_FILL, INT_MISSING
from sgkit.io.utils import FLOAT32_FILL, FLOAT32_MISSING, INT_FILL, INT_MISSING
from sgkit.io.vcf import (
MaxAltAllelesExceededWarning,
partition_into_regions,
Expand Down Expand Up @@ -55,15 +55,22 @@ def test_vcf_to_zarr__small_vcf(
"INFO/AN",
"INFO/AA",
"INFO/DB",
"INFO/AC",
"INFO/AF",
"FORMAT/GT",
"FORMAT/DP",
"FORMAT/HQ",
]
field_defs = {"FORMAT/HQ": {"dimension": "ploidy"}}
field_defs = {
"FORMAT/HQ": {"dimension": "ploidy"},
"INFO/AF": {"Number": "2", "dimension": "AF"},
"INFO/AC": {"Number": "2", "dimension": "AC"},
}
if method == "to_zarr":
vcf_to_zarr(
path,
output,
max_alt_alleles=3,
chunk_length=5,
chunk_width=2,
read_chunk_length=read_chunk_length,
Expand Down Expand Up @@ -155,6 +162,30 @@ def test_vcf_to_zarr__small_vcf(
)
assert ds["variant_AN"].chunks[0][0] == 5

variant_AF = np.full((9, 2), FLOAT32_MISSING, dtype=np.float32)
variant_AF[2, 0] = 0.5
variant_AF[3, 0] = 0.017
variant_AF[4, 0] = 0.333
variant_AF[4, 1] = 0.667
assert_array_almost_equal(ds["variant_AF"], variant_AF, 3)
assert ds["variant_AF"].chunks[0][0] == 5

assert_array_equal(
ds["variant_AC"],
[
[-2, -2],
[-2, -2],
[-2, -2],
[-2, -2],
[-2, -2],
[-2, -2],
[3, 1],
[-2, -2],
[-2, -2],
],
)
assert ds["variant_AC"].chunks[0][0] == 5

assert_array_equal(
ds["variant_allele"].values.tolist(),
[
Expand Down Expand Up @@ -1809,14 +1840,18 @@ def test_vcf_to_zarr__no_samples(shared_datadir, tmp_path):
],
)
def test_compare_vcf_to_zarr_convert(shared_datadir, tmp_path, vcf_name):
max_alt_alleles = 200
vcf_path = path_for_test(shared_datadir, vcf_name)
zarr1_path = tmp_path.joinpath("vcf1.zarr").as_posix()
vcf_to_zarr(vcf_path, zarr1_path, max_alt_alleles=max_alt_alleles)
zarr2_path = tmp_path.joinpath("vcf2.zarr").as_posix()
convert_vcf([vcf_path], zarr2_path, max_alt_alleles=max_alt_alleles)
# convert reads all variables by default.

# Convert gets the actual number of alleles by default, so use this as the
# input for
convert_vcf([vcf_path], zarr2_path)
ds2 = load_dataset(zarr2_path)
vcf_to_zarr(vcf_path, zarr1_path, max_alt_alleles=ds2.variant_allele.shape[1] - 1)
ds1 = load_dataset(zarr1_path)

# convert reads all variables by default.
base_vars = list(ds1)
ds2 = load_dataset(zarr2_path)
xr.testing.assert_equal(ds1, ds2[base_vars])

0 comments on commit 7d0e398

Please sign in to comment.