When constant folding Case nodes, propagate known output shapes to the "_output_shapes" attribute, to aid in shape inference.

PiperOrigin-RevId: 315996338
Change-Id: I2da7e935dacea11511229588bfd143ad46c1e5ac
This commit is contained in:
A. Unique TensorFlower 2020-06-11 16:01:36 -07:00 committed by TensorFlower Gardener
parent 07c8612582
commit 603e328a1c
2 changed files with 89 additions and 45 deletions

View File

@ -2385,8 +2385,9 @@ bool ConstantFolding::SimplifyCase(GraphDef* optimized_graph, NodeDef* node) {
if (node->op() != "Case") return false;
const NodeDef* output_idx_node = node_map_->GetNode(node->input(0));
if (output_idx_node == nullptr ||
!CheckAttrExists(*output_idx_node, "value").ok())
!CheckAttrExists(*output_idx_node, "value").ok()) {
return false;
}
Tensor output_idx_t;
if (!output_idx_t.FromProto(output_idx_node->attr().at("value").tensor()))
return false;
@ -2401,8 +2402,22 @@ bool ConstantFolding::SimplifyCase(GraphDef* optimized_graph, NodeDef* node) {
}
auto* new_func = (*call_node.mutable_attr())["f"].mutable_func();
*new_func = func_list.func(output_idx);
call_node.mutable_attr()->erase("branches");
// Move the output shape of the branch to _output_shapes if it is known.
const auto& output_shape_list =
(*node->mutable_attr())["output_shapes"].list();
if (output_shape_list.shape_size() > output_idx) {
TensorShapeProto* new_output_shape =
(*call_node.mutable_attr())["_output_shapes"]
.mutable_list()
->add_shape();
*new_output_shape =
std::move(node->attr().at("output_shapes").list().shape(output_idx));
}
call_node.mutable_attr()->erase("output_shapes");
call_node.mutable_attr()->erase("branches");
*node = std::move(call_node);
return true;
}

View File

@ -4095,54 +4095,83 @@ TEST_F(ConstantFoldingTest, BitcastDenormalFloats) {
TEST_F(ConstantFoldingTest, SimplifyCase) {
using test::function::NDef;
// Build a graph to compute y = Case(1, x, XTimesTwo(x), NonZero(x))
GrapplerItem item;
constexpr char kDevice[] = "/job:localhost/replica:0/task:0/device:CPU:0";
AttrValue branches;
auto* f = branches.mutable_list()->add_func();
f->set_name("XTimesTwo");
(*f->mutable_attr())["T"].set_type(DT_FLOAT);
auto* g = branches.mutable_list()->add_func();
*g = *f;
g->set_name("NonZero");
for (int index = 0; index < 2; ++index) {
// Build a graph to compute y = Case(index, x, XTimesTwo(x), NonZero(x))
GrapplerItem item;
constexpr char kDevice[] = "/job:localhost/replica:0/task:0/device:CPU:0";
AttrValue branches;
auto* f = branches.mutable_list()->add_func();
f->set_name("XTimesTwo");
(*f->mutable_attr())["T"].set_type(DT_FLOAT);
auto* g = branches.mutable_list()->add_func();
*g = *f;
g->set_name("NonZero");
const Tensor kOne = test::AsScalar<int32>(1);
item.graph = test::function::GDef(
{NDef("one", "Const", {}, {{"value", kOne}, {"dtype", DT_INT32}},
kDevice),
NDef("x", "Placeholder", {}, {{"dtype", DT_FLOAT}}, kDevice),
NDef("case", "Case", {"one", "x"},
{{"Tin", DataTypeSlice{DT_FLOAT}},
{"Tout", DataTypeSlice{DT_FLOAT}},
{"branches", branches}},
kDevice),
NDef("y", "Identity", {"case"}, {{"T", DT_FLOAT}}, kDevice)},
// FunctionLib
{
test::function::XTimesTwo(),
test::function::NonZero(),
});
VLOG(1) << "Before: " << item.graph.DebugString();
// Add a pair of somewhat arbitrary output shapes to
// test that they are correctly propagates to the _output_shapes
// attribute.
AttrValue output_shapes;
// The first shape is a scalar.
output_shapes.mutable_list()->add_shape();
// The second shape is unknown.
TensorShapeProto* g_shape = output_shapes.mutable_list()->add_shape();
g_shape->set_unknown_rank(true);
item.fetch = {"y"};
const Tensor kTwo = test::AsScalar<float>(2.0f);
auto tensors_expected = EvaluateNodes(item.graph, item.fetch, {{"x", kTwo}});
const Tensor kZero = test::AsScalar<int32>(0);
const Tensor kOne = test::AsScalar<int32>(1);
item.graph = test::function::GDef(
{NDef("one", "Const", {},
{{"value", index == 0 ? kZero : kOne}, {"dtype", DT_INT32}},
kDevice),
NDef("x", "Placeholder", {}, {{"dtype", DT_FLOAT}}, kDevice),
NDef("case", "Case", {"one", "x"},
{{"Tin", DataTypeSlice{DT_FLOAT}},
{"Tout", DataTypeSlice{DT_FLOAT}},
{"branches", branches},
{"output_shapes", output_shapes}},
kDevice),
NDef("y", "Identity", {"case"}, {{"T", DT_FLOAT}}, kDevice)},
// FunctionLib
{
test::function::XTimesTwo(),
test::function::NonZero(),
});
VLOG(1) << "Before: " << item.graph.DebugString();
ConstantFolding optimizer(/*cpu_device=*/nullptr);
GraphDef optimized_graph;
TF_ASSERT_OK(optimizer.Optimize(/*cluster=*/nullptr, item, &optimized_graph));
VLOG(1) << "After: " << optimized_graph.DebugString();
item.fetch = {"y"};
const Tensor kTwo = test::AsScalar<float>(2.0f);
auto tensors_expected =
EvaluateNodes(item.graph, item.fetch, {{"x", kTwo}});
int pco_count = 0;
for (const auto& node : optimized_graph.node()) {
EXPECT_NE(node.op(), "Case");
if (node.op() == "PartitionedCall") ++pco_count;
ConstantFolding optimizer(/*cpu_device=*/nullptr);
GraphDef optimized_graph;
TF_ASSERT_OK(
optimizer.Optimize(/*cluster=*/nullptr, item, &optimized_graph));
VLOG(1) << "After: " << optimized_graph.DebugString();
int pco_count = 0;
for (const auto& node : optimized_graph.node()) {
EXPECT_NE(node.op(), "Case");
if (node.op() == "PartitionedCall") {
++pco_count;
const auto& shape_list = node.attr().at("_output_shapes").list();
ASSERT_EQ(shape_list.shape_size(), 1);
EXPECT_EQ(shape_list.shape(0).dim_size(), 0);
if (index == 0) {
EXPECT_EQ(node.attr().at("f").func().name(), "XTimesTwo");
EXPECT_EQ(shape_list.shape(0).unknown_rank(), false);
} else {
EXPECT_EQ(node.attr().at("f").func().name(), "NonZero");
EXPECT_EQ(shape_list.shape(0).unknown_rank(), true);
}
}
}
EXPECT_EQ(pco_count, 1);
auto tensors = EvaluateNodes(optimized_graph, item.fetch, {{"x", kTwo}});
ASSERT_EQ(tensors.size(), tensors_expected.size());
test::ExpectTensorEqual<float>(tensors[0], tensors_expected[0]);
}
EXPECT_EQ(pco_count, 1);
auto tensors = EvaluateNodes(optimized_graph, item.fetch, {{"x", kTwo}});
ASSERT_EQ(tensors.size(), tensors_expected.size());
test::ExpectTensorEqual<float>(tensors[0], tensors_expected[0]);
}
TEST_F(ConstantFoldingTest, SimplifySelect) {