Skip to content

Commit

Permalink
fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
nh13 committed Dec 19, 2024
1 parent e568eef commit 367d08c
Show file tree
Hide file tree
Showing 2 changed files with 102 additions and 134 deletions.
2 changes: 1 addition & 1 deletion pybwa/libbwamem.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ class BwaMemOptions:
min_seeded_bases_in_chain: int
seed_occurrence_in_3rd_round: int
xa_max_hits: int | tuple[int, int]
xa_drop_ration: float
xa_drop_ratio: float
gap_open_penalty: int | tuple[int, int]
gap_extension_penalty: int | tuple[int, int]
clipping_penalty: int | tuple[int, int]
Expand Down
234 changes: 101 additions & 133 deletions pybwa/libbwamem.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -168,7 +168,7 @@ cdef class BwaMemOptions:
def __get__(self):
return self._options.max_XA_hits, self._options.max_XA_hits_alt

property xa_drop_ration:
property xa_drop_ratio:
"""bwa mem -y <float>"""
def __get__(self):
return self._options.XA_drop_ratio
Expand Down Expand Up @@ -486,7 +486,7 @@ cdef class BwaMemOptionsBuilder(BwaMemOptions):
self._options.max_XA_hits = left
self._options.max_XA_hits_alt = right

property xa_drop_ration:
property xa_drop_ratio:
"""bwa mem -y <float>"""
def __get__(self):
return self._options.XA_drop_ratio
Expand Down Expand Up @@ -595,97 +595,22 @@ cdef class BwaMem:
rec.query_qualities = query.quality
return rec

# mimics mem_aln2sam in bwamem.c
cdef _build_alignment(self, opt: BwaMemOptions, query: FastxRecord, mem_aln_t *mem_aln_v, int mem_aln_n, int which):
cdef mem_aln_t mem_aln
cdef mem_aln_t* mem_aln_other

# create a AlignedSegment record here
rec = self._unmapped(query=query)

mem_aln = mem_aln_v[which]

# set the flags
mem_aln.flag |= 0x4 if mem_aln.rid < 0 else 0
mem_aln.flag |= 0x10 if mem_aln.is_rev > 0 else 0
rec.flag = (mem_aln.flag & 0xffff) | (0x100 if (mem_aln.flag & 0x10000) != 0 else 0)

# sequence and qualities
rec.query_sequence = query.sequence if rec.is_forward else query.sequence[::-1]
if query.quality is not None:
rec.query_qualities = query.quality if rec.is_forward else query.quality[::-1]

if rec.is_unmapped:
return rec

# reference id, position, mapq, and cigar
rec.reference_id = mem_aln.rid
rec.reference_start = mem_aln.pos
rec.mapping_quality = mem_aln.mapq
cigar = ""
for i in range(mem_aln.n_cigar):
cigar_op = mem_aln.cigar[i] & 0xf
if opt.soft_clip_supplementary and mem_aln.is_alt == 0 and (
cigar_op == 3 or cigar_op == 4):
cigar_op = 4 if mem_aln_n > 0 else 3 # // use hard clipping for supplementary alignments
cigar_len = mem_aln.cigar[i] >> 4
cigar = cigar + f"{cigar_len}" + "MIDS"[cigar_op]
rec.cigarstring = cigar

# remove leading and trailing soft-clipped bases for non-primary etc.
if mem_aln.n_cigar > 0 and mem_aln_n > 0 and not opt.soft_clip_supplementary and not mem_aln.is_alt:
qb = 0
qe = len(query.sequence)
if (mem_aln.cigar[0] & 0xf) == 4 or (mem_aln.cigar[0] & 0xf) == 3:
qb += mem_aln.cigar[0] >> 4
if (mem_aln.cigar[mem_aln.n_cigar - 1] & 0xf) == 4 or (
mem_aln.cigar[mem_aln.n_cigar - 1] & 0xf) == 3:
qe -= mem_aln.cigar[mem_aln.n_cigar - 1] >> 4
rec.query_sequence = rec.query_sequence[qb:qe]
if query.quality is not None:
rec.query_qualities = rec.query_qualities[qb:qe]

# Optional tags
md = <char *> (mem_aln.cigar + mem_aln.n_cigar)
attrs = dict()
attrs["NM"] = f"{mem_aln.NM}"
attrs["MD"] = f"{md}"
# NB: mate tags are not output: MC, MQ
if mem_aln.score >= 0:
attrs["AS"] = mem_aln.score
if mem_aln.sub >= 0:
attrs["XS"] = mem_aln.sub
# if mem_aln.flag & 0x100 == 0: # secondary
# # find if there are other primary hits, and if so, output them in the SA tag
# for i in range(mem_aln_n):
# if i == which or mem_aln_v[i].flag & 0x100 == 0:
# break
# if i < mem_aln_n:
# for i in range(mem_aln_n):
# mem_aln_other = &mem_aln_v[i]
# if i == which or mem_aln_other.flag & 0x100 != 0:
# continue
# SA = self._index.bns().anns[mem_aln_other.rid].name + ","
# SA += f"{mem_aln_other.pos + 1},"
# SA += "+-"[mem_aln_other.is_rev] + ","
# for k in range(mem_aln_other.n_cigar):
# cigar_op = mem_aln_other.cigar[i] & 0xf
# cigar_len = mem_aln_other.cigar[i] >> 4
# SA += f"{cigar_len}" + "MIDSH"[cigar_op]
# SA += f",{mem_aln_other.mapq}"
# SA += f",{mem_aln_other.NM}"
# SA += ";"
# attrs["SA"] = SA
# if mem_aln.alt_sc > 0:
# attrs['pa'] = mem_aln.score / float(mem_aln.alt_sc)
# if mem_aln.XA != NULL:
# attrs["XB" if opt.with_xb_tag else "XA"] = mem_aln.XA
# if opt.with_xr_tag and self._index.bns().anns[rec.reference_id].anno != 0 and \
# self._index.bns().anns[rec.reference_id].anno[0] != 0:
# attrs["XR"] = self._index.bns().anns[rec.reference_id].anno
rec.set_tags(list(attrs.items()))

return rec
def _add_sa_tag(self, records: list[AlignedSegment]) -> None:
num_non_secondary = sum(1 for record in records if not record.is_secondary)
if num_non_secondary <= 1:
return
for i, record in enumerate(records):
if record.is_secondary:
continue
SA = ""
for j, other in enumerate(records):
if i == j or other.is_secondary:
continue
SA = f"{other.reference_name},{other.reference_start+1},"
SA += "+" if other.is_forward else "-"
SA += f",{other.cigarstring}"
SA += f",{other.mapq},{other.get_tag('NM')};"
record.set_tag("SA", SA)

cdef _calign(self, opt: BwaMemOptions, queries: List[FastxRecord]):
# TODO: ignore_alt
Expand All @@ -700,83 +625,126 @@ cdef class BwaMem:
cdef char **XA
cdef mem_alnreg_t *mem_alnreg
cdef mem_aln_t mem_aln
cdef mem_aln_t *mem_aln_v
cdef int mem_aln_n
cdef char *md
cdef mem_opt_t *mem_opt
print(f"bwamem [start] num_seqs: {len(queries)}")

mem_opt = opt.mem_opt()

recs_to_return: List[List[AlignedSegment]] = []

# copy FastqProxy into bwa_seq_t
num_seqs = len(queries)
print(f"bwamem [start] num_seqs: {num_seqs}")
mem_opt = opt.mem_opt()

seqs = <kstring_t*>calloc(sizeof(kstring_t), num_seqs)
for i in range(num_seqs):
print(f"bwamem [start] {i}")
seq = &seqs[i]
query = queries[i]
self._copy_seq(queries[i], seq)
print(f"bwamem [mem_align1] i: {i}")
self._copy_seq(query, seq)
mem_alnregs = mem_align1(mem_opt, self._index.bwt(), self._index.bns(), self._index.pac(), seq.l, seq.s)
if opt.query_coord_as_primary:
print(f"bwamem [mem_reorder_primary5] i: {i}")
mem_reorder_primary5(opt.minimum_score, &mem_alnregs)

# mimic mem_reg2sam from bwamem.c
mem_aln_n = 0
mem_aln_v = <mem_aln_t*>calloc(sizeof(mem_aln_t), mem_alnregs.n)

recs = []
XA = NULL
keep_all = opt.output_all_for_fragments
if not keep_all:
print(f"bwamem [mem_gen_alt] i: {i}")
print(f"XA: {keep_all}")
XA = mem_gen_alt(mem_opt, self._index.bns(), self._index.pac(), &mem_alnregs, seq.l, seq.s)

mapped_recs = []
for j in range(mem_alnregs.n):
mem_alnreg = &mem_alnregs.a[j]

print(f"bwamem [loop] i: {i} j: {j}")
if mem_alnreg.score < opt.minimum_score:
continue
if mem_alnreg.secondary >= 0 and (mem_alnreg.is_alt or not keep_all):
continue
if 0 <= mem_alnreg.secondary < INT_MAX and mem_alnreg.score < mem_alnregs.a[mem_alnreg.secondary].score * opt.xa_drop_ration:
if 0 <= mem_alnreg.secondary < INT_MAX and mem_alnreg.score < mem_alnregs.a[mem_alnreg.secondary].score * opt.xa_drop_ratio:
continue
print(f"bwamem [mem_reg2aln] i: {i} j: {j}")
mem_aln = mem_reg2aln(mem_opt, self._index.bns(), self._index.pac(), seq.l, seq.s, mem_alnreg)
mem_aln.XA = XA[j] if not keep_all else NULL
mem_aln.XA = XA[j] if XA != NULL else NULL
if mem_alnreg.secondary >= 0:
mem_aln.sub = -1 # don't output sub-optimal score
if mem_aln_n > 0 > mem_alnreg.secondary: # if supplementary
if len(mapped_recs) > 0 and mem_alnreg.secondary < 0: # if supplementary
mem_aln.flag |= 0x10000 if opt.skip_mate_rescue else 0x800
if opt.keep_mapq_for_supplementary and mem_aln_n > 0 and mem_alnreg.is_alt > 0 and mem_aln.mapq > mem_alnregs.a[0].mapq:
if opt.keep_mapq_for_supplementary and len(mapped_recs) > 0 and mem_alnreg.is_alt > 0 and mem_aln.mapq > mem_alnregs.a[0].mapq:
mem_aln.mapq = mem_alnregs.a[0].mapq # lower

mem_aln_v[mem_aln_n] = mem_aln
mem_aln_n += 1
rec = self._unmapped(query=query)

print(f"bwamem [after] i: {i} mem_aln_n: {mem_aln_n}")
if mem_aln_n == 0:
recs = [self._unmapped(query=queries[i])]
# set the flags
mem_aln.flag |= 0x4 if mem_aln.rid < 0 else 0
mem_aln.flag |= 0x10 if mem_aln.is_rev > 0 else 0
rec.flag = (mem_aln.flag & 0xffff) | (0x100 if (mem_aln.flag & 0x10000) != 0 else 0)
if rec.is_unmapped:
continue

# sequence and qualities
if not rec.is_secondary:
rec.query_sequence = query.sequence if rec.is_forward else query.sequence[::-1]
if query.quality is not None:
rec.query_qualities = query.quality if rec.is_forward else query.quality[::-1]

# reference id, position, mapq, and cigar
rec.reference_id = mem_aln.rid
rec.reference_start = mem_aln.pos
rec.mapping_quality = mem_aln.mapq
cigar = ""
for k in range(mem_aln.n_cigar):
cigar_op = mem_aln.cigar[k] & 0xf
if opt.soft_clip_supplementary and mem_aln.is_alt == 0 and (
cigar_op == 3 or cigar_op == 4):
cigar_op = 4 if rec.is_supplementary else 3 # // use hard clipping for supplementary alignments
cigar_len = mem_aln.cigar[k] >> 4
cigar += f"{cigar_len}" + "MIDS"[cigar_op]
rec.cigarstring = cigar

# remove leading and trailing soft-clipped bases for non-primary etc.
if mem_aln.n_cigar > 0 and not rec.is_secondary and not opt.soft_clip_supplementary and not mem_aln.is_alt:
qb = 0
qe = len(query.sequence)
leading_op = mem_aln.cigar[0] & 0xf
trailing_op = mem_aln.cigar[mem_aln.n_cigar - 1] & 0xf
if leading_op == 3 or leading_op == 4:
qb += mem_aln.cigar[0] >> 4
if trailing_op == 3 or trailing_op == 4:
qe -= mem_aln.cigar[mem_aln.n_cigar - 1] >> 4
rec.query_sequence = rec.query_sequence[qb:qe]
if query.quality is not None:
rec.query_qualities = rec.query_qualities[qb:qe]

# Optional tags
attrs = dict()
if mem_aln.n_cigar > 0:
attrs["NM"] = f"{mem_aln.NM}"
md = <char *> (mem_aln.cigar + mem_aln.n_cigar)
attrs["MD"] = f"{md}"
# NB: mate tags are not output: MC, MQ
if mem_aln.score >= 0:
attrs["AS"] = mem_aln.score
if mem_aln.sub >= 0:
attrs["XS"] = mem_aln.sub
# NB: SA is added after all the records have been created
if mem_aln.XA != NULL:
attrs["XB" if opt.with_xb_tag else "XA"] = mem_aln.XA
if opt.with_xr_tag and self._index.bns().anns[rec.reference_id].anno != 0 and \
self._index.bns().anns[rec.reference_id].anno[0] != 0:
attrs["XR"] = self._index.bns().anns[rec.reference_id].anno
rec.set_tags(list(attrs.items()))

mapped_recs.append(rec)

free(mem_aln.cigar)
if len(mapped_recs) == 0:
recs_to_return.append([self._unmapped(query=query)])
else:
print(f"bwamem [_build_alignment] i: {i} mem_aln_n: {mem_aln_n}")
recs = [
self._build_alignment(opt, query, mem_aln_v, mem_aln_n, j)
for j in range(mem_aln_n)
]
print(f"bwamem [freeing] i: {i} mem_aln_n: {mem_aln_n}")
for j in range(mem_aln_n):
free(XA[j])
free(mem_aln_v[j].cigar)
free(XA)
free(mem_aln_v)
self._add_sa_tag(mapped_recs)
recs_to_return.append(mapped_recs)

if XA != NULL:
for j in range(len(mapped_recs)):
free(XA[j])
free(XA)
free(mem_alnregs.a)
recs_to_return.append(recs)

print(f"bwamem [freeing]")
for i in range(num_seqs):
free(seqs[i].s)
free(seqs)
Expand Down

0 comments on commit 367d08c

Please sign in to comment.