-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathsplit_bam_by_cluster.py
71 lines (57 loc) · 2.03 KB
/
split_bam_by_cluster.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
from __future__ import print_function
# Split the bam file by cluster ID.
# Credited to https://divingintogeneticsandgenomics.rbind.io/post/split-a-10xscatac-bam-file-by-cluster/
import os
import sys
import pysam
import csv
cluster_file = sys.argv[1]
bam_file = sys.argv[2]
output_location = sys.argv[3]
output_prefix = sys.argv[4]
#strand_type=sys.argv[5]
strand_split=False
if len(sys.argv)==7:
strand_split=True
cluster_dict = {}
with open(cluster_file) as csv_file:
csv_reader = csv.reader(csv_file, delimiter='\t')
# skip header
header = next(csv_reader)
for row in csv_reader:
cluster_dict[row[0]] = row[1]
clusters = set(x for x in cluster_dict.values())
fin = pysam.AlignmentFile(bam_file, "rb")
# open the number of bam files as the same number of clusters, and map the out file handler to the cluster id,
# write to a bam with wb
fouts_dict = {}
for cluster in clusters:
if strand_split:
output_filename = os.path.join(output_location, f"{output_prefix}_{cluster}_{strand_type}.bam")
fout = pysam.AlignmentFile(output_filename, "wb", template=fin)
fouts_dict[cluster] = fout
else:
output_filename = os.path.join(output_location, "{}_{}.bam".format(output_prefix, cluster))
fout = pysam.AlignmentFile(output_filename, "wb", template=fin)
fouts_dict[cluster] = fout
for read in fin:
if not read.is_proper_pair:
continue
## if keep reversed, is_reverse should match
if strand_split and (read.is_reverse != (strand_type == "reversed")):
continue
tags = read.tags
CB_list = [x for x in tags if x[0] == "CB"]
if CB_list:
cell_barcode = CB_list[0][1]
# the bam files may contain reads not in the final clustered barcodes
# will be None if the barcode is not in the clusters.csv file
else:
continue
cluster_id = cluster_dict.get(cell_barcode)
if cluster_id:
fouts_dict[cluster_id].write(read)
# do not forget to close the files
fin.close()
for fout in fouts_dict.values():
fout.close()