Skip to content

Commit

Permalink
add word
Browse files Browse the repository at this point in the history
  • Loading branch information
Jianguo99 committed Dec 8, 2023
1 parent 8063d61 commit 2cdd4aa
Showing 1 changed file with 2 additions and 4 deletions.
6 changes: 2 additions & 4 deletions deepcp/classification/predictor/standard.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
import torch
import numpy as np

from .base import BasePredictor
from deepcp.classification.predictor.base import BasePredictor


class StandardPredictor(BasePredictor):
def __init__(self, score_function):0
def __init__(self, score_function):
super().__init__(score_function)


Expand All @@ -17,10 +17,8 @@ def fit(self, x_cal, y_cal, alpha):

self.q_hat = torch.quantile(scores, np.ceil((scores.shape[0] + 1) * (1 - alpha)) / scores.shape[0])


def predict(self, x):
scores = self.score_function.predict(x)
S = torch.argwhere(scores < self.q_hat).reshape(-1).tolist()
return S


0 comments on commit 2cdd4aa

Please sign in to comment.