Skip to content

Commit

Permalink
[xdoctest][task 365-370] reformat example code with google style in …
Browse files Browse the repository at this point in the history
…`paddle/incubate/nn/functional/` ,`paddle/incubate/optimizer/` (PaddlePaddle#58178)

* [Doctest]fix No.365-370, test=docs_preview

* Apply suggestions from code review

* Update python/paddle/incubate/optimizer/lars_momentum.py

---------

Co-authored-by: Nyakku Shigure <[email protected]>
  • Loading branch information
ooooo-create and SigureMo authored Nov 3, 2023
1 parent 23b88f3 commit 808239e
Show file tree
Hide file tree
Showing 5 changed files with 106 additions and 103 deletions.
17 changes: 9 additions & 8 deletions python/paddle/incubate/nn/functional/fused_layer_norm.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,14 +58,15 @@ def fused_layer_norm(
Examples:
.. code-block:: python
# required: gpu
import paddle
paddle_x = paddle.cast(paddle.randn(shape=[32, 256]), dtype=paddle.float16)
paddle_weight = paddle.cast(paddle.randn(shape=[256]), dtype=paddle.float32)
paddle_bias = paddle.cast(paddle.randn(shape=[256]), dtype=paddle.float32)
epsilon = 1e-6
paddle_layernorm = paddle.incubate.nn.functional.fused_layer_norm(paddle_x, paddle_weight, paddle_bias, epsilon, 1)
>>> # doctest: +REQUIRES(env:GPU)
>>> import paddle
>>> paddle.device.set_device('gpu')
>>> paddle_x = paddle.cast(paddle.randn(shape=[32, 256]), dtype=paddle.float16)
>>> paddle_weight = paddle.cast(paddle.randn(shape=[256]), dtype=paddle.float32)
>>> paddle_bias = paddle.cast(paddle.randn(shape=[256]), dtype=paddle.float32)
>>> epsilon = 1e-6
>>> paddle_layernorm = paddle.incubate.nn.functional.fused_layer_norm(paddle_x, paddle_weight, paddle_bias, epsilon, 1)
"""

if in_dynamic_mode():
Expand Down
23 changes: 12 additions & 11 deletions python/paddle/incubate/nn/functional/masked_multihead_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,21 +71,22 @@ def masked_multihead_attention(
Examples:
.. code-block:: python
# required: gpu
import paddle
import paddle.incubate.nn.functional as F
>>> # doctest: +REQUIRES(env:GPU)
>>> import paddle
>>> import paddle.incubate.nn.functional as F
>>> paddle.device.set_device('gpu')
# input: [batch_size, 3 * num_head * dim_head]
x = paddle.rand(shape=(2, 3 * 32 * 128), dtype="float32")
>>> # input: [batch_size, 3 * num_head * dim_head]
>>> x = paddle.rand(shape=(2, 3 * 32 * 128), dtype="float32")
# src_mask: [batch_size, 1, 1, sequence_length]
src_mask = paddle.rand(shape=(2, 1, 1, 10), dtype="float32")
>>> # src_mask: [batch_size, 1, 1, sequence_length]
>>> src_mask = paddle.rand(shape=(2, 1, 1, 10), dtype="float32")
# cache_kv: [2, batch_size, num_head, max_seq_len, dim_head]
cache_kv = paddle.rand(shape=(2, 2, 32, 64, 128), dtype="float32")
>>> # cache_kv: [2, batch_size, num_head, max_seq_len, dim_head]
>>> cache_kv = paddle.rand(shape=(2, 2, 32, 64, 128), dtype="float32")
output = F.masked_multihead_attention(
x, src_mask=src_mask, cache_kv=cache_kv)
>>> output = F.masked_multihead_attention(
... x, src_mask=src_mask, cache_kv=cache_kv)
"""

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -54,38 +54,40 @@ def variable_length_memory_efficient_attention(
Examples:
.. code-block:: python
# required: gpu
import math
import paddle
from paddle.incubate.nn.functional import variable_length_memory_efficient_attention
batch = 1
num_head = 8
seq_len = 256
head_size = 32
dtype = paddle.float16
query = paddle.randn([batch, num_head, seq_len, head_size], dtype=dtype)
key = paddle.randn([batch, num_head, seq_len, head_size], dtype=dtype)
value = paddle.randn([batch, num_head, seq_len, head_size], dtype=dtype)
seq_lens = paddle.to_tensor([seq_len, ] * batch, dtype='int32')
mask = paddle.randn([batch, 1, seq_len, seq_len], dtype=dtype)
scale = float(1.0 / math.sqrt(head_size))
def naive_attention_impl(query, key, value, mask, scale):
qk_res = paddle.matmul(query, key, transpose_y=True)
attention = qk_res * scale
attention = attention + mask
softmax_result = paddle.nn.functional.softmax(attention, -1)
result = paddle.matmul(softmax_result, value)
return result
out = naive_attention_impl(query, key, value, mask, scale)
# equals to: out = variable_length_memory_efficient_attention(query, key, value, seq_lens, seq_lens, mask, scale)
print(out.shape) # [batch, seq_len, num_head, head_size]
>>> # doctest: +REQUIRES(env:GPU)
>>> import math
>>> import paddle
>>> from paddle.incubate.nn.functional import variable_length_memory_efficient_attention
>>> paddle.device.set_device('gpu')
>>> batch = 1
>>> num_head = 8
>>> seq_len = 256
>>> head_size = 32
>>> dtype = paddle.float16
>>> query = paddle.randn([batch, num_head, seq_len, head_size], dtype=dtype)
>>> key = paddle.randn([batch, num_head, seq_len, head_size], dtype=dtype)
>>> value = paddle.randn([batch, num_head, seq_len, head_size], dtype=dtype)
>>> seq_lens = paddle.to_tensor([seq_len, ] * batch, dtype='int32')
>>> mask = paddle.randn([batch, 1, seq_len, seq_len], dtype=dtype)
>>> scale = float(1.0 / math.sqrt(head_size))
>>> def naive_attention_impl(query, key, value, mask, scale):
... qk_res = paddle.matmul(query, key, transpose_y=True)
... attention = qk_res * scale
... attention = attention + mask
... softmax_result = paddle.nn.functional.softmax(attention, -1)
... result = paddle.matmul(softmax_result, value)
... return result
>>> out = naive_attention_impl(query, key, value, mask, scale)
>>> # equals to: out = variable_length_memory_efficient_attention(query, key, value, seq_lens, seq_lens, mask, scale)
>>> print(out.shape) # [batch, seq_len, num_head, head_size]
[1, 8, 256, 32]
"""
if scale is None:
head_size = query.shape[3]
Expand Down
68 changes: 34 additions & 34 deletions python/paddle/incubate/optimizer/gradient_merge.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,40 +50,40 @@ class GradientMergeOptimizer:
Examples:
.. code-block:: python
import paddle
import paddle.base as base
import numpy as np
def gen_data(batch_size):
return {"x": np.random.random(size=(batch_size, 32)).astype('float32'),
"y": np.random.random(size=(batch_size, 1)).astype('int64')}
def mlp(input_x, input_y, hid_dim=128, label_dim=2):
fc_1 = paddle.static.nn.fc(x=input_x, size=hid_dim)
prediction = paddle.static.nn.fc(x=[fc_1], size=label_dim, activation='softmax')
cost = paddle.nn.functional.cross_entropy(
input=prediction, label=input_y,
reduction='none', use_softmax=False
)
sum_cost = paddle.mean(cost)
return sum_cost, fc_1, prediction
input_x = paddle.static.data(name="x", shape=[-1,32], dtype='float32')
input_y = paddle.static.data(name="y", shape=[-1,1], dtype='int64')
cost, fc_1, pred = mlp(input_x, input_y)
sgd = paddle.optimizer.Adam(learning_rate=0.01)
sgd = paddle.incubate.optimizer.GradientMergeOptimizer(sgd, k_steps=4, avg=True)
sgd.minimize(cost)
place = base.CPUPlace()
exe = base.Executor(place)
exe.run(base.default_startup_program())
for i in range(10):
cost_val = exe.run(feed=gen_data(32),
program=base.default_main_program(),
fetch_list=[cost.name])
print("step=%d, cost=%f" % (i, cost_val[0]))
>>> import paddle
>>> import numpy as np
>>> paddle.enable_static()
>>> def gen_data(batch_size):
... return {"x": np.random.random(size=(batch_size, 32)).astype('float32'),
... "y": np.random.random(size=(batch_size, 1)).astype('int64')}
>>> def mlp(input_x, input_y, hid_dim=128, label_dim=2):
... fc_1 = paddle.static.nn.fc(x=input_x, size=hid_dim)
... prediction = paddle.static.nn.fc(x=[fc_1], size=label_dim, activation='softmax')
... cost = paddle.nn.functional.cross_entropy(
... input=prediction, label=input_y,
... reduction='none', use_softmax=False
... )
... sum_cost = paddle.mean(cost)
... return sum_cost, fc_1, prediction
>>> input_x = paddle.static.data(name="x", shape=[-1,32], dtype='float32')
>>> input_y = paddle.static.data(name="y", shape=[-1,1], dtype='int64')
>>> cost, fc_1, pred = mlp(input_x, input_y)
>>> sgd = paddle.optimizer.Adam(learning_rate=0.01)
>>> sgd = paddle.incubate.optimizer.GradientMergeOptimizer(sgd, k_steps=4, avg=True)
>>> sgd.minimize(cost)
>>> place = paddle.CPUPlace()
>>> exe = paddle.static.Executor(place)
>>> exe.run(paddle.static.default_startup_program())
>>> for i in range(10):
... cost_val = exe.run(feed=gen_data(32),
... program=paddle.static.default_main_program(),
... fetch_list=[cost.name])
... print("step=%d, cost=%f" % (i, cost_val[0]))
"""

GRAD_MERGE_COND_NAME = "grad_merge_cond_name"
Expand Down
35 changes: 17 additions & 18 deletions python/paddle/incubate/optimizer/lars_momentum.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,24 +63,23 @@ class LarsMomentumOptimizer(Optimizer):
Examples:
.. code-block:: python
import paddle
import paddle.base as base
import numpy as np
paddle.enable_static()
np_inp = np.array([[1.0, 2.0], [3.0, 4.0]], dtype=np.float32)
inp = paddle.static.data(
name="inp", shape=[2, 2], dtype='float32')
out = paddle.static.nn.fc(inp, size=3)
out = paddle.sum(out)
optimizer = base.optimizer.LarsMomentumOptimizer(learning_rate=0.001, momentum=0.9)
optimizer.minimize(out)
exe = base.Executor(base.CPUPlace())
exe.run(base.default_startup_program())
exe.run(
feed={"inp": np_inp},
fetch_list=[out.name])
>>> import paddle
>>> import numpy as np
>>> paddle.enable_static()
>>> np_inp = np.array([[1.0, 2.0], [3.0, 4.0]], dtype=np.float32)
>>> inp = paddle.static.data(
... name="inp", shape=[2, 2], dtype='float32')
>>> out = paddle.static.nn.fc(inp, size=3)
>>> out = paddle.sum(out)
>>> optimizer = paddle.incubate.optimizer.LarsMomentumOptimizer(learning_rate=0.001, momentum=0.9)
>>> optimizer.minimize(out)
>>> exe = paddle.static.Executor(paddle.CPUPlace())
>>> exe.run(paddle.static.default_startup_program())
>>> exe.run(
... feed={"inp": np_inp},
... fetch_list=[out.name])
"""
_velocity_acc_str = "velocity"

Expand Down

0 comments on commit 808239e

Please sign in to comment.