When constant folding Case nodes, propagate known output shapes to the "_output_shapes" attribute, to aid in shape inference.

PiperOrigin-RevId: 315996338
Change-Id: I2da7e935dacea11511229588bfd143ad46c1e5ac
This commit is contained in:
A. Unique TensorFlower 2020-06-11 16:01:36 -07:00 committed by TensorFlower Gardener
parent 07c8612582
commit 603e328a1c
2 changed files with 89 additions and 45 deletions

View File

@ -2385,8 +2385,9 @@ bool ConstantFolding::SimplifyCase(GraphDef* optimized_graph, NodeDef* node) {
if (node->op() != "Case") return false; if (node->op() != "Case") return false;
const NodeDef* output_idx_node = node_map_->GetNode(node->input(0)); const NodeDef* output_idx_node = node_map_->GetNode(node->input(0));
if (output_idx_node == nullptr || if (output_idx_node == nullptr ||
!CheckAttrExists(*output_idx_node, "value").ok()) !CheckAttrExists(*output_idx_node, "value").ok()) {
return false; return false;
}
Tensor output_idx_t; Tensor output_idx_t;
if (!output_idx_t.FromProto(output_idx_node->attr().at("value").tensor())) if (!output_idx_t.FromProto(output_idx_node->attr().at("value").tensor()))
return false; return false;
@ -2401,8 +2402,22 @@ bool ConstantFolding::SimplifyCase(GraphDef* optimized_graph, NodeDef* node) {
} }
auto* new_func = (*call_node.mutable_attr())["f"].mutable_func(); auto* new_func = (*call_node.mutable_attr())["f"].mutable_func();
*new_func = func_list.func(output_idx); *new_func = func_list.func(output_idx);
call_node.mutable_attr()->erase("branches");
// Move the output shape of the branch to _output_shapes if it is known.
const auto& output_shape_list =
(*node->mutable_attr())["output_shapes"].list();
if (output_shape_list.shape_size() > output_idx) {
TensorShapeProto* new_output_shape =
(*call_node.mutable_attr())["_output_shapes"]
.mutable_list()
->add_shape();
*new_output_shape =
std::move(node->attr().at("output_shapes").list().shape(output_idx));
}
call_node.mutable_attr()->erase("output_shapes"); call_node.mutable_attr()->erase("output_shapes");
call_node.mutable_attr()->erase("branches");
*node = std::move(call_node); *node = std::move(call_node);
return true; return true;
} }

View File

@ -4095,7 +4095,8 @@ TEST_F(ConstantFoldingTest, BitcastDenormalFloats) {
TEST_F(ConstantFoldingTest, SimplifyCase) { TEST_F(ConstantFoldingTest, SimplifyCase) {
using test::function::NDef; using test::function::NDef;
// Build a graph to compute y = Case(1, x, XTimesTwo(x), NonZero(x)) for (int index = 0; index < 2; ++index) {
// Build a graph to compute y = Case(index, x, XTimesTwo(x), NonZero(x))
GrapplerItem item; GrapplerItem item;
constexpr char kDevice[] = "/job:localhost/replica:0/task:0/device:CPU:0"; constexpr char kDevice[] = "/job:localhost/replica:0/task:0/device:CPU:0";
AttrValue branches; AttrValue branches;
@ -4106,15 +4107,28 @@ TEST_F(ConstantFoldingTest, SimplifyCase) {
*g = *f; *g = *f;
g->set_name("NonZero"); g->set_name("NonZero");
// Add a pair of somewhat arbitrary output shapes to
// test that they are correctly propagates to the _output_shapes
// attribute.
AttrValue output_shapes;
// The first shape is a scalar.
output_shapes.mutable_list()->add_shape();
// The second shape is unknown.
TensorShapeProto* g_shape = output_shapes.mutable_list()->add_shape();
g_shape->set_unknown_rank(true);
const Tensor kZero = test::AsScalar<int32>(0);
const Tensor kOne = test::AsScalar<int32>(1); const Tensor kOne = test::AsScalar<int32>(1);
item.graph = test::function::GDef( item.graph = test::function::GDef(
{NDef("one", "Const", {}, {{"value", kOne}, {"dtype", DT_INT32}}, {NDef("one", "Const", {},
{{"value", index == 0 ? kZero : kOne}, {"dtype", DT_INT32}},
kDevice), kDevice),
NDef("x", "Placeholder", {}, {{"dtype", DT_FLOAT}}, kDevice), NDef("x", "Placeholder", {}, {{"dtype", DT_FLOAT}}, kDevice),
NDef("case", "Case", {"one", "x"}, NDef("case", "Case", {"one", "x"},
{{"Tin", DataTypeSlice{DT_FLOAT}}, {{"Tin", DataTypeSlice{DT_FLOAT}},
{"Tout", DataTypeSlice{DT_FLOAT}}, {"Tout", DataTypeSlice{DT_FLOAT}},
{"branches", branches}}, {"branches", branches},
{"output_shapes", output_shapes}},
kDevice), kDevice),
NDef("y", "Identity", {"case"}, {{"T", DT_FLOAT}}, kDevice)}, NDef("y", "Identity", {"case"}, {{"T", DT_FLOAT}}, kDevice)},
// FunctionLib // FunctionLib
@ -4126,17 +4140,31 @@ TEST_F(ConstantFoldingTest, SimplifyCase) {
item.fetch = {"y"}; item.fetch = {"y"};
const Tensor kTwo = test::AsScalar<float>(2.0f); const Tensor kTwo = test::AsScalar<float>(2.0f);
auto tensors_expected = EvaluateNodes(item.graph, item.fetch, {{"x", kTwo}}); auto tensors_expected =
EvaluateNodes(item.graph, item.fetch, {{"x", kTwo}});
ConstantFolding optimizer(/*cpu_device=*/nullptr); ConstantFolding optimizer(/*cpu_device=*/nullptr);
GraphDef optimized_graph; GraphDef optimized_graph;
TF_ASSERT_OK(optimizer.Optimize(/*cluster=*/nullptr, item, &optimized_graph)); TF_ASSERT_OK(
optimizer.Optimize(/*cluster=*/nullptr, item, &optimized_graph));
VLOG(1) << "After: " << optimized_graph.DebugString(); VLOG(1) << "After: " << optimized_graph.DebugString();
int pco_count = 0; int pco_count = 0;
for (const auto& node : optimized_graph.node()) { for (const auto& node : optimized_graph.node()) {
EXPECT_NE(node.op(), "Case"); EXPECT_NE(node.op(), "Case");
if (node.op() == "PartitionedCall") ++pco_count; if (node.op() == "PartitionedCall") {
++pco_count;
const auto& shape_list = node.attr().at("_output_shapes").list();
ASSERT_EQ(shape_list.shape_size(), 1);
EXPECT_EQ(shape_list.shape(0).dim_size(), 0);
if (index == 0) {
EXPECT_EQ(node.attr().at("f").func().name(), "XTimesTwo");
EXPECT_EQ(shape_list.shape(0).unknown_rank(), false);
} else {
EXPECT_EQ(node.attr().at("f").func().name(), "NonZero");
EXPECT_EQ(shape_list.shape(0).unknown_rank(), true);
}
}
} }
EXPECT_EQ(pco_count, 1); EXPECT_EQ(pco_count, 1);
@ -4144,6 +4172,7 @@ TEST_F(ConstantFoldingTest, SimplifyCase) {
ASSERT_EQ(tensors.size(), tensors_expected.size()); ASSERT_EQ(tensors.size(), tensors_expected.size());
test::ExpectTensorEqual<float>(tensors[0], tensors_expected[0]); test::ExpectTensorEqual<float>(tensors[0], tensors_expected[0]);
} }
}
TEST_F(ConstantFoldingTest, SimplifySelect) { TEST_F(ConstantFoldingTest, SimplifySelect) {
for (bool scalar_pred : {true, false}) { for (bool scalar_pred : {true, false}) {