ruy: support dst == int32, currently only kStandardCpp path.
PiperOrigin-RevId: 246974269
This commit is contained in:
parent
e92ca4ad29
commit
4624a9ee5f
@ -339,6 +339,7 @@ ruy_benchmark(
|
|||||||
("i8", "i8", "i32", "u8"),
|
("i8", "i8", "i32", "u8"),
|
||||||
("i8", "i8", "i32", "i8"),
|
("i8", "i8", "i32", "i8"),
|
||||||
("u8", "u8", "i32", "i16"),
|
("u8", "u8", "i32", "i16"),
|
||||||
|
("i8", "i8", "i32", "i32"),
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -353,6 +354,7 @@ ruy_test(
|
|||||||
("i8", "i8", "i32", "i8"),
|
("i8", "i8", "i32", "i8"),
|
||||||
("i8", "u8", "i32", "i8"),
|
("i8", "u8", "i32", "i8"),
|
||||||
("u8", "u8", "i32", "i16"),
|
("u8", "u8", "i32", "i16"),
|
||||||
|
("i8", "i8", "i32", "i32"),
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -364,6 +366,7 @@ ruy_test(
|
|||||||
("u8", "u8", "i32", "u8"),
|
("u8", "u8", "i32", "u8"),
|
||||||
("i8", "i8", "i32", "i8"),
|
("i8", "i8", "i32", "i8"),
|
||||||
("u8", "u8", "i32", "i16"),
|
("u8", "u8", "i32", "i16"),
|
||||||
|
("i8", "i8", "i32", "i32"),
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -141,15 +141,28 @@ template <Path ThePath, typename LhsScalar, typename RhsScalar,
|
|||||||
void PopulateTrMulParams(TrMulParams* params) {
|
void PopulateTrMulParams(TrMulParams* params) {
|
||||||
static_assert((ThePath & Path::kReference) == Path::kNone,
|
static_assert((ThePath & Path::kReference) == Path::kNone,
|
||||||
"Path::kReference should not do TrMul");
|
"Path::kReference should not do TrMul");
|
||||||
// The optimized code paths only handle a very specific set of layouts.
|
// The optimized code paths don't handle the full generality of Ruy's API.
|
||||||
// Fall back to Path::kStandardCpp if needed.
|
// Fall back to Path::kStandardCpp if necessary.
|
||||||
|
bool fallback_to_standard_cpp = false;
|
||||||
if (ThePath != Path::kStandardCpp) {
|
if (ThePath != Path::kStandardCpp) {
|
||||||
|
// The optimized code paths currently only handle the case of all matrices
|
||||||
|
// being column major.
|
||||||
if (!IsColMajorTrMul(params->lhs, params->rhs, params->dst)) {
|
if (!IsColMajorTrMul(params->lhs, params->rhs, params->dst)) {
|
||||||
|
fallback_to_standard_cpp = true;
|
||||||
|
}
|
||||||
|
|
||||||
|
// If DstScalar is std::int32_t, means user want to get from accumulator
|
||||||
|
// results directly, fallback to Path::kStandardCpp.
|
||||||
|
if (std::is_same<DstScalar, std::int32_t>::value) {
|
||||||
|
fallback_to_standard_cpp = true;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if (fallback_to_standard_cpp) {
|
||||||
PopulateTrMulParams<Path::kStandardCpp, LhsScalar, RhsScalar, DstScalar,
|
PopulateTrMulParams<Path::kStandardCpp, LhsScalar, RhsScalar, DstScalar,
|
||||||
Spec>(params);
|
Spec>(params);
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
}
|
|
||||||
|
|
||||||
using PackedLhsScalar = PackedType<ThePath, LhsScalar>;
|
using PackedLhsScalar = PackedType<ThePath, LhsScalar>;
|
||||||
using PackedRhsScalar = PackedType<ThePath, RhsScalar>;
|
using PackedRhsScalar = PackedType<ThePath, RhsScalar>;
|
||||||
|
@ -197,6 +197,7 @@ RUY_INHERIT_KERNEL(Path::kNeon, Path::kNeonDotprod)
|
|||||||
#define RUY_ASM_TYPE_ID_UINT8 1
|
#define RUY_ASM_TYPE_ID_UINT8 1
|
||||||
#define RUY_ASM_TYPE_ID_INT8 2
|
#define RUY_ASM_TYPE_ID_INT8 2
|
||||||
#define RUY_ASM_TYPE_ID_INT16 3
|
#define RUY_ASM_TYPE_ID_INT16 3
|
||||||
|
#define RUY_ASM_TYPE_ID_INT32 4
|
||||||
|
|
||||||
template <typename DstScalar>
|
template <typename DstScalar>
|
||||||
struct DstTypeId {};
|
struct DstTypeId {};
|
||||||
@ -216,9 +217,14 @@ struct DstTypeId<std::int16_t> {
|
|||||||
static constexpr int kValue = RUY_ASM_TYPE_ID_INT16;
|
static constexpr int kValue = RUY_ASM_TYPE_ID_INT16;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
template <>
|
||||||
|
struct DstTypeId<std::int32_t> {
|
||||||
|
static constexpr int kValue = RUY_ASM_TYPE_ID_INT32;
|
||||||
|
};
|
||||||
|
|
||||||
template <int LhsCols, int RhsCols>
|
template <int LhsCols, int RhsCols>
|
||||||
struct KernelParams8bit {
|
struct KernelParams8bit {
|
||||||
static constexpr int kMaxDstTypeSize = 2;
|
static constexpr int kMaxDstTypeSize = 4;
|
||||||
|
|
||||||
const std::int32_t* bias;
|
const std::int32_t* bias;
|
||||||
const std::int32_t* lhs_sums;
|
const std::int32_t* lhs_sums;
|
||||||
|
Loading…
Reference in New Issue
Block a user