Skip to content

Commit

Permalink
Getting same results as existing converter
Browse files Browse the repository at this point in the history
  • Loading branch information
jeromekelleher committed Jan 24, 2024
1 parent 1cded23 commit cad2eb8
Show file tree
Hide file tree
Showing 2 changed files with 129 additions and 15 deletions.
96 changes: 82 additions & 14 deletions sgkit/io/vcf/vcf_converter.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,7 @@ def swap_buffers(self):


def write_partition(
vcf_fields, zarr_path, partition, *, first_chunk_lock, last_chunk_lock
vcf_fields, zarr_path, partition, max_num_alleles, *, first_chunk_lock, last_chunk_lock
):
# print(f"process {os.getpid()} starting")
vcf = cyvcf2.VCF(partition.path)
Expand All @@ -118,6 +118,7 @@ def write_partition(
vid = BufferedArray(root["variant_id"], ".")
vid_mask = BufferedArray(root["variant_id_mask"], True)
qual = BufferedArray(root["variant_quality"], FLOAT32_MISSING)
filt = BufferedArray(root["variant_filter"], False)
gt = BufferedArray(root["call_genotype"], -1)
gt_phased = BufferedArray(root["call_genotype_phased"], 0)
gt_mask = BufferedArray(root["call_genotype_mask"], 0)
Expand All @@ -128,6 +129,7 @@ def write_partition(
vid,
vid_mask,
qual,
filt,
gt,
gt_phased,
gt_mask,
Expand Down Expand Up @@ -167,9 +169,10 @@ def flush_buffers(futures, start=0, stop=chunk_length):
return futures

contig_name_map = {name: j for j, name in enumerate(vcf_fields.contig_names)}
filter_map = {filter_id: j for j, filter_id in enumerate(vcf_fields.filters)}

gt_min = -1 # TODO make this -2 if mixed_ploidy
gt_max = 7 # TODO based on dtype
gt_max = max_num_alleles - 1

# Flushing out the chunks takes less time than reading in here in the
# main thread, so no real point in using lots of threads.
Expand Down Expand Up @@ -199,6 +202,12 @@ def flush_buffers(futures, start=0, stop=chunk_length):
if variant.ID is not None:
vid.buff[j] = variant.ID
vid_mask.buff[j] = False
try:
for f in variant.FILTERS:
filt.buff[j, filter_map[f]] = True
except IndexError:
raise ValueError(f"Filter '{f}' is not defined in the header.")

vcf_gt = variant.genotype.array()
assert vcf_gt.shape[1] == 3
gt.buff[j] = np.clip(vcf_gt[:, :-1], gt_min, gt_max)
Expand Down Expand Up @@ -231,7 +240,7 @@ def flush_buffers(futures, start=0, stop=chunk_length):


@dataclasses.dataclass
class VcfChunk:
class VcfPartition:
path: str
num_records: int
first_position: int
Expand All @@ -243,6 +252,8 @@ class VcfChunk:
class VcfFields:
samples: list
contig_names: list
filters: list
contig_lengths: list = None
# field_handlers: list


Expand All @@ -268,7 +279,25 @@ def scan_vcfs(paths, show_progress):
# )
# for field_name in field_names
# ]
fields = VcfFields(samples=vcf.samples, contig_names=vcf.seqnames)

filters = [
h["ID"]
for h in vcf.header_iter()
if h["HeaderType"] == "FILTER" and isinstance(h["ID"], str)
]
# Ensure PASS is the first filter if present
if "PASS" in filters:
filters.remove("PASS")
filters.insert(0, "PASS")

fields = VcfFields(
samples=vcf.samples, contig_names=vcf.seqnames, filters=filters
)
try:
fields.contig_lengths = vcf.seqlens
except AttributeError:
pass

if vcf_fields is None:
vcf_fields = fields
else:
Expand All @@ -278,7 +307,7 @@ def scan_vcfs(paths, show_progress):

chunks.append(
# Requires cyvcf2>=0.30.27
VcfChunk(
VcfPartition(
path=path,
num_records=vcf.num_records,
first_position=(record.CHROM, record.POS),
Expand All @@ -294,7 +323,8 @@ def scan_vcfs(paths, show_progress):
return vcf_fields, chunks


def create_zarr(path, vcf_fields, partitions, *, chunk_length, chunk_width):
def create_zarr(path, vcf_fields, partitions, *, chunk_length, chunk_width,
max_num_alleles):
sample_id = np.array(vcf_fields.samples, dtype="O")
n = sample_id.shape[0]
m = sum(partition.num_records for partition in partitions)
Expand All @@ -304,6 +334,16 @@ def create_zarr(path, vcf_fields, partitions, *, chunk_length, chunk_width):
compressor = numcodecs.Blosc(
cname="zstd", clevel=7, shuffle=numcodecs.Blosc.AUTOSHUFFLE
)

root.attrs["filters"] = vcf_fields.filters
a = root.array(
"filter_id",
vcf_fields.filters,
dtype="str",
compressor=compressor,
)
a.attrs["_ARRAY_DIMENSIONS"] = ["filters"]

a = root.array(
"sample_id",
sample_id,
Expand All @@ -319,7 +359,16 @@ def create_zarr(path, vcf_fields, partitions, *, chunk_length, chunk_width):
dtype="str",
compressor=compressor,
)
a.attrs["_ARRAY_DIMENSIONS"] = ["contig"]
a.attrs["_ARRAY_DIMENSIONS"] = ["contigs"]

if vcf_fields.contig_lengths is not None:
a = root.array(
"contig_length",
vcf_fields.contig_lengths,
dtype=np.int64,
compressor=compressor,
)
a.attrs["_ARRAY_DIMENSIONS"] = ["contigs"]

a = root.empty(
"variant_contig",
Expand Down Expand Up @@ -366,11 +415,20 @@ def create_zarr(path, vcf_fields, partitions, *, chunk_length, chunk_width):
)
a.attrs["_ARRAY_DIMENSIONS"] = ["variants"]

a = root.empty(
"variant_filter",
shape=(m, len(vcf_fields.filters)),
chunks=(chunk_length),
dtype=bool,
compressor=compressor,
)
a.attrs["_ARRAY_DIMENSIONS"] = ["variants", "filters"]

a = root.empty(
"call_genotype",
shape=(m, n, 2),
chunks=(chunk_length, chunk_width),
dtype=np.int8,
dtype=smallest_numpy_int_dtype(max_num_alleles),
compressor=compressor,
)
a.attrs["_ARRAY_DIMENSIONS"] = ["variants", "samples", "ploidy"]
Expand Down Expand Up @@ -413,16 +471,17 @@ def create_zarr(path, vcf_fields, partitions, *, chunk_length, chunk_width):
# )


def finalise_zarr(path, partitions, chunk_length):
def finalise_zarr(path, partitions, chunk_length, max_num_alleles):
m = sum(partition.num_records for partition in partitions)

alleles = []
for part in partitions:
alleles.extend(part.alleles)

max_num_alleles = 0
for row in alleles:
max_num_alleles = max(max_num_alleles, len(row))
# TODO raise a warning here if this isn't met.
# max_num_alleles = 0
# for row in alleles:
# max_num_alleles = max(max_num_alleles, len(row))

variant_alleles = np.full((m, max_num_alleles), "", dtype="O")
for j, row in enumerate(alleles):
Expand Down Expand Up @@ -463,7 +522,9 @@ def init_workers(counter):


def convert_vcf(
vcfs, out_path, *, chunk_length=None, chunk_width=None, show_progress=False
vcfs, out_path, *, chunk_length=None, chunk_width=None,
max_alt_alleles=None,
show_progress=False
):
# TODO add a try-except here for KeyboardInterrupt which will kill
# various things and clean-up.
Expand All @@ -475,6 +536,11 @@ def convert_vcf(
if chunk_length is None:
chunk_length = 2000

if max_alt_alleles is None:
max_alt_alleles = 3

max_num_alleles = max_alt_alleles + 1

# TODO write the Zarr to a temporary name, only renaming at the end
# on success.
create_zarr(
Expand All @@ -483,6 +549,7 @@ def convert_vcf(
partitions,
chunk_width=chunk_width,
chunk_length=chunk_length,
max_num_alleles=max_num_alleles,
)

total_variants = sum(partition.num_records for partition in partitions)
Expand Down Expand Up @@ -511,6 +578,7 @@ def convert_vcf(
vcf_fields,
out_path,
part,
max_num_alleles,
first_chunk_lock=locks[j],
last_chunk_lock=locks[j + 1],
)
Expand All @@ -525,4 +593,4 @@ def convert_vcf(
assert progress_counter.value == total_variants

completed_partitions.sort(key=lambda x: x.first_position)
finalise_zarr(out_path, completed_partitions, chunk_length)
finalise_zarr(out_path, completed_partitions, chunk_length, max_num_alleles)
48 changes: 47 additions & 1 deletion sgkit/tests/io/vcf/test_vcf_reader.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,11 @@
@pytest.mark.parametrize("method", ["to_zarr", "convert", "load"])
@pytest.mark.filterwarnings("ignore::xarray.coding.variables.SerializationWarning")
def test_vcf_to_zarr__small_vcf(
shared_datadir, is_path, read_chunk_length, tmp_path, method,
shared_datadir,
is_path,
read_chunk_length,
tmp_path,
method,
):
path = path_for_test(shared_datadir, "sample.vcf.gz", is_path)
output = tmp_path.joinpath("vcf.zarr").as_posix()
Expand All @@ -68,6 +72,21 @@ def test_vcf_to_zarr__small_vcf(
else:
ds = read_vcf(path, chunk_length=5, chunk_width=2)

assert_array_equal(ds["filter_id"], ["PASS", "s50", "q10"])
assert_array_equal(
ds["variant_filter"],
[
[False, False, False],
[False, False, False],
[True, False, False],
[False, False, True],
[True, False, False],
[True, False, False],
[True, False, False],
[False, False, False],
[True, False, False],
],
)
assert_array_equal(ds["contig_id"], ["19", "20", "X"])
assert "contig_length" not in ds
assert_array_equal(ds["variant_contig"], [0, 0, 1, 1, 1, 1, 1, 1, 2])
Expand Down Expand Up @@ -1689,3 +1708,30 @@ def test_vcf_to_zarr__no_samples(shared_datadir, tmp_path):
assert_array_equal(ds["sample_id"], [])
assert_array_equal(ds["contig_id"], ["1"])
assert ds.sizes["variants"] == 973


@pytest.mark.parametrize(
"vcf_name",
[
"1000G.phase3.broad.withGenotypes.chr20.10100000.vcf.gz",
"CEUTrio.20.21.gatk3.4.csi.g.vcf.bgz",
"CEUTrio.20.21.gatk3.4.g.bcf",
"CEUTrio.20.21.gatk3.4.g.vcf.bgz",
"CEUTrio.20.gatk3.4.g.vcf.bgz",
"CEUTrio.21.gatk3.4.g.vcf.bgz",
"NA12878.prod.chr20snippet.g.vcf.gz",
"sample_multiple_filters.vcf.gz",
"sample.vcf.gz",
"allele_overflow.vcf.gz",
],
)
def test_compare_vcf_to_zarr_convert(shared_datadir, tmp_path, vcf_name):
max_alt_alleles = 200
vcf_path = path_for_test(shared_datadir, vcf_name)
zarr1_path = tmp_path.joinpath("vcf1.zarr").as_posix()
vcf_to_zarr(vcf_path, zarr1_path, max_alt_alleles=max_alt_alleles)
zarr2_path = tmp_path.joinpath("vcf2.zarr").as_posix()
convert_vcf([vcf_path], zarr2_path, max_alt_alleles=max_alt_alleles)
ds1 = load_dataset(zarr1_path)
ds2 = load_dataset(zarr2_path)
xr.testing.assert_equal(ds1, ds2)

0 comments on commit cad2eb8

Please sign in to comment.