Skip to content

Commit

Permalink
[GPU] Optimize Condition operation by better integrating its subnetwo…
Browse files Browse the repository at this point in the history
…rk primitives with the main network
  • Loading branch information
sshlyapn committed Nov 14, 2023
1 parent c451a94 commit 2bf4ca1
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 9 deletions.
2 changes: 1 addition & 1 deletion src/plugins/intel_gpu/src/graph/condition.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -257,7 +257,7 @@ void condition_inst::postprocess_output_memory(network::ptr executed_net, cldnn:
for (auto out_mem_map : branch.output_map) {
auto out_mem_idx = out_mem_map.first;
auto inner_out_id = out_mem_map.second;
auto mem_ptr = executed_net->get_output(inner_out_id).get_memory();
auto mem_ptr = executed_net->get_output_memory(inner_out_id);
if (mem_ptr) {
auto layout = _impl_params->get_output_layout(out_mem_idx);
GPU_DEBUG_LOG << "Reshape output from " << mem_ptr->get_layout().to_short_string()
Expand Down
24 changes: 16 additions & 8 deletions src/plugins/intel_gpu/src/graph/impls/common/condition.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -30,13 +30,15 @@ struct condition_impl : typed_primitive_impl<condition> {
}

event::ptr execute_impl(const std::vector<event::ptr>& events, condition_inst& instance) override {
for (auto& a : events) {
a->wait();
}
auto ev = instance.get_network().get_stream().create_user_event(false);
// Wait for condition statement event only, and pass all other events to sub-network directly
if (!events.empty())
events[0]->wait();

auto& stream = instance.get_network().get_stream();
auto ev = stream.create_user_event(false);
set_node_params(instance.get_node());

auto pred = condition_inst::get_pred_from_memory(instance.pred_memory_ptr(), instance.get_network().get_stream());
auto pred = condition_inst::get_pred_from_memory(instance.pred_memory_ptr(), stream);
network::ptr executed_net = pred ? instance.get_net_true() : instance.get_net_false();
auto branch = pred ? instance.get_branch_true() : instance.get_branch_false();
executed_net->set_shape_predictor(instance.get_network().get_shape_predictor());
Expand All @@ -62,16 +64,22 @@ struct condition_impl : typed_primitive_impl<condition> {
}
}

executed_net->execute({});
// Ignore condition statement event
std::vector<event::ptr> sub_net_events(events.begin() + 1, events.end());
auto sub_net_results = executed_net->execute(sub_net_events);

// Update output layout of impl_param in condition_inst
instance.update_output_layout();

// Set output memory of condition_inst to inner network output memory after inner network execution
instance.postprocess_output_memory(executed_net, branch);

ev->set();
return ev;
std::vector<event::ptr> output_events;
for (auto& output : sub_net_results)
if (output.second.get_event() != nullptr)
output_events.push_back(output.second.get_event());

return stream.group_events(output_events);
}

static std::unique_ptr<primitive_impl> create(const condition_node& arg, const kernel_impl_params&) {
Expand Down

0 comments on commit 2bf4ca1

Please sign in to comment.