Skip to content

Commit

Permalink
Column sanitisation seems to be mostly working
Browse files Browse the repository at this point in the history
  • Loading branch information
jeromekelleher committed Feb 2, 2024
1 parent 45b26df commit 028ee11
Show file tree
Hide file tree
Showing 2 changed files with 204 additions and 89 deletions.
284 changes: 199 additions & 85 deletions sgkit/io/vcf/vcf_converter.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,9 +29,9 @@
# CHAR_FILL,
# CHAR_MISSING,
FLOAT32_FILL,
# FLOAT32_MISSING,
FLOAT32_MISSING,
INT_FILL,
# INT_MISSING,
INT_MISSING,
STR_FILL,
# STR_MISSING,
str_is_int,
Expand Down Expand Up @@ -120,22 +120,25 @@ def smallest_dtype(self):
"""
s = self.summary
if self.vcf_type == "Float":
return "f4"
if self.vcf_type == "Integer":
ret = "f4"
elif self.vcf_type == "Integer":
dtype = "i4"
for a_dtype in ["i1", "i2"]:
info = np.iinfo(a_dtype)
if info.min <= s.min_value and s.max_value <= info.max:
dtype = a_dtype
break
return dtype
if self.vcf_type == "Flag":
return "bool"
assert self.vcf_type == "String"

if s.max_number == 0:
return "str"
return "O"
ret = dtype
elif self.vcf_type == "Flag":
ret = "bool"
else:
assert self.vcf_type == "String"
if s.max_number == 0:
ret = "str"
else:
ret = "O"
# print("smallest dtype", self.name, self.vcf_type,":", ret)
return ret


@dataclasses.dataclass
Expand Down Expand Up @@ -258,6 +261,113 @@ def scan_vcfs(paths, show_progress):
return vcf_metadata


def sanitise_value_bool(buff, j, value):
x = True
if value is None:
x = False
buff[j] = x


def sanitise_value_float_scalar(buff, j, value):
x = value
if value is None:
x = FLOAT32_MISSING
buff[j] = x


def sanitise_value_int_scalar(buff, j, value):
x = value
if value is None:
x = -1
# TODO check for missing values as well
buff[j] = x


def sanitise_value_string_scalar(buff, j, value):
x = value
if value is None:
x = ""
# TODO check for missing values as well
buff[j] = x


def sanitise_value_string_1d(buff, j, value):
if value is None:
buff[j] = ""
else:
value = np.array(value, ndmin=1, dtype=buff.dtype, copy=False)
value = drop_empty_second_dim(value)
buff[j] = ""
# TODO check for missing?
buff[j, : value.shape[0]] = value


def sanitise_value_string_2d(buff, j, value):
if value is None:
buff[j] = ""
else:
value = np.array(value, ndmin=1, dtype=buff.dtype, copy=False)
value = drop_empty_second_dim(value)
buff[j] = ""
# TODO check for missing?
buff[j, : value.shape[0]] = value


def drop_empty_second_dim(value):
assert len(value.shape) == 1 or value.shape[1] == 1
if len(value.shape) == 2 and value.shape[1] == 1:
value = value[..., 0]
return value


def sanitise_value_float_1d(buff, j, value):
if value is None:
buff[j] = FLOAT32_MISSING
else:
value = np.array(value, ndmin=1, dtype=buff.dtype, copy=False)
value = drop_empty_second_dim(value)
buff[j] = FLOAT32_FILL
# TODO check for missing?
buff[j, : value.shape[0]] = value


def sanitise_value_float_2d(buff, j, value):
if value is None:
buff[j] = FLOAT32_MISSING
else:
value = np.array(value, dtype=buff.dtype, copy=False)
buff[j] = FLOAT32_FILL
# TODO check for missing?
buff[j, :, : value.shape[0]] = value


def sanitise_int_array(value, ndmin, dtype):
value = np.array(value, ndmin=ndmin, dtype=dtype, copy=False)
# FIXME
value[value == (MIN_INT_VALUE - 2)] = -1
value[value == (MIN_INT_VALUE - 1)] = -2
return value


def sanitise_value_int_1d(buff, j, value):
if value is None:
buff[j] = -1
else:
value = sanitise_int_array(value, 1, buff.dtype)
value = drop_empty_second_dim(value)
buff[j] = -2
buff[j, : value.shape[0]] = value


def sanitise_value_int_2d(buff, j, value):
if value is None:
buff[j] = -1
else:
value = sanitise_int_array(value, 2, buff.dtype)
buff[j] = -2
buff[j, :, : value.shape[1]] = value


class PickleChunkedVcfField:
def __init__(self, vcf_field, base_path):
self.vcf_field = vcf_field
Expand Down Expand Up @@ -320,33 +430,41 @@ def iter_values(self):
f"Corruption detected: incorrect number of records in {str(self.path)}."
)

def get_bounds(self):
filter_missing_int = False
if self.vcf_field.vcf_type == "Integer":
filter_missing_int = True
# cyvcf2 represents missing Integer values as the minimum
# int32 value and fill as minimum int32 value + 1
sentinel = np.iinfo(np.int32).min + 1

min_value = np.inf
max_value = -np.inf
max_second_dimension = 0
num_missing = 0
for value in self.iter_values():
if value is not None:
value = np.array(value)
if filter_missing_int:
value[value <= sentinel] = 0
max_value = max(max_value, np.max(value))
min_value = min(min_value, np.min(value))
assert len(value.shape) <= 2
if len(value.shape) == 2:
max_second_dimension = max(max_second_dimension, value.shape[1])
def sanitiser_factory(self, shape):
"""
Return a function that sanitised values from this column
and writes into a buffer of the specified shape.
"""
assert len(shape) <= 3
if self.vcf_field.vcf_type == "Flag":
assert len(shape) == 1
return sanitise_value_bool
elif self.vcf_field.vcf_type == "Float":
if len(shape) == 1:
return sanitise_value_float_scalar
elif len(shape) == 2:
return sanitise_value_float_1d
else:
num_missing += 1
return NumericColumnBounds(
min_value, max_value, max_second_dimension, num_missing
)
return sanitise_value_float_2d
elif self.vcf_field.vcf_type == "Integer":
if len(shape) == 1:
return sanitise_value_int_scalar
elif len(shape) == 2:
return sanitise_value_int_1d
else:
return sanitise_value_int_2d
else:
assert self.vcf_field.vcf_type == "String"
if len(shape) == 1:
return sanitise_value_string_scalar
elif len(shape) == 2:
return sanitise_value_string_1d
else:
return sanitise_value_string_2d

print(shape)

# return ret


def update_bounds_float(summary, value, number_dim):
Expand Down Expand Up @@ -892,74 +1010,70 @@ def empty_fixed_field_array(name, dtype, shape=None):
a.attrs["_ARRAY_DIMENSIONS"] = dimensions
# print(a)

def encode_flag_column(self, source_col, array):
print("FLAG", source_col, array)
a = np.zeros_like(array)
# print(a)
for j, val in enumerate(source_col.iter_values()):
if val is not None:
a[j] = True
# print(a)
array[:] = a

def encode_gt_int_column(self, source_col, array, executor):
def encode_column(self, pcvcf, column):
source_col = pcvcf.columns[column.vcf_field]
array = self.root[column.name]
ba = BufferedArray(array)
sanitiser = source_col.sanitiser_factory(ba.buff.shape)
chunk_length = array.chunks[0]
num_variants = array.shape[0]
futures = []

chunk_start = 0
j = 0
pbar = tqdm.tqdm(total=num_variants, desc=source_col.vcf_field.full_name)
for index, value in enumerate(source_col.iter_values()):
pbar.update(1)
if value is not None:
ba.buff[j] = value.reshape(ba.buff.shape[1:])
j += 1
if j == chunk_length:

with cf.ThreadPoolExecutor(max_workers=4) as executor:
futures = []
chunk_start = 0
j = 0
iterator = tqdm.tqdm(
source_col.iter_values(),
total=num_variants,
desc=source_col.vcf_field.full_name,
)
for value in iterator:
sanitiser(ba.buff, j, value)
j += 1
if j == chunk_length:
flush_futures(futures)
futures.extend(
async_flush_array(executor, ba.buff, ba.array, chunk_start)
)
ba.swap_buffers()
j = 0
chunk_start += chunk_length
if j != 0:
flush_futures(futures)
futures.extend(
async_flush_array(executor, ba.buff, ba.array, chunk_start)
async_flush_array(executor, ba.buff[:j], ba.array, chunk_start)
)
ba.swap_buffers()
j = 0
chunk_start += chunk_length
if j != 0:
flush_futures(futures)
futures.extend(
async_flush_array(executor, ba.buff[:j], ba.array, chunk_start)
)
pbar.close()

def encode_column(self, pcvcf, column):
source_col = pcvcf.columns[column.vcf_field]
array = self.root[column.name]
with cf.ThreadPoolExecutor(max_workers=4) as executor:
self.encode_gt_int_column(source_col, array, executor)

@staticmethod
def convert(pcvcf, path, conversion_spec, show_progress=False):
sgvcf = SgvcfZarr(path)
sgvcf.create_arrays(pcvcf, conversion_spec)

for column in conversion_spec.columns:
for column in conversion_spec.columns[::-1]:
# TODO change this variable to array_name or something, this is
# getting very confusing.
# print(column.name)
# if column.name == "call_GQ":
# if column.name == "variant_position":

# if "GT" not in column.name:
if "GT" in column.name:
continue

# FIXME we seem to be calling FORMAT/PID as a String type not integer
# for some reason. Need to dig in. Looks like the GIL starts hitting
# when we have large string columns.
if "AB" not in column.name and "GT" not in column.name:
try:
sgvcf.encode_column(pcvcf, column)
except Exception as e:
print("ERROR", e)
# break
# if column.dtype == "bool":
# if column.dtype.startswith("s"):
sgvcf.encode_column(pcvcf, column)

# if "variant_POSITIVE_TRAIN_SITE" in column.name:
# sgvcf.encode_column(pcvcf, column)

# if "AB" not in column.name and "GT" not in column.name:
# try:
# sgvcf.encode_column(pcvcf, column)
# except Exception as e:
# print("ERROR", e)
# # break


def sync_flush_array(np_buffer, zarr_array, offset):
Expand Down
9 changes: 5 additions & 4 deletions vcf2zarr.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,12 +48,13 @@ def summarise(columnarised):

@click.command
@click.argument("columnarised", type=click.Path())
@click.argument("specfile", type=click.Path())
def genspec(columnarised, specfile):
# @click.argument("specfile", type=click.Path())
def genspec(columnarised):
pcvcf = cnv.PickleChunkedVcf.load(columnarised)
spec = cnv.ZarrConversionSpec.generate(pcvcf)
with open(specfile, "w") as f:
json.dump(spec.asdict(), f, indent=4)
# with open(specfile, "w") as f:
stream = click.get_text_stream("stdout")
json.dump(spec.asdict(), stream, indent=4)



Expand Down

0 comments on commit 028ee11

Please sign in to comment.