Add capability to insert intermediate tensors into a model. This is to support quantized lstm.
PiperOrigin-RevId: 278713731 Change-Id: Id38a708db4e6678dc7bfce46dd66a8b66480dc35
This commit is contained in:
parent
3006f06929
commit
2b78c1e722
@ -12,6 +12,36 @@ exports_files(glob([
|
||||
"testdata/*.bin",
|
||||
]))
|
||||
|
||||
cc_library(
|
||||
name = "add_intermediate_tensors",
|
||||
srcs = ["add_intermediate_tensors.cc"],
|
||||
hdrs = ["add_intermediate_tensors.h"],
|
||||
deps = [
|
||||
":operator_property",
|
||||
"//tensorflow/lite:framework",
|
||||
"//tensorflow/lite/core/api",
|
||||
"//tensorflow/lite/schema:schema_fbs",
|
||||
"@flatbuffers",
|
||||
],
|
||||
)
|
||||
|
||||
tf_cc_test(
|
||||
name = "add_intermediate_tensors_test",
|
||||
srcs = ["add_intermediate_tensors_test.cc"],
|
||||
tags = [
|
||||
"tflite_not_portable_android",
|
||||
"tflite_not_portable_ios",
|
||||
],
|
||||
deps = [
|
||||
":add_intermediate_tensors",
|
||||
"//tensorflow/lite:framework",
|
||||
"//tensorflow/lite/schema:schema_fbs",
|
||||
"@com_google_absl//absl/memory",
|
||||
"@com_google_googletest//:gtest",
|
||||
"@flatbuffers",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "quantization_utils",
|
||||
srcs = ["quantization_utils.cc"],
|
||||
|
76
tensorflow/lite/tools/optimize/add_intermediate_tensors.cc
Normal file
76
tensorflow/lite/tools/optimize/add_intermediate_tensors.cc
Normal file
@ -0,0 +1,76 @@
|
||||
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
#include "tensorflow/lite/tools/optimize/add_intermediate_tensors.h"
|
||||
|
||||
#include <memory>
|
||||
|
||||
#include "tensorflow/lite/tools/optimize/operator_property.h"
|
||||
|
||||
namespace tflite {
|
||||
namespace optimize {
|
||||
namespace {
|
||||
|
||||
void MakeTensor(const string& name, std::unique_ptr<TensorT>* tensor) {
|
||||
TensorT* tensor_raw = new TensorT;
|
||||
tensor_raw->name = name;
|
||||
tensor_raw->shape = {0};
|
||||
tensor_raw->type = TensorType_FLOAT32;
|
||||
|
||||
tensor->reset(tensor_raw);
|
||||
}
|
||||
|
||||
string CreateTensorName(int op_index, int tensor_index) {
|
||||
return "intermediate_" + std::to_string(op_index) + "_" +
|
||||
std::to_string(tensor_index);
|
||||
}
|
||||
} // namespace
|
||||
|
||||
TfLiteStatus AddIntemediateTensorsToFusedOp(
|
||||
flatbuffers::FlatBufferBuilder* builder, ModelT* model) {
|
||||
// Process the model.
|
||||
for (int subgraph_idx = 0; subgraph_idx < model->subgraphs.size();
|
||||
++subgraph_idx) {
|
||||
SubGraphT* subgraph = model->subgraphs.at(subgraph_idx).get();
|
||||
for (size_t op_idx = 0; op_idx < subgraph->operators.size(); op_idx++) {
|
||||
// Find LSTM
|
||||
OperatorT* op = subgraph->operators[op_idx].get();
|
||||
operator_property::OperatorProperty property =
|
||||
operator_property::GetOperatorProperty(model, subgraph_idx, op_idx);
|
||||
if (property.intermediates.empty()) {
|
||||
continue;
|
||||
}
|
||||
// Add tensors.
|
||||
const int next_tensor_index = subgraph->tensors.size();
|
||||
const int num_intermediates = property.intermediates.size();
|
||||
for (int i = 0; i < num_intermediates; ++i) {
|
||||
std::unique_ptr<TensorT> intermediate_tensor;
|
||||
auto name = CreateTensorName(op_idx, i);
|
||||
MakeTensor(name, &intermediate_tensor);
|
||||
subgraph->tensors.push_back(std::move(intermediate_tensor));
|
||||
op->intermediates.push_back(next_tensor_index + i);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Export the model.
|
||||
flatbuffers::Offset<Model> output_model_location =
|
||||
Model::Pack(*builder, model);
|
||||
FinishModelBuffer(*builder, output_model_location);
|
||||
|
||||
return kTfLiteOk;
|
||||
}
|
||||
|
||||
} // namespace optimize
|
||||
} // namespace tflite
|
32
tensorflow/lite/tools/optimize/add_intermediate_tensors.h
Normal file
32
tensorflow/lite/tools/optimize/add_intermediate_tensors.h
Normal file
@ -0,0 +1,32 @@
|
||||
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
#ifndef TENSORFLOW_LITE_TOOLS_OPTIMIZE_ADD_INTERMEDIATE_TENSORS_H_
|
||||
#define TENSORFLOW_LITE_TOOLS_OPTIMIZE_ADD_INTERMEDIATE_TENSORS_H_
|
||||
|
||||
#include "tensorflow/lite/core/api/error_reporter.h"
|
||||
#include "tensorflow/lite/model.h"
|
||||
#include "tensorflow/lite/schema/schema_generated.h"
|
||||
|
||||
namespace tflite {
|
||||
namespace optimize {
|
||||
|
||||
// Going through the model and add intermediates tensors if the ops have any.
|
||||
TfLiteStatus AddIntemediateTensorsToFusedOp(
|
||||
flatbuffers::FlatBufferBuilder* builder, ModelT* input_model);
|
||||
|
||||
} // namespace optimize
|
||||
} // namespace tflite
|
||||
|
||||
#endif // TENSORFLOW_LITE_TOOLS_OPTIMIZE_ADD_INTERMEDIATE_TENSORS_H_
|
@ -0,0 +1,82 @@
|
||||
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
#include "tensorflow/lite/tools/optimize/add_intermediate_tensors.h"
|
||||
|
||||
#include <gmock/gmock.h>
|
||||
#include <gtest/gtest.h>
|
||||
#include "absl/memory/memory.h"
|
||||
#include "tensorflow/lite/model.h"
|
||||
#include "tensorflow/lite/schema/schema_generated.h"
|
||||
|
||||
namespace tflite {
|
||||
namespace optimize {
|
||||
namespace {
|
||||
|
||||
using ::testing::ElementsAreArray;
|
||||
|
||||
TEST(LstmPreprocess, Add2Tensors) {
|
||||
// Create a model with 1 lstm layer.
|
||||
auto model = absl::make_unique<ModelT>();
|
||||
auto subgraph = absl::make_unique<tflite::SubGraphT>();
|
||||
auto tensor = absl::make_unique<TensorT>();
|
||||
auto buffer = absl::make_unique<tflite::BufferT>();
|
||||
auto lstm_op_code = absl::make_unique<OperatorCodeT>();
|
||||
auto lstm_op = absl::make_unique<OperatorT>();
|
||||
|
||||
tensor->name = "lstm_tensor0";
|
||||
tensor->shape = {2, 3, 4};
|
||||
tensor->type = TensorType_FLOAT32;
|
||||
lstm_op_code->builtin_code = BuiltinOperator_LSTM;
|
||||
lstm_op_code->version = 2;
|
||||
lstm_op->opcode_index = 0;
|
||||
lstm_op->inputs = {0};
|
||||
lstm_op->outputs = {0};
|
||||
|
||||
model->subgraphs.push_back(std::move(subgraph));
|
||||
model->subgraphs[0]->operators.push_back(std::move(lstm_op));
|
||||
model->subgraphs[0]->tensors.push_back(std::move(tensor));
|
||||
model->operator_codes.push_back(std::move(lstm_op_code));
|
||||
model->buffers.push_back(std::move(buffer));
|
||||
|
||||
// Add 2 tensors.
|
||||
flatbuffers::FlatBufferBuilder builder;
|
||||
tflite::optimize::AddIntemediateTensorsToFusedOp(&builder, model.get());
|
||||
|
||||
// Verify results.
|
||||
EXPECT_EQ(model->operator_codes.size(), 1);
|
||||
EXPECT_EQ(model->subgraphs.size(), 1);
|
||||
EXPECT_EQ(model->subgraphs[0]->operators.size(), 1);
|
||||
EXPECT_EQ(model->subgraphs[0]->tensors.size(), 6);
|
||||
EXPECT_EQ(model->buffers.size(), 1);
|
||||
|
||||
EXPECT_EQ(model->operator_codes[0]->builtin_code, BuiltinOperator_LSTM);
|
||||
EXPECT_EQ(model->subgraphs[0]->tensors[0]->name, "lstm_tensor0");
|
||||
EXPECT_EQ(model->subgraphs[0]->tensors[1]->name, "intermediate_0_0");
|
||||
EXPECT_EQ(model->subgraphs[0]->tensors[2]->name, "intermediate_0_1");
|
||||
EXPECT_EQ(model->subgraphs[0]->tensors[3]->name, "intermediate_0_2");
|
||||
EXPECT_EQ(model->subgraphs[0]->tensors[4]->name, "intermediate_0_3");
|
||||
EXPECT_EQ(model->subgraphs[0]->tensors[5]->name, "intermediate_0_4");
|
||||
EXPECT_THAT(model->subgraphs[0]->operators[0]->inputs, ElementsAreArray({0}));
|
||||
EXPECT_THAT(model->subgraphs[0]->operators[0]->outputs,
|
||||
ElementsAreArray({0}));
|
||||
EXPECT_THAT(model->subgraphs[0]->operators[0]->intermediates,
|
||||
ElementsAreArray({1, 2, 3, 4, 5}));
|
||||
}
|
||||
|
||||
} // namespace
|
||||
} // namespace optimize
|
||||
} // namespace tflite
|
||||
|
||||
int main(int argc, char** argv) { return RUN_ALL_TESTS(); }
|
Loading…
Reference in New Issue
Block a user