Converting EmbeddingLookup from hybrid to ordinary op.
PiperOrigin-RevId: 246354602
This commit is contained in:
parent
ba6f907546
commit
eac6722a0f
@ -878,8 +878,9 @@ cc_test(
|
||||
srcs = ["embedding_lookup_test.cc"],
|
||||
deps = [
|
||||
":builtin_ops",
|
||||
":test_util",
|
||||
"//tensorflow/lite:framework",
|
||||
"//tensorflow/lite/kernels:test_util",
|
||||
"//tensorflow/lite/kernels/internal:tensor",
|
||||
"@com_google_googletest//:gtest",
|
||||
],
|
||||
)
|
||||
|
@ -69,9 +69,9 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
|
||||
return context->ResizeTensor(context, output, outputSize);
|
||||
}
|
||||
|
||||
TfLiteStatus EvalFloat(TfLiteContext* context, TfLiteNode* node,
|
||||
const TfLiteTensor* lookup, const TfLiteTensor* value,
|
||||
TfLiteTensor* output) {
|
||||
TfLiteStatus EvalSimple(TfLiteContext* context, TfLiteNode* node,
|
||||
const TfLiteTensor* lookup, const TfLiteTensor* value,
|
||||
TfLiteTensor* output) {
|
||||
const int row_size = SizeOfDimension(value, 0);
|
||||
const int row_bytes = value->bytes / row_size;
|
||||
|
||||
@ -138,10 +138,14 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
|
||||
TfLiteTensor* output = GetOutput(context, node, 0);
|
||||
switch (value->type) {
|
||||
case kTfLiteFloat32:
|
||||
return EvalFloat(context, node, lookup, value, output);
|
||||
return EvalSimple(context, node, lookup, value, output);
|
||||
case kTfLiteUInt8:
|
||||
case kTfLiteInt8:
|
||||
return EvalHybrid(context, node, lookup, value, output);
|
||||
if (output->type == kTfLiteFloat32) {
|
||||
return EvalHybrid(context, node, lookup, value, output);
|
||||
} else {
|
||||
return EvalSimple(context, node, lookup, value, output);
|
||||
}
|
||||
default:
|
||||
context->ReportError(context, "Type not currently supported.");
|
||||
return kTfLiteError;
|
||||
|
@ -21,6 +21,7 @@ License.
|
||||
#include <gmock/gmock.h>
|
||||
#include <gtest/gtest.h>
|
||||
#include "tensorflow/lite/interpreter.h"
|
||||
#include "tensorflow/lite/kernels/internal/tensor.h"
|
||||
#include "tensorflow/lite/kernels/register.h"
|
||||
#include "tensorflow/lite/kernels/test_util.h"
|
||||
#include "tensorflow/lite/model.h"
|
||||
@ -36,10 +37,11 @@ class BaseEmbeddingLookupOpModel : public SingleOpModel {
|
||||
public:
|
||||
BaseEmbeddingLookupOpModel(std::initializer_list<int> index_shape,
|
||||
std::initializer_list<int> weight_shape,
|
||||
TensorType weight_type = TensorType_FLOAT32) {
|
||||
TensorType weight_type = TensorType_FLOAT32,
|
||||
TensorType output_type = TensorType_FLOAT32) {
|
||||
input_ = AddInput(TensorType_INT32);
|
||||
weight_ = AddInput(weight_type);
|
||||
output_ = AddOutput(TensorType_FLOAT32);
|
||||
output_ = AddOutput(output_type);
|
||||
SetBuiltinOp(BuiltinOperator_EMBEDDING_LOOKUP, BuiltinOptions_NONE, 0);
|
||||
BuildInterpreter({index_shape, weight_shape});
|
||||
}
|
||||
@ -48,7 +50,10 @@ class BaseEmbeddingLookupOpModel : public SingleOpModel {
|
||||
PopulateTensor(input_, data);
|
||||
}
|
||||
|
||||
std::vector<float> GetOutput() { return ExtractVector<float>(output_); }
|
||||
template <typename T>
|
||||
std::vector<T> GetOutput() {
|
||||
return ExtractVector<T>(output_);
|
||||
}
|
||||
|
||||
protected:
|
||||
int input_;
|
||||
@ -60,15 +65,17 @@ class EmbeddingLookupOpModel : public BaseEmbeddingLookupOpModel {
|
||||
public:
|
||||
using BaseEmbeddingLookupOpModel::BaseEmbeddingLookupOpModel;
|
||||
|
||||
void Set3DWeightMatrix(const std::function<float(int, int, int)>& function) {
|
||||
template <typename T>
|
||||
void Set3DWeightMatrix(const std::function<T(int, int, int)>& function) {
|
||||
TfLiteTensor* tensor = interpreter_->tensor(weight_);
|
||||
int rows = tensor->dims->data[0];
|
||||
int columns = tensor->dims->data[1];
|
||||
int features = tensor->dims->data[2];
|
||||
T* data = GetTensorData<T>(tensor);
|
||||
for (int i = 0; i < rows; i++) {
|
||||
for (int j = 0; j < columns; j++) {
|
||||
for (int k = 0; k < features; k++) {
|
||||
tensor->data.f[(i * columns + j) * features + k] = function(i, j, k);
|
||||
data[(i * columns + j) * features + k] = function(i, j, k);
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -96,12 +103,12 @@ class HybridEmbeddingLookupOpModel : public BaseEmbeddingLookupOpModel {
|
||||
TEST(EmbeddingLookupOpTest, SimpleTest) {
|
||||
EmbeddingLookupOpModel m({3}, {3, 2, 4});
|
||||
m.SetInput({1, 0, 2});
|
||||
m.Set3DWeightMatrix(
|
||||
[](int i, int j, int k) { return i + j / 10.0f + k / 100.0f; });
|
||||
m.Set3DWeightMatrix<float>(
|
||||
[](int i, int j, int k) -> float { return i + j / 10.0f + k / 100.0f; });
|
||||
|
||||
m.Invoke();
|
||||
|
||||
EXPECT_THAT(m.GetOutput(),
|
||||
EXPECT_THAT(m.GetOutput<float>(),
|
||||
ElementsAreArray(ArrayFloatNear({
|
||||
1.00, 1.01, 1.02, 1.03, 1.10, 1.11, 1.12, 1.13, // Row 1
|
||||
0.00, 0.01, 0.02, 0.03, 0.10, 0.11, 0.12, 0.13, // Row 0
|
||||
@ -120,7 +127,7 @@ TEST(HybridEmbeddingLookupHybridOpTest, Simple2DTestUint8) {
|
||||
|
||||
m.Invoke();
|
||||
|
||||
EXPECT_THAT(m.GetOutput(),
|
||||
EXPECT_THAT(m.GetOutput<float>(),
|
||||
ElementsAreArray(ArrayFloatNear(
|
||||
{
|
||||
1.00, -1.01, 1.02, 1.03, 1.10, 1.11, 1.12, 1.13, // Row 1
|
||||
@ -141,7 +148,7 @@ TEST(HybridEmbeddingLookupHybridOpTest, Simple3DTestUint8) {
|
||||
|
||||
m.Invoke();
|
||||
|
||||
EXPECT_THAT(m.GetOutput(),
|
||||
EXPECT_THAT(m.GetOutput<float>(),
|
||||
ElementsAreArray(ArrayFloatNear(
|
||||
{
|
||||
1.00, -1.01, 1.02, 1.03, 1.10, 1.11, 1.12, 1.13, // Row 1
|
||||
@ -162,7 +169,7 @@ TEST(HybridEmbeddingLookupHybridOpTest, Simple4DTestUint8) {
|
||||
|
||||
m.Invoke();
|
||||
|
||||
EXPECT_THAT(m.GetOutput(),
|
||||
EXPECT_THAT(m.GetOutput<float>(),
|
||||
ElementsAreArray(ArrayFloatNear(
|
||||
{
|
||||
1.00, -1.01, 1.02, 1.03, 1.10, 1.11, 1.12, 1.13, // Row 1
|
||||
@ -183,7 +190,7 @@ TEST(HybridEmbeddingLookupHybridOpTest, Simple2DTestInt8) {
|
||||
|
||||
m.Invoke();
|
||||
|
||||
EXPECT_THAT(m.GetOutput(),
|
||||
EXPECT_THAT(m.GetOutput<float>(),
|
||||
ElementsAreArray(ArrayFloatNear(
|
||||
{
|
||||
1.00, -1.01, 1.02, 1.03, 1.10, 1.11, 1.12, 1.13, // Row 1
|
||||
@ -204,7 +211,7 @@ TEST(HybridEmbeddingLookupHybridOpTest, Simple3DTestInt8) {
|
||||
|
||||
m.Invoke();
|
||||
|
||||
EXPECT_THAT(m.GetOutput(),
|
||||
EXPECT_THAT(m.GetOutput<float>(),
|
||||
ElementsAreArray(ArrayFloatNear(
|
||||
{
|
||||
1.00, -1.01, 1.02, 1.03, 1.10, 1.11, 1.12, 1.13, // Row 1
|
||||
@ -225,7 +232,7 @@ TEST(HybridEmbeddingLookupHybridOpTest, Simple4DTestInt8) {
|
||||
|
||||
m.Invoke();
|
||||
|
||||
EXPECT_THAT(m.GetOutput(),
|
||||
EXPECT_THAT(m.GetOutput<float>(),
|
||||
ElementsAreArray(ArrayFloatNear(
|
||||
{
|
||||
1.00, -1.01, 1.02, 1.03, 1.10, 1.11, 1.12, 1.13, // Row 1
|
||||
@ -235,6 +242,22 @@ TEST(HybridEmbeddingLookupHybridOpTest, Simple4DTestInt8) {
|
||||
kTestTolerance)));
|
||||
}
|
||||
|
||||
TEST(EmbeddingLookupHybridOpTest, Simple3DTestQuantized) {
|
||||
EmbeddingLookupOpModel m({3}, {3, 2, 4}, TensorType_UINT8, TensorType_INT8);
|
||||
m.SetInput({1, 0, 2});
|
||||
m.Set3DWeightMatrix<uint8_t>(
|
||||
[](int i, int j, int k) -> uint8_t { return 100 * i + 10 * j + k; });
|
||||
|
||||
m.Invoke();
|
||||
|
||||
EXPECT_THAT(m.GetOutput<int8_t>(),
|
||||
ElementsAreArray({
|
||||
100, 101, 102, 103, 110, 111, 112, 113, // Row 1
|
||||
0, 1, 2, 3, 10, 11, 12, 13, // Row 0
|
||||
200, 201, 202, 203, 210, 211, 212, 213, // Row 2
|
||||
}));
|
||||
}
|
||||
|
||||
} // namespace
|
||||
} // namespace tflite
|
||||
|
||||
|
@ -211,7 +211,7 @@ BuiltinOpResolver::BuiltinOpResolver() {
|
||||
/* max_version */ 2);
|
||||
AddBuiltin(BuiltinOperator_EMBEDDING_LOOKUP, Register_EMBEDDING_LOOKUP(),
|
||||
/* min_version */ 1,
|
||||
/* max_version */ 2);
|
||||
/* max_version */ 3);
|
||||
AddBuiltin(BuiltinOperator_EMBEDDING_LOOKUP_SPARSE,
|
||||
Register_EMBEDDING_LOOKUP_SPARSE());
|
||||
AddBuiltin(BuiltinOperator_FULLY_CONNECTED, Register_FULLY_CONNECTED(),
|
||||
|
@ -128,7 +128,6 @@ bool IsHybridEvaluationOp(const OperatorT* op, const OperatorCodeT* op_code,
|
||||
} else if (builtin_op_code == BuiltinOperator_FULLY_CONNECTED ||
|
||||
builtin_op_code == BuiltinOperator_CONV_2D ||
|
||||
builtin_op_code == BuiltinOperator_SVDF ||
|
||||
builtin_op_code == BuiltinOperator_EMBEDDING_LOOKUP ||
|
||||
builtin_op_code == BuiltinOperator_RNN ||
|
||||
builtin_op_code == BuiltinOperator_BIDIRECTIONAL_SEQUENCE_LSTM ||
|
||||
builtin_op_code == BuiltinOperator_BIDIRECTIONAL_SEQUENCE_RNN ||
|
||||
@ -262,7 +261,6 @@ void UpdateInt8OperatorVersions(ModelT* model) {
|
||||
for (int i = 0; i < model->operator_codes.size(); ++i) {
|
||||
const BuiltinOperator& op_code = model->operator_codes[i]->builtin_code;
|
||||
if (op_code == BuiltinOperator_CONV_2D || op_code == BuiltinOperator_SVDF ||
|
||||
op_code == BuiltinOperator_EMBEDDING_LOOKUP ||
|
||||
op_code == BuiltinOperator_RNN ||
|
||||
op_code == BuiltinOperator_BIDIRECTIONAL_SEQUENCE_RNN ||
|
||||
op_code == BuiltinOperator_UNIDIRECTIONAL_SEQUENCE_LSTM ||
|
||||
@ -271,6 +269,7 @@ void UpdateInt8OperatorVersions(ModelT* model) {
|
||||
|
||||
} else if (op_code == BuiltinOperator_FULLY_CONNECTED ||
|
||||
op_code == BuiltinOperator_BIDIRECTIONAL_SEQUENCE_LSTM ||
|
||||
op_code == BuiltinOperator_EMBEDDING_LOOKUP ||
|
||||
op_code == BuiltinOperator_LSTM) {
|
||||
model->operator_codes[i]->version = 3;
|
||||
}
|
||||
@ -286,7 +285,8 @@ bool IsQuantizationPassThroughOps(
|
||||
const OperatorT* consumer_op = consumer_op_infos.front().op;
|
||||
const BuiltinOperator op_code =
|
||||
model->operator_codes[consumer_op->opcode_index]->builtin_code;
|
||||
return op_code == BuiltinOperator_GATHER;
|
||||
return op_code == BuiltinOperator_GATHER ||
|
||||
op_code == BuiltinOperator_EMBEDDING_LOOKUP;
|
||||
}
|
||||
|
||||
// Copies quantization parameters from input to output and returns consumers of
|
||||
|
Loading…
x
Reference in New Issue
Block a user