Skip to content

Commit

Permalink
Add some required fields
Browse files Browse the repository at this point in the history
  • Loading branch information
jeromekelleher committed Jan 23, 2024
1 parent 7641481 commit 0ee7356
Showing 1 changed file with 71 additions and 17 deletions.
88 changes: 71 additions & 17 deletions vcf2zarr.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,8 @@
import tqdm
import zarr

from sgkit.io.vcf.vcf_reader import VcfFieldHandler, _normalize_fields
# from sgkit.io.vcf.vcf_reader import VcfFieldHandler, _normalize_fields
from sgkit.io.utils import FLOAT32_MISSING

numcodecs.blosc.use_threads = False

Expand Down Expand Up @@ -101,7 +102,9 @@ def write_partition(
store = zarr.DirectoryStore(zarr_path)
root = zarr.group(store=store)
# These are the bare-minimum
contig_array = root["variant_contig"]
pos_array = root["variant_position"]
qual_array = root["variant_quality"]
gt_array = root["call_genotype"]
gt_phased_array = root["call_genotype_phased"]

Expand All @@ -110,12 +113,15 @@ def write_partition(

# TODO generalise this so we're allocating the buffer and the array
# at the same time.
contig_buffer = np.zeros((chunk_length), dtype=contig_array.dtype)
pos_buffer = np.zeros((chunk_length), dtype=np.int32)
qual_buffer = np.zeros((chunk_length), dtype=np.float32)
gt_buffer = np.zeros((chunk_length, n, 2), dtype=np.int8)
gt_phased_buffer = np.zeros((chunk_length, n), dtype=bool)

buffered_arrays = [
BufferedArray(pos_buffer, pos_array),
BufferedArray(qual_buffer, qual_array),
BufferedArray(gt_buffer, gt_array),
BufferedArray(gt_phased_buffer, gt_phased_array),
]
Expand Down Expand Up @@ -154,6 +160,8 @@ def flush_buffers(futures, start=0, stop=chunk_length):

return futures

variant_contig_names = vcf.seqnames

# Flushing out the chunks takes less time than reading in here in the
# main thread, so no real point in using lots of threads.
with concurrent.futures.ThreadPoolExecutor(max_workers=4) as executor:
Expand All @@ -168,7 +176,18 @@ def flush_buffers(futures, start=0, stop=chunk_length):
# Translate this record into numpy buffers. There is some compute
# done here, but it's not releasing the GIL, so may not be worth
# moving to threads.
try:
# TODO make this faster - can have large number of contigs
# in the header
contig_buffer[j] = variant_contig_names.index(variant.CHROM)
except ValueError:
raise ValueError(
f"Contig '{variant.CHROM}' is not defined in the header."
)
pos_buffer[j] = variant.POS
qual_buffer[j] = (
variant.QUAL if variant.QUAL is not None else FLOAT32_MISSING
)
gt = variant.genotype.array()
assert gt.shape[1] == 3
gt_buffer[j] = gt[:, :-1]
Expand Down Expand Up @@ -203,13 +222,31 @@ class VcfChunk:
class VcfFields:
samples: list
# TODO other stuff like sgkit does
# field_handlers: list


def scan_vcfs(paths):
chunks = []
vcf_fields = None
for path in tqdm.tqdm(paths, desc="Scan"):
vcf = cyvcf2.VCF(path)
# Hack
# field_names = _normalize_fields(
# vcf, ["FORMAT/GQ", "FORMAT/DP", "INFO/AA", "INFO/DP"]
# )
# field_handlers = [
# VcfFieldHandler.for_field(
# vcf,
# field_name,
# chunk_length=0,
# ploidy=2,
# mixed_ploidy=False,
# truncate_calls=False,
# max_alt_alleles=4,
# field_def={},
# )
# for field_name in field_names
# ]
fields = VcfFields(samples=vcf.samples)
if vcf_fields is None:
vcf_fields = fields
Expand All @@ -218,22 +255,6 @@ def scan_vcfs(paths):
raise ValueError("Incompatible VCF chunks")
record = next(vcf)

fields = _normalize_fields(
vcf, ["FORMAT/GT", "FORMAT/GQ", "FORMAT/DP", "INFO/AA", "INFO/DP"]
)
field_handlers = [
VcfFieldHandler.for_field(
vcf,
field,
chunk_length=0,
ploidy=2,
mixed_ploidy=False,
truncate_calls=False,
max_alt_alleles=4,
field_def={},
)
for field in fields
]
chunks.append(
# Requires cyvcf2>=0.30.27
VcfChunk(path=path, num_records=vcf.num_records, first_position=record.POS)
Expand All @@ -260,13 +281,27 @@ def create_zarr(path, vcf_fields, partitions):
compressor = numcodecs.Blosc(
cname="zstd", clevel=7, shuffle=numcodecs.Blosc.AUTOSHUFFLE
)
root.empty(
"variant_contig",
shape=(m),
chunks=(chunk_length),
dtype=np.int8, # FIXME use smallest_np_dtype
compressor=compressor,
)
root.empty(
"variant_position",
shape=(m),
chunks=(chunk_length),
dtype=np.int32,
compressor=compressor,
)
root.empty(
"variant_quality",
shape=(m),
chunks=(chunk_length),
dtype=np.float32,
compressor=compressor,
)
root.empty(
"call_genotype",
shape=(m, n, 2),
Expand All @@ -282,6 +317,25 @@ def create_zarr(path, vcf_fields, partitions):
compressor=compressor,
)

# for handler in vcf_fields.field_handlers:
# if handler.dims == ["variants"]:
# shape = m,
# chunks = chunk_length,
# elif handler.dims == ["variants", "samples"]:
# shape = m, n
# chunks = chunk_length, chunk_width
# else:
# raise ValueError("Not handled")

# root.empty(
# handler.variable_name,
# shape=shape,
# chunks=chunks,
# dtype=handler.array.dtype,
# compressor=compressor,

# )


def update_bar(progress_counter, num_variants):
pbar = tqdm.tqdm(total=num_variants)
Expand Down

0 comments on commit 0ee7356

Please sign in to comment.