Skip to content

Commit

Permalink
up
Browse files Browse the repository at this point in the history
  • Loading branch information
Jianguo99 committed Dec 7, 2023
1 parent 03ef7cb commit a8dbf5c
Show file tree
Hide file tree
Showing 7 changed files with 107 additions and 9 deletions.
Empty file added deepcp/__init__.py
Empty file.
5 changes: 4 additions & 1 deletion deepcp/classification/predictor/standard.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@

import torch
import numpy as np

from .base import BasePredictor

Expand All @@ -15,7 +16,9 @@ def fit(self, x_cal, y_cal, alpha):
for index,(x,y) in enumerate(zip(x_cal,y_cal)):
scores.append(self.score_function(x,y))
scores = torch.tensor(scores)
self.q_hat = torch.quantile(scores,1-alpha)

self.q_hat = torch.quantile( scores , np.ceil((scores.shape[0]+1) * (1-alpha)) / scores.shape[0] )




Expand Down
23 changes: 20 additions & 3 deletions deepcp/classification/utils/metircs.py
Original file line number Diff line number Diff line change
@@ -1,21 +1,38 @@


def compute_coverage_rate(prediction_sets,labels):
from deepcp.utils.registry import Registry

METRICS_REGISTRY = Registry("METRICS")


@METRICS_REGISTRY.register()
def coverage_rate(prediction_sets,labels):
cvg = 0
for index,ele in enumerate(zip(prediction_sets,labels)):
if ele[1] in ele[0]:
cvg += 1
return cvg/len(prediction_sets)

@METRICS_REGISTRY.register()
def average_size(prediction_sets,labels):
avg_size = 0
for index,ele in enumerate(prediction_sets):
avg_size += len(ele)
return avg_size/len(prediction_sets)




class Metrics:
def __init__(self,metrics_list=[]) -> None:
self.metrics_list = metrics_list


def compute(self,prediction_sets,labels):
# for metric in self.metrics_list:
return compute_coverage_rate(prediction_sets,labels)
metrics = {}
for metric in self.metrics_list:
metrics[metric] = METRICS_REGISTRY.get(metric)(prediction_sets,labels)
return metrics



2 changes: 2 additions & 0 deletions deepcp/utils/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
from .common import *
from .registry import *
7 changes: 6 additions & 1 deletion deepcp/common.py → deepcp/utils/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,14 @@
import numpy as np
import random


__all__ = ["fix_randomness"]

def fix_randomness(seed=0):
### Fix randomness
np.random.seed(seed=seed)
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
random.seed(seed)
random.seed(seed)


71 changes: 71 additions & 0 deletions deepcp/utils/registry.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@


from difflib import SequenceMatcher

__all__ = ["Registry"]


class Registry:
"""A registry providing name -> object mapping, to support
custom modules.
To create a registry (e.g. a backbone registry):
.. code-block:: python
BACKBONE_REGISTRY = Registry('BACKBONE')
To register an object:
.. code-block:: python
@BACKBONE_REGISTRY.register()
class MyBackbone(nn.Module):
...
Or:
.. code-block:: python
BACKBONE_REGISTRY.register(MyBackbone)
"""

def __init__(self, name):
self._name = name
self._obj_map = dict()

def _do_register(self, name, obj, force=False):
if name in self._obj_map and not force:
raise KeyError(
'An object named "{}" was already '
'registered in "{}" registry'.format(name, self._name)
)

self._obj_map[name] = obj

def register(self, obj=None, force=False):
if obj is None:
# Used as a decorator
def wrapper(fn_or_class):
name = fn_or_class.__name__
self._do_register(name, fn_or_class, force=force)
return fn_or_class

return wrapper

# Used as a function call
name = obj.__name__
self._do_register(name, obj, force=force)

def get(self, name):
if name not in self._obj_map:
raise KeyError(
'Object name "{}" does not exist '
'in "{}" registry'.format(name, self._name)
)

return self._obj_map[name]

def registered_names(self):
return list(self._obj_map.keys())

8 changes: 4 additions & 4 deletions imagenet_thr_standard.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,8 @@

from deepcp.classification.scores import THR
from deepcp.classification.predictor import StandardPredictor
from deepcp.classification.utils.metircs import Metrics
from deepcp.common import fix_randomness
from deepcp.classification.utils.metircs import Metrics
from deepcp.utils import fix_randomness


fix_randomness(seed = 0)
Expand Down Expand Up @@ -77,6 +77,6 @@
prediction_set = predictor.predict(ele)
prediction_sets.append(prediction_set)

print("computing metrics...")
metrics = Metrics(["coverage_rate"])
print("Evaluating prediction sets...")
metrics = Metrics(["coverage_rate","average_size"])
print(metrics.compute(prediction_sets,test_labels))

0 comments on commit a8dbf5c

Please sign in to comment.