Skip to content

Commit

Permalink
[xdoctest][task 313] reformat example code with google style in pytho…
Browse files Browse the repository at this point in the history
…n/paddle/sparse/nn/functional/transformer.py (PaddlePaddle#57131)

* [Doctest]fix No.313, test=docs_preview

* cast to float32

---------

Co-authored-by: SigureMo <[email protected]>
  • Loading branch information
yoyoIcy and SigureMo authored Sep 18, 2023
1 parent 9a71a59 commit 1b31658
Showing 1 changed file with 28 additions and 28 deletions.
56 changes: 28 additions & 28 deletions python/paddle/sparse/nn/functional/transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,50 +46,50 @@ def attention(
``d`` represents ``head_dim`` .
Args:
query(DenseTensor): `query` in the Attention module. 4D Tensor with float32 or float64.
key(DenseTensor): `key` in the Attention module. 4D Tensor with float32 or float64.
value(DenseTensor): `value` in the Attention module. 4D Tensor with float32 or float64.
sparse_mask(SparseCsrTensor): The sparse layout in the Attention module. Its dense shape
is `[batch_size*num_heads, seq_len, seq_len]` . `nnz` of each batch must be the same.
query (DenseTensor): `query` in the Attention module. 4D Tensor with float32 or float64.
key (DenseTensor): `key` in the Attention module. 4D Tensor with float32 or float64.
value (DenseTensor): `value` in the Attention module. 4D Tensor with float32 or float64.
sparse_mask (SparseCsrTensor): The sparse layout in the Attention module. Its dense shape
is `[batch_size*num_heads, seq_len, seq_len]`. `nnz` of each batch must be the same.
dtype of `crows` and `cols` must be int64, dtype of `values` can be float32 or float64.
key_padding_mask(DenseTensor, optional): The key padding mask tensor in the Attention module.
key_padding_mask (DenseTensor, optional): The key padding mask tensor in the Attention module.
2D tensor with shape: [batch_size, seq_len]. dtype can be float32 or float64. Default: None.
attn_mask(DenseTensor, optional): The attention mask tensor in the Attention module.
attn_mask (DenseTensor, optional): The attention mask tensor in the Attention module.
2D tensor with shape: [seq_len, seq_len]. dtype can be float32 or float64. Default: None.
name(str, optional): The default value is None. Normally there is no need for user
to set this property. For more information, please refer to
:ref:`api_guide_Name`.
name (str, optional): The default value is None. Normally there is no need for user
to set this property. For more information, please refer to :ref:`api_guide_Name`.
Returns:
4D tensor with shape: [batch_size, num_heads, seq_len, head_dim]. dtype is same with input.
Examples:
.. code-block:: python
# required: gpu
import paddle
>>> # doctest: +REQUIRES(env:GPU)
>>> import paddle
>>> paddle.device.set_device('gpu')
batch_size = 16
num_heads = 16
seq_len = 512
head_dim = 32
>>> batch_size = 16
>>> num_heads = 16
>>> seq_len = 512
>>> head_dim = 32
query = paddle.rand([batch_size, num_heads, seq_len, head_dim])
key = paddle.rand([batch_size, num_heads, seq_len, head_dim])
value = paddle.rand([batch_size, num_heads, seq_len, head_dim])
>>> query = paddle.rand([batch_size, num_heads, seq_len, head_dim])
>>> key = paddle.rand([batch_size, num_heads, seq_len, head_dim])
>>> value = paddle.rand([batch_size, num_heads, seq_len, head_dim])
query.stop_gradient = False
key.stop_gradient = False
value.stop_gradient = False
>>> query.stop_gradient = False
>>> key.stop_gradient = False
>>> value.stop_gradient = False
mask = paddle.nn.functional.dropout(paddle.ones([seq_len, seq_len])).expand([batch_size, num_heads, seq_len, seq_len])
sp_mask = mask.reshape([-1, seq_len, seq_len]).to_sparse_csr()
>>> mask = paddle.nn.functional.dropout(paddle.ones([seq_len, seq_len])).expand([batch_size, num_heads, seq_len, seq_len])
>>> sp_mask = mask.reshape([-1, seq_len, seq_len]).to_sparse_csr()
kp_mask = paddle.randint(0, 2, [batch_size, seq_len])
attn_mask = paddle.randint(0, 2, [seq_len, seq_len])
>>> kp_mask = paddle.randint(0, 2, [batch_size, seq_len]).astype(paddle.float32)
>>> attn_mask = paddle.randint(0, 2, [seq_len, seq_len]).astype(paddle.float32)
output = paddle.sparse.nn.functional.attention(query, key, value, sp_mask, kp_mask, attn_mask)
output.backward()
>>> output = paddle.sparse.nn.functional.attention(query, key, value, sp_mask, kp_mask, attn_mask)
>>> output.backward()
"""
return _C_ops.sparse_fused_attention(
query, key, value, sparse_mask, key_padding_mask, attn_mask
Expand Down

0 comments on commit 1b31658

Please sign in to comment.