Skip to content

Commit

Permalink
Documentation improvements
Browse files Browse the repository at this point in the history
  • Loading branch information
maartenbuyl committed Apr 26, 2024
1 parent 7f7bcae commit 5fcf252
Show file tree
Hide file tree
Showing 5 changed files with 13 additions and 12 deletions.
1 change: 1 addition & 0 deletions .github/workflows/pypi_release.yml
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ jobs:
uses: actions/setup-python@v4
with:
python-version: "3.x"
cache: `pip`
- name: Install pypa/build
run: >-
python3 -m
Expand Down
1 change: 0 additions & 1 deletion .github/workflows/sphinx.yml
Original file line number Diff line number Diff line change
Expand Up @@ -26,4 +26,3 @@ jobs:
uses: sphinx-notes/pages@v3
with:
documentation_path: docs/source
cache: true
16 changes: 8 additions & 8 deletions fairret/loss/projection.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,13 +27,13 @@ class ProjectionLoss(FairnessLoss):
The projections are computed using cvxpy. Hence, any subclass is expected to implement the statistical distance
between distributions in both cvxpy and PyTorch by implementing the
:py:meth:`~projection.ProjectionLoss.cvxpy_distance` method and the
:py:meth:`~projection.ProjectionLoss.torch_distance` method respectively.
:py:meth:`~fairret.loss.projection.ProjectionLoss.cvxpy_distance` method and the
:py:meth:`~fairret.loss.projection.ProjectionLoss.torch_distance` method respectively.
Optionally, the :py:meth:`~projection.ProjectionLoss.torch_distance_with_logits` method can be overwritten to
provide a more numerically stable handling of predictions that are provided as logits. If left unimplemented,
:py:meth:`~projection.ProjectionLoss.torch_distance` will be called instead, after applying the sigmoid function to
the predictions.
Optionally, the :py:meth:`~fairret.loss.projection.ProjectionLoss.torch_distance_with_logits` method can be
overwritten to provide a more numerically stable handling of predictions that are provided as logits. If left
unimplemented, :py:meth:`~fairret.loss.projection.ProjectionLoss.torch_distance` will be called instead,
after applying the sigmoid function to the predictions.
Note:
We use 'statistical distance' in a broad sense here, and do not require that the distance is a metric. See
Expand Down Expand Up @@ -131,8 +131,8 @@ def torch_distance(self, pred: torch.Tensor, proj: torch.Tensor) -> torch.Tensor

def torch_distance_with_logits(self, pred, proj):
"""
A more numerically stable alternative method to :py:meth:`~projection.ProjectionLoss.torch_distance`, where `pred`
is assumed to be logits.
A more numerically stable alternative method to
:py:meth:`~fairret.loss.projection.ProjectionLoss.torch_distance`, where `pred` is assumed to be logits.
Args:
pred (torch.Tensor): The predicted distribution as logits, in shape (N,1). As we assume binary
Expand Down
2 changes: 1 addition & 1 deletion fairret/loss/violation.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ def forward(self, pred: torch.Tensor, sens: torch.Tensor, *stat_args, pred_as_lo
target_statistic: Optional[torch.Tensor] = None, **stat_kwargs: Any) -> torch.Tensor:
"""
Calculate the violation vector in relation to the `target_statistic` and penalize this violation using the
:py:meth:`~violation.ViolationLoss.penalize_violation` method implemented by the subclass.
:py:meth:`~fairret.loss.violation.ViolationLoss.penalize_violation` method implemented by the subclass.
Args:
pred (torch.Tensor): Predictions of shape :math:`(N, 1)`, as we assume to be performing binary
Expand Down
5 changes: 3 additions & 2 deletions fairret/statistic/linear_fractional.py
Original file line number Diff line number Diff line change
Expand Up @@ -595,14 +595,15 @@ def denom_slope(self, label: torch.Tensor) -> torch.Tensor:
class StackedLinearFractionalStatistic(LinearFractionalStatistic):
"""
A vector-valued statistic that combines the outputs of K
:py:class:`fairret.statistic.linear_fractional.LinearFractionalStatistic` 's into a single statistic with output
:py:class:`~fairret.statistic.linear_fractional.LinearFractionalStatistic` 's into a single statistic with output
(K, S) by stacking all outputs in the second-to-last dimension (`dim=-2`).
"""

def __init__(self, *statistics: LinearFractionalStatistic):
"""
Args:
*statistics: The :py:class:`fairret.statistic.linear_fractional.LinearFractionalStatistic` 's to be stacked.
*statistics: The :py:class:`~fairret.statistic.linear_fractional.LinearFractionalStatistic` 's to be
stacked.
"""

super().__init__()
Expand Down

0 comments on commit 5fcf252

Please sign in to comment.