Skip to content

Commit

Permalink
update docs
Browse files Browse the repository at this point in the history
  • Loading branch information
Jianguo99 committed Dec 25, 2023
1 parent 2fa0979 commit 1a1869b
Showing 1 changed file with 10 additions and 6 deletions.
16 changes: 10 additions & 6 deletions torchcp/classification/loss/conftr.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
#
__all__ = ["ConfTr"]


import torch

Expand All @@ -14,9 +16,8 @@
class ConfTr(nn.Module):
"""
Conformal Training (Stutz et al., 2021).
Paper: https://arxiv.org/abs/2110.09192.
Paper: https://arxiv.org/abs/2110.09192
:param weight: the weight of each loss function
:param predictor: the CP predictors
:param alpha: the significance level for each training batch
Expand All @@ -31,10 +32,13 @@ def __init__(self, weight, predictor, alpha, fraction, loss_type="valid", target

super(ConfTr, self).__init__()
assert weight>0, "weight must be greater than 0."
assert (fraction > 0 and fraction<1), "fraction should be a value in (0,1)."
assert loss_type in ["valid", "classification", "probs", "coverage"], 'loss_type should be a value in ["valid", "classification", "probs", "coverage"].'
assert (0 < fraction < 1), "fraction should be a value in (0,1)."
assert loss_type in ["valid", "classification", "probs", "coverage"], ('loss_type should be a value in ['
'"valid", "classification", "probs", '
'"coverage"].')
assert target_size==0 or target_size ==1, "target_size should be 0 or 1."
assert loss_transform in ["square", "abs", "log"], 'loss_transform should be a value in ["square", "abs", "log"].'
assert loss_transform in ["square", "abs", "log"], ('loss_transform should be a value in ["square", "abs", '
'"log"].')
self.weight = weight
self.predictor = predictor
self.alpha = alpha
Expand Down

0 comments on commit 1a1869b

Please sign in to comment.