[Graph] Avoid calling Node::GetNodeClassForOp()
when a node is copied.
Instead, copy the class from the original node. This change also modifies the `kNodeClassTable` to use `absl::flat_hash_map<>`. PiperOrigin-RevId: 306945263 Change-Id: I8eb1c80b57fdf204fbc7072a55615dd688025e87
This commit is contained in:
parent
1bbc26d375
commit
aec85065ff
@ -17,6 +17,7 @@ limitations under the License.
|
|||||||
|
|
||||||
#include <vector>
|
#include <vector>
|
||||||
|
|
||||||
|
#include "absl/container/flat_hash_map.h"
|
||||||
#include "tensorflow/core/framework/graph.pb.h"
|
#include "tensorflow/core/framework/graph.pb.h"
|
||||||
#include "tensorflow/core/framework/node_def.pb.h"
|
#include "tensorflow/core/framework/node_def.pb.h"
|
||||||
#include "tensorflow/core/framework/node_properties.h"
|
#include "tensorflow/core/framework/node_properties.h"
|
||||||
@ -38,65 +39,63 @@ namespace tensorflow {
|
|||||||
const int Graph::kControlSlot = -1;
|
const int Graph::kControlSlot = -1;
|
||||||
|
|
||||||
// Node
|
// Node
|
||||||
|
Node::NodeClass Node::GetNodeClassForOp(const string& ts) {
|
||||||
|
static const absl::flat_hash_map<string, Node::NodeClass>* kNodeClassTable =
|
||||||
#define REF_CLASS(key, value) \
|
#define REF_CLASS(key, value) \
|
||||||
{key, value}, { "Ref" key, value }
|
{key, value}, { "Ref" key, value }
|
||||||
|
new absl::flat_hash_map<string, Node::NodeClass>({
|
||||||
const std::unordered_map<string, Node::NodeClass>& Node::kNodeClassTable =
|
// Keep in same order as NodeClass values
|
||||||
*new std::unordered_map<string, Node::NodeClass>({
|
REF_CLASS("Switch", NC_SWITCH),
|
||||||
// Keep in same order as NodeClass values
|
REF_CLASS("_SwitchN", NC_SWITCH),
|
||||||
REF_CLASS("Switch", NC_SWITCH),
|
REF_CLASS("Merge", NC_MERGE),
|
||||||
REF_CLASS("_SwitchN", NC_SWITCH),
|
REF_CLASS("Enter", NC_ENTER),
|
||||||
REF_CLASS("Merge", NC_MERGE),
|
REF_CLASS("Exit", NC_EXIT),
|
||||||
REF_CLASS("Enter", NC_ENTER),
|
REF_CLASS("NextIteration", NC_NEXT_ITERATION),
|
||||||
REF_CLASS("Exit", NC_EXIT),
|
{"LoopCond", NC_LOOP_COND},
|
||||||
REF_CLASS("NextIteration", NC_NEXT_ITERATION),
|
{"ControlTrigger", NC_CONTROL_TRIGGER},
|
||||||
{"LoopCond", NC_LOOP_COND},
|
{"_Send", NC_SEND},
|
||||||
{"ControlTrigger", NC_CONTROL_TRIGGER},
|
{"_HostSend", NC_HOST_SEND},
|
||||||
{"_Send", NC_SEND},
|
{"_Recv", NC_RECV},
|
||||||
{"_HostSend", NC_HOST_SEND},
|
{"_HostRecv", NC_HOST_RECV},
|
||||||
{"_Recv", NC_RECV},
|
{"Const", NC_CONSTANT},
|
||||||
{"_HostRecv", NC_HOST_RECV},
|
{"HostConst", NC_CONSTANT},
|
||||||
{"Const", NC_CONSTANT},
|
{"Variable", NC_VARIABLE},
|
||||||
{"HostConst", NC_CONSTANT},
|
{"VariableV2", NC_VARIABLE},
|
||||||
{"Variable", NC_VARIABLE},
|
REF_CLASS("Identity", NC_IDENTITY),
|
||||||
{"VariableV2", NC_VARIABLE},
|
{"GetSessionHandle", NC_GET_SESSION_HANDLE},
|
||||||
REF_CLASS("Identity", NC_IDENTITY),
|
{"GetSessionHandleV2", NC_GET_SESSION_HANDLE},
|
||||||
{"GetSessionHandle", NC_GET_SESSION_HANDLE},
|
{"GetSessionTensor", NC_GET_SESSION_TENSOR},
|
||||||
{"GetSessionHandleV2", NC_GET_SESSION_HANDLE},
|
{"DeleteSessionTensor", NC_DELETE_SESSION_TENSOR},
|
||||||
{"GetSessionTensor", NC_GET_SESSION_TENSOR},
|
{"Size", NC_METADATA},
|
||||||
{"DeleteSessionTensor", NC_DELETE_SESSION_TENSOR},
|
{"Shape", NC_METADATA},
|
||||||
{"Size", NC_METADATA},
|
{"Rank", NC_METADATA},
|
||||||
{"Shape", NC_METADATA},
|
{"_ScopedAllocator", NC_SCOPED_ALLOCATOR},
|
||||||
{"Rank", NC_METADATA},
|
{"CollectiveReduce", NC_COLLECTIVE},
|
||||||
{"_ScopedAllocator", NC_SCOPED_ALLOCATOR},
|
{"CollectiveBcastSend", NC_COLLECTIVE},
|
||||||
{"CollectiveReduce", NC_COLLECTIVE},
|
{"CollectiveBcastRecv", NC_COLLECTIVE},
|
||||||
{"CollectiveBcastSend", NC_COLLECTIVE},
|
{"CollectiveGather", NC_COLLECTIVE},
|
||||||
{"CollectiveBcastRecv", NC_COLLECTIVE},
|
{"FakeParam", NC_FAKE_PARAM},
|
||||||
{"CollectiveGather", NC_COLLECTIVE},
|
{"PartitionedCall", NC_PARTITIONED_CALL},
|
||||||
{"FakeParam", NC_FAKE_PARAM},
|
{"StatefulPartitionedCall", NC_PARTITIONED_CALL},
|
||||||
{"PartitionedCall", NC_PARTITIONED_CALL},
|
{"SymbolicGradient", NC_SYMBOLIC_GRADIENT},
|
||||||
{"StatefulPartitionedCall", NC_PARTITIONED_CALL},
|
{"If", NC_IF},
|
||||||
{"SymbolicGradient", NC_SYMBOLIC_GRADIENT},
|
{"StatelessIf", NC_IF},
|
||||||
{"If", NC_IF},
|
{"While", NC_WHILE},
|
||||||
{"StatelessIf", NC_IF},
|
{"StatelessWhile", NC_WHILE},
|
||||||
{"While", NC_WHILE},
|
// Not using the constants defined in FunctionLibraryDefinition
|
||||||
{"StatelessWhile", NC_WHILE},
|
// for the
|
||||||
// Not using the constants defined in FunctionLibraryDefinition
|
// 4 ops below because android inference library does not link
|
||||||
// for the
|
// tf.function related files.
|
||||||
// 4 ops below because android inference library does not link
|
{"_Arg", NC_ARG},
|
||||||
// tf.function related files.
|
{"_DeviceArg", NC_ARG},
|
||||||
{"_Arg", NC_ARG},
|
{"_Retval", NC_RETVAL},
|
||||||
{"_DeviceArg", NC_ARG},
|
{"_DeviceRetval", NC_RETVAL},
|
||||||
{"_Retval", NC_RETVAL},
|
{"_XlaMerge", NC_MERGE},
|
||||||
{"_DeviceRetval", NC_RETVAL},
|
});
|
||||||
{"_XlaMerge", NC_MERGE},
|
|
||||||
});
|
|
||||||
|
|
||||||
#undef REF_CLASS
|
#undef REF_CLASS
|
||||||
|
|
||||||
Node::NodeClass Node::GetNodeClassForOp(const string& ts) {
|
auto it = kNodeClassTable->find(ts);
|
||||||
auto it = kNodeClassTable.find(ts);
|
if (it != kNodeClassTable->end()) {
|
||||||
if (it != kNodeClassTable.end()) {
|
|
||||||
return it->second;
|
return it->second;
|
||||||
} else {
|
} else {
|
||||||
return NC_OTHER;
|
return NC_OTHER;
|
||||||
@ -127,7 +126,7 @@ Node::Node()
|
|||||||
|
|
||||||
void Node::Initialize(int id, int cost_id,
|
void Node::Initialize(int id, int cost_id,
|
||||||
std::shared_ptr<NodeProperties> props,
|
std::shared_ptr<NodeProperties> props,
|
||||||
bool is_function_op) {
|
Node::NodeClass node_class) {
|
||||||
DCHECK_EQ(id_, -1);
|
DCHECK_EQ(id_, -1);
|
||||||
DCHECK(in_edges_.empty());
|
DCHECK(in_edges_.empty());
|
||||||
DCHECK(out_edges_.empty());
|
DCHECK(out_edges_.empty());
|
||||||
@ -135,12 +134,7 @@ void Node::Initialize(int id, int cost_id,
|
|||||||
cost_id_ = cost_id;
|
cost_id_ = cost_id;
|
||||||
|
|
||||||
props_ = std::move(props);
|
props_ = std::move(props);
|
||||||
// Initialize the class_ based on the type string
|
class_ = node_class;
|
||||||
if (is_function_op) {
|
|
||||||
class_ = NC_FUNCTION_OP;
|
|
||||||
} else {
|
|
||||||
class_ = GetNodeClassForOp(props_->node_def.op());
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
void Node::Clear() {
|
void Node::Clear() {
|
||||||
@ -423,18 +417,21 @@ Node* Graph::AddNode(NodeDef node_def, Status* status) {
|
|||||||
return nullptr;
|
return nullptr;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
Node::NodeClass node_class = op_reg_data->is_function_op
|
||||||
|
? Node::NC_FUNCTION_OP
|
||||||
|
: Node::GetNodeClassForOp(node_def.op());
|
||||||
|
|
||||||
Node* node = AllocateNode(
|
Node* node = AllocateNode(
|
||||||
std::make_shared<NodeProperties>(&op_reg_data->op_def,
|
std::make_shared<NodeProperties>(&op_reg_data->op_def,
|
||||||
std::move(node_def), inputs, outputs),
|
std::move(node_def), inputs, outputs),
|
||||||
nullptr, op_reg_data->is_function_op);
|
nullptr, node_class);
|
||||||
return node;
|
return node;
|
||||||
}
|
}
|
||||||
|
|
||||||
Node* Graph::CopyNode(const Node* node) {
|
Node* Graph::CopyNode(const Node* node) {
|
||||||
DCHECK(!node->IsSource());
|
DCHECK(!node->IsSource());
|
||||||
DCHECK(!node->IsSink());
|
DCHECK(!node->IsSink());
|
||||||
Node* copy =
|
Node* copy = AllocateNode(node->props_, node, node->class_);
|
||||||
AllocateNode(node->props_, node, node->class_ == Node::NC_FUNCTION_OP);
|
|
||||||
copy->set_assigned_device_name(node->assigned_device_name());
|
copy->set_assigned_device_name(node->assigned_device_name());
|
||||||
|
|
||||||
// Since the OpDef of a function may be owned by the Graph that owns 'node',
|
// Since the OpDef of a function may be owned by the Graph that owns 'node',
|
||||||
@ -759,7 +756,7 @@ Status Graph::IsValidInputTensor(const Node* node, int idx) const {
|
|||||||
}
|
}
|
||||||
|
|
||||||
Node* Graph::AllocateNode(std::shared_ptr<NodeProperties> props,
|
Node* Graph::AllocateNode(std::shared_ptr<NodeProperties> props,
|
||||||
const Node* cost_node, bool is_function_op) {
|
const Node* cost_node, Node::NodeClass node_class) {
|
||||||
Node* node = nullptr;
|
Node* node = nullptr;
|
||||||
if (free_nodes_.empty()) {
|
if (free_nodes_.empty()) {
|
||||||
node = new (arena_.Alloc(sizeof(Node))) Node; // placement new
|
node = new (arena_.Alloc(sizeof(Node))) Node; // placement new
|
||||||
@ -770,7 +767,7 @@ Node* Graph::AllocateNode(std::shared_ptr<NodeProperties> props,
|
|||||||
node->graph_ = this;
|
node->graph_ = this;
|
||||||
const int id = nodes_.size();
|
const int id = nodes_.size();
|
||||||
int cost_id = cost_node ? cost_node->cost_id() : id;
|
int cost_id = cost_node ? cost_node->cost_id() : id;
|
||||||
node->Initialize(id, cost_id, std::move(props), is_function_op);
|
node->Initialize(id, cost_id, std::move(props), node_class);
|
||||||
nodes_.push_back(node);
|
nodes_.push_back(node);
|
||||||
++num_nodes_;
|
++num_nodes_;
|
||||||
return node;
|
return node;
|
||||||
|
@ -236,10 +236,6 @@ class Node {
|
|||||||
friend class Graph;
|
friend class Graph;
|
||||||
Node();
|
Node();
|
||||||
|
|
||||||
|
|
||||||
void Initialize(int id, int cost_id, std::shared_ptr<NodeProperties> props,
|
|
||||||
bool is_function_op);
|
|
||||||
|
|
||||||
// Releases memory from props_, in addition to restoring *this to its
|
// Releases memory from props_, in addition to restoring *this to its
|
||||||
// uninitialized state.
|
// uninitialized state.
|
||||||
void Clear();
|
void Clear();
|
||||||
@ -291,7 +287,8 @@ class Node {
|
|||||||
NC_OTHER // Not a special kind of node
|
NC_OTHER // Not a special kind of node
|
||||||
};
|
};
|
||||||
|
|
||||||
static const std::unordered_map<string, NodeClass>& kNodeClassTable;
|
void Initialize(int id, int cost_id, std::shared_ptr<NodeProperties> props,
|
||||||
|
NodeClass node_class);
|
||||||
|
|
||||||
static NodeClass GetNodeClassForOp(const string& ts);
|
static NodeClass GetNodeClassForOp(const string& ts);
|
||||||
|
|
||||||
@ -692,7 +689,7 @@ class Graph {
|
|||||||
//
|
//
|
||||||
// Ownership of the returned Node is not transferred to caller.
|
// Ownership of the returned Node is not transferred to caller.
|
||||||
Node* AllocateNode(std::shared_ptr<NodeProperties> props,
|
Node* AllocateNode(std::shared_ptr<NodeProperties> props,
|
||||||
const Node* cost_node, bool is_function_op);
|
const Node* cost_node, Node::NodeClass node_class);
|
||||||
void ReleaseNode(Node* node);
|
void ReleaseNode(Node* node);
|
||||||
// Insert edge in free_edges_ for possible reuse.
|
// Insert edge in free_edges_ for possible reuse.
|
||||||
void RecycleEdge(const Edge* edge);
|
void RecycleEdge(const Edge* edge);
|
||||||
|
Loading…
x
Reference in New Issue
Block a user