Skip to content

Commit

Permalink
fix device id issue for xpu eager mode (#48076)
Browse files Browse the repository at this point in the history
* fix device id issue for xpu eager

xpu device id is not correctly set in eager mode, thus vars are on dev0 unless
XPUDeviceGurad is called, leading to this error message for all node rank != 0:
"NotImplementedError: (Unimplemented) Place Place(xpu:0) is not supported."

* fix typo

* fix pybind error
  • Loading branch information
XiaociZhang authored Nov 18, 2022
1 parent 14a6e67 commit 3b18d96
Show file tree
Hide file tree
Showing 3 changed files with 11 additions and 1 deletion.
1 change: 1 addition & 0 deletions paddle/fluid/distributed/collective/ProcessGroupBKCL.cc
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,7 @@ void ProcessGroupBKCL::BroadcastUniqueBKCLID(BKCLUniqueId* bkcl_id) {

void ProcessGroupBKCL::CreateBKCLEnvCache(const Place& place,
const std::string& place_key) {
platform::XPUDeviceGuard guard(place.GetDeviceId());
BKCLUniqueId bkcl_id;
if (rank_ == 0) {
PADDLE_ENFORCE_XPU_SUCCESS(bkcl_get_unique_id(&bkcl_id));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,15 @@ def FindParsingFunctionFromAttributeType(atype):
#else
PADDLE_THROW(paddle::platform::errors::PreconditionNotMet(
"PaddlePaddle should compile with CUSTOM_DEVICE if use CustomPlace."));
#endif
}}
if (paddle::platform::is_xpu_place(place)) {{
#if defined(PADDLE_WITH_XPU)
phi::backends::xpu::SetXPUDeviceId(place.device);
VLOG(4) <<"CurrentDeviceId: " << phi::backends::xpu::GetXPUCurrentDeviceId() << " from " << (int)place.device;
#else
PADDLE_THROW(paddle::platform::errors::PreconditionNotMet(
"PaddlePaddle should compile with XPU if use XPUPlace."));
#endif
}}
"""
Expand Down
2 changes: 1 addition & 1 deletion paddle/fluid/pybind/distributed_py.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1284,7 +1284,7 @@ void BindDistributed(py::module *m) {
auto processGroupBKCL =
py::class_<distributed::ProcessGroupBKCL,
std::shared_ptr<distributed::ProcessGroupBKCL>>(
*m, "ProcessGroupBKCL", ProcessGroup)
*m, "ProcessGroupBKCL", ProcessGroupStream)
.def(py::init<const std::shared_ptr<distributed::Store> &,
int,
int,
Expand Down

0 comments on commit 3b18d96

Please sign in to comment.