Skip to content

Commit

Permalink
fix torch.Tensor.scatter.md (#6099)
Browse files Browse the repository at this point in the history
* fix torch.Tensor.scatter.md

* add torch.scatter.md

* scatter add reduce

* scatter reduce

* fix torch.scatter.md
  • Loading branch information
LokeZhou authored Aug 22, 2023
1 parent f0a6139 commit d953f25
Show file tree
Hide file tree
Showing 2 changed files with 39 additions and 26 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,13 @@
### [torch.Tensor.scatter](https://pytorch.org/docs/stable/generated/torch.Tensor.scatter.html#torch.Tensor.scatter)

```python
torch.Tensor.scatter(dim, index, src)
torch.Tensor.scatter(dim, index, src, reduce=None)
```

### [paddle.Tensor.put_along_axis](https://www.paddlepaddle.org.cn/documentation/docs/zh/develop/api/paddle/Tensor_cn.html#put-along-axis-arr-index-value-axis-reduce-assign)

```python
paddle.Tensor.put_along_axis(index, value, axis, reduce="assign")
paddle.Tensor.put_along_axis(indices, values, axis, reduce="assign")

```

Expand All @@ -19,6 +19,6 @@ paddle.Tensor.put_along_axis(index, value, axis, reduce="assign")
| PyTorch | PaddlePaddle | 备注 |
| ------- | ------------ | ------- |
| dim | axis | 表示在哪一个维度 scatter ,仅参数名不一致。 |
| index | index | 表示输入的索引张量,仅参数名不一致。 |
| src | value | 表示需要插入的值,仅参数名不一致。 |
| - | reduce | 归约操作类型,PyTorch 无此参数, Paddle 保持默认即可|
| index | indices | 表示输入的索引张量,仅参数名不一致。 |
| src | values | 表示需要插入的值,仅参数名不一致。 |
| reduce | reduce | 归约操作类型 |
Original file line number Diff line number Diff line change
@@ -1,31 +1,44 @@
## [ 仅 paddle 参数更多 ]torch.scatter
## [torch 参数更多]torch.scatter

### [torch.scatter](https://pytorch.org/docs/stable/generated/torch.scatter.html)
### [torch.scatter](https://pytorch.org/docs/2.0/generated/torch.scatter.html?highlight=torch+scatter#torch.scatter)

```python
torch.scatter(input,
dim,
index,
src)
torch.scatter(input,dim, index, src, reduce=None,out=None)
```

### [paddle.put_along_axis](https://www.paddlepaddle.org.cn/documentation/docs/zh/develop/api/paddle/put_along_axis_cn.html)
### [paddle.put_along_axis](https://www.paddlepaddle.org.cn/documentation/docs/zh/api/paddle/put_along_axis_cn.html#cn-api-paddle-tensor-put-along-axis)

```python
paddle.put_along_axis(arr,
indices,
values,
axis,
reduce='assign')
paddle.put_along_axis(arr,indices, values, axis, reduce="assign")

```

其中 Paddle 相比 Pytorch 支持更多其他参数,具体如下:
其中 Paddle 相比 PyTorch 支持更多其他参数,具体如下:

### 参数映射
| PyTorch | PaddlePaddle | 备注 |
| ------- | ------------ | ------- |
| input | arr | 表示输入的 Tensor ,仅参数名不一致。 |
| dim | axis | 表示在哪一个维度 scatter ,仅参数名不一致。 |
| index | indices | 表示输入的索引张量,仅参数名不一致。 |
| src | values | 表示需要插入的值,仅参数名不一致。 |
| reduce | reduce | 归约操作类型 。 |
| out | - | 表示输出的 Tensor,Paddle 无此参数,需要转写。 |


### 转写示例

### 参数差异
| PyTorch | PaddlePaddle | 备注 |
| ------------- | ------------ | ------------------------------------------------------ |
| input | arr | 表示输入 Tensor ,仅参数名不一致。 |
| dim | axis | 表示在哪一个维度 scatter ,仅参数名不一致。 |
| index | indices | 表示输入的索引张量,仅参数名不一致。 |
| src | values | 表示需要插入的值,仅参数名不一致。 |
| - | reduce | 表示插入 values 时的计算方式,PyTorch 无此参数,Paddle 保持默认即可。 |
#### out:指定输出
```python
# Pytorch 写法
index = torch.tensor([[0],[1],[2]])
input = torch.zeros(3, 5)
out = torch.zeros(3, 5)
torch.scatter(input,1, index, 1.0,out=out)

# Paddle 写法
index = paddle.to_tensor(data=[[0], [1], [2]])
input = paddle.zeros(shape=[3, 5])
out = paddle.zeros(shape=[3, 5])
paddle.assign(paddle.put_along_axis(input, 1, index, 1.0), output=out)
```

0 comments on commit d953f25

Please sign in to comment.