TFLite GPU: Added some tests for model_builder.

PiperOrigin-RevId: 242247595
This commit is contained in:
Juhyun Lee 2019-04-05 23:05:30 -07:00 committed by TensorFlower Gardener
parent f0cccb0705
commit a4b0dc1acb
4 changed files with 126 additions and 4 deletions

View File

@ -77,7 +77,15 @@ cc_library(
], ],
) )
# TODO(impjdi): Add unit test for model_builder. cc_test(
name = "model_builder_test",
srcs = ["model_builder_test.cc"],
deps = [
":model_builder",
"//tensorflow/lite/c:c_api_internal",
"@com_google_googletest//:gtest_main",
],
)
cc_library( cc_library(
name = "model_transformer", name = "model_transformer",

View File

@ -314,7 +314,7 @@ class ObjectReader {
} }
Value<TensorRefFloat32>* value = graph_->NewValue(); Value<TensorRefFloat32>* value = graph_->NewValue();
RETURN_IF_ERROR( RETURN_IF_ERROR(
ConvertTfliteTensorToTensorRef(tflite_tensor, &value->tensor)); ConvertTfLiteTensorToTensorRef(tflite_tensor, &value->tensor));
value->tensor.ref = tensor_idx; value->tensor.ref = tensor_idx;
(*tensor_to_value_)[tensor_idx] = value; (*tensor_to_value_)[tensor_idx] = value;
} }
@ -1848,7 +1848,7 @@ std::unique_ptr<TFLiteOperationParser> NewOperationParser(
} // namespace } // namespace
Status ConvertTfliteTensorToTensorRef(const TfLiteTensor& tflite_tensor, Status ConvertTfLiteTensorToTensorRef(const TfLiteTensor& tflite_tensor,
TensorRefFloat32* tensor_ref) { TensorRefFloat32* tensor_ref) {
tensor_ref->type = ToDataType(tflite_tensor.type); tensor_ref->type = ToDataType(tflite_tensor.type);
const TfLiteIntArray* dims = tflite_tensor.dims; const TfLiteIntArray* dims = tflite_tensor.dims;

View File

@ -36,7 +36,8 @@ Status BuildModel(TfLiteContext* context,
const TfLiteDelegateParams* delegate_params, const TfLiteDelegateParams* delegate_params,
GraphFloat32* graph); GraphFloat32* graph);
Status ConvertTfliteTensorToTensorRef(const TfLiteTensor& tflite_tensor, // Module-internal converter, exposed for unit testing purpose only.
Status ConvertTfLiteTensorToTensorRef(const TfLiteTensor& tflite_tensor,
TensorRefFloat32* tensor_ref); TensorRefFloat32* tensor_ref);
} // namespace gpu } // namespace gpu

View File

@ -0,0 +1,113 @@
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/lite/delegates/gpu/common/model_builder.h"
#include <gmock/gmock.h>
#include <gtest/gtest.h>
#include "tensorflow/lite/c/c_api_internal.h"
namespace tflite {
namespace gpu {
namespace {
TEST(ModelBuilderTest, ConvertTfLiteTensorToTensorRefSucceedsForRank0) {
TfLiteTensor tflite_tensor;
tflite_tensor.type = TfLiteType::kTfLiteFloat32;
tflite_tensor.dims = TfLiteIntArrayCreate(1);
tflite_tensor.dims->data[0] = 4;
TensorRefFloat32 tensor_ref;
const auto status =
ConvertTfLiteTensorToTensorRef(tflite_tensor, &tensor_ref);
TfLiteIntArrayFree(tflite_tensor.dims);
ASSERT_TRUE(status.ok());
EXPECT_EQ(tensor_ref.type, DataType::FLOAT32);
EXPECT_EQ(tensor_ref.shape, BHWC(4, 1, 1, 1));
}
TEST(ModelBuilderTest, ConvertTfLiteTensorToTensorRefSucceedsForRank1) {
TfLiteTensor tflite_tensor;
tflite_tensor.type = TfLiteType::kTfLiteInt32;
tflite_tensor.dims = TfLiteIntArrayCreate(2);
tflite_tensor.dims->data[0] = 4;
tflite_tensor.dims->data[1] = 5;
TensorRefFloat32 tensor_ref;
const auto status =
ConvertTfLiteTensorToTensorRef(tflite_tensor, &tensor_ref);
TfLiteIntArrayFree(tflite_tensor.dims);
ASSERT_TRUE(status.ok());
EXPECT_EQ(tensor_ref.type, DataType::INT32);
EXPECT_EQ(tensor_ref.shape, BHWC(4, 1, 1, 5));
}
TEST(ModelBuilderTest, ConvertTfLiteTensorToTensorRefSucceedsForRank2) {
TfLiteTensor tflite_tensor;
tflite_tensor.type = TfLiteType::kTfLiteInt64;
tflite_tensor.dims = TfLiteIntArrayCreate(3);
tflite_tensor.dims->data[0] = 4;
tflite_tensor.dims->data[1] = 5;
tflite_tensor.dims->data[2] = 6;
TensorRefFloat32 tensor_ref;
const auto status =
ConvertTfLiteTensorToTensorRef(tflite_tensor, &tensor_ref);
TfLiteIntArrayFree(tflite_tensor.dims);
ASSERT_TRUE(status.ok());
EXPECT_EQ(tensor_ref.type, DataType::INT64);
EXPECT_EQ(tensor_ref.shape, BHWC(4, 1, 5, 6));
}
TEST(ModelBuilderTest, ConvertTfLiteTensorToTensorRefSucceedsForRank3) {
TfLiteTensor tflite_tensor;
tflite_tensor.type = TfLiteType::kTfLiteUInt8;
tflite_tensor.dims = TfLiteIntArrayCreate(4);
tflite_tensor.dims->data[0] = 4;
tflite_tensor.dims->data[1] = 5;
tflite_tensor.dims->data[2] = 6;
tflite_tensor.dims->data[3] = 7;
TensorRefFloat32 tensor_ref;
const auto status =
ConvertTfLiteTensorToTensorRef(tflite_tensor, &tensor_ref);
TfLiteIntArrayFree(tflite_tensor.dims);
ASSERT_TRUE(status.ok());
EXPECT_EQ(tensor_ref.type, DataType::UINT8);
EXPECT_EQ(tensor_ref.shape, BHWC(4, 5, 6, 7));
}
TEST(ModelBuilderTest, ConvertTfLiteTensorToTensorRefFailsForRankLT0) {
TfLiteTensor tflite_tensor;
tflite_tensor.type = TfLiteType::kTfLiteFloat32;
tflite_tensor.dims = TfLiteIntArrayCreate(0);
TensorRefFloat32 tensor_ref;
const auto status =
ConvertTfLiteTensorToTensorRef(tflite_tensor, &tensor_ref);
TfLiteIntArrayFree(tflite_tensor.dims);
// TODO(b/130054481): Cover scalar.
EXPECT_FALSE(status.ok());
}
TEST(ModelBuilderTest, ConvertTfLiteTensorToTensorRefFailsForRankGT3) {
TfLiteTensor tflite_tensor;
tflite_tensor.type = TfLiteType::kTfLiteFloat32;
tflite_tensor.dims = TfLiteIntArrayCreate(5);
TensorRefFloat32 tensor_ref;
const auto status =
ConvertTfLiteTensorToTensorRef(tflite_tensor, &tensor_ref);
TfLiteIntArrayFree(tflite_tensor.dims);
EXPECT_FALSE(status.ok());
}
} // namespace
} // namespace gpu
} // namespace tflite