Skip to content

Commit

Permalink
Merge pull request #210 from qiyunzhu/vector
Browse files Browse the repository at this point in the history
Improved vectorized ordinal mapper
  • Loading branch information
qiyunzhu authored Oct 5, 2024
2 parents da32abf + 511e9cf commit 08ce8f8
Show file tree
Hide file tree
Showing 2 changed files with 82 additions and 69 deletions.
132 changes: 70 additions & 62 deletions woltka/ordinal.py
Original file line number Diff line number Diff line change
Expand Up @@ -185,7 +185,7 @@ def ordinal_mapper(fh, coords, idmap, fmt=None, excl=None, n=2**20, th=0.8,
th : float
Minimum threshold of overlap length : alignment length for a match.
prefix : bool
Prefix gene IDs with nucleotide IDs.
Prefix gene IDs with genome IDs.
See Also
--------
Expand All @@ -194,93 +194,103 @@ def ordinal_mapper(fh, coords, idmap, fmt=None, excl=None, n=2**20, th=0.8,
Yields
------
list of str
Query queue.
Read (query) queue.
list of set of str
Subject(s) queue.
Genes (subjects) queue.
"""
it = iter_align(fh, fmt, excl, True)

# arguments for flush_chunk
args = (coords, idmap, th, prefix)

# cached lists of read Ids and lengths (pre-allocate space)
# gene Ids are unique, but read Ids can have duplicates (i.e., one read is
# mapped to multiple loci on a genome), therefore an incremental integer
# here replaces the original read Id as its identifer
rids = [None] * n
lens = np.empty((n,), dtype=np.uint32)
# cached read information
qrys = [None] * n
lens = np.empty(n, dtype=np.uint32)
begs = np.empty(n, dtype=np.int64)
ends = np.empty(n, dtype=np.int64)

# cached map of reads to per-genome coordinates
locmap = defaultdict(list)
# arguments for flush_chunk
args = (qrys, lens, begs, ends, coords, idmap, th, prefix)

# current read index in the cached lists; will reset after each flush
# current read index in cache; will reset after each flush
idx = 0

# subject-to-indices mapping of cache
sub2idx = defaultdict(list)

# parse alignment file
for query, records in it:

# exclude hits with unavailable or zero length
records = [x for x in records if x[2]]

# when chunk limit is about to be exceeded by the next query, match
# currently cached reads with genes, flush, and reset
if idx + len(records) > n:
yield flush_chunk(idx, locmap, rids, lens, *args)
locmap = defaultdict(list)
yield flush_chunk(idx, sub2idx, *args)
idx = 0
sub2idx = defaultdict(list)

# extract alignment info and add to cache
# extract read info and add to cache
# hits with unavailable or zero length are excluded
for subject, _, length, beg, end in records:
rids[idx] = query
lens[idx] = length
locmap[subject].extend((
(beg << 24) + idx,
(end << 24) + (1 << 23) + idx))
idx += 1
if length:
qrys[idx] = query
lens[idx] = length
begs[idx] = beg
ends[idx] = end
sub2idx[subject].append(idx)
idx += 1

# final flush
yield flush_chunk(idx, locmap, rids, lens, *args)
yield flush_chunk(idx, sub2idx, *args)


def flush_chunk(n, rlocmap, rids, rlens, glocmap, gidmap, th, prefix):
def flush_chunk(n, idxmap, rids, lens, begs, ends, glocmap, gidmap, th,
prefix):
"""Match reads in current chunk with genes from all genomes.
Parameters
----------
n : int
Number of reads to flush.
rlocmap : dict of list
Read coordinates per genome.
idxmap : dict of list of int
Read indices per genome.
rids : list of str
Read IDs.
rlens : np.array(-1, dtype=int64)
Read identifiers.
lens : np.array(-1, dtype=uint32)
Read lengths.
begs : np.array(-1, dtype=int64)
Read start coordinates.
ends : np.array(-1, dtype=int64)
Read end coordinates.
glocmap : dict of list
Gene coordinates per genome.
gidmap : dict of list
Gene identifiers.
Gene identifiers per genome.
th : float
Length threshold.
prefix : bool
Prefix gene IDs with nucleotide IDs.
Prefix gene IDs with genome IDs.
Returns
-------
list of str
Query queue.
Read (query) queue.
list of set of str
Subject(s) queue.
Genes (subjects) queue.
"""
# master read-to-gene(s) map
res = defaultdict(set)

# effective length = length * th
rels = np.ceil(rlens[:n] * th).astype(np.uint32)
# calculate effective lengths of reads
rels = np.ceil(lens[:n] * th).astype(np.uint32)

# iterate over nucleotides
for nucl, rlocs in rlocmap.items():
# encode read start and end positions
idx = np.arange(n)
begs[:n] <<= 24
ends[:n] <<= 24
begs[:n] += idx
ends[:n] += idx + (1 << 23)

# it's possible that no gene was annotated on the nucleotide
# iterate over genomes:
for nucl, idx in idxmap.items():

# in case no gene was annotated on the genome
try:
glocs = glocmap[nucl]
except KeyError:
Expand All @@ -292,33 +302,33 @@ def flush_chunk(n, rlocmap, rids, rlens, glocmap, gidmap, th, prefix):
# append prefix if needed
pfx = nucl + '_' if prefix else ''

# convert list to array
rlocs = np.array(rlocs, dtype=np.int64)
# pair read starts and ends
idx = np.array(idx, dtype=np.uint32)
m = idx.size
locs = np.empty(2 * m, dtype=np.int64)
locs[0::2] = begs[idx]
locs[1::2] = ends[idx]

# execute ordinal algorithm when reads are many
# 10 (>5 reads) is an empirically determined cutoff
if rlocs.size > 10:
# 5 is an empirically determined cutoff
if m > 5:

# merge pre-sorted genes with reads of unknown sorting status
queue = np.concatenate((glocs, rlocs))
queue = np.concatenate((glocs, locs))

# sort genes and reads into a mixture
# timsort is efficient for this task
queue.sort(kind='stable')

# a potentially more efficient method is to use sortednp:
# rlocs.sort(kind='stable')
# queue = sortednp.merge(glocs, rlocs)

# map reads to genes using the core algorithm
matches = match_read_gene(queue, rels)
gen = match_read_gene(queue, rels)

# execute naive algorithm when reads are few
else:
matches = match_read_gene_quart(glocs, rlocs, rels)
gen = match_read_gene_quart(glocs, locs, rels)

# add read-gene pairs to the master map
for read, gene in matches:
for read, gene in gen:
res[rids[read]].add(pfx + gids[gene])

# return matching read Ids and gene Ids
Expand Down Expand Up @@ -446,23 +456,21 @@ def encode_genes(lst):

# order each pair of start and end coordinates such that smaller one
# comes first
# faster than np.sort since there are only two numbers
# < is slightly faster than np.less
cmp = beg < end
lo = np.where(cmp, beg, end)
hi = np.where(cmp, end, beg)

# encode coordinate, start/end, is gene, and index into one integer
lo = np.left_shift(lo - 1, 24) + (1 << 22) + idx
hi = np.left_shift(hi, 24) + (3 << 22) + idx
lo = (lo - 1 << 24) + (1 << 22) + idx
hi = (hi << 24) + (3 << 22) + idx

# fastest way to interleave two arrays
# https://stackoverflow.com/questions/5347065/
que = np.empty((2 * n,), dtype=np.int64)
que[0::2] = lo
que[1::2] = hi
queue = np.empty(2 * n, dtype=np.int64)
queue[0::2] = lo
queue[1::2] = hi

return que
return queue


@njit((int64[:], uint32[:]))
Expand Down
19 changes: 12 additions & 7 deletions woltka/tests/test_ordinal.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from tempfile import mkdtemp
from io import StringIO
from functools import partial
from collections import defaultdict

import numpy as np
import numpy.testing as npt
Expand Down Expand Up @@ -332,17 +333,21 @@ def test_flush_chunk(self):
'r9 n1 95 20 0 0 1 20 95 82 1 1',
'rx nx 95 0 0 0 0 0 0 0 1 1',
'# end of file')))
idx, rids, rlens, locmap = 0, [], [], {}
idx, sub2idx = 0, defaultdict(list)
qrys, lens, begs, ends = [], [], [], []
for query, records in parse_b6o_file_ex(aln):
for subject, _, length, beg, end in records:
rids.append(query)
rlens.append(length)
locmap.setdefault(subject, []).extend((
(beg << 24) + idx, (end << 24) + (1 << 23) + idx))
qrys.append(query)
lens.append(length)
begs.append(beg)
ends.append(end)
sub2idx[subject].append(idx)
idx += 1
rlens = np.array(rlens)
lens = np.array(lens, dtype=np.uint32)
begs = np.array(begs, dtype=np.int64)
ends = np.array(ends, dtype=np.int64)
obs = flush_chunk(
len(rids), locmap, rids, rlens, coords, idmap, 0.8, False)
idx, sub2idx, qrys, lens, begs, ends, coords, idmap, 0.8, False)
exp = [('r1', 'g1'),
('r5', 'g2'),
('r6', 'g2'),
Expand Down

0 comments on commit 08ce8f8

Please sign in to comment.