From b89f5b55556e567377544b25c67457899f458bc3 Mon Sep 17 00:00:00 2001 From: co63oc Date: Wed, 5 Jul 2023 16:23:29 +0800 Subject: [PATCH 1/2] Add api_difference doc --- .../api_difference/cuda/torch.cuda.device.md | 40 +++++++++++++++++++ 1 file changed, 40 insertions(+) create mode 100644 docs/guides/model_convert/convert_from_pytorch/api_difference/cuda/torch.cuda.device.md diff --git a/docs/guides/model_convert/convert_from_pytorch/api_difference/cuda/torch.cuda.device.md b/docs/guides/model_convert/convert_from_pytorch/api_difference/cuda/torch.cuda.device.md new file mode 100644 index 00000000000..9dba11321f8 --- /dev/null +++ b/docs/guides/model_convert/convert_from_pytorch/api_difference/cuda/torch.cuda.device.md @@ -0,0 +1,40 @@ +## [参数不一致]torch.cuda.device + +### [torch.cuda.device](https://pytorch.org/docs/1.13/generated/torch.cuda.device.html#torch.cuda.device) + +```python +torch.cuda.device(device) +``` + +### [paddle.CUDAPlace](https://www.paddlepaddle.org.cn/documentation/docs/zh/develop/api/paddle/CUDAPlace_cn.html) + +```python +paddle.CUDAPlace(id) +``` + +其中 Pytorch 与 Paddle 的参数支持类型不一致,具体如下: + +### 参数映射 + +| PyTorch | PaddlePaddle | 备注 | +| ------- | ------------ | -------------------------------------------------------------------------------- | +| device | id | GPU 的设备 ID, Pytorch 支持 torch.device 和 int,Paddle 支持 int,需要进行转写。 | + +### 转写示例 + +#### device: 获取 device 参数,对其取 device.index 值 + +```python +# Pytorch 写法 +torch.cuda.device(torch.device('cuda').index) + +# Paddle 写法 +paddle.CUDAPlace(0) + +# 增加 index +# Pytorch 写法 +torch.cuda.device(torch.device('cuda', index=1).index) + +# Paddle 写法 +paddle.CUDAPlace(1) +``` From 6779681909a6b8c8818b60fd2a50ffef3cebe542 Mon Sep 17 00:00:00 2001 From: co63oc Date: Wed, 5 Jul 2023 17:02:20 +0800 Subject: [PATCH 2/2] Fix --- .../api_difference/cuda/torch.cuda.device.md | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/docs/guides/model_convert/convert_from_pytorch/api_difference/cuda/torch.cuda.device.md b/docs/guides/model_convert/convert_from_pytorch/api_difference/cuda/torch.cuda.device.md index 9dba11321f8..1ba1e7da761 100644 --- a/docs/guides/model_convert/convert_from_pytorch/api_difference/cuda/torch.cuda.device.md +++ b/docs/guides/model_convert/convert_from_pytorch/api_difference/cuda/torch.cuda.device.md @@ -26,15 +26,15 @@ paddle.CUDAPlace(id) ```python # Pytorch 写法 -torch.cuda.device(torch.device('cuda').index) +torch.cuda.device(torch.device('cuda')) # Paddle 写法 paddle.CUDAPlace(0) # 增加 index # Pytorch 写法 -torch.cuda.device(torch.device('cuda', index=1).index) +torch.cuda.device(torch.device('cuda', index=index)) # Paddle 写法 -paddle.CUDAPlace(1) +paddle.CUDAPlace(index) ```