Skip to content

Commit

Permalink
feat: Codebase Improvements - Modularization and Maintainability
Browse files Browse the repository at this point in the history
While refactoring `compute_summary_stats()`, we found it would be
helpful if the parent function tracked a list of output strings, so each
helper function can return a single string (or `str | None`)

feat: add PixyArgs dataclass to hold pixy args (ksamuk#40)

This PR adds a dataclass, `PixyArgs` to store arguments given to `pixy`
from the user at the command line. Addresses the first checkbox on ksamuk#36.

Highlights:
* Any arguments that were listed as `optional` in `__main__.py.main()`
are represented by a `Union[T, None]`.
* There are 2 string-based arguments here that I used `Enum` for: the
`Stats` (one or more of: `pi`, `dxy`, and `fst`) and `FSTEstimator` (one
of either `wc` or `hudson`).
* The `--bypass-invariant-check` that expects a "yes" or "no" value was
converted into a `bool`.

No additional tests are added in this PR -- this dataclass is only being
added here, not used. The next PR will start replacing current
functionality and thus, tests will be included in that PR.

refactor: Extract `precompute_filtered_variant_array` (ksamuk#43)

This is the first of several planned PRs to refactor the monolithic
`compute_summary_stats()` function.

It extracts a helper function to precompute the filtered genotype and
position arrays.

refactor: extract compute_summary_pi (ksamuk#50)

This PR extracts the computation of the summary `pi` statistics. It also
introduces a new dataclass, `PiResult`, to capture the results of the
`calc_pi` function.

To reduce the overhead associated with refactoring, I propose using
`Union[T, Literal["NA"]]` in most situations where we'd use `Union[T,
None]`

refactor: extract compute_summary_dxy (ksamuk#51)

Companion to ksamuk#50 , this PR extracts `compute_summary_dxy()`

feat: `PixyTempResult` (ksamuk#49)

Closes ksamuk#42.

This PR adds `PixyTempResult`, a dataclass that stores output from
`pixy` and helps write out results to a tab-delimited file. One test for
the `__str__()` method is added.

feat: use `PixyTempResult` (ksamuk#52)

Uses the `PixyTempResult` object introduced in ksamuk#49. **Only used with
reworked `pi` and `dxy`-based functions (not `fst`, which is pending
additional updates). We can hold this PR in draft form until we finalize
the fst functions.

refactor: extract `validate_populations_path()` (ksamuk#59)

This PR extracts out the checks related to a user-specific
populations_path. I wrote the docstring to reflect what the function
does now with the intention of future refactoring (i.e., I don't want to
raise a base Exception or use `print` forever).

No underlying code is changed as a result of this PR -- it's just code
movement. Existing tests cover these changes, albeit indirectly (another
item that will be fixed in a future refactoring).

refactor: extract `validate_bed_path()` (ksamuk#54)

This PR moves BED file-related validation and code out of
`check_and_validate_args` in `core.py` to a new function,
`validate_bed_path` in `args_validation.py`.

No underlying code is changed, only moved. A future PR will add
unit-tests, additional error handling, and other code changes.

refactor: extract `validate_sites_path()` (ksamuk#56)

This PR moves sites file-related validation and code out of
`check_and_validate_args()` in `core.py` to a new function,
`validate_sites_path` in args_validation.py.

No underlying code is changed, only moved. A future PR will add
unit-tests, additional error handling, and other code changes.

There is technically testing coverage here but it's a little indirect --
in `test_main.py` we have a test, `test_malformed_sites_file`, that
fails the assertions that were previously in
`check_and_validate_args()`. In a future PR we could refactor
`run_pixy_helper` to instead be `validate_sites_path`, happy to make
that change now or in the future.

refactor: extract `validate_vcf_path()` (ksamuk#58)

This PR adds `validate_vcf_path()` to `args_validation.py` and moves
code from `core.py` into that function. No underlying code is changed,
this is just code movement.

refactor: extract out `validate_output_path()` (ksamuk#60)

This PR extracts functionality related to the output_path
(output_folder, output_prefix, and temp_file) . As with the other
similar PRs, I wrote the docstring to reflect what the function does now
with the intention of future refactoring.

No underlying code is changed as a result of this PR -- it's just code
movement. Existing tests cover these changes, albeit indirectly (another
item that will be fixed in a future refactoring).

refactor: extract window/interval validation (ksamuk#64)

Refactored during sync

refactor: Extract `compute_summary_fst()` (ksamuk#55)

Companion to ksamuk#50 and ksamuk#51

refactor: move and clean up `check_and_validate_args()` (ksamuk#69)

This PR moves `check_and_validate_args()` out of `core.py` and into
`args_validation.py`. Next, it updates `check_and_validate_args()` to
return an instance of `PixyArgs`. `PixyArgs` is then used throughout
`__main__.py` instead of the large tuple that was previously returned.

Additionally, this PR adds some error handling on the `PixyArgs` class.
While refactoring that, I updated the tests to make sure that we were
catching bad values passed to `run_pixy_helper`. I added one more
unit-test about multiple chromosomes.

@msto this PR grew bigger than I planned for, so please let me know if
you would prefer multiple, smaller PRs.

---------

Co-authored-by: Matt Stone <[email protected]>
Co-authored-by: Erin McAuley <[email protected]>
  • Loading branch information
msto and emmcauley committed Jan 15, 2025
1 parent c718c17 commit 5c4a2b8
Show file tree
Hide file tree
Showing 13 changed files with 1,793 additions and 1,046 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -10,3 +10,4 @@ output/*
docs/_build

__pycache__/
.coverage
227 changes: 119 additions & 108 deletions pixy/__main__.py

Large diffs are not rendered by default.

676 changes: 676 additions & 0 deletions pixy/args_validation.py

Large diffs are not rendered by default.

75 changes: 44 additions & 31 deletions pixy/calc.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,20 @@
import typing
from typing import Tuple, Union, List, Any
from typing import Tuple, Union, List

from allel import AlleleCountsArray, GenotypeArray

import allel
import numpy as np
from numpy.typing import NDArray

from scipy import special

from pixy.models import NA
from pixy.models import DxyResult
from pixy.models import FstResult
from pixy.models import PiResult
from pixy.enums import FSTEstimator


# vectorized functions for calculating pi and dxy
# these are reimplementations of the original functions
Expand Down Expand Up @@ -38,7 +45,7 @@ def count_diff_comp_missing(row: AlleleCountsArray, n_haps: int) -> Tuple[int, i


# function for vectorized calculation of pi from a pre-filtered scikit-allel genotype matrix
def calc_pi(gt_array: GenotypeArray) -> Tuple[Union[float, str], int, int, int]:
def calc_pi(gt_array: GenotypeArray) -> PiResult:
"""Given a filtered genotype matrix, calculate `pi`.
Args:
Expand Down Expand Up @@ -78,20 +85,20 @@ def calc_pi(gt_array: GenotypeArray) -> Tuple[Union[float, str], int, int, int]:

# if there are valid data (comparisons between genotypes) at the site, compute average dxy
# otherwise return NA
avg_pi: Union[float, str]
if total_comps > 0:
avg_pi = total_diffs / total_comps
else:
avg_pi = "NA"
avg_pi: Union[float, NA] = total_diffs / total_comps if total_comps > 0 else "NA"

return (avg_pi, total_diffs, total_comps, total_missing)
return PiResult(
avg_pi=avg_pi,
total_diffs=total_diffs,
total_comps=total_comps,
total_missing=total_missing,
)


# function for vectorized calculation of dxy from a pre-filtered scikit-allel genotype matrix
def calc_dxy(
pop1_gt_array: GenotypeArray, pop2_gt_array: GenotypeArray
) -> Tuple[Union[float, str], int, int, int]:
"""Given a filtered genotype matrix, calculate `dxy`.
def calc_dxy(pop1_gt_array: GenotypeArray, pop2_gt_array: GenotypeArray) -> DxyResult:
"""
Given a filtered genotype matrix, calculate `dxy`.
Args:
pop1_gt_array: the GenotypeArray representing population-specific allele counts
Expand All @@ -103,7 +110,6 @@ def calc_dxy(
total_diffs: sum of the number of differences between the populations
total_comps: sum of the number of comparisons between the populations
total_missing: sum of the number of missing between the populations
"""

# the counts of each of the two alleles in each population at each site
Expand Down Expand Up @@ -134,13 +140,14 @@ def calc_dxy(

# if there are valid data (comparisons between genotypes) at the site, compute average dxy
# otherwise return NA
avg_dxy: Union[float, str]
if total_comps > 0:
avg_dxy = total_diffs / total_comps
else:
avg_dxy = "NA"
avg_dxy: Union[float, NA] = total_diffs / total_comps if total_comps > 0 else "NA"

return (avg_dxy, total_diffs, total_comps, total_missing)
return DxyResult(
avg_dxy=avg_dxy,
total_diffs=total_diffs,
total_comps=total_comps,
total_missing=total_missing,
)


# function for obtaining fst AND variance components via scikit allel function
Expand All @@ -149,8 +156,8 @@ def calc_dxy(
# in aggregation mode, we just want a,b,c and n_sites for aggregating and fst
@typing.no_type_check
def calc_fst(
gt_array_fst: GenotypeArray, fst_pop_indicies: List[List[int]], fst_type: str
) -> Tuple[Union[float, str], float, float, Union[float, int], int]:
gt_array_fst: GenotypeArray, fst_pop_indicies: List[List[int]], fst_type: FSTEstimator
) -> FstResult:
# TODO: update the return type here after refactoring (2 -> 1 return statements)
"""Calculates FST according to either Weir and Cockerham (1984) or Hudson (1992).
Expand Down Expand Up @@ -185,6 +192,9 @@ def calc_fst(
a: List[float]
b: List[float]
c: Union[List[float], int]

result: FstResult

# WC 84
if fst_type == "wc":
a, b, c = allel.weir_cockerham_fst(gt_array_fst, subpops=fst_pop_indicies)
Expand All @@ -201,8 +211,7 @@ def calc_fst(
else:
fst = "NA"

return (fst, a, b, c, n_sites)
# TODO: can't get mypy to recognize multiple return statements of slightly different structures?
result = FstResult(fst=fst, a=a, b=b, c=c, n_sites=n_sites)

# Hudson 92
if fst_type == "hudson":
Expand All @@ -228,17 +237,20 @@ def calc_fst(

# same abc format as WC84, where 'a' is the numerator and
# 'b' is the demoninator, and 'c' is a zero placeholder
return (fst, num, den, c, n_sites)
# TODO: mypy thinks this function is missing a return statement
result = FstResult(fst=fst, a=num, b=den, c=c, n_sites=n_sites)

return result


# simplified version of above to handle the case
# of per-site estimates of FST over whole chunks


def calc_fst_persite(
gt_array_fst: GenotypeArray, fst_pop_indicies: List[List[int]], fst_type: str
) -> Any:
gt_array_fst: GenotypeArray,
fst_pop_indicies: List[List[int]],
fst_type: str,
) -> NDArray[np.float64]:
"""Calculates site-specific FST according to Weir and Cockerham (1984) or Hudson (1992).
Args:
Expand All @@ -252,16 +264,14 @@ def calc_fst_persite(
"""

# compute basic (multisite) FST via scikit allel
fst: NDArray[np.float64]

# WC 84
if fst_type == "wc":
a, b, c = allel.weir_cockerham_fst(gt_array_fst, subpops=fst_pop_indicies)

fst = np.sum(a, axis=1) / (np.sum(a, axis=1) + np.sum(b, axis=1) + np.sum(c, axis=1))

return fst
# TODO: fix nested returns

# Hudson 92
elif fst_type == "hudson":
# following scikit allel docs
Expand All @@ -274,4 +284,7 @@ def calc_fst_persite(

fst = num / den

return fst
else:
raise ValueError("fst_type must be either 'wc' or 'hudson'")

return fst
Loading

0 comments on commit 5c4a2b8

Please sign in to comment.