Skip to content

Commit

Permalink
pass prelimiter into ALL ICL datasets (#3069)
Browse files Browse the repository at this point in the history
* pass prelimiter into ALL ICL datasets

* add prelimiter

* add context option prelimiter

---------

Co-authored-by: Daniel King <[email protected]>
Co-authored-by: Max Marion <[email protected]>
Co-authored-by: Max Marion <[email protected]>
  • Loading branch information
4 people authored Mar 3, 2024
1 parent 548842c commit ab4273d
Showing 1 changed file with 7 additions and 2 deletions.
9 changes: 7 additions & 2 deletions composer/datasets/in_context_learning_evaluation.py
Original file line number Diff line number Diff line change
Expand Up @@ -1130,7 +1130,7 @@ def construct_context(self, example, preceding_text: str = '', add_answer: bool
context = context_options[gold_idx]
if len(preceding_text) > 0:
context = f'{self.example_delimiter}{context}'
context = f'{context}{self.continuation_delimiter}{continuation}'
context = f'{self.prelimiter}{context}{self.continuation_delimiter}{continuation}'
return context

def _construct_multiple_contexts(self, example: Dict, preceding_text: str = '') -> List[str]:
Expand All @@ -1151,7 +1151,9 @@ def _construct_multiple_contexts(self, example: Dict, preceding_text: str = '')
cont_del = self.continuation_delimiter.rstrip()
else:
cont_del = self.continuation_delimiter
context_options = [f'{self.example_delimiter}{c}{cont_del}' for c in context_options]
context_options = [f'{self.prelimiter}{self.example_delimiter}{c}{cont_del}' for c in context_options]
else:
context_options = [f'{self.prelimiter}{c}' for c in context_options]
return context_options

def _prep_example(
Expand Down Expand Up @@ -1480,6 +1482,7 @@ def build_icl_dataloader(
example_delimiter=example_delimiter,
continuation_delimiter=continuation_delimiter,
destination_path=destination_path,
prelimiter=prelimiter,
fewshot_random_seed=fewshot_random_seed,
hf_loading_vars=hf_loading_vars,
hf_parsing_map=hf_parsing_map,
Expand All @@ -1498,6 +1501,7 @@ def build_icl_dataloader(
example_delimiter=example_delimiter,
continuation_delimiter=continuation_delimiter,
destination_path=destination_path,
prelimiter=prelimiter,
fewshot_random_seed=fewshot_random_seed,
hf_loading_vars=hf_loading_vars,
hf_parsing_map=hf_parsing_map,
Expand All @@ -1516,6 +1520,7 @@ def build_icl_dataloader(
example_delimiter=example_delimiter,
continuation_delimiter=continuation_delimiter,
destination_path=destination_path,
prelimiter=prelimiter,
fewshot_random_seed=fewshot_random_seed,
hf_loading_vars=hf_loading_vars,
hf_parsing_map=hf_parsing_map,
Expand Down

0 comments on commit ab4273d

Please sign in to comment.