When constant folding Case nodes, propagate known output shapes to the "_output_shapes" attribute, to aid in shape inference.
PiperOrigin-RevId: 315996338 Change-Id: I2da7e935dacea11511229588bfd143ad46c1e5ac
This commit is contained in:
parent
07c8612582
commit
603e328a1c
@ -2385,8 +2385,9 @@ bool ConstantFolding::SimplifyCase(GraphDef* optimized_graph, NodeDef* node) {
|
|||||||
if (node->op() != "Case") return false;
|
if (node->op() != "Case") return false;
|
||||||
const NodeDef* output_idx_node = node_map_->GetNode(node->input(0));
|
const NodeDef* output_idx_node = node_map_->GetNode(node->input(0));
|
||||||
if (output_idx_node == nullptr ||
|
if (output_idx_node == nullptr ||
|
||||||
!CheckAttrExists(*output_idx_node, "value").ok())
|
!CheckAttrExists(*output_idx_node, "value").ok()) {
|
||||||
return false;
|
return false;
|
||||||
|
}
|
||||||
Tensor output_idx_t;
|
Tensor output_idx_t;
|
||||||
if (!output_idx_t.FromProto(output_idx_node->attr().at("value").tensor()))
|
if (!output_idx_t.FromProto(output_idx_node->attr().at("value").tensor()))
|
||||||
return false;
|
return false;
|
||||||
@ -2401,8 +2402,22 @@ bool ConstantFolding::SimplifyCase(GraphDef* optimized_graph, NodeDef* node) {
|
|||||||
}
|
}
|
||||||
auto* new_func = (*call_node.mutable_attr())["f"].mutable_func();
|
auto* new_func = (*call_node.mutable_attr())["f"].mutable_func();
|
||||||
*new_func = func_list.func(output_idx);
|
*new_func = func_list.func(output_idx);
|
||||||
call_node.mutable_attr()->erase("branches");
|
|
||||||
|
// Move the output shape of the branch to _output_shapes if it is known.
|
||||||
|
const auto& output_shape_list =
|
||||||
|
(*node->mutable_attr())["output_shapes"].list();
|
||||||
|
if (output_shape_list.shape_size() > output_idx) {
|
||||||
|
TensorShapeProto* new_output_shape =
|
||||||
|
(*call_node.mutable_attr())["_output_shapes"]
|
||||||
|
.mutable_list()
|
||||||
|
->add_shape();
|
||||||
|
*new_output_shape =
|
||||||
|
std::move(node->attr().at("output_shapes").list().shape(output_idx));
|
||||||
|
}
|
||||||
|
|
||||||
call_node.mutable_attr()->erase("output_shapes");
|
call_node.mutable_attr()->erase("output_shapes");
|
||||||
|
call_node.mutable_attr()->erase("branches");
|
||||||
|
|
||||||
*node = std::move(call_node);
|
*node = std::move(call_node);
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
@ -4095,7 +4095,8 @@ TEST_F(ConstantFoldingTest, BitcastDenormalFloats) {
|
|||||||
TEST_F(ConstantFoldingTest, SimplifyCase) {
|
TEST_F(ConstantFoldingTest, SimplifyCase) {
|
||||||
using test::function::NDef;
|
using test::function::NDef;
|
||||||
|
|
||||||
// Build a graph to compute y = Case(1, x, XTimesTwo(x), NonZero(x))
|
for (int index = 0; index < 2; ++index) {
|
||||||
|
// Build a graph to compute y = Case(index, x, XTimesTwo(x), NonZero(x))
|
||||||
GrapplerItem item;
|
GrapplerItem item;
|
||||||
constexpr char kDevice[] = "/job:localhost/replica:0/task:0/device:CPU:0";
|
constexpr char kDevice[] = "/job:localhost/replica:0/task:0/device:CPU:0";
|
||||||
AttrValue branches;
|
AttrValue branches;
|
||||||
@ -4106,15 +4107,28 @@ TEST_F(ConstantFoldingTest, SimplifyCase) {
|
|||||||
*g = *f;
|
*g = *f;
|
||||||
g->set_name("NonZero");
|
g->set_name("NonZero");
|
||||||
|
|
||||||
|
// Add a pair of somewhat arbitrary output shapes to
|
||||||
|
// test that they are correctly propagates to the _output_shapes
|
||||||
|
// attribute.
|
||||||
|
AttrValue output_shapes;
|
||||||
|
// The first shape is a scalar.
|
||||||
|
output_shapes.mutable_list()->add_shape();
|
||||||
|
// The second shape is unknown.
|
||||||
|
TensorShapeProto* g_shape = output_shapes.mutable_list()->add_shape();
|
||||||
|
g_shape->set_unknown_rank(true);
|
||||||
|
|
||||||
|
const Tensor kZero = test::AsScalar<int32>(0);
|
||||||
const Tensor kOne = test::AsScalar<int32>(1);
|
const Tensor kOne = test::AsScalar<int32>(1);
|
||||||
item.graph = test::function::GDef(
|
item.graph = test::function::GDef(
|
||||||
{NDef("one", "Const", {}, {{"value", kOne}, {"dtype", DT_INT32}},
|
{NDef("one", "Const", {},
|
||||||
|
{{"value", index == 0 ? kZero : kOne}, {"dtype", DT_INT32}},
|
||||||
kDevice),
|
kDevice),
|
||||||
NDef("x", "Placeholder", {}, {{"dtype", DT_FLOAT}}, kDevice),
|
NDef("x", "Placeholder", {}, {{"dtype", DT_FLOAT}}, kDevice),
|
||||||
NDef("case", "Case", {"one", "x"},
|
NDef("case", "Case", {"one", "x"},
|
||||||
{{"Tin", DataTypeSlice{DT_FLOAT}},
|
{{"Tin", DataTypeSlice{DT_FLOAT}},
|
||||||
{"Tout", DataTypeSlice{DT_FLOAT}},
|
{"Tout", DataTypeSlice{DT_FLOAT}},
|
||||||
{"branches", branches}},
|
{"branches", branches},
|
||||||
|
{"output_shapes", output_shapes}},
|
||||||
kDevice),
|
kDevice),
|
||||||
NDef("y", "Identity", {"case"}, {{"T", DT_FLOAT}}, kDevice)},
|
NDef("y", "Identity", {"case"}, {{"T", DT_FLOAT}}, kDevice)},
|
||||||
// FunctionLib
|
// FunctionLib
|
||||||
@ -4126,17 +4140,31 @@ TEST_F(ConstantFoldingTest, SimplifyCase) {
|
|||||||
|
|
||||||
item.fetch = {"y"};
|
item.fetch = {"y"};
|
||||||
const Tensor kTwo = test::AsScalar<float>(2.0f);
|
const Tensor kTwo = test::AsScalar<float>(2.0f);
|
||||||
auto tensors_expected = EvaluateNodes(item.graph, item.fetch, {{"x", kTwo}});
|
auto tensors_expected =
|
||||||
|
EvaluateNodes(item.graph, item.fetch, {{"x", kTwo}});
|
||||||
|
|
||||||
ConstantFolding optimizer(/*cpu_device=*/nullptr);
|
ConstantFolding optimizer(/*cpu_device=*/nullptr);
|
||||||
GraphDef optimized_graph;
|
GraphDef optimized_graph;
|
||||||
TF_ASSERT_OK(optimizer.Optimize(/*cluster=*/nullptr, item, &optimized_graph));
|
TF_ASSERT_OK(
|
||||||
|
optimizer.Optimize(/*cluster=*/nullptr, item, &optimized_graph));
|
||||||
VLOG(1) << "After: " << optimized_graph.DebugString();
|
VLOG(1) << "After: " << optimized_graph.DebugString();
|
||||||
|
|
||||||
int pco_count = 0;
|
int pco_count = 0;
|
||||||
for (const auto& node : optimized_graph.node()) {
|
for (const auto& node : optimized_graph.node()) {
|
||||||
EXPECT_NE(node.op(), "Case");
|
EXPECT_NE(node.op(), "Case");
|
||||||
if (node.op() == "PartitionedCall") ++pco_count;
|
if (node.op() == "PartitionedCall") {
|
||||||
|
++pco_count;
|
||||||
|
const auto& shape_list = node.attr().at("_output_shapes").list();
|
||||||
|
ASSERT_EQ(shape_list.shape_size(), 1);
|
||||||
|
EXPECT_EQ(shape_list.shape(0).dim_size(), 0);
|
||||||
|
if (index == 0) {
|
||||||
|
EXPECT_EQ(node.attr().at("f").func().name(), "XTimesTwo");
|
||||||
|
EXPECT_EQ(shape_list.shape(0).unknown_rank(), false);
|
||||||
|
} else {
|
||||||
|
EXPECT_EQ(node.attr().at("f").func().name(), "NonZero");
|
||||||
|
EXPECT_EQ(shape_list.shape(0).unknown_rank(), true);
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
EXPECT_EQ(pco_count, 1);
|
EXPECT_EQ(pco_count, 1);
|
||||||
|
|
||||||
@ -4144,6 +4172,7 @@ TEST_F(ConstantFoldingTest, SimplifyCase) {
|
|||||||
ASSERT_EQ(tensors.size(), tensors_expected.size());
|
ASSERT_EQ(tensors.size(), tensors_expected.size());
|
||||||
test::ExpectTensorEqual<float>(tensors[0], tensors_expected[0]);
|
test::ExpectTensorEqual<float>(tensors[0], tensors_expected[0]);
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
TEST_F(ConstantFoldingTest, SimplifySelect) {
|
TEST_F(ConstantFoldingTest, SimplifySelect) {
|
||||||
for (bool scalar_pred : {true, false}) {
|
for (bool scalar_pred : {true, false}) {
|
||||||
|
Loading…
Reference in New Issue
Block a user