diff --git a/tensorflow/lite/kernels/BUILD b/tensorflow/lite/kernels/BUILD index 64255523000..99770c39a1e 100644 --- a/tensorflow/lite/kernels/BUILD +++ b/tensorflow/lite/kernels/BUILD @@ -878,8 +878,9 @@ cc_test( srcs = ["embedding_lookup_test.cc"], deps = [ ":builtin_ops", + ":test_util", "//tensorflow/lite:framework", - "//tensorflow/lite/kernels:test_util", + "//tensorflow/lite/kernels/internal:tensor", "@com_google_googletest//:gtest", ], ) diff --git a/tensorflow/lite/kernels/embedding_lookup.cc b/tensorflow/lite/kernels/embedding_lookup.cc index 3f1d62389f4..8a285f6622d 100644 --- a/tensorflow/lite/kernels/embedding_lookup.cc +++ b/tensorflow/lite/kernels/embedding_lookup.cc @@ -69,9 +69,9 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { return context->ResizeTensor(context, output, outputSize); } -TfLiteStatus EvalFloat(TfLiteContext* context, TfLiteNode* node, - const TfLiteTensor* lookup, const TfLiteTensor* value, - TfLiteTensor* output) { +TfLiteStatus EvalSimple(TfLiteContext* context, TfLiteNode* node, + const TfLiteTensor* lookup, const TfLiteTensor* value, + TfLiteTensor* output) { const int row_size = SizeOfDimension(value, 0); const int row_bytes = value->bytes / row_size; @@ -138,10 +138,14 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { TfLiteTensor* output = GetOutput(context, node, 0); switch (value->type) { case kTfLiteFloat32: - return EvalFloat(context, node, lookup, value, output); + return EvalSimple(context, node, lookup, value, output); case kTfLiteUInt8: case kTfLiteInt8: - return EvalHybrid(context, node, lookup, value, output); + if (output->type == kTfLiteFloat32) { + return EvalHybrid(context, node, lookup, value, output); + } else { + return EvalSimple(context, node, lookup, value, output); + } default: context->ReportError(context, "Type not currently supported."); return kTfLiteError; diff --git a/tensorflow/lite/kernels/embedding_lookup_test.cc b/tensorflow/lite/kernels/embedding_lookup_test.cc index 2462ff26933..cf90ed08aa6 100644 --- a/tensorflow/lite/kernels/embedding_lookup_test.cc +++ b/tensorflow/lite/kernels/embedding_lookup_test.cc @@ -21,6 +21,7 @@ License. #include #include #include "tensorflow/lite/interpreter.h" +#include "tensorflow/lite/kernels/internal/tensor.h" #include "tensorflow/lite/kernels/register.h" #include "tensorflow/lite/kernels/test_util.h" #include "tensorflow/lite/model.h" @@ -36,10 +37,11 @@ class BaseEmbeddingLookupOpModel : public SingleOpModel { public: BaseEmbeddingLookupOpModel(std::initializer_list index_shape, std::initializer_list weight_shape, - TensorType weight_type = TensorType_FLOAT32) { + TensorType weight_type = TensorType_FLOAT32, + TensorType output_type = TensorType_FLOAT32) { input_ = AddInput(TensorType_INT32); weight_ = AddInput(weight_type); - output_ = AddOutput(TensorType_FLOAT32); + output_ = AddOutput(output_type); SetBuiltinOp(BuiltinOperator_EMBEDDING_LOOKUP, BuiltinOptions_NONE, 0); BuildInterpreter({index_shape, weight_shape}); } @@ -48,7 +50,10 @@ class BaseEmbeddingLookupOpModel : public SingleOpModel { PopulateTensor(input_, data); } - std::vector GetOutput() { return ExtractVector(output_); } + template + std::vector GetOutput() { + return ExtractVector(output_); + } protected: int input_; @@ -60,15 +65,17 @@ class EmbeddingLookupOpModel : public BaseEmbeddingLookupOpModel { public: using BaseEmbeddingLookupOpModel::BaseEmbeddingLookupOpModel; - void Set3DWeightMatrix(const std::function& function) { + template + void Set3DWeightMatrix(const std::function& function) { TfLiteTensor* tensor = interpreter_->tensor(weight_); int rows = tensor->dims->data[0]; int columns = tensor->dims->data[1]; int features = tensor->dims->data[2]; + T* data = GetTensorData(tensor); for (int i = 0; i < rows; i++) { for (int j = 0; j < columns; j++) { for (int k = 0; k < features; k++) { - tensor->data.f[(i * columns + j) * features + k] = function(i, j, k); + data[(i * columns + j) * features + k] = function(i, j, k); } } } @@ -96,12 +103,12 @@ class HybridEmbeddingLookupOpModel : public BaseEmbeddingLookupOpModel { TEST(EmbeddingLookupOpTest, SimpleTest) { EmbeddingLookupOpModel m({3}, {3, 2, 4}); m.SetInput({1, 0, 2}); - m.Set3DWeightMatrix( - [](int i, int j, int k) { return i + j / 10.0f + k / 100.0f; }); + m.Set3DWeightMatrix( + [](int i, int j, int k) -> float { return i + j / 10.0f + k / 100.0f; }); m.Invoke(); - EXPECT_THAT(m.GetOutput(), + EXPECT_THAT(m.GetOutput(), ElementsAreArray(ArrayFloatNear({ 1.00, 1.01, 1.02, 1.03, 1.10, 1.11, 1.12, 1.13, // Row 1 0.00, 0.01, 0.02, 0.03, 0.10, 0.11, 0.12, 0.13, // Row 0 @@ -120,7 +127,7 @@ TEST(HybridEmbeddingLookupHybridOpTest, Simple2DTestUint8) { m.Invoke(); - EXPECT_THAT(m.GetOutput(), + EXPECT_THAT(m.GetOutput(), ElementsAreArray(ArrayFloatNear( { 1.00, -1.01, 1.02, 1.03, 1.10, 1.11, 1.12, 1.13, // Row 1 @@ -141,7 +148,7 @@ TEST(HybridEmbeddingLookupHybridOpTest, Simple3DTestUint8) { m.Invoke(); - EXPECT_THAT(m.GetOutput(), + EXPECT_THAT(m.GetOutput(), ElementsAreArray(ArrayFloatNear( { 1.00, -1.01, 1.02, 1.03, 1.10, 1.11, 1.12, 1.13, // Row 1 @@ -162,7 +169,7 @@ TEST(HybridEmbeddingLookupHybridOpTest, Simple4DTestUint8) { m.Invoke(); - EXPECT_THAT(m.GetOutput(), + EXPECT_THAT(m.GetOutput(), ElementsAreArray(ArrayFloatNear( { 1.00, -1.01, 1.02, 1.03, 1.10, 1.11, 1.12, 1.13, // Row 1 @@ -183,7 +190,7 @@ TEST(HybridEmbeddingLookupHybridOpTest, Simple2DTestInt8) { m.Invoke(); - EXPECT_THAT(m.GetOutput(), + EXPECT_THAT(m.GetOutput(), ElementsAreArray(ArrayFloatNear( { 1.00, -1.01, 1.02, 1.03, 1.10, 1.11, 1.12, 1.13, // Row 1 @@ -204,7 +211,7 @@ TEST(HybridEmbeddingLookupHybridOpTest, Simple3DTestInt8) { m.Invoke(); - EXPECT_THAT(m.GetOutput(), + EXPECT_THAT(m.GetOutput(), ElementsAreArray(ArrayFloatNear( { 1.00, -1.01, 1.02, 1.03, 1.10, 1.11, 1.12, 1.13, // Row 1 @@ -225,7 +232,7 @@ TEST(HybridEmbeddingLookupHybridOpTest, Simple4DTestInt8) { m.Invoke(); - EXPECT_THAT(m.GetOutput(), + EXPECT_THAT(m.GetOutput(), ElementsAreArray(ArrayFloatNear( { 1.00, -1.01, 1.02, 1.03, 1.10, 1.11, 1.12, 1.13, // Row 1 @@ -235,6 +242,22 @@ TEST(HybridEmbeddingLookupHybridOpTest, Simple4DTestInt8) { kTestTolerance))); } +TEST(EmbeddingLookupHybridOpTest, Simple3DTestQuantized) { + EmbeddingLookupOpModel m({3}, {3, 2, 4}, TensorType_UINT8, TensorType_INT8); + m.SetInput({1, 0, 2}); + m.Set3DWeightMatrix( + [](int i, int j, int k) -> uint8_t { return 100 * i + 10 * j + k; }); + + m.Invoke(); + + EXPECT_THAT(m.GetOutput(), + ElementsAreArray({ + 100, 101, 102, 103, 110, 111, 112, 113, // Row 1 + 0, 1, 2, 3, 10, 11, 12, 13, // Row 0 + 200, 201, 202, 203, 210, 211, 212, 213, // Row 2 + })); +} + } // namespace } // namespace tflite diff --git a/tensorflow/lite/kernels/register.cc b/tensorflow/lite/kernels/register.cc index 35d93429d35..b527d927812 100644 --- a/tensorflow/lite/kernels/register.cc +++ b/tensorflow/lite/kernels/register.cc @@ -211,7 +211,7 @@ BuiltinOpResolver::BuiltinOpResolver() { /* max_version */ 2); AddBuiltin(BuiltinOperator_EMBEDDING_LOOKUP, Register_EMBEDDING_LOOKUP(), /* min_version */ 1, - /* max_version */ 2); + /* max_version */ 3); AddBuiltin(BuiltinOperator_EMBEDDING_LOOKUP_SPARSE, Register_EMBEDDING_LOOKUP_SPARSE()); AddBuiltin(BuiltinOperator_FULLY_CONNECTED, Register_FULLY_CONNECTED(), diff --git a/tensorflow/lite/tools/optimize/quantize_weights.cc b/tensorflow/lite/tools/optimize/quantize_weights.cc index 935dc01de7f..dda825393a5 100644 --- a/tensorflow/lite/tools/optimize/quantize_weights.cc +++ b/tensorflow/lite/tools/optimize/quantize_weights.cc @@ -128,7 +128,6 @@ bool IsHybridEvaluationOp(const OperatorT* op, const OperatorCodeT* op_code, } else if (builtin_op_code == BuiltinOperator_FULLY_CONNECTED || builtin_op_code == BuiltinOperator_CONV_2D || builtin_op_code == BuiltinOperator_SVDF || - builtin_op_code == BuiltinOperator_EMBEDDING_LOOKUP || builtin_op_code == BuiltinOperator_RNN || builtin_op_code == BuiltinOperator_BIDIRECTIONAL_SEQUENCE_LSTM || builtin_op_code == BuiltinOperator_BIDIRECTIONAL_SEQUENCE_RNN || @@ -262,7 +261,6 @@ void UpdateInt8OperatorVersions(ModelT* model) { for (int i = 0; i < model->operator_codes.size(); ++i) { const BuiltinOperator& op_code = model->operator_codes[i]->builtin_code; if (op_code == BuiltinOperator_CONV_2D || op_code == BuiltinOperator_SVDF || - op_code == BuiltinOperator_EMBEDDING_LOOKUP || op_code == BuiltinOperator_RNN || op_code == BuiltinOperator_BIDIRECTIONAL_SEQUENCE_RNN || op_code == BuiltinOperator_UNIDIRECTIONAL_SEQUENCE_LSTM || @@ -271,6 +269,7 @@ void UpdateInt8OperatorVersions(ModelT* model) { } else if (op_code == BuiltinOperator_FULLY_CONNECTED || op_code == BuiltinOperator_BIDIRECTIONAL_SEQUENCE_LSTM || + op_code == BuiltinOperator_EMBEDDING_LOOKUP || op_code == BuiltinOperator_LSTM) { model->operator_codes[i]->version = 3; } @@ -286,7 +285,8 @@ bool IsQuantizationPassThroughOps( const OperatorT* consumer_op = consumer_op_infos.front().op; const BuiltinOperator op_code = model->operator_codes[consumer_op->opcode_index]->builtin_code; - return op_code == BuiltinOperator_GATHER; + return op_code == BuiltinOperator_GATHER || + op_code == BuiltinOperator_EMBEDDING_LOOKUP; } // Copies quantization parameters from input to output and returns consumers of