ruy: support dst == int32 for neon out of order case.
PiperOrigin-RevId: 247154463
This commit is contained in:
parent
34b0836381
commit
8dc0be226b
@ -37,6 +37,8 @@ limitations under the License.
|
||||
#ifndef TENSORFLOW_LITE_EXPERIMENTAL_RUY_DISPATCH_H_
|
||||
#define TENSORFLOW_LITE_EXPERIMENTAL_RUY_DISPATCH_H_
|
||||
|
||||
#include <limits>
|
||||
|
||||
#include "profiling/instrumentation.h"
|
||||
#include "tensorflow/lite/experimental/ruy/common.h"
|
||||
#include "tensorflow/lite/experimental/ruy/context.h"
|
||||
@ -95,6 +97,21 @@ void EnforceZeroPointSupport(LhsScalar lhs_zero_point, RhsScalar rhs_zero_point,
|
||||
rhs_zero_point != std::numeric_limits<RhsScalar>::lowest());
|
||||
}
|
||||
|
||||
template <typename Spec, typename DstScalar>
|
||||
void EnforceDstSpecSupport(const Spec& spec, DstScalar dst_zero_point) {
|
||||
if (!std::is_same<typename Spec::DstScalar, std::int32_t>::value) return;
|
||||
|
||||
// If user is looking for the raw accumulator, zero_point and all the other
|
||||
// dequantize fields don't make sense and should not be set.
|
||||
RUY_DCHECK(dst_zero_point == 0);
|
||||
RUY_DCHECK(spec.clamp_max == std::numeric_limits<std::int32_t>::max());
|
||||
RUY_DCHECK(spec.clamp_min == std::numeric_limits<std::int32_t>::min());
|
||||
RUY_DCHECK(spec.multiplier_fixedpoint == 0);
|
||||
RUY_DCHECK(spec.multiplier_exponent == 0);
|
||||
RUY_DCHECK(spec.multiplier_fixedpoint_perchannel == nullptr);
|
||||
RUY_DCHECK(spec.multiplier_exponent_perchannel == nullptr);
|
||||
}
|
||||
|
||||
inline bool IsColMajorTrMul(const DMatrix& lhs, const DMatrix& rhs,
|
||||
const DMatrix& dst) {
|
||||
return IsColMajor(lhs.layout) && IsColMajor(rhs.layout) &&
|
||||
@ -152,8 +169,10 @@ void PopulateTrMulParams(TrMulParams* params) {
|
||||
}
|
||||
|
||||
// If DstScalar is std::int32_t, means user want to get from accumulator
|
||||
// results directly, fallback to Path::kStandardCpp.
|
||||
if (std::is_same<DstScalar, std::int32_t>::value) {
|
||||
// results directly, if it's not Neon path, will fallback to
|
||||
// Path::kStandardCpp.
|
||||
if (std::is_same<DstScalar, std::int32_t>::value &&
|
||||
ThePath != Path::kNeon) {
|
||||
fallback_to_standard_cpp = true;
|
||||
}
|
||||
}
|
||||
@ -367,6 +386,7 @@ void DispatchMul(const Matrix<LhsScalar>& lhs, const Matrix<RhsScalar>& rhs,
|
||||
EnforceLayoutSupport<Spec>(lhs.layout, rhs.layout, dst->layout);
|
||||
EnforceZeroPointSupport<Spec>(lhs.zero_point, rhs.zero_point,
|
||||
dst->zero_point);
|
||||
EnforceDstSpecSupport<Spec>(spec, dst->zero_point);
|
||||
|
||||
// This should be a constant, for a given machine and CompiledPaths.
|
||||
// There is a back door to override it for testing, but in production it will
|
||||
|
@ -24,6 +24,7 @@ namespace ruy {
|
||||
#define RUY_ASM_LABEL_STORE_UINT8 91
|
||||
#define RUY_ASM_LABEL_STORE_INT8 92
|
||||
#define RUY_ASM_LABEL_STORE_INT16 93
|
||||
#define RUY_ASM_LABEL_STORE_INT32 94
|
||||
#define RUY_ASM_LABEL_AFTER_STORE 99
|
||||
|
||||
#define RUY_OFFSET_BIAS 0
|
||||
@ -49,8 +50,8 @@ namespace ruy {
|
||||
#define RUY_OFFSET_DST_STRIDE 112
|
||||
#define RUY_OFFSET_DEPTH 116
|
||||
#define RUY_OFFSET_CLAMP_MIN 120
|
||||
#define RUY_OFFSET_CLAMP_MAX 122
|
||||
#define RUY_OFFSET_FLAGS 124
|
||||
#define RUY_OFFSET_CLAMP_MAX 124
|
||||
#define RUY_OFFSET_FLAGS 128
|
||||
|
||||
template <typename Params>
|
||||
void CheckOffsetsInKernelParams8bit(const Params&) {
|
||||
@ -476,6 +477,12 @@ void Kernel8bitNeonOutOfOrder(const KernelParams8bit<4, 4>& params) {
|
||||
"sub v17.4s, v17.4s, v11.4s\n"
|
||||
"sub v18.4s, v18.4s, v11.4s\n"
|
||||
"sub v19.4s, v19.4s, v11.4s\n"
|
||||
|
||||
// If the destination is int32, it means the user asks for the raw
|
||||
// accumulators, no need for us to downquantize the value.
|
||||
"cmp %w[dst_type_id], #" RUY_STR(RUY_ASM_TYPE_ID_INT32) "\n"
|
||||
"beq " RUY_STR(RUY_ASM_LABEL_STORE_INT32) "f\n"
|
||||
|
||||
"402:\n"
|
||||
|
||||
// At this point we have computed the final int32 values. Now we
|
||||
@ -924,6 +931,108 @@ void Kernel8bitNeonOutOfOrder(const KernelParams8bit<4, 4>& params) {
|
||||
RUY_MAKE_ZERO(v16)
|
||||
RUY_MAKE_ZERO(v17)
|
||||
|
||||
"b " RUY_STR(RUY_ASM_LABEL_AFTER_STORE) "f\n"
|
||||
|
||||
RUY_STR(RUY_ASM_LABEL_STORE_INT32) ":\n"
|
||||
|
||||
// Since the store type is the same as the accum type, no need for
|
||||
// downcast. There's also no need for clamp by min/max.
|
||||
|
||||
// At this point, v20 -- v31 aren't used anymore for the current block,
|
||||
// so we can start clearing these accumulators for the next block
|
||||
// (next iteration of the main loop).
|
||||
RUY_MAKE_ZERO(v20)
|
||||
RUY_MAKE_ZERO(v21)
|
||||
RUY_MAKE_ZERO(v22)
|
||||
RUY_MAKE_ZERO(v23)
|
||||
RUY_MAKE_ZERO(v24)
|
||||
RUY_MAKE_ZERO(v25)
|
||||
RUY_MAKE_ZERO(v26)
|
||||
RUY_MAKE_ZERO(v27)
|
||||
RUY_MAKE_ZERO(v28)
|
||||
RUY_MAKE_ZERO(v29)
|
||||
RUY_MAKE_ZERO(v30)
|
||||
RUY_MAKE_ZERO(v31)
|
||||
|
||||
// Compute how much of the 4x4 block of destination 8bit values that
|
||||
// we have computed, fit in the destination matrix. Typically, all of
|
||||
// it fits, but when the destination matrix shape is not a multiple
|
||||
// of 4x4, there are some 4x4 blocks along the boundaries that do
|
||||
// not fit entirely.
|
||||
"sub w1, %w[dst_rows], %w[row]\n"
|
||||
"sub w2, %w[dst_cols], %w[col]\n"
|
||||
"mov w3, #4\n"
|
||||
"cmp w1, #4\n"
|
||||
// Compute w1 = how many rows of the 4x4 block fit
|
||||
"csel w1, w1, w3, le\n"
|
||||
"cmp w2, #4\n"
|
||||
// Compute w2 = how many cols of the 4x4 block fit
|
||||
"csel w2, w2, w3, le\n"
|
||||
|
||||
// Test if w1==4 && w2 == 4, i.e. if all of the 8x8 block fits.
|
||||
"cmp w1, w3\n"
|
||||
"ccmp w2, w3, 0, eq\n"
|
||||
"mov x4, %[dst_ptr]\n"
|
||||
// Yes, all of the 4x4 block fits, go to fast path.
|
||||
"beq 30f\n"
|
||||
// Not all of the 4x4 block fits.
|
||||
// Store to dst_tmp_buf
|
||||
"str q16, [%[dst_tmp_buf], #0]\n"
|
||||
"str q17, [%[dst_tmp_buf], #16]\n"
|
||||
"str q18, [%[dst_tmp_buf], #32]\n"
|
||||
"str q19, [%[dst_tmp_buf], #48]\n"
|
||||
// Slow loop copying from dst_tmp_buf to dst.
|
||||
"mov x3, %[dst_tmp_buf]\n"
|
||||
"mov w6, #0\n"
|
||||
"50:\n"
|
||||
"mov w5, #0\n"
|
||||
"51:\n"
|
||||
"ldr w7, [x3, x5, lsl #2]\n"
|
||||
"str w7, [x4, x5, lsl #2]\n"
|
||||
"add w5, w5, #1\n"
|
||||
"cmp w5, w1\n"
|
||||
"blt 51b\n"
|
||||
"add w6, w6, #1\n"
|
||||
"add x3, x3, #16\n"
|
||||
"add x4, x4, x11\n"
|
||||
"cmp w6, w2\n"
|
||||
"blt 50b\n"
|
||||
"b 31f\n"
|
||||
"30:\n"
|
||||
// Yes, all of the 4x4 block fits.
|
||||
"mov x3, x4\n"
|
||||
"st1 {v16.s}[0], [x3], #4\n"
|
||||
"add x4, x4, x11\n"
|
||||
"st1 {v16.s}[1], [x3], #4\n"
|
||||
"st1 {v16.s}[2], [x3], #4\n"
|
||||
"st1 {v16.s}[3], [x3], #4\n"
|
||||
"mov x3, x4\n"
|
||||
"st1 {v17.s}[0], [x3], #4\n"
|
||||
"add x4, x4, x11\n"
|
||||
"st1 {v17.s}[1], [x3], #4\n"
|
||||
"st1 {v17.s}[2], [x3], #4\n"
|
||||
"st1 {v17.s}[3], [x3], #4\n"
|
||||
"mov x3, x4\n"
|
||||
"st1 {v18.s}[0], [x3], #4\n"
|
||||
"add x4, x4, x11\n"
|
||||
"st1 {v18.s}[1], [x3], #4\n"
|
||||
"st1 {v18.s}[2], [x3], #4\n"
|
||||
"st1 {v18.s}[3], [x3], #4\n"
|
||||
"mov x3, x4\n"
|
||||
"st1 {v19.s}[0], [x3], #4\n"
|
||||
"add x4, x4, x11\n"
|
||||
"st1 {v19.s}[1], [x3], #4\n"
|
||||
"st1 {v19.s}[2], [x3], #4\n"
|
||||
"st1 {v19.s}[3], [x3], #4\n"
|
||||
"31:\n"
|
||||
|
||||
"add %[dst_ptr], %[dst_ptr], #16\n"
|
||||
|
||||
RUY_MAKE_ZERO(v16)
|
||||
RUY_MAKE_ZERO(v17)
|
||||
RUY_MAKE_ZERO(v18)
|
||||
RUY_MAKE_ZERO(v19)
|
||||
|
||||
RUY_STR(RUY_ASM_LABEL_AFTER_STORE) ":\n"
|
||||
|
||||
// For the next block: perform the first few multiply-adds on the data
|
||||
|
@ -248,8 +248,8 @@ struct KernelParams8bit {
|
||||
std::int32_t rhs_stride;
|
||||
std::int32_t dst_stride;
|
||||
std::int32_t depth;
|
||||
std::int16_t clamp_min;
|
||||
std::int16_t clamp_max;
|
||||
std::int32_t clamp_min;
|
||||
std::int32_t clamp_max;
|
||||
std::uint8_t flags;
|
||||
std::uint8_t dst_type_id;
|
||||
const std::int32_t zero_data[LhsCols] = {0};
|
||||
@ -348,7 +348,9 @@ struct Kernel<Path::kNeon, std::int8_t, std::int8_t, DstScalar,
|
||||
KernelParams8bit<LhsLayout::kCols, RhsLayout::kCols> params;
|
||||
MakeKernelParams8bit(lhs, rhs, spec, start_row, start_col, end_row, end_col,
|
||||
dst, ¶ms);
|
||||
if (__builtin_expect(tuning == Tuning::kInOrder, true)) {
|
||||
// TODO(renjieliu): Support InOrder path for dest is std::int32_t case.
|
||||
if (__builtin_expect(tuning == Tuning::kInOrder, true) &&
|
||||
!std::is_same<DstScalar, std::int32_t>::value) {
|
||||
Kernel8bitNeonInOrder(params);
|
||||
} else {
|
||||
Kernel8bitNeonOutOfOrder(params);
|
||||
|
@ -1396,11 +1396,14 @@ void MakeSpecClampFields(const Matrix<LhsScalar>& lhs,
|
||||
spec_unclamped.multiplier_exponent_perchannel =
|
||||
spec->multiplier_exponent_perchannel;
|
||||
Mul<Path::kReference>(lhs, rhs, spec_unclamped, &context, &unclamped_dst);
|
||||
std::sort(unclamped_dst_data.begin(), unclamped_dst_data.end());
|
||||
const int clamp_count = static_cast<int>(std::floor(kClampRatio * size));
|
||||
RUY_CHECK_LT(clamp_count, size);
|
||||
spec->clamp_min = unclamped_dst_data[clamp_count];
|
||||
spec->clamp_max = unclamped_dst_data[size - 1 - clamp_count];
|
||||
// If dst is std::int32_t, no need to set the clamp min/max.
|
||||
if (!std::is_same<typename Spec::DstScalar, std::int32_t>::value) {
|
||||
std::sort(unclamped_dst_data.begin(), unclamped_dst_data.end());
|
||||
const int clamp_count = static_cast<int>(std::floor(kClampRatio * size));
|
||||
RUY_CHECK_LT(clamp_count, size);
|
||||
spec->clamp_min = unclamped_dst_data[clamp_count];
|
||||
spec->clamp_max = unclamped_dst_data[size - 1 - clamp_count];
|
||||
}
|
||||
}
|
||||
|
||||
template <typename LhsScalar, typename RhsScalar, typename SpecType>
|
||||
@ -1409,7 +1412,12 @@ void TestSet<LhsScalar, RhsScalar, SpecType>::MakeZeroPoints() {
|
||||
if (!use_specified_zero_points) {
|
||||
MakeRandomScalar(RandomRange::kReasonableSrcZeroPoint, &lhs_zero_point);
|
||||
MakeRandomScalar(RandomRange::kReasonableSrcZeroPoint, &rhs_zero_point);
|
||||
MakeRandomScalar(RandomRange::kReasonableDstZeroPoint, &dst_zero_point);
|
||||
// If destination is std::int32_t, no dst_zero_point is necessary.
|
||||
if (std::is_same<DstScalar, std::int32_t>::value) {
|
||||
dst_zero_point = 0;
|
||||
} else {
|
||||
MakeRandomScalar(RandomRange::kReasonableDstZeroPoint, &dst_zero_point);
|
||||
}
|
||||
}
|
||||
life_stage = LifeStage::kHasZeroPoints;
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user