Skip to content

Commit

Permalink
wip
Browse files Browse the repository at this point in the history
Signed-off-by: Michal Miotk <[email protected]>
  • Loading branch information
michal-miotk committed Nov 21, 2024
1 parent e1fcc01 commit 9fa6fff
Show file tree
Hide file tree
Showing 6 changed files with 94 additions and 91 deletions.
1 change: 0 additions & 1 deletion src/plugins/intel_gpu/include/intel_gpu/primitives/rnn.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,6 @@ struct RNNParams : public primitive_base<PType> {
offset_order(offset_order),
direction(direction) {
std::vector<std::string> pids{initial_hidden_state.pid, initial_cell_state.pid, W.pid, R.pid, B.pid, seq_lenghts.pid, out1_prim_id, out2_prim_id};
assert(direction == ov::op::RecurrentSequenceDirection::FORWARD || direction == ov::op::RecurrentSequenceDirection::REVERSE);
for (auto pid : pids) {
if (!pid.empty()) {
primitive_base<PType>::input.push_back(pid);
Expand Down
23 changes: 12 additions & 11 deletions src/plugins/intel_gpu/src/graph/impls/onednn/lstm_seq_onednn.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -93,15 +93,15 @@ struct lstm_seq_onednn : typed_primitive_onednn_impl<lstm_seq> {
}

{
auto& output = instance.input_memory(7);
auto offset = onednn::get_offset(instance.get_input_layout(7), _pd.dnnl::primitive_desc_base::dst_desc(1));
auto& output = instance.output_memory(1);
auto offset = onednn::get_offset(instance.get_output_layout(1), _pd.dnnl::primitive_desc_base::dst_desc(1));
auto mem = output.get_onednn_memory(_pd.dnnl::primitive_desc_base::dst_desc(1), offset);
args.insert({DNNL_ARG_DST_ITER, mem});
}

{
auto& output = instance.input_memory(8);
auto offset = onednn::get_offset(instance.get_input_layout(8), _pd.dnnl::primitive_desc_base::dst_desc(2));
auto& output = instance.output_memory(2);
auto offset = onednn::get_offset(instance.get_output_layout(2), _pd.dnnl::primitive_desc_base::dst_desc(2));
auto mem = output.get_onednn_memory(_pd.dnnl::primitive_desc_base::dst_desc(2), offset);
args.insert({DNNL_ARG_DST_ITER_C, mem});
}
Expand Down Expand Up @@ -134,34 +134,35 @@ struct lstm_seq_onednn : typed_primitive_onednn_impl<lstm_seq> {
const dnnl::primitive_attr& attr,
ov::op::RecurrentSequenceDirection direction) {
auto prim = impl_params.typed_desc<lstm_seq>();
auto num_dir = static_cast<size_t>(prim->num_directions());
const auto& src_shape = impl_params.get_input_layout(0).get_shape();
auto mod_src_shape = src_shape;
std::swap(mod_src_shape[0], mod_src_shape[1]);
auto input_md = onednn::layout_to_memory_desc(impl_params.get_input_layout(0).clone_with_other_shape(mod_src_shape), dnnl::memory::format_tag::abc);
auto initial_hidden_shape_mod = impl_params.get_input_layout(1).get_shape();
initial_hidden_shape_mod = { 1, 1, initial_hidden_shape_mod[0], initial_hidden_shape_mod[2] };
initial_hidden_shape_mod = { 1, num_dir, initial_hidden_shape_mod[0], initial_hidden_shape_mod[2] };
auto initial_hidden = onednn::layout_to_memory_desc(impl_params.get_input_layout(1).clone_with_other_shape(initial_hidden_shape_mod));
auto initial_cell = onednn::layout_to_memory_desc(impl_params.get_input_layout(2).clone_with_other_shape(initial_hidden_shape_mod));
auto W_shape_mod = impl_params.get_input_layout(3).get_shape();
W_shape_mod = {1, 1, W_shape_mod[2], 4, W_shape_mod[1]/4};
W_shape_mod = {1, num_dir, W_shape_mod[2], 4, W_shape_mod[1]/4};
auto w_layout = impl_params.get_input_layout(3).clone_with_other_shape(W_shape_mod);
w_layout.format = cldnn::format::bfzyx;
auto W_md = onednn::layout_to_memory_desc(w_layout);
auto R_shape_mod = impl_params.get_input_layout(4).get_shape();
R_shape_mod = {1, 1, R_shape_mod[2], 4, R_shape_mod[1]/4};
R_shape_mod = {1, num_dir, R_shape_mod[2], 4, R_shape_mod[1]/4};
auto r_layout = impl_params.get_input_layout(4).clone_with_other_shape(R_shape_mod);
r_layout.format = cldnn::format::bfzyx;
auto R_md = onednn::layout_to_memory_desc(r_layout);
auto B_shape_mod = impl_params.get_input_layout(5).get_shape();
B_shape_mod = {1, 1, 4, B_shape_mod[1]/4};
B_shape_mod = {1, num_dir, 4, B_shape_mod[1]/4};
auto b_layout = impl_params.get_input_layout(5).clone_with_other_shape(B_shape_mod);
b_layout.format = cldnn::format::bfyx;
auto B_md = onednn::layout_to_memory_desc(b_layout);
auto out_shape = impl_params.get_output_layout().get_shape();
out_shape = {out_shape[2], out_shape[0], out_shape[3], 1};
out_shape = {out_shape[2], out_shape[0], out_shape[3]*num_dir};
auto output_md = onednn::layout_to_memory_desc(impl_params.get_output_layout().clone_with_other_shape(out_shape), dnnl::memory::format_tag::abc);
auto output1_md = onednn::layout_to_memory_desc(impl_params.get_input_layout(7).clone_with_other_shape(initial_hidden_shape_mod));
auto output2_md = onednn::layout_to_memory_desc(impl_params.get_input_layout(7).clone_with_other_shape(initial_hidden_shape_mod));
auto output1_md = onednn::layout_to_memory_desc(impl_params.get_output_layout(1).clone_with_other_shape(initial_hidden_shape_mod));
auto output2_md = onednn::layout_to_memory_desc(impl_params.get_output_layout(2).clone_with_other_shape(initial_hidden_shape_mod));
OPENVINO_ASSERT(input_md.get_format_kind() != dnnl::memory::format_kind::any,
"[GPU] The format kind of the input memory descriptor of onednn lstm_seq cannot be 'any'.");
OPENVINO_ASSERT(output_md.get_format_kind() != dnnl::memory::format_kind::any,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,14 +14,26 @@

#ifdef SEQUENCE
#define GET_IN0_IDX(b, f, y) INPUT0_GET_INDEX(b, f, y, 0)
#define GET_IN1_IDX(b, f, y) INPUT1_GET_INDEX(b, f, y, 0)
#define GET_IN3_IDX(b, f) INPUT3_GET_INDEX(0, b, f, 0)
#define GET_IN4_IDX(b, f) INPUT4_GET_INDEX(0, b, f, 0)
#if DIRECTION == 2
#define GET_IN1_IDX(b, f, y) INPUT1_GET_INDEX(b, f, y, 0)
#define GET_IN2_IDX(b, f, y) INPUT2_GET_INDEX(b, f, y, 0)
#define GET_IN3_IDX(b, f, y) INPUT3_GET_INDEX(b, f, y, 0)
#define GET_IN4_IDX(b, f, y) INPUT4_GET_INDEX(b, f, y, 0)
#define GET_IN5_IDX(b, f) INPUT5_GET_INDEX(b, f, 0, 0)
#else
#define GET_IN1_IDX(b, f, y) INPUT1_GET_INDEX(b, 0, y, 0)
#define GET_IN2_IDX(b, f, y) INPUT2_GET_INDEX(b, 0, y, 0)
#define GET_IN3_IDX(b, f, y) INPUT3_GET_INDEX(0, f, y, 0)
#define GET_IN4_IDX(b, f, y) INPUT4_GET_INDEX(0, f, y, 0)
#define GET_IN5_IDX(b, f) INPUT5_GET_INDEX(0, f, 0, 0)
#endif
#else
#define GET_IN0_IDX(b, f, y) INPUT0_GET_INDEX(b, y, 0, 0)
#define GET_IN0_IDX(b, f, y) INPUT0_GET_INDEX(b, y, 0, 0)
#define GET_IN1_IDX(b, f, y) INPUT1_GET_INDEX(b, y, 0, 0)
#define GET_IN3_IDX(b, f) INPUT3_GET_INDEX(b, f, 0, 0)
#define GET_IN4_IDX(b, f) INPUT4_GET_INDEX(b, f, 0, 0)
#define GET_IN2_IDX(b, f, y) INPUT2_GET_INDEX(b, y, 0, 0)
#define GET_IN3_IDX(b, f, y) INPUT3_GET_INDEX(f, y, 0, 0)
#define GET_IN4_IDX(b, f, y) INPUT4_GET_INDEX(f, y, 0, 0)
#define GET_IN5_IDX(b, f) INPUT5_GET_INDEX(f, 0, 0, 0)
#endif

KERNEL(lstm_cell_and_seq_bfyx)(
Expand Down Expand Up @@ -51,7 +63,7 @@ KERNEL(lstm_cell_and_seq_bfyx)(
const uint real_seq_length = 1;
#endif
#if DIRECTION == 2
for(uint dir=0;dir<DIRECTION+1;dir++) {
for(uint dir=0;dir<DIRECTION;dir++) {
#else
uint dir = DIRECTION;
#endif
Expand All @@ -77,23 +89,35 @@ KERNEL(lstm_cell_and_seq_bfyx)(
const uint weight_idx = hidden_idx+weight_offsets[k];
uint hblock_num = HIDDEN_SIZE/VEC_SIZE;
unroll_for(uint j=0;j<hblock_num;++j) {
INPUT4_TYPE_VEC r_block = READ_VEC(0, &R[GET_IN4_IDX(weight_idx, j*VEC_SIZE)]);
INPUT4_TYPE_VEC r_block = READ_VEC(0, &R[GET_IN4_IDX(dir, weight_idx, j*VEC_SIZE)]);
if(i==0){
INPUT1_TYPE_VEC initial_block = READ_VEC(0, &initial_hidden_state[GET_IN1_IDX(b, 0, j*VEC_SIZE)]);
INPUT1_TYPE_VEC initial_block = READ_VEC(0, &initial_hidden_state[GET_IN1_IDX(b, dir, j*VEC_SIZE)]);
hidden_result += dot(initial_block, r_block);
}else{
#ifdef SEQUENCE
OUTPUT_TYPE_VEC h_block = READ_VEC(0, &hidden_history[OUTPUT_GET_INDEX(b, 0, prev_idx, j*VEC_SIZE)]);
#if DIRECTION == 2
OUTPUT_TYPE_VEC h_block = READ_VEC(0, &hidden_history[OUTPUT_GET_INDEX(b, dir, prev_idx, j*VEC_SIZE)]);
#else
OUTPUT_TYPE_VEC h_block = READ_VEC(0, &hidden_history[OUTPUT_GET_INDEX(b, 0, prev_idx, j*VEC_SIZE)]);
#endif
hidden_result += dot(h_block, r_block);
#endif
}
}
unroll_for(uint j=hblock_num*VEC_SIZE;j<HIDDEN_SIZE;++j) {
if(i==0){
hidden_result += initial_hidden_state[GET_IN1_IDX(b, 0, j)]*R[GET_IN4_IDX(weight_idx, j)];
#if DIRECTION == 2
hidden_result += initial_hidden_state[GET_IN1_IDX(b, dir, j)]*R[GET_IN4_IDX(dir, weight_idx, j)];
#else
hidden_result += initial_hidden_state[GET_IN1_IDX(b, 0, j)]*R[GET_IN4_IDX(dir, weight_idx, j)];
#endif
}else{
#ifdef SEQUENCE
hidden_result += hidden_history[OUTPUT_GET_INDEX(b, 0, prev_idx, j)]*R[GET_IN4_IDX(weight_idx, j)];
#if DIRECTION == 2
hidden_result += hidden_history[OUTPUT_GET_INDEX(b, dir, prev_idx, j)]*R[GET_IN4_IDX(dir, weight_idx, j)];
#else
hidden_result += hidden_history[OUTPUT_GET_INDEX(b, 0, prev_idx, j)]*R[GET_IN4_IDX(dir, weight_idx, j)];
#endif
#endif
}
}
Expand All @@ -107,22 +131,18 @@ KERNEL(lstm_cell_and_seq_bfyx)(
} else {
x_block = READ_VEC(0, &x[GET_IN0_IDX(b, i, j*VEC_SIZE)]);
}
INPUT3_TYPE_VEC w_block = READ_VEC(0, &W[GET_IN3_IDX(weight_idx, j*VEC_SIZE)]);
INPUT3_TYPE_VEC w_block = READ_VEC(0, &W[GET_IN3_IDX(dir, weight_idx, j*VEC_SIZE)]);
input_result += dot(x_block, w_block);
}

unroll_for(uint j=block_num*VEC_SIZE;j<INPUT_SIZE;++j) { //leftovers
if (dir == 1) {
input_result += x[GET_IN0_IDX(b, real_seq_length-1-i, j)]*W[GET_IN3_IDX(weight_idx, j)];
input_result += x[GET_IN0_IDX(b, real_seq_length-1-i, j)]*W[GET_IN3_IDX(dir, weight_idx, j)];
} else {
input_result += x[GET_IN0_IDX(b, i, j)]*W[GET_IN3_IDX(weight_idx, j)];
input_result += x[GET_IN0_IDX(b, i, j)]*W[GET_IN3_IDX(dir, weight_idx, j)];
}
}
#ifdef SEQUENCE
gate_output[k] = hidden_result + input_result + TO_ACCUMULATOR_TYPE(B[INPUT5_GET_INDEX(0, weight_idx, 0, 0)]);
#else
gate_output[k] = hidden_result + input_result + TO_ACCUMULATOR_TYPE(B[INPUT5_GET_INDEX(weight_idx, 0, 0, 0)]);
#endif
gate_output[k] = hidden_result + input_result + TO_ACCUMULATOR_TYPE(B[GET_IN5_IDX(dir, weight_idx)]);
switch(k){
case 0:
case 1:
Expand All @@ -138,11 +158,7 @@ KERNEL(lstm_cell_and_seq_bfyx)(
}
ACCUMULATOR_TYPE temp_cell_state;
if (i==0){
#ifdef SEQUENCE
temp_cell_state = gate_output[0]*initial_cell_state[INPUT2_GET_INDEX(b, 0, hidden_idx, 0)] + gate_output[1]*gate_output[2];
#else
temp_cell_state = gate_output[0]*initial_cell_state[INPUT2_GET_INDEX(b, hidden_idx, 0, 0)] + gate_output[1]*gate_output[2];
#endif
temp_cell_state = gate_output[0]*initial_cell_state[GET_IN2_IDX(b, dir, hidden_idx)] + gate_output[1]*gate_output[2];
}else{
temp_cell_state *= gate_output[0];
temp_cell_state += gate_output[1]*gate_output[2];
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,30 @@

#include "include/batch_headers/fetch_data.cl"

#ifdef SEQUENCE
#define GET_IN0_IDX(b, f, y) INPUT0_GET_INDEX(b, f, y, 0)
#if DIRECTION == 2
#define GET_IN1_IDX(b, f, y) INPUT1_GET_INDEX(b, f, y, 0)
#define GET_IN2_IDX(b, f, y) INPUT2_GET_INDEX(b, f, y, 0)
#define GET_IN3_IDX(b, f, y) INPUT3_GET_INDEX(b, f, y, 0)
#define GET_IN4_IDX(b, f, y) INPUT4_GET_INDEX(b, f, y, 0)
#define GET_IN5_IDX(b, f) INPUT5_GET_INDEX(b, f, 0, 0)
#else
#define GET_IN1_IDX(b, f, y) INPUT1_GET_INDEX(b, 0, y, 0)
#define GET_IN2_IDX(b, f, y) INPUT2_GET_INDEX(b, 0, y, 0)
#define GET_IN3_IDX(b, f, y) INPUT3_GET_INDEX(0, f, y, 0)
#define GET_IN4_IDX(b, f, y) INPUT4_GET_INDEX(0, f, y, 0)
#define GET_IN5_IDX(b, f) INPUT5_GET_INDEX(0, f, 0, 0)
#endif
#else
#define GET_IN0_IDX(b, f, y) INPUT0_GET_INDEX(b, y, 0, 0)
#define GET_IN1_IDX(b, f, y) INPUT1_GET_INDEX(b, y, 0, 0)
#define GET_IN2_IDX(b, f, y) INPUT2_GET_INDEX(b, y, 0, 0)
#define GET_IN3_IDX(b, f, y) INPUT3_GET_INDEX(f, y, 0, 0)
#define GET_IN4_IDX(b, f, y) INPUT4_GET_INDEX(f, y, 0, 0)
#define GET_IN5_IDX(b, f) INPUT5_GET_INDEX(f, 0, 0, 0)
#endif

KERNEL(lstm_cell_and_seq_ref)(
const __global INPUT0_TYPE* x,
const __global INPUT1_TYPE* initial_hidden_state,
Expand Down Expand Up @@ -31,7 +55,7 @@ KERNEL(lstm_cell_and_seq_ref)(
const uint real_seq_length = 1;
#endif
#if DIRECTION == 2
for(uint dir=0;dir<DIRECTION+1;dir++) {
for(uint dir=0;dir<DIRECTION;dir++) {
#else
uint dir = DIRECTION;
#endif
Expand All @@ -57,34 +81,27 @@ KERNEL(lstm_cell_and_seq_ref)(
const uint weight_idx = hidden_idx+weight_offsets[k];
unroll_for(uint j=0;j<HIDDEN_SIZE;++j) {
if(i==0){
#ifdef SEQUENCE
hidden_result += initial_hidden_state[INPUT1_GET_INDEX_SAFE(b, 0, j, 0)]*R[INPUT4_GET_INDEX_SAFE(0, weight_idx, j, 0)];
#else
hidden_result += initial_hidden_state[INPUT1_GET_INDEX_SAFE(b, j, 0, 0)]*R[INPUT4_GET_INDEX_SAFE(weight_idx, j, 0, 0)];
#endif
hidden_result += initial_hidden_state[GET_IN1_IDX(b, dir, j)]*R[GET_IN4_IDX(dir, weight_idx, j)];
}else{
#ifdef SEQUENCE
hidden_result += hidden_history[OUTPUT_GET_INDEX_SAFE(b, 0, prev_idx, j)]*R[INPUT4_GET_INDEX_SAFE(0, weight_idx, j, 0)];
#if DIRECTION == 2
hidden_result += hidden_history[OUTPUT_GET_INDEX_SAFE(b, dir, prev_idx, j)]*R[GET_IN4_IDX(dir, weight_idx, j)];
#else
hidden_result += hidden_history[OUTPUT_GET_INDEX_SAFE(b, 0, prev_idx, j)]*R[GET_IN4_IDX(0, weight_idx, j)];
#endif
#endif
}
}

unroll_for(uint j=0;j<INPUT_SIZE;++j) {
if (dir == 1) { //reverse
input_result += x[INPUT0_GET_INDEX_SAFE(b, real_seq_length-1-i, j, 0)]*W[INPUT3_GET_INDEX_SAFE(0, weight_idx, j, 0)];
input_result += x[GET_IN0_IDX(b, real_seq_length-1-i, j)]*W[GET_IN3_IDX(dir, weight_idx, j)];
} else {
#ifdef SEQUENCE
input_result += x[INPUT0_GET_INDEX_SAFE(b, i, j, 0)]*W[INPUT3_GET_INDEX_SAFE(0, weight_idx, j, 0)];
#else
input_result += x[INPUT0_GET_INDEX_SAFE(b, j, 0, 0)]*W[INPUT3_GET_INDEX_SAFE(weight_idx, j, 0, 0)];
#endif
input_result += x[GET_IN0_IDX(b, i, j)]*W[GET_IN3_IDX(dir, weight_idx, j)];
}
}
#ifdef SEQUENCE
gate_output[k] = hidden_result + input_result + TO_ACCUMULATOR_TYPE(B[INPUT5_GET_INDEX_SAFE(0, weight_idx, 0, 0)]);
#else
gate_output[k] = hidden_result + input_result + TO_ACCUMULATOR_TYPE(B[INPUT5_GET_INDEX_SAFE(weight_idx, 0, 0, 0)]);
#endif
gate_output[k] = hidden_result + input_result + TO_ACCUMULATOR_TYPE(B[GET_IN5_IDX(dir, weight_idx)]);

switch(k){
case 0:
case 1:
Expand All @@ -100,21 +117,16 @@ KERNEL(lstm_cell_and_seq_ref)(
}
ACCUMULATOR_TYPE temp_cell_state;
if (i==0){
#ifdef SEQUENCE
temp_cell_state = gate_output[0]*initial_cell_state[INPUT2_GET_INDEX_SAFE(b, 0, hidden_idx, 0)] + gate_output[1]*gate_output[2];
#else
temp_cell_state = gate_output[0]*initial_cell_state[INPUT2_GET_INDEX_SAFE(b, hidden_idx, 0, 0)] + gate_output[1]*gate_output[2];
#endif
temp_cell_state = gate_output[0]*initial_cell_state[GET_IN2_IDX(b, dir, hidden_idx)] + gate_output[1]*gate_output[2];
}else{
temp_cell_state *= gate_output[0];
temp_cell_state += gate_output[1]*gate_output[2];
}

#if DIRECTION == 1 //reverse
const uint cur_history_idx = real_seq_length - 1 - i ;
#else
const uint cur_history_idx = i;
#endif
uint cur_history_idx = i;
if (dir == 1) { //reverse
cur_history_idx = real_seq_length - 1 - i ;
}
#ifdef SEQUENCE
#if DIRECTION == 2
hidden_state[OUTPUT1_GET_INDEX_SAFE(b, dir, hidden_idx, 0)] = gate_output[3]*ACTIVATION_H(temp_cell_state, ACTIVATION_PARAMS_H);
Expand All @@ -129,6 +141,7 @@ KERNEL(lstm_cell_and_seq_ref)(
hidden_history[OUTPUT_GET_INDEX_SAFE(b, dir, cur_history_idx, hidden_idx)] = hidden_state[OUTPUT1_GET_INDEX_SAFE(b, dir, hidden_idx, 0)];
#else // DIRECTION == 2
hidden_history[OUTPUT_GET_INDEX_SAFE(b, 0, cur_history_idx, hidden_idx)] = hidden_state[OUTPUT1_GET_INDEX_SAFE(b, 0, hidden_idx, 0)];
#endif
#endif
if(i==real_seq_length-1){
#ifdef SEQUENCE
Expand Down
Loading

0 comments on commit 9fa6fff

Please sign in to comment.