lite: enable int8 for op sparse_to_dense

Test: bazel run tensorflow/lite/kernels:sparse_to_dense_test
This commit is contained in:
王振华 (Zhenhua WANG) 2019-06-12 15:27:08 +08:00
parent 8211365f9e
commit 63463d853e
2 changed files with 31 additions and 11 deletions

View File

@ -170,10 +170,11 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
TF_LITE_ENSURE(
context, indices->type == kTfLiteInt32 || indices->type == kTfLiteInt64);
TF_LITE_ENSURE(context, output_shape->type == kTfLiteInt32 ||
output_shape->type == kTfLiteInt64);
output_shape->type == kTfLiteInt64);
TF_LITE_ENSURE(context, values->type == kTfLiteInt32 ||
values->type == kTfLiteInt64 ||
values->type == kTfLiteFloat32);
values->type == kTfLiteInt64 ||
values->type == kTfLiteInt8 ||
values->type == kTfLiteFloat32);
TF_LITE_ENSURE_EQ(context, values->type, default_value->type);
// Ensure dimensions match.
@ -232,7 +233,8 @@ TfLiteStatus EvalForIndexType(TfLiteContext* context, TfLiteNode* node,
}
default:
context->ReportError(
context, "Type %d is currently not supported by sparse to dense.",
context,
"Indice type %d is currently not supported by sparse to dense.",
indices->type);
return kTfLiteError;
}
@ -242,7 +244,6 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
const TfLiteTensor* indices = GetInput(context, node, kIndicesTensor);
const TfLiteTensor* values = GetInput(context, node, kValueInputTensor);
// Currently only supports float32, int32 and int64.
switch (values->type) {
case kTfLiteFloat32:
return EvalForIndexType<float>(context, node, indices);
@ -250,9 +251,12 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
return EvalForIndexType<int32_t>(context, node, indices);
case kTfLiteInt64:
return EvalForIndexType<int64_t>(context, node, indices);
case kTfLiteInt8:
return EvalForIndexType<int8_t>(context, node, indices);
default:
context->ReportError(
context, "Type %d is currently not supported by sparse to dense.",
context,
"Value type %d is currently not supported by sparse to dense.",
values->type);
return kTfLiteError;
}

View File

@ -100,6 +100,21 @@ TEST(SparseToDenseOpModelTest, TwoDimensionsTest) {
EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({3, 3, 3}));
}
TEST(SparseToDenseOpModelTest, Int64IndexTest) {
SparseToDenseOpModel<float> m({3, 3}, {3}, {3}, -1, TensorType_INT64,
TensorType_FLOAT32);
m.PopulateTensor<int64_t>(m.indices(), {0, 0, 0, 1, 2, 1, 2, 0, 1});
m.PopulateTensor<int32_t>(m.output_shape(), {3, 3, 3});
m.PopulateTensor<float>(m.values(), {2, 4, 6});
m.Invoke();
EXPECT_THAT(
m.GetOutput(),
ElementsAreArray({2, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1,
-1, -1, 4, -1, -1, 6, -1, -1, -1, -1, -1, -1, -1}));
EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({3, 3, 3}));
}
TEST(SparseToDenseOpModelTest, DefaultValueTest) {
SparseToDenseOpModel<float> m({3, 3}, {3}, {3}, -1, TensorType_INT32,
TensorType_FLOAT32);
@ -145,12 +160,12 @@ TEST(SparseToDenseOpModelTest, Int64ValueTest) {
EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({3, 3, 3}));
}
TEST(SparseToDenseOpModelTest, Int64IndexTest) {
SparseToDenseOpModel<float> m({3, 3}, {3}, {3}, -1, TensorType_INT64,
TensorType_FLOAT32);
m.PopulateTensor<int64_t>(m.indices(), {0, 0, 0, 1, 2, 1, 2, 0, 1});
TEST(SparseToDenseOpModelTest, Int8ValueTest) {
SparseToDenseOpModel<int8_t> m({3, 3}, {3}, {3}, -1, TensorType_INT32,
TensorType_INT8);
m.PopulateTensor<int32_t>(m.indices(), {0, 0, 0, 1, 2, 1, 2, 0, 1});
m.PopulateTensor<int32_t>(m.output_shape(), {3, 3, 3});
m.PopulateTensor<float>(m.values(), {2, 4, 6});
m.PopulateTensor<int8_t>(m.values(), {2, 4, 6});
m.Invoke();
EXPECT_THAT(
@ -160,5 +175,6 @@ TEST(SparseToDenseOpModelTest, Int64IndexTest) {
EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({3, 3, 3}));
}
} // namespace
} // namespace tflite