From 9e7875437f2f547095b06f123a0a9ad07ec02475 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" <gardener@tensorflow.org> Date: Mon, 31 Jul 2017 16:06:05 -0700 Subject: [PATCH] Add the option of including Shape, ShapeN, Size and Rank in the standard TensorFlow constant propagation pass, when the inputs to those Ops have sufficiently known static shape. PiperOrigin-RevId: 163762750 --- tensorflow/compiler/tf2xla/xla_compiler.cc | 2 +- .../core/common_runtime/constant_folding.cc | 415 ++++++++++++++---- .../core/common_runtime/constant_folding.h | 5 + .../common_runtime/constant_folding_test.cc | 191 ++++++++ .../core/common_runtime/direct_session.cc | 3 +- tensorflow/core/common_runtime/function.cc | 4 +- .../core/common_runtime/graph_optimizer.cc | 8 +- .../core/common_runtime/graph_optimizer.h | 17 +- .../core/distributed_runtime/graph_mgr.cc | 3 +- .../core/grappler/grappler_item_builder.cc | 3 +- 10 files changed, 562 insertions(+), 89 deletions(-) diff --git a/tensorflow/compiler/tf2xla/xla_compiler.cc b/tensorflow/compiler/tf2xla/xla_compiler.cc index 6c81d4e9f82..89d92173c37 100644 --- a/tensorflow/compiler/tf2xla/xla_compiler.cc +++ b/tensorflow/compiler/tf2xla/xla_compiler.cc @@ -161,7 +161,7 @@ Status XlaCompiler::CompileFunction( opts.set_do_constant_folding(true); GraphOptimizer optimizer(opts); optimizer.Optimize(flib_runtime_.get(), flib_runtime_->env(), - /*device=*/nullptr, &graph); + /*device=*/nullptr, &graph, /*shape_map=*/nullptr); VLOG(1) << "===================================================="; TF_RETURN_IF_ERROR( diff --git a/tensorflow/core/common_runtime/constant_folding.cc b/tensorflow/core/common_runtime/constant_folding.cc index 914683d9fa3..8dfb8b45de2 100644 --- a/tensorflow/core/common_runtime/constant_folding.cc +++ b/tensorflow/core/common_runtime/constant_folding.cc @@ -43,11 +43,182 @@ namespace tensorflow { namespace { -bool IsConstantFoldable(const Node* n, - const std::function<bool(const Node*)>& consider) { +// Test to see if the Op is one that turns into a constant when its +// inputs' shapes are known. +bool IsShapeOp(const Node* n) { + const auto& ts = n->type_string(); + return ts == "Shape" || ts == "ShapeN" || ts == "Rank" || ts == "Size"; +} + +// Reads the partially-known shape of each of n's inputs from shape_map, and +// stores it to input_shapes. Returns false if any input does not have a shape +// in shape_map. +bool ReadPartialShapesFromShapeMap( + const Node* n, + const std::unordered_map<const Node*, std::vector<PartialTensorShape>>* + shape_map, + std::vector<PartialTensorShape>* input_shapes) { + CHECK(shape_map != nullptr); + for (const Edge* in : n->in_edges()) { + // Don't need to check if incoming control edges have known shapes. + if (in->IsControlEdge()) continue; + if (shape_map->count(in->src()) == 0) { + // One of n's inputs doesn't have known shapes, so don't replace n. + return false; + } + const auto& known_shape = shape_map->at(in->src()); + CHECK_GT(known_shape.size(), in->src_output()); + input_shapes->push_back(known_shape[in->src_output()]); + } + return true; +} + +// If all of n's inputs have fully-defined shapes, inserts those shapes as a +// vector of Tensors in the shape_replacement_map. +bool MaybeReplaceShapeOrShapeNOp( + const Node* n, const std::vector<PartialTensorShape>& input_shapes, + std::unordered_map<const Node*, std::vector<Tensor>>* + shape_replacement_map) { + std::vector<Tensor> defined_shape; + for (const auto& shape : input_shapes) { + if (!shape.IsFullyDefined()) { + return false; + } + const int rank = shape.dims(); + DataType op_type = n->output_type(0); + Tensor t(op_type, TensorShape({rank})); + if (op_type == DT_INT64) { + auto vec = t.vec<int64>(); + for (int i = 0; i < rank; ++i) { + vec(i) = shape.dim_size(i); + } + } else { + CHECK(op_type == DT_INT32); + auto vec = t.vec<int32>(); + for (int i = 0; i < rank; ++i) { + if (shape.dim_size(i) > INT_MAX) { + VLOG(1) << "Node " << n->name() << " has input shape dimension " << i + << " of " << shape.dim_size(i) << " but type INT32 " + << " so not replacing as constant: this will trigger a " + "runtime error later."; + return false; + } + vec(i) = static_cast<int32>(shape.dim_size(i)); + } + } + defined_shape.push_back(t); + } + // All the inputs had known shapes so we can replace the node by constants + // later in the rewrite. + shape_replacement_map->insert({n, defined_shape}); + return true; +} + +// If n's input has defined rank, inserts that rank as a Tensor in the +// shape_replacement_map. +bool MaybeReplaceRankOp(const Node* n, + const std::vector<PartialTensorShape>& input_shapes, + std::unordered_map<const Node*, std::vector<Tensor>>* + shape_replacement_map) { + CHECK_EQ(input_shapes.size(), 1); + if (input_shapes[0].unknown_rank()) { + return false; + } + Tensor t(DT_INT32, TensorShape({})); + t.scalar<int32>()() = input_shapes[0].dims(); + shape_replacement_map->insert({n, {t}}); + return true; +} + +// If n's input has defined size, inserts that size as a Tensor in the +// shape_replacement_map. +bool MaybeReplaceSizeOp(const Node* n, + const std::vector<PartialTensorShape>& input_shapes, + std::unordered_map<const Node*, std::vector<Tensor>>* + shape_replacement_map) { + CHECK_EQ(input_shapes.size(), 1); + if (!input_shapes[0].IsFullyDefined()) { + return false; + } + DataType op_type = n->output_type(0); + Tensor t(op_type, TensorShape({})); + int64 size = input_shapes[0].num_elements(); + if (op_type == DT_INT64) { + t.scalar<int64>()() = size; + } else { + CHECK(op_type == DT_INT32); + if (size > INT_MAX) { + VLOG(1) << "Node " << n->name() << " has input shape size " << size + << " but type INT32 " + << " so not replacing as constant: this will trigger a runtime " + "error later."; + return false; + } + t.scalar<int32>()() = static_cast<int32>(size); + } + shape_replacement_map->insert({n, {t}}); + return true; +} + +// If n is a shape Op (Shape, ShapeN, Rank, or Size) and its inputs have their +// shapes specified in shape_map, then adds to shape_replacement_map a mapping +// from n to a vector of Tensors, where Tensor k is the (statically known) value +// on n's kth output edge. shape_replacement_map has an entry for n iff +// MaybeReplaceShapeOp returns true, so it's valid to use +// shape_replacement_map->count(n) as a test to see if n is a shape op that can +// be replaced. +bool MaybeReplaceShapeOp( + const Node* n, + const std::unordered_map<const Node*, std::vector<PartialTensorShape>>* + shape_map, + std::unordered_map<const Node*, std::vector<Tensor>>* + shape_replacement_map) { + if (shape_map == nullptr || !IsShapeOp(n)) { + return false; + } + // input_shapes will contain the shapes of each of n's inputs. + std::vector<PartialTensorShape> input_shapes; + if (!ReadPartialShapesFromShapeMap(n, shape_map, &input_shapes)) { + return false; + } + const auto& ts = n->type_string(); + if (ts == "Shape" || ts == "ShapeN") { + if (!MaybeReplaceShapeOrShapeNOp(n, input_shapes, shape_replacement_map)) { + return false; + } + } else if (ts == "Rank") { + if (!MaybeReplaceRankOp(n, input_shapes, shape_replacement_map)) { + return false; + } + } else { + CHECK_EQ(ts, "Size"); + if (!MaybeReplaceSizeOp(n, input_shapes, shape_replacement_map)) { + return false; + } + } + return true; +} + +// Returns true if n can be evaluated as constant. shape_map maps from +// nodes to the partially-known shapes of their outputs. consider if +// non-null returns a bool indicating whether a given (non-Const, +// non-Shape) node is eligible to be +// constant-propagated. shape_replacement_map is filled in with a +// vector of constant output tensors for constant-foldable shape nodes +// (Shape, ShapeN, Size, or Rank). +bool IsConstantFoldable( + const Node* n, + const std::unordered_map<const Node*, std::vector<PartialTensorShape>>* + shape_map, + const std::function<bool(const Node*)>& consider, + std::unordered_map<const Node*, std::vector<Tensor>>* + shape_replacement_map) { if (n->IsConstant()) { return true; } + if (MaybeReplaceShapeOp(n, shape_map, shape_replacement_map)) { + return true; + } if (n->op_def().is_stateful()) { return false; } @@ -82,56 +253,81 @@ bool IsConstantFoldable(const Node* n, return true; } +// If n is eligible for constant-folding, adds it to nodes, and places its +// control dependencies and those transitively of its constant-foldable inputs +// into constant_control_deps. If n is a constant-foldable shape node (Shape, +// ShapeN, Rank, or Size), also puts its outputs into shape_replacement_map. +void ConsiderConstantFoldableNode( + Node* n, const ConstantFoldingOptions& opts, std::vector<Node*>* nodes, + std::unordered_map<const Node*, gtl::FlatSet<Node*>>* constant_control_deps, + std::unordered_map<const Node*, std::vector<Tensor>>* shape_replacement_map, + bool* internal_node_inserted) { + if (IsConstantFoldable(n, opts.shape_map, opts.consider, + shape_replacement_map)) { + // A node is constant provided all of its non-control incoming Tensors come + // from constant nodes, or it's a shape Op with statically known inputs in + // which case it is placed in shape_replacement_map. + // + // We allow control dependencies from non-constant nodes to constant nodes, + // but to preserve the graph structure we must transfer the control + // dependency onto any constant replacement. + bool all_parents_constant = true; + for (const Edge* in : n->in_edges()) { + // Allows non-constant -> constant control edges. + if (!in->IsControlEdge() && + constant_control_deps->count(in->src()) == 0) { + all_parents_constant = false; + break; + } + } + if (all_parents_constant || shape_replacement_map->count(n) != 0) { + gtl::FlatSet<Node*>& control_deps = (*constant_control_deps)[n]; + for (const Edge* e : n->in_edges()) { + if (constant_control_deps->count(e->src()) == 0) { + // This branch is taken if the incoming edge is a control dependency, + // in which case we want to add it to the dependencies being + // accumulated for this node, or the incoming edge is not + // constant. The latter may happen when n is a shape node and the + // source has known shape. In that case add a control dependency from + // the source node, since there was previously a data dependency and + // we want to preserve sequencing constraints. + if (!e->src()->IsSource()) { + control_deps.insert(e->src()); + } + } else { + // If the parent has been accumulating control dependencies, add all + // of its transitive control deps. + const gtl::FlatSet<Node*>& parent_deps = + (*constant_control_deps)[e->src()]; + control_deps.insert(parent_deps.begin(), parent_deps.end()); + } + } + nodes->push_back(n); + if (!n->IsConstant()) { + *internal_node_inserted = true; + } + } + } +} + // Returns the constant foldable nodes in `nodes` in topological order. // Populates `constant_control_deps` with the non-constant control dependencies // of each constant node. void FindConstantFoldableNodes( - const Graph* graph, ConstantFoldingOptions opts, std::vector<Node*>* nodes, - std::unordered_map<const Node*, gtl::FlatSet<Node*>>* - constant_control_deps) { + const Graph* graph, const ConstantFoldingOptions& opts, + std::vector<Node*>* nodes, + std::unordered_map<const Node*, gtl::FlatSet<Node*>>* constant_control_deps, + std::unordered_map<const Node*, std::vector<Tensor>>* + shape_replacement_map) { bool internal_node_inserted = false; - // Walk the nodes in data flow order - ReverseDFS( - *graph, nullptr, - [nodes, constant_control_deps, &internal_node_inserted, opts](Node* n) { - if (IsConstantFoldable(n, opts.consider)) { - // A node is constant provided all of its non-control - // incoming Tensors come from constant nodes. - // - // We allow control dependencies from non-constant nodes to constant - // nodes, but to preserve the graph structure we must transfer the - // control dependency onto any constant replacement. - bool all_parents_constant = true; - for (const Edge* in : n->in_edges()) { - // Allows non-constant -> constant control edges. - if (!in->IsControlEdge() && - constant_control_deps->count(in->src()) == 0) { - all_parents_constant = false; - break; - } - } - if (all_parents_constant) { - gtl::FlatSet<Node*>& control_deps = (*constant_control_deps)[n]; - for (const Edge* e : n->in_edges()) { - if (constant_control_deps->count(e->src()) == 0) { - if (!e->src()->IsSource()) { - control_deps.insert(e->src()); - } - } else { - // If the parent is constant, add all of its transitive control - // deps. - const gtl::FlatSet<Node*>& parent_deps = - (*constant_control_deps)[e->src()]; - control_deps.insert(parent_deps.begin(), parent_deps.end()); - } - } - nodes->push_back(n); - if (!n->IsConstant()) { - internal_node_inserted = true; - } - } - } - }); + // Walk the nodes in data flow order. + ReverseDFS(*graph, nullptr, + [nodes, constant_control_deps, shape_replacement_map, + &internal_node_inserted, &opts](Node* n) { + ConsiderConstantFoldableNode( + n, opts, nodes, constant_control_deps, shape_replacement_map, + &internal_node_inserted); + }); // If we have inserted just leaf level nodes, then there is nothing to fold. if (!internal_node_inserted) { nodes->clear(); @@ -141,31 +337,93 @@ void FindConstantFoldableNodes( typedef std::pair<Node*, int> NodeAndOutput; +int64 UniqueConstantId() { + static std::atomic_int_fast64_t id; + return id.fetch_add(1); +} + +// Adds n to constant_graph which is being built up for subsequent evaluation of +// constant propagation. node_map is the mapping of nodes in the original graph +// to nodes in the constant graph. The value of an entry in node_map is a vector +// of nodes because a ShapeN node in the original graph is replaced by a vector +// of Constant nodes in the constant graph. +void AddNodeToConstantGraph( + Node* n, std::unordered_map<Node*, std::vector<Node*>>* node_map, + Graph* constant_graph) { + std::vector<Node*>& added = (*node_map)[n]; + added.push_back(constant_graph->CopyNode(n)); + for (const Edge* in_edge : n->in_edges()) { + // Don't copy control edges to the constant graph. + if (!in_edge->IsControlEdge()) { + Node* in = in_edge->src(); + auto it = node_map->find(in); + CHECK(it != node_map->end()) + << n->DebugString() << " <-" << in->DebugString(); + if (it->second.size() == 1) { + constant_graph->AddEdge(it->second[0], in_edge->src_output(), added[0], + in_edge->dst_input()); + } else { + // The original source node had multiple outputs and was replaced by a + // vector of constants, so the edge comes from the 0th output of the kth + // added constant, rather than the kth output of the added node as in + // the standard case above. + constant_graph->AddEdge(it->second[in_edge->src_output()], 0, added[0], + in_edge->dst_input()); + } + } + } +} + +// Replaces constant-foldable shape node n by a vector of constants in +// constant_graph, which is being built up for subsequent evaluation of constant +// propagation. node_map is the mapping of nodes in the original graph to nodes +// in the constant graph. The value of an entry in node_map is a vector of nodes +// because a ShapeN node in the original graph is replaced by a vector of +// Constant nodes in the constant graph. +void AddShapeNodeToConstantGraph( + Node* n, + const std::unordered_map<const Node*, std::vector<Tensor>>& + shape_replacement_map, + std::unordered_map<Node*, std::vector<Node*>>* node_map, + Graph* constant_graph) { + std::vector<Node*>& added = (*node_map)[n]; + const string& node_name = n->name(); + for (const Tensor& t : shape_replacement_map.at(n)) { + auto builder = + NodeDefBuilder(strings::StrCat(constant_graph->NewName(node_name), + "__cf__", UniqueConstantId()), + "Const") + .Attr("dtype", t.dtype()) + .Attr("value", t); + NodeDef def; + CHECK(builder.Finalize(&def).ok()); + Node* constant_node; + CHECK(NodeBuilder(builder).Finalize(constant_graph, &constant_node).ok()); + added.push_back(constant_node); + } + // Don't copy incoming edges to shape nodes that are being replaced. +} + // Given the constant foldable nodes in 'nodes', returns a new graph 'g'. 'g' // will contain copies of the nodes in 'nodes'. In addition, if there is an edge // going from a node 'n' in 'nodes' to another node in 'orig_graph' but not in // 'nodes', then 'tensors_to_fetch' will contain the mapping from the // corresponding copy of 'n' and the edge number in 'g' to 'n'. -Graph* GetConstantGraph(const Graph* orig_graph, - const std::vector<Node*>& nodes, - std::map<NodeAndOutput, Node*>* tensors_to_fetch) { +Graph* GetConstantGraph( + const Graph* orig_graph, const std::vector<Node*>& nodes, + const std::unordered_map<const Node*, std::vector<Tensor>>& + shape_replacement_map, + std::map<NodeAndOutput, Node*>* tensors_to_fetch) { Graph* constant_graph = new Graph(orig_graph->op_registry()); - std::unordered_map<Node*, Node*> node_map; - node_map[orig_graph->source_node()] = constant_graph->source_node(); - node_map[orig_graph->sink_node()] = constant_graph->sink_node(); + std::unordered_map<Node*, std::vector<Node*>> node_map; + node_map[orig_graph->source_node()] = {constant_graph->source_node()}; + node_map[orig_graph->sink_node()] = {constant_graph->sink_node()}; for (Node* n : nodes) { - Node* added = constant_graph->CopyNode(n); - node_map[n] = added; - for (const Edge* in_edge : n->in_edges()) { - // Don't copy control edges to the constant graph. - if (!in_edge->IsControlEdge()) { - Node* in = in_edge->src(); - auto it = node_map.find(in); - CHECK(it != node_map.end()) - << n->DebugString() << " <-" << in->DebugString(); - constant_graph->AddEdge(it->second, in_edge->src_output(), added, - in_edge->dst_input()); - } + if (shape_replacement_map.count(n) == 0) { + AddNodeToConstantGraph(n, &node_map, constant_graph); + } else { + AddShapeNodeToConstantGraph(n, shape_replacement_map, &node_map, + constant_graph); } } @@ -173,8 +431,19 @@ Graph* GetConstantGraph(const Graph* orig_graph, for (const Edge* out_edge : added_nodes.first->out_edges()) { if (node_map.count(out_edge->dst()) == 0) { if (out_edge->IsControlEdge()) continue; - tensors_to_fetch->insert( - {{added_nodes.second, out_edge->src_output()}, added_nodes.first}); + if (added_nodes.second.size() == 1) { + tensors_to_fetch->insert( + {{added_nodes.second[0], out_edge->src_output()}, + added_nodes.first}); + } else { + // The node had multiple outputs and was replaced by a + // vector of constants, so the NodeAndOutput is the 0th + // output of the kth added constant, rather than the kth + // output of the added node as in the standard case above. + tensors_to_fetch->insert( + {{added_nodes.second[out_edge->src_output()], 0}, + added_nodes.first}); + } } } } @@ -182,11 +451,6 @@ Graph* GetConstantGraph(const Graph* orig_graph, return constant_graph; } -int64 UniqueConstantId() { - static std::atomic_int_fast64_t id; - return id.fetch_add(1); -} - // Replaces the identified Tensor in 'graph' by a 'Const' node with // the value supplied in 'constant'. 'partition_device', if non-null // is the device where the graph executes. Returns true if the @@ -291,8 +555,9 @@ Status ConstantFold(const ConstantFoldingOptions& opts, std::vector<Node*> constant_foldable_nodes; std::unordered_map<const Node*, gtl::FlatSet<Node*>> constant_control_deps; + std::unordered_map<const Node*, std::vector<Tensor>> shape_replacement_map; FindConstantFoldableNodes(graph, opts, &constant_foldable_nodes, - &constant_control_deps); + &constant_control_deps, &shape_replacement_map); if (constant_foldable_nodes.empty()) { VLOG(1) << "No constant foldable nodes found"; *was_mutated = false; @@ -302,7 +567,8 @@ Status ConstantFold(const ConstantFoldingOptions& opts, std::map<NodeAndOutput, Node*> tensors_to_fetch; std::unique_ptr<Graph> constant_graph( - GetConstantGraph(graph, constant_foldable_nodes, &tensors_to_fetch)); + GetConstantGraph(graph, constant_foldable_nodes, shape_replacement_map, + &tensors_to_fetch)); DumpGraph("Constant graph", constant_graph.get()); if (tensors_to_fetch.empty()) { @@ -337,7 +603,6 @@ Status ConstantFold(const ConstantFoldingOptions& opts, if (!s.ok()) { VLOG(1) << "Could not fetch constants: " << s; *was_mutated = false; - // This is not an error, so return the status as OK. return s; } diff --git a/tensorflow/core/common_runtime/constant_folding.h b/tensorflow/core/common_runtime/constant_folding.h index 93289b875f5..07cb5def8b4 100644 --- a/tensorflow/core/common_runtime/constant_folding.h +++ b/tensorflow/core/common_runtime/constant_folding.h @@ -29,6 +29,11 @@ struct ConstantFoldingOptions { // If "consider" is not a nullptr, then only constant fold a node "n" if // consider(n) returns true. std::function<bool(const Node*)> consider = nullptr; + // If shape_map is not a nullptr, it is a map from node n to a + // vector of the (potentially partially-known) shapes of its + // outputs. + const std::unordered_map<const Node*, std::vector<PartialTensorShape>>* + shape_map; // not owned }; // Perform constant folding optimization on "graph". diff --git a/tensorflow/core/common_runtime/constant_folding_test.cc b/tensorflow/core/common_runtime/constant_folding_test.cc index 4a8560960ed..c76ad647a0a 100644 --- a/tensorflow/core/common_runtime/constant_folding_test.cc +++ b/tensorflow/core/common_runtime/constant_folding_test.cc @@ -363,6 +363,197 @@ TEST_F(ConstantFoldingTest, ControlDependencies) { } } +TEST_F(ConstantFoldingTest, SimpleShapeKnown) { + Graph g(OpRegistry::Global()); + { + Scope s = Scope::NewRootScope(); + Output recv0 = ops::_Recv(s.WithOpName("recv0"), DT_FLOAT, "recv0", + "sender", 0, "receiver"); + auto shape = ops::Shape(s.WithOpName("shape"), recv0); + Output recv1 = ops::_Recv(s.WithOpName("recv1"), DT_FLOAT, "recv1", + "sender", 0, "receiver"); + auto shape_n = ops::ShapeN(s.WithOpName("shape_n"), {recv0, recv1}); + auto rank = ops::Rank(s.WithOpName("rank"), recv0); + auto size = ops::Size(s.WithOpName("size"), recv1); + auto recv2 = ops::_Recv(s.WithOpName("recv2"), DT_FLOAT, "recv2", "sender", + 0, "receiver"); + auto c = ops::Const<int>(s.WithControlDependencies(recv2), 3); + auto add0 = ops::Add(s.WithControlDependencies(c), rank, size); + auto add1 = ops::Add(s, shape, shape_n[0]); + auto add2 = ops::Add(s, shape_n[1], shape_n[1]); + auto send0 = ops::_Send(s.WithOpName("send0"), add0, "send0", "sender", 0, + "receiver"); + auto send1 = ops::_Send(s.WithOpName("send1"), add1, "send1", "sender", 0, + "receiver"); + auto send2 = ops::_Send(s.WithOpName("send2"), add2, "send2", "sender", 0, + "receiver"); + TF_ASSERT_OK(s.ToGraph(&g)); + } + std::unordered_map<string, Node*> orig_index = NodeNameIndex(g); + Node* recv0 = orig_index.at("recv0"); + Node* recv1 = orig_index.at("recv1"); + PartialTensorShape ps0; + int r0_dims[] = {1, 2}; + TF_EXPECT_OK(PartialTensorShape::MakePartialShape(r0_dims, 2, &ps0)); + PartialTensorShape ps1; + int r1_dims[] = {2, 3, 4}; + TF_EXPECT_OK(PartialTensorShape::MakePartialShape<int>(r1_dims, 3, &ps1)); + std::unordered_map<const Node*, std::vector<PartialTensorShape>> map; + map[recv0].push_back(ps0); + map[recv1].push_back(ps1); + ConstantFoldingOptions opts; + opts.shape_map = ↦ + bool was_mutated; + TF_EXPECT_OK( + ConstantFold(opts, nullptr, Env::Default(), nullptr, &g, &was_mutated)); + EXPECT_TRUE(was_mutated); + + std::unordered_map<string, Node*> index = NodeNameIndex(g); + Node* recv2 = index.at("recv2"); + Node* send0 = index.at("send0"); + Node* send1 = index.at("send1"); + Node* send2 = index.at("send2"); + + ASSERT_EQ(1, send0->num_inputs()); + Node* cf0 = *(send0->in_nodes().begin()); + ExpectNodeEqual<int>(cf0, {26}, {}); + + ASSERT_EQ(1, send1->num_inputs()); + Node* cf1 = *(send1->in_nodes().begin()); + ExpectNodeEqual<int>(cf1, {2, 4}, {2}); + + ASSERT_EQ(1, send2->num_inputs()); + Node* cf2 = *(send2->in_nodes().begin()); + ExpectNodeEqual<int>(cf2, {4, 6, 8}, {3}); + + ASSERT_EQ(3, cf0->in_edges().size()); + for (const Edge* e : cf0->in_edges()) { + EXPECT_TRUE(e->IsControlEdge()); + EXPECT_TRUE(e->src() == recv0 || e->src() == recv1 || e->src() == recv2) + << e->src()->name(); + } + + ASSERT_EQ(2, cf1->in_edges().size()); + for (const Edge* e : cf1->in_edges()) { + EXPECT_TRUE(e->IsControlEdge()); + EXPECT_TRUE(e->src() == recv0 || e->src() == recv1) << e->src()->name(); + } + + ASSERT_EQ(2, cf2->in_edges().size()); + for (const Edge* e : cf2->in_edges()) { + EXPECT_TRUE(e->IsControlEdge()); + EXPECT_TRUE(e->src() == recv0 || e->src() == recv1) << e->src()->name(); + } +} + +TEST_F(ConstantFoldingTest, PartialShape) { + Graph g(OpRegistry::Global()); + { + Scope s = Scope::NewRootScope(); + Output recv0 = ops::_Recv(s.WithOpName("recv0"), DT_FLOAT, "recv0", + "sender", 0, "receiver"); + Output recv1 = ops::_Recv(s.WithOpName("recv1"), DT_FLOAT, "recv1", + "sender", 0, "receiver"); + auto shape = ops::Shape(s.WithOpName("shape"), recv0); + auto rank0 = ops::Rank(s.WithOpName("rank0"), recv0); + auto rank1 = ops::Rank(s.WithOpName("rank1"), recv1); + auto size = ops::Size(s.WithOpName("size"), recv0); + auto send0 = ops::_Send(s.WithOpName("send0"), rank0, "send0", "sender", 0, + "receiver"); + auto send1 = ops::_Send(s.WithOpName("send1"), shape, "send1", "sender", 0, + "receiver"); + auto send2 = ops::_Send(s.WithOpName("send2"), size, "send2", "sender", 0, + "receiver"); + auto send3 = ops::_Send(s.WithOpName("send3"), rank1, "send3", "sender", 0, + "receiver"); + TF_ASSERT_OK(s.ToGraph(&g)); + } + std::unordered_map<string, Node*> orig_index = NodeNameIndex(g); + Node* recv0 = orig_index.at("recv0"); + Node* recv1 = orig_index.at("recv1"); + PartialTensorShape ps0; + int r0_dims[] = {-1, -1}; + TF_EXPECT_OK(PartialTensorShape::MakePartialShape(r0_dims, 2, &ps0)); + PartialTensorShape ps1; + std::unordered_map<const Node*, std::vector<PartialTensorShape>> map; + map[recv0].push_back(ps0); + map[recv1].push_back(ps1); + ConstantFoldingOptions opts; + opts.shape_map = ↦ + bool was_mutated; + TF_EXPECT_OK( + ConstantFold(opts, nullptr, Env::Default(), nullptr, &g, &was_mutated)); + EXPECT_TRUE(was_mutated); + + std::unordered_map<string, Node*> index = NodeNameIndex(g); + Node* shape = index.at("shape"); + Node* size = index.at("size"); + Node* rank1 = index.at("rank1"); + Node* send0 = index.at("send0"); + Node* send1 = index.at("send1"); + Node* send2 = index.at("send2"); + Node* send3 = index.at("send3"); + + ASSERT_EQ(1, send0->num_inputs()); + Node* cf0 = *(send0->in_nodes().begin()); + ExpectNodeEqual<int>(cf0, {2}, {}); + + ASSERT_EQ(1, send1->num_inputs()); + Node* ncf1 = *(send1->in_nodes().begin()); + EXPECT_EQ(ncf1, shape); + + ASSERT_EQ(1, send2->num_inputs()); + Node* ncf2 = *(send2->in_nodes().begin()); + EXPECT_EQ(ncf2, size); + + ASSERT_EQ(1, send3->num_inputs()); + Node* ncf3 = *(send3->in_nodes().begin()); + EXPECT_EQ(ncf3, rank1); +} + +TEST_F(ConstantFoldingTest, ConstShapeKnown) { + Graph g(OpRegistry::Global()); + { + Scope s = Scope::NewRootScope(); + auto recv0 = ops::_Recv(s.WithOpName("recv0"), DT_FLOAT, "recv0", "sender", + 0, "receiver"); + auto c0 = + ops::Const<int>(s.WithOpName("c0").WithControlDependencies(recv0), 1); + auto rank = ops::Rank(s.WithOpName("rank"), c0); + auto add0 = ops::Add(s, rank, rank); + auto send0 = ops::_Send(s.WithOpName("send0"), add0, "send0", "sender", 0, + "receiver"); + TF_ASSERT_OK(s.ToGraph(&g)); + } + std::unordered_map<string, Node*> orig_index = NodeNameIndex(g); + Node* c0 = orig_index.at("c0"); + PartialTensorShape ps0; + int c0_dims[] = {}; + TF_EXPECT_OK(PartialTensorShape::MakePartialShape(c0_dims, 0, &ps0)); + std::unordered_map<const Node*, std::vector<PartialTensorShape>> map; + map[c0].push_back(ps0); + ConstantFoldingOptions opts; + opts.shape_map = ↦ + bool was_mutated; + TF_EXPECT_OK( + ConstantFold(opts, nullptr, Env::Default(), nullptr, &g, &was_mutated)); + EXPECT_TRUE(was_mutated); + + std::unordered_map<string, Node*> index = NodeNameIndex(g); + Node* recv0 = index.at("recv0"); + Node* send0 = index.at("send0"); + + ASSERT_EQ(1, send0->num_inputs()); + Node* cf0 = *(send0->in_nodes().begin()); + ExpectNodeEqual<int>(cf0, {0}, {}); + + ASSERT_EQ(1, cf0->in_edges().size()); + for (const Edge* e : cf0->in_edges()) { + EXPECT_TRUE(e->IsControlEdge()); + EXPECT_TRUE(e->src() == recv0) << e->src()->name(); + } +} + namespace { const char kTestMemRegionName[] = "test://test"; diff --git a/tensorflow/core/common_runtime/direct_session.cc b/tensorflow/core/common_runtime/direct_session.cc index 294ad13a0a4..aaba2aa7875 100644 --- a/tensorflow/core/common_runtime/direct_session.cc +++ b/tensorflow/core/common_runtime/direct_session.cc @@ -1194,7 +1194,8 @@ Status DirectSession::GetOrCreateExecutors( }; params.node_outputs_cb = node_outputs_callback_; - optimizer.Optimize(lib, options_.env, device, &iter->second); + optimizer.Optimize(lib, options_.env, device, &iter->second, + /*shape_map=*/nullptr); // EXPERIMENTAL: tfdbg inserts debug nodes in the graph. if (!options.debug_options.debug_tensor_watch_opts().empty()) { diff --git a/tensorflow/core/common_runtime/function.cc b/tensorflow/core/common_runtime/function.cc index 64c3747ce1c..b7fad68de76 100644 --- a/tensorflow/core/common_runtime/function.cc +++ b/tensorflow/core/common_runtime/function.cc @@ -461,7 +461,7 @@ void OptimizeGraph(FunctionLibraryRuntime* lib, std::unique_ptr<Graph>* g) { opts.set_do_function_inlining(true); opts.set_do_constant_folding(true); GraphOptimizer optimizer(opts); - optimizer.Optimize(lib, lib->env(), lib->device(), g); + optimizer.Optimize(lib, lib->env(), lib->device(), g, /*shape_map=*/nullptr); } Status FunctionLibraryRuntimeImpl::CreateItem(Handle handle, Item** item) { @@ -470,7 +470,7 @@ Status FunctionLibraryRuntimeImpl::CreateItem(Handle handle, Item** item) { std::unique_ptr<Graph> g(new Graph(lib_def_)); CopyGraph(*fbody->graph, g.get()); - optimizer_.Optimize(this, env(), device(), &g); + optimizer_.Optimize(this, env(), device(), &g, /*shape_map=*/nullptr); TF_RETURN_IF_ERROR(EnsureMemoryTypes(DeviceType(device()->device_type()), device()->name(), g.get())); diff --git a/tensorflow/core/common_runtime/graph_optimizer.cc b/tensorflow/core/common_runtime/graph_optimizer.cc index edfecfae06e..c32d9f8a2e9 100644 --- a/tensorflow/core/common_runtime/graph_optimizer.cc +++ b/tensorflow/core/common_runtime/graph_optimizer.cc @@ -33,8 +33,11 @@ GraphOptimizer::GraphOptimizer(const OptimizerOptions& opts) : opts_(opts) { GraphOptimizer::~GraphOptimizer() {} -void GraphOptimizer::Optimize(FunctionLibraryRuntime* runtime, Env* env, - Device* device, std::unique_ptr<Graph>* graph) { +void GraphOptimizer::Optimize( + FunctionLibraryRuntime* runtime, Env* env, Device* device, + std::unique_ptr<Graph>* graph, + const std::unordered_map<const Node*, std::vector<PartialTensorShape>>* + shape_map) { Graph* g = graph->get(); DumpGraph("Initial", g); @@ -57,6 +60,7 @@ void GraphOptimizer::Optimize(FunctionLibraryRuntime* runtime, Env* env, if (opts_.do_constant_folding()) { ConstantFoldingOptions cf_opts; + cf_opts.shape_map = shape_map; bool was_mutated; ConstantFold(cf_opts, runtime, env, device, g, &was_mutated) .IgnoreError(); diff --git a/tensorflow/core/common_runtime/graph_optimizer.h b/tensorflow/core/common_runtime/graph_optimizer.h index a6b10356ce1..c145adde829 100644 --- a/tensorflow/core/common_runtime/graph_optimizer.h +++ b/tensorflow/core/common_runtime/graph_optimizer.h @@ -30,12 +30,17 @@ class GraphOptimizer { ~GraphOptimizer(); // Applies optimization passes specified in 'opts' to 'graph'. - // Maybe replace *graph with a new graph object. - // 'device' is device on which the 'graph' will execute. It's passed to the - // optimizers so that they can respect constraints if any, that should be - // respected. - void Optimize(FunctionLibraryRuntime* runtime, Env* env, Device* device, - std::unique_ptr<Graph>* graph); + // Maybe replace *graph with a new graph object. 'device' is device + // on which the 'graph' will execute. It's passed to the optimizers + // so that they can respect constraints if any, that should be + // respected. If shape_map is not null it maps from nodes in graph + // to partially-known shapes of their outputs, and may be used, + // e.g., in the constant folding pass. + void Optimize( + FunctionLibraryRuntime* runtime, Env* env, Device* device, + std::unique_ptr<Graph>* graph, + const std::unordered_map<const Node*, std::vector<PartialTensorShape>>* + shape_map); private: OptimizerOptions opts_; diff --git a/tensorflow/core/distributed_runtime/graph_mgr.cc b/tensorflow/core/distributed_runtime/graph_mgr.cc index 8d4b1be0261..ce186700b33 100644 --- a/tensorflow/core/distributed_runtime/graph_mgr.cc +++ b/tensorflow/core/distributed_runtime/graph_mgr.cc @@ -244,7 +244,8 @@ Status GraphMgr::InitItem(const string& session, const GraphDef& gdef, } }; - optimizer.Optimize(lib, worker_env_->env, params.device, &subgraph); + optimizer.Optimize(lib, worker_env_->env, params.device, &subgraph, + /*shape_map=*/nullptr); // EXPERIMENTAL: tfdbg inserts debug nodes (i.e., probes) to the graph. if (!debug_options.debug_tensor_watch_opts().empty()) { diff --git a/tensorflow/core/grappler/grappler_item_builder.cc b/tensorflow/core/grappler/grappler_item_builder.cc index 0c2801e8bc3..ed8bbe58c46 100644 --- a/tensorflow/core/grappler/grappler_item_builder.cc +++ b/tensorflow/core/grappler/grappler_item_builder.cc @@ -127,7 +127,8 @@ Status OptimizeGraph(const GraphDef& graph_def, GraphDef* output_graph_def, // Optimize the graph. GraphOptimizer optimizer(*optimizer_opts); - optimizer.Optimize(flib.get(), env, devices[0], &graphptr); + optimizer.Optimize(flib.get(), env, devices[0], &graphptr, + /*shape_map=*/nullptr); graphptr->ToGraphDef(output_graph_def); return Status::OK();