From a8b9c83afe4cd0894c7d274ee3b667b2b93facbc Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Tue, 2 Jun 2020 17:51:00 -0700 Subject: [PATCH] [Grappler] Simplify Select nodes where the predicates have been constant folded to all true or all false. PiperOrigin-RevId: 314439636 Change-Id: I693dbc29360005982e75b110d5d2bd7cf8e4914c --- .../grappler/optimizers/constant_folding.cc | 37 +++++++++++ .../grappler/optimizers/constant_folding.h | 4 ++ .../optimizers/constant_folding_test.cc | 66 +++++++++++++++++++ 3 files changed, 107 insertions(+) diff --git a/tensorflow/core/grappler/optimizers/constant_folding.cc b/tensorflow/core/grappler/optimizers/constant_folding.cc index d0942471f13..89cdb308992 100644 --- a/tensorflow/core/grappler/optimizers/constant_folding.cc +++ b/tensorflow/core/grappler/optimizers/constant_folding.cc @@ -2032,6 +2032,8 @@ Status ConstantFolding::SimplifyNode(bool use_shape_info, NodeDef* node, SET_AND_RETURN_IF_MODIFIED( ConstantPushDownBiasAdd(properties, optimized_graph, node)); SET_AND_RETURN_IF_MODIFIED(SimplifyCase(optimized_graph, node)); + SET_AND_RETURN_IF_MODIFIED( + SimplifySelect(*properties, optimized_graph, node)); graph_modified_ = graph_modified_cached; return Status::OK(); @@ -2405,6 +2407,40 @@ bool ConstantFolding::SimplifyCase(GraphDef* optimized_graph, NodeDef* node) { return true; } +bool ConstantFolding::SimplifySelect(const GraphProperties& properties, + GraphDef* optimized_graph, NodeDef* node) { + if (!IsSelect(*node)) return false; + // Replace node with Identity if no broadcasting is involved. + // TODO(b/155503011): Add support for broadcast. + const std::vector& input_props = + properties.GetInputProperties(node->name()); + if (input_props.size() < 3) return false; + const TensorShapeProto& predicate_shape = input_props[0].shape(); + const bool predicate_is_scalar = + !predicate_shape.unknown_rank() && predicate_shape.dim_size() == 0; + if (!ShapesSymbolicallyEqual(input_props[1], input_props[2]) || + !(ShapesSymbolicallyEqual(input_props[0], input_props[1]) || + predicate_is_scalar)) { + return false; + } + const NodeDef* predicate_node = node_map_->GetNode(node->input(0)); + const bool is_all_true = IsOnes(*predicate_node); + const bool is_all_false = IsZeros(*predicate_node); + if (!is_all_true && !is_all_false) { + return false; + } + const int live_input_idx = is_all_true ? 1 : 2; + const int ignored_input_idx = is_all_true ? 2 : 1; + node->set_op("Identity"); + *node->mutable_input(0) = + AddControlDependency(node->input(0), optimized_graph, node_map_.get()); + *node->mutable_input(ignored_input_idx) = AddControlDependency( + node->input(ignored_input_idx), optimized_graph, node_map_.get()); + node->mutable_input()->SwapElements(0, live_input_idx); + DedupControlInputs(node); + return true; +} + bool ConstantFolding::MoveConstantsPastEnter(GraphDef* optimized_graph, NodeDef* node) { if (!IsEnter(*node) || node->input_size() == 0 || @@ -3771,6 +3807,7 @@ Status ConstantFolding::RunOptimizationPass(Cluster* cluster, /*include_output_tensor_values=*/true); const bool can_use_shape_info = s.ok(); + VLOG(1) << "can_use_shape_info = " << can_use_shape_info; absl::flat_hash_set nodes_to_not_simplify; if (can_use_shape_info) { diff --git a/tensorflow/core/grappler/optimizers/constant_folding.h b/tensorflow/core/grappler/optimizers/constant_folding.h index 88784339816..7a06cfc1e1a 100644 --- a/tensorflow/core/grappler/optimizers/constant_folding.h +++ b/tensorflow/core/grappler/optimizers/constant_folding.h @@ -282,6 +282,10 @@ class ConstantFolding : public GraphOptimizer { // Simplify a Case operation where the output_idx is known. bool SimplifyCase(GraphDef* optimized_graph, NodeDef* node); + // Simplify a Select operation where the predicates are all true or all false. + bool SimplifySelect(const GraphProperties& properties, + GraphDef* optimized_graph, NodeDef* node); + // Removes Reverse op over dimensions with size 1. Status RemoveReverse(const GraphProperties& properties, bool use_shape_info, GraphDef* optimized_graph, NodeDef* node); diff --git a/tensorflow/core/grappler/optimizers/constant_folding_test.cc b/tensorflow/core/grappler/optimizers/constant_folding_test.cc index 7e4a698fff6..1d8899de989 100644 --- a/tensorflow/core/grappler/optimizers/constant_folding_test.cc +++ b/tensorflow/core/grappler/optimizers/constant_folding_test.cc @@ -4145,6 +4145,72 @@ TEST_F(ConstantFoldingTest, SimplifyCase) { test::ExpectTensorEqual(tensors[0], tensors_expected[0]); } +TEST_F(ConstantFoldingTest, SimplifySelect) { + for (bool scalar_pred : {true, false}) { + for (bool pred_val : {true, false}) { + tensorflow::Scope scope = tensorflow::Scope::NewRootScope(); + std::unique_ptr if_t; + if (scalar_pred) { + if_t.reset(new Tensor(DT_BOOL, TensorShape())); + } else { + if_t.reset(new Tensor(DT_BOOL, TensorShape({2, 2}))); + } + for (int i = 0; i < (scalar_pred ? 1 : 4); ++i) { + if_t->flat()(i) = pred_val; + } + Output if_ = ops::Const(scope.WithOpName("if"), *if_t); + Output then_ = + ops::Placeholder(scope.WithOpName("then"), DT_FLOAT, + ops::Placeholder::Shape(TensorShape({2, 2}))); + Output else_ = + ops::Placeholder(scope.WithOpName("else"), DT_FLOAT, + ops::Placeholder::Shape(TensorShape({2, 2}))); + Output select = + ops::SelectV2(scope.WithOpName("select"), if_, then_, else_); + Output id = ops::Identity(scope.WithOpName("id"), select); + + GrapplerItem item; + TF_CHECK_OK(scope.ToGraphDef(&item.graph)); + item.fetch = {"id"}; + + const Tensor kOne = + test::AsTensor({1.0f, 1.0f, 1.0f, 1.0f}, TensorShape({2, 2})); + const Tensor kTwo = + test::AsTensor({2.0f, 2.0f, 2.0f, 2.0f}, TensorShape({2, 2})); + auto tensors_expected = EvaluateNodes(item.graph, item.fetch, + {{"then", kOne}, {"else", kTwo}}); + + // Use aggressive mode to force the shape inference to propagate + // placeholder shapes. + ConstantFolding optimizer(RewriterConfig::AGGRESSIVE, + /*cpu_device=*/nullptr); + GraphDef optimized_graph; + TF_EXPECT_OK( + optimizer.Optimize(/*cluster=*/nullptr, item, &optimized_graph)); + + ASSERT_EQ(optimized_graph.node_size(), 5); + bool found = false; + for (const auto& node : optimized_graph.node()) { + if (node.name() == "select") { + found = true; + EXPECT_EQ(node.op(), "Identity"); + ASSERT_EQ(node.input_size(), 3); + EXPECT_EQ(node.input(0), pred_val ? "then" : "else"); + EXPECT_EQ(node.input(1), pred_val ? "^if" : "^then"); + EXPECT_EQ(node.input(2), pred_val ? "^else" : "^if"); + } + } + EXPECT_TRUE(found); + + auto tensors = EvaluateNodes(optimized_graph, item.fetch, + {{"then", kOne}, {"else", kTwo}}); + ASSERT_EQ(tensors.size(), 1); + ASSERT_EQ(tensors_expected.size(), 1); + test::ExpectTensorEqual(tensors[0], tensors_expected[0]); + } + } +} + } // namespace } // namespace grappler } // namespace tensorflow