Skip to content

Commit

Permalink
docs(examples): add notes to gpt-2 example
Browse files Browse the repository at this point in the history
* Leaving Error Note for Examples

* added docs for gpt2.rs

* Delete examples/Readme.md

* Update gpt2.rs

---------

Co-authored-by: Carson M <[email protected]>
  • Loading branch information
JewishLewish and decahedron1 authored Nov 13, 2023
1 parent e5d8b6d commit 4778871
Showing 1 changed file with 12 additions and 0 deletions.
12 changes: 12 additions & 0 deletions examples/gpt2/examples/gpt2.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,15 +9,24 @@ use rand::Rng;
use tokenizers::Tokenizer;

const PROMPT: &str = "The corsac fox (Vulpes corsac), also known simply as a corsac, is a medium-sized fox found in";
/// Max tokens to generate
const GEN_TOKENS: i32 = 90;
/// Top_K -> Sample from the k most likely next tokens at each step. Lower k focuses on higher probability tokens.
const TOP_K: usize = 5;

/// GPT-2 Text Generation
///
/// This Rust program demonstrates text generation using the GPT-2 language model with `ort`.
/// The program initializes the model, tokenizes a prompt, and generates a sequence of tokens.
/// It utilizes top-k sampling for diverse and contextually relevant text generation.
fn main() -> ort::Result<()> {
/// Initialize tracing to receive debug messages from `ort`
tracing_subscriber::fmt::init();

let mut stdout = io::stdout();
let mut rng = rand::thread_rng();

/// Create the ONNX Runtime environment and session for the GPT-2 model.
let environment = Environment::builder()
.with_name("GPT-2")
.with_execution_providers([CUDAExecutionProvider::default().build()])
Expand All @@ -29,6 +38,7 @@ fn main() -> ort::Result<()> {
.with_intra_threads(1)?
.with_model_downloaded(GPT2::GPT2LmHead)?;

/// Load the tokenizer and encode the prompt into a sequence of tokens.
let tokenizer = Tokenizer::from_file(Path::new(env!("CARGO_MANIFEST_DIR")).join("data").join("tokenizer.json")).unwrap();
let tokens = tokenizer.encode(PROMPT, false).unwrap();
let tokens = tokens.get_ids().iter().map(|i| *i as i64).collect::<Vec<_>>();
Expand All @@ -44,6 +54,7 @@ fn main() -> ort::Result<()> {
let generated_tokens: Tensor<f32> = outputs["output1"].extract_tensor()?;
let generated_tokens = generated_tokens.view();

/// Collect and sort logits
let probabilities = &mut generated_tokens
.slice(s![0, 0, -1, ..])
.insert_axis(Axis(0))
Expand All @@ -54,6 +65,7 @@ fn main() -> ort::Result<()> {
.collect::<Vec<_>>();
probabilities.sort_unstable_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Less));

/// Sample using top-k sampling
let token = probabilities[rng.gen_range(0..=TOP_K)].0;
tokens = concatenate![Axis(0), tokens, array![token.try_into().unwrap()]];

Expand Down

0 comments on commit 4778871

Please sign in to comment.