ruy: support dst == int32, currently only kStandardCpp path.

PiperOrigin-RevId: 246974269
This commit is contained in:
Renjie Liu 2019-05-07 01:09:46 -07:00 committed by TensorFlower Gardener
parent e92ca4ad29
commit 4624a9ee5f
3 changed files with 28 additions and 6 deletions

View File

@ -339,6 +339,7 @@ ruy_benchmark(
("i8", "i8", "i32", "u8"), ("i8", "i8", "i32", "u8"),
("i8", "i8", "i32", "i8"), ("i8", "i8", "i32", "i8"),
("u8", "u8", "i32", "i16"), ("u8", "u8", "i32", "i16"),
("i8", "i8", "i32", "i32"),
], ],
) )
@ -353,6 +354,7 @@ ruy_test(
("i8", "i8", "i32", "i8"), ("i8", "i8", "i32", "i8"),
("i8", "u8", "i32", "i8"), ("i8", "u8", "i32", "i8"),
("u8", "u8", "i32", "i16"), ("u8", "u8", "i32", "i16"),
("i8", "i8", "i32", "i32"),
], ],
) )
@ -364,6 +366,7 @@ ruy_test(
("u8", "u8", "i32", "u8"), ("u8", "u8", "i32", "u8"),
("i8", "i8", "i32", "i8"), ("i8", "i8", "i32", "i8"),
("u8", "u8", "i32", "i16"), ("u8", "u8", "i32", "i16"),
("i8", "i8", "i32", "i32"),
], ],
) )

View File

@ -141,14 +141,27 @@ template <Path ThePath, typename LhsScalar, typename RhsScalar,
void PopulateTrMulParams(TrMulParams* params) { void PopulateTrMulParams(TrMulParams* params) {
static_assert((ThePath & Path::kReference) == Path::kNone, static_assert((ThePath & Path::kReference) == Path::kNone,
"Path::kReference should not do TrMul"); "Path::kReference should not do TrMul");
// The optimized code paths only handle a very specific set of layouts. // The optimized code paths don't handle the full generality of Ruy's API.
// Fall back to Path::kStandardCpp if needed. // Fall back to Path::kStandardCpp if necessary.
bool fallback_to_standard_cpp = false;
if (ThePath != Path::kStandardCpp) { if (ThePath != Path::kStandardCpp) {
// The optimized code paths currently only handle the case of all matrices
// being column major.
if (!IsColMajorTrMul(params->lhs, params->rhs, params->dst)) { if (!IsColMajorTrMul(params->lhs, params->rhs, params->dst)) {
PopulateTrMulParams<Path::kStandardCpp, LhsScalar, RhsScalar, DstScalar, fallback_to_standard_cpp = true;
Spec>(params);
return;
} }
// If DstScalar is std::int32_t, means user want to get from accumulator
// results directly, fallback to Path::kStandardCpp.
if (std::is_same<DstScalar, std::int32_t>::value) {
fallback_to_standard_cpp = true;
}
}
if (fallback_to_standard_cpp) {
PopulateTrMulParams<Path::kStandardCpp, LhsScalar, RhsScalar, DstScalar,
Spec>(params);
return;
} }
using PackedLhsScalar = PackedType<ThePath, LhsScalar>; using PackedLhsScalar = PackedType<ThePath, LhsScalar>;

View File

@ -197,6 +197,7 @@ RUY_INHERIT_KERNEL(Path::kNeon, Path::kNeonDotprod)
#define RUY_ASM_TYPE_ID_UINT8 1 #define RUY_ASM_TYPE_ID_UINT8 1
#define RUY_ASM_TYPE_ID_INT8 2 #define RUY_ASM_TYPE_ID_INT8 2
#define RUY_ASM_TYPE_ID_INT16 3 #define RUY_ASM_TYPE_ID_INT16 3
#define RUY_ASM_TYPE_ID_INT32 4
template <typename DstScalar> template <typename DstScalar>
struct DstTypeId {}; struct DstTypeId {};
@ -216,9 +217,14 @@ struct DstTypeId<std::int16_t> {
static constexpr int kValue = RUY_ASM_TYPE_ID_INT16; static constexpr int kValue = RUY_ASM_TYPE_ID_INT16;
}; };
template <>
struct DstTypeId<std::int32_t> {
static constexpr int kValue = RUY_ASM_TYPE_ID_INT32;
};
template <int LhsCols, int RhsCols> template <int LhsCols, int RhsCols>
struct KernelParams8bit { struct KernelParams8bit {
static constexpr int kMaxDstTypeSize = 2; static constexpr int kMaxDstTypeSize = 4;
const std::int32_t* bias; const std::int32_t* bias;
const std::int32_t* lhs_sums; const std::int32_t* lhs_sums;