diff --git a/examples/tutorials/t5_demo.py b/examples/tutorials/t5_demo.py index d6165a2009..3741454edf 100644 --- a/examples/tutorials/t5_demo.py +++ b/examples/tutorials/t5_demo.py @@ -2,8 +2,7 @@ T5-Base Model for Summarization, Sentiment Classification, and Translation ========================================================================== -**Author**: `Pendo Abbo `__ -**Author**: `Joe Cummings `__ +**Author**: `Pendo Abbo `__, `Joe Cummings `__ """ @@ -21,14 +20,6 @@ # # -###################################################################### -# Common imports -# -------------- -import torch - -DEVICE = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") - - ####################################################################### # Data Transformation # ------------------- @@ -43,7 +34,7 @@ # # T5 uses a SentencePiece model for text tokenization. Below, we use a pre-trained SentencePiece model to build # the text pre-processing pipeline using torchtext's T5Transform. Note that the transform supports both -# batched and non-batched text input (i.e. one can either pass a single sentence or a list of sentences), however +# batched and non-batched text input (for example, one can either pass a single sentence or a list of sentences), however # the T5 model expects the input to be batched. # @@ -64,7 +55,7 @@ ####################################################################### # Alternatively, we can also use the transform shipped with the pre-trained models that does all of the above out-of-the-box # -# :: +# .. code-block:: # # from torchtext.models import T5_BASE_GENERATION # transform = T5_BASE_GENERATION.transform() @@ -77,8 +68,7 @@ # # torchtext provides SOTA pre-trained models that can be used directly for NLP tasks or fine-tuned on downstream tasks. Below # we use the pre-trained T5 model with standard base configuration to perform text summarization, sentiment classification, and -# translation. For additional details on available pre-trained models, please refer to documentation at -# https://pytorch.org/text/main/models.html +# translation. For additional details on available pre-trained models, see `the torchtext documentation `__ # # from torchtext.models import T5_BASE_GENERATION @@ -88,16 +78,15 @@ transform = t5_base.transform() model = t5_base.get_model() model.eval() -model.to(DEVICE) ####################################################################### # GenerationUtils # ------------------ # -# We can use torchtext's `GenerationUtils` to produce an output sequence based on the input sequence provided. This calls on the +# We can use torchtext's ``GenerationUtils`` to produce an output sequence based on the input sequence provided. This calls on the # model's encoder and decoder, and iteratively expands the decoded sequences until the end-of-sequence token is generated -# for all sequences in the batch. The `generate` method shown below uses greedy search to generate the sequences. Beam search and +# for all sequences in the batch. The ``generate`` method shown below uses greedy search to generate the sequences. Beam search and # other decoding strategies are also supported. # # @@ -114,12 +103,12 @@ # datapipes and hence support standard flow-control and mapping/transformation using user defined # functions and transforms. # -# Below, we demonstrate how to pre-process the CNNDM dataset to include the prefix necessary for the +# Below we demonstrate how to pre-process the CNNDM dataset to include the prefix necessary for the # model to indentify the task it is performing. The CNNDM dataset has a train, validation, and test # split. Below we demo on the test split. # # The T5 model uses the prefix "summarize" for text summarization. For more information on task -# prefixes, please visit Appendix D of the T5 Paper at https://arxiv.org/pdf/1910.10683.pdf +# prefixes, please visit Appendix D of the `T5 Paper `__ # # .. note:: # Using datapipes is still currently subject to a few caveats. If you wish @@ -144,12 +133,12 @@ def apply_prefix(task, x): cnndm_datapipe = cnndm_datapipe.map(partial(apply_prefix, task)) cnndm_datapipe = cnndm_datapipe.batch(cnndm_batch_size) cnndm_datapipe = cnndm_datapipe.rows2columnar(["article", "abstract"]) -cnndm_dataloader = DataLoader(cnndm_datapipe, batch_size=None) +cnndm_dataloader = DataLoader(cnndm_datapipe, shuffle=True, batch_size=None) ####################################################################### -# Alternately we can also use batched API (i.e apply the prefix on the whole batch) +# Alternately, we can also use batched API, for example, apply the prefix on the whole batch: # -# :: +# .. code-block:: # # def batch_prefix(task, x): # return { @@ -179,11 +168,11 @@ def apply_prefix(task, x): imdb_batch_size = 3 imdb_datapipe = IMDB(split="test") task = "sst2 sentence" -labels = {"neg": "negative", "pos": "positive"} +labels = {"1": "negative", "2": "positive"} def process_labels(labels, x): - return x[1], labels[x[0]] + return x[1], labels[str(x[0])] imdb_datapipe = imdb_datapipe.map(partial(process_labels, labels)) @@ -224,7 +213,7 @@ def process_labels(labels, x): beam_size = 1 model_input = transform(input_text) -model_output = sequence_generator.generate(model_input, eos_idx=eos_idx, beam_size=beam_size) +model_output = sequence_generator.generate(model_input, eos_idx=eos_idx, num_beams=beam_size) output_text = transform.decode(model_output.tolist()) for i in range(cnndm_batch_size): @@ -234,21 +223,19 @@ def process_labels(labels, x): ####################################################################### -# Summarization Output +# Summarization Output (Might vary since we shuffle the dataloader) # -------------------- # -# :: +# .. code-block:: # # Example 1: # -# prediction: the Palestinians become the 123rd member of the international criminal -# court . the accession was marked by a ceremony at the Hague, where the court is based . -# the ICC opened a preliminary examination into the situation in the occupied -# Palestinian territory . +# prediction: the 24-year-old has been tattooed for over a decade . he has landed in australia +# to start work on a new campaign . he says he is 'taking it in your stride' to be honest . # -# target: Membership gives the ICC jurisdiction over alleged crimes committed in -# Palestinian territories since last June . Israel and the United States opposed the -# move, which could open the door to war crimes investigations against Israelis . +# target: London-based model Stephen James Hendry famed for his full body tattoo . The supermodel +# is in Sydney for a new modelling campaign . Australian fans understood to have already located +# him at his hotel . The 24-year-old heartthrob is recently single . # # # Example 2: @@ -314,7 +301,7 @@ def process_labels(labels, x): beam_size = 1 model_input = transform(input_text) -model_output = sequence_generator.generate(model_input, eos_idx=eos_idx, beam_size=beam_size) +model_output = sequence_generator.generate(model_input, eos_idx=eos_idx, num_beams=beam_size) output_text = transform.decode(model_output.tolist()) for i in range(imdb_batch_size): @@ -372,7 +359,7 @@ def process_labels(labels, x): # really annoying was the constant cuts to VDs daughter during the last fight scene.

# Not bad. Not good. Passable 4. # -# prediction: negative +# prediction: positive # # target: negative # @@ -399,16 +386,15 @@ def process_labels(labels, x): # --------------------- # # Finally, we can also use the model to generate English to German translations on the first batch of examples from the Multi30k -# test set using a beam size of 4. +# test set. # batch = next(iter(multi_dataloader)) input_text = batch["english"] target = batch["german"] -beam_size = 4 model_input = transform(input_text) -model_output = sequence_generator.generate(model_input, eos_idx=eos_idx, beam_size=beam_size) +model_output = sequence_generator.generate(model_input, eos_idx=eos_idx, num_beams=beam_size) output_text = transform.decode(model_output.tolist()) for i in range(multi_batch_size):