Remove cancelling pairs of transposes that are separated by a non-branching chain of ops that preserve value, order, and shape. Off by default.
PiperOrigin-RevId: 196183111
This commit is contained in:
parent
ff7f7a566b
commit
f7e24ab111
@ -254,6 +254,17 @@ NodeDef* GetTailOfValuePreservingChain(
|
||||
is_value_preserving_non_branching);
|
||||
}
|
||||
|
||||
NodeDef* GetTailOfIdempotentChain(
|
||||
const NodeDef& node, const NodeMap& node_map,
|
||||
const std::unordered_set<string>& nodes_to_preserve) {
|
||||
auto is_idempotent_non_branching = [&](const NodeDef& node) {
|
||||
return nodes_to_preserve.find(node.name()) == nodes_to_preserve.end() &&
|
||||
IsIdempotent(node) && NumNonControlOutputs(node, node_map) == 1;
|
||||
};
|
||||
return GetTailOfChain(node, node_map, /*follow_control_input=*/false,
|
||||
is_idempotent_non_branching);
|
||||
}
|
||||
|
||||
// Graph optimizer context extension specific to ArithmeticOptimizer.
|
||||
struct ArithmeticOptimizerContext {
|
||||
explicit ArithmeticOptimizerContext(SetVector<NodeDef*>* nodes_to_simplify)
|
||||
@ -1149,21 +1160,27 @@ class MinimizeBroadcasts : public ArithmeticNodesGroupOptimizerStage {
|
||||
class RemoveIdentityTranspose : public ArithmeticOptimizerStage {
|
||||
public:
|
||||
explicit RemoveIdentityTranspose(const GraphOptimizerContext& ctx,
|
||||
const ArithmeticOptimizerContext& ctx_ext)
|
||||
: ArithmeticOptimizerStage("RemoveIdentityTranspose", ctx, ctx_ext) {}
|
||||
const ArithmeticOptimizerContext& ctx_ext,
|
||||
RewriterConfig::Toggle opt_level)
|
||||
: ArithmeticOptimizerStage("RemoveIdentityTranspose", ctx, ctx_ext),
|
||||
opt_level_(opt_level) {}
|
||||
~RemoveIdentityTranspose() override = default;
|
||||
|
||||
bool IsSupported(const NodeDef* node) const override {
|
||||
return IsTranspose(*node) || IsConjugateTranspose(*node);
|
||||
}
|
||||
|
||||
// TODO(rmlarsen): Forward control dependencies on the bypassed
|
||||
// transpose nodes.
|
||||
Status TrySimplify(NodeDef* node, string* simplified_node_name) override {
|
||||
TF_RETURN_IF_ERROR(EnsureNodeIsSupported(node));
|
||||
NodeDef* tail = node;
|
||||
// TODO(rmlarsen): Enable in regular mode after May 15, 2018.
|
||||
if (opt_level_ == RewriterConfig::AGGRESSIVE) {
|
||||
tail = GetTailOfIdempotentChain(*tail, *ctx().node_map,
|
||||
*ctx().nodes_to_preserve);
|
||||
}
|
||||
NodeDef* first_transpose;
|
||||
TF_RETURN_IF_ERROR(GetInputNode(tail->input(0), &first_transpose));
|
||||
|
||||
NodeDef* input;
|
||||
TF_RETURN_IF_ERROR(GetInputNode(node->input(0), &input));
|
||||
NodeDef* node_perm;
|
||||
TF_RETURN_IF_ERROR(GetInputNode(node->input(1), &node_perm));
|
||||
if (!IsConstant(*node_perm)) {
|
||||
@ -1171,17 +1188,30 @@ class RemoveIdentityTranspose : public ArithmeticOptimizerStage {
|
||||
}
|
||||
std::vector<int64> node_perm_values;
|
||||
TF_RETURN_IF_ERROR(GetPermutation(*node_perm, &node_perm_values));
|
||||
if (input->op() == node->op()) {
|
||||
if (first_transpose->op() == node->op()) {
|
||||
// Remove pairs of transposes that cancel each other.
|
||||
NodeDef* input_perm;
|
||||
TF_RETURN_IF_ERROR(GetInputNode(input->input(1), &input_perm));
|
||||
if (!IsConstant(*input_perm)) {
|
||||
NodeDef* first_transpose_perm;
|
||||
TF_RETURN_IF_ERROR(
|
||||
GetInputNode(first_transpose->input(1), &first_transpose_perm));
|
||||
if (!IsConstant(*first_transpose_perm)) {
|
||||
return Status::OK();
|
||||
}
|
||||
std::vector<int64> input_perm_values;
|
||||
TF_RETURN_IF_ERROR(GetPermutation(*input_perm, &input_perm_values));
|
||||
if (AreInversePermutations(node_perm_values, input_perm_values)) {
|
||||
*simplified_node_name = input->input(0);
|
||||
std::vector<int64> first_transpose_perm_values;
|
||||
TF_RETURN_IF_ERROR(
|
||||
GetPermutation(*first_transpose_perm, &first_transpose_perm_values));
|
||||
if (AreInversePermutations(node_perm_values,
|
||||
first_transpose_perm_values)) {
|
||||
if (tail == node) {
|
||||
// Bypass adjacent pair.
|
||||
*simplified_node_name = first_transpose->input(0);
|
||||
} else {
|
||||
// Bypass pair connected through chain.
|
||||
tail->set_input(0, first_transpose->input(0));
|
||||
ctx().node_map->UpdateInput(tail->name(), first_transpose->name(),
|
||||
first_transpose->input(0));
|
||||
ForwardControlDependencies(tail, {first_transpose});
|
||||
*simplified_node_name = node->input(0);
|
||||
}
|
||||
}
|
||||
} else {
|
||||
// Remove simple identity transposes.
|
||||
@ -1231,6 +1261,8 @@ class RemoveIdentityTranspose : public ArithmeticOptimizerStage {
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
RewriterConfig::Toggle opt_level_;
|
||||
};
|
||||
|
||||
// Remove redundant Bitcasts.
|
||||
@ -2401,7 +2433,7 @@ Status ArithmeticOptimizer::SimplifyArithmeticOps(bool can_use_shapes) {
|
||||
if (options_.minimize_broadcasts && can_use_shapes)
|
||||
pipeline.AddStage<MinimizeBroadcasts>(ctx, ctx_ext);
|
||||
if (options_.remove_identity_transpose && can_use_shapes)
|
||||
pipeline.AddStage<RemoveIdentityTranspose>(ctx, ctx_ext);
|
||||
pipeline.AddStage<RemoveIdentityTranspose>(ctx, ctx_ext, opt_level_);
|
||||
if (options_.remove_redundant_bitcast)
|
||||
pipeline.AddStage<RemoveRedundantBitcastStage>(ctx, ctx_ext);
|
||||
if (options_.remove_redundant_cast)
|
||||
|
@ -1122,7 +1122,7 @@ TEST_F(ArithmeticOptimizerTest, RemoveIdentityTransposes) {
|
||||
ops::RandomUniform(s.WithOpName("inputs"), inputs_shape, DT_FLOAT);
|
||||
Output perm1 = ops::Const(s.WithOpName("perm1"), {0, 2, 3, 1}, {4});
|
||||
Output perm2 = ops::Const(s.WithOpName("perm2"), {0, 3, 1, 2}, {4});
|
||||
Output perm3 = ops::Const(s.WithOpName("perm2"), {0, 1, 2, 3}, {4});
|
||||
Output perm3 = ops::Const(s.WithOpName("perm3"), {0, 1, 2, 3}, {4});
|
||||
Output transpose1 = ops::Transpose(s.WithOpName("transpose1"), inputs, perm1);
|
||||
Output transpose2 =
|
||||
ops::Transpose(s.WithOpName("transpose2"), transpose1, perm2);
|
||||
@ -1248,6 +1248,47 @@ TEST_F(ArithmeticOptimizerTest, NotRemoveTransposes) {
|
||||
EXPECT_EQ(6, output.node_size());
|
||||
}
|
||||
|
||||
TEST_F(ArithmeticOptimizerTest, RemoveIdentityTransposesThroughChain) {
|
||||
tensorflow::Scope s = tensorflow::Scope::NewRootScope();
|
||||
Output inputs_shape =
|
||||
ops::Const(s.WithOpName("inputs_shape"), {8, 3, 28, 28}, {4});
|
||||
Output inputs =
|
||||
ops::RandomUniform(s.WithOpName("inputs"), inputs_shape, DT_FLOAT);
|
||||
Output perm1 = ops::Const(s.WithOpName("perm1"), {0, 2, 3, 1}, {4});
|
||||
Output perm2 = ops::Const(s.WithOpName("perm2"), {0, 3, 1, 2}, {4});
|
||||
Output transpose1 = ops::Transpose(
|
||||
s.WithOpName("transpose1").WithControlDependencies(perm2), inputs, perm1);
|
||||
Output identity = ops::Identity(s.WithOpName("id"), transpose1);
|
||||
Output transpose2 =
|
||||
ops::Transpose(s.WithOpName("transpose2"), identity, perm2);
|
||||
Output id1 = ops::Identity(s.WithOpName("id1"), transpose2);
|
||||
|
||||
GrapplerItem item;
|
||||
item.fetch = {"id1"};
|
||||
TF_CHECK_OK(s.ToGraphDef(&item.graph));
|
||||
|
||||
GraphDef output;
|
||||
ArithmeticOptimizer optimizer(RewriterConfig::AGGRESSIVE);
|
||||
EnableOnlyRemoveIdentityTranspose(&optimizer);
|
||||
OptimizeAndPrune(&optimizer, &item, &output);
|
||||
|
||||
std::set<string> nodes_after_optimization;
|
||||
for (const NodeDef& node : output.node()) {
|
||||
nodes_after_optimization.insert(node.name());
|
||||
if (node.name() == "id") {
|
||||
EXPECT_EQ(2, node.input_size());
|
||||
EXPECT_EQ("inputs", node.input(0));
|
||||
EXPECT_EQ("^perm2", node.input(1));
|
||||
}
|
||||
if (node.name() == "id1") {
|
||||
EXPECT_EQ(1, node.input_size());
|
||||
EXPECT_EQ("id", node.input(0));
|
||||
}
|
||||
}
|
||||
EXPECT_EQ(nodes_after_optimization,
|
||||
std::set<string>({"id", "id1", "inputs_shape", "inputs", "perm2"}));
|
||||
}
|
||||
|
||||
TEST_F(ArithmeticOptimizerTest, FoldMulToTransposeConv) {
|
||||
tensorflow::Scope s = tensorflow::Scope::NewRootScope();
|
||||
Output inputs = ops::Placeholder(s.WithOpName("inputs"), DT_FLOAT,
|
||||
|
Loading…
x
Reference in New Issue
Block a user