ruy: support dst == int32, currently only kStandardCpp path.
PiperOrigin-RevId: 246974269
This commit is contained in:
parent
e92ca4ad29
commit
4624a9ee5f
@ -339,6 +339,7 @@ ruy_benchmark(
|
||||
("i8", "i8", "i32", "u8"),
|
||||
("i8", "i8", "i32", "i8"),
|
||||
("u8", "u8", "i32", "i16"),
|
||||
("i8", "i8", "i32", "i32"),
|
||||
],
|
||||
)
|
||||
|
||||
@ -353,6 +354,7 @@ ruy_test(
|
||||
("i8", "i8", "i32", "i8"),
|
||||
("i8", "u8", "i32", "i8"),
|
||||
("u8", "u8", "i32", "i16"),
|
||||
("i8", "i8", "i32", "i32"),
|
||||
],
|
||||
)
|
||||
|
||||
@ -364,6 +366,7 @@ ruy_test(
|
||||
("u8", "u8", "i32", "u8"),
|
||||
("i8", "i8", "i32", "i8"),
|
||||
("u8", "u8", "i32", "i16"),
|
||||
("i8", "i8", "i32", "i32"),
|
||||
],
|
||||
)
|
||||
|
||||
|
@ -141,14 +141,27 @@ template <Path ThePath, typename LhsScalar, typename RhsScalar,
|
||||
void PopulateTrMulParams(TrMulParams* params) {
|
||||
static_assert((ThePath & Path::kReference) == Path::kNone,
|
||||
"Path::kReference should not do TrMul");
|
||||
// The optimized code paths only handle a very specific set of layouts.
|
||||
// Fall back to Path::kStandardCpp if needed.
|
||||
// The optimized code paths don't handle the full generality of Ruy's API.
|
||||
// Fall back to Path::kStandardCpp if necessary.
|
||||
bool fallback_to_standard_cpp = false;
|
||||
if (ThePath != Path::kStandardCpp) {
|
||||
// The optimized code paths currently only handle the case of all matrices
|
||||
// being column major.
|
||||
if (!IsColMajorTrMul(params->lhs, params->rhs, params->dst)) {
|
||||
PopulateTrMulParams<Path::kStandardCpp, LhsScalar, RhsScalar, DstScalar,
|
||||
Spec>(params);
|
||||
return;
|
||||
fallback_to_standard_cpp = true;
|
||||
}
|
||||
|
||||
// If DstScalar is std::int32_t, means user want to get from accumulator
|
||||
// results directly, fallback to Path::kStandardCpp.
|
||||
if (std::is_same<DstScalar, std::int32_t>::value) {
|
||||
fallback_to_standard_cpp = true;
|
||||
}
|
||||
}
|
||||
|
||||
if (fallback_to_standard_cpp) {
|
||||
PopulateTrMulParams<Path::kStandardCpp, LhsScalar, RhsScalar, DstScalar,
|
||||
Spec>(params);
|
||||
return;
|
||||
}
|
||||
|
||||
using PackedLhsScalar = PackedType<ThePath, LhsScalar>;
|
||||
|
@ -197,6 +197,7 @@ RUY_INHERIT_KERNEL(Path::kNeon, Path::kNeonDotprod)
|
||||
#define RUY_ASM_TYPE_ID_UINT8 1
|
||||
#define RUY_ASM_TYPE_ID_INT8 2
|
||||
#define RUY_ASM_TYPE_ID_INT16 3
|
||||
#define RUY_ASM_TYPE_ID_INT32 4
|
||||
|
||||
template <typename DstScalar>
|
||||
struct DstTypeId {};
|
||||
@ -216,9 +217,14 @@ struct DstTypeId<std::int16_t> {
|
||||
static constexpr int kValue = RUY_ASM_TYPE_ID_INT16;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct DstTypeId<std::int32_t> {
|
||||
static constexpr int kValue = RUY_ASM_TYPE_ID_INT32;
|
||||
};
|
||||
|
||||
template <int LhsCols, int RhsCols>
|
||||
struct KernelParams8bit {
|
||||
static constexpr int kMaxDstTypeSize = 2;
|
||||
static constexpr int kMaxDstTypeSize = 4;
|
||||
|
||||
const std::int32_t* bias;
|
||||
const std::int32_t* lhs_sums;
|
||||
|
Loading…
Reference in New Issue
Block a user