diff --git a/tensorflow/core/grappler/optimizers/BUILD b/tensorflow/core/grappler/optimizers/BUILD index f88b995c89f..42b23576d68 100644 --- a/tensorflow/core/grappler/optimizers/BUILD +++ b/tensorflow/core/grappler/optimizers/BUILD @@ -97,8 +97,10 @@ cc_library( "//tensorflow/core:lib", "//tensorflow/core:protos_all_cc", "//tensorflow/core/grappler:grappler_item", + "//tensorflow/core/grappler:op_types", "//tensorflow/core/grappler:utils", "//tensorflow/core/grappler/clusters:cluster", + "//tensorflow/core/grappler/costs:graph_properties", ], ) diff --git a/tensorflow/core/grappler/optimizers/constant_folding.cc b/tensorflow/core/grappler/optimizers/constant_folding.cc index e66ae05fd75..781387e29a6 100644 --- a/tensorflow/core/grappler/optimizers/constant_folding.cc +++ b/tensorflow/core/grappler/optimizers/constant_folding.cc @@ -24,7 +24,9 @@ limitations under the License. #include "tensorflow/core/framework/tensor_shape.pb.h" #include "tensorflow/core/framework/types.h" #include "tensorflow/core/grappler/clusters/cluster.h" +#include "tensorflow/core/grappler/costs/graph_properties.h" #include "tensorflow/core/grappler/grappler_item.h" +#include "tensorflow/core/grappler/op_types.h" #include "tensorflow/core/lib/gtl/inlined_vector.h" #include "tensorflow/core/lib/strings/numbers.h" #include "tensorflow/core/lib/strings/strcat.h" @@ -99,8 +101,86 @@ Status NumOutputs(const NodeDef& node, int* num_outputs) { } } // namespace -bool ConstantFolding::IsConst(const NodeDef& node) const { - return node.op() == "Const"; +Status ConstantFolding::MaterializeShapes(const GrapplerItem& item) { + GraphProperties properties(item); + TF_RETURN_IF_ERROR(properties.InferStatically()); + for (NodeDef& node : *graph_.mutable_node()) { + const string op = node.op(); + if (op != "Shape" && op != "Size" && op != "Rank") { + continue; + } + std::vector output = + properties.GetOutputProperties(node.name()); + CHECK_EQ(1, output.size()); + const DataType type = output[0].dtype(); + CHECK(type == DT_INT32 || type == DT_INT64); + + std::vector input = + properties.GetInputProperties(node.name()); + CHECK_EQ(1, input.size()); + const TensorShapeProto shape = input[0].shape(); + + // Materialize the shapes using constants whenever possible. + TensorShape shp(shape); + if (shp.IsFullyDefined() || (!shp.unknown_rank() && op == "Rank")) { + bool valid = true; + Tensor value(type); + if (op == "Shape") { + value = Tensor(type, TensorShape({shp.dims()})); + for (int i = 0; i < shp.dims(); ++i) { + if (type == DT_INT32) { + if (shp.dim_size(i) >= INT_MAX) { + valid = false; + break; + } + value.flat()(i) = shp.dim_size(i); + } else { + value.flat()(i) = shp.dim_size(i); + } + } + } else if (op == "Size") { + int64 size = 1; + for (int i = 0; i < shp.dims(); ++i) { + size *= shp.dim_size(i); + } + value = Tensor(type, TensorShape({})); + if (type == DT_INT32) { + if (size >= INT_MAX) { + valid = false; + } else { + value.flat()(0) = size; + } + } else { + value.flat()(0) = size; + } + } else { + value = Tensor(type, TensorShape({})); + if (type == DT_INT32) { + if (shp.dims() >= INT_MAX) { + valid = false; + } else { + value.flat()(0) = shp.dims(); + } + } else { + value.flat()(0) = shp.dims(); + } + } + + if (valid) { + // Replace the node with the corresponding constant. + node.set_op("Const"); + node.clear_attr(); + (*node.mutable_attr())["dtype"].set_type(type); + value.AsProtoTensorContent( + (*node.mutable_attr())["value"].mutable_tensor()); + + // Turn the inputs into control dependencies. + CHECK_EQ(1, node.input_size()); + node.set_input(0, strings::StrCat("^", node.input(0))); + } + } + } + return Status::OK(); } bool ConstantFolding::IsFoldable(const NodeDef& node) const { @@ -155,7 +235,7 @@ bool ConstantFolding::IsFoldable(const NodeDef& node) const { if (input[0] == '^') { continue; } - bool is_const = IsConst(*node_map_->GetNode(input)); + bool is_const = IsConstant(*node_map_->GetNode(input)); if (!is_const) { return false; } @@ -326,6 +406,7 @@ Status ConstantFolding::Optimize(Cluster* cluster, const GrapplerItem& item, } device_.reset(new DeviceSimple()); *output = GraphDef(); + TF_RETURN_IF_ERROR(MaterializeShapes(item)); TF_RETURN_IF_ERROR(FoldGraph(output)); LOG(INFO) << "Optimized graph size: " << output->node_size(); return Status::OK(); diff --git a/tensorflow/core/grappler/optimizers/constant_folding.h b/tensorflow/core/grappler/optimizers/constant_folding.h index d5b1c9e4bb8..fd77fc945e3 100644 --- a/tensorflow/core/grappler/optimizers/constant_folding.h +++ b/tensorflow/core/grappler/optimizers/constant_folding.h @@ -42,7 +42,7 @@ class ConstantFolding : public GraphOptimizer { const GraphDef& optimize_output, double result) override; private: - bool IsConst(const NodeDef& node) const; + Status MaterializeShapes(const GrapplerItem& item); bool IsFoldable(const NodeDef& node) const; diff --git a/tensorflow/core/grappler/optimizers/constant_folding_test.cc b/tensorflow/core/grappler/optimizers/constant_folding_test.cc index 6259caa0fb1..58bbb817d0b 100644 --- a/tensorflow/core/grappler/optimizers/constant_folding_test.cc +++ b/tensorflow/core/grappler/optimizers/constant_folding_test.cc @@ -193,6 +193,58 @@ TEST_F(ConstantFoldingTest, ControlDependencies) { EXPECT_EQ(2, found); } +TEST_F(ConstantFoldingTest, ShapeMaterialization) { + tensorflow::Scope scope = tensorflow::Scope::NewRootScope(); + Output v1 = ops::Variable(scope.WithOpName("v1"), {3}, DT_FLOAT); + Output v2 = ops::Variable(scope.WithOpName("v2"), {5, 7}, DT_FLOAT); + Output v3 = ops::Variable(scope.WithOpName("v3"), {11, 13}, DT_FLOAT); + Output rank = ops::Rank(scope.WithOpName("rank"), v1); + Output shape = ops::Shape(scope.WithOpName("shape"), v2); + Output size = ops::Size(scope.WithOpName("size"), v3); + Output p1 = ops::Multiply(scope.WithOpName("p1"), size, rank); + Output p2 = ops::Multiply(scope.WithOpName("p2"), p1, shape); + + GrapplerItem item; + item.fetch.push_back("p2"); + TF_CHECK_OK(scope.ToGraphDef(&item.graph)); + + ConstantFolding fold; + GraphDef output; + Status status = fold.Optimize(nullptr, item, &output); + TF_EXPECT_OK(status); + + int found = 0; + for (const auto& node : output.node()) { + if (node.name() == "size") { + ++found; + EXPECT_EQ("Const", node.op()); + EXPECT_EQ(1, node.input_size()); + EXPECT_EQ("^v3", node.input(0)); + Tensor value; + CHECK(value.FromProto(node.attr().at("value").tensor())); + EXPECT_EQ(11 * 13, value.flat()(0)); + } else if (node.name() == "rank") { + ++found; + EXPECT_EQ("Const", node.op()); + EXPECT_EQ(1, node.input_size()); + EXPECT_EQ("^v1", node.input(0)); + Tensor value; + CHECK(value.FromProto(node.attr().at("value").tensor())); + EXPECT_EQ(1, value.flat()(0)); + } else if (node.name() == "shape") { + ++found; + EXPECT_EQ("Const", node.op()); + EXPECT_EQ(1, node.input_size()); + EXPECT_EQ("^v2", node.input(0)); + Tensor value; + CHECK(value.FromProto(node.attr().at("value").tensor())); + EXPECT_EQ(5, value.flat()(0)); + EXPECT_EQ(7, value.flat()(1)); + } + } + EXPECT_EQ(3, found); +} + } // namespace } // namespace grappler } // namespace tensorflow