[Graph] Avoid calling Node::GetNodeClassForOp() when a node is copied.

Instead, copy the class from the original node. This change also modifies the `kNodeClassTable` to use `absl::flat_hash_map<>`.

PiperOrigin-RevId: 306945263
Change-Id: I8eb1c80b57fdf204fbc7072a55615dd688025e87
This commit is contained in:
Derek Murray 2020-04-16 16:34:52 -07:00 committed by TensorFlower Gardener
parent 1bbc26d375
commit aec85065ff
2 changed files with 67 additions and 73 deletions

View File

@ -17,6 +17,7 @@ limitations under the License.
#include <vector> #include <vector>
#include "absl/container/flat_hash_map.h"
#include "tensorflow/core/framework/graph.pb.h" #include "tensorflow/core/framework/graph.pb.h"
#include "tensorflow/core/framework/node_def.pb.h" #include "tensorflow/core/framework/node_def.pb.h"
#include "tensorflow/core/framework/node_properties.h" #include "tensorflow/core/framework/node_properties.h"
@ -38,65 +39,63 @@ namespace tensorflow {
const int Graph::kControlSlot = -1; const int Graph::kControlSlot = -1;
// Node // Node
Node::NodeClass Node::GetNodeClassForOp(const string& ts) {
static const absl::flat_hash_map<string, Node::NodeClass>* kNodeClassTable =
#define REF_CLASS(key, value) \ #define REF_CLASS(key, value) \
{key, value}, { "Ref" key, value } {key, value}, { "Ref" key, value }
new absl::flat_hash_map<string, Node::NodeClass>({
const std::unordered_map<string, Node::NodeClass>& Node::kNodeClassTable = // Keep in same order as NodeClass values
*new std::unordered_map<string, Node::NodeClass>({ REF_CLASS("Switch", NC_SWITCH),
// Keep in same order as NodeClass values REF_CLASS("_SwitchN", NC_SWITCH),
REF_CLASS("Switch", NC_SWITCH), REF_CLASS("Merge", NC_MERGE),
REF_CLASS("_SwitchN", NC_SWITCH), REF_CLASS("Enter", NC_ENTER),
REF_CLASS("Merge", NC_MERGE), REF_CLASS("Exit", NC_EXIT),
REF_CLASS("Enter", NC_ENTER), REF_CLASS("NextIteration", NC_NEXT_ITERATION),
REF_CLASS("Exit", NC_EXIT), {"LoopCond", NC_LOOP_COND},
REF_CLASS("NextIteration", NC_NEXT_ITERATION), {"ControlTrigger", NC_CONTROL_TRIGGER},
{"LoopCond", NC_LOOP_COND}, {"_Send", NC_SEND},
{"ControlTrigger", NC_CONTROL_TRIGGER}, {"_HostSend", NC_HOST_SEND},
{"_Send", NC_SEND}, {"_Recv", NC_RECV},
{"_HostSend", NC_HOST_SEND}, {"_HostRecv", NC_HOST_RECV},
{"_Recv", NC_RECV}, {"Const", NC_CONSTANT},
{"_HostRecv", NC_HOST_RECV}, {"HostConst", NC_CONSTANT},
{"Const", NC_CONSTANT}, {"Variable", NC_VARIABLE},
{"HostConst", NC_CONSTANT}, {"VariableV2", NC_VARIABLE},
{"Variable", NC_VARIABLE}, REF_CLASS("Identity", NC_IDENTITY),
{"VariableV2", NC_VARIABLE}, {"GetSessionHandle", NC_GET_SESSION_HANDLE},
REF_CLASS("Identity", NC_IDENTITY), {"GetSessionHandleV2", NC_GET_SESSION_HANDLE},
{"GetSessionHandle", NC_GET_SESSION_HANDLE}, {"GetSessionTensor", NC_GET_SESSION_TENSOR},
{"GetSessionHandleV2", NC_GET_SESSION_HANDLE}, {"DeleteSessionTensor", NC_DELETE_SESSION_TENSOR},
{"GetSessionTensor", NC_GET_SESSION_TENSOR}, {"Size", NC_METADATA},
{"DeleteSessionTensor", NC_DELETE_SESSION_TENSOR}, {"Shape", NC_METADATA},
{"Size", NC_METADATA}, {"Rank", NC_METADATA},
{"Shape", NC_METADATA}, {"_ScopedAllocator", NC_SCOPED_ALLOCATOR},
{"Rank", NC_METADATA}, {"CollectiveReduce", NC_COLLECTIVE},
{"_ScopedAllocator", NC_SCOPED_ALLOCATOR}, {"CollectiveBcastSend", NC_COLLECTIVE},
{"CollectiveReduce", NC_COLLECTIVE}, {"CollectiveBcastRecv", NC_COLLECTIVE},
{"CollectiveBcastSend", NC_COLLECTIVE}, {"CollectiveGather", NC_COLLECTIVE},
{"CollectiveBcastRecv", NC_COLLECTIVE}, {"FakeParam", NC_FAKE_PARAM},
{"CollectiveGather", NC_COLLECTIVE}, {"PartitionedCall", NC_PARTITIONED_CALL},
{"FakeParam", NC_FAKE_PARAM}, {"StatefulPartitionedCall", NC_PARTITIONED_CALL},
{"PartitionedCall", NC_PARTITIONED_CALL}, {"SymbolicGradient", NC_SYMBOLIC_GRADIENT},
{"StatefulPartitionedCall", NC_PARTITIONED_CALL}, {"If", NC_IF},
{"SymbolicGradient", NC_SYMBOLIC_GRADIENT}, {"StatelessIf", NC_IF},
{"If", NC_IF}, {"While", NC_WHILE},
{"StatelessIf", NC_IF}, {"StatelessWhile", NC_WHILE},
{"While", NC_WHILE}, // Not using the constants defined in FunctionLibraryDefinition
{"StatelessWhile", NC_WHILE}, // for the
// Not using the constants defined in FunctionLibraryDefinition // 4 ops below because android inference library does not link
// for the // tf.function related files.
// 4 ops below because android inference library does not link {"_Arg", NC_ARG},
// tf.function related files. {"_DeviceArg", NC_ARG},
{"_Arg", NC_ARG}, {"_Retval", NC_RETVAL},
{"_DeviceArg", NC_ARG}, {"_DeviceRetval", NC_RETVAL},
{"_Retval", NC_RETVAL}, {"_XlaMerge", NC_MERGE},
{"_DeviceRetval", NC_RETVAL}, });
{"_XlaMerge", NC_MERGE},
});
#undef REF_CLASS #undef REF_CLASS
Node::NodeClass Node::GetNodeClassForOp(const string& ts) { auto it = kNodeClassTable->find(ts);
auto it = kNodeClassTable.find(ts); if (it != kNodeClassTable->end()) {
if (it != kNodeClassTable.end()) {
return it->second; return it->second;
} else { } else {
return NC_OTHER; return NC_OTHER;
@ -127,7 +126,7 @@ Node::Node()
void Node::Initialize(int id, int cost_id, void Node::Initialize(int id, int cost_id,
std::shared_ptr<NodeProperties> props, std::shared_ptr<NodeProperties> props,
bool is_function_op) { Node::NodeClass node_class) {
DCHECK_EQ(id_, -1); DCHECK_EQ(id_, -1);
DCHECK(in_edges_.empty()); DCHECK(in_edges_.empty());
DCHECK(out_edges_.empty()); DCHECK(out_edges_.empty());
@ -135,12 +134,7 @@ void Node::Initialize(int id, int cost_id,
cost_id_ = cost_id; cost_id_ = cost_id;
props_ = std::move(props); props_ = std::move(props);
// Initialize the class_ based on the type string class_ = node_class;
if (is_function_op) {
class_ = NC_FUNCTION_OP;
} else {
class_ = GetNodeClassForOp(props_->node_def.op());
}
} }
void Node::Clear() { void Node::Clear() {
@ -423,18 +417,21 @@ Node* Graph::AddNode(NodeDef node_def, Status* status) {
return nullptr; return nullptr;
} }
Node::NodeClass node_class = op_reg_data->is_function_op
? Node::NC_FUNCTION_OP
: Node::GetNodeClassForOp(node_def.op());
Node* node = AllocateNode( Node* node = AllocateNode(
std::make_shared<NodeProperties>(&op_reg_data->op_def, std::make_shared<NodeProperties>(&op_reg_data->op_def,
std::move(node_def), inputs, outputs), std::move(node_def), inputs, outputs),
nullptr, op_reg_data->is_function_op); nullptr, node_class);
return node; return node;
} }
Node* Graph::CopyNode(const Node* node) { Node* Graph::CopyNode(const Node* node) {
DCHECK(!node->IsSource()); DCHECK(!node->IsSource());
DCHECK(!node->IsSink()); DCHECK(!node->IsSink());
Node* copy = Node* copy = AllocateNode(node->props_, node, node->class_);
AllocateNode(node->props_, node, node->class_ == Node::NC_FUNCTION_OP);
copy->set_assigned_device_name(node->assigned_device_name()); copy->set_assigned_device_name(node->assigned_device_name());
// Since the OpDef of a function may be owned by the Graph that owns 'node', // Since the OpDef of a function may be owned by the Graph that owns 'node',
@ -759,7 +756,7 @@ Status Graph::IsValidInputTensor(const Node* node, int idx) const {
} }
Node* Graph::AllocateNode(std::shared_ptr<NodeProperties> props, Node* Graph::AllocateNode(std::shared_ptr<NodeProperties> props,
const Node* cost_node, bool is_function_op) { const Node* cost_node, Node::NodeClass node_class) {
Node* node = nullptr; Node* node = nullptr;
if (free_nodes_.empty()) { if (free_nodes_.empty()) {
node = new (arena_.Alloc(sizeof(Node))) Node; // placement new node = new (arena_.Alloc(sizeof(Node))) Node; // placement new
@ -770,7 +767,7 @@ Node* Graph::AllocateNode(std::shared_ptr<NodeProperties> props,
node->graph_ = this; node->graph_ = this;
const int id = nodes_.size(); const int id = nodes_.size();
int cost_id = cost_node ? cost_node->cost_id() : id; int cost_id = cost_node ? cost_node->cost_id() : id;
node->Initialize(id, cost_id, std::move(props), is_function_op); node->Initialize(id, cost_id, std::move(props), node_class);
nodes_.push_back(node); nodes_.push_back(node);
++num_nodes_; ++num_nodes_;
return node; return node;

View File

@ -236,10 +236,6 @@ class Node {
friend class Graph; friend class Graph;
Node(); Node();
void Initialize(int id, int cost_id, std::shared_ptr<NodeProperties> props,
bool is_function_op);
// Releases memory from props_, in addition to restoring *this to its // Releases memory from props_, in addition to restoring *this to its
// uninitialized state. // uninitialized state.
void Clear(); void Clear();
@ -291,7 +287,8 @@ class Node {
NC_OTHER // Not a special kind of node NC_OTHER // Not a special kind of node
}; };
static const std::unordered_map<string, NodeClass>& kNodeClassTable; void Initialize(int id, int cost_id, std::shared_ptr<NodeProperties> props,
NodeClass node_class);
static NodeClass GetNodeClassForOp(const string& ts); static NodeClass GetNodeClassForOp(const string& ts);
@ -692,7 +689,7 @@ class Graph {
// //
// Ownership of the returned Node is not transferred to caller. // Ownership of the returned Node is not transferred to caller.
Node* AllocateNode(std::shared_ptr<NodeProperties> props, Node* AllocateNode(std::shared_ptr<NodeProperties> props,
const Node* cost_node, bool is_function_op); const Node* cost_node, Node::NodeClass node_class);
void ReleaseNode(Node* node); void ReleaseNode(Node* node);
// Insert edge in free_edges_ for possible reuse. // Insert edge in free_edges_ for possible reuse.
void RecycleEdge(const Edge* edge); void RecycleEdge(const Edge* edge);