Skip to content

Commit

Permalink
Stuff
Browse files Browse the repository at this point in the history
  • Loading branch information
jeromekelleher committed Feb 1, 2024
1 parent 0f17a29 commit 9de0913
Show file tree
Hide file tree
Showing 2 changed files with 257 additions and 120 deletions.
336 changes: 231 additions & 105 deletions sgkit/io/vcf/vcf_converter.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,30 @@ def full_name(self):
return self.name
return f"{self.category}/{self.name}"

def smallest_dtype(self):
"""
Returns the smallest dtype suitable for this field based
on type, and values.
"""
s = self.summary
if self.vcf_type == "Float":
return "f4"
if self.vcf_type == "Integer":
dtype = "i4"
for a_dtype in ["i1", "i2"]:
info = np.iinfo(a_dtype)
if info.min <= s.min_value and s.max_value <= info.max:
dtype = a_dtype
break
return dtype
if self.vcf_type == "Flag":
return "bool"
assert self.vcf_type == "String"

if s.max_number == 0:
return "str"
return "O"


@dataclasses.dataclass
class VcfPartition:
Expand Down Expand Up @@ -261,7 +285,7 @@ def is_numeric(self):

def __repr__(self):
# TODO add class name
return repr({"path": str(self.path)})
return repr({"path": str(self.path), **self.vcf_field.summary.asdict()})

def write_chunk(self, partition_index, chunk_index, data):
path = self.path / f"p{partition_index}" / f"c{chunk_index}"
Expand Down Expand Up @@ -352,6 +376,7 @@ def update_bounds_integer(summary, value, number_dim):
assert len(value.shape) <= number_dim + 1
if len(value.shape) == number_dim + 1:
number = value.shape[number_dim]
summary.max_number = max(summary.max_number, number)


def update_bounds_string(summary, value):
Expand Down Expand Up @@ -462,12 +487,16 @@ def display_size(n):
return data

@functools.cached_property
def num_records(self):
return sum(partition.num_records for partition in self.metadata.partitions)

@property
def num_partitions(self):
return len(self.metadata.partitions)

@functools.cached_property
def num_records(self):
return sum(partition.num_records for partition in self.metadata.partitions)
@property
def num_samples(self):
return len(self.metadata.samples)

def mkdirs(self):
self.path.mkdir()
Expand Down Expand Up @@ -668,107 +697,204 @@ def explode(
)


# @dataclasses.dataclass
# class ZarrArrayDefinition:
# name: str
# dtype: str
# shape: tuple

# @staticmethod
# def from_numeric_column(col, bounds):
# if col.vcf_field.vcf_type == "Integer":
# dtype = None
# for a_dtype in ("i1", "i2", "i4"):
# info = np.iinfo(a_dtype)
# if info.min <= bounds.min_value and bounds.max_value <= info.max:
# dtype = a_dtype
# break
# else:
# raise ValueError("Value too something")
# else:
# dtype = "f4"
# shape = []
# if bounds.max_second_dimension > 1:
# shape.append(bounds.max_second_dimension)

# return ZarrArrayDefinition("", dtype, shape)


# # @dataclasses.dataclass
# # class ColumnConversionSpec:
# # vcf_field: VcfFieldDefinition
# # zarr_array: ZarrArrayDefinition


# # @dataclasses.dataclass
# # class ConversionSpec:
# # columns: list


# def plan_conversion(columnarised_path, out_file):
# pcv = PickleChunkedVcf.load(pathlib.Path(columnarised_path))
# # extract
# convert_columns = {
# name: col for name, col in pcv.columns.items() if name not in ["REF", "ALT"]
# }
# out = []
# for name, col in convert_columns.items():
# prefix = ""
# if col.vcf_field.category == "INFO":
# prefix = "variant_"
# elif col.vcf_field.category == "FORMAT":
# prefix = "call_"
# else:
# continue
# array_name = prefix + col.vcf_field.name
# # print(name, col)
# if col.is_numeric():
# bounds = col.get_bounds()
# # print(bounds)
# zarr_definition = ZarrArrayDefinition.from_numeric_column(col, bounds)
# zarr_definition.shape = [pcv.num_records] + zarr_definition.shape
# zarr_definition.name = array_name
# # print(zarr_definition)
# out.append(ColumnConversionSpec(col.vcf_field, zarr_definition))

# spec = ConversionSpec(out)
# print(json.dumps(dataclasses.asdict(spec), indent=4))


# def encode_zarr(
# columnarised_path,
# out_path,
# *,
# chunk_width=None,
# chunk_length=None,
# show_progress=False,
# ):
# pcv = PickleChunkedVcf.load(pathlib.Path(columnarised_path))

# # d = pcv.columns["CHROM"].get_counts()
# # print(d)

# # d= pcv.columns["FILTERS"].get_counts()
# # print(d)
# # ref = columns["REF"]
# # alt = columns["ALT"]

# # # print(pcv.columns["FORMAT/AD"].get_bounds())
# # with cf.ProcessPoolExecutor(max_workers=8) as executor:

# # future_to_col = {}

# # for col in pcv.columns.values():
# # if col.is_numeric():
# # print("dispatch", col)
# # future = executor.submit(col.get_bounds)
# # future_to_col[future] = col
# # # print(col)
# # # print(col.get_bounds())
# # for future in cf.as_completed(future_to_col):
# # col = future_to_col[future]
# # bounds = future.result()
# # print(col, bounds)
@dataclasses.dataclass
class ZarrColumnSpec:
vcf_field: str
name: str
dtype: str
shape: tuple


@dataclasses.dataclass
class ZarrConversionSpec:
chunk_width: int
chunk_length: int
columns: list

def asdict(self):
return dataclasses.asdict(self)

@staticmethod
def fromdict(d):
ret = ZarrConversionSpec(**d)
ret.columns = [ZarrColumnSpec(**cd) for cd in d["columns"]]
return ret

@staticmethod
def generate(pcvcf):
m = pcvcf.num_records
n = pcvcf.num_samples
colspecs = []
for field in pcvcf.metadata.fields:
if field.category == "fixed":
continue
shape = [m]
prefix = "variant_"
if field.category == "FORMAT":
prefix = "call_"
shape.append(n)
if field.summary.max_number > 1:
shape.append(field.summary.max_number)
if field.name == "GT":
# GT is a special case because we pull phasing last value
shape[2] -= 1
variable_name = prefix + field.name
colspec = ZarrColumnSpec(
vcf_field=field.full_name,
name=variable_name,
dtype=field.smallest_dtype(),
shape=shape,
)
colspecs.append(colspec)
# Arbitrary defaults here, we'll want to do something much more
# sophisticated I'd imagine.
return ZarrConversionSpec(
columns=colspecs, chunk_width=1000, chunk_length=10_000
)


class SgvcfZarr:
def __init__(self, path):
self.path = pathlib.Path(path)
self.root = None

def create_arrays(self, pcvcf, spec):
store = zarr.DirectoryStore(self.path)
num_variants = pcvcf.num_records
num_samplesa = pcvcf.num_samples

self.root = zarr.group(store=store, overwrite=True)
compressor = numcodecs.Blosc(
cname="zstd", clevel=7, shuffle=numcodecs.Blosc.AUTOSHUFFLE
)

def full_array(name, data, dimensions, *, dtype=None, chunks=None):
a = self.root.array(
name,
data,
dtype=dtype,
chunks=chunks,
compressor=compressor,
)
a.attrs["_ARRAY_DIMENSIONS"] = dimensions
return a


self.root.attrs["filters"] = pcvcf.metadata.filters
full_array("filter_id", pcvcf.metadata.filters, ["filters"], dtype="str")
full_array("contig_id", pcvcf.metadata.contig_names, ["configs"], dtype="str")
full_array(
"sample_id",
pcvcf.metadata.samples,
["samples"],
dtype="str",
chunks=[spec.chunk_width],
)

if pcvcf.metadata.contig_lengths is not None:
full_array(
"contig_length",
pcvcf.metadata.contig_lengths,
["configs"],
dtype=np.int64,
)

def empty_fixed_field_array(name, dtype, shape=None):
a = self.root.empty(
name,
shape=(num_variants,),
dtype=dtype,
chunks=(spec.chunk_length,),
compressor=compressor,
)
a.attrs["_ARRAY_DIMENSIONS"] = ["variants"]
return a

# FIXME get dtype from lookup table
empty_fixed_field_array("variant_contig", np.int16)
empty_fixed_field_array("variant_position", np.int32)
empty_fixed_field_array("variant_id", "str")
empty_fixed_field_array("variant_id_mask", bool)
empty_fixed_field_array("variant_quality", np.float32)
# TODO FILTER
# empty_fixed_field_array("variant_filter",
# shape=(m, len(vcf_metadata.filters)),
# chunks=(chunk_length),
# dtype=bool,
# compressor=compressor,
# )
# a.attrs["_ARRAY_DIMENSIONS"] = ["variants", "filters"]

# a = root.empty(
# "call_genotype",
# shape=(m, n, 2),
# chunks=(chunk_length, chunk_width),
# dtype=smallest_numpy_int_dtype(max_num_alleles),
# compressor=compressor,
# )
# a.attrs["_ARRAY_DIMENSIONS"] = ["variants", "samples", "ploidy"]

# a = root.empty(
# "call_genotype_mask",
# shape=(m, n, 2),
# chunks=(chunk_length, chunk_width),
# dtype=bool,
# compressor=compressor,
# )
# a.attrs["_ARRAY_DIMENSIONS"] = ["variants", "samples", "ploidy"]

# a = root.empty(
# "call_genotype_phased",
# shape=(m, n),
# chunks=(chunk_length, chunk_width),
# dtype=bool,
# compressor=compressor,
# )
# a.attrs["_ARRAY_DIMENSIONS"] = ["variants", "samples"]

for column in spec.columns:
if column.name == "call_GT":
# FIXME this shouldn't be in here.
continue
chunks = [spec.chunk_length]
dimensions = ["variants"]
if len(column.shape) > 1:
# TODO this should all be in the column spec
chunks.append(spec.chunk_width)
dimensions.append(["variants", "samples"])
if len(column.shape) > 2:
dimensions.append(["variants", "samples", column.vcf_field])
a = self.root.empty(
column.name,
shape=column.shape,
chunks=chunks,
dtype=column.dtype,
compressor=compressor,
)
a.attrs["_ARRAY_DIMENSIONS"] = dimensions
# print(a)

def encode_column(self, pcvcf, column):
source_col = pcvcf.columns[column.vcf_field]
print(source_col)
array = self.root[column.name]
print(array)
try:
a = np.array(list(source_col.iter_values()))
array[:] = a
print("WORKED")
except Exception as e:
print("error", e)

@staticmethod
def convert(pcvcf, path, conversion_spec, show_progress=False):
sgvcf = SgvcfZarr(path)
sgvcf.create_arrays(pcvcf, conversion_spec)

for column in conversion_spec.columns:
if "GT" not in column.name:
sgvcf.encode_column(pcvcf, column)



# def sync_flush_array(np_buffer, zarr_array, offset):
Expand Down
Loading

0 comments on commit 9de0913

Please sign in to comment.