Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[cherry-pick] Update tutorial to match PyTorch main tutorials page (#2097) #2100

Open
wants to merge 1 commit into
base: release/0.15
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
64 changes: 25 additions & 39 deletions examples/tutorials/t5_demo.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,7 @@
T5-Base Model for Summarization, Sentiment Classification, and Translation
==========================================================================

**Author**: `Pendo Abbo <[email protected]>`__
**Author**: `Joe Cummings <[email protected]>`__
**Author**: `Pendo Abbo <[email protected]>`__, `Joe Cummings <[email protected]>`__

"""

Expand All @@ -21,14 +20,6 @@
#
#

######################################################################
# Common imports
# --------------
import torch

DEVICE = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")


#######################################################################
# Data Transformation
# -------------------
Expand All @@ -43,7 +34,7 @@
#
# T5 uses a SentencePiece model for text tokenization. Below, we use a pre-trained SentencePiece model to build
# the text pre-processing pipeline using torchtext's T5Transform. Note that the transform supports both
# batched and non-batched text input (i.e. one can either pass a single sentence or a list of sentences), however
# batched and non-batched text input (for example, one can either pass a single sentence or a list of sentences), however
# the T5 model expects the input to be batched.
#

Expand All @@ -64,7 +55,7 @@
#######################################################################
# Alternatively, we can also use the transform shipped with the pre-trained models that does all of the above out-of-the-box
#
# ::
# .. code-block::
#
# from torchtext.models import T5_BASE_GENERATION
# transform = T5_BASE_GENERATION.transform()
Expand All @@ -77,8 +68,7 @@
#
# torchtext provides SOTA pre-trained models that can be used directly for NLP tasks or fine-tuned on downstream tasks. Below
# we use the pre-trained T5 model with standard base configuration to perform text summarization, sentiment classification, and
# translation. For additional details on available pre-trained models, please refer to documentation at
# https://pytorch.org/text/main/models.html
# translation. For additional details on available pre-trained models, see `the torchtext documentation <https://pytorch.org/text/main/models.html>`__
#
#
from torchtext.models import T5_BASE_GENERATION
Expand All @@ -88,16 +78,15 @@
transform = t5_base.transform()
model = t5_base.get_model()
model.eval()
model.to(DEVICE)


#######################################################################
# GenerationUtils
# ------------------
#
# We can use torchtext's `GenerationUtils` to produce an output sequence based on the input sequence provided. This calls on the
# We can use torchtext's ``GenerationUtils`` to produce an output sequence based on the input sequence provided. This calls on the
# model's encoder and decoder, and iteratively expands the decoded sequences until the end-of-sequence token is generated
# for all sequences in the batch. The `generate` method shown below uses greedy search to generate the sequences. Beam search and
# for all sequences in the batch. The ``generate`` method shown below uses greedy search to generate the sequences. Beam search and
# other decoding strategies are also supported.
#
#
Expand All @@ -114,12 +103,12 @@
# datapipes and hence support standard flow-control and mapping/transformation using user defined
# functions and transforms.
#
# Below, we demonstrate how to pre-process the CNNDM dataset to include the prefix necessary for the
# Below we demonstrate how to pre-process the CNNDM dataset to include the prefix necessary for the
# model to indentify the task it is performing. The CNNDM dataset has a train, validation, and test
# split. Below we demo on the test split.
#
# The T5 model uses the prefix "summarize" for text summarization. For more information on task
# prefixes, please visit Appendix D of the T5 Paper at https://arxiv.org/pdf/1910.10683.pdf
# prefixes, please visit Appendix D of the `T5 Paper <https://arxiv.org/pdf/1910.10683.pdf>`__
#
# .. note::
# Using datapipes is still currently subject to a few caveats. If you wish
Expand All @@ -144,12 +133,12 @@ def apply_prefix(task, x):
cnndm_datapipe = cnndm_datapipe.map(partial(apply_prefix, task))
cnndm_datapipe = cnndm_datapipe.batch(cnndm_batch_size)
cnndm_datapipe = cnndm_datapipe.rows2columnar(["article", "abstract"])
cnndm_dataloader = DataLoader(cnndm_datapipe, batch_size=None)
cnndm_dataloader = DataLoader(cnndm_datapipe, shuffle=True, batch_size=None)

#######################################################################
# Alternately we can also use batched API (i.e apply the prefix on the whole batch)
# Alternately, we can also use batched API, for example, apply the prefix on the whole batch:
#
# ::
# .. code-block::
#
# def batch_prefix(task, x):
# return {
Expand Down Expand Up @@ -179,11 +168,11 @@ def apply_prefix(task, x):
imdb_batch_size = 3
imdb_datapipe = IMDB(split="test")
task = "sst2 sentence"
labels = {"neg": "negative", "pos": "positive"}
labels = {"1": "negative", "2": "positive"}


def process_labels(labels, x):
return x[1], labels[x[0]]
return x[1], labels[str(x[0])]


imdb_datapipe = imdb_datapipe.map(partial(process_labels, labels))
Expand Down Expand Up @@ -224,7 +213,7 @@ def process_labels(labels, x):
beam_size = 1

model_input = transform(input_text)
model_output = sequence_generator.generate(model_input, eos_idx=eos_idx, beam_size=beam_size)
model_output = sequence_generator.generate(model_input, eos_idx=eos_idx, num_beams=beam_size)
output_text = transform.decode(model_output.tolist())

for i in range(cnndm_batch_size):
Expand All @@ -234,21 +223,19 @@ def process_labels(labels, x):


#######################################################################
# Summarization Output
# Summarization Output (Might vary since we shuffle the dataloader)
# --------------------
#
# ::
# .. code-block::
#
# Example 1:
#
# prediction: the Palestinians become the 123rd member of the international criminal
# court . the accession was marked by a ceremony at the Hague, where the court is based .
# the ICC opened a preliminary examination into the situation in the occupied
# Palestinian territory .
# prediction: the 24-year-old has been tattooed for over a decade . he has landed in australia
# to start work on a new campaign . he says he is 'taking it in your stride' to be honest .
#
# target: Membership gives the ICC jurisdiction over alleged crimes committed in
# Palestinian territories since last June . Israel and the United States opposed the
# move, which could open the door to war crimes investigations against Israelis .
# target: London-based model Stephen James Hendry famed for his full body tattoo . The supermodel
# is in Sydney for a new modelling campaign . Australian fans understood to have already located
# him at his hotel . The 24-year-old heartthrob is recently single .
#
#
# Example 2:
Expand Down Expand Up @@ -314,7 +301,7 @@ def process_labels(labels, x):
beam_size = 1

model_input = transform(input_text)
model_output = sequence_generator.generate(model_input, eos_idx=eos_idx, beam_size=beam_size)
model_output = sequence_generator.generate(model_input, eos_idx=eos_idx, num_beams=beam_size)
output_text = transform.decode(model_output.tolist())

for i in range(imdb_batch_size):
Expand Down Expand Up @@ -372,7 +359,7 @@ def process_labels(labels, x):
# really annoying was the constant cuts to VDs daughter during the last fight scene.<br /><br />
# Not bad. Not good. Passable 4.
#
# prediction: negative
# prediction: positive
#
# target: negative
#
Expand All @@ -399,16 +386,15 @@ def process_labels(labels, x):
# ---------------------
#
# Finally, we can also use the model to generate English to German translations on the first batch of examples from the Multi30k
# test set using a beam size of 4.
# test set.
#

batch = next(iter(multi_dataloader))
input_text = batch["english"]
target = batch["german"]
beam_size = 4

model_input = transform(input_text)
model_output = sequence_generator.generate(model_input, eos_idx=eos_idx, beam_size=beam_size)
model_output = sequence_generator.generate(model_input, eos_idx=eos_idx, num_beams=beam_size)
output_text = transform.decode(model_output.tolist())

for i in range(multi_batch_size):
Expand Down