Add optimization to reorder redundant reshapes around unary ops.

PiperOrigin-RevId: 358239215
Change-Id: Ie0723ad4322fd0099185bcf5c92258ea797f917c
This commit is contained in:
A. Unique TensorFlower 2021-02-18 12:33:01 -08:00 committed by TensorFlower Gardener
parent a98039c09e
commit 9378d6c3db
4 changed files with 0 additions and 193 deletions

View File

@ -2142,93 +2142,6 @@ class ReorderCastLikeAndValuePreserving : public ArithmeticOptimizerStage {
}
};
// Reorder redundant reshapes around a single unary element-wise op, i.e.,
//
// input -> reshape A -> unary -> reshape B -> output
//
// becomes
//
// input -> unary -> reshape A -> reshape B -> output
//
// We conservatively consider reshapes to be redundant only if:
// 1) The input shape of A is equal to the output shape of B.
// 2) Both A and unary have a single output.
//
// A later pass (RemoveRedundantReshapeOrBroadcastTo) will remove both reshapes
//
class ReorderRedundantReshapeAroundUnary : public ArithmeticOptimizerStage {
public:
explicit ReorderRedundantReshapeAroundUnary(
const GraphOptimizerContext& ctx,
const ArithmeticOptimizerContext& ctx_ext)
: ArithmeticOptimizerStage("ReorderRedundantReshapeAroundUnary", ctx,
ctx_ext) {}
~ReorderRedundantReshapeAroundUnary() override = default;
bool IsSupported(const NodeDef* node) const override {
return IsReshape(*node) && !IsInPreserveSet(*node);
}
Status TrySimplify(NodeDef* node, string* simplified_node_name) override {
// Check that we have a chain of (reshape -> unary -> reshape), with no
// additional outputs on either the first reshape or unary op
NodeDef* head = node;
if (!IsReshape(*head) || IsInPreserveSet(*head)) {
return Status::OK();
}
NodeDef* unary;
TF_RETURN_IF_ERROR(GetInputNode(head->input(0), &unary));
if (!IsUnaryElementWise(*unary) || IsInPreserveSet(*unary) ||
NumNonControlOutputs(*unary, *ctx().node_map) != 1) {
return Status::OK();
}
NodeDef* tail;
TF_RETURN_IF_ERROR(GetInputNode(unary->input(0), &tail));
if (!IsReshape(*tail) || IsInPreserveSet(*tail) ||
NumNonControlOutputs(*tail, *ctx().node_map) != 1) {
return Status::OK();
}
// The reshapes are a no-op if the input and output shapes match
NodeDef* input;
TF_RETURN_IF_ERROR(GetInputNode(tail->input(0), &input));
if (!InputMatchesOutputShape(*input, *head)) {
VLOG(3) << "Input and output shapes are unequal: input=" << input->name()
<< ", output=" << head->name();
return Status::OK();
}
// Swap `unary` and `tail` reshape
unary->set_input(0, input->name());
ctx().node_map->UpdateInput(unary->name(), tail->name(), input->name());
tail->set_input(0, unary->name());
ctx().node_map->UpdateInput(tail->name(), input->name(), unary->name());
head->set_input(0, tail->name());
ctx().node_map->UpdateInput(head->name(), unary->name(), tail->name());
*simplified_node_name = node->name();
AddToOptimizationQueue(node);
return Status::OK();
}
private:
// Returns whether the input shape of the first op matches the output shape of
// the second op.
bool InputMatchesOutputShape(const NodeDef& input, const NodeDef& output) {
const OpInfo::TensorProperties* input_props;
const OpInfo::TensorProperties* output_props;
if (!GetTensorProperties(input.name(), &input_props).ok() ||
!GetTensorProperties(output.name(), &output_props).ok()) {
return false;
}
return ShapesSymbolicallyEqual(input_props->shape(), output_props->shape());
}
};
// Fold a multiply of a scalar into the following convolution. This folding
// can jump across nodes that merely reorders data (such as reshape and
// transpose). For example, we can optimize
@ -3919,8 +3832,6 @@ Status ArithmeticOptimizer::SimplifyArithmeticOps(bool can_use_shapes) {
pipeline.AddStage<RemoveLogicalNotStage>(ctx, ctx_ext);
if (options_.reorder_cast_like_and_value_preserving)
pipeline.AddStage<ReorderCastLikeAndValuePreserving>(ctx, ctx_ext);
if (options_.reorder_redundant_reshape_around_unary)
pipeline.AddStage<ReorderRedundantReshapeAroundUnary>(ctx, ctx_ext);
if (options_.simplify_aggregation)
pipeline.AddStage<SimplifyAggregation>(ctx, ctx_ext);
if (options_.hoist_cwise_unary_chains)

View File

@ -78,7 +78,6 @@ class ArithmeticOptimizer : public GraphOptimizer {
bool remove_redundant_cast = true;
bool remove_redundant_reshape = true;
bool reorder_cast_like_and_value_preserving = true;
bool reorder_redundant_reshape_around_unary = true;
bool replace_mul_with_tile = true;
bool replace_mul_with_square = true;
bool simplify_aggregation = true;

View File

@ -1003,103 +1003,6 @@ TEST_F(ArithmeticOptimizerTest, FoldConjugateTransposeIntoBatchMatMul) {
test::ExpectTensorNear<complex64>(tensors[0], tensors_expected[0], 1e-6);
}
TEST_F(ArithmeticOptimizerTest, RemoveRedundantReshapeAroundUnary) {
tensorflow::Scope s = tensorflow::Scope::NewRootScope();
Output inputs =
ops::Placeholder(s, DT_FLOAT, ops::Placeholder::Shape({1, 300, 300, 1}));
Output reshape0 = ops::Reshape(s.WithOpName("Reshape0"), inputs,
ops::Const(s, {1, 90000, 1}, {3}));
Output unary = ops::Sigmoid(s, reshape0);
Output reshape1 = ops::Reshape(s.WithOpName("Reshape1"), unary,
ops::Const(s, {1, 300, 300, 1}, {4}));
Output outputs = ops::Identity(s.WithOpName("outputs"), reshape1);
GrapplerItem item;
item.fetch = {"outputs"};
TF_CHECK_OK(s.ToGraphDef(&item.graph));
auto t = GenerateRandomTensor<DT_FLOAT>(TensorShape({1, 300, 300, 1}));
auto expected = EvaluateNodes(item.graph, item.fetch, {{"Placeholder", t}});
ASSERT_EQ(expected.size(), 1);
GraphDef output;
ArithmeticOptimizer optimizer;
EnableOnlyReorderRedundantReshapeAroundUnary(&optimizer);
OptimizeTwiceAndPrune(&optimizer, &item, &output);
EXPECT_EQ(CountOpNodes(output, "Reshape"), 2);
// Reshapes should be removed after pruning
EnableOnlyRemoveRedundantReshape(&optimizer);
OptimizeTwiceAndPrune(&optimizer, &item, &output);
EXPECT_EQ(CountOpNodes(output, "Reshape"), 0);
auto actual = EvaluateNodes(output, item.fetch, {{"Placeholder", t}});
ASSERT_EQ(actual.size(), 1);
test::ExpectTensorNear<float>(actual[0], expected[0], 1e-6);
}
TEST_F(ArithmeticOptimizerTest, RemoveRedundantReshapeAroundUnaryNotOutput) {
tensorflow::Scope s = tensorflow::Scope::NewRootScope();
Output inputs =
ops::Placeholder(s, DT_FLOAT, ops::Placeholder::Shape({1, 300, 300, 1}));
Output reshape0 = ops::Reshape(s, inputs, ops::Const(s, {1, 90000, 1}, {3}));
Output unary = ops::Sigmoid(s.WithOpName("sigmoid"), reshape0);
Output reshape1 =
ops::Reshape(s, unary, ops::Const(s, {1, 300, 300, 1}, {4}));
Output outputs = ops::Identity(s.WithOpName("output"), reshape1);
GrapplerItem item;
item.fetch = {"output"};
item.keep_ops = {"sigmoid"};
TF_CHECK_OK(s.ToGraphDef(&item.graph));
auto t = GenerateRandomTensor<DT_FLOAT>(TensorShape({1, 300, 300, 1}));
auto expected = EvaluateNodes(item.graph, item.fetch, {{"Placeholder", t}});
ASSERT_EQ(expected.size(), 1);
// Reshape should not be moved since unary is a keep op
GraphDef output;
ArithmeticOptimizer optimizer;
EnableOnlyReorderRedundantReshapeAroundUnary(&optimizer);
OptimizeTwiceAndPrune(&optimizer, &item, &output);
EnableOnlyRemoveRedundantReshape(&optimizer);
OptimizeTwiceAndPrune(&optimizer, &item, &output);
EXPECT_EQ(CountOpNodes(output, "Reshape"), 2);
auto actual = EvaluateNodes(output, item.fetch, {{"Placeholder", t}});
ASSERT_EQ(actual.size(), 1);
test::ExpectTensorNear<float>(actual[0], expected[0], 1e-6);
}
TEST_F(ArithmeticOptimizerTest, RemoveRedundantReshapeAroundUnaryNotIdentity) {
tensorflow::Scope s = tensorflow::Scope::NewRootScope();
Output inputs =
ops::Placeholder(s, DT_FLOAT, ops::Placeholder::Shape({1, 300, 300, 1}));
Output reshape0 = ops::Reshape(s, inputs, ops::Const(s, {1, 90000, 1}, {3}));
Output unary = ops::Sigmoid(s, reshape0);
// [1, 300, 300, 1] is not equivalent to [1, 300, 1, 300]
Output reshape1 =
ops::Reshape(s, unary, ops::Const(s, {1, 300, 1, 300}, {4}));
Output outputs = ops::Identity(s.WithOpName("outputs"), reshape1);
GrapplerItem item;
item.fetch = {"outputs"};
TF_CHECK_OK(s.ToGraphDef(&item.graph));
auto t = GenerateRandomTensor<DT_FLOAT>(TensorShape({1, 300, 300, 1}));
auto expected = EvaluateNodes(item.graph, item.fetch, {{"Placeholder", t}});
ASSERT_EQ(expected.size(), 1);
GraphDef output;
ArithmeticOptimizer optimizer;
EnableOnlyReorderRedundantReshapeAroundUnary(&optimizer);
OptimizeTwiceAndPrune(&optimizer, &item, &output);
EnableOnlyRemoveRedundantReshape(&optimizer);
OptimizeTwiceAndPrune(&optimizer, &item, &output);
EXPECT_EQ(CountOpNodes(output, "Reshape"), 2);
auto actual = EvaluateNodes(output, item.fetch, {{"Placeholder", t}});
ASSERT_EQ(actual.size(), 1);
test::ExpectTensorNear<float>(actual[0], expected[0], 1e-6);
}
TEST_F(ArithmeticOptimizerTest, RemoveRedundantReshapeIdentityReshape) {
for (bool is_broadcastto : {false, true}) {
tensorflow::Scope s = tensorflow::Scope::NewRootScope();

View File

@ -138,12 +138,6 @@ class ArithmeticOptimizerTest : public GrapplerTest {
optimizer->options_.remove_redundant_cast = true;
}
void EnableOnlyReorderRedundantReshapeAroundUnary(
ArithmeticOptimizer* optimizer) {
DisableAllStages(optimizer);
optimizer->options_.reorder_redundant_reshape_around_unary = true;
}
void EnableOnlyRemoveRedundantReshape(ArithmeticOptimizer* optimizer) {
DisableAllStages(optimizer);
optimizer->options_.remove_redundant_reshape = true;