diff --git a/tensorflow/core/graph/graph.cc b/tensorflow/core/graph/graph.cc index 1f8a4d06c7a..98cf4a2491d 100644 --- a/tensorflow/core/graph/graph.cc +++ b/tensorflow/core/graph/graph.cc @@ -17,6 +17,7 @@ limitations under the License. #include +#include "absl/container/flat_hash_map.h" #include "tensorflow/core/framework/graph.pb.h" #include "tensorflow/core/framework/node_def.pb.h" #include "tensorflow/core/framework/node_properties.h" @@ -38,65 +39,63 @@ namespace tensorflow { const int Graph::kControlSlot = -1; // Node +Node::NodeClass Node::GetNodeClassForOp(const string& ts) { + static const absl::flat_hash_map* kNodeClassTable = #define REF_CLASS(key, value) \ {key, value}, { "Ref" key, value } - -const std::unordered_map& Node::kNodeClassTable = - *new std::unordered_map({ - // Keep in same order as NodeClass values - REF_CLASS("Switch", NC_SWITCH), - REF_CLASS("_SwitchN", NC_SWITCH), - REF_CLASS("Merge", NC_MERGE), - REF_CLASS("Enter", NC_ENTER), - REF_CLASS("Exit", NC_EXIT), - REF_CLASS("NextIteration", NC_NEXT_ITERATION), - {"LoopCond", NC_LOOP_COND}, - {"ControlTrigger", NC_CONTROL_TRIGGER}, - {"_Send", NC_SEND}, - {"_HostSend", NC_HOST_SEND}, - {"_Recv", NC_RECV}, - {"_HostRecv", NC_HOST_RECV}, - {"Const", NC_CONSTANT}, - {"HostConst", NC_CONSTANT}, - {"Variable", NC_VARIABLE}, - {"VariableV2", NC_VARIABLE}, - REF_CLASS("Identity", NC_IDENTITY), - {"GetSessionHandle", NC_GET_SESSION_HANDLE}, - {"GetSessionHandleV2", NC_GET_SESSION_HANDLE}, - {"GetSessionTensor", NC_GET_SESSION_TENSOR}, - {"DeleteSessionTensor", NC_DELETE_SESSION_TENSOR}, - {"Size", NC_METADATA}, - {"Shape", NC_METADATA}, - {"Rank", NC_METADATA}, - {"_ScopedAllocator", NC_SCOPED_ALLOCATOR}, - {"CollectiveReduce", NC_COLLECTIVE}, - {"CollectiveBcastSend", NC_COLLECTIVE}, - {"CollectiveBcastRecv", NC_COLLECTIVE}, - {"CollectiveGather", NC_COLLECTIVE}, - {"FakeParam", NC_FAKE_PARAM}, - {"PartitionedCall", NC_PARTITIONED_CALL}, - {"StatefulPartitionedCall", NC_PARTITIONED_CALL}, - {"SymbolicGradient", NC_SYMBOLIC_GRADIENT}, - {"If", NC_IF}, - {"StatelessIf", NC_IF}, - {"While", NC_WHILE}, - {"StatelessWhile", NC_WHILE}, - // Not using the constants defined in FunctionLibraryDefinition - // for the - // 4 ops below because android inference library does not link - // tf.function related files. - {"_Arg", NC_ARG}, - {"_DeviceArg", NC_ARG}, - {"_Retval", NC_RETVAL}, - {"_DeviceRetval", NC_RETVAL}, - {"_XlaMerge", NC_MERGE}, - }); - + new absl::flat_hash_map({ + // Keep in same order as NodeClass values + REF_CLASS("Switch", NC_SWITCH), + REF_CLASS("_SwitchN", NC_SWITCH), + REF_CLASS("Merge", NC_MERGE), + REF_CLASS("Enter", NC_ENTER), + REF_CLASS("Exit", NC_EXIT), + REF_CLASS("NextIteration", NC_NEXT_ITERATION), + {"LoopCond", NC_LOOP_COND}, + {"ControlTrigger", NC_CONTROL_TRIGGER}, + {"_Send", NC_SEND}, + {"_HostSend", NC_HOST_SEND}, + {"_Recv", NC_RECV}, + {"_HostRecv", NC_HOST_RECV}, + {"Const", NC_CONSTANT}, + {"HostConst", NC_CONSTANT}, + {"Variable", NC_VARIABLE}, + {"VariableV2", NC_VARIABLE}, + REF_CLASS("Identity", NC_IDENTITY), + {"GetSessionHandle", NC_GET_SESSION_HANDLE}, + {"GetSessionHandleV2", NC_GET_SESSION_HANDLE}, + {"GetSessionTensor", NC_GET_SESSION_TENSOR}, + {"DeleteSessionTensor", NC_DELETE_SESSION_TENSOR}, + {"Size", NC_METADATA}, + {"Shape", NC_METADATA}, + {"Rank", NC_METADATA}, + {"_ScopedAllocator", NC_SCOPED_ALLOCATOR}, + {"CollectiveReduce", NC_COLLECTIVE}, + {"CollectiveBcastSend", NC_COLLECTIVE}, + {"CollectiveBcastRecv", NC_COLLECTIVE}, + {"CollectiveGather", NC_COLLECTIVE}, + {"FakeParam", NC_FAKE_PARAM}, + {"PartitionedCall", NC_PARTITIONED_CALL}, + {"StatefulPartitionedCall", NC_PARTITIONED_CALL}, + {"SymbolicGradient", NC_SYMBOLIC_GRADIENT}, + {"If", NC_IF}, + {"StatelessIf", NC_IF}, + {"While", NC_WHILE}, + {"StatelessWhile", NC_WHILE}, + // Not using the constants defined in FunctionLibraryDefinition + // for the + // 4 ops below because android inference library does not link + // tf.function related files. + {"_Arg", NC_ARG}, + {"_DeviceArg", NC_ARG}, + {"_Retval", NC_RETVAL}, + {"_DeviceRetval", NC_RETVAL}, + {"_XlaMerge", NC_MERGE}, + }); #undef REF_CLASS -Node::NodeClass Node::GetNodeClassForOp(const string& ts) { - auto it = kNodeClassTable.find(ts); - if (it != kNodeClassTable.end()) { + auto it = kNodeClassTable->find(ts); + if (it != kNodeClassTable->end()) { return it->second; } else { return NC_OTHER; @@ -127,7 +126,7 @@ Node::Node() void Node::Initialize(int id, int cost_id, std::shared_ptr props, - bool is_function_op) { + Node::NodeClass node_class) { DCHECK_EQ(id_, -1); DCHECK(in_edges_.empty()); DCHECK(out_edges_.empty()); @@ -135,12 +134,7 @@ void Node::Initialize(int id, int cost_id, cost_id_ = cost_id; props_ = std::move(props); - // Initialize the class_ based on the type string - if (is_function_op) { - class_ = NC_FUNCTION_OP; - } else { - class_ = GetNodeClassForOp(props_->node_def.op()); - } + class_ = node_class; } void Node::Clear() { @@ -423,18 +417,21 @@ Node* Graph::AddNode(NodeDef node_def, Status* status) { return nullptr; } + Node::NodeClass node_class = op_reg_data->is_function_op + ? Node::NC_FUNCTION_OP + : Node::GetNodeClassForOp(node_def.op()); + Node* node = AllocateNode( std::make_shared(&op_reg_data->op_def, std::move(node_def), inputs, outputs), - nullptr, op_reg_data->is_function_op); + nullptr, node_class); return node; } Node* Graph::CopyNode(const Node* node) { DCHECK(!node->IsSource()); DCHECK(!node->IsSink()); - Node* copy = - AllocateNode(node->props_, node, node->class_ == Node::NC_FUNCTION_OP); + Node* copy = AllocateNode(node->props_, node, node->class_); copy->set_assigned_device_name(node->assigned_device_name()); // Since the OpDef of a function may be owned by the Graph that owns 'node', @@ -759,7 +756,7 @@ Status Graph::IsValidInputTensor(const Node* node, int idx) const { } Node* Graph::AllocateNode(std::shared_ptr props, - const Node* cost_node, bool is_function_op) { + const Node* cost_node, Node::NodeClass node_class) { Node* node = nullptr; if (free_nodes_.empty()) { node = new (arena_.Alloc(sizeof(Node))) Node; // placement new @@ -770,7 +767,7 @@ Node* Graph::AllocateNode(std::shared_ptr props, node->graph_ = this; const int id = nodes_.size(); int cost_id = cost_node ? cost_node->cost_id() : id; - node->Initialize(id, cost_id, std::move(props), is_function_op); + node->Initialize(id, cost_id, std::move(props), node_class); nodes_.push_back(node); ++num_nodes_; return node; diff --git a/tensorflow/core/graph/graph.h b/tensorflow/core/graph/graph.h index cdb2d123eaf..675f96fa5cd 100644 --- a/tensorflow/core/graph/graph.h +++ b/tensorflow/core/graph/graph.h @@ -236,10 +236,6 @@ class Node { friend class Graph; Node(); - - void Initialize(int id, int cost_id, std::shared_ptr props, - bool is_function_op); - // Releases memory from props_, in addition to restoring *this to its // uninitialized state. void Clear(); @@ -291,7 +287,8 @@ class Node { NC_OTHER // Not a special kind of node }; - static const std::unordered_map& kNodeClassTable; + void Initialize(int id, int cost_id, std::shared_ptr props, + NodeClass node_class); static NodeClass GetNodeClassForOp(const string& ts); @@ -692,7 +689,7 @@ class Graph { // // Ownership of the returned Node is not transferred to caller. Node* AllocateNode(std::shared_ptr props, - const Node* cost_node, bool is_function_op); + const Node* cost_node, Node::NodeClass node_class); void ReleaseNode(Node* node); // Insert edge in free_edges_ for possible reuse. void RecycleEdge(const Edge* edge);