ruy: support dst == int32 for neon out of order case.

PiperOrigin-RevId: 247154463
This commit is contained in:
Renjie Liu 2019-05-07 22:59:23 -07:00 committed by TensorFlower Gardener
parent 34b0836381
commit 8dc0be226b
4 changed files with 152 additions and 13 deletions

View File

@ -37,6 +37,8 @@ limitations under the License.
#ifndef TENSORFLOW_LITE_EXPERIMENTAL_RUY_DISPATCH_H_ #ifndef TENSORFLOW_LITE_EXPERIMENTAL_RUY_DISPATCH_H_
#define TENSORFLOW_LITE_EXPERIMENTAL_RUY_DISPATCH_H_ #define TENSORFLOW_LITE_EXPERIMENTAL_RUY_DISPATCH_H_
#include <limits>
#include "profiling/instrumentation.h" #include "profiling/instrumentation.h"
#include "tensorflow/lite/experimental/ruy/common.h" #include "tensorflow/lite/experimental/ruy/common.h"
#include "tensorflow/lite/experimental/ruy/context.h" #include "tensorflow/lite/experimental/ruy/context.h"
@ -95,6 +97,21 @@ void EnforceZeroPointSupport(LhsScalar lhs_zero_point, RhsScalar rhs_zero_point,
rhs_zero_point != std::numeric_limits<RhsScalar>::lowest()); rhs_zero_point != std::numeric_limits<RhsScalar>::lowest());
} }
template <typename Spec, typename DstScalar>
void EnforceDstSpecSupport(const Spec& spec, DstScalar dst_zero_point) {
if (!std::is_same<typename Spec::DstScalar, std::int32_t>::value) return;
// If user is looking for the raw accumulator, zero_point and all the other
// dequantize fields don't make sense and should not be set.
RUY_DCHECK(dst_zero_point == 0);
RUY_DCHECK(spec.clamp_max == std::numeric_limits<std::int32_t>::max());
RUY_DCHECK(spec.clamp_min == std::numeric_limits<std::int32_t>::min());
RUY_DCHECK(spec.multiplier_fixedpoint == 0);
RUY_DCHECK(spec.multiplier_exponent == 0);
RUY_DCHECK(spec.multiplier_fixedpoint_perchannel == nullptr);
RUY_DCHECK(spec.multiplier_exponent_perchannel == nullptr);
}
inline bool IsColMajorTrMul(const DMatrix& lhs, const DMatrix& rhs, inline bool IsColMajorTrMul(const DMatrix& lhs, const DMatrix& rhs,
const DMatrix& dst) { const DMatrix& dst) {
return IsColMajor(lhs.layout) && IsColMajor(rhs.layout) && return IsColMajor(lhs.layout) && IsColMajor(rhs.layout) &&
@ -152,8 +169,10 @@ void PopulateTrMulParams(TrMulParams* params) {
} }
// If DstScalar is std::int32_t, means user want to get from accumulator // If DstScalar is std::int32_t, means user want to get from accumulator
// results directly, fallback to Path::kStandardCpp. // results directly, if it's not Neon path, will fallback to
if (std::is_same<DstScalar, std::int32_t>::value) { // Path::kStandardCpp.
if (std::is_same<DstScalar, std::int32_t>::value &&
ThePath != Path::kNeon) {
fallback_to_standard_cpp = true; fallback_to_standard_cpp = true;
} }
} }
@ -367,6 +386,7 @@ void DispatchMul(const Matrix<LhsScalar>& lhs, const Matrix<RhsScalar>& rhs,
EnforceLayoutSupport<Spec>(lhs.layout, rhs.layout, dst->layout); EnforceLayoutSupport<Spec>(lhs.layout, rhs.layout, dst->layout);
EnforceZeroPointSupport<Spec>(lhs.zero_point, rhs.zero_point, EnforceZeroPointSupport<Spec>(lhs.zero_point, rhs.zero_point,
dst->zero_point); dst->zero_point);
EnforceDstSpecSupport<Spec>(spec, dst->zero_point);
// This should be a constant, for a given machine and CompiledPaths. // This should be a constant, for a given machine and CompiledPaths.
// There is a back door to override it for testing, but in production it will // There is a back door to override it for testing, but in production it will

View File

@ -24,6 +24,7 @@ namespace ruy {
#define RUY_ASM_LABEL_STORE_UINT8 91 #define RUY_ASM_LABEL_STORE_UINT8 91
#define RUY_ASM_LABEL_STORE_INT8 92 #define RUY_ASM_LABEL_STORE_INT8 92
#define RUY_ASM_LABEL_STORE_INT16 93 #define RUY_ASM_LABEL_STORE_INT16 93
#define RUY_ASM_LABEL_STORE_INT32 94
#define RUY_ASM_LABEL_AFTER_STORE 99 #define RUY_ASM_LABEL_AFTER_STORE 99
#define RUY_OFFSET_BIAS 0 #define RUY_OFFSET_BIAS 0
@ -49,8 +50,8 @@ namespace ruy {
#define RUY_OFFSET_DST_STRIDE 112 #define RUY_OFFSET_DST_STRIDE 112
#define RUY_OFFSET_DEPTH 116 #define RUY_OFFSET_DEPTH 116
#define RUY_OFFSET_CLAMP_MIN 120 #define RUY_OFFSET_CLAMP_MIN 120
#define RUY_OFFSET_CLAMP_MAX 122 #define RUY_OFFSET_CLAMP_MAX 124
#define RUY_OFFSET_FLAGS 124 #define RUY_OFFSET_FLAGS 128
template <typename Params> template <typename Params>
void CheckOffsetsInKernelParams8bit(const Params&) { void CheckOffsetsInKernelParams8bit(const Params&) {
@ -476,6 +477,12 @@ void Kernel8bitNeonOutOfOrder(const KernelParams8bit<4, 4>& params) {
"sub v17.4s, v17.4s, v11.4s\n" "sub v17.4s, v17.4s, v11.4s\n"
"sub v18.4s, v18.4s, v11.4s\n" "sub v18.4s, v18.4s, v11.4s\n"
"sub v19.4s, v19.4s, v11.4s\n" "sub v19.4s, v19.4s, v11.4s\n"
// If the destination is int32, it means the user asks for the raw
// accumulators, no need for us to downquantize the value.
"cmp %w[dst_type_id], #" RUY_STR(RUY_ASM_TYPE_ID_INT32) "\n"
"beq " RUY_STR(RUY_ASM_LABEL_STORE_INT32) "f\n"
"402:\n" "402:\n"
// At this point we have computed the final int32 values. Now we // At this point we have computed the final int32 values. Now we
@ -924,6 +931,108 @@ void Kernel8bitNeonOutOfOrder(const KernelParams8bit<4, 4>& params) {
RUY_MAKE_ZERO(v16) RUY_MAKE_ZERO(v16)
RUY_MAKE_ZERO(v17) RUY_MAKE_ZERO(v17)
"b " RUY_STR(RUY_ASM_LABEL_AFTER_STORE) "f\n"
RUY_STR(RUY_ASM_LABEL_STORE_INT32) ":\n"
// Since the store type is the same as the accum type, no need for
// downcast. There's also no need for clamp by min/max.
// At this point, v20 -- v31 aren't used anymore for the current block,
// so we can start clearing these accumulators for the next block
// (next iteration of the main loop).
RUY_MAKE_ZERO(v20)
RUY_MAKE_ZERO(v21)
RUY_MAKE_ZERO(v22)
RUY_MAKE_ZERO(v23)
RUY_MAKE_ZERO(v24)
RUY_MAKE_ZERO(v25)
RUY_MAKE_ZERO(v26)
RUY_MAKE_ZERO(v27)
RUY_MAKE_ZERO(v28)
RUY_MAKE_ZERO(v29)
RUY_MAKE_ZERO(v30)
RUY_MAKE_ZERO(v31)
// Compute how much of the 4x4 block of destination 8bit values that
// we have computed, fit in the destination matrix. Typically, all of
// it fits, but when the destination matrix shape is not a multiple
// of 4x4, there are some 4x4 blocks along the boundaries that do
// not fit entirely.
"sub w1, %w[dst_rows], %w[row]\n"
"sub w2, %w[dst_cols], %w[col]\n"
"mov w3, #4\n"
"cmp w1, #4\n"
// Compute w1 = how many rows of the 4x4 block fit
"csel w1, w1, w3, le\n"
"cmp w2, #4\n"
// Compute w2 = how many cols of the 4x4 block fit
"csel w2, w2, w3, le\n"
// Test if w1==4 && w2 == 4, i.e. if all of the 8x8 block fits.
"cmp w1, w3\n"
"ccmp w2, w3, 0, eq\n"
"mov x4, %[dst_ptr]\n"
// Yes, all of the 4x4 block fits, go to fast path.
"beq 30f\n"
// Not all of the 4x4 block fits.
// Store to dst_tmp_buf
"str q16, [%[dst_tmp_buf], #0]\n"
"str q17, [%[dst_tmp_buf], #16]\n"
"str q18, [%[dst_tmp_buf], #32]\n"
"str q19, [%[dst_tmp_buf], #48]\n"
// Slow loop copying from dst_tmp_buf to dst.
"mov x3, %[dst_tmp_buf]\n"
"mov w6, #0\n"
"50:\n"
"mov w5, #0\n"
"51:\n"
"ldr w7, [x3, x5, lsl #2]\n"
"str w7, [x4, x5, lsl #2]\n"
"add w5, w5, #1\n"
"cmp w5, w1\n"
"blt 51b\n"
"add w6, w6, #1\n"
"add x3, x3, #16\n"
"add x4, x4, x11\n"
"cmp w6, w2\n"
"blt 50b\n"
"b 31f\n"
"30:\n"
// Yes, all of the 4x4 block fits.
"mov x3, x4\n"
"st1 {v16.s}[0], [x3], #4\n"
"add x4, x4, x11\n"
"st1 {v16.s}[1], [x3], #4\n"
"st1 {v16.s}[2], [x3], #4\n"
"st1 {v16.s}[3], [x3], #4\n"
"mov x3, x4\n"
"st1 {v17.s}[0], [x3], #4\n"
"add x4, x4, x11\n"
"st1 {v17.s}[1], [x3], #4\n"
"st1 {v17.s}[2], [x3], #4\n"
"st1 {v17.s}[3], [x3], #4\n"
"mov x3, x4\n"
"st1 {v18.s}[0], [x3], #4\n"
"add x4, x4, x11\n"
"st1 {v18.s}[1], [x3], #4\n"
"st1 {v18.s}[2], [x3], #4\n"
"st1 {v18.s}[3], [x3], #4\n"
"mov x3, x4\n"
"st1 {v19.s}[0], [x3], #4\n"
"add x4, x4, x11\n"
"st1 {v19.s}[1], [x3], #4\n"
"st1 {v19.s}[2], [x3], #4\n"
"st1 {v19.s}[3], [x3], #4\n"
"31:\n"
"add %[dst_ptr], %[dst_ptr], #16\n"
RUY_MAKE_ZERO(v16)
RUY_MAKE_ZERO(v17)
RUY_MAKE_ZERO(v18)
RUY_MAKE_ZERO(v19)
RUY_STR(RUY_ASM_LABEL_AFTER_STORE) ":\n" RUY_STR(RUY_ASM_LABEL_AFTER_STORE) ":\n"
// For the next block: perform the first few multiply-adds on the data // For the next block: perform the first few multiply-adds on the data

View File

@ -248,8 +248,8 @@ struct KernelParams8bit {
std::int32_t rhs_stride; std::int32_t rhs_stride;
std::int32_t dst_stride; std::int32_t dst_stride;
std::int32_t depth; std::int32_t depth;
std::int16_t clamp_min; std::int32_t clamp_min;
std::int16_t clamp_max; std::int32_t clamp_max;
std::uint8_t flags; std::uint8_t flags;
std::uint8_t dst_type_id; std::uint8_t dst_type_id;
const std::int32_t zero_data[LhsCols] = {0}; const std::int32_t zero_data[LhsCols] = {0};
@ -348,7 +348,9 @@ struct Kernel<Path::kNeon, std::int8_t, std::int8_t, DstScalar,
KernelParams8bit<LhsLayout::kCols, RhsLayout::kCols> params; KernelParams8bit<LhsLayout::kCols, RhsLayout::kCols> params;
MakeKernelParams8bit(lhs, rhs, spec, start_row, start_col, end_row, end_col, MakeKernelParams8bit(lhs, rhs, spec, start_row, start_col, end_row, end_col,
dst, &params); dst, &params);
if (__builtin_expect(tuning == Tuning::kInOrder, true)) { // TODO(renjieliu): Support InOrder path for dest is std::int32_t case.
if (__builtin_expect(tuning == Tuning::kInOrder, true) &&
!std::is_same<DstScalar, std::int32_t>::value) {
Kernel8bitNeonInOrder(params); Kernel8bitNeonInOrder(params);
} else { } else {
Kernel8bitNeonOutOfOrder(params); Kernel8bitNeonOutOfOrder(params);

View File

@ -1396,12 +1396,15 @@ void MakeSpecClampFields(const Matrix<LhsScalar>& lhs,
spec_unclamped.multiplier_exponent_perchannel = spec_unclamped.multiplier_exponent_perchannel =
spec->multiplier_exponent_perchannel; spec->multiplier_exponent_perchannel;
Mul<Path::kReference>(lhs, rhs, spec_unclamped, &context, &unclamped_dst); Mul<Path::kReference>(lhs, rhs, spec_unclamped, &context, &unclamped_dst);
// If dst is std::int32_t, no need to set the clamp min/max.
if (!std::is_same<typename Spec::DstScalar, std::int32_t>::value) {
std::sort(unclamped_dst_data.begin(), unclamped_dst_data.end()); std::sort(unclamped_dst_data.begin(), unclamped_dst_data.end());
const int clamp_count = static_cast<int>(std::floor(kClampRatio * size)); const int clamp_count = static_cast<int>(std::floor(kClampRatio * size));
RUY_CHECK_LT(clamp_count, size); RUY_CHECK_LT(clamp_count, size);
spec->clamp_min = unclamped_dst_data[clamp_count]; spec->clamp_min = unclamped_dst_data[clamp_count];
spec->clamp_max = unclamped_dst_data[size - 1 - clamp_count]; spec->clamp_max = unclamped_dst_data[size - 1 - clamp_count];
} }
}
template <typename LhsScalar, typename RhsScalar, typename SpecType> template <typename LhsScalar, typename RhsScalar, typename SpecType>
void TestSet<LhsScalar, RhsScalar, SpecType>::MakeZeroPoints() { void TestSet<LhsScalar, RhsScalar, SpecType>::MakeZeroPoints() {
@ -1409,8 +1412,13 @@ void TestSet<LhsScalar, RhsScalar, SpecType>::MakeZeroPoints() {
if (!use_specified_zero_points) { if (!use_specified_zero_points) {
MakeRandomScalar(RandomRange::kReasonableSrcZeroPoint, &lhs_zero_point); MakeRandomScalar(RandomRange::kReasonableSrcZeroPoint, &lhs_zero_point);
MakeRandomScalar(RandomRange::kReasonableSrcZeroPoint, &rhs_zero_point); MakeRandomScalar(RandomRange::kReasonableSrcZeroPoint, &rhs_zero_point);
// If destination is std::int32_t, no dst_zero_point is necessary.
if (std::is_same<DstScalar, std::int32_t>::value) {
dst_zero_point = 0;
} else {
MakeRandomScalar(RandomRange::kReasonableDstZeroPoint, &dst_zero_point); MakeRandomScalar(RandomRange::kReasonableDstZeroPoint, &dst_zero_point);
} }
}
life_stage = LifeStage::kHasZeroPoints; life_stage = LifeStage::kHasZeroPoints;
} }