diff --git a/src/jcvi/graphics/base.py b/src/jcvi/graphics/base.py index 93676a4d..d2373143 100644 --- a/src/jcvi/graphics/base.py +++ b/src/jcvi/graphics/base.py @@ -690,7 +690,7 @@ def get_shades(base_color: str, n: int) -> np.array: # Generate lighter shades by blending with white white = np.array([1, 1, 1]) # White color as a NumPy array lighter_shades = [ - (white * (1 - i) + base_rgb * i) for i in np.linspace(0.3, 1, lighter + 1) + (white * (1 - i) + base_rgb * i) for i in np.linspace(0.2, 1, lighter + 1) ][:-1] # Generate darker shades by blending with black diff --git a/src/jcvi/graphics/chromosome.py b/src/jcvi/graphics/chromosome.py index 6aa6fada..97bd9dad 100644 --- a/src/jcvi/graphics/chromosome.py +++ b/src/jcvi/graphics/chromosome.py @@ -9,7 +9,7 @@ from itertools import groupby from math import ceil -from typing import Optional, Tuple +from typing import List, Optional, Tuple import numpy as np @@ -43,15 +43,15 @@ class Chromosome(BaseGlyph): def __init__( self, ax, - x, - y1, - y2, - width=0.015, - ec="k", - patch=None, - patchcolor="lightgrey", - lw=1, - zorder=2, + x: float, + y1: float, + y2: float, + width: float = 0.015, + ec: str = "k", + patch: Optional[List[float]] = None, + patchcolor: str = "lightgrey", + lw: int = 1, + zorder: int = 2, ): """ Chromosome with positions given in (x, y1) => (x, y2) diff --git a/src/jcvi/projects/sugarcane.py b/src/jcvi/projects/sugarcane.py index 1ed0b252..eaee80b3 100644 --- a/src/jcvi/projects/sugarcane.py +++ b/src/jcvi/projects/sugarcane.py @@ -30,7 +30,15 @@ from ..apps.base import ActionDispatcher, OptionParser, flatten, logger, mkdir from ..formats.blast import Blast -from ..graphics.base import adjust_spines, markup, normalize_axes, savefig +from ..graphics.base import ( + Rectangle, + adjust_spines, + get_shades, + markup, + normalize_axes, + savefig, +) +from ..graphics.chromosome import Chromosome as ChromosomePlot SoColor = "#7436a4" # Purple SsColor = "#5a8340" # Green @@ -159,6 +167,12 @@ def __init__( def __str__(self): return self.name + ": " + ";".join(str(_) for _ in self.chromosomes) + def ploidy(self, chrom: str) -> int: + """ + Return the ploidy of a chromosome. + """ + return sum(1 for x in self.chromosomes if x.chrom == chrom) + @classmethod def from_str(cls, s: str) -> "Genome": """ @@ -926,6 +940,47 @@ def divergence(args): savefig(image_name, dpi=iopts.dpi, iopts=iopts) +def plot_genome( + ax, + x: float, + y: float, + height: float, + genome: Genome, + haplotype_colors: Dict[str, str], + chrom_width: float = 0.012, + gap_width: float = 0.008, +): + """ + Plot the genome in the axes, centered around (x, y). + """ + target = "chr01" # Arbitrary target chromosome + ploidy = genome.ploidy(target) + total_width = ploidy * (chrom_width + gap_width) - gap_width + xx = x - total_width / 2 + for chrom in genome.chromosomes: + if chrom.chrom != target: + continue + ChromosomePlot(ax, xx, y, y - height, ec="lightslategray") + gene_height = height / len(chrom) + yy = y + for haplotype, genes in groupby(chrom.genes, key=lambda x: x.haplotype): + genes = list(genes) + g1, g2 = genes[0].idx - 1, genes[-1].idx + patch_height = gene_height * (g2 - g1) + color = haplotype_colors[haplotype] + ax.add_patch( + Rectangle( + (xx - chrom_width / 2, yy - patch_height), + chrom_width, + patch_height, + fc=color, + lw=0, + ) + ) + yy -= patch_height + xx += chrom_width + gap_width + + def chromosome(args): """ %prog chromosome [2n+n_FDR|2n+n_SDR|nx2+n] @@ -985,6 +1040,19 @@ def chromosome(args): fontweight="semibold", ) + SO_colors = get_shades(SoColor, SO_PLOIDY) + SS_colors = get_shades(SsColor, SS_PLOIDY) + SO_haplotypes = [chr(ord("a") + i) for i in range(SO_PLOIDY)] + SS_haplotypes = [chr(ord("a") + i) for i in range(SS_PLOIDY)] + SO_haplotype_colors = dict(zip(SO_haplotypes, SO_colors)) + SS_haplotype_colors = dict(zip(SS_haplotypes, SS_colors)) + + # Plotting + chrom_height = 0.1 + yy = 0.92 + plot_genome(root, 0.35, yy, chrom_height, SO, SO_haplotype_colors) + plot_genome(root, 0.75, yy, chrom_height, SS, SS_haplotype_colors) + # Title mode_title = get_mode_title(mode) root.text(0.5, 0.95, f"Transmission: {mode_title}", ha="center")