From f18697daa15744e7fc51caf5f0a2da40904dede2 Mon Sep 17 00:00:00 2001 From: Derek Murray Date: Mon, 16 Mar 2020 22:39:17 -0700 Subject: [PATCH] [tf.data] Several optimizations for the graph hashing code. 1. Avoid copying the `GraphDef` each time a `GraphHasher` is created. The graph always outlives the hasher, so an unowned pointer is acceptable here. Should save O(#nodes) copies. 2. Use the same `FunctionLibraryDefinition` for all hashing. Previously we were converting it to and from a submessage of `GraphDef`, which led to a lot of copies, dynamic allocations, etc. Instead, we either build it once for the root node, or (ideally) the user passes in an already-constructed library, then we use that for all nodes. Since the function library typically has O(1) functions per node, this saves O(#nodes^2) copies. PiperOrigin-RevId: 301307984 Change-Id: I6e28ffd1df908840e946e43d3be3dc2f5106eb55 --- tensorflow/core/kernels/data/dataset_utils.cc | 37 ++++++++++++------- tensorflow/core/kernels/data/dataset_utils.h | 2 + tensorflow/core/kernels/data/rewrite_utils.cc | 4 +- 3 files changed, 28 insertions(+), 15 deletions(-) diff --git a/tensorflow/core/kernels/data/dataset_utils.cc b/tensorflow/core/kernels/data/dataset_utils.cc index 1502ff1951d..ee135b3c9db 100644 --- a/tensorflow/core/kernels/data/dataset_utils.cc +++ b/tensorflow/core/kernels/data/dataset_utils.cc @@ -153,10 +153,11 @@ Status ParseInputNodeName(const std::string& input_name, std::string* node_name, // https://stackoverflow.com/questions/11338746/directed-graphs-with-a-given-root-node-match-another-directed-graph-for-equali class GraphHasher { public: - explicit GraphHasher(const GraphDef& graph_def, const NodeDef* root_node) - : graph_def_(graph_def), - root_node_(root_node), - flib_def_(OpRegistry::Global(), graph_def.library()) {} + // `GraphHasher` does not take ownership of `graph_def`, `root_node`, or + // `flib_def`. + explicit GraphHasher(const GraphDef* graph_def, const NodeDef* root_node, + const FunctionLibraryDefinition* flib_def) + : graph_def_(graph_def), root_node_(root_node), flib_def_(flib_def) {} Status ComputeHash(uint64* hash) { TF_RETURN_IF_ERROR(Init()); @@ -189,7 +190,7 @@ class GraphHasher { TF_RETURN_IF_ERROR(ParseInputNodeName(node->input(i), &node_name, &suffix, &is_control_input)); const NodeDef* input_node; - TF_RETURN_IF_ERROR(FindNode(graph_def_, node_name, &input_node)); + TF_RETURN_IF_ERROR(FindNode(*graph_def_, node_name, &input_node)); // If we've already seen this node before, skip it and don't add it to // the queue. @@ -308,20 +309,19 @@ class GraphHasher { } Status HashFunction(const NameAttrList& func, uint64* hash) { - const FunctionDef* fdef = flib_def_.Find(func.name()); + const FunctionDef* fdef = flib_def_->Find(func.name()); // Convert to a GraphDef. std::unique_ptr fbody; TF_RETURN_IF_ERROR(FunctionDefToBodyHelper(*fdef, AttrSlice(&func.attr()), - &flib_def_, &fbody)); + flib_def_, &fbody)); GraphDef graph_def = fbody->graph->ToGraphDefDebug(); - graph_def.mutable_library()->MergeFrom(flib_def_.ToProto()); // For each return node, we create a new GraphHasher to compute a hash. // We then combine these hashes to produce the hash ordered. uint64 ret_nodes_hash = 0; for (const auto& ret_node : fbody->ret_nodes) { - GraphHasher ret_node_hasher(graph_def, &ret_node->def()); + GraphHasher ret_node_hasher(&graph_def, &ret_node->def(), flib_def_); uint64 ret_node_hash = 0; TF_RETURN_IF_ERROR(ret_node_hasher.ComputeHash(&ret_node_hash)); ret_nodes_hash = Hash64Combine(ret_nodes_hash, ret_node_hash); @@ -359,9 +359,9 @@ class GraphHasher { } }; - const GraphDef graph_def_; - const NodeDef* root_node_; - const FunctionLibraryDefinition flib_def_; + const GraphDef* const graph_def_; // Not owned. + const NodeDef* const root_node_; // Not owned. + const FunctionLibraryDefinition* const flib_def_; // Not owned. // Edges that need to be pruned as their presence will cause cycles. absl::flat_hash_set cycle_forming_edges_; absl::flat_hash_map nodes_; @@ -397,7 +397,14 @@ Status HashTensor(const Tensor& tensor, uint64* hash) { } Status HashNode(const GraphDef& graph, const NodeDef& node, uint64* hash) { - GraphHasher graph_hasher(graph, &node); + const FunctionLibraryDefinition flib_def(OpRegistry::Global(), + graph.library()); + return HashNode(graph, node, flib_def, hash); +} + +Status HashNode(const GraphDef& graph, const NodeDef& node, + const FunctionLibraryDefinition& flib_def, uint64* hash) { + GraphHasher graph_hasher(&graph, &node, &flib_def); return graph_hasher.ComputeHash(hash); } @@ -414,7 +421,9 @@ Status HashGraph(const GraphDef& graph_def, uint64* hash) { return errors::Internal("Cannot find sink node for dataset graph."); } - GraphHasher graph_hasher(graph_def, sink); + const FunctionLibraryDefinition flib_def(OpRegistry::Global(), + graph_def.library()); + GraphHasher graph_hasher(&graph_def, sink, &flib_def); TF_RETURN_IF_ERROR(graph_hasher.ComputeHash(hash)); return Status::OK(); } diff --git a/tensorflow/core/kernels/data/dataset_utils.h b/tensorflow/core/kernels/data/dataset_utils.h index 7c0857a5225..bedd5facda9 100644 --- a/tensorflow/core/kernels/data/dataset_utils.h +++ b/tensorflow/core/kernels/data/dataset_utils.h @@ -135,6 +135,8 @@ Status VerifyShapesCompatible(const std::vector& expected, // NOTE: There is currently no guarantee that the hash of a subgraph will stay // the same between TensorFlow builds. Status HashNode(const GraphDef& graph, const NodeDef& node, uint64* hash); +Status HashNode(const GraphDef& graph, const NodeDef& node, + const FunctionLibraryDefinition& flib_def, uint64* hash); // Returns a stable hash of the given tensor. // diff --git a/tensorflow/core/kernels/data/rewrite_utils.cc b/tensorflow/core/kernels/data/rewrite_utils.cc index 3717016bba4..609c402fd29 100644 --- a/tensorflow/core/kernels/data/rewrite_utils.cc +++ b/tensorflow/core/kernels/data/rewrite_utils.cc @@ -195,8 +195,10 @@ Status RewriteDataset(OpKernelContext* ctx, const DatasetBase* input, if (record_fingerprint) { (*ctx->runner())([graph_def = std::move(graph_def), + lib_def = lib_def.release(), input_list = std::move(input_list), output_node = std::move(output_node)]() { + std::unique_ptr lib_def_owner(lib_def); const NodeDef* node_def = nullptr; for (const auto& node : graph_def.node()) { if (node.name() == output_node) { @@ -209,7 +211,7 @@ Status RewriteDataset(OpKernelContext* ctx, const DatasetBase* input, return; } uint64 hash = 0; - Status s = HashNode(graph_def, *node_def, &hash); + Status s = HashNode(graph_def, *node_def, *lib_def, &hash); if (!s.ok()) { VLOG(3) << "Failed to hash graph: " << s.ToString(); return;