diff --git a/tensorflow/core/common_runtime/eager/execute.cc b/tensorflow/core/common_runtime/eager/execute.cc index c6e8573cc28..392a0a7a61a 100644 --- a/tensorflow/core/common_runtime/eager/execute.cc +++ b/tensorflow/core/common_runtime/eager/execute.cc @@ -932,10 +932,25 @@ Status EagerKernelExecute(EagerContext* ctx, Device* device, { GraphCollector* collector = ctx->GetGraphCollector(); mutex_lock mll(collector->mu); - for (const auto& graph : collector->graphs) { + + // Adding to partition graphs for backward compatibility. + for (const auto& graph : collector->partitioned_graphs) { *ctx->RunMetadataProto()->add_partition_graphs() = graph; } - collector->graphs.clear(); + + if (collector->dirty) { + auto* function_graphs = + ctx->RunMetadataProto()->add_function_graphs(); + *function_graphs->mutable_post_optimization_graph() = + collector->optimized_graph; + *function_graphs->mutable_pre_optimization_graph() = + collector->raw_graph; + for (const auto& graph : collector->partitioned_graphs) { + *function_graphs->add_partition_graphs() = graph; + } + } + + collector->ClearGraphs(); } auto* step_stats = ctx->RunMetadataProto()->mutable_step_stats(); // Lazily initialize the RunMetadata with information about all devices if diff --git a/tensorflow/core/common_runtime/process_function_library_runtime.cc b/tensorflow/core/common_runtime/process_function_library_runtime.cc index 950a93671c7..608ce8028ac 100644 --- a/tensorflow/core/common_runtime/process_function_library_runtime.cc +++ b/tensorflow/core/common_runtime/process_function_library_runtime.cc @@ -526,6 +526,13 @@ Status ProcessFunctionLibraryRuntime::InstantiateMultiDevice( TF_RETURN_IF_ERROR(GetGraphAndRets(function_name, attrs, fdef, lib_def, &graph, &ret_node_names)); + if (options.graph_collector != nullptr) { + GraphDef def; + graph->ToGraphDef(&def); + *def.mutable_library() = lib_def->ReachableDefinitions(def).ToProto(); + options.graph_collector->CollectRawGraph(def); + } + DeviceSet device_set; for (auto d : device_mgr_->ListDevices()) { device_set.AddDevice(d); @@ -592,6 +599,13 @@ Status ProcessFunctionLibraryRuntime::InstantiateMultiDevice( OptimizationPassRegistry::POST_REWRITE_FOR_EXEC, optimization_options)); DumpGraph("After all optimization passes", graph.get()); + if (options.graph_collector != nullptr) { + GraphDef def; + graph->ToGraphDef(&def); + *def.mutable_library() = lib_def->ReachableDefinitions(def).ToProto(); + options.graph_collector->CollectOptimizedGraph(def); + } + std::unordered_map> subgraphs; TF_RETURN_IF_ERROR( PartitionFunctionGraph(device_set, std::move(graph), &subgraphs)); @@ -600,7 +614,8 @@ Status ProcessFunctionLibraryRuntime::InstantiateMultiDevice( for (const auto& pair : subgraphs) { GraphDef def; pair.second->ToGraphDef(&def); - options.graph_collector->CollectGraph(def); + *def.mutable_library() = lib_def->ReachableDefinitions(def).ToProto(); + options.graph_collector->CollectPartitionedGraph(def); } } diff --git a/tensorflow/core/framework/op_kernel.h b/tensorflow/core/framework/op_kernel.h index 06b90964ad1..f128b407241 100644 --- a/tensorflow/core/framework/op_kernel.h +++ b/tensorflow/core/framework/op_kernel.h @@ -525,11 +525,42 @@ struct TensorValue { // Used to store partitioned graphs from function-calling ops. struct GraphCollector { mutex mu; - std::vector graphs GUARDED_BY(mu); + std::vector partitioned_graphs GUARDED_BY(mu); + GraphDef raw_graph GUARDED_BY(mu); + GraphDef optimized_graph GUARDED_BY(mu); - void CollectGraph(const GraphDef& graph) { + bool dirty GUARDED_BY(mu); + + GraphCollector() : dirty(false) {} + + void CollectRawGraph(const GraphDef& graph) { mutex_lock ml(mu); - graphs.push_back(graph); + raw_graph.MergeFrom(graph); + dirty = true; + } + + void CollectOptimizedGraph(const GraphDef& graph) { + mutex_lock ml(mu); + optimized_graph.MergeFrom(graph); + dirty = true; + } + + void CollectPartitionedGraph(const GraphDef& graph) { + mutex_lock ml(mu); + partitioned_graphs.push_back(graph); + dirty = true; + } + + void ClearGraphs() EXCLUSIVE_LOCKS_REQUIRED(mu) { + raw_graph.Clear(); + optimized_graph.Clear(); + partitioned_graphs.clear(); + dirty = false; + } + + bool HasUpdatedGraphs() { + mutex_lock ml(mu); + return dirty; } }; diff --git a/tensorflow/core/protobuf/config.proto b/tensorflow/core/protobuf/config.proto index 44e98542ec0..3e24235369a 100644 --- a/tensorflow/core/protobuf/config.proto +++ b/tensorflow/core/protobuf/config.proto @@ -520,6 +520,25 @@ message RunMetadata { // Graphs of the partitions executed by executors. repeated GraphDef partition_graphs = 3; + + message FunctionGraphs { + // TODO(nareshmodi): Include some sort of function/cache-key identifier? + repeated GraphDef partition_graphs = 1; + + GraphDef pre_optimization_graph = 2; + GraphDef post_optimization_graph = 3; + } + // This is only populated for graphs that are run as functions in TensorFlow + // V2. There will be an entry below for each function that is traced. + // The main use cases of the post_optimization_graph and the partition_graphs + // is to give the caller insight into the graphs that were actually run by the + // runtime. Additional information (such as those in step_stats) will match + // these graphs. + // We also include the pre_optimization_graph since it is usually easier to + // read, and is helpful in situations where the caller wants to get a high + // level idea of what the built graph looks like (since the various graph + // optimization passes might change the structure of the graph significantly). + repeated FunctionGraphs function_graphs = 4; } // Defines a connection between two tensors in a `GraphDef`. diff --git a/tensorflow/tools/api/golden/v1/tensorflow.-run-metadata.-function-graphs.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.-run-metadata.-function-graphs.pbtxt new file mode 100644 index 00000000000..d2e2f583d21 --- /dev/null +++ b/tensorflow/tools/api/golden/v1/tensorflow.-run-metadata.-function-graphs.pbtxt @@ -0,0 +1,27 @@ +path: "tensorflow.RunMetadata.FunctionGraphs" +tf_proto { + descriptor { + name: "FunctionGraphs" + field { + name: "partition_graphs" + number: 1 + label: LABEL_REPEATED + type: TYPE_MESSAGE + type_name: ".tensorflow.GraphDef" + } + field { + name: "pre_optimization_graph" + number: 2 + label: LABEL_OPTIONAL + type: TYPE_MESSAGE + type_name: ".tensorflow.GraphDef" + } + field { + name: "post_optimization_graph" + number: 3 + label: LABEL_OPTIONAL + type: TYPE_MESSAGE + type_name: ".tensorflow.GraphDef" + } + } +} diff --git a/tensorflow/tools/api/golden/v1/tensorflow.-run-metadata.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.-run-metadata.pbtxt index 1287940326c..777b889745f 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.-run-metadata.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.-run-metadata.pbtxt @@ -23,5 +23,36 @@ tf_proto { type: TYPE_MESSAGE type_name: ".tensorflow.GraphDef" } + field { + name: "function_graphs" + number: 4 + label: LABEL_REPEATED + type: TYPE_MESSAGE + type_name: ".tensorflow.RunMetadata.FunctionGraphs" + } + nested_type { + name: "FunctionGraphs" + field { + name: "partition_graphs" + number: 1 + label: LABEL_REPEATED + type: TYPE_MESSAGE + type_name: ".tensorflow.GraphDef" + } + field { + name: "pre_optimization_graph" + number: 2 + label: LABEL_OPTIONAL + type: TYPE_MESSAGE + type_name: ".tensorflow.GraphDef" + } + field { + name: "post_optimization_graph" + number: 3 + label: LABEL_OPTIONAL + type: TYPE_MESSAGE + type_name: ".tensorflow.GraphDef" + } + } } }