From 63463d853e7a26183d02008d23037da48f7747f7 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E7=8E=8B=E6=8C=AF=E5=8D=8E=20=28Zhenhua=20WANG=29?= Date: Wed, 12 Jun 2019 15:27:08 +0800 Subject: [PATCH] lite: enable int8 for op sparse_to_dense Test: bazel run tensorflow/lite/kernels:sparse_to_dense_test --- tensorflow/lite/kernels/sparse_to_dense.cc | 16 +++++++----- .../lite/kernels/sparse_to_dense_test.cc | 26 +++++++++++++++---- 2 files changed, 31 insertions(+), 11 deletions(-) diff --git a/tensorflow/lite/kernels/sparse_to_dense.cc b/tensorflow/lite/kernels/sparse_to_dense.cc index 74eef2d698e..0e49d503ad5 100644 --- a/tensorflow/lite/kernels/sparse_to_dense.cc +++ b/tensorflow/lite/kernels/sparse_to_dense.cc @@ -170,10 +170,11 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { TF_LITE_ENSURE( context, indices->type == kTfLiteInt32 || indices->type == kTfLiteInt64); TF_LITE_ENSURE(context, output_shape->type == kTfLiteInt32 || - output_shape->type == kTfLiteInt64); + output_shape->type == kTfLiteInt64); TF_LITE_ENSURE(context, values->type == kTfLiteInt32 || - values->type == kTfLiteInt64 || - values->type == kTfLiteFloat32); + values->type == kTfLiteInt64 || + values->type == kTfLiteInt8 || + values->type == kTfLiteFloat32); TF_LITE_ENSURE_EQ(context, values->type, default_value->type); // Ensure dimensions match. @@ -232,7 +233,8 @@ TfLiteStatus EvalForIndexType(TfLiteContext* context, TfLiteNode* node, } default: context->ReportError( - context, "Type %d is currently not supported by sparse to dense.", + context, + "Indice type %d is currently not supported by sparse to dense.", indices->type); return kTfLiteError; } @@ -242,7 +244,6 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { const TfLiteTensor* indices = GetInput(context, node, kIndicesTensor); const TfLiteTensor* values = GetInput(context, node, kValueInputTensor); - // Currently only supports float32, int32 and int64. switch (values->type) { case kTfLiteFloat32: return EvalForIndexType(context, node, indices); @@ -250,9 +251,12 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { return EvalForIndexType(context, node, indices); case kTfLiteInt64: return EvalForIndexType(context, node, indices); + case kTfLiteInt8: + return EvalForIndexType(context, node, indices); default: context->ReportError( - context, "Type %d is currently not supported by sparse to dense.", + context, + "Value type %d is currently not supported by sparse to dense.", values->type); return kTfLiteError; } diff --git a/tensorflow/lite/kernels/sparse_to_dense_test.cc b/tensorflow/lite/kernels/sparse_to_dense_test.cc index 4a5ce6a36b5..27d733afd3c 100644 --- a/tensorflow/lite/kernels/sparse_to_dense_test.cc +++ b/tensorflow/lite/kernels/sparse_to_dense_test.cc @@ -100,6 +100,21 @@ TEST(SparseToDenseOpModelTest, TwoDimensionsTest) { EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({3, 3, 3})); } +TEST(SparseToDenseOpModelTest, Int64IndexTest) { + SparseToDenseOpModel m({3, 3}, {3}, {3}, -1, TensorType_INT64, + TensorType_FLOAT32); + m.PopulateTensor(m.indices(), {0, 0, 0, 1, 2, 1, 2, 0, 1}); + m.PopulateTensor(m.output_shape(), {3, 3, 3}); + m.PopulateTensor(m.values(), {2, 4, 6}); + m.Invoke(); + + EXPECT_THAT( + m.GetOutput(), + ElementsAreArray({2, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, + -1, -1, 4, -1, -1, 6, -1, -1, -1, -1, -1, -1, -1})); + EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({3, 3, 3})); +} + TEST(SparseToDenseOpModelTest, DefaultValueTest) { SparseToDenseOpModel m({3, 3}, {3}, {3}, -1, TensorType_INT32, TensorType_FLOAT32); @@ -145,12 +160,12 @@ TEST(SparseToDenseOpModelTest, Int64ValueTest) { EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({3, 3, 3})); } -TEST(SparseToDenseOpModelTest, Int64IndexTest) { - SparseToDenseOpModel m({3, 3}, {3}, {3}, -1, TensorType_INT64, - TensorType_FLOAT32); - m.PopulateTensor(m.indices(), {0, 0, 0, 1, 2, 1, 2, 0, 1}); +TEST(SparseToDenseOpModelTest, Int8ValueTest) { + SparseToDenseOpModel m({3, 3}, {3}, {3}, -1, TensorType_INT32, + TensorType_INT8); + m.PopulateTensor(m.indices(), {0, 0, 0, 1, 2, 1, 2, 0, 1}); m.PopulateTensor(m.output_shape(), {3, 3, 3}); - m.PopulateTensor(m.values(), {2, 4, 6}); + m.PopulateTensor(m.values(), {2, 4, 6}); m.Invoke(); EXPECT_THAT( @@ -160,5 +175,6 @@ TEST(SparseToDenseOpModelTest, Int64IndexTest) { EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({3, 3, 3})); } + } // namespace } // namespace tflite