[Grappler] Simplify Select nodes where the predicates have been constant folded to all true or all false.

PiperOrigin-RevId: 314439636
Change-Id: I693dbc29360005982e75b110d5d2bd7cf8e4914c
This commit is contained in:
A. Unique TensorFlower 2020-06-02 17:51:00 -07:00 committed by TensorFlower Gardener
parent 5bca7e22aa
commit a8b9c83afe
3 changed files with 107 additions and 0 deletions

View File

@ -2032,6 +2032,8 @@ Status ConstantFolding::SimplifyNode(bool use_shape_info, NodeDef* node,
SET_AND_RETURN_IF_MODIFIED(
ConstantPushDownBiasAdd(properties, optimized_graph, node));
SET_AND_RETURN_IF_MODIFIED(SimplifyCase(optimized_graph, node));
SET_AND_RETURN_IF_MODIFIED(
SimplifySelect(*properties, optimized_graph, node));
graph_modified_ = graph_modified_cached;
return Status::OK();
@ -2405,6 +2407,40 @@ bool ConstantFolding::SimplifyCase(GraphDef* optimized_graph, NodeDef* node) {
return true;
}
bool ConstantFolding::SimplifySelect(const GraphProperties& properties,
GraphDef* optimized_graph, NodeDef* node) {
if (!IsSelect(*node)) return false;
// Replace node with Identity if no broadcasting is involved.
// TODO(b/155503011): Add support for broadcast.
const std::vector<OpInfo::TensorProperties>& input_props =
properties.GetInputProperties(node->name());
if (input_props.size() < 3) return false;
const TensorShapeProto& predicate_shape = input_props[0].shape();
const bool predicate_is_scalar =
!predicate_shape.unknown_rank() && predicate_shape.dim_size() == 0;
if (!ShapesSymbolicallyEqual(input_props[1], input_props[2]) ||
!(ShapesSymbolicallyEqual(input_props[0], input_props[1]) ||
predicate_is_scalar)) {
return false;
}
const NodeDef* predicate_node = node_map_->GetNode(node->input(0));
const bool is_all_true = IsOnes(*predicate_node);
const bool is_all_false = IsZeros(*predicate_node);
if (!is_all_true && !is_all_false) {
return false;
}
const int live_input_idx = is_all_true ? 1 : 2;
const int ignored_input_idx = is_all_true ? 2 : 1;
node->set_op("Identity");
*node->mutable_input(0) =
AddControlDependency(node->input(0), optimized_graph, node_map_.get());
*node->mutable_input(ignored_input_idx) = AddControlDependency(
node->input(ignored_input_idx), optimized_graph, node_map_.get());
node->mutable_input()->SwapElements(0, live_input_idx);
DedupControlInputs(node);
return true;
}
bool ConstantFolding::MoveConstantsPastEnter(GraphDef* optimized_graph,
NodeDef* node) {
if (!IsEnter(*node) || node->input_size() == 0 ||
@ -3771,6 +3807,7 @@ Status ConstantFolding::RunOptimizationPass(Cluster* cluster,
/*include_output_tensor_values=*/true);
const bool can_use_shape_info = s.ok();
VLOG(1) << "can_use_shape_info = " << can_use_shape_info;
absl::flat_hash_set<string> nodes_to_not_simplify;
if (can_use_shape_info) {

View File

@ -282,6 +282,10 @@ class ConstantFolding : public GraphOptimizer {
// Simplify a Case operation where the output_idx is known.
bool SimplifyCase(GraphDef* optimized_graph, NodeDef* node);
// Simplify a Select operation where the predicates are all true or all false.
bool SimplifySelect(const GraphProperties& properties,
GraphDef* optimized_graph, NodeDef* node);
// Removes Reverse op over dimensions with size 1.
Status RemoveReverse(const GraphProperties& properties, bool use_shape_info,
GraphDef* optimized_graph, NodeDef* node);

View File

@ -4145,6 +4145,72 @@ TEST_F(ConstantFoldingTest, SimplifyCase) {
test::ExpectTensorEqual<float>(tensors[0], tensors_expected[0]);
}
TEST_F(ConstantFoldingTest, SimplifySelect) {
for (bool scalar_pred : {true, false}) {
for (bool pred_val : {true, false}) {
tensorflow::Scope scope = tensorflow::Scope::NewRootScope();
std::unique_ptr<Tensor> if_t;
if (scalar_pred) {
if_t.reset(new Tensor(DT_BOOL, TensorShape()));
} else {
if_t.reset(new Tensor(DT_BOOL, TensorShape({2, 2})));
}
for (int i = 0; i < (scalar_pred ? 1 : 4); ++i) {
if_t->flat<bool>()(i) = pred_val;
}
Output if_ = ops::Const(scope.WithOpName("if"), *if_t);
Output then_ =
ops::Placeholder(scope.WithOpName("then"), DT_FLOAT,
ops::Placeholder::Shape(TensorShape({2, 2})));
Output else_ =
ops::Placeholder(scope.WithOpName("else"), DT_FLOAT,
ops::Placeholder::Shape(TensorShape({2, 2})));
Output select =
ops::SelectV2(scope.WithOpName("select"), if_, then_, else_);
Output id = ops::Identity(scope.WithOpName("id"), select);
GrapplerItem item;
TF_CHECK_OK(scope.ToGraphDef(&item.graph));
item.fetch = {"id"};
const Tensor kOne =
test::AsTensor<float>({1.0f, 1.0f, 1.0f, 1.0f}, TensorShape({2, 2}));
const Tensor kTwo =
test::AsTensor<float>({2.0f, 2.0f, 2.0f, 2.0f}, TensorShape({2, 2}));
auto tensors_expected = EvaluateNodes(item.graph, item.fetch,
{{"then", kOne}, {"else", kTwo}});
// Use aggressive mode to force the shape inference to propagate
// placeholder shapes.
ConstantFolding optimizer(RewriterConfig::AGGRESSIVE,
/*cpu_device=*/nullptr);
GraphDef optimized_graph;
TF_EXPECT_OK(
optimizer.Optimize(/*cluster=*/nullptr, item, &optimized_graph));
ASSERT_EQ(optimized_graph.node_size(), 5);
bool found = false;
for (const auto& node : optimized_graph.node()) {
if (node.name() == "select") {
found = true;
EXPECT_EQ(node.op(), "Identity");
ASSERT_EQ(node.input_size(), 3);
EXPECT_EQ(node.input(0), pred_val ? "then" : "else");
EXPECT_EQ(node.input(1), pred_val ? "^if" : "^then");
EXPECT_EQ(node.input(2), pred_val ? "^else" : "^if");
}
}
EXPECT_TRUE(found);
auto tensors = EvaluateNodes(optimized_graph, item.fetch,
{{"then", kOne}, {"else", kTwo}});
ASSERT_EQ(tensors.size(), 1);
ASSERT_EQ(tensors_expected.size(), 1);
test::ExpectTensorEqual<float>(tensors[0], tensors_expected[0]);
}
}
}
} // namespace
} // namespace grappler
} // namespace tensorflow