Skip to content

Commit

Permalink
Fix Device Event Creation (#57574)
Browse files Browse the repository at this point in the history
* Fix Device Event Creation

* Fix Device Event Test

* fix test
  • Loading branch information
eee4017 authored Sep 26, 2023
1 parent a6f1fbf commit d55bb44
Show file tree
Hide file tree
Showing 3 changed files with 12 additions and 8 deletions.
12 changes: 8 additions & 4 deletions paddle/fluid/distributed/collective/process_group_nccl.cc
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ ProcessGroupNCCL::NCCLTask::NCCLTask(const Place& place,
bool sync_op,
bool use_calc_stream)
: TaskStream(rank, comm_type, sync_op, use_calc_stream),
comm_event_(place),
comm_event_(place, platform::GenerateDeviceEventFlag()),
task_place_(place) {}

ProcessGroupNCCL::NCCLTask::~NCCLTask() = default;
Expand Down Expand Up @@ -506,7 +506,9 @@ void ProcessGroupNCCL::CreateNCCLEnvCache(const Place& place,
auto nccl_comm_ctx = this->GetCommContext();
comm_ctx->set_nccl_comm(nccl_comm_ctx->GetNcclComm());

place_to_calc_event_.emplace(place_key, place);
place_to_calc_event_.emplace(
place_key,
platform::DeviceEvent(place, platform::GenerateDeviceEventFlag()));
place_to_calc_ctx_.emplace(place_key, calc_ctx);
place_to_comm_ctx_.emplace(place_key, std::move(comm_ctx));

Expand Down Expand Up @@ -592,7 +594,7 @@ ProcessGroupNCCL::NCCLTask::NCCLTask(
CommType CommType,
const std::vector<phi::DenseTensor>& inputs)
: TaskStream(rank, inputs, CommType),
comm_event_(places[0]),
comm_event_(places[0], platform::GenerateDeviceEventFlag()),
task_place_(places[0]) {}

// create NCCLManager cache for places_key
Expand Down Expand Up @@ -636,7 +638,9 @@ void ProcessGroupNCCL::CreateNCCLManagerCache(
GroupEnd();

// TODO(sunyilun): for compatibility, will be removed later
place_to_calc_event_.emplace(places_key, places[0]);
place_to_calc_event_.emplace(
places_key,
platform::DeviceEvent(places[0], platform::GenerateDeviceEventFlag()));
place_to_calc_ctx_.emplace(
places_key,
static_cast<phi::GPUContext*>(
Expand Down
2 changes: 1 addition & 1 deletion paddle/fluid/platform/device_event_base.h
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ enum EventStatus {

class DeviceEvent {
public:
explicit DeviceEvent(const platform::Place& place, unsigned int flag = 0)
explicit DeviceEvent(const platform::Place& place, unsigned int flag)
: event_(), place_(place), flag_(flag) {
type_id_ = DeviceTypeToId(platform::Place2DeviceType(place));
PADDLE_ENFORCE_LT(type_id_,
Expand Down
6 changes: 3 additions & 3 deletions paddle/fluid/platform/device_event_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ TEST(DeviceEvent, CUDA) {

ASSERT_NE(context, nullptr);
// case 1. test for event_creator
DeviceEvent event(place);
DeviceEvent event(place, paddle::platform::GenerateDeviceEventFlag());
ASSERT_NE(event.GetEvent().get(), nullptr);
bool status = event.Query();
ASSERT_EQ(status, true);
Expand Down Expand Up @@ -86,7 +86,7 @@ TEST(DeviceEvent, CUDA) {

ASSERT_NE(context, nullptr);
// case 1. test for event_creator
DeviceEvent event(place);
DeviceEvent event(place, paddle::platform::GenerateDeviceEventFlag());
ASSERT_NE(event.GetEvent().get(), nullptr);
bool status = event.Query();
ASSERT_EQ(status, true);
Expand Down Expand Up @@ -127,7 +127,7 @@ TEST(DeviceEvent, CUDA) {
TEST(DeviceEvent, CPU) {
using paddle::platform::CPUPlace;
auto place = CPUPlace();
DeviceEvent event(place);
DeviceEvent event(place, paddle::platform::GenerateDeviceEventFlag());
auto& pool = DeviceContextPool::Instance();
auto* context = pool.Get(place);

Expand Down

0 comments on commit d55bb44

Please sign in to comment.