Skip to content

Commit

Permalink
[arm-cpu] fix conv+hardswish in int8-int8 compute diff (#8996) (#9001)
Browse files Browse the repository at this point in the history
  • Loading branch information
chenjiaoAngel authored May 10, 2022
1 parent a280a0a commit 6f20460
Show file tree
Hide file tree
Showing 5 changed files with 30 additions and 2 deletions.
25 changes: 23 additions & 2 deletions lite/backends/arm/math/gemv_arm_int8.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1222,6 +1222,8 @@ inline void write_gemv_out(const int* in,
float32x4_t valpha = vdupq_n_f32(alpha);
float32x4_t voffset = vdupq_n_f32(offset);
float32x4_t vthreshold = vdupq_n_f32(threshold);
float32x4_t vmax = vdupq_n_f32(-127.f);

#ifdef __aarch64__
asm volatile(
"cmp %w[cnt], #1\n"
Expand Down Expand Up @@ -1252,6 +1254,11 @@ inline void write_gemv_out(const int* in,
"fmul v0.4s, v4.4s, v5.4s\n"
"fmin v6.4s, v6.4s, %[vthreshold].4s\n"
"fmul v3.4s, v6.4s, v7.4s\n"
// out >= -127
"fcmge v4.4s, v0.4s, %[vmax].4s\n"
"fcmge v5.4s, v3.4s, %[vmax].4s\n"
"bif v0.16b, %[vmax].16b, v4.16b\n"
"bif v3.16b, %[vmax].16b, v5.16b\n"
// fp32 - int32
"fcvtas v4.4s, v0.4s\n"
"fcvtas v5.4s, v3.4s\n"
Expand Down Expand Up @@ -1279,6 +1286,9 @@ inline void write_gemv_out(const int* in,
"fmax v4.4s, v4.4s, %[vzero].4s\n"
"fmin v4.4s, v4.4s, %[vthreshold].4s\n"
"fmul v0.4s, v4.4s, v5.4s\n"
// out >= -127
"fcmge v4.4s, v0.4s, %[vmax].4s\n"
"bif v0.16b, %[vmax].16b, v4.16b\n"
// fp32 - int32
"fcvtas v4.4s, v0.4s\n"
// int32 - int16
Expand All @@ -1298,7 +1308,8 @@ inline void write_gemv_out(const int* in,
[vzero] "w"(vzero),
[valpha] "w"(valpha),
[voffset] "w"(voffset),
[vthreshold] "w"(vthreshold)
[vthreshold] "w"(vthreshold),
[vmax] "w"(vmax)
: "cc",
"memory",
"v0",
Expand Down Expand Up @@ -1349,6 +1360,11 @@ inline void write_gemv_out(const int* in,
"vbif q13, %q[vfive], q10\n"
"vadd.f32 q5, q5, q12\n"
"vadd.f32 q8, q8, q13\n"
// data >= -127
"vcge.f32 q7, q5, %q[vmax]\n"
"vcge.f32 q9, q8, %q[vmax]\n"
"vbif q5, %q[vmax], q7\n"
"vbif q8, %q[vmax], q9\n"
// fp32 -> int32
"vcvt.s32.f32 q7, q5\n"
"vcvt.s32.f32 q9, q8\n"
Expand Down Expand Up @@ -1380,6 +1396,9 @@ inline void write_gemv_out(const int* in,
"vcge.f32 q7, q5, %q[vzero]\n"
"vbif q12, %q[vfive], q7\n"
"vadd.f32 q5, q5, q12\n"
// data >= -127
"vcge.f32 q7, q5, %q[vmax]\n"
"vbif q5, %q[vmax], q7\n"
// fp32 -> int32
"vcvt.s32.f32 q7, q5\n"
// int32 -> int16
Expand All @@ -1400,7 +1419,8 @@ inline void write_gemv_out(const int* in,
[valpha] "w"(valpha),
[voffset] "w"(voffset),
[vthreshold] "w"(vthreshold),
[vfive] "w"(vfive)
[vfive] "w"(vfive),
[vmax] "w"(vmax)
: "cc",
"memory",
"q4",
Expand Down Expand Up @@ -1624,6 +1644,7 @@ bool gemv_int8_trans_oth(const int8_t* A,
memset(zerobuf, 0, sizeof(float) * (M + 16));
const float* bias_ptr = is_bias ? bias : zerobuf;
float six = alpha;

#ifdef __aarch64__
int cnt = N >> 3;
int tail = N & 7;
Expand Down
1 change: 1 addition & 0 deletions lite/kernels/arm/conv_direct.h
Original file line number Diff line number Diff line change
Expand Up @@ -154,6 +154,7 @@ inline bool direct_conv_trans_weights<PRECISION(kInt8), PRECISION(kInt8)>(
}
//! update hardswish parameter
if (act_param.active_type == lite_api::ActivationType::kHardSwish) {
act_param.hard_swish_scale = act_param.hard_swish_scale / out_scale;
act_param.hard_swish_offset = act_param.hard_swish_offset / out_scale;
act_param.hard_swish_threshold = act_param.hard_swish_threshold / out_scale;
}
Expand Down
2 changes: 2 additions & 0 deletions lite/kernels/arm/conv_gemmlike.cc
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,8 @@ void GemmLikeConv<PRECISION(kInt8), PRECISION(kInt8)>::PrepareForRun() {
//! update hardswish parameter
if (param.activation_param.active_type ==
lite_api::ActivationType::kHardSwish) {
param.activation_param.hard_swish_scale =
param.activation_param.hard_swish_scale / param.output_scale;
param.activation_param.hard_swish_offset =
param.activation_param.hard_swish_offset / param.output_scale;
param.activation_param.hard_swish_threshold =
Expand Down
2 changes: 2 additions & 0 deletions lite/kernels/arm/conv_transpose_compute.cc
Original file line number Diff line number Diff line change
Expand Up @@ -154,6 +154,8 @@ void Conv2DTransposeCompute<PRECISION(kInt8),
//! update hardswish parameter
if (param.activation_param.active_type ==
lite_api::ActivationType::kHardSwish) {
param.activation_param.hard_swish_scale =
param.activation_param.hard_swish_scale / param.output_scale;
param.activation_param.hard_swish_offset =
param.activation_param.hard_swish_offset / param.output_scale;
param.activation_param.hard_swish_threshold =
Expand Down
2 changes: 2 additions & 0 deletions lite/kernels/arm/conv_winograd.cc
Original file line number Diff line number Diff line change
Expand Up @@ -196,6 +196,8 @@ void WinogradConv<PRECISION(kInt8), OutType>::ReInitWhenNeeded() {
//! update hardswish parameter
if (param.activation_param.active_type ==
lite_api::ActivationType::kHardSwish) {
param.activation_param.hard_swish_scale =
param.activation_param.hard_swish_scale / param.output_scale;
param.activation_param.hard_swish_offset =
param.activation_param.hard_swish_offset / output_scale;
param.activation_param.hard_swish_threshold =
Expand Down

0 comments on commit 6f20460

Please sign in to comment.