Skip to content

Commit

Permalink
Update processing.py
Browse files Browse the repository at this point in the history
  • Loading branch information
raphaelsty committed Oct 11, 2024
1 parent 98c4e2b commit 7074625
Showing 1 changed file with 18 additions and 4 deletions.
22 changes: 18 additions & 4 deletions pylate/utils/processing.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,9 @@ class KDProcessing:
Queries dataset.
documents
Documents dataset.
n_scores
split
Split to use for the queries and documents datasets. Used only if the queries and documents are of type `datasets.DatasetDict`.
n_ways
Number of scores to keep for the distillation.
Examples
Expand Down Expand Up @@ -55,10 +57,22 @@ class KDProcessing:
"""

def __init__(
self, queries: datasets.Dataset, documents: datasets.Dataset, n_ways: int = 32
self,
queries: datasets.Dataset | datasets.DatasetDict,
documents: datasets.Dataset | datasets.DatasetDict,
split: str = "train",
n_ways: int = 32,
) -> None:
self.queries = queries["train"] if "train" in queries else queries
self.documents = documents["train"] if "train" in documents else documents
if isinstance(queries, datasets.DatasetDict):
self.queries = queries[split]
else:
self.queries = queries

if isinstance(documents, datasets.DatasetDict):
self.documents = documents[split]
else:
self.documents = documents

self.n_ways = n_ways

self.queries_index = {
Expand Down

0 comments on commit 7074625

Please sign in to comment.