Add a graph transformation to identify and fuse success transposes that are noops
PiperOrigin-RevId: 262034930
This commit is contained in:
parent
74cb38c373
commit
bef807a6d6
@ -222,6 +222,7 @@ cc_library(
|
||||
"graph_transformations/quantize.cc",
|
||||
"graph_transformations/read_array_minmax_and_narrow_range_from_fake_quant.cc",
|
||||
"graph_transformations/remove_final_dequantize_op.cc",
|
||||
"graph_transformations/remove_successive_transpose.cc",
|
||||
"graph_transformations/remove_tensorflow_assert.cc",
|
||||
"graph_transformations/remove_tensorflow_identity.cc",
|
||||
"graph_transformations/remove_trivial_binary.cc",
|
||||
|
@ -159,6 +159,7 @@ DECLARE_GRAPH_TRANSFORMATION(PropagateFixedSizes)
|
||||
DECLARE_GRAPH_TRANSFORMATION(HardcodeMinMax)
|
||||
DECLARE_GRAPH_TRANSFORMATION(Quantize)
|
||||
DECLARE_GRAPH_TRANSFORMATION(RemoveFinalDequantizeOp)
|
||||
DECLARE_GRAPH_TRANSFORMATION(RemoveSuccesiveTranspose)
|
||||
DECLARE_GRAPH_TRANSFORMATION(RemoveTensorFlowAssert)
|
||||
DECLARE_GRAPH_TRANSFORMATION(RemoveTensorFlowIdentity)
|
||||
DECLARE_GRAPH_TRANSFORMATION(RemoveTrivialBinaryOperator)
|
||||
|
@ -0,0 +1,95 @@
|
||||
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
#include <vector>
|
||||
|
||||
#include "tensorflow/core/platform/logging.h"
|
||||
#include "tensorflow/lite/toco/graph_transformations/graph_transformations.h"
|
||||
#include "tensorflow/lite/toco/model.h"
|
||||
#include "tensorflow/lite/toco/tooling_util.h"
|
||||
|
||||
namespace toco {
|
||||
|
||||
namespace {
|
||||
|
||||
bool TransformsToIdentity(std::vector<int> const& perm1,
|
||||
std::vector<int> const& perm2) {
|
||||
if (perm2.size() != perm1.size() || perm1.empty()) {
|
||||
return false;
|
||||
}
|
||||
// perm1 is the order of the indices after first transpose. When perm1 is
|
||||
// reordered according to perm2, if the result is simple increasing sequence
|
||||
// i.e., range(0, perm1.size()), then the two transposes cancel each other.
|
||||
for (int i = 0; i < perm1.size(); ++i) {
|
||||
if (perm1[i] < 0 || perm1[i] >= perm1.size() || perm2[i] < 0 ||
|
||||
perm2[i] >= perm1.size()) {
|
||||
return false;
|
||||
}
|
||||
if (perm1[perm2[i]] != i) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
void ReplaceOpInputsWith(Model* model, const string& lookfor,
|
||||
const string& replacewith) {
|
||||
for (const auto& op : model->operators) {
|
||||
for (int i = 0; i < op->inputs.size(); ++i) {
|
||||
if (op->inputs[i] == lookfor) {
|
||||
op->inputs[i] = replacewith;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
::tensorflow::Status RemoveSuccesiveTranspose::Run(Model* model,
|
||||
std::size_t op_index,
|
||||
bool* modified) {
|
||||
*modified = false;
|
||||
auto op = model->operators.begin() + op_index;
|
||||
if (op->get()->type != OperatorType::kTranspose) {
|
||||
return ::tensorflow::Status::OK();
|
||||
}
|
||||
|
||||
TransposeOperator* t_op = static_cast<TransposeOperator*>(op->get());
|
||||
if (CountOpsWithInput(*model, t_op->outputs[0]) != 1) {
|
||||
return ::tensorflow::Status::OK();
|
||||
}
|
||||
Operator* next = GetOpWithInput(*model, t_op->outputs[0]);
|
||||
if (!next || next->type != OperatorType::kTranspose) {
|
||||
return ::tensorflow::Status::OK();
|
||||
}
|
||||
|
||||
TransposeOperator* t_next = static_cast<TransposeOperator*>(next);
|
||||
if (!CountOpsWithInput(*model, t_next->outputs[0])) {
|
||||
return ::tensorflow::Status::OK();
|
||||
}
|
||||
|
||||
if (TransformsToIdentity(t_op->perm, t_next->perm)) {
|
||||
// Find the input tensor that uses the results of transpose t_next, then
|
||||
// make it point to the input of t_op, effectively isolating both the
|
||||
// transposes from the graph.
|
||||
ReplaceOpInputsWith(model, t_next->outputs[0], t_op->inputs[0]);
|
||||
DeleteOpAndArrays(model, t_next);
|
||||
DeleteOpAndArrays(model, t_op);
|
||||
*modified = true;
|
||||
}
|
||||
|
||||
return ::tensorflow::Status::OK();
|
||||
}
|
||||
|
||||
} // namespace toco
|
@ -31,6 +31,17 @@ tf_cc_test(
|
||||
],
|
||||
)
|
||||
|
||||
tf_cc_test(
|
||||
name = "remove_successive_transpose_test",
|
||||
srcs = ["remove_successive_transpose_test.cc"],
|
||||
deps = [
|
||||
"//tensorflow/lite/toco:graph_transformations",
|
||||
"//tensorflow/lite/toco:model",
|
||||
"//tensorflow/lite/toco:tooling_util",
|
||||
"@com_google_googletest//:gtest_main",
|
||||
],
|
||||
)
|
||||
|
||||
tf_cc_test(
|
||||
name = "resolve_constant_concatenation_test",
|
||||
srcs = ["resolve_constant_concatenation_test.cc"],
|
||||
|
@ -0,0 +1,147 @@
|
||||
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
#include <gmock/gmock.h>
|
||||
#include <gtest/gtest.h>
|
||||
#include "tensorflow/lite/toco/graph_transformations/graph_transformations.h"
|
||||
#include "tensorflow/lite/toco/model.h"
|
||||
#include "tensorflow/lite/toco/tooling_util.h"
|
||||
|
||||
namespace {
|
||||
|
||||
using ::testing::Test;
|
||||
|
||||
class RemoveSuccessiveTransposeTest : public Test {
|
||||
protected:
|
||||
RemoveSuccessiveTransposeTest() {}
|
||||
|
||||
void SetUp() override { model_.reset(new toco::Model); }
|
||||
|
||||
void CreateArray(const std::string& name, const std::vector<int>& shape) {
|
||||
toco::Array& array = model_->GetOrCreateArray(name);
|
||||
array.data_type = toco::ArrayDataType::kFloat;
|
||||
toco::Shape* array_shape = array.mutable_shape();
|
||||
*(array_shape->mutable_dims()) = shape;
|
||||
}
|
||||
|
||||
void CreateConstantArray(const std::string& name,
|
||||
const std::vector<int>& shape,
|
||||
const std::vector<float>& data) {
|
||||
CreateArray(name, shape);
|
||||
toco::Array& array = model_->GetOrCreateArray(name);
|
||||
auto& array_buffer = array.GetMutableBuffer<toco::ArrayDataType::kFloat>();
|
||||
int bufsize = 1;
|
||||
for (int dim : shape) {
|
||||
bufsize *= dim;
|
||||
}
|
||||
array_buffer.data.resize(bufsize);
|
||||
float* buf_ptr = array_buffer.data.data();
|
||||
for (int i = 0; i < bufsize; ++i) {
|
||||
buf_ptr[i] = data[i];
|
||||
}
|
||||
}
|
||||
|
||||
void CreateGraph(const std::vector<int>& perm1,
|
||||
const std::vector<int>& perm2) {
|
||||
CreateArray("InputA", {2, 2});
|
||||
CreateArray("InputB", {2, 2});
|
||||
CreateArray("Input", {2, 2});
|
||||
CreateArray("InputTranspose", {2, 2});
|
||||
CreateArray("InputTransposeTranspose", {2, 2});
|
||||
CreateArray("InputTransposeTransposePlusB", {2, 2});
|
||||
|
||||
auto* add_op = new toco::AddOperator;
|
||||
add_op->inputs = {"InputA", "InputB"};
|
||||
add_op->outputs = {"Input"};
|
||||
model_->operators.push_back(std::unique_ptr<toco::Operator>(add_op));
|
||||
|
||||
auto* transpose_op = new toco::TransposeOperator;
|
||||
transpose_op->inputs = {"Input"};
|
||||
transpose_op->perm = perm1;
|
||||
transpose_op->outputs = {"InputTranspose"};
|
||||
model_->operators.push_back(std::unique_ptr<toco::Operator>(transpose_op));
|
||||
|
||||
auto* transpose2_op = new toco::TransposeOperator;
|
||||
transpose2_op->inputs = {"InputTranspose"};
|
||||
transpose2_op->perm = perm2;
|
||||
transpose2_op->outputs = {"InputTransposeTranspose"};
|
||||
model_->operators.push_back(std::unique_ptr<toco::Operator>(transpose2_op));
|
||||
|
||||
auto* add2_op = new toco::AddOperator;
|
||||
add2_op->inputs = {"InputTransposeTranspose", "InputB"};
|
||||
add2_op->outputs = {"InputTransposeTransposePlusB"};
|
||||
model_->operators.push_back(std::unique_ptr<toco::Operator>(add2_op));
|
||||
}
|
||||
|
||||
std::unique_ptr<toco::Model> model_;
|
||||
};
|
||||
|
||||
TEST_F(RemoveSuccessiveTransposeTest, RemoveTranspose) {
|
||||
// Creating a model.
|
||||
CreateGraph({1, 0}, {1, 0});
|
||||
|
||||
toco::RemoveSuccesiveTranspose transformation;
|
||||
bool modified;
|
||||
ASSERT_TRUE(transformation.Run(model_.get(), /*op_index=*/1, &modified).ok());
|
||||
EXPECT_TRUE(modified);
|
||||
|
||||
ASSERT_EQ(model_->operators.size(), 2);
|
||||
ASSERT_EQ(model_->operators[0]->type, toco::OperatorType::kAdd);
|
||||
ASSERT_EQ(model_->operators[1]->type, toco::OperatorType::kAdd);
|
||||
ASSERT_EQ(model_->operators[1]->inputs[0], model_->operators[0]->outputs[0]);
|
||||
}
|
||||
|
||||
TEST_F(RemoveSuccessiveTransposeTest, DontRemoveNotIdentityTranspose) {
|
||||
// Creating a model.
|
||||
CreateGraph({0, 2, 1}, {1, 0, 2});
|
||||
|
||||
toco::RemoveSuccesiveTranspose transformation;
|
||||
bool modified;
|
||||
ASSERT_TRUE(transformation.Run(model_.get(), /*op_index=*/1, &modified).ok());
|
||||
EXPECT_FALSE(modified);
|
||||
}
|
||||
|
||||
TEST_F(RemoveSuccessiveTransposeTest, DontRemoveTransposeOutputUnused) {
|
||||
CreateArray("InputA", {2, 2});
|
||||
CreateArray("InputB", {2, 2});
|
||||
CreateArray("Input", {2, 2});
|
||||
CreateArray("InputTranspose", {2, 2});
|
||||
CreateArray("InputTransposeTranspose", {2, 2});
|
||||
|
||||
auto* add_op = new toco::AddOperator;
|
||||
add_op->inputs = {"InputA", "InputB"};
|
||||
add_op->outputs = {"Input"};
|
||||
model_->operators.push_back(std::unique_ptr<toco::Operator>(add_op));
|
||||
|
||||
auto* transpose_op = new toco::TransposeOperator;
|
||||
transpose_op->inputs = {"Input"};
|
||||
transpose_op->perm = {0, 2, 1};
|
||||
transpose_op->outputs = {"InputTranspose"};
|
||||
model_->operators.push_back(std::unique_ptr<toco::Operator>(transpose_op));
|
||||
|
||||
auto* transpose2_op = new toco::TransposeOperator;
|
||||
transpose2_op->inputs = {"InputTranspose"};
|
||||
transpose2_op->perm = {0, 2, 1};
|
||||
transpose2_op->outputs = {"InputTransposeTranspose"};
|
||||
model_->operators.push_back(std::unique_ptr<toco::Operator>(transpose2_op));
|
||||
|
||||
toco::RemoveSuccesiveTranspose transformation;
|
||||
bool modified;
|
||||
ASSERT_TRUE(transformation.Run(model_.get(), /*op_index=*/1, &modified).ok());
|
||||
EXPECT_FALSE(modified);
|
||||
}
|
||||
} // namespace
|
@ -67,6 +67,7 @@ void MakeGeneralGraphTransformationsSet(
|
||||
transformations->Add(new PropagateActivationFunctionIntoConstants);
|
||||
transformations->Add(new PropagateArrayDataTypes);
|
||||
transformations->Add(new PropagateFixedSizes);
|
||||
transformations->Add(new RemoveSuccesiveTranspose);
|
||||
transformations->Add(new RemoveTensorFlowAssert);
|
||||
transformations->Add(new RemoveTensorFlowIdentity);
|
||||
transformations->Add(new RemoveTrivialConcatenation);
|
||||
|
Loading…
Reference in New Issue
Block a user