When constant folding Case nodes, propagate known output shapes to the "_output_shapes" attribute, to aid in shape inference.
PiperOrigin-RevId: 315996338 Change-Id: I2da7e935dacea11511229588bfd143ad46c1e5ac
This commit is contained in:
parent
07c8612582
commit
603e328a1c
@ -2385,8 +2385,9 @@ bool ConstantFolding::SimplifyCase(GraphDef* optimized_graph, NodeDef* node) {
|
||||
if (node->op() != "Case") return false;
|
||||
const NodeDef* output_idx_node = node_map_->GetNode(node->input(0));
|
||||
if (output_idx_node == nullptr ||
|
||||
!CheckAttrExists(*output_idx_node, "value").ok())
|
||||
!CheckAttrExists(*output_idx_node, "value").ok()) {
|
||||
return false;
|
||||
}
|
||||
Tensor output_idx_t;
|
||||
if (!output_idx_t.FromProto(output_idx_node->attr().at("value").tensor()))
|
||||
return false;
|
||||
@ -2401,8 +2402,22 @@ bool ConstantFolding::SimplifyCase(GraphDef* optimized_graph, NodeDef* node) {
|
||||
}
|
||||
auto* new_func = (*call_node.mutable_attr())["f"].mutable_func();
|
||||
*new_func = func_list.func(output_idx);
|
||||
call_node.mutable_attr()->erase("branches");
|
||||
|
||||
// Move the output shape of the branch to _output_shapes if it is known.
|
||||
const auto& output_shape_list =
|
||||
(*node->mutable_attr())["output_shapes"].list();
|
||||
if (output_shape_list.shape_size() > output_idx) {
|
||||
TensorShapeProto* new_output_shape =
|
||||
(*call_node.mutable_attr())["_output_shapes"]
|
||||
.mutable_list()
|
||||
->add_shape();
|
||||
*new_output_shape =
|
||||
std::move(node->attr().at("output_shapes").list().shape(output_idx));
|
||||
}
|
||||
|
||||
call_node.mutable_attr()->erase("output_shapes");
|
||||
call_node.mutable_attr()->erase("branches");
|
||||
|
||||
*node = std::move(call_node);
|
||||
return true;
|
||||
}
|
||||
|
@ -4095,54 +4095,83 @@ TEST_F(ConstantFoldingTest, BitcastDenormalFloats) {
|
||||
TEST_F(ConstantFoldingTest, SimplifyCase) {
|
||||
using test::function::NDef;
|
||||
|
||||
// Build a graph to compute y = Case(1, x, XTimesTwo(x), NonZero(x))
|
||||
GrapplerItem item;
|
||||
constexpr char kDevice[] = "/job:localhost/replica:0/task:0/device:CPU:0";
|
||||
AttrValue branches;
|
||||
auto* f = branches.mutable_list()->add_func();
|
||||
f->set_name("XTimesTwo");
|
||||
(*f->mutable_attr())["T"].set_type(DT_FLOAT);
|
||||
auto* g = branches.mutable_list()->add_func();
|
||||
*g = *f;
|
||||
g->set_name("NonZero");
|
||||
for (int index = 0; index < 2; ++index) {
|
||||
// Build a graph to compute y = Case(index, x, XTimesTwo(x), NonZero(x))
|
||||
GrapplerItem item;
|
||||
constexpr char kDevice[] = "/job:localhost/replica:0/task:0/device:CPU:0";
|
||||
AttrValue branches;
|
||||
auto* f = branches.mutable_list()->add_func();
|
||||
f->set_name("XTimesTwo");
|
||||
(*f->mutable_attr())["T"].set_type(DT_FLOAT);
|
||||
auto* g = branches.mutable_list()->add_func();
|
||||
*g = *f;
|
||||
g->set_name("NonZero");
|
||||
|
||||
const Tensor kOne = test::AsScalar<int32>(1);
|
||||
item.graph = test::function::GDef(
|
||||
{NDef("one", "Const", {}, {{"value", kOne}, {"dtype", DT_INT32}},
|
||||
kDevice),
|
||||
NDef("x", "Placeholder", {}, {{"dtype", DT_FLOAT}}, kDevice),
|
||||
NDef("case", "Case", {"one", "x"},
|
||||
{{"Tin", DataTypeSlice{DT_FLOAT}},
|
||||
{"Tout", DataTypeSlice{DT_FLOAT}},
|
||||
{"branches", branches}},
|
||||
kDevice),
|
||||
NDef("y", "Identity", {"case"}, {{"T", DT_FLOAT}}, kDevice)},
|
||||
// FunctionLib
|
||||
{
|
||||
test::function::XTimesTwo(),
|
||||
test::function::NonZero(),
|
||||
});
|
||||
VLOG(1) << "Before: " << item.graph.DebugString();
|
||||
// Add a pair of somewhat arbitrary output shapes to
|
||||
// test that they are correctly propagates to the _output_shapes
|
||||
// attribute.
|
||||
AttrValue output_shapes;
|
||||
// The first shape is a scalar.
|
||||
output_shapes.mutable_list()->add_shape();
|
||||
// The second shape is unknown.
|
||||
TensorShapeProto* g_shape = output_shapes.mutable_list()->add_shape();
|
||||
g_shape->set_unknown_rank(true);
|
||||
|
||||
item.fetch = {"y"};
|
||||
const Tensor kTwo = test::AsScalar<float>(2.0f);
|
||||
auto tensors_expected = EvaluateNodes(item.graph, item.fetch, {{"x", kTwo}});
|
||||
const Tensor kZero = test::AsScalar<int32>(0);
|
||||
const Tensor kOne = test::AsScalar<int32>(1);
|
||||
item.graph = test::function::GDef(
|
||||
{NDef("one", "Const", {},
|
||||
{{"value", index == 0 ? kZero : kOne}, {"dtype", DT_INT32}},
|
||||
kDevice),
|
||||
NDef("x", "Placeholder", {}, {{"dtype", DT_FLOAT}}, kDevice),
|
||||
NDef("case", "Case", {"one", "x"},
|
||||
{{"Tin", DataTypeSlice{DT_FLOAT}},
|
||||
{"Tout", DataTypeSlice{DT_FLOAT}},
|
||||
{"branches", branches},
|
||||
{"output_shapes", output_shapes}},
|
||||
kDevice),
|
||||
NDef("y", "Identity", {"case"}, {{"T", DT_FLOAT}}, kDevice)},
|
||||
// FunctionLib
|
||||
{
|
||||
test::function::XTimesTwo(),
|
||||
test::function::NonZero(),
|
||||
});
|
||||
VLOG(1) << "Before: " << item.graph.DebugString();
|
||||
|
||||
ConstantFolding optimizer(/*cpu_device=*/nullptr);
|
||||
GraphDef optimized_graph;
|
||||
TF_ASSERT_OK(optimizer.Optimize(/*cluster=*/nullptr, item, &optimized_graph));
|
||||
VLOG(1) << "After: " << optimized_graph.DebugString();
|
||||
item.fetch = {"y"};
|
||||
const Tensor kTwo = test::AsScalar<float>(2.0f);
|
||||
auto tensors_expected =
|
||||
EvaluateNodes(item.graph, item.fetch, {{"x", kTwo}});
|
||||
|
||||
int pco_count = 0;
|
||||
for (const auto& node : optimized_graph.node()) {
|
||||
EXPECT_NE(node.op(), "Case");
|
||||
if (node.op() == "PartitionedCall") ++pco_count;
|
||||
ConstantFolding optimizer(/*cpu_device=*/nullptr);
|
||||
GraphDef optimized_graph;
|
||||
TF_ASSERT_OK(
|
||||
optimizer.Optimize(/*cluster=*/nullptr, item, &optimized_graph));
|
||||
VLOG(1) << "After: " << optimized_graph.DebugString();
|
||||
|
||||
int pco_count = 0;
|
||||
for (const auto& node : optimized_graph.node()) {
|
||||
EXPECT_NE(node.op(), "Case");
|
||||
if (node.op() == "PartitionedCall") {
|
||||
++pco_count;
|
||||
const auto& shape_list = node.attr().at("_output_shapes").list();
|
||||
ASSERT_EQ(shape_list.shape_size(), 1);
|
||||
EXPECT_EQ(shape_list.shape(0).dim_size(), 0);
|
||||
if (index == 0) {
|
||||
EXPECT_EQ(node.attr().at("f").func().name(), "XTimesTwo");
|
||||
EXPECT_EQ(shape_list.shape(0).unknown_rank(), false);
|
||||
} else {
|
||||
EXPECT_EQ(node.attr().at("f").func().name(), "NonZero");
|
||||
EXPECT_EQ(shape_list.shape(0).unknown_rank(), true);
|
||||
}
|
||||
}
|
||||
}
|
||||
EXPECT_EQ(pco_count, 1);
|
||||
|
||||
auto tensors = EvaluateNodes(optimized_graph, item.fetch, {{"x", kTwo}});
|
||||
ASSERT_EQ(tensors.size(), tensors_expected.size());
|
||||
test::ExpectTensorEqual<float>(tensors[0], tensors_expected[0]);
|
||||
}
|
||||
EXPECT_EQ(pco_count, 1);
|
||||
|
||||
auto tensors = EvaluateNodes(optimized_graph, item.fetch, {{"x", kTwo}});
|
||||
ASSERT_EQ(tensors.size(), tensors_expected.size());
|
||||
test::ExpectTensorEqual<float>(tensors[0], tensors_expected[0]);
|
||||
}
|
||||
|
||||
TEST_F(ConstantFoldingTest, SimplifySelect) {
|
||||
|
Loading…
Reference in New Issue
Block a user