diff --git a/tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc b/tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc index 67df01a3658..9e16b3e2101 100644 --- a/tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc +++ b/tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc @@ -2142,93 +2142,6 @@ class ReorderCastLikeAndValuePreserving : public ArithmeticOptimizerStage { } }; -// Reorder redundant reshapes around a single unary element-wise op, i.e., -// -// input -> reshape A -> unary -> reshape B -> output -// -// becomes -// -// input -> unary -> reshape A -> reshape B -> output -// -// We conservatively consider reshapes to be redundant only if: -// 1) The input shape of A is equal to the output shape of B. -// 2) Both A and unary have a single output. -// -// A later pass (RemoveRedundantReshapeOrBroadcastTo) will remove both reshapes -// -class ReorderRedundantReshapeAroundUnary : public ArithmeticOptimizerStage { - public: - explicit ReorderRedundantReshapeAroundUnary( - const GraphOptimizerContext& ctx, - const ArithmeticOptimizerContext& ctx_ext) - : ArithmeticOptimizerStage("ReorderRedundantReshapeAroundUnary", ctx, - ctx_ext) {} - - ~ReorderRedundantReshapeAroundUnary() override = default; - - bool IsSupported(const NodeDef* node) const override { - return IsReshape(*node) && !IsInPreserveSet(*node); - } - - Status TrySimplify(NodeDef* node, string* simplified_node_name) override { - // Check that we have a chain of (reshape -> unary -> reshape), with no - // additional outputs on either the first reshape or unary op - NodeDef* head = node; - if (!IsReshape(*head) || IsInPreserveSet(*head)) { - return Status::OK(); - } - - NodeDef* unary; - TF_RETURN_IF_ERROR(GetInputNode(head->input(0), &unary)); - if (!IsUnaryElementWise(*unary) || IsInPreserveSet(*unary) || - NumNonControlOutputs(*unary, *ctx().node_map) != 1) { - return Status::OK(); - } - - NodeDef* tail; - TF_RETURN_IF_ERROR(GetInputNode(unary->input(0), &tail)); - if (!IsReshape(*tail) || IsInPreserveSet(*tail) || - NumNonControlOutputs(*tail, *ctx().node_map) != 1) { - return Status::OK(); - } - - // The reshapes are a no-op if the input and output shapes match - NodeDef* input; - TF_RETURN_IF_ERROR(GetInputNode(tail->input(0), &input)); - if (!InputMatchesOutputShape(*input, *head)) { - VLOG(3) << "Input and output shapes are unequal: input=" << input->name() - << ", output=" << head->name(); - return Status::OK(); - } - - // Swap `unary` and `tail` reshape - unary->set_input(0, input->name()); - ctx().node_map->UpdateInput(unary->name(), tail->name(), input->name()); - tail->set_input(0, unary->name()); - ctx().node_map->UpdateInput(tail->name(), input->name(), unary->name()); - head->set_input(0, tail->name()); - ctx().node_map->UpdateInput(head->name(), unary->name(), tail->name()); - - *simplified_node_name = node->name(); - AddToOptimizationQueue(node); - return Status::OK(); - } - - private: - // Returns whether the input shape of the first op matches the output shape of - // the second op. - bool InputMatchesOutputShape(const NodeDef& input, const NodeDef& output) { - const OpInfo::TensorProperties* input_props; - const OpInfo::TensorProperties* output_props; - if (!GetTensorProperties(input.name(), &input_props).ok() || - !GetTensorProperties(output.name(), &output_props).ok()) { - return false; - } - - return ShapesSymbolicallyEqual(input_props->shape(), output_props->shape()); - } -}; - // Fold a multiply of a scalar into the following convolution. This folding // can jump across nodes that merely reorders data (such as reshape and // transpose). For example, we can optimize @@ -3919,8 +3832,6 @@ Status ArithmeticOptimizer::SimplifyArithmeticOps(bool can_use_shapes) { pipeline.AddStage<RemoveLogicalNotStage>(ctx, ctx_ext); if (options_.reorder_cast_like_and_value_preserving) pipeline.AddStage<ReorderCastLikeAndValuePreserving>(ctx, ctx_ext); - if (options_.reorder_redundant_reshape_around_unary) - pipeline.AddStage<ReorderRedundantReshapeAroundUnary>(ctx, ctx_ext); if (options_.simplify_aggregation) pipeline.AddStage<SimplifyAggregation>(ctx, ctx_ext); if (options_.hoist_cwise_unary_chains) diff --git a/tensorflow/core/grappler/optimizers/arithmetic_optimizer.h b/tensorflow/core/grappler/optimizers/arithmetic_optimizer.h index 373e5004d8d..d9f03ef38c1 100644 --- a/tensorflow/core/grappler/optimizers/arithmetic_optimizer.h +++ b/tensorflow/core/grappler/optimizers/arithmetic_optimizer.h @@ -78,7 +78,6 @@ class ArithmeticOptimizer : public GraphOptimizer { bool remove_redundant_cast = true; bool remove_redundant_reshape = true; bool reorder_cast_like_and_value_preserving = true; - bool reorder_redundant_reshape_around_unary = true; bool replace_mul_with_tile = true; bool replace_mul_with_square = true; bool simplify_aggregation = true; diff --git a/tensorflow/core/grappler/optimizers/arithmetic_optimizer_test.cc b/tensorflow/core/grappler/optimizers/arithmetic_optimizer_test.cc index 293e017bbf5..154196c954f 100644 --- a/tensorflow/core/grappler/optimizers/arithmetic_optimizer_test.cc +++ b/tensorflow/core/grappler/optimizers/arithmetic_optimizer_test.cc @@ -1003,103 +1003,6 @@ TEST_F(ArithmeticOptimizerTest, FoldConjugateTransposeIntoBatchMatMul) { test::ExpectTensorNear<complex64>(tensors[0], tensors_expected[0], 1e-6); } -TEST_F(ArithmeticOptimizerTest, RemoveRedundantReshapeAroundUnary) { - tensorflow::Scope s = tensorflow::Scope::NewRootScope(); - Output inputs = - ops::Placeholder(s, DT_FLOAT, ops::Placeholder::Shape({1, 300, 300, 1})); - Output reshape0 = ops::Reshape(s.WithOpName("Reshape0"), inputs, - ops::Const(s, {1, 90000, 1}, {3})); - Output unary = ops::Sigmoid(s, reshape0); - Output reshape1 = ops::Reshape(s.WithOpName("Reshape1"), unary, - ops::Const(s, {1, 300, 300, 1}, {4})); - Output outputs = ops::Identity(s.WithOpName("outputs"), reshape1); - - GrapplerItem item; - item.fetch = {"outputs"}; - TF_CHECK_OK(s.ToGraphDef(&item.graph)); - auto t = GenerateRandomTensor<DT_FLOAT>(TensorShape({1, 300, 300, 1})); - auto expected = EvaluateNodes(item.graph, item.fetch, {{"Placeholder", t}}); - ASSERT_EQ(expected.size(), 1); - - GraphDef output; - ArithmeticOptimizer optimizer; - EnableOnlyReorderRedundantReshapeAroundUnary(&optimizer); - OptimizeTwiceAndPrune(&optimizer, &item, &output); - EXPECT_EQ(CountOpNodes(output, "Reshape"), 2); - - // Reshapes should be removed after pruning - EnableOnlyRemoveRedundantReshape(&optimizer); - OptimizeTwiceAndPrune(&optimizer, &item, &output); - EXPECT_EQ(CountOpNodes(output, "Reshape"), 0); - - auto actual = EvaluateNodes(output, item.fetch, {{"Placeholder", t}}); - ASSERT_EQ(actual.size(), 1); - test::ExpectTensorNear<float>(actual[0], expected[0], 1e-6); -} - -TEST_F(ArithmeticOptimizerTest, RemoveRedundantReshapeAroundUnaryNotOutput) { - tensorflow::Scope s = tensorflow::Scope::NewRootScope(); - Output inputs = - ops::Placeholder(s, DT_FLOAT, ops::Placeholder::Shape({1, 300, 300, 1})); - Output reshape0 = ops::Reshape(s, inputs, ops::Const(s, {1, 90000, 1}, {3})); - Output unary = ops::Sigmoid(s.WithOpName("sigmoid"), reshape0); - Output reshape1 = - ops::Reshape(s, unary, ops::Const(s, {1, 300, 300, 1}, {4})); - Output outputs = ops::Identity(s.WithOpName("output"), reshape1); - - GrapplerItem item; - item.fetch = {"output"}; - item.keep_ops = {"sigmoid"}; - TF_CHECK_OK(s.ToGraphDef(&item.graph)); - auto t = GenerateRandomTensor<DT_FLOAT>(TensorShape({1, 300, 300, 1})); - auto expected = EvaluateNodes(item.graph, item.fetch, {{"Placeholder", t}}); - ASSERT_EQ(expected.size(), 1); - - // Reshape should not be moved since unary is a keep op - GraphDef output; - ArithmeticOptimizer optimizer; - EnableOnlyReorderRedundantReshapeAroundUnary(&optimizer); - OptimizeTwiceAndPrune(&optimizer, &item, &output); - EnableOnlyRemoveRedundantReshape(&optimizer); - OptimizeTwiceAndPrune(&optimizer, &item, &output); - - EXPECT_EQ(CountOpNodes(output, "Reshape"), 2); - auto actual = EvaluateNodes(output, item.fetch, {{"Placeholder", t}}); - ASSERT_EQ(actual.size(), 1); - test::ExpectTensorNear<float>(actual[0], expected[0], 1e-6); -} - -TEST_F(ArithmeticOptimizerTest, RemoveRedundantReshapeAroundUnaryNotIdentity) { - tensorflow::Scope s = tensorflow::Scope::NewRootScope(); - Output inputs = - ops::Placeholder(s, DT_FLOAT, ops::Placeholder::Shape({1, 300, 300, 1})); - Output reshape0 = ops::Reshape(s, inputs, ops::Const(s, {1, 90000, 1}, {3})); - Output unary = ops::Sigmoid(s, reshape0); - // [1, 300, 300, 1] is not equivalent to [1, 300, 1, 300] - Output reshape1 = - ops::Reshape(s, unary, ops::Const(s, {1, 300, 1, 300}, {4})); - Output outputs = ops::Identity(s.WithOpName("outputs"), reshape1); - - GrapplerItem item; - item.fetch = {"outputs"}; - TF_CHECK_OK(s.ToGraphDef(&item.graph)); - auto t = GenerateRandomTensor<DT_FLOAT>(TensorShape({1, 300, 300, 1})); - auto expected = EvaluateNodes(item.graph, item.fetch, {{"Placeholder", t}}); - ASSERT_EQ(expected.size(), 1); - - GraphDef output; - ArithmeticOptimizer optimizer; - EnableOnlyReorderRedundantReshapeAroundUnary(&optimizer); - OptimizeTwiceAndPrune(&optimizer, &item, &output); - EnableOnlyRemoveRedundantReshape(&optimizer); - OptimizeTwiceAndPrune(&optimizer, &item, &output); - - EXPECT_EQ(CountOpNodes(output, "Reshape"), 2); - auto actual = EvaluateNodes(output, item.fetch, {{"Placeholder", t}}); - ASSERT_EQ(actual.size(), 1); - test::ExpectTensorNear<float>(actual[0], expected[0], 1e-6); -} - TEST_F(ArithmeticOptimizerTest, RemoveRedundantReshapeIdentityReshape) { for (bool is_broadcastto : {false, true}) { tensorflow::Scope s = tensorflow::Scope::NewRootScope(); diff --git a/tensorflow/core/grappler/optimizers/arithmetic_optimizer_test_utils.h b/tensorflow/core/grappler/optimizers/arithmetic_optimizer_test_utils.h index f2600cf79ad..71c7ef564ae 100644 --- a/tensorflow/core/grappler/optimizers/arithmetic_optimizer_test_utils.h +++ b/tensorflow/core/grappler/optimizers/arithmetic_optimizer_test_utils.h @@ -138,12 +138,6 @@ class ArithmeticOptimizerTest : public GrapplerTest { optimizer->options_.remove_redundant_cast = true; } - void EnableOnlyReorderRedundantReshapeAroundUnary( - ArithmeticOptimizer* optimizer) { - DisableAllStages(optimizer); - optimizer->options_.reorder_redundant_reshape_around_unary = true; - } - void EnableOnlyRemoveRedundantReshape(ArithmeticOptimizer* optimizer) { DisableAllStages(optimizer); optimizer->options_.remove_redundant_reshape = true;