[Grappler] Simplify Select nodes where the predicates have been constant folded to all true or all false.
PiperOrigin-RevId: 314439636 Change-Id: I693dbc29360005982e75b110d5d2bd7cf8e4914c
This commit is contained in:
parent
5bca7e22aa
commit
a8b9c83afe
|
@ -2032,6 +2032,8 @@ Status ConstantFolding::SimplifyNode(bool use_shape_info, NodeDef* node,
|
|||
SET_AND_RETURN_IF_MODIFIED(
|
||||
ConstantPushDownBiasAdd(properties, optimized_graph, node));
|
||||
SET_AND_RETURN_IF_MODIFIED(SimplifyCase(optimized_graph, node));
|
||||
SET_AND_RETURN_IF_MODIFIED(
|
||||
SimplifySelect(*properties, optimized_graph, node));
|
||||
|
||||
graph_modified_ = graph_modified_cached;
|
||||
return Status::OK();
|
||||
|
@ -2405,6 +2407,40 @@ bool ConstantFolding::SimplifyCase(GraphDef* optimized_graph, NodeDef* node) {
|
|||
return true;
|
||||
}
|
||||
|
||||
bool ConstantFolding::SimplifySelect(const GraphProperties& properties,
|
||||
GraphDef* optimized_graph, NodeDef* node) {
|
||||
if (!IsSelect(*node)) return false;
|
||||
// Replace node with Identity if no broadcasting is involved.
|
||||
// TODO(b/155503011): Add support for broadcast.
|
||||
const std::vector<OpInfo::TensorProperties>& input_props =
|
||||
properties.GetInputProperties(node->name());
|
||||
if (input_props.size() < 3) return false;
|
||||
const TensorShapeProto& predicate_shape = input_props[0].shape();
|
||||
const bool predicate_is_scalar =
|
||||
!predicate_shape.unknown_rank() && predicate_shape.dim_size() == 0;
|
||||
if (!ShapesSymbolicallyEqual(input_props[1], input_props[2]) ||
|
||||
!(ShapesSymbolicallyEqual(input_props[0], input_props[1]) ||
|
||||
predicate_is_scalar)) {
|
||||
return false;
|
||||
}
|
||||
const NodeDef* predicate_node = node_map_->GetNode(node->input(0));
|
||||
const bool is_all_true = IsOnes(*predicate_node);
|
||||
const bool is_all_false = IsZeros(*predicate_node);
|
||||
if (!is_all_true && !is_all_false) {
|
||||
return false;
|
||||
}
|
||||
const int live_input_idx = is_all_true ? 1 : 2;
|
||||
const int ignored_input_idx = is_all_true ? 2 : 1;
|
||||
node->set_op("Identity");
|
||||
*node->mutable_input(0) =
|
||||
AddControlDependency(node->input(0), optimized_graph, node_map_.get());
|
||||
*node->mutable_input(ignored_input_idx) = AddControlDependency(
|
||||
node->input(ignored_input_idx), optimized_graph, node_map_.get());
|
||||
node->mutable_input()->SwapElements(0, live_input_idx);
|
||||
DedupControlInputs(node);
|
||||
return true;
|
||||
}
|
||||
|
||||
bool ConstantFolding::MoveConstantsPastEnter(GraphDef* optimized_graph,
|
||||
NodeDef* node) {
|
||||
if (!IsEnter(*node) || node->input_size() == 0 ||
|
||||
|
@ -3771,6 +3807,7 @@ Status ConstantFolding::RunOptimizationPass(Cluster* cluster,
|
|||
/*include_output_tensor_values=*/true);
|
||||
|
||||
const bool can_use_shape_info = s.ok();
|
||||
VLOG(1) << "can_use_shape_info = " << can_use_shape_info;
|
||||
|
||||
absl::flat_hash_set<string> nodes_to_not_simplify;
|
||||
if (can_use_shape_info) {
|
||||
|
|
|
@ -282,6 +282,10 @@ class ConstantFolding : public GraphOptimizer {
|
|||
// Simplify a Case operation where the output_idx is known.
|
||||
bool SimplifyCase(GraphDef* optimized_graph, NodeDef* node);
|
||||
|
||||
// Simplify a Select operation where the predicates are all true or all false.
|
||||
bool SimplifySelect(const GraphProperties& properties,
|
||||
GraphDef* optimized_graph, NodeDef* node);
|
||||
|
||||
// Removes Reverse op over dimensions with size 1.
|
||||
Status RemoveReverse(const GraphProperties& properties, bool use_shape_info,
|
||||
GraphDef* optimized_graph, NodeDef* node);
|
||||
|
|
|
@ -4145,6 +4145,72 @@ TEST_F(ConstantFoldingTest, SimplifyCase) {
|
|||
test::ExpectTensorEqual<float>(tensors[0], tensors_expected[0]);
|
||||
}
|
||||
|
||||
TEST_F(ConstantFoldingTest, SimplifySelect) {
|
||||
for (bool scalar_pred : {true, false}) {
|
||||
for (bool pred_val : {true, false}) {
|
||||
tensorflow::Scope scope = tensorflow::Scope::NewRootScope();
|
||||
std::unique_ptr<Tensor> if_t;
|
||||
if (scalar_pred) {
|
||||
if_t.reset(new Tensor(DT_BOOL, TensorShape()));
|
||||
} else {
|
||||
if_t.reset(new Tensor(DT_BOOL, TensorShape({2, 2})));
|
||||
}
|
||||
for (int i = 0; i < (scalar_pred ? 1 : 4); ++i) {
|
||||
if_t->flat<bool>()(i) = pred_val;
|
||||
}
|
||||
Output if_ = ops::Const(scope.WithOpName("if"), *if_t);
|
||||
Output then_ =
|
||||
ops::Placeholder(scope.WithOpName("then"), DT_FLOAT,
|
||||
ops::Placeholder::Shape(TensorShape({2, 2})));
|
||||
Output else_ =
|
||||
ops::Placeholder(scope.WithOpName("else"), DT_FLOAT,
|
||||
ops::Placeholder::Shape(TensorShape({2, 2})));
|
||||
Output select =
|
||||
ops::SelectV2(scope.WithOpName("select"), if_, then_, else_);
|
||||
Output id = ops::Identity(scope.WithOpName("id"), select);
|
||||
|
||||
GrapplerItem item;
|
||||
TF_CHECK_OK(scope.ToGraphDef(&item.graph));
|
||||
item.fetch = {"id"};
|
||||
|
||||
const Tensor kOne =
|
||||
test::AsTensor<float>({1.0f, 1.0f, 1.0f, 1.0f}, TensorShape({2, 2}));
|
||||
const Tensor kTwo =
|
||||
test::AsTensor<float>({2.0f, 2.0f, 2.0f, 2.0f}, TensorShape({2, 2}));
|
||||
auto tensors_expected = EvaluateNodes(item.graph, item.fetch,
|
||||
{{"then", kOne}, {"else", kTwo}});
|
||||
|
||||
// Use aggressive mode to force the shape inference to propagate
|
||||
// placeholder shapes.
|
||||
ConstantFolding optimizer(RewriterConfig::AGGRESSIVE,
|
||||
/*cpu_device=*/nullptr);
|
||||
GraphDef optimized_graph;
|
||||
TF_EXPECT_OK(
|
||||
optimizer.Optimize(/*cluster=*/nullptr, item, &optimized_graph));
|
||||
|
||||
ASSERT_EQ(optimized_graph.node_size(), 5);
|
||||
bool found = false;
|
||||
for (const auto& node : optimized_graph.node()) {
|
||||
if (node.name() == "select") {
|
||||
found = true;
|
||||
EXPECT_EQ(node.op(), "Identity");
|
||||
ASSERT_EQ(node.input_size(), 3);
|
||||
EXPECT_EQ(node.input(0), pred_val ? "then" : "else");
|
||||
EXPECT_EQ(node.input(1), pred_val ? "^if" : "^then");
|
||||
EXPECT_EQ(node.input(2), pred_val ? "^else" : "^if");
|
||||
}
|
||||
}
|
||||
EXPECT_TRUE(found);
|
||||
|
||||
auto tensors = EvaluateNodes(optimized_graph, item.fetch,
|
||||
{{"then", kOne}, {"else", kTwo}});
|
||||
ASSERT_EQ(tensors.size(), 1);
|
||||
ASSERT_EQ(tensors_expected.size(), 1);
|
||||
test::ExpectTensorEqual<float>(tensors[0], tensors_expected[0]);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace
|
||||
} // namespace grappler
|
||||
} // namespace tensorflow
|
||||
|
|
Loading…
Reference in New Issue