Merge pull request #28057 from benbarsdell:fix-const-fold-dyn-shape-noop-reduction
PiperOrigin-RevId: 246892307
This commit is contained in:
commit
a83582dc61
@ -2382,20 +2382,28 @@ bool ConstantFolding::SimplifySwitch(GraphDef* optimized_graph, NodeDef* node) {
|
||||
return false;
|
||||
}
|
||||
|
||||
bool ConstantFolding::IsReductionCandidateForSimplification(
|
||||
const NodeDef& node, const GraphProperties& properties,
|
||||
TensorShapeProto* input_tensor_shape, TensorShapeProto* output_tensor_shape,
|
||||
bool* is_single_element_op) const {
|
||||
bool ConstantFolding::IsReductionWithConstantIndices(
|
||||
const NodeDef& node, bool* indices_is_empty) const {
|
||||
// Ensure its an appropriate Reduce node.
|
||||
if (!IsReduction(node) || node.input_size() < 2) {
|
||||
return false;
|
||||
}
|
||||
// Ensure that the axes to reduce by are constant.
|
||||
NodeDef* reductions_indices = node_map_->GetNode(node.input(1));
|
||||
if (!IsReallyConstant(*reductions_indices)) {
|
||||
if (!IsReallyConstant(*reductions_indices) ||
|
||||
!reductions_indices->attr().count("value")) {
|
||||
return false;
|
||||
}
|
||||
const TensorShapeProto& reduction_indices_shape =
|
||||
reductions_indices->attr().at("value").tensor().tensor_shape();
|
||||
*indices_is_empty = TensorShape(reduction_indices_shape).num_elements() == 0;
|
||||
return true;
|
||||
}
|
||||
|
||||
bool ConstantFolding::IsReductionCandidateForSimplification(
|
||||
const NodeDef& node, const GraphProperties& properties,
|
||||
TensorShapeProto* input_tensor_shape, TensorShapeProto* output_tensor_shape,
|
||||
bool* is_single_element_op) const {
|
||||
// Get the properties of the input & output tensors and check if they both
|
||||
// contain a single element.
|
||||
if (!properties.HasInputProperties(node.name()) ||
|
||||
@ -2460,9 +2468,34 @@ bool ConstantFolding::IsReductionSimplifiableToIdentity(
|
||||
return simplifiable;
|
||||
}
|
||||
|
||||
bool ConstantFolding::ReplaceReductionWithIdentity(NodeDef* node) const {
|
||||
// Replace the reduction node with an identity node, that can be further
|
||||
// optimized by other passes.
|
||||
DataType output_type;
|
||||
if (node->attr().count("T") != 0) {
|
||||
output_type = node->attr().at("T").type();
|
||||
} else if (IsAny(*node) || IsAll(*node)) {
|
||||
output_type = DT_BOOL;
|
||||
} else {
|
||||
return false;
|
||||
}
|
||||
node->set_op("Identity");
|
||||
node->clear_attr();
|
||||
(*node->mutable_attr())["T"].set_type(output_type);
|
||||
*node->mutable_input(1) = AsControlDependency(node->input(1));
|
||||
return true;
|
||||
}
|
||||
|
||||
bool ConstantFolding::SimplifyReduction(GraphDef* optimized_graph,
|
||||
const GraphProperties& properties,
|
||||
NodeDef* node) {
|
||||
bool indices_is_empty = false;
|
||||
if (!IsReductionWithConstantIndices(*node, &indices_is_empty)) {
|
||||
return false;
|
||||
}
|
||||
if (indices_is_empty) {
|
||||
return ReplaceReductionWithIdentity(node);
|
||||
}
|
||||
bool is_single_element_op = false;
|
||||
TensorShapeProto input_tensor_shape, output_tensor_shape;
|
||||
if (!IsReductionCandidateForSimplification(
|
||||
@ -2524,20 +2557,7 @@ bool ConstantFolding::SimplifyReduction(GraphDef* optimized_graph,
|
||||
(*node->mutable_attr())["Tshape"] = attr_type_indices;
|
||||
return true;
|
||||
} else if (simplifiable_to_identity) {
|
||||
// Replace the reduction node with an identity node, that can be further
|
||||
// optimized by the model pruner.
|
||||
DataType output_type;
|
||||
if (node->attr().count("T") != 0) {
|
||||
output_type = node->attr().at("T").type();
|
||||
} else {
|
||||
// This is an 'any' or 'all' reduction. The output is always boolean.
|
||||
output_type = DT_BOOL;
|
||||
}
|
||||
node->set_op("Identity");
|
||||
node->clear_attr();
|
||||
(*node->mutable_attr())["T"].set_type(output_type);
|
||||
*node->mutable_input(1) = AsControlDependency(node->input(1));
|
||||
return true;
|
||||
return ReplaceReductionWithIdentity(node);
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
@ -153,6 +153,11 @@ class ConstantFolding : public GraphOptimizer {
|
||||
bool SimplifyReshape(const GraphProperties& properties, bool use_shape_info,
|
||||
NodeDef* node);
|
||||
|
||||
// Returns true iff the node is a reduction and its reduction indices are
|
||||
// constant. Sets *indices_is_empty to true if the set of dimensions to reduce
|
||||
// along is empty (this happens often in the gradient graphs).
|
||||
bool IsReductionWithConstantIndices(const NodeDef& node,
|
||||
bool* indices_is_empty) const;
|
||||
// Returns true if theres a possibility that a Reduce node could be simplified
|
||||
// to an Identity/Reshape.
|
||||
bool IsReductionCandidateForSimplification(
|
||||
@ -160,11 +165,12 @@ class ConstantFolding : public GraphOptimizer {
|
||||
TensorShapeProto* input_tensor_shape,
|
||||
TensorShapeProto* output_tensor_shape, bool* is_single_element_op) const;
|
||||
// Returns true iff this reduction can be reduced to an identity (i.e if the
|
||||
// set of dimensions to reduce along is empty). This happens often in the
|
||||
// gradient graphs.
|
||||
// input dimensions to reduce along are all of size 1 and keep_dims is true).
|
||||
bool IsReductionSimplifiableToIdentity(
|
||||
const NodeDef& node, const TensorShapeProto& input_shape, bool keep_dims,
|
||||
const gtl::InlinedVector<TensorValue, 4>& reduction_indices_vector) const;
|
||||
// Changes a reduction into an Identity op, returning true on success.
|
||||
bool ReplaceReductionWithIdentity(NodeDef* node) const;
|
||||
// Simplifies a Reduction operation to an Identity/Reshape operation if
|
||||
// applicable.
|
||||
bool SimplifyReduction(GraphDef* optimized_graph,
|
||||
|
@ -2511,8 +2511,12 @@ TEST_F(ConstantFoldingTest, NoOpReduction) {
|
||||
attr = attr.KeepDims(true);
|
||||
Output p2 = ops::Prod(scope.WithOpName("p2"), v2, c2, attr);
|
||||
|
||||
// Test with unknown input shape.
|
||||
Output a = ops::Placeholder(scope.WithOpName("a"), DT_FLOAT);
|
||||
Output p3 = ops::Prod(scope.WithOpName("p3"), a, i, attr);
|
||||
|
||||
GrapplerItem item;
|
||||
item.fetch = {"s", "p2"};
|
||||
item.fetch = {"s", "p2", "p3"};
|
||||
TF_CHECK_OK(scope.ToGraphDef(&item.graph));
|
||||
|
||||
ConstantFolding optimizer(/*cpu_device=*/nullptr);
|
||||
@ -2534,19 +2538,28 @@ TEST_F(ConstantFoldingTest, NoOpReduction) {
|
||||
EXPECT_EQ(2, node.input_size());
|
||||
EXPECT_EQ("v2", node.input(0));
|
||||
EXPECT_EQ("^c2", node.input(1));
|
||||
} else if (node.name() == "p3") {
|
||||
found++;
|
||||
EXPECT_EQ("Identity", node.op());
|
||||
EXPECT_EQ(2, node.input_size());
|
||||
EXPECT_EQ("a", node.input(0));
|
||||
EXPECT_EQ("^i", node.input(1));
|
||||
}
|
||||
}
|
||||
EXPECT_EQ(2, found);
|
||||
EXPECT_EQ(3, found);
|
||||
|
||||
auto v_t = GenerateRandomTensor<DT_FLOAT>(TensorShape({3, 5, 7}));
|
||||
auto v2_t = GenerateRandomTensor<DT_FLOAT>(TensorShape({3, 5, 1}));
|
||||
auto tensors_expected =
|
||||
EvaluateNodes(item.graph, item.fetch, {{"v", v_t}, {"v2", v2_t}});
|
||||
EXPECT_EQ(2, tensors_expected.size());
|
||||
auto tensors = EvaluateNodes(output, item.fetch, {{"v", v_t}, {"v2", v2_t}});
|
||||
EXPECT_EQ(2, tensors.size());
|
||||
auto a_t = GenerateRandomTensor<DT_FLOAT>(TensorShape({3, 5, 7}));
|
||||
auto tensors_expected = EvaluateNodes(item.graph, item.fetch,
|
||||
{{"v", v_t}, {"v2", v2_t}, {"a", a_t}});
|
||||
EXPECT_EQ(3, tensors_expected.size());
|
||||
auto tensors =
|
||||
EvaluateNodes(output, item.fetch, {{"v", v_t}, {"v2", v2_t}, {"a", a_t}});
|
||||
EXPECT_EQ(3, tensors.size());
|
||||
test::ExpectTensorNear<float>(tensors_expected[0], tensors[0], 1e-5);
|
||||
test::ExpectTensorNear<float>(tensors_expected[1], tensors[1], 1e-5);
|
||||
test::ExpectTensorNear<float>(tensors_expected[2], tensors[2], 1e-5);
|
||||
}
|
||||
|
||||
TEST_F(ConstantFoldingTest, SingleElementEmptyAxisReduction) {
|
||||
|
Loading…
Reference in New Issue
Block a user