Skip to content

Commit

Permalink
Fix text2text-generation pipeline output inconsistency w/ python li…
Browse files Browse the repository at this point in the history
…brary (#384)

* Fix `text2text-generation` pipeline inconsistency

See https://huggingface.co/docs/transformers/v4.35.0/en/main_classes/pipelines#transformers.Text2TextGenerationPipeline

* Fix `text2text-generation` example in docs

* Improve text2text-generation output in docs
  • Loading branch information
xenova authored Nov 9, 2023
1 parent e440372 commit 96c5dd4
Show file tree
Hide file tree
Showing 4 changed files with 9 additions and 12 deletions.
5 changes: 1 addition & 4 deletions docs/source/pipelines.md
Original file line number Diff line number Diff line change
Expand Up @@ -131,13 +131,10 @@ let result = await poet('Write me a love poem about cheese.', {
temperature: 0.9,
repetition_penalty: 2.0,
no_repeat_ngram_size: 3,

// top_k: 20,
// do_sample: true,
});
```

which outputs:
Logging `result[0].generated_text` to the console gives:

```
Cheese, oh cheese! You're the perfect comfort food.
Expand Down
4 changes: 2 additions & 2 deletions src/pipelines.js
Original file line number Diff line number Diff line change
Expand Up @@ -479,11 +479,11 @@ export class FillMaskPipeline extends Pipeline {
* let output = await generator('how can I become more healthy?', {
* max_new_tokens: 100,
* });
* // [ 'To become more healthy, you can: 1. Eat a balanced diet with plenty of fruits, vegetables, whole grains, lean proteins, and healthy fats. 2. Stay hydrated by drinking plenty of water. 3. Get enough sleep and manage stress levels. 4. Avoid smoking and excessive alcohol consumption. 5. Regularly exercise and maintain a healthy weight. 6. Practice good hygiene and sanitation. 7. Seek medical attention if you experience any health issues.' ]
* // [{ generated_text: "To become more healthy, you can: 1. Eat a balanced diet with plenty of fruits, vegetables, whole grains, lean proteins, and healthy fats. 2. Stay hydrated by drinking plenty of water. 3. Get enough sleep and manage stress levels. 4. Avoid smoking and excessive alcohol consumption. 5. Regularly exercise and maintain a healthy weight. 6. Practice good hygiene and sanitation. 7. Seek medical attention if you experience any health issues." }]
* ```
*/
export class Text2TextGenerationPipeline extends Pipeline {
_key = null;
_key = 'generated_text';

/**
* Fill the masked token in the text(s) given as inputs.
Expand Down
8 changes: 4 additions & 4 deletions tests/generation.test.js
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ describe('Generation parameters', () => {
{
const outputs = await generator(text);

const tokens = generator.tokenizer.encode(outputs[0])
const tokens = generator.tokenizer.encode(outputs[0].generated_text)
expect(tokens.length).toEqual(20);
}

Expand All @@ -37,7 +37,7 @@ describe('Generation parameters', () => {
max_new_tokens: MAX_NEW_TOKENS,
});

const tokens = generator.tokenizer.encode(outputs[0])
const tokens = generator.tokenizer.encode(outputs[0].generated_text)
expect(tokens.length).toEqual(MAX_NEW_TOKENS + 1); // + 1 due to forced BOS token
}

Expand All @@ -52,7 +52,7 @@ describe('Generation parameters', () => {
min_length: MIN_LENGTH,
});

const tokens = generator.tokenizer.encode(outputs[0])
const tokens = generator.tokenizer.encode(outputs[0].generated_text)
expect(tokens.length).toBeGreaterThanOrEqual(MIN_LENGTH);
}

Expand All @@ -67,7 +67,7 @@ describe('Generation parameters', () => {
min_new_tokens: MIN_NEW_TOKENS,
});

const tokens = generator.tokenizer.encode(outputs[0])
const tokens = generator.tokenizer.encode(outputs[0].generated_text)
expect(tokens.length).toBeGreaterThanOrEqual(MIN_NEW_TOKENS);
}

Expand Down
4 changes: 2 additions & 2 deletions tests/pipelines.test.js
Original file line number Diff line number Diff line change
Expand Up @@ -568,7 +568,7 @@ describe('Pipelines', () => {
do_sample: false
});
expect(outputs).toHaveLength(1);
expect(outputs[0].length).toBeGreaterThan(1);
expect(outputs[0].generated_text.length).toBeGreaterThan(1);
}

await generator.dispose();
Expand All @@ -593,7 +593,7 @@ describe('Pipelines', () => {
do_sample: false
});
expect(outputs).toHaveLength(1);
expect(outputs[0].length).toBeGreaterThan(10);
expect(outputs[0].generated_text.length).toBeGreaterThan(10);
}
await generator.dispose();
}, MAX_TEST_EXECUTION_TIME);
Expand Down

0 comments on commit 96c5dd4

Please sign in to comment.