diff --git a/tensorflow/core/grappler/optimizers/constant_folding.cc b/tensorflow/core/grappler/optimizers/constant_folding.cc index faea843c69b..ea036604408 100644 --- a/tensorflow/core/grappler/optimizers/constant_folding.cc +++ b/tensorflow/core/grappler/optimizers/constant_folding.cc @@ -122,9 +122,9 @@ string ConstantFolding::AddControlDependency(const string& input_name) { } // We haven't found an existing node where we can anchor the control // dependency: add a new identity node. - int position = 0; - string ctrl_dep_name = ParseNodeName(input_name, &position); - strings::StrAppend(&ctrl_dep_name, "_", position); + int port = 0; + string ctrl_dep_name = ParseNodeName(input_name, &port); + strings::StrAppend(&ctrl_dep_name, "_", port); ctrl_dep_name = AddPrefixToNodeName(ctrl_dep_name, kConstantFoldingCtrl); const DataType output_type = node->attr().at("T").type(); @@ -141,6 +141,48 @@ string ConstantFolding::AddControlDependency(const string& input_name) { } } +Status ConvertShapeToConstant(const string& op, const DataType& type, + const PartialTensorShape& shp, Tensor* value) { + if (op == "Shape" || op == "ShapeN") { + *value = Tensor(type, TensorShape({shp.dims()})); + for (int i = 0; i < shp.dims(); ++i) { + if (type == DT_INT32) { + if (shp.dim_size(i) >= INT_MAX) { + return Status(error::INVALID_ARGUMENT, "Invalid dimension size"); + } + value->flat()(i) = shp.dim_size(i); + } else { + value->flat()(i) = shp.dim_size(i); + } + } + } else if (op == "Size") { + int64 size = 1; + for (int i = 0; i < shp.dims(); ++i) { + size *= shp.dim_size(i); + } + *value = Tensor(type, TensorShape({})); + if (type == DT_INT32) { + if (size >= INT_MAX) { + return Status(error::INVALID_ARGUMENT, "Invalid dimension size"); + } + value->flat()(0) = size; + } else { + value->flat()(0) = size; + } + } else { + *value = Tensor(type, TensorShape({})); + if (type == DT_INT32) { + if (shp.dims() >= INT_MAX) { + return Status(error::INVALID_ARGUMENT, "Invalid dimension size"); + } + value->flat()(0) = shp.dims(); + } else { + value->flat()(0) = shp.dims(); + } + } + return Status::OK(); +} + Status ConstantFolding::MaterializeShapes(const GrapplerItem& item, const GraphProperties& properties) { // We may add some nodes to the graph to encode control dependencies: there is @@ -150,83 +192,84 @@ Status ConstantFolding::MaterializeShapes(const GrapplerItem& item, for (int i = 0; i < node_count; ++i) { NodeDef& node = *graph_.mutable_node(i); const string op = node.op(); - if (op != "Shape" && op != "Size" && op != "Rank") { + if (op != "Shape" && op != "Size" && op != "Rank" && op != "ShapeN") { continue; } + std::vector output = properties.GetOutputProperties(node.name()); - CHECK_EQ(1, output.size()); - const DataType type = output[0].dtype(); - CHECK(type == DT_INT32 || type == DT_INT64); - std::vector input = properties.GetInputProperties(node.name()); - CHECK_EQ(1, input.size()); + if (op == "Shape" || op == "Size" || op == "Rank") { + CHECK_EQ(1, output.size()); + CHECK_EQ(1, input.size()); + } + CHECK_EQ(input.size(), output.size()); - const TensorShapeProto shape = input[0].shape(); - // Materialize the shapes using constants whenever possible. - PartialTensorShape shp(shape); - if (shp.IsFullyDefined() || (!shp.unknown_rank() && op == "Rank")) { - bool valid = true; - Tensor value(type); - if (op == "Shape") { - value = Tensor(type, TensorShape({shp.dims()})); - for (int i = 0; i < shp.dims(); ++i) { - if (type == DT_INT32) { - if (shp.dim_size(i) >= INT_MAX) { - valid = false; - break; + for (int j = 0; j < output.size(); ++j) { + const DataType type = output[j].dtype(); + CHECK(type == DT_INT32 || type == DT_INT64); + const TensorShapeProto shape = input[j].shape(); + // Materialize the shapes using constants whenever possible. + PartialTensorShape shp(shape); + if (shp.IsFullyDefined() || (!shp.unknown_rank() && op == "Rank")) { + Tensor value(type); + auto status = ConvertShapeToConstant(op, type, shp, &value); + if (!status.ok()) { + continue; + } + // We rewrite the existing node for the first const output and + // create new nodes for the remaining const outputs (Note that ShapeN + // could have multiple outputs). + if (op == "Shape" || op == "Size" || op == "Rank") { + // Replace the node with the corresponding constant. + node.set_op("Const"); + node.clear_attr(); + (*node.mutable_attr())["dtype"].set_type(type); + value.AsProtoTensorContent( + (*node.mutable_attr())["value"].mutable_tensor()); + + // Turn the data input into a control dependency: this is needed to + // ensure that the constant value will only be run in the + // cases where the shape/rank/size would have been run in + // the original graph. Additional inputs are extra control + string ctrl_dep = AddControlDependency(node.input(0)); + node.set_input(0, ctrl_dep); + node_map_->AddOutput(NodeName(ctrl_dep), node.name()); + } else { + auto outputs = node_map_->GetOutputs(node.name()); + for (const auto& output : outputs) { + for (int k = 0; k < output->input_size(); ++k) { + int port; + string node_name = ParseNodeName(output->input(k), &port); + if (node_name == node.name() && port == j) { + // Create a const node as ShapeN's output if not already. + string const_name = + AddPrefixToNodeName(strings::StrCat(node.name(), "-", j), + kConstantFoldingConst); + if (node_map_->GetNode(const_name) == nullptr) { + NodeDef* added_node = graph_.add_node(); + added_node->set_name(const_name); + added_node->set_op("Const"); + added_node->set_device(node.device()); + node_map_->AddNode(added_node->name(), added_node); + (*added_node->mutable_attr())["dtype"].set_type(type); + value.AsProtoTensorContent( + (*added_node->mutable_attr())["value"].mutable_tensor()); + // We add a control dependency to the original ShapeN node, + // so that the node will only be run if all inputs of the + // original ShapeN node are run. + string ctrl_dep = AddControlDependency(node.name()); + *added_node->add_input() = ctrl_dep; + node_map_->AddOutput(NodeName(ctrl_dep), added_node->name()); + } + node_map_->UpdateInput(output->name(), + NodeName(output->input(k)), const_name); + *output->mutable_input(k) = const_name; + } } - value.flat()(i) = shp.dim_size(i); - } else { - value.flat()(i) = shp.dim_size(i); } } - } else if (op == "Size") { - int64 size = 1; - for (int i = 0; i < shp.dims(); ++i) { - size *= shp.dim_size(i); - } - value = Tensor(type, TensorShape({})); - if (type == DT_INT32) { - if (size >= INT_MAX) { - valid = false; - } else { - value.flat()(0) = size; - } - } else { - value.flat()(0) = size; - } - } else { - value = Tensor(type, TensorShape({})); - if (type == DT_INT32) { - if (shp.dims() >= INT_MAX) { - valid = false; - } else { - value.flat()(0) = shp.dims(); - } - } else { - value.flat()(0) = shp.dims(); - } - } - - if (valid) { - // Replace the node with the corresponding constant. - node.set_op("Const"); - node.clear_attr(); - (*node.mutable_attr())["dtype"].set_type(type); - value.AsProtoTensorContent( - (*node.mutable_attr())["value"].mutable_tensor()); - - // Turn the data input into a control dependency: this is needed to - // ensure that the constant value will only be generated in the cases - // where the shape/rank/size would have been generated in the original - // graph. Additional inputs are extra control dependencies that we - // preserve. - CHECK_LE(1, node.input_size()); - string ctrl_dep = AddControlDependency(node.input(0)); - node.set_input(0, ctrl_dep); - node_map_->AddOutput(NodeName(ctrl_dep), node.name()); } } } @@ -427,9 +470,9 @@ Status ConstantFolding::EvaluateOneFoldable(const NodeDef& node, }); for (const auto& input : node.input()) { - int position = 0; - ParseNodeName(input, &position); - if (position < 0) { + int port = 0; + ParseNodeName(input, &port); + if (port < 0) { // Control dependency break; } @@ -539,13 +582,13 @@ Status ConstantFolding::FoldNode(NodeDef* node, GraphDef* output_graph) { auto outputs = node_map_->GetOutputs(node->name()); for (auto& output : outputs) { for (int i = 0; i < output->input_size(); i++) { - int position; - string node_name = ParseNodeName(output->input(i), &position); + int port; + string node_name = ParseNodeName(output->input(i), &port); if (node_name == node->name()) { - if (position == 0) { + if (port == 0) { *output->mutable_input(i) = const_out->name(); node_map_->AddOutput(const_out->name(), output->name()); - } else if (position == 1) { + } else if (port == 1) { *output->mutable_input(i) = const_index->name(); node_map_->AddOutput(const_index->name(), output->name()); } else { @@ -630,10 +673,10 @@ Status ConstantFolding::FoldNode(NodeDef* node, GraphDef* output_graph) { auto outputs = node_map_->GetOutputs(node->name()); for (const auto& output : outputs) { for (int i = 0; i < output->input_size(); i++) { - int position; - string node_name = ParseNodeName(output->input(i), &position); + int port; + string node_name = ParseNodeName(output->input(i), &port); if (node_name == node->name()) { - if (position < 0) { + if (port < 0) { // Propagate control dependencies if possible. If not, we'll just // preserve the existing control dependencies. if (constant_output != nullptr) { @@ -641,17 +684,17 @@ Status ConstantFolding::FoldNode(NodeDef* node, GraphDef* output_graph) { constant_output->name()); *output->mutable_input(i) = AsControlDependency(*constant_output); } - } else if (position < const_nodes.size() && - !const_nodes[position].name().empty()) { + } else if (port < const_nodes.size() && + !const_nodes[port].name().empty()) { // Replace alive outputs with the corresponding constant. node_map_->UpdateInput(output->name(), NodeName(output->input(i)), - const_nodes[position].name()); - *output->mutable_input(i) = const_nodes[position].name(); + const_nodes[port].name()); + *output->mutable_input(i) = const_nodes[port].name(); } else { // Leave this edge alone. - VLOG(1) << "Preserving edge from " << node->name() << ":" - << position << "[" << node->op() << "] to " - << output->name() << ":" << i << "[" << output->op() << "]"; + VLOG(1) << "Preserving edge from " << node->name() << ":" << port + << "[" << node->op() << "] to " << output->name() << ":" + << i << "[" << output->op() << "]"; } } } diff --git a/tensorflow/core/grappler/optimizers/constant_folding_test.cc b/tensorflow/core/grappler/optimizers/constant_folding_test.cc index 183d783b55b..a1dee6d2fb8 100644 --- a/tensorflow/core/grappler/optimizers/constant_folding_test.cc +++ b/tensorflow/core/grappler/optimizers/constant_folding_test.cc @@ -421,6 +421,64 @@ TEST_F(ConstantFoldingTest, ShapeMaterializationEmptyFetch) { EXPECT_EQ(3, found); } +TEST_F(ConstantFoldingTest, ShapeMaterializationShapeN) { + tensorflow::Scope scope = tensorflow::Scope::NewRootScope(); + Output v1 = ops::Variable(scope.WithOpName("v1"), {3, -1}, DT_FLOAT); + Output v2 = ops::Variable(scope.WithOpName("v2"), {}, DT_FLOAT); + Output v3 = ops::Variable(scope.WithOpName("v3"), {4, 6}, DT_FLOAT); + auto s = ops::ShapeN(scope.WithOpName("s"), {v1, v2, v3}); + Output i1a = ops::Identity(scope.WithOpName("i1a"), s[0]); + Output i1b = ops::Identity(scope.WithOpName("i1b"), s[0]); + Output i2a = ops::Identity(scope.WithOpName("i2a"), s[1]); + Output i2b = ops::Identity(scope.WithOpName("i2b"), s[1]); + Output i2c = ops::Identity(scope.WithOpName("i2c"), s[1]); + Output i3a = ops::Identity(scope.WithOpName("i3a"), s[2]); + Output i3b = ops::Identity(scope.WithOpName("i3b"), s[2]); + + GrapplerItem item; + TF_CHECK_OK(scope.ToGraphDef(&item.graph)); + + ConstantFolding fold(nullptr /* cpu_device */); + GraphDef output; + Status status = fold.Optimize(nullptr, item, &output); + TF_EXPECT_OK(status); + int found = 0; + for (const auto& node : output.node()) { + EXPECT_NE(AddPrefixToNodeName("s-0", kConstantFoldingConst), node.name()); + EXPECT_NE(AddPrefixToNodeName("s-1", kConstantFoldingConst), node.name()); + if (node.name() == "i1a" || node.name() == "i1b") { + ++found; + EXPECT_EQ("s", node.input(0)); + } + if (node.name() == "i2a" || node.name() == "i2b" || node.name() == "i2c") { + ++found; + EXPECT_EQ("s:1", node.input(0)); + } + if (node.name() == "i3a" || node.name() == "i3b") { + ++found; + EXPECT_EQ(AddPrefixToNodeName("s-2", kConstantFoldingConst), + node.input(0)); + } + if (node.name() == "s") { + ++found; + EXPECT_EQ("ShapeN", node.op()); + EXPECT_EQ("v1", node.input(0)); + EXPECT_EQ("v2", node.input(1)); + EXPECT_EQ("v3", node.input(2)); + } + if (node.name() == AddPrefixToNodeName("s-2", kConstantFoldingConst)) { + ++found; + EXPECT_EQ("Const", node.op()); + EXPECT_EQ("^s", node.input(0)); + Tensor value; + CHECK(value.FromProto(node.attr().at("value").tensor())); + EXPECT_EQ(4, value.flat()(0)); + EXPECT_EQ(6, value.flat()(1)); + } + } + EXPECT_EQ(9, found); +} + TEST_F(ConstantFoldingTest, SwitchNodesEmptyFetch) { tensorflow::Scope scope = tensorflow::Scope::NewRootScope(); ops::Variable v_in(scope.WithOpName("v_in"), {3}, DT_FLOAT);