[tf.data] Several optimizations for the graph hashing code.

1. Avoid copying the `GraphDef` each time a `GraphHasher` is created. The graph always outlives the hasher, so an unowned pointer is acceptable here. Should save O(#nodes) copies.
2. Use the same `FunctionLibraryDefinition` for all hashing. Previously we were converting it to and from a submessage of `GraphDef`, which led to a lot of copies, dynamic allocations, etc. Instead, we either build it once for the root node, or (ideally) the user passes in an already-constructed library, then we use that for all nodes. Since the function library typically has O(1) functions per node, this saves O(#nodes^2) copies.

PiperOrigin-RevId: 301307984
Change-Id: I6e28ffd1df908840e946e43d3be3dc2f5106eb55
This commit is contained in:
Derek Murray 2020-03-16 22:39:17 -07:00 committed by TensorFlower Gardener
parent 489126360d
commit f18697daa1
3 changed files with 28 additions and 15 deletions

View File

@ -153,10 +153,11 @@ Status ParseInputNodeName(const std::string& input_name, std::string* node_name,
// https://stackoverflow.com/questions/11338746/directed-graphs-with-a-given-root-node-match-another-directed-graph-for-equali
class GraphHasher {
public:
explicit GraphHasher(const GraphDef& graph_def, const NodeDef* root_node)
: graph_def_(graph_def),
root_node_(root_node),
flib_def_(OpRegistry::Global(), graph_def.library()) {}
// `GraphHasher` does not take ownership of `graph_def`, `root_node`, or
// `flib_def`.
explicit GraphHasher(const GraphDef* graph_def, const NodeDef* root_node,
const FunctionLibraryDefinition* flib_def)
: graph_def_(graph_def), root_node_(root_node), flib_def_(flib_def) {}
Status ComputeHash(uint64* hash) {
TF_RETURN_IF_ERROR(Init());
@ -189,7 +190,7 @@ class GraphHasher {
TF_RETURN_IF_ERROR(ParseInputNodeName(node->input(i), &node_name,
&suffix, &is_control_input));
const NodeDef* input_node;
TF_RETURN_IF_ERROR(FindNode(graph_def_, node_name, &input_node));
TF_RETURN_IF_ERROR(FindNode(*graph_def_, node_name, &input_node));
// If we've already seen this node before, skip it and don't add it to
// the queue.
@ -308,20 +309,19 @@ class GraphHasher {
}
Status HashFunction(const NameAttrList& func, uint64* hash) {
const FunctionDef* fdef = flib_def_.Find(func.name());
const FunctionDef* fdef = flib_def_->Find(func.name());
// Convert to a GraphDef.
std::unique_ptr<FunctionBody> fbody;
TF_RETURN_IF_ERROR(FunctionDefToBodyHelper(*fdef, AttrSlice(&func.attr()),
&flib_def_, &fbody));
flib_def_, &fbody));
GraphDef graph_def = fbody->graph->ToGraphDefDebug();
graph_def.mutable_library()->MergeFrom(flib_def_.ToProto());
// For each return node, we create a new GraphHasher to compute a hash.
// We then combine these hashes to produce the hash ordered.
uint64 ret_nodes_hash = 0;
for (const auto& ret_node : fbody->ret_nodes) {
GraphHasher ret_node_hasher(graph_def, &ret_node->def());
GraphHasher ret_node_hasher(&graph_def, &ret_node->def(), flib_def_);
uint64 ret_node_hash = 0;
TF_RETURN_IF_ERROR(ret_node_hasher.ComputeHash(&ret_node_hash));
ret_nodes_hash = Hash64Combine(ret_nodes_hash, ret_node_hash);
@ -359,9 +359,9 @@ class GraphHasher {
}
};
const GraphDef graph_def_;
const NodeDef* root_node_;
const FunctionLibraryDefinition flib_def_;
const GraphDef* const graph_def_; // Not owned.
const NodeDef* const root_node_; // Not owned.
const FunctionLibraryDefinition* const flib_def_; // Not owned.
// Edges that need to be pruned as their presence will cause cycles.
absl::flat_hash_set<uint64> cycle_forming_edges_;
absl::flat_hash_map<const NodeDef*, NodeRep> nodes_;
@ -397,7 +397,14 @@ Status HashTensor(const Tensor& tensor, uint64* hash) {
}
Status HashNode(const GraphDef& graph, const NodeDef& node, uint64* hash) {
GraphHasher graph_hasher(graph, &node);
const FunctionLibraryDefinition flib_def(OpRegistry::Global(),
graph.library());
return HashNode(graph, node, flib_def, hash);
}
Status HashNode(const GraphDef& graph, const NodeDef& node,
const FunctionLibraryDefinition& flib_def, uint64* hash) {
GraphHasher graph_hasher(&graph, &node, &flib_def);
return graph_hasher.ComputeHash(hash);
}
@ -414,7 +421,9 @@ Status HashGraph(const GraphDef& graph_def, uint64* hash) {
return errors::Internal("Cannot find sink node for dataset graph.");
}
GraphHasher graph_hasher(graph_def, sink);
const FunctionLibraryDefinition flib_def(OpRegistry::Global(),
graph_def.library());
GraphHasher graph_hasher(&graph_def, sink, &flib_def);
TF_RETURN_IF_ERROR(graph_hasher.ComputeHash(hash));
return Status::OK();
}

View File

@ -135,6 +135,8 @@ Status VerifyShapesCompatible(const std::vector<PartialTensorShape>& expected,
// NOTE: There is currently no guarantee that the hash of a subgraph will stay
// the same between TensorFlow builds.
Status HashNode(const GraphDef& graph, const NodeDef& node, uint64* hash);
Status HashNode(const GraphDef& graph, const NodeDef& node,
const FunctionLibraryDefinition& flib_def, uint64* hash);
// Returns a stable hash of the given tensor.
//

View File

@ -195,8 +195,10 @@ Status RewriteDataset(OpKernelContext* ctx, const DatasetBase* input,
if (record_fingerprint) {
(*ctx->runner())([graph_def = std::move(graph_def),
lib_def = lib_def.release(),
input_list = std::move(input_list),
output_node = std::move(output_node)]() {
std::unique_ptr<FunctionLibraryDefinition> lib_def_owner(lib_def);
const NodeDef* node_def = nullptr;
for (const auto& node : graph_def.node()) {
if (node.name() == output_node) {
@ -209,7 +211,7 @@ Status RewriteDataset(OpKernelContext* ctx, const DatasetBase* input,
return;
}
uint64 hash = 0;
Status s = HashNode(graph_def, *node_def, &hash);
Status s = HashNode(graph_def, *node_def, *lib_def, &hash);
if (!s.ok()) {
VLOG(3) << "Failed to hash graph: " << s.ToString();
return;