Skip to content

Commit

Permalink
More-or-less full minimal VCF parsing
Browse files Browse the repository at this point in the history
  • Loading branch information
jeromekelleher committed Jan 23, 2024
1 parent 0ee7356 commit 3a530c7
Showing 1 changed file with 109 additions and 21 deletions.
130 changes: 109 additions & 21 deletions vcf2zarr.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

# from sgkit.io.vcf.vcf_reader import VcfFieldHandler, _normalize_fields
from sgkit.io.utils import FLOAT32_MISSING
from sgkit.utils import smallest_numpy_int_dtype

numcodecs.blosc.use_threads = False

Expand Down Expand Up @@ -97,11 +98,11 @@ def write_partition(
):
# print(f"process {os.getpid()} starting")
vcf = cyvcf2.VCF(partition.path)
offset = partition.offset
offset = partition.start_offset

store = zarr.DirectoryStore(zarr_path)
root = zarr.group(store=store)
# These are the bare-minimum

contig_array = root["variant_contig"]
pos_array = root["variant_position"]
qual_array = root["variant_quality"]
Expand All @@ -120,6 +121,7 @@ def write_partition(
gt_phased_buffer = np.zeros((chunk_length, n), dtype=bool)

buffered_arrays = [
BufferedArray(contig_buffer, contig_array),
BufferedArray(pos_buffer, pos_array),
BufferedArray(qual_buffer, qual_array),
BufferedArray(gt_buffer, gt_array),
Expand Down Expand Up @@ -160,10 +162,11 @@ def flush_buffers(futures, start=0, stop=chunk_length):

return futures

variant_contig_names = vcf.seqnames
contig_name_map = {name: j for j, name in enumerate(vcf_fields.contig_names)}

# Flushing out the chunks takes less time than reading in here in the
# main thread, so no real point in using lots of threads.
alleles = []
with concurrent.futures.ThreadPoolExecutor(max_workers=4) as executor:
j = offset % chunk_length
chunk_start = j
Expand All @@ -177,13 +180,12 @@ def flush_buffers(futures, start=0, stop=chunk_length):
# done here, but it's not releasing the GIL, so may not be worth
# moving to threads.
try:
# TODO make this faster - can have large number of contigs
# in the header
contig_buffer[j] = variant_contig_names.index(variant.CHROM)
except ValueError:
contig_buffer[j] = contig_name_map[variant.CHROM]
except KeyError:
raise ValueError(
f"Contig '{variant.CHROM}' is not defined in the header."
)

pos_buffer[j] = variant.POS
qual_buffer[j] = (
variant.QUAL if variant.QUAL is not None else FLOAT32_MISSING
Expand All @@ -193,6 +195,10 @@ def flush_buffers(futures, start=0, stop=chunk_length):
gt_buffer[j] = gt[:, :-1]
gt_phased_buffer[j] = gt[:, -1]

# Alleles are treated separately. Store the alleles for each site
# in a list and return to the main thread for later processing.
alleles.append([variant.REF] + variant.ALT)

j += 1
if j == chunk_length:
futures = flush_buffers(futures, start=chunk_start)
Expand All @@ -209,19 +215,24 @@ def flush_buffers(futures, start=0, stop=chunk_length):
# Flush the last chunk
flush_buffers(futures, stop=j)

# Send the alleles list back to the main process.
partition.alleles = alleles
return partition


@dataclasses.dataclass
class VcfChunk:
path: str
num_records: int
first_position: int
offset: int = 0
start_offset: int = 0
alleles: list = None


@dataclasses.dataclass
class VcfFields:
samples: list
# TODO other stuff like sgkit does
contig_names: list
# field_handlers: list


Expand All @@ -247,7 +258,7 @@ def scan_vcfs(paths):
# )
# for field_name in field_names
# ]
fields = VcfFields(samples=vcf.samples)
fields = VcfFields(samples=vcf.samples, contig_names=vcf.seqnames)
if vcf_fields is None:
vcf_fields = fields
else:
Expand All @@ -257,65 +268,92 @@ def scan_vcfs(paths):

chunks.append(
# Requires cyvcf2>=0.30.27
VcfChunk(path=path, num_records=vcf.num_records, first_position=record.POS)
VcfChunk(
path=path,
num_records=vcf.num_records,
first_position=(record.CHROM, record.POS),
)
)

# Assuming these are all on the same contig for now.
chunks.sort(key=lambda x: x.first_position)
offset = 0
for chunk in chunks:
chunk.offset = offset
chunk.start_offset = offset
offset += chunk.num_records
return vcf_fields, chunks


def create_zarr(path, vcf_fields, partitions):
chunk_width = 10001
chunk_length = 2001
chunk_width = 10000
chunk_length = 2000

n = len(vcf_fields.samples)
sample_id = np.array(vcf_fields.samples, dtype="O")
n = sample_id.shape[0]
m = sum(partition.num_records for partition in partitions)

store = zarr.DirectoryStore(path)
root = zarr.group(store=store, overwrite=True)
compressor = numcodecs.Blosc(
cname="zstd", clevel=7, shuffle=numcodecs.Blosc.AUTOSHUFFLE
)
root.empty(
a = root.array("sample_id", sample_id, dtype="str", compressor=compressor)
a.attrs["_ARRAY_DIMENSIONS"] = ["samples"]

a = root.array(
"variant_contig_names",
vcf_fields.contig_names,
dtype="str",
compressor=compressor,
)
a.attrs["_ARRAY_DIMENSIONS"] = ["contig"]

a = root.empty(
"variant_contig",
shape=(m),
chunks=(chunk_length),
dtype=np.int8, # FIXME use smallest_np_dtype
dtype=smallest_numpy_int_dtype(len(vcf_fields.contig_names)),
compressor=compressor,
)
root.empty(
a.attrs["_ARRAY_DIMENSIONS"] = ["variants"]

a = root.empty(
"variant_position",
shape=(m),
chunks=(chunk_length),
dtype=np.int32,
compressor=compressor,
)
root.empty(
a.attrs["_ARRAY_DIMENSIONS"] = ["variants"]

a = root.empty(
"variant_quality",
shape=(m),
chunks=(chunk_length),
dtype=np.float32,
compressor=compressor,
)
root.empty(
a.attrs["_ARRAY_DIMENSIONS"] = ["variants"]

a = root.empty(
"call_genotype",
shape=(m, n, 2),
chunks=(chunk_length, chunk_width),
dtype=np.int8,
compressor=compressor,
)
root.empty(
a.attrs["_ARRAY_DIMENSIONS"] = ["variants", "samples", "ploidy"]

# TODO add call_genotype_mask. What's the point of it, though?

a = root.empty(
"call_genotype_phased",
shape=(m, n),
chunks=(chunk_length, chunk_width),
dtype=bool,
compressor=compressor,
)
a.attrs["_ARRAY_DIMENSIONS"] = ["variants", "samples"]

# for handler in vcf_fields.field_handlers:
# if handler.dims == ["variants"]:
Expand All @@ -337,6 +375,39 @@ def create_zarr(path, vcf_fields, partitions):
# )


def finalise_zarr(path, partitions):
m = sum(partition.num_records for partition in partitions)

alleles = []
for part in partitions:
alleles.extend(part.alleles)

max_num_alleles = 0
for row in alleles:
max_num_alleles = max(max_num_alleles, len(row))

variant_alleles = np.full((m, max_num_alleles), "", dtype="O")
for j, row in enumerate(alleles):
variant_alleles[j, : len(row)] = row

variant_allele_array = np.array(variant_alleles, dtype="O")

store = zarr.DirectoryStore(path)
root = zarr.group(store=store, overwrite=False)
compressor = numcodecs.Blosc(
cname="zstd", clevel=7, shuffle=numcodecs.Blosc.AUTOSHUFFLE
)
a = root.array(
"variant_alleles", variant_allele_array, dtype="str", compressor=compressor
)
a.attrs["_ARRAY_DIMENSIONS"] = ["variants", "alleles"]
print(a)

# print(all_alleles)

zarr.consolidate_metadata(path)


def update_bar(progress_counter, num_variants):
pbar = tqdm.tqdm(total=num_variants)

Expand Down Expand Up @@ -391,13 +462,30 @@ def main(vcfs, out_path):
last_chunk_lock=locks[j + 1],
)
)
completed = []
for future in concurrent.futures.as_completed(futures):
exception = future.exception()
if exception is not None:
raise exception
completed.append(future.result())

assert progress_counter.value == total_variants

completed.sort(key=lambda x: x.first_position)
finalise_zarr(out_path, completed)

import sgkit

ds = sgkit.load_dataset(out_path)
print(ds)
print(ds.variant_contig_names.values)
print(ds.sample_id.values)
print(ds.variant_contig.values)
print(ds.variant_position.values)
print("PROBLEM!! Why are lots of these 0?")
print(np.sum(ds.variant_position.values == 0))
print(ds.variant_alleles.values)


if __name__ == "__main__":
main()

0 comments on commit 3a530c7

Please sign in to comment.