Full int8 quantization BatchMatMul
PiperOrigin-RevId: 317304259 Change-Id: Icf96d9d129db30b965e36f5c8befd27762b173b2
This commit is contained in:
parent
0c7e61d660
commit
9e7d5ef6f2
@ -953,14 +953,14 @@ in the batch dimensions and broadcasting.
|
|||||||
}];
|
}];
|
||||||
|
|
||||||
let arguments = (ins
|
let arguments = (ins
|
||||||
TFL_TensorOf<[F32]>:$x,
|
TFL_TensorOf<[F32, QI8]>:$x,
|
||||||
TFL_TensorOf<[F32]>:$y,
|
TFL_TensorOf<[F32, QI8]>:$y,
|
||||||
DefaultValuedAttr<BoolAttr, "false">:$adj_x,
|
DefaultValuedAttr<BoolAttr, "false">:$adj_x,
|
||||||
DefaultValuedAttr<BoolAttr, "false">:$adj_y
|
DefaultValuedAttr<BoolAttr, "false">:$adj_y
|
||||||
);
|
);
|
||||||
|
|
||||||
let results = (outs
|
let results = (outs
|
||||||
TFL_TensorOf<[F32]>:$output
|
TFL_TensorOf<[F32, QI8]>:$output
|
||||||
);
|
);
|
||||||
|
|
||||||
let hasOptions = 1;
|
let hasOptions = 1;
|
||||||
|
@ -19,6 +19,7 @@ limitations under the License.
|
|||||||
|
|
||||||
#include <algorithm>
|
#include <algorithm>
|
||||||
#include <cstdint>
|
#include <cstdint>
|
||||||
|
#include <limits>
|
||||||
|
|
||||||
#include "tensorflow/lite/c/builtin_op_data.h"
|
#include "tensorflow/lite/c/builtin_op_data.h"
|
||||||
#include "tensorflow/lite/c/common.h"
|
#include "tensorflow/lite/c/common.h"
|
||||||
@ -52,6 +53,14 @@ enum KernelType {
|
|||||||
};
|
};
|
||||||
|
|
||||||
struct OpData {
|
struct OpData {
|
||||||
|
// The scaling factor from input to output (aka the 'real multiplier') can
|
||||||
|
// be represented as a fixed point multiplier plus a left shift.
|
||||||
|
int32_t output_multiplier;
|
||||||
|
int output_shift;
|
||||||
|
// The range of the fused activation layer. For example for kNone and
|
||||||
|
// uint8_t these would be 0 and 255.
|
||||||
|
int32_t output_activation_min;
|
||||||
|
int32_t output_activation_max;
|
||||||
// The index of the temporary tensors where we store transposed LHS/RHS.
|
// The index of the temporary tensors where we store transposed LHS/RHS.
|
||||||
int scratch_tensor_index;
|
int scratch_tensor_index;
|
||||||
bool rhs_transposed;
|
bool rhs_transposed;
|
||||||
@ -274,6 +283,7 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
|
|||||||
|
|
||||||
OpContext op_context(context, node);
|
OpContext op_context(context, node);
|
||||||
TF_LITE_ENSURE_OK(context, InitializeTemporaries(context, node, &op_context));
|
TF_LITE_ENSURE_OK(context, InitializeTemporaries(context, node, &op_context));
|
||||||
|
OpData* op_data = reinterpret_cast<OpData*>(node->user_data);
|
||||||
|
|
||||||
bool adj_x = op_context.params->adj_x;
|
bool adj_x = op_context.params->adj_x;
|
||||||
bool adj_y = op_context.params->adj_y;
|
bool adj_y = op_context.params->adj_y;
|
||||||
@ -282,7 +292,24 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
|
|||||||
const TfLiteTensor* rhs_data = GetInput(context, node, kInputRHSTensor);
|
const TfLiteTensor* rhs_data = GetInput(context, node, kInputRHSTensor);
|
||||||
TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
|
TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
|
||||||
|
|
||||||
TF_LITE_ENSURE_TYPES_EQ(context, lhs_data->type, kTfLiteFloat32);
|
// Note that quantized inference requires that all tensors have their
|
||||||
|
// parameters set. This is usually done during quantized training.
|
||||||
|
if (lhs_data->type == kTfLiteInt8) {
|
||||||
|
double real_multiplier = 0.0;
|
||||||
|
TF_LITE_ENSURE_STATUS(GetQuantizedConvolutionMultipler(
|
||||||
|
context, lhs_data, rhs_data, output, &real_multiplier));
|
||||||
|
int exponent;
|
||||||
|
QuantizeMultiplier(real_multiplier, &op_data->output_multiplier, &exponent);
|
||||||
|
op_data->output_shift = exponent;
|
||||||
|
// BatchMatMul has no fused activation functions. Therefore, set
|
||||||
|
// output activation min and max to min and max of int8_t type,
|
||||||
|
// respecitvely.
|
||||||
|
op_data->output_activation_min = std::numeric_limits<int8_t>::min();
|
||||||
|
op_data->output_activation_max = std::numeric_limits<int8_t>::max();
|
||||||
|
}
|
||||||
|
|
||||||
|
TF_LITE_ENSURE(context, lhs_data->type == kTfLiteFloat32 ||
|
||||||
|
lhs_data->type == kTfLiteInt8);
|
||||||
TF_LITE_ENSURE(context, rhs_data->type == kTfLiteFloat32 ||
|
TF_LITE_ENSURE(context, rhs_data->type == kTfLiteFloat32 ||
|
||||||
rhs_data->type == kTfLiteInt8);
|
rhs_data->type == kTfLiteInt8);
|
||||||
// Support dimensions between 2 and 4, inclusive.
|
// Support dimensions between 2 and 4, inclusive.
|
||||||
@ -433,6 +460,41 @@ TfLiteStatus EvalHybrid(TfLiteContext* context, TfLiteNode* node, OpData* data,
|
|||||||
return kTfLiteOk;
|
return kTfLiteOk;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
template <KernelType kernel_type>
|
||||||
|
TfLiteStatus EvalInt8(TfLiteContext* context, const OpData* data,
|
||||||
|
const RuntimeShape& lhs_shape, const TfLiteTensor* lhs,
|
||||||
|
const RuntimeShape& rhs_shape, const TfLiteTensor* rhs,
|
||||||
|
const RuntimeShape& output_shape, TfLiteTensor* output) {
|
||||||
|
// Reuse params struct from FullyConnected Op.
|
||||||
|
FullyConnectedParams op_params;
|
||||||
|
int32_t input_offset = -lhs->params.zero_point;
|
||||||
|
int32_t filter_offset = -rhs->params.zero_point;
|
||||||
|
int32_t output_offset = output->params.zero_point;
|
||||||
|
op_params.input_offset = input_offset;
|
||||||
|
op_params.weights_offset = filter_offset;
|
||||||
|
op_params.output_offset = output_offset;
|
||||||
|
op_params.output_multiplier = data->output_multiplier;
|
||||||
|
op_params.output_shift = data->output_shift;
|
||||||
|
op_params.quantized_activation_min = data->output_activation_min;
|
||||||
|
op_params.quantized_activation_max = data->output_activation_max;
|
||||||
|
op_params.lhs_cacheable = IsConstantTensor(lhs);
|
||||||
|
op_params.rhs_cacheable = IsConstantTensor(rhs);
|
||||||
|
|
||||||
|
if (kernel_type == kReference) {
|
||||||
|
reference_ops::BatchMatMul(op_params, rhs_shape, GetTensorData<int8_t>(rhs),
|
||||||
|
lhs_shape, GetTensorData<int8_t>(lhs),
|
||||||
|
GetTensorShape(output),
|
||||||
|
GetTensorData<int8_t>(output));
|
||||||
|
} else {
|
||||||
|
optimized_ops::BatchMatMul(op_params, rhs_shape, GetTensorData<int8_t>(rhs),
|
||||||
|
lhs_shape, GetTensorData<int8_t>(lhs),
|
||||||
|
GetTensorShape(output),
|
||||||
|
GetTensorData<int8_t>(output),
|
||||||
|
CpuBackendContext::GetFromContext(context));
|
||||||
|
}
|
||||||
|
return kTfLiteOk;
|
||||||
|
}
|
||||||
|
|
||||||
template <KernelType kernel_type>
|
template <KernelType kernel_type>
|
||||||
TfLiteStatus EvalQuantized(TfLiteContext* context, TfLiteNode* node,
|
TfLiteStatus EvalQuantized(TfLiteContext* context, TfLiteNode* node,
|
||||||
OpData* data, const RuntimeShape& lhs_shape,
|
OpData* data, const RuntimeShape& lhs_shape,
|
||||||
@ -448,25 +510,39 @@ TfLiteStatus EvalQuantized(TfLiteContext* context, TfLiteNode* node,
|
|||||||
return EvalHybrid<kernel_type>(
|
return EvalHybrid<kernel_type>(
|
||||||
context, node, data, lhs_shape, lhs, rhs_shape, rhs, input_quantized,
|
context, node, data, lhs_shape, lhs, rhs_shape, rhs, input_quantized,
|
||||||
scaling_factors, accum_scratch, row_sums, input_offsets, output);
|
scaling_factors, accum_scratch, row_sums, input_offsets, output);
|
||||||
|
} else if (lhs->type == kTfLiteInt8) {
|
||||||
|
return EvalInt8<kernel_type>(context, data, lhs_shape, lhs, rhs_shape, rhs,
|
||||||
|
GetTensorShape(output), output);
|
||||||
} else {
|
} else {
|
||||||
TF_LITE_KERNEL_LOG(context,
|
TF_LITE_KERNEL_LOG(
|
||||||
"Currently only hybrid quantization is supported.\n");
|
context, "Currently only hybrid and int8 quantization is supported.\n");
|
||||||
return kTfLiteError;
|
return kTfLiteError;
|
||||||
}
|
}
|
||||||
return kTfLiteOk;
|
return kTfLiteOk;
|
||||||
}
|
}
|
||||||
|
|
||||||
TfLiteTensor* GetRhs(TfLiteContext* context, TfLiteNode* node,
|
TfLiteTensor* GetTempRhs(TfLiteContext* context, TfLiteNode* node,
|
||||||
const TfLiteTensor* rhs) {
|
const TfLiteTensor* rhs) {
|
||||||
TfLiteTensor* transposed_rhs = GetTemporary(context, node, 1);
|
TfLiteTensor* transposed_rhs = GetTemporary(context, node, 1);
|
||||||
if (rhs->type == kTfLiteInt8) {
|
if (rhs->type == kTfLiteInt8) {
|
||||||
// Get the quantization params from the weights tensors.
|
// Get the quantization params from the RHS tensor.
|
||||||
transposed_rhs->params.scale = rhs->params.scale;
|
transposed_rhs->params.scale = rhs->params.scale;
|
||||||
transposed_rhs->params.zero_point = rhs->params.zero_point;
|
transposed_rhs->params.zero_point = rhs->params.zero_point;
|
||||||
}
|
}
|
||||||
return transposed_rhs;
|
return transposed_rhs;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
TfLiteTensor* GetTempLhs(TfLiteContext* context, TfLiteNode* node,
|
||||||
|
const TfLiteTensor* lhs) {
|
||||||
|
TfLiteTensor* transposed_lhs = GetTemporary(context, node, 0);
|
||||||
|
if (lhs->type == kTfLiteInt8) {
|
||||||
|
// Get the quantization params from the LHS tensor.
|
||||||
|
transposed_lhs->params.scale = lhs->params.scale;
|
||||||
|
transposed_lhs->params.zero_point = lhs->params.zero_point;
|
||||||
|
}
|
||||||
|
return transposed_lhs;
|
||||||
|
}
|
||||||
|
|
||||||
// Perform a batch matrix multiply on
|
// Perform a batch matrix multiply on
|
||||||
// LHS <..., A, B> X RHS<..., B, C>
|
// LHS <..., A, B> X RHS<..., B, C>
|
||||||
// where the leading dimensions of LHS and RHS obey broadcasting rules
|
// where the leading dimensions of LHS and RHS obey broadcasting rules
|
||||||
@ -491,8 +567,8 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
|
|||||||
bool adj_y = op_context.params->adj_y;
|
bool adj_y = op_context.params->adj_y;
|
||||||
bool adj_x = op_context.params->adj_x;
|
bool adj_x = op_context.params->adj_x;
|
||||||
|
|
||||||
const TfLiteTensor* rhs_tensor = adj_y ? rhs : GetRhs(context, node, rhs);
|
const TfLiteTensor* rhs_tensor = adj_y ? rhs : GetTempRhs(context, node, rhs);
|
||||||
const TfLiteTensor* lhs_tensor = adj_x ? GetTemporary(context, node, 0) : lhs;
|
const TfLiteTensor* lhs_tensor = adj_x ? GetTempLhs(context, node, lhs) : lhs;
|
||||||
if (!adj_y) {
|
if (!adj_y) {
|
||||||
// TODO(b/154760341) Constant tensors should already be transposed, but
|
// TODO(b/154760341) Constant tensors should already be transposed, but
|
||||||
// we transpose once if necessary for now.
|
// we transpose once if necessary for now.
|
||||||
|
@ -24,8 +24,19 @@ limitations under the License.
|
|||||||
#include "tensorflow/lite/schema/schema_generated.h"
|
#include "tensorflow/lite/schema/schema_generated.h"
|
||||||
|
|
||||||
namespace tflite {
|
namespace tflite {
|
||||||
|
|
||||||
|
namespace ops {
|
||||||
|
namespace builtin {
|
||||||
|
|
||||||
|
TfLiteRegistration* Register_BATCH_MATMUL_REF();
|
||||||
|
TfLiteRegistration* Register_BATCH_MATMUL_GENERIC_OPTIMIZED();
|
||||||
|
|
||||||
|
} // namespace builtin
|
||||||
|
} // namespace ops
|
||||||
|
|
||||||
namespace {
|
namespace {
|
||||||
|
|
||||||
|
using ::testing::ElementsAre;
|
||||||
using ::testing::ElementsAreArray;
|
using ::testing::ElementsAreArray;
|
||||||
|
|
||||||
template <typename T>
|
template <typename T>
|
||||||
@ -53,7 +64,20 @@ class BatchMatMulOpModel : public SingleOpModel {
|
|||||||
int output_id_;
|
int output_id_;
|
||||||
};
|
};
|
||||||
|
|
||||||
TEST(BatchMatMulOpModelTest, Float32Test_Simple) {
|
const auto kKernelMap = new std::map<string, TfLiteRegistration*>({
|
||||||
|
{"Reference", ops::builtin::Register_BATCH_MATMUL_REF()},
|
||||||
|
{"GenericOptimized",
|
||||||
|
ops::builtin::Register_BATCH_MATMUL_GENERIC_OPTIMIZED()},
|
||||||
|
});
|
||||||
|
|
||||||
|
class BatchMatMulOpTest : public SingleOpTest {
|
||||||
|
protected:
|
||||||
|
const std::map<string, TfLiteRegistration*>& GetKernelMap() override {
|
||||||
|
return *kKernelMap;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
TEST_P(BatchMatMulOpTest, Float32Test_Simple) {
|
||||||
BatchMatMulOpModel<float> model({TensorType_FLOAT32, {1, 2, 3}},
|
BatchMatMulOpModel<float> model({TensorType_FLOAT32, {1, 2, 3}},
|
||||||
{TensorType_FLOAT32, {1, 3, 4}});
|
{TensorType_FLOAT32, {1, 3, 4}});
|
||||||
model.PopulateTensor<float>(model.lhs(), {1, 2, 3, 4, 5, 6});
|
model.PopulateTensor<float>(model.lhs(), {1, 2, 3, 4, 5, 6});
|
||||||
@ -65,7 +89,7 @@ TEST(BatchMatMulOpModelTest, Float32Test_Simple) {
|
|||||||
EXPECT_THAT(model.GetOutputShape(), ElementsAreArray({1, 2, 4}));
|
EXPECT_THAT(model.GetOutputShape(), ElementsAreArray({1, 2, 4}));
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST(BatchMatMulOpModelTest, Float32Test_SimpleRHSAdjoint) {
|
TEST_P(BatchMatMulOpTest, Float32Test_SimpleRHSAdjoint) {
|
||||||
BatchMatMulOpModel<float> model({TensorType_FLOAT32, {1, 2, 3}},
|
BatchMatMulOpModel<float> model({TensorType_FLOAT32, {1, 2, 3}},
|
||||||
{TensorType_FLOAT32, {1, 4, 3}}, false, true);
|
{TensorType_FLOAT32, {1, 4, 3}}, false, true);
|
||||||
model.PopulateTensor<float>(model.lhs(), {1, 2, 3, 4, 5, 6});
|
model.PopulateTensor<float>(model.lhs(), {1, 2, 3, 4, 5, 6});
|
||||||
@ -77,7 +101,7 @@ TEST(BatchMatMulOpModelTest, Float32Test_SimpleRHSAdjoint) {
|
|||||||
EXPECT_THAT(model.GetOutputShape(), ElementsAreArray({1, 2, 4}));
|
EXPECT_THAT(model.GetOutputShape(), ElementsAreArray({1, 2, 4}));
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST(BatchMatMulOpModelTest, Float32Test_SimpleLHSAdjoint) {
|
TEST_P(BatchMatMulOpTest, Float32Test_SimpleLHSAdjoint) {
|
||||||
BatchMatMulOpModel<float> model({TensorType_FLOAT32, {1, 3, 2}},
|
BatchMatMulOpModel<float> model({TensorType_FLOAT32, {1, 3, 2}},
|
||||||
{TensorType_FLOAT32, {1, 3, 4}}, true, false);
|
{TensorType_FLOAT32, {1, 3, 4}}, true, false);
|
||||||
model.PopulateTensor<float>(model.lhs(), {1, 4, 2, 5, 3, 6});
|
model.PopulateTensor<float>(model.lhs(), {1, 4, 2, 5, 3, 6});
|
||||||
@ -89,7 +113,7 @@ TEST(BatchMatMulOpModelTest, Float32Test_SimpleLHSAdjoint) {
|
|||||||
EXPECT_THAT(model.GetOutputShape(), ElementsAreArray({1, 2, 4}));
|
EXPECT_THAT(model.GetOutputShape(), ElementsAreArray({1, 2, 4}));
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST(BatchMatMulOpModelTest, Float32Test_BatchSizeTwo) {
|
TEST_P(BatchMatMulOpTest, Float32Test_BatchSizeTwo) {
|
||||||
BatchMatMulOpModel<float> model({TensorType_FLOAT32, {2, 2, 3}},
|
BatchMatMulOpModel<float> model({TensorType_FLOAT32, {2, 2, 3}},
|
||||||
{TensorType_FLOAT32, {2, 3, 4}});
|
{TensorType_FLOAT32, {2, 3, 4}});
|
||||||
model.PopulateTensor<float>(model.lhs(),
|
model.PopulateTensor<float>(model.lhs(),
|
||||||
@ -105,7 +129,7 @@ TEST(BatchMatMulOpModelTest, Float32Test_BatchSizeTwo) {
|
|||||||
EXPECT_THAT(model.GetOutputShape(), ElementsAreArray({2, 2, 4}));
|
EXPECT_THAT(model.GetOutputShape(), ElementsAreArray({2, 2, 4}));
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST(BatchMatMulOpModelTest, Float32Test_Broadcast) {
|
TEST_P(BatchMatMulOpTest, Float32Test_Broadcast) {
|
||||||
BatchMatMulOpModel<float> model({TensorType_FLOAT32, {2, 2, 3}},
|
BatchMatMulOpModel<float> model({TensorType_FLOAT32, {2, 2, 3}},
|
||||||
{TensorType_FLOAT32, {3, 4}});
|
{TensorType_FLOAT32, {3, 4}});
|
||||||
model.PopulateTensor<float>(model.lhs(),
|
model.PopulateTensor<float>(model.lhs(),
|
||||||
@ -121,7 +145,7 @@ TEST(BatchMatMulOpModelTest, Float32Test_Broadcast) {
|
|||||||
EXPECT_THAT(model.GetOutputShape(), ElementsAreArray({2, 2, 4}));
|
EXPECT_THAT(model.GetOutputShape(), ElementsAreArray({2, 2, 4}));
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST(BatchMatMulOpModelTest, Float32Test_BroadcastLHSAdjoint) {
|
TEST_P(BatchMatMulOpTest, Float32Test_BroadcastLHSAdjoint) {
|
||||||
BatchMatMulOpModel<float> model({TensorType_FLOAT32, {2, 3, 2}},
|
BatchMatMulOpModel<float> model({TensorType_FLOAT32, {2, 3, 2}},
|
||||||
{TensorType_FLOAT32, {3, 4}}, true, false);
|
{TensorType_FLOAT32, {3, 4}}, true, false);
|
||||||
model.PopulateTensor<float>(model.lhs(),
|
model.PopulateTensor<float>(model.lhs(),
|
||||||
@ -137,7 +161,7 @@ TEST(BatchMatMulOpModelTest, Float32Test_BroadcastLHSAdjoint) {
|
|||||||
EXPECT_THAT(model.GetOutputShape(), ElementsAreArray({2, 2, 4}));
|
EXPECT_THAT(model.GetOutputShape(), ElementsAreArray({2, 2, 4}));
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST(BatchMatMulOpModelTest, Float32Test_Broadcast2) {
|
TEST_P(BatchMatMulOpTest, Float32Test_Broadcast2) {
|
||||||
BatchMatMulOpModel<float> model({TensorType_FLOAT32, {2, 1, 3, 2}},
|
BatchMatMulOpModel<float> model({TensorType_FLOAT32, {2, 1, 3, 2}},
|
||||||
{TensorType_FLOAT32, {3, 2, 4}});
|
{TensorType_FLOAT32, {3, 2, 4}});
|
||||||
model.PopulateTensor<float>(model.lhs(),
|
model.PopulateTensor<float>(model.lhs(),
|
||||||
@ -161,7 +185,7 @@ TEST(BatchMatMulOpModelTest, Float32Test_Broadcast2) {
|
|||||||
EXPECT_THAT(model.GetOutputShape(), ElementsAreArray({2, 3, 3, 4}));
|
EXPECT_THAT(model.GetOutputShape(), ElementsAreArray({2, 3, 3, 4}));
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST(BatchMatMulOpModelTest, Float32Test_Broadcast2LHSAdjoint) {
|
TEST_P(BatchMatMulOpTest, Float32Test_Broadcast2LHSAdjoint) {
|
||||||
BatchMatMulOpModel<float> model({TensorType_FLOAT32, {2, 1, 2, 3}},
|
BatchMatMulOpModel<float> model({TensorType_FLOAT32, {2, 1, 2, 3}},
|
||||||
{TensorType_FLOAT32, {3, 2, 4}}, true, false);
|
{TensorType_FLOAT32, {3, 2, 4}}, true, false);
|
||||||
model.PopulateTensor<float>(model.lhs(),
|
model.PopulateTensor<float>(model.lhs(),
|
||||||
@ -185,7 +209,7 @@ TEST(BatchMatMulOpModelTest, Float32Test_Broadcast2LHSAdjoint) {
|
|||||||
EXPECT_THAT(model.GetOutputShape(), ElementsAreArray({2, 3, 3, 4}));
|
EXPECT_THAT(model.GetOutputShape(), ElementsAreArray({2, 3, 3, 4}));
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST(BatchMatMulOpModelTest, Float32Test_Broadcast2RHSAdjoint) {
|
TEST_P(BatchMatMulOpTest, Float32Test_Broadcast2RHSAdjoint) {
|
||||||
BatchMatMulOpModel<float> model({TensorType_FLOAT32, {2, 1, 3, 2}},
|
BatchMatMulOpModel<float> model({TensorType_FLOAT32, {2, 1, 3, 2}},
|
||||||
{TensorType_FLOAT32, {3, 4, 2}}, false, true);
|
{TensorType_FLOAT32, {3, 4, 2}}, false, true);
|
||||||
model.PopulateTensor<float>(model.lhs(),
|
model.PopulateTensor<float>(model.lhs(),
|
||||||
@ -208,7 +232,7 @@ TEST(BatchMatMulOpModelTest, Float32Test_Broadcast2RHSAdjoint) {
|
|||||||
EXPECT_THAT(model.GetOutputShape(), ElementsAreArray({2, 3, 3, 4}));
|
EXPECT_THAT(model.GetOutputShape(), ElementsAreArray({2, 3, 3, 4}));
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST(BatchMatMulOpModelTest, Float32Test_Broadcast2BothAdjoint) {
|
TEST_P(BatchMatMulOpTest, Float32Test_Broadcast2BothAdjoint) {
|
||||||
BatchMatMulOpModel<float> model({TensorType_FLOAT32, {2, 1, 2, 3}},
|
BatchMatMulOpModel<float> model({TensorType_FLOAT32, {2, 1, 2, 3}},
|
||||||
{TensorType_FLOAT32, {3, 4, 2}}, true, true);
|
{TensorType_FLOAT32, {3, 4, 2}}, true, true);
|
||||||
model.PopulateTensor<float>(model.lhs(),
|
model.PopulateTensor<float>(model.lhs(),
|
||||||
@ -231,7 +255,7 @@ TEST(BatchMatMulOpModelTest, Float32Test_Broadcast2BothAdjoint) {
|
|||||||
EXPECT_THAT(model.GetOutputShape(), ElementsAreArray({2, 3, 3, 4}));
|
EXPECT_THAT(model.GetOutputShape(), ElementsAreArray({2, 3, 3, 4}));
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST(BatchMatMulOpModelTest, Float32Test_BroadcastFromRHS) {
|
TEST_P(BatchMatMulOpTest, Float32Test_BroadcastFromRHS) {
|
||||||
BatchMatMulOpModel<float> model({TensorType_FLOAT32, {4, 5}},
|
BatchMatMulOpModel<float> model({TensorType_FLOAT32, {4, 5}},
|
||||||
{TensorType_FLOAT32, {3, 1, 5, 2}});
|
{TensorType_FLOAT32, {3, 1, 5, 2}});
|
||||||
model.PopulateTensor<float>(
|
model.PopulateTensor<float>(
|
||||||
@ -251,6 +275,10 @@ TEST(BatchMatMulOpModelTest, Float32Test_BroadcastFromRHS) {
|
|||||||
EXPECT_THAT(model.GetOutputShape(), ElementsAreArray({3, 1, 4, 2}));
|
EXPECT_THAT(model.GetOutputShape(), ElementsAreArray({3, 1, 4, 2}));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
INSTANTIATE_TEST_SUITE_P(
|
||||||
|
BatchMatMulOpTest, BatchMatMulOpTest,
|
||||||
|
::testing::ValuesIn(SingleOpTest::GetKernelTags(*kKernelMap)));
|
||||||
|
|
||||||
// In the hybrid model the weights are quantized int8. But the input
|
// In the hybrid model the weights are quantized int8. But the input
|
||||||
// and output are expected to be in float precision.
|
// and output are expected to be in float precision.
|
||||||
class HybridAsymmetricBatchMatMulOpModel : public SingleOpModel {
|
class HybridAsymmetricBatchMatMulOpModel : public SingleOpModel {
|
||||||
@ -304,7 +332,14 @@ class HybridAsymmetricBatchMatMulOpModel : public SingleOpModel {
|
|||||||
int input_size_;
|
int input_size_;
|
||||||
};
|
};
|
||||||
|
|
||||||
TEST(HybridAsymmetricBatchMatMulOpTest, SimpleTestQuantizedInt8) {
|
class HybridAsymmetricBatchMatMulOpTest : public SingleOpTest {
|
||||||
|
protected:
|
||||||
|
const std::map<string, TfLiteRegistration*>& GetKernelMap() override {
|
||||||
|
return *kKernelMap;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
TEST_P(HybridAsymmetricBatchMatMulOpTest, SimpleTestQuantizedInt8) {
|
||||||
HybridAsymmetricBatchMatMulOpModel m(
|
HybridAsymmetricBatchMatMulOpModel m(
|
||||||
/*units=*/3, /*batches=*/2,
|
/*units=*/3, /*batches=*/2,
|
||||||
/*lhs=*/{TensorType_FLOAT32, {2, 10}},
|
/*lhs=*/{TensorType_FLOAT32, {2, 10}},
|
||||||
@ -335,7 +370,7 @@ TEST(HybridAsymmetricBatchMatMulOpTest, SimpleTestQuantizedInt8) {
|
|||||||
EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({2, 3}));
|
EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({2, 3}));
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST(HybridAsymmetricBatchMatMulOpTest, QuantizedInt8BroadcastWeights) {
|
TEST_P(HybridAsymmetricBatchMatMulOpTest, QuantizedInt8BroadcastWeights) {
|
||||||
HybridAsymmetricBatchMatMulOpModel m(
|
HybridAsymmetricBatchMatMulOpModel m(
|
||||||
/*units=*/3, /*batches=*/2,
|
/*units=*/3, /*batches=*/2,
|
||||||
/*lhs=*/{TensorType_FLOAT32, {2, 2, 10}},
|
/*lhs=*/{TensorType_FLOAT32, {2, 2, 10}},
|
||||||
@ -366,7 +401,7 @@ TEST(HybridAsymmetricBatchMatMulOpTest, QuantizedInt8BroadcastWeights) {
|
|||||||
EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({2, 2, 3}));
|
EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({2, 2, 3}));
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST(HybridAsymmetricBatchMatMulOpTest, QuantizedInt8BroadcastBigWeights) {
|
TEST_P(HybridAsymmetricBatchMatMulOpTest, QuantizedInt8BroadcastBigWeights) {
|
||||||
HybridAsymmetricBatchMatMulOpModel m(
|
HybridAsymmetricBatchMatMulOpModel m(
|
||||||
/*units=*/9, /*batches=*/2,
|
/*units=*/9, /*batches=*/2,
|
||||||
/*lhs=*/{TensorType_FLOAT32, {2, 2, 10}},
|
/*lhs=*/{TensorType_FLOAT32, {2, 2, 10}},
|
||||||
@ -401,7 +436,7 @@ TEST(HybridAsymmetricBatchMatMulOpTest, QuantizedInt8BroadcastBigWeights) {
|
|||||||
EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({2, 2, 9}));
|
EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({2, 2, 9}));
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST(HybridAsymmetricBatchMatMulOpTest, QuantizedInt8BroadcastInputs) {
|
TEST_P(HybridAsymmetricBatchMatMulOpTest, QuantizedInt8BroadcastInputs) {
|
||||||
HybridAsymmetricBatchMatMulOpModel m(
|
HybridAsymmetricBatchMatMulOpModel m(
|
||||||
/*units=*/3, /*batches=*/2,
|
/*units=*/3, /*batches=*/2,
|
||||||
/*lhs=*/{TensorType_FLOAT32, {2, 10}},
|
/*lhs=*/{TensorType_FLOAT32, {2, 10}},
|
||||||
@ -431,5 +466,96 @@ TEST(HybridAsymmetricBatchMatMulOpTest, QuantizedInt8BroadcastInputs) {
|
|||||||
EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({2, 2, 3}));
|
EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({2, 2, 3}));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
INSTANTIATE_TEST_SUITE_P(
|
||||||
|
HybridAsymmetricBatchMatMulOpTest, HybridAsymmetricBatchMatMulOpTest,
|
||||||
|
::testing::ValuesIn(SingleOpTest::GetKernelTags(*kKernelMap)));
|
||||||
|
|
||||||
|
class QuantizedBatchMatMulOpModel : public SingleOpModel {
|
||||||
|
public:
|
||||||
|
QuantizedBatchMatMulOpModel(int units, int batches, const TensorData& lhs,
|
||||||
|
const TensorData& output = {TensorType_INT8},
|
||||||
|
bool adj_x = false, bool adj_y = false)
|
||||||
|
: units_(units), batches_(batches) {
|
||||||
|
int total_input_size = 1;
|
||||||
|
for (size_t i = 0; i < lhs.shape.size(); ++i) {
|
||||||
|
total_input_size *= lhs.shape[i];
|
||||||
|
}
|
||||||
|
input_size_ = total_input_size / batches_;
|
||||||
|
|
||||||
|
lhs_id_ = AddInput(lhs);
|
||||||
|
rhs_id_ = AddInput({lhs.type, {input_size_, units_}, lhs.min, lhs.max});
|
||||||
|
|
||||||
|
output_id_ = AddOutput(output);
|
||||||
|
|
||||||
|
SetBuiltinOp(BuiltinOperator_BATCH_MATMUL,
|
||||||
|
BuiltinOptions_BatchMatMulOptions,
|
||||||
|
CreateBatchMatMulOptions(builder_, adj_x, adj_y).Union());
|
||||||
|
BuildInterpreter({GetShape(lhs_id_), GetShape(rhs_id_)});
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
void SetWeights(const std::vector<float>& data) {
|
||||||
|
QuantizeAndPopulate<T>(rhs_id_, data);
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
void SetInput(const std::vector<float>& data) {
|
||||||
|
QuantizeAndPopulate<T>(lhs_id_, data);
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
std::vector<T> GetOutput() {
|
||||||
|
return ExtractVector<T>(output_id_);
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
std::vector<float> GetDequantizedOutput() {
|
||||||
|
return Dequantize<T>(ExtractVector<T>(output_id_), GetScale(output_id_),
|
||||||
|
GetZeroPoint(output_id_));
|
||||||
|
}
|
||||||
|
|
||||||
|
protected:
|
||||||
|
int lhs_id_;
|
||||||
|
int rhs_id_;
|
||||||
|
int output_id_;
|
||||||
|
int units_;
|
||||||
|
int batches_;
|
||||||
|
int input_size_;
|
||||||
|
};
|
||||||
|
|
||||||
|
class QuantizedBatchMatMulOpTest : public SingleOpTest {
|
||||||
|
protected:
|
||||||
|
const std::map<string, TfLiteRegistration*>& GetKernelMap() override {
|
||||||
|
return *kKernelMap;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
TEST_P(QuantizedBatchMatMulOpTest, SimpleTestQuantizedInt8) {
|
||||||
|
QuantizedBatchMatMulOpModel m(
|
||||||
|
/*units=*/3, /*batches*/ 2,
|
||||||
|
/*lhs=*/{TensorType_INT8, {2, 10}, -63.5, 64},
|
||||||
|
/*output=*/{TensorType_INT8, {}, -127, 128});
|
||||||
|
|
||||||
|
m.SetWeights<int8_t>({
|
||||||
|
1, 1, 1, 2, 2, 2, 3, 3, 3, 4, 4, 4, 5, 5, 5,
|
||||||
|
6, 6, 6, 7, 7, 7, 8, 8, 8, 9, 9, 9, 10, 10, 10,
|
||||||
|
});
|
||||||
|
|
||||||
|
m.SetInput<int8_t>({
|
||||||
|
1, 2, 3, 4, 5, 6, 7, 8, -9, -10, // b = 0
|
||||||
|
1, 2, 3, 4, 5, 6, 7, -8, 9, -10, // b = 1
|
||||||
|
});
|
||||||
|
|
||||||
|
m.Invoke();
|
||||||
|
|
||||||
|
EXPECT_THAT(m.GetDequantizedOutput<int8_t>(),
|
||||||
|
ElementsAreArray(ArrayFloatNear({23, 23, 23, 57, 57, 57})));
|
||||||
|
EXPECT_THAT(m.GetOutput<int8_t>(), ElementsAre(22, 22, 22, 56, 56, 56));
|
||||||
|
}
|
||||||
|
|
||||||
|
INSTANTIATE_TEST_SUITE_P(
|
||||||
|
QuantizedBatchMatMulOpTest, QuantizedBatchMatMulOpTest,
|
||||||
|
::testing::ValuesIn(SingleOpTest::GetKernelTags(*kKernelMap)));
|
||||||
|
|
||||||
} // namespace
|
} // namespace
|
||||||
} // namespace tflite
|
} // namespace tflite
|
||||||
|
@ -272,6 +272,112 @@ inline void BatchMatMul(const RuntimeShape& lhs_shape, const int8_t* lhs_data,
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
inline void BatchMatMul(const FullyConnectedParams& params,
|
||||||
|
const RuntimeShape& lhs_shape, const int8_t* lhs_data,
|
||||||
|
const RuntimeShape& rhs_shape, const int8_t* rhs_data,
|
||||||
|
const RuntimeShape& output_shape, int8_t* output_data,
|
||||||
|
CpuBackendContext* context) {
|
||||||
|
using ::tflite::cpu_backend_gemm::Gemm;
|
||||||
|
using ::tflite::cpu_backend_gemm::GemmParams;
|
||||||
|
using ::tflite::cpu_backend_gemm::MatrixParams;
|
||||||
|
|
||||||
|
const RuntimeShape extended_lhs_shape =
|
||||||
|
RuntimeShape::ExtendedShape(5, lhs_shape);
|
||||||
|
const RuntimeShape extended_rhs_shape =
|
||||||
|
RuntimeShape::ExtendedShape(5, rhs_shape);
|
||||||
|
|
||||||
|
// Determine which dimension is the broadcast dimension.
|
||||||
|
auto broadcast_dim = [](int lhs_dim, int rhs_dim) {
|
||||||
|
if (lhs_dim == rhs_dim) return lhs_dim;
|
||||||
|
if (lhs_dim == 1) return rhs_dim;
|
||||||
|
TFLITE_DCHECK_EQ(rhs_dim, 1);
|
||||||
|
return lhs_dim;
|
||||||
|
};
|
||||||
|
|
||||||
|
// Compute the "extent" for iterating on this dimension.
|
||||||
|
// If we are broadcasting, then don't advance (i.e return 0).
|
||||||
|
auto extent = [](const RuntimeShape& shape, int x) {
|
||||||
|
if (shape.Dims(x) == 1) {
|
||||||
|
return 0;
|
||||||
|
}
|
||||||
|
int prod = 1;
|
||||||
|
for (int i = x + 1; i < shape.DimensionsCount(); ++i) {
|
||||||
|
prod *= shape.Dims(i);
|
||||||
|
}
|
||||||
|
return prod;
|
||||||
|
};
|
||||||
|
|
||||||
|
const int batch_dim0 =
|
||||||
|
broadcast_dim(extended_lhs_shape.Dims(0), extended_rhs_shape.Dims(0));
|
||||||
|
const int batch_dim1 =
|
||||||
|
broadcast_dim(extended_lhs_shape.Dims(1), extended_rhs_shape.Dims(1));
|
||||||
|
const int batch_dim2 =
|
||||||
|
broadcast_dim(extended_lhs_shape.Dims(2), extended_rhs_shape.Dims(2));
|
||||||
|
|
||||||
|
const int lhs_ext0 = extent(extended_lhs_shape, 0);
|
||||||
|
const int lhs_ext1 = extent(extended_lhs_shape, 1);
|
||||||
|
const int lhs_ext2 = extent(extended_lhs_shape, 2);
|
||||||
|
const int rhs_ext0 = extent(extended_rhs_shape, 0);
|
||||||
|
const int rhs_ext1 = extent(extended_rhs_shape, 1);
|
||||||
|
const int rhs_ext2 = extent(extended_rhs_shape, 2);
|
||||||
|
|
||||||
|
// Set params for each matrix multiply.
|
||||||
|
const int lhs_rows = extended_lhs_shape.Dims(3);
|
||||||
|
const int rhs_cols = extended_rhs_shape.Dims(4);
|
||||||
|
const int accum_depth = extended_lhs_shape.Dims(4);
|
||||||
|
|
||||||
|
const int32 input_offset = params.input_offset;
|
||||||
|
const int32 filter_offset = params.weights_offset;
|
||||||
|
const int32 output_offset = params.output_offset;
|
||||||
|
const int32 output_multiplier = params.output_multiplier;
|
||||||
|
const int output_shift = params.output_shift;
|
||||||
|
const int32 output_activation_min = params.quantized_activation_min;
|
||||||
|
const int32 output_activation_max = params.quantized_activation_max;
|
||||||
|
TFLITE_DCHECK_LE(output_activation_min, output_activation_max);
|
||||||
|
|
||||||
|
MatrixParams<int8_t> lhs_params;
|
||||||
|
lhs_params.order = cpu_backend_gemm::Order::kRowMajor;
|
||||||
|
lhs_params.rows = lhs_rows;
|
||||||
|
lhs_params.cols = accum_depth;
|
||||||
|
lhs_params.zero_point = -filter_offset;
|
||||||
|
|
||||||
|
MatrixParams<int8_t> rhs_params;
|
||||||
|
rhs_params.order = cpu_backend_gemm::Order::kColMajor;
|
||||||
|
rhs_params.rows = accum_depth;
|
||||||
|
rhs_params.cols = rhs_cols;
|
||||||
|
rhs_params.zero_point = -input_offset;
|
||||||
|
|
||||||
|
MatrixParams<int8_t> dst_params;
|
||||||
|
dst_params.order = cpu_backend_gemm::Order::kColMajor;
|
||||||
|
dst_params.rows = lhs_rows;
|
||||||
|
dst_params.cols = rhs_cols;
|
||||||
|
dst_params.zero_point = output_offset;
|
||||||
|
|
||||||
|
for (int b0 = 0; b0 < batch_dim0; ++b0) {
|
||||||
|
const int8_t* lhs_ptr0 = lhs_data + (b0 * lhs_ext0);
|
||||||
|
const int8_t* rhs_ptr0 = rhs_data + (b0 * rhs_ext0);
|
||||||
|
for (int b1 = 0; b1 < batch_dim1; ++b1) {
|
||||||
|
const int8_t* lhs_ptr1 = lhs_ptr0 + b1 * lhs_ext1;
|
||||||
|
const int8_t* rhs_ptr1 = rhs_ptr0 + b1 * rhs_ext1;
|
||||||
|
for (int b2 = 0; b2 < batch_dim2; ++b2) {
|
||||||
|
const int8_t* lhs_ptr2 = lhs_ptr1 + b2 * lhs_ext2;
|
||||||
|
const int8_t* rhs_ptr2 = rhs_ptr1 + b2 * rhs_ext2;
|
||||||
|
int8_t* out_ptr = output_data + ((b0 * batch_dim1 * batch_dim2) +
|
||||||
|
b1 * batch_dim2 + b2) *
|
||||||
|
lhs_rows * rhs_cols;
|
||||||
|
|
||||||
|
GemmParams<int32_t, int8_t> gemm_params;
|
||||||
|
gemm_params.clamp_min = output_activation_min;
|
||||||
|
gemm_params.clamp_max = output_activation_max;
|
||||||
|
gemm_params.multiplier_fixedpoint = output_multiplier;
|
||||||
|
gemm_params.multiplier_exponent = output_shift;
|
||||||
|
cpu_backend_gemm::Gemm(lhs_params, lhs_ptr2, rhs_params, rhs_ptr2,
|
||||||
|
dst_params, out_ptr, gemm_params, context);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
} // namespace optimized_ops
|
} // namespace optimized_ops
|
||||||
} // namespace tflite
|
} // namespace tflite
|
||||||
|
|
||||||
|
@ -217,6 +217,99 @@ inline void BatchMatMul(const RuntimeShape& lhs_shape, const int8_t* lhs_data,
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
inline void BatchMatMul(const FullyConnectedParams& params,
|
||||||
|
const RuntimeShape& lhs_shape, const int8_t* lhs_data,
|
||||||
|
const RuntimeShape& rhs_shape, const int8_t* rhs_data,
|
||||||
|
const RuntimeShape& output_shape, int8_t* output_data) {
|
||||||
|
const RuntimeShape extended_lhs_shape =
|
||||||
|
RuntimeShape::ExtendedShape(5, lhs_shape);
|
||||||
|
const RuntimeShape extended_rhs_shape =
|
||||||
|
RuntimeShape::ExtendedShape(5, rhs_shape);
|
||||||
|
|
||||||
|
// Determine which dimension is the broadcast dimension.
|
||||||
|
auto broadcast_dim = [](int lhs_dim, int rhs_dim) {
|
||||||
|
if (lhs_dim == rhs_dim) return lhs_dim;
|
||||||
|
if (lhs_dim == 1) return rhs_dim;
|
||||||
|
TFLITE_DCHECK_EQ(rhs_dim, 1);
|
||||||
|
return lhs_dim;
|
||||||
|
};
|
||||||
|
|
||||||
|
// Compute the "extent" for iterating on this dimension.
|
||||||
|
// If we are broadcasting, then don't advance (i.e return 0).
|
||||||
|
auto extent = [](const RuntimeShape& shape, int x) {
|
||||||
|
if (shape.Dims(x) == 1) {
|
||||||
|
return 0;
|
||||||
|
}
|
||||||
|
int prod = 1;
|
||||||
|
for (int i = x + 1; i < shape.DimensionsCount(); ++i) {
|
||||||
|
prod *= shape.Dims(i);
|
||||||
|
}
|
||||||
|
return prod;
|
||||||
|
};
|
||||||
|
|
||||||
|
const int batch_dim0 =
|
||||||
|
broadcast_dim(extended_lhs_shape.Dims(0), extended_rhs_shape.Dims(0));
|
||||||
|
const int batch_dim1 =
|
||||||
|
broadcast_dim(extended_lhs_shape.Dims(1), extended_rhs_shape.Dims(1));
|
||||||
|
const int batch_dim2 =
|
||||||
|
broadcast_dim(extended_lhs_shape.Dims(2), extended_rhs_shape.Dims(2));
|
||||||
|
|
||||||
|
const int lhs_ext0 = extent(extended_lhs_shape, 0);
|
||||||
|
const int lhs_ext1 = extent(extended_lhs_shape, 1);
|
||||||
|
const int lhs_ext2 = extent(extended_lhs_shape, 2);
|
||||||
|
const int rhs_ext0 = extent(extended_rhs_shape, 0);
|
||||||
|
const int rhs_ext1 = extent(extended_rhs_shape, 1);
|
||||||
|
const int rhs_ext2 = extent(extended_rhs_shape, 2);
|
||||||
|
|
||||||
|
// Set params for each matrix multiply.
|
||||||
|
const int lhs_rows = extended_lhs_shape.Dims(3);
|
||||||
|
const int rhs_cols = extended_rhs_shape.Dims(4);
|
||||||
|
const int accum_depth = extended_lhs_shape.Dims(4);
|
||||||
|
|
||||||
|
const int32 input_offset = params.input_offset;
|
||||||
|
const int32 filter_offset = params.weights_offset;
|
||||||
|
const int32 output_offset = params.output_offset;
|
||||||
|
const int32 output_multiplier = params.output_multiplier;
|
||||||
|
const int output_shift = params.output_shift;
|
||||||
|
const int32 output_activation_min = params.quantized_activation_min;
|
||||||
|
const int32 output_activation_max = params.quantized_activation_max;
|
||||||
|
TFLITE_DCHECK_LE(output_activation_min, output_activation_max);
|
||||||
|
|
||||||
|
for (int b0 = 0; b0 < batch_dim0; ++b0) {
|
||||||
|
const int8_t* lhs_ptr0 = lhs_data + (b0 * lhs_ext0);
|
||||||
|
const int8_t* rhs_ptr0 = rhs_data + (b0 * rhs_ext0);
|
||||||
|
for (int b1 = 0; b1 < batch_dim1; ++b1) {
|
||||||
|
const int8_t* lhs_ptr1 = lhs_ptr0 + b1 * lhs_ext1;
|
||||||
|
const int8_t* rhs_ptr1 = rhs_ptr0 + b1 * rhs_ext1;
|
||||||
|
for (int b2 = 0; b2 < batch_dim2; ++b2) {
|
||||||
|
const int8_t* lhs_ptr2 = lhs_ptr1 + b2 * lhs_ext2;
|
||||||
|
const int8_t* rhs_ptr2 = rhs_ptr1 + b2 * rhs_ext2;
|
||||||
|
int8_t* out_ptr = output_data + ((b0 * batch_dim1 * batch_dim2) +
|
||||||
|
b1 * batch_dim2 + b2) *
|
||||||
|
lhs_rows * rhs_cols;
|
||||||
|
|
||||||
|
for (int j = 0; j < rhs_cols; ++j) {
|
||||||
|
for (int i = 0; i < lhs_rows; ++i) {
|
||||||
|
int32_t total = 0;
|
||||||
|
for (int k = 0; k < accum_depth; ++k) {
|
||||||
|
int32 lhs_val = lhs_ptr2[accum_depth * i + k];
|
||||||
|
int32 rhs_val = rhs_ptr2[accum_depth * j + k];
|
||||||
|
total += (lhs_val + filter_offset) * (rhs_val + input_offset);
|
||||||
|
}
|
||||||
|
total = MultiplyByQuantizedMultiplier(total, output_multiplier,
|
||||||
|
output_shift);
|
||||||
|
total += output_offset;
|
||||||
|
total = std::max(total, output_activation_min);
|
||||||
|
total = std::min(total, output_activation_max);
|
||||||
|
const int idx = lhs_rows * j + i;
|
||||||
|
out_ptr[idx] = static_cast<int8_t>(total);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
} // namespace reference_ops
|
} // namespace reference_ops
|
||||||
} // namespace tflite
|
} // namespace tflite
|
||||||
|
|
||||||
|
@ -289,7 +289,9 @@ BuiltinOpResolver::BuiltinOpResolver() {
|
|||||||
AddBuiltin(BuiltinOperator_SCATTER_ND, Register_SCATTER_ND());
|
AddBuiltin(BuiltinOperator_SCATTER_ND, Register_SCATTER_ND());
|
||||||
AddBuiltin(BuiltinOperator_DENSIFY, Register_DENSIFY());
|
AddBuiltin(BuiltinOperator_DENSIFY, Register_DENSIFY());
|
||||||
AddBuiltin(BuiltinOperator_SEGMENT_SUM, Register_SEGMENT_SUM());
|
AddBuiltin(BuiltinOperator_SEGMENT_SUM, Register_SEGMENT_SUM());
|
||||||
AddBuiltin(BuiltinOperator_BATCH_MATMUL, Register_BATCH_MATMUL());
|
AddBuiltin(BuiltinOperator_BATCH_MATMUL, Register_BATCH_MATMUL(),
|
||||||
|
/* min_version = */ 1,
|
||||||
|
/* max_version = */ 2);
|
||||||
AddCustom("NumericVerify", tflite::ops::custom::Register_NUMERIC_VERIFY());
|
AddCustom("NumericVerify", tflite::ops::custom::Register_NUMERIC_VERIFY());
|
||||||
// TODO(andrewharp, ahentz): Move these somewhere more appropriate so that
|
// TODO(andrewharp, ahentz): Move these somewhere more appropriate so that
|
||||||
// custom ops aren't always included by default.
|
// custom ops aren't always included by default.
|
||||||
|
@ -88,6 +88,12 @@ OperatorProperty GetOperatorProperty(const ModelT* model, int subgraph_index,
|
|||||||
property.restrict_same_input_output_scale = true;
|
property.restrict_same_input_output_scale = true;
|
||||||
property.version = 2;
|
property.version = 2;
|
||||||
break;
|
break;
|
||||||
|
case BuiltinOperator_BATCH_MATMUL: {
|
||||||
|
property.inputs = {{0, {}}, {1, {}}};
|
||||||
|
property.outputs = {{0, {}}};
|
||||||
|
property.version = 2;
|
||||||
|
break;
|
||||||
|
}
|
||||||
case BuiltinOperator_BATCH_TO_SPACE_ND:
|
case BuiltinOperator_BATCH_TO_SPACE_ND:
|
||||||
case BuiltinOperator_SPACE_TO_BATCH_ND:
|
case BuiltinOperator_SPACE_TO_BATCH_ND:
|
||||||
case BuiltinOperator_SPACE_TO_DEPTH:
|
case BuiltinOperator_SPACE_TO_DEPTH:
|
||||||
|
@ -24,6 +24,7 @@ limitations under the License.
|
|||||||
#include "absl/strings/str_split.h"
|
#include "absl/strings/str_split.h"
|
||||||
#include "tensorflow/core/platform/logging.h"
|
#include "tensorflow/core/platform/logging.h"
|
||||||
#include "tensorflow/lite/kernels/internal/compatibility.h"
|
#include "tensorflow/lite/kernels/internal/compatibility.h"
|
||||||
|
#include "tensorflow/lite/schema/schema_generated.h"
|
||||||
|
|
||||||
namespace tflite {
|
namespace tflite {
|
||||||
namespace {
|
namespace {
|
||||||
@ -518,6 +519,7 @@ int GetBuiltinOperatorVersion(const OpSignature& op_sig) {
|
|||||||
case BuiltinOperator_LESS:
|
case BuiltinOperator_LESS:
|
||||||
case BuiltinOperator_LESS_EQUAL:
|
case BuiltinOperator_LESS_EQUAL:
|
||||||
case BuiltinOperator_SELECT:
|
case BuiltinOperator_SELECT:
|
||||||
|
case BuiltinOperator_BATCH_MATMUL:
|
||||||
if (op_sig.input_types.at(0) == TensorType_INT8) {
|
if (op_sig.input_types.at(0) == TensorType_INT8) {
|
||||||
return 2;
|
return 2;
|
||||||
}
|
}
|
||||||
|
@ -58,6 +58,7 @@ std::string FindMinimumRuntimeVersionForOp(tflite::BuiltinOperator op_code,
|
|||||||
{{BuiltinOperator_AVERAGE_POOL_2D, 2}, "1.14.0"},
|
{{BuiltinOperator_AVERAGE_POOL_2D, 2}, "1.14.0"},
|
||||||
{{BuiltinOperator_AVERAGE_POOL_2D, 3}, kPendingReleaseVersion},
|
{{BuiltinOperator_AVERAGE_POOL_2D, 3}, kPendingReleaseVersion},
|
||||||
{{BuiltinOperator_BATCH_MATMUL, 1}, kPendingReleaseVersion},
|
{{BuiltinOperator_BATCH_MATMUL, 1}, kPendingReleaseVersion},
|
||||||
|
{{BuiltinOperator_BATCH_MATMUL, 2}, kPendingReleaseVersion},
|
||||||
{{BuiltinOperator_CONV_2D, 1}, "1.5.0"},
|
{{BuiltinOperator_CONV_2D, 1}, "1.5.0"},
|
||||||
{{BuiltinOperator_CONV_2D, 2}, "1.14.0"},
|
{{BuiltinOperator_CONV_2D, 2}, "1.14.0"},
|
||||||
{{BuiltinOperator_CONV_2D, 3}, "1.14.0"},
|
{{BuiltinOperator_CONV_2D, 3}, "1.14.0"},
|
||||||
|
Loading…
x
Reference in New Issue
Block a user