Converting EmbeddingLookup from hybrid to ordinary op.

PiperOrigin-RevId: 246354602
This commit is contained in:
A. Unique TensorFlower 2019-05-02 11:18:42 -07:00 committed by TensorFlower Gardener
parent ba6f907546
commit eac6722a0f
5 changed files with 52 additions and 24 deletions

View File

@ -878,8 +878,9 @@ cc_test(
srcs = ["embedding_lookup_test.cc"],
deps = [
":builtin_ops",
":test_util",
"//tensorflow/lite:framework",
"//tensorflow/lite/kernels:test_util",
"//tensorflow/lite/kernels/internal:tensor",
"@com_google_googletest//:gtest",
],
)

View File

@ -69,9 +69,9 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
return context->ResizeTensor(context, output, outputSize);
}
TfLiteStatus EvalFloat(TfLiteContext* context, TfLiteNode* node,
const TfLiteTensor* lookup, const TfLiteTensor* value,
TfLiteTensor* output) {
TfLiteStatus EvalSimple(TfLiteContext* context, TfLiteNode* node,
const TfLiteTensor* lookup, const TfLiteTensor* value,
TfLiteTensor* output) {
const int row_size = SizeOfDimension(value, 0);
const int row_bytes = value->bytes / row_size;
@ -138,10 +138,14 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
TfLiteTensor* output = GetOutput(context, node, 0);
switch (value->type) {
case kTfLiteFloat32:
return EvalFloat(context, node, lookup, value, output);
return EvalSimple(context, node, lookup, value, output);
case kTfLiteUInt8:
case kTfLiteInt8:
return EvalHybrid(context, node, lookup, value, output);
if (output->type == kTfLiteFloat32) {
return EvalHybrid(context, node, lookup, value, output);
} else {
return EvalSimple(context, node, lookup, value, output);
}
default:
context->ReportError(context, "Type not currently supported.");
return kTfLiteError;

View File

@ -21,6 +21,7 @@ License.
#include <gmock/gmock.h>
#include <gtest/gtest.h>
#include "tensorflow/lite/interpreter.h"
#include "tensorflow/lite/kernels/internal/tensor.h"
#include "tensorflow/lite/kernels/register.h"
#include "tensorflow/lite/kernels/test_util.h"
#include "tensorflow/lite/model.h"
@ -36,10 +37,11 @@ class BaseEmbeddingLookupOpModel : public SingleOpModel {
public:
BaseEmbeddingLookupOpModel(std::initializer_list<int> index_shape,
std::initializer_list<int> weight_shape,
TensorType weight_type = TensorType_FLOAT32) {
TensorType weight_type = TensorType_FLOAT32,
TensorType output_type = TensorType_FLOAT32) {
input_ = AddInput(TensorType_INT32);
weight_ = AddInput(weight_type);
output_ = AddOutput(TensorType_FLOAT32);
output_ = AddOutput(output_type);
SetBuiltinOp(BuiltinOperator_EMBEDDING_LOOKUP, BuiltinOptions_NONE, 0);
BuildInterpreter({index_shape, weight_shape});
}
@ -48,7 +50,10 @@ class BaseEmbeddingLookupOpModel : public SingleOpModel {
PopulateTensor(input_, data);
}
std::vector<float> GetOutput() { return ExtractVector<float>(output_); }
template <typename T>
std::vector<T> GetOutput() {
return ExtractVector<T>(output_);
}
protected:
int input_;
@ -60,15 +65,17 @@ class EmbeddingLookupOpModel : public BaseEmbeddingLookupOpModel {
public:
using BaseEmbeddingLookupOpModel::BaseEmbeddingLookupOpModel;
void Set3DWeightMatrix(const std::function<float(int, int, int)>& function) {
template <typename T>
void Set3DWeightMatrix(const std::function<T(int, int, int)>& function) {
TfLiteTensor* tensor = interpreter_->tensor(weight_);
int rows = tensor->dims->data[0];
int columns = tensor->dims->data[1];
int features = tensor->dims->data[2];
T* data = GetTensorData<T>(tensor);
for (int i = 0; i < rows; i++) {
for (int j = 0; j < columns; j++) {
for (int k = 0; k < features; k++) {
tensor->data.f[(i * columns + j) * features + k] = function(i, j, k);
data[(i * columns + j) * features + k] = function(i, j, k);
}
}
}
@ -96,12 +103,12 @@ class HybridEmbeddingLookupOpModel : public BaseEmbeddingLookupOpModel {
TEST(EmbeddingLookupOpTest, SimpleTest) {
EmbeddingLookupOpModel m({3}, {3, 2, 4});
m.SetInput({1, 0, 2});
m.Set3DWeightMatrix(
[](int i, int j, int k) { return i + j / 10.0f + k / 100.0f; });
m.Set3DWeightMatrix<float>(
[](int i, int j, int k) -> float { return i + j / 10.0f + k / 100.0f; });
m.Invoke();
EXPECT_THAT(m.GetOutput(),
EXPECT_THAT(m.GetOutput<float>(),
ElementsAreArray(ArrayFloatNear({
1.00, 1.01, 1.02, 1.03, 1.10, 1.11, 1.12, 1.13, // Row 1
0.00, 0.01, 0.02, 0.03, 0.10, 0.11, 0.12, 0.13, // Row 0
@ -120,7 +127,7 @@ TEST(HybridEmbeddingLookupHybridOpTest, Simple2DTestUint8) {
m.Invoke();
EXPECT_THAT(m.GetOutput(),
EXPECT_THAT(m.GetOutput<float>(),
ElementsAreArray(ArrayFloatNear(
{
1.00, -1.01, 1.02, 1.03, 1.10, 1.11, 1.12, 1.13, // Row 1
@ -141,7 +148,7 @@ TEST(HybridEmbeddingLookupHybridOpTest, Simple3DTestUint8) {
m.Invoke();
EXPECT_THAT(m.GetOutput(),
EXPECT_THAT(m.GetOutput<float>(),
ElementsAreArray(ArrayFloatNear(
{
1.00, -1.01, 1.02, 1.03, 1.10, 1.11, 1.12, 1.13, // Row 1
@ -162,7 +169,7 @@ TEST(HybridEmbeddingLookupHybridOpTest, Simple4DTestUint8) {
m.Invoke();
EXPECT_THAT(m.GetOutput(),
EXPECT_THAT(m.GetOutput<float>(),
ElementsAreArray(ArrayFloatNear(
{
1.00, -1.01, 1.02, 1.03, 1.10, 1.11, 1.12, 1.13, // Row 1
@ -183,7 +190,7 @@ TEST(HybridEmbeddingLookupHybridOpTest, Simple2DTestInt8) {
m.Invoke();
EXPECT_THAT(m.GetOutput(),
EXPECT_THAT(m.GetOutput<float>(),
ElementsAreArray(ArrayFloatNear(
{
1.00, -1.01, 1.02, 1.03, 1.10, 1.11, 1.12, 1.13, // Row 1
@ -204,7 +211,7 @@ TEST(HybridEmbeddingLookupHybridOpTest, Simple3DTestInt8) {
m.Invoke();
EXPECT_THAT(m.GetOutput(),
EXPECT_THAT(m.GetOutput<float>(),
ElementsAreArray(ArrayFloatNear(
{
1.00, -1.01, 1.02, 1.03, 1.10, 1.11, 1.12, 1.13, // Row 1
@ -225,7 +232,7 @@ TEST(HybridEmbeddingLookupHybridOpTest, Simple4DTestInt8) {
m.Invoke();
EXPECT_THAT(m.GetOutput(),
EXPECT_THAT(m.GetOutput<float>(),
ElementsAreArray(ArrayFloatNear(
{
1.00, -1.01, 1.02, 1.03, 1.10, 1.11, 1.12, 1.13, // Row 1
@ -235,6 +242,22 @@ TEST(HybridEmbeddingLookupHybridOpTest, Simple4DTestInt8) {
kTestTolerance)));
}
TEST(EmbeddingLookupHybridOpTest, Simple3DTestQuantized) {
EmbeddingLookupOpModel m({3}, {3, 2, 4}, TensorType_UINT8, TensorType_INT8);
m.SetInput({1, 0, 2});
m.Set3DWeightMatrix<uint8_t>(
[](int i, int j, int k) -> uint8_t { return 100 * i + 10 * j + k; });
m.Invoke();
EXPECT_THAT(m.GetOutput<int8_t>(),
ElementsAreArray({
100, 101, 102, 103, 110, 111, 112, 113, // Row 1
0, 1, 2, 3, 10, 11, 12, 13, // Row 0
200, 201, 202, 203, 210, 211, 212, 213, // Row 2
}));
}
} // namespace
} // namespace tflite

View File

@ -211,7 +211,7 @@ BuiltinOpResolver::BuiltinOpResolver() {
/* max_version */ 2);
AddBuiltin(BuiltinOperator_EMBEDDING_LOOKUP, Register_EMBEDDING_LOOKUP(),
/* min_version */ 1,
/* max_version */ 2);
/* max_version */ 3);
AddBuiltin(BuiltinOperator_EMBEDDING_LOOKUP_SPARSE,
Register_EMBEDDING_LOOKUP_SPARSE());
AddBuiltin(BuiltinOperator_FULLY_CONNECTED, Register_FULLY_CONNECTED(),

View File

@ -128,7 +128,6 @@ bool IsHybridEvaluationOp(const OperatorT* op, const OperatorCodeT* op_code,
} else if (builtin_op_code == BuiltinOperator_FULLY_CONNECTED ||
builtin_op_code == BuiltinOperator_CONV_2D ||
builtin_op_code == BuiltinOperator_SVDF ||
builtin_op_code == BuiltinOperator_EMBEDDING_LOOKUP ||
builtin_op_code == BuiltinOperator_RNN ||
builtin_op_code == BuiltinOperator_BIDIRECTIONAL_SEQUENCE_LSTM ||
builtin_op_code == BuiltinOperator_BIDIRECTIONAL_SEQUENCE_RNN ||
@ -262,7 +261,6 @@ void UpdateInt8OperatorVersions(ModelT* model) {
for (int i = 0; i < model->operator_codes.size(); ++i) {
const BuiltinOperator& op_code = model->operator_codes[i]->builtin_code;
if (op_code == BuiltinOperator_CONV_2D || op_code == BuiltinOperator_SVDF ||
op_code == BuiltinOperator_EMBEDDING_LOOKUP ||
op_code == BuiltinOperator_RNN ||
op_code == BuiltinOperator_BIDIRECTIONAL_SEQUENCE_RNN ||
op_code == BuiltinOperator_UNIDIRECTIONAL_SEQUENCE_LSTM ||
@ -271,6 +269,7 @@ void UpdateInt8OperatorVersions(ModelT* model) {
} else if (op_code == BuiltinOperator_FULLY_CONNECTED ||
op_code == BuiltinOperator_BIDIRECTIONAL_SEQUENCE_LSTM ||
op_code == BuiltinOperator_EMBEDDING_LOOKUP ||
op_code == BuiltinOperator_LSTM) {
model->operator_codes[i]->version = 3;
}
@ -286,7 +285,8 @@ bool IsQuantizationPassThroughOps(
const OperatorT* consumer_op = consumer_op_infos.front().op;
const BuiltinOperator op_code =
model->operator_codes[consumer_op->opcode_index]->builtin_code;
return op_code == BuiltinOperator_GATHER;
return op_code == BuiltinOperator_GATHER ||
op_code == BuiltinOperator_EMBEDDING_LOOKUP;
}
// Copies quantization parameters from input to output and returns consumers of