ruy: support dst == int32, currently only kStandardCpp path.

PiperOrigin-RevId: 246974269
This commit is contained in:
Renjie Liu 2019-05-07 01:09:46 -07:00 committed by TensorFlower Gardener
parent e92ca4ad29
commit 4624a9ee5f
3 changed files with 28 additions and 6 deletions

View File

@ -339,6 +339,7 @@ ruy_benchmark(
("i8", "i8", "i32", "u8"),
("i8", "i8", "i32", "i8"),
("u8", "u8", "i32", "i16"),
("i8", "i8", "i32", "i32"),
],
)
@ -353,6 +354,7 @@ ruy_test(
("i8", "i8", "i32", "i8"),
("i8", "u8", "i32", "i8"),
("u8", "u8", "i32", "i16"),
("i8", "i8", "i32", "i32"),
],
)
@ -364,6 +366,7 @@ ruy_test(
("u8", "u8", "i32", "u8"),
("i8", "i8", "i32", "i8"),
("u8", "u8", "i32", "i16"),
("i8", "i8", "i32", "i32"),
],
)

View File

@ -141,14 +141,27 @@ template <Path ThePath, typename LhsScalar, typename RhsScalar,
void PopulateTrMulParams(TrMulParams* params) {
static_assert((ThePath & Path::kReference) == Path::kNone,
"Path::kReference should not do TrMul");
// The optimized code paths only handle a very specific set of layouts.
// Fall back to Path::kStandardCpp if needed.
// The optimized code paths don't handle the full generality of Ruy's API.
// Fall back to Path::kStandardCpp if necessary.
bool fallback_to_standard_cpp = false;
if (ThePath != Path::kStandardCpp) {
// The optimized code paths currently only handle the case of all matrices
// being column major.
if (!IsColMajorTrMul(params->lhs, params->rhs, params->dst)) {
PopulateTrMulParams<Path::kStandardCpp, LhsScalar, RhsScalar, DstScalar,
Spec>(params);
return;
fallback_to_standard_cpp = true;
}
// If DstScalar is std::int32_t, means user want to get from accumulator
// results directly, fallback to Path::kStandardCpp.
if (std::is_same<DstScalar, std::int32_t>::value) {
fallback_to_standard_cpp = true;
}
}
if (fallback_to_standard_cpp) {
PopulateTrMulParams<Path::kStandardCpp, LhsScalar, RhsScalar, DstScalar,
Spec>(params);
return;
}
using PackedLhsScalar = PackedType<ThePath, LhsScalar>;

View File

@ -197,6 +197,7 @@ RUY_INHERIT_KERNEL(Path::kNeon, Path::kNeonDotprod)
#define RUY_ASM_TYPE_ID_UINT8 1
#define RUY_ASM_TYPE_ID_INT8 2
#define RUY_ASM_TYPE_ID_INT16 3
#define RUY_ASM_TYPE_ID_INT32 4
template <typename DstScalar>
struct DstTypeId {};
@ -216,9 +217,14 @@ struct DstTypeId<std::int16_t> {
static constexpr int kValue = RUY_ASM_TYPE_ID_INT16;
};
template <>
struct DstTypeId<std::int32_t> {
static constexpr int kValue = RUY_ASM_TYPE_ID_INT32;
};
template <int LhsCols, int RhsCols>
struct KernelParams8bit {
static constexpr int kMaxDstTypeSize = 2;
static constexpr int kMaxDstTypeSize = 4;
const std::int32_t* bias;
const std::int32_t* lhs_sums;