diff --git a/tensorflow/core/common_runtime/function.cc b/tensorflow/core/common_runtime/function.cc index 0f2e24690f3..13e20568fff 100644 --- a/tensorflow/core/common_runtime/function.cc +++ b/tensorflow/core/common_runtime/function.cc @@ -456,7 +456,7 @@ Status FunctionLibraryRuntimeImpl::Instantiate( void DumpGraph(StringPiece label, const Graph* g) { // TODO(zhifengc): Change Graph to record #nodes. VLOG(1) << "Graph " << label << " #nodes " << g->num_nodes() << " #edges " - << g->edges().size(); + << g->num_edges(); if (VLOG_IS_ON(2)) { for (const auto& line : str_util::Split(DebugString(g), '\n')) { VLOG(2) << "|| " << line; diff --git a/tensorflow/core/common_runtime/function_test.cc b/tensorflow/core/common_runtime/function_test.cc index bbf35590eb6..8f70ab8783c 100644 --- a/tensorflow/core/common_runtime/function_test.cc +++ b/tensorflow/core/common_runtime/function_test.cc @@ -424,7 +424,7 @@ TEST_F(FunctionLibraryRuntimeTest, ControlDeps) { n8 = NoOp() @ n4 n9 = Identity[T=float](n3) @ n8 n10 = Identity[T=float](n2) @ n8 - n11 = NoOp() @ n10, n9 + n11 = NoOp() @ n9, n10 n5 = Mul[T=float](n2, n2) @ n11 n6 = Add[T=float](n4, n5) } @@ -500,8 +500,8 @@ TEST_F(FunctionLibraryRuntimeTest, Gradient_XTimesTwo) { OptimizeGraph(lib_, &g); const char* e2 = R"P( (n2:float, n3:float) -> (n9:float) { - n11 = Const[dtype=int32, value=Tensor]() n10 = Const[dtype=float, value=Tensor]() + n11 = Const[dtype=int32, value=Tensor]() n6 = Shape[T=float, out_type=int32](n2) n5 = Mul[T=float](n3, n10) n7 = BroadcastGradientArgs[T=int32](n6, n11) @@ -614,10 +614,10 @@ TEST_F(FunctionLibraryRuntimeTest, Gradient_AddSum) { n17 = Sum[T=float, Tidx=int32, keep_dims=false](n14, n16) n19 = SymbolicGradient[Tin={float, int32, float}, Tout={float, int32}, f=Sum[T=float, Tidx=int32, keep_dims=false]](n14, n16, n26) n21 = SymbolicGradient[Tin={float, float, float}, Tout={float, float}, f=Add[T=float]](n24, n25, n19) - n28 = Identity[T=float](n21:1) n27 = Identity[T=float](n21) - n6 = Identity[T=float](n28) + n28 = Identity[T=float](n21:1) n8 = Identity[T=float](n27) + n6 = Identity[T=float](n28) } )P"; EXPECT_EQ(e1, DebugString(g.get())); @@ -626,8 +626,8 @@ TEST_F(FunctionLibraryRuntimeTest, Gradient_AddSum) { const char* e2 = R"P( (n4:float, n3:float) -> (n25:float, n23:float) { n2 = Const[dtype=float, value=Tensor]() - n8 = Const[dtype=int32, value=Tensor]() n7 = Const[dtype=int32, value=Tensor]() + n8 = Const[dtype=int32, value=Tensor]() n19 = Shape[T=float, out_type=int32](n3) n9 = Add[T=float](n4, n3) n20 = Shape[T=float, out_type=int32](n4) @@ -641,10 +641,10 @@ TEST_F(FunctionLibraryRuntimeTest, Gradient_AddSum) { n16 = Reshape[T=float, Tshape=int32](n2, n15) n17 = Div[T=int32](n14, n15) n18 = Tile[T=float, Tmultiples=int32](n16, n17) - n24 = Sum[T=float, Tidx=int32, keep_dims=false](n18, n21) n22 = Sum[T=float, Tidx=int32, keep_dims=false](n18, n21:1) - n25 = Reshape[T=float, Tshape=int32](n24, n20) + n24 = Sum[T=float, Tidx=int32, keep_dims=false](n18, n21) n23 = Reshape[T=float, Tshape=int32](n22, n19) + n25 = Reshape[T=float, Tshape=int32](n24, n20) } )P"; EXPECT_EQ(e2, DebugString(g.get())); diff --git a/tensorflow/core/graph/graph.cc b/tensorflow/core/graph/graph.cc index e1657cb8622..a68a8f25093 100644 --- a/tensorflow/core/graph/graph.cc +++ b/tensorflow/core/graph/graph.cc @@ -344,7 +344,7 @@ const Edge* Graph::AddEdge(Node* source, int x, Node* dest, int y) { CHECK(source->out_edges_.insert(e).second); CHECK(dest->in_edges_.insert(e).second); edges_.push_back(e); - edge_set_.insert(e); + ++num_edges_; return e; } @@ -354,8 +354,8 @@ void Graph::RemoveEdge(const Edge* e) { CHECK_EQ(e->src_->out_edges_.erase(e), size_t{1}); CHECK_EQ(e->dst_->in_edges_.erase(e), size_t{1}); CHECK_EQ(e, edges_[e->id_]); + CHECK_GT(num_edges_, 0); - CHECK_EQ(edge_set_.erase(e), size_t{1}); edges_[e->id_] = nullptr; Edge* del = const_cast(e); @@ -365,6 +365,7 @@ void Graph::RemoveEdge(const Edge* e) { del->src_output_ = kControlSlot - 1; del->dst_input_ = kControlSlot - 1; free_edges_.push_back(del); + --num_edges_; } Status Graph::AddFunctionLibrary(const FunctionDefLibrary& fdef_lib) { diff --git a/tensorflow/core/graph/graph.h b/tensorflow/core/graph/graph.h index 11a49ec3b3d..bbb3af196d6 100644 --- a/tensorflow/core/graph/graph.h +++ b/tensorflow/core/graph/graph.h @@ -268,6 +268,66 @@ class Edge { int dst_input_; }; +// Allows for iteration of the edges of a Graph, by iterating the underlying +// Graph.edges_ vector while skipping over null entries. +class GraphEdgesIterable { + private: + const std::vector& edges_; + + public: + explicit GraphEdgesIterable(const std::vector& edges) + : edges_(edges) {} + + typedef Edge* value_type; + + class const_iterator { + private: + // The underlying iterator. + std::vector::const_iterator iter_; + + // The end of the underlying iterator. + std::vector::const_iterator end_; + + // Advances iter_ until it reaches a non-null item, or reaches the end. + void apply_filter() { + while (iter_ != end_ && *iter_ == nullptr) { + ++iter_; + } + } + + public: + const_iterator(std::vector::const_iterator iter, + std::vector::const_iterator end) + : iter_(iter), end_(end) { + apply_filter(); + } + + bool operator==(const const_iterator& other) const { + return iter_ == other.iter_; + } + + bool operator!=(const const_iterator& other) const { + return iter_ != other.iter_; + } + + // This is the prefix increment operator (++x), which is the operator + // used by C++ range iteration (for (x : y) ...). We intentionally do not + // provide a postfix increment operator. + const_iterator& operator++() { + ++iter_; + apply_filter(); + return *this; + } + + value_type operator*() { return *iter_; } + }; + + const_iterator begin() { + return const_iterator(edges_.begin(), edges_.end()); + } + const_iterator end() { return const_iterator(edges_.end(), edges_.end()); } +}; + // Thread compatible but not thread safe. class Graph { public: @@ -345,7 +405,7 @@ class Graph { // smaller than num_edge_ids(). If one needs to create an array of // edges indexed by edge ids, num_edge_ids() should be used as the // array's size. - int num_edges() const { return edges().size(); } + int num_edges() const { return num_edges_; } // Serialize the nodes starting at `from_node_id` to a GraphDef. void ToGraphDefSubRange(GraphDef* graph_def, int from_node_id) const; @@ -381,7 +441,7 @@ class Graph { // Access to the set of all edges. Example usage: // for (const Edge* e : graph.edges()) { ... } - const EdgeSet& edges() const { return edge_set_; } + GraphEdgesIterable edges() const { return GraphEdgesIterable(edges_); } // The pre-defined nodes. enum { kSourceId = 0, kSinkId = 1 }; @@ -421,9 +481,8 @@ class Graph { // the edge with that id was removed from the graph. std::vector edges_; - // For ease of iteration, we currently just keep a set of all live - // edges. May want to optimize by removing this copy. - EdgeSet edge_set_; + // The number of entries in edges_ that are not nullptr. + int num_edges_ = 0; // Allocated but free nodes and edges. std::vector free_nodes_;