lite: enable int8 for op sparse_to_dense
Test: bazel run tensorflow/lite/kernels:sparse_to_dense_test
This commit is contained in:
parent
8211365f9e
commit
63463d853e
@ -170,10 +170,11 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
|
||||
TF_LITE_ENSURE(
|
||||
context, indices->type == kTfLiteInt32 || indices->type == kTfLiteInt64);
|
||||
TF_LITE_ENSURE(context, output_shape->type == kTfLiteInt32 ||
|
||||
output_shape->type == kTfLiteInt64);
|
||||
output_shape->type == kTfLiteInt64);
|
||||
TF_LITE_ENSURE(context, values->type == kTfLiteInt32 ||
|
||||
values->type == kTfLiteInt64 ||
|
||||
values->type == kTfLiteFloat32);
|
||||
values->type == kTfLiteInt64 ||
|
||||
values->type == kTfLiteInt8 ||
|
||||
values->type == kTfLiteFloat32);
|
||||
TF_LITE_ENSURE_EQ(context, values->type, default_value->type);
|
||||
|
||||
// Ensure dimensions match.
|
||||
@ -232,7 +233,8 @@ TfLiteStatus EvalForIndexType(TfLiteContext* context, TfLiteNode* node,
|
||||
}
|
||||
default:
|
||||
context->ReportError(
|
||||
context, "Type %d is currently not supported by sparse to dense.",
|
||||
context,
|
||||
"Indice type %d is currently not supported by sparse to dense.",
|
||||
indices->type);
|
||||
return kTfLiteError;
|
||||
}
|
||||
@ -242,7 +244,6 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
|
||||
const TfLiteTensor* indices = GetInput(context, node, kIndicesTensor);
|
||||
const TfLiteTensor* values = GetInput(context, node, kValueInputTensor);
|
||||
|
||||
// Currently only supports float32, int32 and int64.
|
||||
switch (values->type) {
|
||||
case kTfLiteFloat32:
|
||||
return EvalForIndexType<float>(context, node, indices);
|
||||
@ -250,9 +251,12 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
|
||||
return EvalForIndexType<int32_t>(context, node, indices);
|
||||
case kTfLiteInt64:
|
||||
return EvalForIndexType<int64_t>(context, node, indices);
|
||||
case kTfLiteInt8:
|
||||
return EvalForIndexType<int8_t>(context, node, indices);
|
||||
default:
|
||||
context->ReportError(
|
||||
context, "Type %d is currently not supported by sparse to dense.",
|
||||
context,
|
||||
"Value type %d is currently not supported by sparse to dense.",
|
||||
values->type);
|
||||
return kTfLiteError;
|
||||
}
|
||||
|
@ -100,6 +100,21 @@ TEST(SparseToDenseOpModelTest, TwoDimensionsTest) {
|
||||
EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({3, 3, 3}));
|
||||
}
|
||||
|
||||
TEST(SparseToDenseOpModelTest, Int64IndexTest) {
|
||||
SparseToDenseOpModel<float> m({3, 3}, {3}, {3}, -1, TensorType_INT64,
|
||||
TensorType_FLOAT32);
|
||||
m.PopulateTensor<int64_t>(m.indices(), {0, 0, 0, 1, 2, 1, 2, 0, 1});
|
||||
m.PopulateTensor<int32_t>(m.output_shape(), {3, 3, 3});
|
||||
m.PopulateTensor<float>(m.values(), {2, 4, 6});
|
||||
m.Invoke();
|
||||
|
||||
EXPECT_THAT(
|
||||
m.GetOutput(),
|
||||
ElementsAreArray({2, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1,
|
||||
-1, -1, 4, -1, -1, 6, -1, -1, -1, -1, -1, -1, -1}));
|
||||
EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({3, 3, 3}));
|
||||
}
|
||||
|
||||
TEST(SparseToDenseOpModelTest, DefaultValueTest) {
|
||||
SparseToDenseOpModel<float> m({3, 3}, {3}, {3}, -1, TensorType_INT32,
|
||||
TensorType_FLOAT32);
|
||||
@ -145,12 +160,12 @@ TEST(SparseToDenseOpModelTest, Int64ValueTest) {
|
||||
EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({3, 3, 3}));
|
||||
}
|
||||
|
||||
TEST(SparseToDenseOpModelTest, Int64IndexTest) {
|
||||
SparseToDenseOpModel<float> m({3, 3}, {3}, {3}, -1, TensorType_INT64,
|
||||
TensorType_FLOAT32);
|
||||
m.PopulateTensor<int64_t>(m.indices(), {0, 0, 0, 1, 2, 1, 2, 0, 1});
|
||||
TEST(SparseToDenseOpModelTest, Int8ValueTest) {
|
||||
SparseToDenseOpModel<int8_t> m({3, 3}, {3}, {3}, -1, TensorType_INT32,
|
||||
TensorType_INT8);
|
||||
m.PopulateTensor<int32_t>(m.indices(), {0, 0, 0, 1, 2, 1, 2, 0, 1});
|
||||
m.PopulateTensor<int32_t>(m.output_shape(), {3, 3, 3});
|
||||
m.PopulateTensor<float>(m.values(), {2, 4, 6});
|
||||
m.PopulateTensor<int8_t>(m.values(), {2, 4, 6});
|
||||
m.Invoke();
|
||||
|
||||
EXPECT_THAT(
|
||||
@ -160,5 +175,6 @@ TEST(SparseToDenseOpModelTest, Int64IndexTest) {
|
||||
EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({3, 3, 3}));
|
||||
}
|
||||
|
||||
|
||||
} // namespace
|
||||
} // namespace tflite
|
||||
|
Loading…
Reference in New Issue
Block a user