Remove cancelling pairs of transposes that are separated by a non-branching chain of ops that preserve value, order, and shape. Off by default.

PiperOrigin-RevId: 196183111
This commit is contained in:
A. Unique TensorFlower 2018-05-10 15:43:55 -07:00 committed by TensorFlower Gardener
parent ff7f7a566b
commit f7e24ab111
2 changed files with 89 additions and 16 deletions

View File

@ -254,6 +254,17 @@ NodeDef* GetTailOfValuePreservingChain(
is_value_preserving_non_branching);
}
NodeDef* GetTailOfIdempotentChain(
const NodeDef& node, const NodeMap& node_map,
const std::unordered_set<string>& nodes_to_preserve) {
auto is_idempotent_non_branching = [&](const NodeDef& node) {
return nodes_to_preserve.find(node.name()) == nodes_to_preserve.end() &&
IsIdempotent(node) && NumNonControlOutputs(node, node_map) == 1;
};
return GetTailOfChain(node, node_map, /*follow_control_input=*/false,
is_idempotent_non_branching);
}
// Graph optimizer context extension specific to ArithmeticOptimizer.
struct ArithmeticOptimizerContext {
explicit ArithmeticOptimizerContext(SetVector<NodeDef*>* nodes_to_simplify)
@ -1149,21 +1160,27 @@ class MinimizeBroadcasts : public ArithmeticNodesGroupOptimizerStage {
class RemoveIdentityTranspose : public ArithmeticOptimizerStage {
public:
explicit RemoveIdentityTranspose(const GraphOptimizerContext& ctx,
const ArithmeticOptimizerContext& ctx_ext)
: ArithmeticOptimizerStage("RemoveIdentityTranspose", ctx, ctx_ext) {}
const ArithmeticOptimizerContext& ctx_ext,
RewriterConfig::Toggle opt_level)
: ArithmeticOptimizerStage("RemoveIdentityTranspose", ctx, ctx_ext),
opt_level_(opt_level) {}
~RemoveIdentityTranspose() override = default;
bool IsSupported(const NodeDef* node) const override {
return IsTranspose(*node) || IsConjugateTranspose(*node);
}
// TODO(rmlarsen): Forward control dependencies on the bypassed
// transpose nodes.
Status TrySimplify(NodeDef* node, string* simplified_node_name) override {
TF_RETURN_IF_ERROR(EnsureNodeIsSupported(node));
NodeDef* tail = node;
// TODO(rmlarsen): Enable in regular mode after May 15, 2018.
if (opt_level_ == RewriterConfig::AGGRESSIVE) {
tail = GetTailOfIdempotentChain(*tail, *ctx().node_map,
*ctx().nodes_to_preserve);
}
NodeDef* first_transpose;
TF_RETURN_IF_ERROR(GetInputNode(tail->input(0), &first_transpose));
NodeDef* input;
TF_RETURN_IF_ERROR(GetInputNode(node->input(0), &input));
NodeDef* node_perm;
TF_RETURN_IF_ERROR(GetInputNode(node->input(1), &node_perm));
if (!IsConstant(*node_perm)) {
@ -1171,17 +1188,30 @@ class RemoveIdentityTranspose : public ArithmeticOptimizerStage {
}
std::vector<int64> node_perm_values;
TF_RETURN_IF_ERROR(GetPermutation(*node_perm, &node_perm_values));
if (input->op() == node->op()) {
if (first_transpose->op() == node->op()) {
// Remove pairs of transposes that cancel each other.
NodeDef* input_perm;
TF_RETURN_IF_ERROR(GetInputNode(input->input(1), &input_perm));
if (!IsConstant(*input_perm)) {
NodeDef* first_transpose_perm;
TF_RETURN_IF_ERROR(
GetInputNode(first_transpose->input(1), &first_transpose_perm));
if (!IsConstant(*first_transpose_perm)) {
return Status::OK();
}
std::vector<int64> input_perm_values;
TF_RETURN_IF_ERROR(GetPermutation(*input_perm, &input_perm_values));
if (AreInversePermutations(node_perm_values, input_perm_values)) {
*simplified_node_name = input->input(0);
std::vector<int64> first_transpose_perm_values;
TF_RETURN_IF_ERROR(
GetPermutation(*first_transpose_perm, &first_transpose_perm_values));
if (AreInversePermutations(node_perm_values,
first_transpose_perm_values)) {
if (tail == node) {
// Bypass adjacent pair.
*simplified_node_name = first_transpose->input(0);
} else {
// Bypass pair connected through chain.
tail->set_input(0, first_transpose->input(0));
ctx().node_map->UpdateInput(tail->name(), first_transpose->name(),
first_transpose->input(0));
ForwardControlDependencies(tail, {first_transpose});
*simplified_node_name = node->input(0);
}
}
} else {
// Remove simple identity transposes.
@ -1231,6 +1261,8 @@ class RemoveIdentityTranspose : public ArithmeticOptimizerStage {
}
return true;
}
RewriterConfig::Toggle opt_level_;
};
// Remove redundant Bitcasts.
@ -2401,7 +2433,7 @@ Status ArithmeticOptimizer::SimplifyArithmeticOps(bool can_use_shapes) {
if (options_.minimize_broadcasts && can_use_shapes)
pipeline.AddStage<MinimizeBroadcasts>(ctx, ctx_ext);
if (options_.remove_identity_transpose && can_use_shapes)
pipeline.AddStage<RemoveIdentityTranspose>(ctx, ctx_ext);
pipeline.AddStage<RemoveIdentityTranspose>(ctx, ctx_ext, opt_level_);
if (options_.remove_redundant_bitcast)
pipeline.AddStage<RemoveRedundantBitcastStage>(ctx, ctx_ext);
if (options_.remove_redundant_cast)

View File

@ -1122,7 +1122,7 @@ TEST_F(ArithmeticOptimizerTest, RemoveIdentityTransposes) {
ops::RandomUniform(s.WithOpName("inputs"), inputs_shape, DT_FLOAT);
Output perm1 = ops::Const(s.WithOpName("perm1"), {0, 2, 3, 1}, {4});
Output perm2 = ops::Const(s.WithOpName("perm2"), {0, 3, 1, 2}, {4});
Output perm3 = ops::Const(s.WithOpName("perm2"), {0, 1, 2, 3}, {4});
Output perm3 = ops::Const(s.WithOpName("perm3"), {0, 1, 2, 3}, {4});
Output transpose1 = ops::Transpose(s.WithOpName("transpose1"), inputs, perm1);
Output transpose2 =
ops::Transpose(s.WithOpName("transpose2"), transpose1, perm2);
@ -1248,6 +1248,47 @@ TEST_F(ArithmeticOptimizerTest, NotRemoveTransposes) {
EXPECT_EQ(6, output.node_size());
}
TEST_F(ArithmeticOptimizerTest, RemoveIdentityTransposesThroughChain) {
tensorflow::Scope s = tensorflow::Scope::NewRootScope();
Output inputs_shape =
ops::Const(s.WithOpName("inputs_shape"), {8, 3, 28, 28}, {4});
Output inputs =
ops::RandomUniform(s.WithOpName("inputs"), inputs_shape, DT_FLOAT);
Output perm1 = ops::Const(s.WithOpName("perm1"), {0, 2, 3, 1}, {4});
Output perm2 = ops::Const(s.WithOpName("perm2"), {0, 3, 1, 2}, {4});
Output transpose1 = ops::Transpose(
s.WithOpName("transpose1").WithControlDependencies(perm2), inputs, perm1);
Output identity = ops::Identity(s.WithOpName("id"), transpose1);
Output transpose2 =
ops::Transpose(s.WithOpName("transpose2"), identity, perm2);
Output id1 = ops::Identity(s.WithOpName("id1"), transpose2);
GrapplerItem item;
item.fetch = {"id1"};
TF_CHECK_OK(s.ToGraphDef(&item.graph));
GraphDef output;
ArithmeticOptimizer optimizer(RewriterConfig::AGGRESSIVE);
EnableOnlyRemoveIdentityTranspose(&optimizer);
OptimizeAndPrune(&optimizer, &item, &output);
std::set<string> nodes_after_optimization;
for (const NodeDef& node : output.node()) {
nodes_after_optimization.insert(node.name());
if (node.name() == "id") {
EXPECT_EQ(2, node.input_size());
EXPECT_EQ("inputs", node.input(0));
EXPECT_EQ("^perm2", node.input(1));
}
if (node.name() == "id1") {
EXPECT_EQ(1, node.input_size());
EXPECT_EQ("id", node.input(0));
}
}
EXPECT_EQ(nodes_after_optimization,
std::set<string>({"id", "id1", "inputs_shape", "inputs", "perm2"}));
}
TEST_F(ArithmeticOptimizerTest, FoldMulToTransposeConv) {
tensorflow::Scope s = tensorflow::Scope::NewRootScope();
Output inputs = ops::Placeholder(s.WithOpName("inputs"), DT_FLOAT,