diff --git a/tensorflow/core/grappler/optimizers/constant_folding.cc b/tensorflow/core/grappler/optimizers/constant_folding.cc index 718aa69ebf6..acd642044b4 100644 --- a/tensorflow/core/grappler/optimizers/constant_folding.cc +++ b/tensorflow/core/grappler/optimizers/constant_folding.cc @@ -680,6 +680,7 @@ Status CreateConstantTensorAttrValue(DataType type, double value, const TensorShapeProto& shape, AttrValue* attr_tensor) { TensorProto* t = attr_tensor->mutable_tensor(); + t->set_dtype(type); *t->mutable_tensor_shape() = shape; switch (type) { SET_TENSOR_VAL_CASE(DT_FLOAT, float, float); @@ -1332,45 +1333,47 @@ Status ConstantFolding::SimplifyGraph(GraphDef* output, } const TensorShapeProto& output_shape = properties.GetOutputProperties(node.name())[0].shape(); - const TensorShapeProto& x_shape = - properties.GetInputProperties(node.name())[0].shape(); + + // Simplify element-wise multiplication by ones or addition of zeros. const TensorShapeProto& y_shape = properties.GetInputProperties(node.name())[1].shape(); const bool x_is_zero = IsZeros(*x); - const bool x_matches_output_shape = ShapesEqual(output_shape, x_shape); - const bool y_is_zero = IsZeros(*y); + const bool x_is_one = IsOnes(*x); const bool y_matches_output_shape = ShapesEqual(output_shape, y_shape); - - // Simplify addition of zeros. - if (is_add) { - if (x_is_zero && y_matches_output_shape) { - // 0 + y = y. - ReplaceAddOrMulWithIdentity(1, &node); - continue; - } else if (y_is_zero && x_matches_output_shape) { - // x + 0 = y. - ReplaceAddOrMulWithIdentity(0, &node); - continue; - } + if (y_matches_output_shape && + ((is_mul && x_is_one) || (is_add && x_is_zero))) { + // 1 * y = y or 0 + y = y. + ReplaceAddOrMulWithIdentity(1, &node); + continue; } - - // Simplify element-wise multiplication by ones. - if (is_mul) { - if (IsOnes(*x) && y_matches_output_shape) { - // 1 * y = y. - ReplaceAddOrMulWithIdentity(1, &node); - continue; - } - if (IsOnes(*y) && x_matches_output_shape) { - // x * 1 = x. - ReplaceAddOrMulWithIdentity(0, &node); - continue; - } + const TensorShapeProto& x_shape = + properties.GetInputProperties(node.name())[0].shape(); + const bool y_is_zero = IsZeros(*y); + const bool y_is_one = IsOnes(*y); + const bool x_matches_output_shape = ShapesEqual(output_shape, x_shape); + if (x_matches_output_shape && + ((is_mul && y_is_one) || (is_add && y_is_zero))) { + // x * 1 = x or x + 0 = x + ReplaceAddOrMulWithIdentity(0, &node); + continue; } // Simplify multiplication and matmul by zeros. - if (x_is_zero || y_is_zero) { - TF_RETURN_IF_ERROR(ReplaceAddOrMulWithConstant(0, output_shape, &node)); + if (!is_add && (x_is_zero || y_is_zero)) { + const PartialTensorShape shp(output_shape); + if (shp.IsFullyDefined()) { + TF_RETURN_IF_ERROR( + ReplaceAddOrMulWithConstant(0, output_shape, &node)); + continue; + } + // Even if an input shape is only partially known, we may known that it + // matches the output shape and thus forward the corresponding zero + // input. + if (is_mul && x_is_zero && x_matches_output_shape) { + ReplaceAddOrMulWithIdentity(0, &node); + } else if (is_mul && y_is_zero && y_matches_output_shape) { + ReplaceAddOrMulWithIdentity(1, &node); + } } } } diff --git a/tensorflow/core/grappler/optimizers/constant_folding_test.cc b/tensorflow/core/grappler/optimizers/constant_folding_test.cc index ffa09b8e294..21011eb7902 100644 --- a/tensorflow/core/grappler/optimizers/constant_folding_test.cc +++ b/tensorflow/core/grappler/optimizers/constant_folding_test.cc @@ -198,6 +198,120 @@ TEST_F(ConstantFoldingTest, NeutralElement) { } } +TEST_F(ConstantFoldingTest, NeutralElement_PartialShape_UnknownOutputShape) { + tensorflow::Scope s = tensorflow::Scope::NewRootScope(); + Output x_known = + ops::Placeholder(s.WithOpName("x_known"), DT_FLOAT, + ops::Placeholder::Shape(TensorShape({2, 2}))); + Output x_partially_known = + ops::Placeholder(s.WithOpName("x_partially_unknown"), DT_FLOAT, + ops::Placeholder::Shape(PartialTensorShape({-1, -1}))); + Output x_unknown = ops::Placeholder(s.WithOpName("x_unknown"), DT_FLOAT); + Output zeros_known = ops::ZerosLike(s.WithOpName("zeros_known"), x_known); + Output zeros_partially_known = + ops::ZerosLike(s.WithOpName("zeros_partially_known"), x_partially_known); + Output zeros_unknown = + ops::ZerosLike(s.WithOpName("zeros_unknown"), x_unknown); + + // Multiplies without any additional ops to supply the output shape. + int count = 0; + std::vector muls; + std::unordered_set not_converted; + std::unordered_set to_const; + std::unordered_set to_identity; + for (const auto* x : {&x_known, &x_partially_known, &x_unknown}) { + for (const auto* zeros : + {&zeros_known, &zeros_partially_known, &zeros_unknown}) { + const string name = strings::StrCat("mul_", count++); + muls.push_back(ops::Mul(s.WithOpName(name), *x, *zeros)); + if (x == &x_partially_known && zeros == &zeros_partially_known) { + to_identity.insert(name); + } else if (x == &x_unknown || zeros == &zeros_unknown) { + not_converted.insert(name); + } else { + to_const.insert(name); + } + } + } + + GrapplerItem item; + TF_CHECK_OK(s.ToGraphDef(&item.graph)); + + ConstantFolding optimizer(RewriterConfig::AGGRESSIVE, + nullptr /* cpu_device */); + GraphDef output; + Status status = optimizer.Optimize(nullptr, item, &output); + TF_EXPECT_OK(status); + LOG(INFO) << output.DebugString(); + + EXPECT_EQ(15, output.node_size()); + for (int i = 0; i < output.node_size(); ++i) { + const NodeDef& node = output.node(i); + const string& name = node.name(); + if (to_const.count(name) > 0) { + EXPECT_EQ("Const", node.op()) << node.name(); + } else if (to_identity.count(name) > 0) { + EXPECT_EQ("Identity", node.op()) << node.name(); + } else if (not_converted.count(name) > 0) { + EXPECT_EQ("Mul", node.op()) << node.name(); + } + } +} + +TEST_F(ConstantFoldingTest, NeutralElement_PartialShape_KnownOutputShape) { + tensorflow::Scope s = tensorflow::Scope::NewRootScope(); + Output known_shape = ops::Const(s.WithOpName("known_shape"), 0.0f, {2, 2}); + Output x_partially_known = + ops::Placeholder(s.WithOpName("x_partially_unknown"), DT_FLOAT, + ops::Placeholder::Shape(PartialTensorShape({-1, -1}))); + Output x_unknown = ops::Placeholder(s.WithOpName("x_unknown"), DT_FLOAT); + Output zeros_partially_known = + ops::ZerosLike(s.WithOpName("zeros_partially_known"), x_partially_known); + Output zeros_unknown = + ops::ZerosLike(s.WithOpName("zeros_unknown"), x_unknown); + + // If at least one of the inputs to AddN has a known shape, shape inference + // will propagate the shape back to the inputs of AddN, making the + // output shapes of all its inputs known + std::vector muls_deduced_output_shape; + std::unordered_set to_const; + int count = 0; + for (const auto& x : {x_partially_known, x_unknown}) { + for (const auto& zeros : {zeros_partially_known, zeros_unknown}) { + const string name = strings::StrCat("mul_", count++); + muls_deduced_output_shape.push_back( + ops::Mul(s.WithOpName(name), x, zeros)); + to_const.insert(name); + } + } + // We add a known shape as input to AddN to propagate it back to the + // multiplies above, which means they can all be turned into Const nodes. + muls_deduced_output_shape.push_back(known_shape); + Output addn1 = ops::AddN(s.WithOpName("addn1"), muls_deduced_output_shape); + + GrapplerItem item; + TF_CHECK_OK(s.ToGraphDef(&item.graph)); + + ConstantFolding optimizer(RewriterConfig::AGGRESSIVE, + nullptr /* cpu_device */); + GraphDef output; + Status status = optimizer.Optimize(nullptr, item, &output); + TF_EXPECT_OK(status); + LOG(INFO) << output.DebugString(); + + EXPECT_EQ(10, output.node_size()); + for (int i = 0; i < output.node_size(); ++i) { + const NodeDef& node = output.node(i); + const string& name = node.name(); + if (to_const.count(name) > 0) { + EXPECT_EQ("Const", node.op()) << node.name(); + EXPECT_EQ(2, node.input_size()); + EXPECT_TRUE(IsControlInput(node.input(0))); + EXPECT_TRUE(IsControlInput(node.input(1))); + } + } +} + TEST_F(ConstantFoldingTest, CreateConstNodes) { tensorflow::Scope s = tensorflow::Scope::NewRootScope();