Skip to content

Commit

Permalink
Basically working validation.
Browse files Browse the repository at this point in the history
  • Loading branch information
jeromekelleher committed Feb 14, 2024
1 parent d42a167 commit 284ee6e
Show file tree
Hide file tree
Showing 3 changed files with 156 additions and 19 deletions.
145 changes: 128 additions & 17 deletions sgkit/io/vcf/vcf_converter.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,8 @@
# CHAR_MISSING,
FLOAT32_FILL,
FLOAT32_MISSING,
FLOAT32_FILL_AS_INT32,
FLOAT32_MISSING_AS_INT32,
INT_FILL,
INT_MISSING,
# STR_FILL,
Expand All @@ -49,6 +51,63 @@
)


def assert_all_missing_float(a):
v = np.array(a, dtype=np.float32).view(np.int32)
assert np.all(v == FLOAT32_MISSING_AS_INT32)


def assert_prefix_integer_equal_1d(vcf_val, zarr_val):
v = np.array(vcf_val, dtype=np.int32, ndmin=1)
z = np.array(zarr_val, dtype=np.int32, ndmin=1)
v[v == VCF_INT_MISSING] = -1
v[v == VCF_INT_FILL] = -2
k = v.shape[0]
assert np.all(z[k:] == -2)
nt.assert_array_equal(v, z[:k])


def assert_prefix_integer_equal_2d(vcf_val, zarr_val):
assert len(vcf_val.shape) == 2
vcf_val[vcf_val == VCF_INT_MISSING] = -1
vcf_val[vcf_val == VCF_INT_FILL] = -2
if vcf_val.shape[1] == 1:
nt.assert_array_equal(vcf_val[:, 0], zarr_val)
else:
k = vcf_val.shape[1]
nt.assert_array_equal(vcf_val, zarr_val[:, :k])
assert np.all(zarr_val[:, k:] == -2)


def assert_prefix_float_equal_1d(vcf_val, zarr_val):
v = np.array(vcf_val, dtype=np.float32, ndmin=1)
vi = v.view(np.int32)
z = np.array(zarr_val, dtype=np.float32, ndmin=1)
zi = z.view(np.int32)
assert np.sum(zi == FLOAT32_MISSING_AS_INT32) == 0
k = v.shape[0]
assert np.all(zi[k:] == FLOAT32_FILL_AS_INT32)
# assert np.where(zi[:k] == FLOAT32_FILL_AS_INT32)
nt.assert_array_almost_equal(v, z[:k])
# nt.assert_array_equal(v, z[:k])


def assert_prefix_float_equal_2d(vcf_val, zarr_val):
assert len(vcf_val.shape) == 2
v = np.array(vcf_val, dtype=np.float32, ndmin=1)
vi = v.view(np.int32)
z = np.array(zarr_val, dtype=np.float32, ndmin=1)
zi = z.view(np.int32)
assert np.sum(zi == FLOAT32_MISSING_AS_INT32) == 0
k = v.shape[0]
print("k", k)
assert np.all(zi[k:] == FLOAT32_FILL_AS_INT32)
# assert np.where(zi[:k] == FLOAT32_FILL_AS_INT32)
nt.assert_array_almost_equal(v, z[:k])
# nt.assert_array_equal(v, z[:k])

# pass


# TODO rename to wait_and_check_futures
def flush_futures(futures):
# Make sure previous futures have completed
Expand Down Expand Up @@ -864,7 +923,7 @@ def asdict(self):
@staticmethod
def fromdict(d):
ret = ZarrConversionSpec(**d)
ret.columns = [ZarrColumnSpec(**cd) for cd in d["columns"]]
ret.variables = [ZarrColumnSpec(**cd) for cd in d["variables"]]
return ret

@staticmethod
Expand Down Expand Up @@ -1254,7 +1313,6 @@ def convert(
# FIXME
sgvcf = SgvcfZarr(path)
sgvcf.root = zarr.group(store=store, overwrite=True)

for variable in conversion_spec.variables[:]:
sgvcf.create_array(variable)

Expand Down Expand Up @@ -1298,6 +1356,11 @@ def convert(
for variable in conversion_spec.variables[:]:
if variable.vcf_field is not None:
# print("Encode", variable.name)
# TODO for large columns it's probably worth splitting up
# these into vertical chunks. Otherwise we tend to get a
# long wait for the largest GT columns to finish.
# Straightforward to do because we can chunk-align the work
# packages.
future = executor.submit(sgvcf.encode_column, pcvcf, variable)
futures.append(future)
else:
Expand Down Expand Up @@ -1438,17 +1501,29 @@ def validate(vcf_path, zarr_path, show_progress):
vid = root["variant_id"][:]
call_genotype = iter(root["call_genotype"])

format_fields = {}
vcf = cyvcf2.VCF(vcf_path)
format_headers = {}
info_headers = {}
for h in vcf.header_iter():
if h["HeaderType"] == "FORMAT":
format_headers[h["ID"]] = h
if h["HeaderType"] == "INFO":
info_headers[h["ID"]] = h

format_fields = {}
info_fields = {}
for colname in root.keys():
if colname.startswith("call") and not colname.startswith("call_genotype"):
vcf_name = colname.split("_", 1)[1]
vcf_type = None
for h in vcf.header_iter():
if h["HeaderType"] == "FORMAT" and h["ID"] == vcf_name:
vcf_type = h["Type"]
assert vcf_type is not None
vcf_type = format_headers[vcf_name]["Type"]
format_fields[vcf_name] = vcf_type, iter(root[colname])
if colname.startswith("variant"):
name = colname.split("_", 1)[1]
if name.isupper():
vcf_type = info_headers[name]["Type"]
# print(root[colname])
info_fields[name] = vcf_type, iter(root[colname])
# print(info_fields)

first_pos = next(vcf).POS
start_index = np.searchsorted(pos, first_pos)
Expand Down Expand Up @@ -1476,6 +1551,40 @@ def validate(vcf_path, zarr_path, show_progress):
# print(gt_vcf)
nt.assert_array_equal(gt_zarr, gt_vcf)

# TODO this is basically right, but the details about float padding
# need to be worked out in particular. Need to find examples of
# VCFs with Number=. Float fields.
for name, (vcf_type, zarr_iter) in info_fields.items():
vcf_val = None
try:
vcf_val = row.INFO[name]
except KeyError:
pass
zarr_val = next(zarr_iter)
if vcf_val is None:
if vcf_type == "Integer":
assert np.all(zarr_val == -1)
elif vcf_type == "String":
assert np.all(zarr_val == ".")
elif vcf_type == "Flag":
assert zarr_val == False
elif vcf_type == "Float":
assert_all_missing_float(zarr_val)
else:
assert False
else:
# print(name, vcf_type, vcf_val, zarr_val, sep="\t")
if vcf_type == "Integer":
assert_prefix_integer_equal_1d(vcf_val, zarr_val)
elif vcf_type == "Float":
assert_prefix_float_equal_1d(vcf_val, zarr_val)
elif vcf_type == "Flag":
assert zarr_val == True
elif vcf_type == "String":
assert np.all(zarr_val == vcf_val)
else:
assert False

for name, (vcf_type, zarr_iter) in format_fields.items():
vcf_val = None
try:
Expand All @@ -1486,26 +1595,28 @@ def validate(vcf_path, zarr_path, show_progress):
if vcf_val is None:
if vcf_type == "Integer":
assert np.all(zarr_val == -1)
elif vcf_type == "Float":
assert_all_missing_float(zarr_val)
elif vcf_type == "String":
assert np.all(zarr_val == ".")
else:
print("vcf_val", vcf_type, name, vcf_val)
assert False
else:
assert vcf_val.shape[0] == zarr_val.shape[0]
if vcf_type == "Integer":
assert len(vcf_val.shape) == 2
vcf_val[vcf_val == VCF_INT_MISSING] = -1
vcf_val[vcf_val == VCF_INT_FILL] = -2
if vcf_val.shape[1] == 1:
nt.assert_array_equal(vcf_val[:, 0], zarr_val)
else:
k = vcf_val.shape[1]
nt.assert_array_equal(vcf_val, zarr_val[:, :k])
assert np.all(zarr_val[:, k:] == -2)
assert_prefix_integer_equal_2d(vcf_val, zarr_val)
elif vcf_type == "Float":
assert_prefix_float_equal_2d(vcf_val, zarr_val)
elif vcf_type == "String":
nt.assert_array_equal(vcf_val, zarr_val)

# assert_prefix_string_equal_2d(vcf_val, zarr_val)
else:
print(name)
print(vcf_val)
print(zarr_val)
assert False


def convert_plink(
Expand Down
26 changes: 25 additions & 1 deletion sgkit/tests/io/vcf/test_vcf_reader.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
merge_zarr_array_sizes,
zarr_array_sizes,
)
from sgkit.io.vcf.vcf_converter import convert_vcf
from sgkit.io.vcf.vcf_converter import convert_vcf, validate
from sgkit.model import get_contigs, get_filters, num_contigs
from sgkit.tests.io.test_dataset import assert_identical

Expand Down Expand Up @@ -1865,3 +1865,27 @@ def test_compare_vcf_to_zarr_convert(shared_datadir, tmp_path, vcf_name):
# print(ds1.call_genotype.values)
# print(ds2.call_genotype.values)
xr.testing.assert_equal(ds1, ds2[base_vars])


@pytest.mark.parametrize(
"vcf_name",
[
"1000G.phase3.broad.withGenotypes.chr20.10100000.vcf.gz",
"CEUTrio.20.21.gatk3.4.csi.g.vcf.bgz",
"CEUTrio.20.21.gatk3.4.g.bcf",
"CEUTrio.20.21.gatk3.4.g.vcf.bgz",
"CEUTrio.20.gatk3.4.g.vcf.bgz",
"CEUTrio.21.gatk3.4.g.vcf.bgz",
"sample_multiple_filters.vcf.gz",
"sample.vcf.gz",
"allele_overflow.vcf.gz",
],
)
def test_validate_vcf(shared_datadir, tmp_path, vcf_name):
vcf_path = path_for_test(shared_datadir, vcf_name)
zarr_path = os.path.join("tmp/converted/", vcf_name, ".vcf.zarr")
# zarr_path = tmp_path.joinpath("vcf.zarr").as_posix()
print("converting", zarr_path)
convert_vcf([vcf_path], zarr_path)
# validate([vcf_path], zarr_path)

4 changes: 3 additions & 1 deletion vcf2zarr.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,8 @@ def genspec(columnarised):
@click.argument("columnarised", type=click.Path())
@click.argument("zarr_path", type=click.Path())
@click.option("-s", "--conversion-spec", default=None)
def to_zarr(columnarised, zarr_path, conversion_spec):
@click.option("-p", "--worker-processes", type=int, default=1)
def to_zarr(columnarised, zarr_path, conversion_spec, worker_processes):
pcvcf = cnv.PickleChunkedVcf.load(columnarised)
if conversion_spec is None:
spec = cnv.ZarrConversionSpec.generate(pcvcf)
Expand All @@ -74,6 +75,7 @@ def to_zarr(columnarised, zarr_path, conversion_spec):
pcvcf,
zarr_path,
conversion_spec=spec,
worker_processes=worker_processes,
show_progress=True,
)

Expand Down

0 comments on commit 284ee6e

Please sign in to comment.