From 1a1869ba857e1992fc6ef8ecc20e4d4112d600e3 Mon Sep 17 00:00:00 2001 From: huangjg Date: Mon, 25 Dec 2023 19:07:23 +0800 Subject: [PATCH] update docs --- torchcp/classification/loss/conftr.py | 16 ++++++++++------ 1 file changed, 10 insertions(+), 6 deletions(-) diff --git a/torchcp/classification/loss/conftr.py b/torchcp/classification/loss/conftr.py index 5ea8bd5..0380699 100644 --- a/torchcp/classification/loss/conftr.py +++ b/torchcp/classification/loss/conftr.py @@ -4,6 +4,8 @@ # This source code is licensed under the license found in the # LICENSE file in the root directory of this source tree. # +__all__ = ["ConfTr"] + import torch @@ -14,9 +16,8 @@ class ConfTr(nn.Module): """ Conformal Training (Stutz et al., 2021). - Paper: https://arxiv.org/abs/2110.09192. - - + Paper: https://arxiv.org/abs/2110.09192 + :param weight: the weight of each loss function :param predictor: the CP predictors :param alpha: the significance level for each training batch @@ -31,10 +32,13 @@ def __init__(self, weight, predictor, alpha, fraction, loss_type="valid", target super(ConfTr, self).__init__() assert weight>0, "weight must be greater than 0." - assert (fraction > 0 and fraction<1), "fraction should be a value in (0,1)." - assert loss_type in ["valid", "classification", "probs", "coverage"], 'loss_type should be a value in ["valid", "classification", "probs", "coverage"].' + assert (0 < fraction < 1), "fraction should be a value in (0,1)." + assert loss_type in ["valid", "classification", "probs", "coverage"], ('loss_type should be a value in [' + '"valid", "classification", "probs", ' + '"coverage"].') assert target_size==0 or target_size ==1, "target_size should be 0 or 1." - assert loss_transform in ["square", "abs", "log"], 'loss_transform should be a value in ["square", "abs", "log"].' + assert loss_transform in ["square", "abs", "log"], ('loss_transform should be a value in ["square", "abs", ' + '"log"].') self.weight = weight self.predictor = predictor self.alpha = alpha