Merge pull request #28057 from benbarsdell:fix-const-fold-dyn-shape-noop-reduction

PiperOrigin-RevId: 246892307
This commit is contained in:
TensorFlower Gardener 2019-05-06 15:04:25 -07:00
commit a83582dc61
3 changed files with 67 additions and 28 deletions

View File

@ -2382,20 +2382,28 @@ bool ConstantFolding::SimplifySwitch(GraphDef* optimized_graph, NodeDef* node) {
return false;
}
bool ConstantFolding::IsReductionCandidateForSimplification(
const NodeDef& node, const GraphProperties& properties,
TensorShapeProto* input_tensor_shape, TensorShapeProto* output_tensor_shape,
bool* is_single_element_op) const {
bool ConstantFolding::IsReductionWithConstantIndices(
const NodeDef& node, bool* indices_is_empty) const {
// Ensure its an appropriate Reduce node.
if (!IsReduction(node) || node.input_size() < 2) {
return false;
}
// Ensure that the axes to reduce by are constant.
NodeDef* reductions_indices = node_map_->GetNode(node.input(1));
if (!IsReallyConstant(*reductions_indices)) {
if (!IsReallyConstant(*reductions_indices) ||
!reductions_indices->attr().count("value")) {
return false;
}
const TensorShapeProto& reduction_indices_shape =
reductions_indices->attr().at("value").tensor().tensor_shape();
*indices_is_empty = TensorShape(reduction_indices_shape).num_elements() == 0;
return true;
}
bool ConstantFolding::IsReductionCandidateForSimplification(
const NodeDef& node, const GraphProperties& properties,
TensorShapeProto* input_tensor_shape, TensorShapeProto* output_tensor_shape,
bool* is_single_element_op) const {
// Get the properties of the input & output tensors and check if they both
// contain a single element.
if (!properties.HasInputProperties(node.name()) ||
@ -2460,9 +2468,34 @@ bool ConstantFolding::IsReductionSimplifiableToIdentity(
return simplifiable;
}
bool ConstantFolding::ReplaceReductionWithIdentity(NodeDef* node) const {
// Replace the reduction node with an identity node, that can be further
// optimized by other passes.
DataType output_type;
if (node->attr().count("T") != 0) {
output_type = node->attr().at("T").type();
} else if (IsAny(*node) || IsAll(*node)) {
output_type = DT_BOOL;
} else {
return false;
}
node->set_op("Identity");
node->clear_attr();
(*node->mutable_attr())["T"].set_type(output_type);
*node->mutable_input(1) = AsControlDependency(node->input(1));
return true;
}
bool ConstantFolding::SimplifyReduction(GraphDef* optimized_graph,
const GraphProperties& properties,
NodeDef* node) {
bool indices_is_empty = false;
if (!IsReductionWithConstantIndices(*node, &indices_is_empty)) {
return false;
}
if (indices_is_empty) {
return ReplaceReductionWithIdentity(node);
}
bool is_single_element_op = false;
TensorShapeProto input_tensor_shape, output_tensor_shape;
if (!IsReductionCandidateForSimplification(
@ -2524,20 +2557,7 @@ bool ConstantFolding::SimplifyReduction(GraphDef* optimized_graph,
(*node->mutable_attr())["Tshape"] = attr_type_indices;
return true;
} else if (simplifiable_to_identity) {
// Replace the reduction node with an identity node, that can be further
// optimized by the model pruner.
DataType output_type;
if (node->attr().count("T") != 0) {
output_type = node->attr().at("T").type();
} else {
// This is an 'any' or 'all' reduction. The output is always boolean.
output_type = DT_BOOL;
}
node->set_op("Identity");
node->clear_attr();
(*node->mutable_attr())["T"].set_type(output_type);
*node->mutable_input(1) = AsControlDependency(node->input(1));
return true;
return ReplaceReductionWithIdentity(node);
}
return false;
}

View File

@ -153,6 +153,11 @@ class ConstantFolding : public GraphOptimizer {
bool SimplifyReshape(const GraphProperties& properties, bool use_shape_info,
NodeDef* node);
// Returns true iff the node is a reduction and its reduction indices are
// constant. Sets *indices_is_empty to true if the set of dimensions to reduce
// along is empty (this happens often in the gradient graphs).
bool IsReductionWithConstantIndices(const NodeDef& node,
bool* indices_is_empty) const;
// Returns true if theres a possibility that a Reduce node could be simplified
// to an Identity/Reshape.
bool IsReductionCandidateForSimplification(
@ -160,11 +165,12 @@ class ConstantFolding : public GraphOptimizer {
TensorShapeProto* input_tensor_shape,
TensorShapeProto* output_tensor_shape, bool* is_single_element_op) const;
// Returns true iff this reduction can be reduced to an identity (i.e if the
// set of dimensions to reduce along is empty). This happens often in the
// gradient graphs.
// input dimensions to reduce along are all of size 1 and keep_dims is true).
bool IsReductionSimplifiableToIdentity(
const NodeDef& node, const TensorShapeProto& input_shape, bool keep_dims,
const gtl::InlinedVector<TensorValue, 4>& reduction_indices_vector) const;
// Changes a reduction into an Identity op, returning true on success.
bool ReplaceReductionWithIdentity(NodeDef* node) const;
// Simplifies a Reduction operation to an Identity/Reshape operation if
// applicable.
bool SimplifyReduction(GraphDef* optimized_graph,

View File

@ -2511,8 +2511,12 @@ TEST_F(ConstantFoldingTest, NoOpReduction) {
attr = attr.KeepDims(true);
Output p2 = ops::Prod(scope.WithOpName("p2"), v2, c2, attr);
// Test with unknown input shape.
Output a = ops::Placeholder(scope.WithOpName("a"), DT_FLOAT);
Output p3 = ops::Prod(scope.WithOpName("p3"), a, i, attr);
GrapplerItem item;
item.fetch = {"s", "p2"};
item.fetch = {"s", "p2", "p3"};
TF_CHECK_OK(scope.ToGraphDef(&item.graph));
ConstantFolding optimizer(/*cpu_device=*/nullptr);
@ -2534,19 +2538,28 @@ TEST_F(ConstantFoldingTest, NoOpReduction) {
EXPECT_EQ(2, node.input_size());
EXPECT_EQ("v2", node.input(0));
EXPECT_EQ("^c2", node.input(1));
} else if (node.name() == "p3") {
found++;
EXPECT_EQ("Identity", node.op());
EXPECT_EQ(2, node.input_size());
EXPECT_EQ("a", node.input(0));
EXPECT_EQ("^i", node.input(1));
}
}
EXPECT_EQ(2, found);
EXPECT_EQ(3, found);
auto v_t = GenerateRandomTensor<DT_FLOAT>(TensorShape({3, 5, 7}));
auto v2_t = GenerateRandomTensor<DT_FLOAT>(TensorShape({3, 5, 1}));
auto tensors_expected =
EvaluateNodes(item.graph, item.fetch, {{"v", v_t}, {"v2", v2_t}});
EXPECT_EQ(2, tensors_expected.size());
auto tensors = EvaluateNodes(output, item.fetch, {{"v", v_t}, {"v2", v2_t}});
EXPECT_EQ(2, tensors.size());
auto a_t = GenerateRandomTensor<DT_FLOAT>(TensorShape({3, 5, 7}));
auto tensors_expected = EvaluateNodes(item.graph, item.fetch,
{{"v", v_t}, {"v2", v2_t}, {"a", a_t}});
EXPECT_EQ(3, tensors_expected.size());
auto tensors =
EvaluateNodes(output, item.fetch, {{"v", v_t}, {"v2", v2_t}, {"a", a_t}});
EXPECT_EQ(3, tensors.size());
test::ExpectTensorNear<float>(tensors_expected[0], tensors[0], 1e-5);
test::ExpectTensorNear<float>(tensors_expected[1], tensors[1], 1e-5);
test::ExpectTensorNear<float>(tensors_expected[2], tensors[2], 1e-5);
}
TEST_F(ConstantFoldingTest, SingleElementEmptyAxisReduction) {