STT-tensorflow/tensorflow/lite/testing/selective_build_test.cc
Thai Nguyen 57f589f66c Selective build rule for C
Usage:
tflite_custom_c_library(
  name = "custom_c_api",
  models = ["model1", "model2"],
)
PiperOrigin-RevId: 347516727
Change-Id: Id9b6b0de9f720aa32351de2605dac73f6bb28d63
2020-12-14 19:04:03 -08:00

81 lines
2.6 KiB
C++

/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include <string>
#include <gmock/gmock.h>
#include <gtest/gtest.h>
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/lite/c/common.h"
#include "tensorflow/lite/create_op_resolver.h"
#include "tensorflow/lite/interpreter.h"
#include "tensorflow/lite/model.h"
#include "tensorflow/lite/model_builder.h"
namespace tflite {
bool RunWithRandomInputs(const std::string& filename) {
std::unique_ptr<tflite::FlatBufferModel> model =
tflite::FlatBufferModel::BuildFromFile(filename.c_str());
// Build the interpreter
std::unique_ptr<MutableOpResolver> resolver = CreateOpResolver();
std::unique_ptr<tflite::Interpreter> interpreter;
if (tflite::InterpreterBuilder(*model, *resolver)(&interpreter) !=
kTfLiteOk) {
LOG(ERROR) << "Could not initialize interpreter for TFLite model.";
return false;
}
// Resize input tensors, if desired.
if (interpreter->AllocateTensors() != kTfLiteOk) {
LOG(ERROR) << "Could not allocate tensor.";
return false;
}
// Fill the random data.
std::vector<std::vector<uint8_t>> sample;
for (int tensor_idx : interpreter->inputs()) {
auto tensor = interpreter->tensor(tensor_idx);
std::vector<uint8_t> data(tensor->bytes);
for (auto it = data.begin(); it != data.end(); ++it) {
*it = random();
}
sample.push_back(data);
tensor->data.raw = reinterpret_cast<char*>(sample.rbegin()->data());
}
// Running inference.
if (interpreter->Invoke() != kTfLiteOk) {
LOG(ERROR) << "Failed to run the model.";
return false;
}
return true;
}
TEST(SelectiveBuiltTest, AddModel) {
std::string model = "third_party/tensorflow/lite/testdata/add.bin";
EXPECT_THAT(RunWithRandomInputs(model), true);
}
TEST(SelectiveBuiltTest, LSTMModel) {
std::string model = "third_party/tensorflow/lite/testdata/lstm.bin";
EXPECT_THAT(RunWithRandomInputs(model), true);
}
} // namespace tflite
int main(int argc, char** argv) {
::testing::InitGoogleTest(&argc, argv);
return RUN_ALL_TESTS();
}