diff --git a/tensorflow/compiler/mlir/lite/ir/tfl_ops.td b/tensorflow/compiler/mlir/lite/ir/tfl_ops.td index 509c13ae161..33281cc58fb 100644 --- a/tensorflow/compiler/mlir/lite/ir/tfl_ops.td +++ b/tensorflow/compiler/mlir/lite/ir/tfl_ops.td @@ -953,14 +953,14 @@ in the batch dimensions and broadcasting. }]; let arguments = (ins - TFL_TensorOf<[F32]>:$x, - TFL_TensorOf<[F32]>:$y, + TFL_TensorOf<[F32, QI8]>:$x, + TFL_TensorOf<[F32, QI8]>:$y, DefaultValuedAttr<BoolAttr, "false">:$adj_x, DefaultValuedAttr<BoolAttr, "false">:$adj_y ); let results = (outs - TFL_TensorOf<[F32]>:$output + TFL_TensorOf<[F32, QI8]>:$output ); let hasOptions = 1; diff --git a/tensorflow/lite/kernels/batch_matmul.cc b/tensorflow/lite/kernels/batch_matmul.cc index 8bc23c9c94a..a414a226504 100644 --- a/tensorflow/lite/kernels/batch_matmul.cc +++ b/tensorflow/lite/kernels/batch_matmul.cc @@ -19,6 +19,7 @@ limitations under the License. #include <algorithm> #include <cstdint> +#include <limits> #include "tensorflow/lite/c/builtin_op_data.h" #include "tensorflow/lite/c/common.h" @@ -52,6 +53,14 @@ enum KernelType { }; struct OpData { + // The scaling factor from input to output (aka the 'real multiplier') can + // be represented as a fixed point multiplier plus a left shift. + int32_t output_multiplier; + int output_shift; + // The range of the fused activation layer. For example for kNone and + // uint8_t these would be 0 and 255. + int32_t output_activation_min; + int32_t output_activation_max; // The index of the temporary tensors where we store transposed LHS/RHS. int scratch_tensor_index; bool rhs_transposed; @@ -274,6 +283,7 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { OpContext op_context(context, node); TF_LITE_ENSURE_OK(context, InitializeTemporaries(context, node, &op_context)); + OpData* op_data = reinterpret_cast<OpData*>(node->user_data); bool adj_x = op_context.params->adj_x; bool adj_y = op_context.params->adj_y; @@ -282,7 +292,24 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { const TfLiteTensor* rhs_data = GetInput(context, node, kInputRHSTensor); TfLiteTensor* output = GetOutput(context, node, kOutputTensor); - TF_LITE_ENSURE_TYPES_EQ(context, lhs_data->type, kTfLiteFloat32); + // Note that quantized inference requires that all tensors have their + // parameters set. This is usually done during quantized training. + if (lhs_data->type == kTfLiteInt8) { + double real_multiplier = 0.0; + TF_LITE_ENSURE_STATUS(GetQuantizedConvolutionMultipler( + context, lhs_data, rhs_data, output, &real_multiplier)); + int exponent; + QuantizeMultiplier(real_multiplier, &op_data->output_multiplier, &exponent); + op_data->output_shift = exponent; + // BatchMatMul has no fused activation functions. Therefore, set + // output activation min and max to min and max of int8_t type, + // respecitvely. + op_data->output_activation_min = std::numeric_limits<int8_t>::min(); + op_data->output_activation_max = std::numeric_limits<int8_t>::max(); + } + + TF_LITE_ENSURE(context, lhs_data->type == kTfLiteFloat32 || + lhs_data->type == kTfLiteInt8); TF_LITE_ENSURE(context, rhs_data->type == kTfLiteFloat32 || rhs_data->type == kTfLiteInt8); // Support dimensions between 2 and 4, inclusive. @@ -433,6 +460,41 @@ TfLiteStatus EvalHybrid(TfLiteContext* context, TfLiteNode* node, OpData* data, return kTfLiteOk; } +template <KernelType kernel_type> +TfLiteStatus EvalInt8(TfLiteContext* context, const OpData* data, + const RuntimeShape& lhs_shape, const TfLiteTensor* lhs, + const RuntimeShape& rhs_shape, const TfLiteTensor* rhs, + const RuntimeShape& output_shape, TfLiteTensor* output) { + // Reuse params struct from FullyConnected Op. + FullyConnectedParams op_params; + int32_t input_offset = -lhs->params.zero_point; + int32_t filter_offset = -rhs->params.zero_point; + int32_t output_offset = output->params.zero_point; + op_params.input_offset = input_offset; + op_params.weights_offset = filter_offset; + op_params.output_offset = output_offset; + op_params.output_multiplier = data->output_multiplier; + op_params.output_shift = data->output_shift; + op_params.quantized_activation_min = data->output_activation_min; + op_params.quantized_activation_max = data->output_activation_max; + op_params.lhs_cacheable = IsConstantTensor(lhs); + op_params.rhs_cacheable = IsConstantTensor(rhs); + + if (kernel_type == kReference) { + reference_ops::BatchMatMul(op_params, rhs_shape, GetTensorData<int8_t>(rhs), + lhs_shape, GetTensorData<int8_t>(lhs), + GetTensorShape(output), + GetTensorData<int8_t>(output)); + } else { + optimized_ops::BatchMatMul(op_params, rhs_shape, GetTensorData<int8_t>(rhs), + lhs_shape, GetTensorData<int8_t>(lhs), + GetTensorShape(output), + GetTensorData<int8_t>(output), + CpuBackendContext::GetFromContext(context)); + } + return kTfLiteOk; +} + template <KernelType kernel_type> TfLiteStatus EvalQuantized(TfLiteContext* context, TfLiteNode* node, OpData* data, const RuntimeShape& lhs_shape, @@ -448,25 +510,39 @@ TfLiteStatus EvalQuantized(TfLiteContext* context, TfLiteNode* node, return EvalHybrid<kernel_type>( context, node, data, lhs_shape, lhs, rhs_shape, rhs, input_quantized, scaling_factors, accum_scratch, row_sums, input_offsets, output); + } else if (lhs->type == kTfLiteInt8) { + return EvalInt8<kernel_type>(context, data, lhs_shape, lhs, rhs_shape, rhs, + GetTensorShape(output), output); } else { - TF_LITE_KERNEL_LOG(context, - "Currently only hybrid quantization is supported.\n"); + TF_LITE_KERNEL_LOG( + context, "Currently only hybrid and int8 quantization is supported.\n"); return kTfLiteError; } return kTfLiteOk; } -TfLiteTensor* GetRhs(TfLiteContext* context, TfLiteNode* node, - const TfLiteTensor* rhs) { +TfLiteTensor* GetTempRhs(TfLiteContext* context, TfLiteNode* node, + const TfLiteTensor* rhs) { TfLiteTensor* transposed_rhs = GetTemporary(context, node, 1); if (rhs->type == kTfLiteInt8) { - // Get the quantization params from the weights tensors. + // Get the quantization params from the RHS tensor. transposed_rhs->params.scale = rhs->params.scale; transposed_rhs->params.zero_point = rhs->params.zero_point; } return transposed_rhs; } +TfLiteTensor* GetTempLhs(TfLiteContext* context, TfLiteNode* node, + const TfLiteTensor* lhs) { + TfLiteTensor* transposed_lhs = GetTemporary(context, node, 0); + if (lhs->type == kTfLiteInt8) { + // Get the quantization params from the LHS tensor. + transposed_lhs->params.scale = lhs->params.scale; + transposed_lhs->params.zero_point = lhs->params.zero_point; + } + return transposed_lhs; +} + // Perform a batch matrix multiply on // LHS <..., A, B> X RHS<..., B, C> // where the leading dimensions of LHS and RHS obey broadcasting rules @@ -491,8 +567,8 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { bool adj_y = op_context.params->adj_y; bool adj_x = op_context.params->adj_x; - const TfLiteTensor* rhs_tensor = adj_y ? rhs : GetRhs(context, node, rhs); - const TfLiteTensor* lhs_tensor = adj_x ? GetTemporary(context, node, 0) : lhs; + const TfLiteTensor* rhs_tensor = adj_y ? rhs : GetTempRhs(context, node, rhs); + const TfLiteTensor* lhs_tensor = adj_x ? GetTempLhs(context, node, lhs) : lhs; if (!adj_y) { // TODO(b/154760341) Constant tensors should already be transposed, but // we transpose once if necessary for now. diff --git a/tensorflow/lite/kernels/batch_matmul_test.cc b/tensorflow/lite/kernels/batch_matmul_test.cc index 5e52479f49b..98df8ebe3db 100644 --- a/tensorflow/lite/kernels/batch_matmul_test.cc +++ b/tensorflow/lite/kernels/batch_matmul_test.cc @@ -24,8 +24,19 @@ limitations under the License. #include "tensorflow/lite/schema/schema_generated.h" namespace tflite { + +namespace ops { +namespace builtin { + +TfLiteRegistration* Register_BATCH_MATMUL_REF(); +TfLiteRegistration* Register_BATCH_MATMUL_GENERIC_OPTIMIZED(); + +} // namespace builtin +} // namespace ops + namespace { +using ::testing::ElementsAre; using ::testing::ElementsAreArray; template <typename T> @@ -53,7 +64,20 @@ class BatchMatMulOpModel : public SingleOpModel { int output_id_; }; -TEST(BatchMatMulOpModelTest, Float32Test_Simple) { +const auto kKernelMap = new std::map<string, TfLiteRegistration*>({ + {"Reference", ops::builtin::Register_BATCH_MATMUL_REF()}, + {"GenericOptimized", + ops::builtin::Register_BATCH_MATMUL_GENERIC_OPTIMIZED()}, +}); + +class BatchMatMulOpTest : public SingleOpTest { + protected: + const std::map<string, TfLiteRegistration*>& GetKernelMap() override { + return *kKernelMap; + } +}; + +TEST_P(BatchMatMulOpTest, Float32Test_Simple) { BatchMatMulOpModel<float> model({TensorType_FLOAT32, {1, 2, 3}}, {TensorType_FLOAT32, {1, 3, 4}}); model.PopulateTensor<float>(model.lhs(), {1, 2, 3, 4, 5, 6}); @@ -65,7 +89,7 @@ TEST(BatchMatMulOpModelTest, Float32Test_Simple) { EXPECT_THAT(model.GetOutputShape(), ElementsAreArray({1, 2, 4})); } -TEST(BatchMatMulOpModelTest, Float32Test_SimpleRHSAdjoint) { +TEST_P(BatchMatMulOpTest, Float32Test_SimpleRHSAdjoint) { BatchMatMulOpModel<float> model({TensorType_FLOAT32, {1, 2, 3}}, {TensorType_FLOAT32, {1, 4, 3}}, false, true); model.PopulateTensor<float>(model.lhs(), {1, 2, 3, 4, 5, 6}); @@ -77,7 +101,7 @@ TEST(BatchMatMulOpModelTest, Float32Test_SimpleRHSAdjoint) { EXPECT_THAT(model.GetOutputShape(), ElementsAreArray({1, 2, 4})); } -TEST(BatchMatMulOpModelTest, Float32Test_SimpleLHSAdjoint) { +TEST_P(BatchMatMulOpTest, Float32Test_SimpleLHSAdjoint) { BatchMatMulOpModel<float> model({TensorType_FLOAT32, {1, 3, 2}}, {TensorType_FLOAT32, {1, 3, 4}}, true, false); model.PopulateTensor<float>(model.lhs(), {1, 4, 2, 5, 3, 6}); @@ -89,7 +113,7 @@ TEST(BatchMatMulOpModelTest, Float32Test_SimpleLHSAdjoint) { EXPECT_THAT(model.GetOutputShape(), ElementsAreArray({1, 2, 4})); } -TEST(BatchMatMulOpModelTest, Float32Test_BatchSizeTwo) { +TEST_P(BatchMatMulOpTest, Float32Test_BatchSizeTwo) { BatchMatMulOpModel<float> model({TensorType_FLOAT32, {2, 2, 3}}, {TensorType_FLOAT32, {2, 3, 4}}); model.PopulateTensor<float>(model.lhs(), @@ -105,7 +129,7 @@ TEST(BatchMatMulOpModelTest, Float32Test_BatchSizeTwo) { EXPECT_THAT(model.GetOutputShape(), ElementsAreArray({2, 2, 4})); } -TEST(BatchMatMulOpModelTest, Float32Test_Broadcast) { +TEST_P(BatchMatMulOpTest, Float32Test_Broadcast) { BatchMatMulOpModel<float> model({TensorType_FLOAT32, {2, 2, 3}}, {TensorType_FLOAT32, {3, 4}}); model.PopulateTensor<float>(model.lhs(), @@ -121,7 +145,7 @@ TEST(BatchMatMulOpModelTest, Float32Test_Broadcast) { EXPECT_THAT(model.GetOutputShape(), ElementsAreArray({2, 2, 4})); } -TEST(BatchMatMulOpModelTest, Float32Test_BroadcastLHSAdjoint) { +TEST_P(BatchMatMulOpTest, Float32Test_BroadcastLHSAdjoint) { BatchMatMulOpModel<float> model({TensorType_FLOAT32, {2, 3, 2}}, {TensorType_FLOAT32, {3, 4}}, true, false); model.PopulateTensor<float>(model.lhs(), @@ -137,7 +161,7 @@ TEST(BatchMatMulOpModelTest, Float32Test_BroadcastLHSAdjoint) { EXPECT_THAT(model.GetOutputShape(), ElementsAreArray({2, 2, 4})); } -TEST(BatchMatMulOpModelTest, Float32Test_Broadcast2) { +TEST_P(BatchMatMulOpTest, Float32Test_Broadcast2) { BatchMatMulOpModel<float> model({TensorType_FLOAT32, {2, 1, 3, 2}}, {TensorType_FLOAT32, {3, 2, 4}}); model.PopulateTensor<float>(model.lhs(), @@ -161,7 +185,7 @@ TEST(BatchMatMulOpModelTest, Float32Test_Broadcast2) { EXPECT_THAT(model.GetOutputShape(), ElementsAreArray({2, 3, 3, 4})); } -TEST(BatchMatMulOpModelTest, Float32Test_Broadcast2LHSAdjoint) { +TEST_P(BatchMatMulOpTest, Float32Test_Broadcast2LHSAdjoint) { BatchMatMulOpModel<float> model({TensorType_FLOAT32, {2, 1, 2, 3}}, {TensorType_FLOAT32, {3, 2, 4}}, true, false); model.PopulateTensor<float>(model.lhs(), @@ -185,7 +209,7 @@ TEST(BatchMatMulOpModelTest, Float32Test_Broadcast2LHSAdjoint) { EXPECT_THAT(model.GetOutputShape(), ElementsAreArray({2, 3, 3, 4})); } -TEST(BatchMatMulOpModelTest, Float32Test_Broadcast2RHSAdjoint) { +TEST_P(BatchMatMulOpTest, Float32Test_Broadcast2RHSAdjoint) { BatchMatMulOpModel<float> model({TensorType_FLOAT32, {2, 1, 3, 2}}, {TensorType_FLOAT32, {3, 4, 2}}, false, true); model.PopulateTensor<float>(model.lhs(), @@ -208,7 +232,7 @@ TEST(BatchMatMulOpModelTest, Float32Test_Broadcast2RHSAdjoint) { EXPECT_THAT(model.GetOutputShape(), ElementsAreArray({2, 3, 3, 4})); } -TEST(BatchMatMulOpModelTest, Float32Test_Broadcast2BothAdjoint) { +TEST_P(BatchMatMulOpTest, Float32Test_Broadcast2BothAdjoint) { BatchMatMulOpModel<float> model({TensorType_FLOAT32, {2, 1, 2, 3}}, {TensorType_FLOAT32, {3, 4, 2}}, true, true); model.PopulateTensor<float>(model.lhs(), @@ -231,7 +255,7 @@ TEST(BatchMatMulOpModelTest, Float32Test_Broadcast2BothAdjoint) { EXPECT_THAT(model.GetOutputShape(), ElementsAreArray({2, 3, 3, 4})); } -TEST(BatchMatMulOpModelTest, Float32Test_BroadcastFromRHS) { +TEST_P(BatchMatMulOpTest, Float32Test_BroadcastFromRHS) { BatchMatMulOpModel<float> model({TensorType_FLOAT32, {4, 5}}, {TensorType_FLOAT32, {3, 1, 5, 2}}); model.PopulateTensor<float>( @@ -251,6 +275,10 @@ TEST(BatchMatMulOpModelTest, Float32Test_BroadcastFromRHS) { EXPECT_THAT(model.GetOutputShape(), ElementsAreArray({3, 1, 4, 2})); } +INSTANTIATE_TEST_SUITE_P( + BatchMatMulOpTest, BatchMatMulOpTest, + ::testing::ValuesIn(SingleOpTest::GetKernelTags(*kKernelMap))); + // In the hybrid model the weights are quantized int8. But the input // and output are expected to be in float precision. class HybridAsymmetricBatchMatMulOpModel : public SingleOpModel { @@ -304,7 +332,14 @@ class HybridAsymmetricBatchMatMulOpModel : public SingleOpModel { int input_size_; }; -TEST(HybridAsymmetricBatchMatMulOpTest, SimpleTestQuantizedInt8) { +class HybridAsymmetricBatchMatMulOpTest : public SingleOpTest { + protected: + const std::map<string, TfLiteRegistration*>& GetKernelMap() override { + return *kKernelMap; + } +}; + +TEST_P(HybridAsymmetricBatchMatMulOpTest, SimpleTestQuantizedInt8) { HybridAsymmetricBatchMatMulOpModel m( /*units=*/3, /*batches=*/2, /*lhs=*/{TensorType_FLOAT32, {2, 10}}, @@ -335,7 +370,7 @@ TEST(HybridAsymmetricBatchMatMulOpTest, SimpleTestQuantizedInt8) { EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({2, 3})); } -TEST(HybridAsymmetricBatchMatMulOpTest, QuantizedInt8BroadcastWeights) { +TEST_P(HybridAsymmetricBatchMatMulOpTest, QuantizedInt8BroadcastWeights) { HybridAsymmetricBatchMatMulOpModel m( /*units=*/3, /*batches=*/2, /*lhs=*/{TensorType_FLOAT32, {2, 2, 10}}, @@ -366,7 +401,7 @@ TEST(HybridAsymmetricBatchMatMulOpTest, QuantizedInt8BroadcastWeights) { EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({2, 2, 3})); } -TEST(HybridAsymmetricBatchMatMulOpTest, QuantizedInt8BroadcastBigWeights) { +TEST_P(HybridAsymmetricBatchMatMulOpTest, QuantizedInt8BroadcastBigWeights) { HybridAsymmetricBatchMatMulOpModel m( /*units=*/9, /*batches=*/2, /*lhs=*/{TensorType_FLOAT32, {2, 2, 10}}, @@ -401,7 +436,7 @@ TEST(HybridAsymmetricBatchMatMulOpTest, QuantizedInt8BroadcastBigWeights) { EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({2, 2, 9})); } -TEST(HybridAsymmetricBatchMatMulOpTest, QuantizedInt8BroadcastInputs) { +TEST_P(HybridAsymmetricBatchMatMulOpTest, QuantizedInt8BroadcastInputs) { HybridAsymmetricBatchMatMulOpModel m( /*units=*/3, /*batches=*/2, /*lhs=*/{TensorType_FLOAT32, {2, 10}}, @@ -431,5 +466,96 @@ TEST(HybridAsymmetricBatchMatMulOpTest, QuantizedInt8BroadcastInputs) { EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({2, 2, 3})); } +INSTANTIATE_TEST_SUITE_P( + HybridAsymmetricBatchMatMulOpTest, HybridAsymmetricBatchMatMulOpTest, + ::testing::ValuesIn(SingleOpTest::GetKernelTags(*kKernelMap))); + +class QuantizedBatchMatMulOpModel : public SingleOpModel { + public: + QuantizedBatchMatMulOpModel(int units, int batches, const TensorData& lhs, + const TensorData& output = {TensorType_INT8}, + bool adj_x = false, bool adj_y = false) + : units_(units), batches_(batches) { + int total_input_size = 1; + for (size_t i = 0; i < lhs.shape.size(); ++i) { + total_input_size *= lhs.shape[i]; + } + input_size_ = total_input_size / batches_; + + lhs_id_ = AddInput(lhs); + rhs_id_ = AddInput({lhs.type, {input_size_, units_}, lhs.min, lhs.max}); + + output_id_ = AddOutput(output); + + SetBuiltinOp(BuiltinOperator_BATCH_MATMUL, + BuiltinOptions_BatchMatMulOptions, + CreateBatchMatMulOptions(builder_, adj_x, adj_y).Union()); + BuildInterpreter({GetShape(lhs_id_), GetShape(rhs_id_)}); + } + + template <typename T> + void SetWeights(const std::vector<float>& data) { + QuantizeAndPopulate<T>(rhs_id_, data); + } + + template <typename T> + void SetInput(const std::vector<float>& data) { + QuantizeAndPopulate<T>(lhs_id_, data); + } + + template <typename T> + std::vector<T> GetOutput() { + return ExtractVector<T>(output_id_); + } + + template <typename T> + std::vector<float> GetDequantizedOutput() { + return Dequantize<T>(ExtractVector<T>(output_id_), GetScale(output_id_), + GetZeroPoint(output_id_)); + } + + protected: + int lhs_id_; + int rhs_id_; + int output_id_; + int units_; + int batches_; + int input_size_; +}; + +class QuantizedBatchMatMulOpTest : public SingleOpTest { + protected: + const std::map<string, TfLiteRegistration*>& GetKernelMap() override { + return *kKernelMap; + } +}; + +TEST_P(QuantizedBatchMatMulOpTest, SimpleTestQuantizedInt8) { + QuantizedBatchMatMulOpModel m( + /*units=*/3, /*batches*/ 2, + /*lhs=*/{TensorType_INT8, {2, 10}, -63.5, 64}, + /*output=*/{TensorType_INT8, {}, -127, 128}); + + m.SetWeights<int8_t>({ + 1, 1, 1, 2, 2, 2, 3, 3, 3, 4, 4, 4, 5, 5, 5, + 6, 6, 6, 7, 7, 7, 8, 8, 8, 9, 9, 9, 10, 10, 10, + }); + + m.SetInput<int8_t>({ + 1, 2, 3, 4, 5, 6, 7, 8, -9, -10, // b = 0 + 1, 2, 3, 4, 5, 6, 7, -8, 9, -10, // b = 1 + }); + + m.Invoke(); + + EXPECT_THAT(m.GetDequantizedOutput<int8_t>(), + ElementsAreArray(ArrayFloatNear({23, 23, 23, 57, 57, 57}))); + EXPECT_THAT(m.GetOutput<int8_t>(), ElementsAre(22, 22, 22, 56, 56, 56)); +} + +INSTANTIATE_TEST_SUITE_P( + QuantizedBatchMatMulOpTest, QuantizedBatchMatMulOpTest, + ::testing::ValuesIn(SingleOpTest::GetKernelTags(*kKernelMap))); + } // namespace } // namespace tflite diff --git a/tensorflow/lite/kernels/internal/optimized/batch_matmul.h b/tensorflow/lite/kernels/internal/optimized/batch_matmul.h index 24b5012304f..5e622154d60 100644 --- a/tensorflow/lite/kernels/internal/optimized/batch_matmul.h +++ b/tensorflow/lite/kernels/internal/optimized/batch_matmul.h @@ -272,6 +272,112 @@ inline void BatchMatMul(const RuntimeShape& lhs_shape, const int8_t* lhs_data, } } +inline void BatchMatMul(const FullyConnectedParams& params, + const RuntimeShape& lhs_shape, const int8_t* lhs_data, + const RuntimeShape& rhs_shape, const int8_t* rhs_data, + const RuntimeShape& output_shape, int8_t* output_data, + CpuBackendContext* context) { + using ::tflite::cpu_backend_gemm::Gemm; + using ::tflite::cpu_backend_gemm::GemmParams; + using ::tflite::cpu_backend_gemm::MatrixParams; + + const RuntimeShape extended_lhs_shape = + RuntimeShape::ExtendedShape(5, lhs_shape); + const RuntimeShape extended_rhs_shape = + RuntimeShape::ExtendedShape(5, rhs_shape); + + // Determine which dimension is the broadcast dimension. + auto broadcast_dim = [](int lhs_dim, int rhs_dim) { + if (lhs_dim == rhs_dim) return lhs_dim; + if (lhs_dim == 1) return rhs_dim; + TFLITE_DCHECK_EQ(rhs_dim, 1); + return lhs_dim; + }; + + // Compute the "extent" for iterating on this dimension. + // If we are broadcasting, then don't advance (i.e return 0). + auto extent = [](const RuntimeShape& shape, int x) { + if (shape.Dims(x) == 1) { + return 0; + } + int prod = 1; + for (int i = x + 1; i < shape.DimensionsCount(); ++i) { + prod *= shape.Dims(i); + } + return prod; + }; + + const int batch_dim0 = + broadcast_dim(extended_lhs_shape.Dims(0), extended_rhs_shape.Dims(0)); + const int batch_dim1 = + broadcast_dim(extended_lhs_shape.Dims(1), extended_rhs_shape.Dims(1)); + const int batch_dim2 = + broadcast_dim(extended_lhs_shape.Dims(2), extended_rhs_shape.Dims(2)); + + const int lhs_ext0 = extent(extended_lhs_shape, 0); + const int lhs_ext1 = extent(extended_lhs_shape, 1); + const int lhs_ext2 = extent(extended_lhs_shape, 2); + const int rhs_ext0 = extent(extended_rhs_shape, 0); + const int rhs_ext1 = extent(extended_rhs_shape, 1); + const int rhs_ext2 = extent(extended_rhs_shape, 2); + + // Set params for each matrix multiply. + const int lhs_rows = extended_lhs_shape.Dims(3); + const int rhs_cols = extended_rhs_shape.Dims(4); + const int accum_depth = extended_lhs_shape.Dims(4); + + const int32 input_offset = params.input_offset; + const int32 filter_offset = params.weights_offset; + const int32 output_offset = params.output_offset; + const int32 output_multiplier = params.output_multiplier; + const int output_shift = params.output_shift; + const int32 output_activation_min = params.quantized_activation_min; + const int32 output_activation_max = params.quantized_activation_max; + TFLITE_DCHECK_LE(output_activation_min, output_activation_max); + + MatrixParams<int8_t> lhs_params; + lhs_params.order = cpu_backend_gemm::Order::kRowMajor; + lhs_params.rows = lhs_rows; + lhs_params.cols = accum_depth; + lhs_params.zero_point = -filter_offset; + + MatrixParams<int8_t> rhs_params; + rhs_params.order = cpu_backend_gemm::Order::kColMajor; + rhs_params.rows = accum_depth; + rhs_params.cols = rhs_cols; + rhs_params.zero_point = -input_offset; + + MatrixParams<int8_t> dst_params; + dst_params.order = cpu_backend_gemm::Order::kColMajor; + dst_params.rows = lhs_rows; + dst_params.cols = rhs_cols; + dst_params.zero_point = output_offset; + + for (int b0 = 0; b0 < batch_dim0; ++b0) { + const int8_t* lhs_ptr0 = lhs_data + (b0 * lhs_ext0); + const int8_t* rhs_ptr0 = rhs_data + (b0 * rhs_ext0); + for (int b1 = 0; b1 < batch_dim1; ++b1) { + const int8_t* lhs_ptr1 = lhs_ptr0 + b1 * lhs_ext1; + const int8_t* rhs_ptr1 = rhs_ptr0 + b1 * rhs_ext1; + for (int b2 = 0; b2 < batch_dim2; ++b2) { + const int8_t* lhs_ptr2 = lhs_ptr1 + b2 * lhs_ext2; + const int8_t* rhs_ptr2 = rhs_ptr1 + b2 * rhs_ext2; + int8_t* out_ptr = output_data + ((b0 * batch_dim1 * batch_dim2) + + b1 * batch_dim2 + b2) * + lhs_rows * rhs_cols; + + GemmParams<int32_t, int8_t> gemm_params; + gemm_params.clamp_min = output_activation_min; + gemm_params.clamp_max = output_activation_max; + gemm_params.multiplier_fixedpoint = output_multiplier; + gemm_params.multiplier_exponent = output_shift; + cpu_backend_gemm::Gemm(lhs_params, lhs_ptr2, rhs_params, rhs_ptr2, + dst_params, out_ptr, gemm_params, context); + } + } + } +} + } // namespace optimized_ops } // namespace tflite diff --git a/tensorflow/lite/kernels/internal/reference/batch_matmul.h b/tensorflow/lite/kernels/internal/reference/batch_matmul.h index 1394bd9da64..05caefaca5d 100644 --- a/tensorflow/lite/kernels/internal/reference/batch_matmul.h +++ b/tensorflow/lite/kernels/internal/reference/batch_matmul.h @@ -217,6 +217,99 @@ inline void BatchMatMul(const RuntimeShape& lhs_shape, const int8_t* lhs_data, } } +inline void BatchMatMul(const FullyConnectedParams& params, + const RuntimeShape& lhs_shape, const int8_t* lhs_data, + const RuntimeShape& rhs_shape, const int8_t* rhs_data, + const RuntimeShape& output_shape, int8_t* output_data) { + const RuntimeShape extended_lhs_shape = + RuntimeShape::ExtendedShape(5, lhs_shape); + const RuntimeShape extended_rhs_shape = + RuntimeShape::ExtendedShape(5, rhs_shape); + + // Determine which dimension is the broadcast dimension. + auto broadcast_dim = [](int lhs_dim, int rhs_dim) { + if (lhs_dim == rhs_dim) return lhs_dim; + if (lhs_dim == 1) return rhs_dim; + TFLITE_DCHECK_EQ(rhs_dim, 1); + return lhs_dim; + }; + + // Compute the "extent" for iterating on this dimension. + // If we are broadcasting, then don't advance (i.e return 0). + auto extent = [](const RuntimeShape& shape, int x) { + if (shape.Dims(x) == 1) { + return 0; + } + int prod = 1; + for (int i = x + 1; i < shape.DimensionsCount(); ++i) { + prod *= shape.Dims(i); + } + return prod; + }; + + const int batch_dim0 = + broadcast_dim(extended_lhs_shape.Dims(0), extended_rhs_shape.Dims(0)); + const int batch_dim1 = + broadcast_dim(extended_lhs_shape.Dims(1), extended_rhs_shape.Dims(1)); + const int batch_dim2 = + broadcast_dim(extended_lhs_shape.Dims(2), extended_rhs_shape.Dims(2)); + + const int lhs_ext0 = extent(extended_lhs_shape, 0); + const int lhs_ext1 = extent(extended_lhs_shape, 1); + const int lhs_ext2 = extent(extended_lhs_shape, 2); + const int rhs_ext0 = extent(extended_rhs_shape, 0); + const int rhs_ext1 = extent(extended_rhs_shape, 1); + const int rhs_ext2 = extent(extended_rhs_shape, 2); + + // Set params for each matrix multiply. + const int lhs_rows = extended_lhs_shape.Dims(3); + const int rhs_cols = extended_rhs_shape.Dims(4); + const int accum_depth = extended_lhs_shape.Dims(4); + + const int32 input_offset = params.input_offset; + const int32 filter_offset = params.weights_offset; + const int32 output_offset = params.output_offset; + const int32 output_multiplier = params.output_multiplier; + const int output_shift = params.output_shift; + const int32 output_activation_min = params.quantized_activation_min; + const int32 output_activation_max = params.quantized_activation_max; + TFLITE_DCHECK_LE(output_activation_min, output_activation_max); + + for (int b0 = 0; b0 < batch_dim0; ++b0) { + const int8_t* lhs_ptr0 = lhs_data + (b0 * lhs_ext0); + const int8_t* rhs_ptr0 = rhs_data + (b0 * rhs_ext0); + for (int b1 = 0; b1 < batch_dim1; ++b1) { + const int8_t* lhs_ptr1 = lhs_ptr0 + b1 * lhs_ext1; + const int8_t* rhs_ptr1 = rhs_ptr0 + b1 * rhs_ext1; + for (int b2 = 0; b2 < batch_dim2; ++b2) { + const int8_t* lhs_ptr2 = lhs_ptr1 + b2 * lhs_ext2; + const int8_t* rhs_ptr2 = rhs_ptr1 + b2 * rhs_ext2; + int8_t* out_ptr = output_data + ((b0 * batch_dim1 * batch_dim2) + + b1 * batch_dim2 + b2) * + lhs_rows * rhs_cols; + + for (int j = 0; j < rhs_cols; ++j) { + for (int i = 0; i < lhs_rows; ++i) { + int32_t total = 0; + for (int k = 0; k < accum_depth; ++k) { + int32 lhs_val = lhs_ptr2[accum_depth * i + k]; + int32 rhs_val = rhs_ptr2[accum_depth * j + k]; + total += (lhs_val + filter_offset) * (rhs_val + input_offset); + } + total = MultiplyByQuantizedMultiplier(total, output_multiplier, + output_shift); + total += output_offset; + total = std::max(total, output_activation_min); + total = std::min(total, output_activation_max); + const int idx = lhs_rows * j + i; + out_ptr[idx] = static_cast<int8_t>(total); + } + } + } + } + } +} + } // namespace reference_ops } // namespace tflite diff --git a/tensorflow/lite/kernels/register.cc b/tensorflow/lite/kernels/register.cc index 90688a2aa1f..c3a4aaad16d 100644 --- a/tensorflow/lite/kernels/register.cc +++ b/tensorflow/lite/kernels/register.cc @@ -289,7 +289,9 @@ BuiltinOpResolver::BuiltinOpResolver() { AddBuiltin(BuiltinOperator_SCATTER_ND, Register_SCATTER_ND()); AddBuiltin(BuiltinOperator_DENSIFY, Register_DENSIFY()); AddBuiltin(BuiltinOperator_SEGMENT_SUM, Register_SEGMENT_SUM()); - AddBuiltin(BuiltinOperator_BATCH_MATMUL, Register_BATCH_MATMUL()); + AddBuiltin(BuiltinOperator_BATCH_MATMUL, Register_BATCH_MATMUL(), + /* min_version = */ 1, + /* max_version = */ 2); AddCustom("NumericVerify", tflite::ops::custom::Register_NUMERIC_VERIFY()); // TODO(andrewharp, ahentz): Move these somewhere more appropriate so that // custom ops aren't always included by default. diff --git a/tensorflow/lite/tools/optimize/operator_property.cc b/tensorflow/lite/tools/optimize/operator_property.cc index 8a0cbca29e2..f2cb98ef31a 100644 --- a/tensorflow/lite/tools/optimize/operator_property.cc +++ b/tensorflow/lite/tools/optimize/operator_property.cc @@ -88,6 +88,12 @@ OperatorProperty GetOperatorProperty(const ModelT* model, int subgraph_index, property.restrict_same_input_output_scale = true; property.version = 2; break; + case BuiltinOperator_BATCH_MATMUL: { + property.inputs = {{0, {}}, {1, {}}}; + property.outputs = {{0, {}}}; + property.version = 2; + break; + } case BuiltinOperator_BATCH_TO_SPACE_ND: case BuiltinOperator_SPACE_TO_BATCH_ND: case BuiltinOperator_SPACE_TO_DEPTH: diff --git a/tensorflow/lite/tools/versioning/op_version.cc b/tensorflow/lite/tools/versioning/op_version.cc index 118e2d420f8..a97b9da47f1 100644 --- a/tensorflow/lite/tools/versioning/op_version.cc +++ b/tensorflow/lite/tools/versioning/op_version.cc @@ -24,6 +24,7 @@ limitations under the License. #include "absl/strings/str_split.h" #include "tensorflow/core/platform/logging.h" #include "tensorflow/lite/kernels/internal/compatibility.h" +#include "tensorflow/lite/schema/schema_generated.h" namespace tflite { namespace { @@ -518,6 +519,7 @@ int GetBuiltinOperatorVersion(const OpSignature& op_sig) { case BuiltinOperator_LESS: case BuiltinOperator_LESS_EQUAL: case BuiltinOperator_SELECT: + case BuiltinOperator_BATCH_MATMUL: if (op_sig.input_types.at(0) == TensorType_INT8) { return 2; } diff --git a/tensorflow/lite/tools/versioning/runtime_version.cc b/tensorflow/lite/tools/versioning/runtime_version.cc index 92a7001606f..36976354685 100644 --- a/tensorflow/lite/tools/versioning/runtime_version.cc +++ b/tensorflow/lite/tools/versioning/runtime_version.cc @@ -58,6 +58,7 @@ std::string FindMinimumRuntimeVersionForOp(tflite::BuiltinOperator op_code, {{BuiltinOperator_AVERAGE_POOL_2D, 2}, "1.14.0"}, {{BuiltinOperator_AVERAGE_POOL_2D, 3}, kPendingReleaseVersion}, {{BuiltinOperator_BATCH_MATMUL, 1}, kPendingReleaseVersion}, + {{BuiltinOperator_BATCH_MATMUL, 2}, kPendingReleaseVersion}, {{BuiltinOperator_CONV_2D, 1}, "1.5.0"}, {{BuiltinOperator_CONV_2D, 2}, "1.14.0"}, {{BuiltinOperator_CONV_2D, 3}, "1.14.0"},