Add post optimization graph to RunMetadata (when running eager functions)
This stores the pre-grappler graph + post-grappler graph + partitioned graphs for each instantiated function. This will be useful to get pre-optimization/post-optimization graphs for displaying within tensorboard. PiperOrigin-RevId: 233813975
This commit is contained in:
parent
8c92769a36
commit
7df8df33c5
@ -932,10 +932,25 @@ Status EagerKernelExecute(EagerContext* ctx, Device* device,
|
||||
{
|
||||
GraphCollector* collector = ctx->GetGraphCollector();
|
||||
mutex_lock mll(collector->mu);
|
||||
for (const auto& graph : collector->graphs) {
|
||||
|
||||
// Adding to partition graphs for backward compatibility.
|
||||
for (const auto& graph : collector->partitioned_graphs) {
|
||||
*ctx->RunMetadataProto()->add_partition_graphs() = graph;
|
||||
}
|
||||
collector->graphs.clear();
|
||||
|
||||
if (collector->dirty) {
|
||||
auto* function_graphs =
|
||||
ctx->RunMetadataProto()->add_function_graphs();
|
||||
*function_graphs->mutable_post_optimization_graph() =
|
||||
collector->optimized_graph;
|
||||
*function_graphs->mutable_pre_optimization_graph() =
|
||||
collector->raw_graph;
|
||||
for (const auto& graph : collector->partitioned_graphs) {
|
||||
*function_graphs->add_partition_graphs() = graph;
|
||||
}
|
||||
}
|
||||
|
||||
collector->ClearGraphs();
|
||||
}
|
||||
auto* step_stats = ctx->RunMetadataProto()->mutable_step_stats();
|
||||
// Lazily initialize the RunMetadata with information about all devices if
|
||||
|
@ -526,6 +526,13 @@ Status ProcessFunctionLibraryRuntime::InstantiateMultiDevice(
|
||||
TF_RETURN_IF_ERROR(GetGraphAndRets(function_name, attrs, fdef, lib_def,
|
||||
&graph, &ret_node_names));
|
||||
|
||||
if (options.graph_collector != nullptr) {
|
||||
GraphDef def;
|
||||
graph->ToGraphDef(&def);
|
||||
*def.mutable_library() = lib_def->ReachableDefinitions(def).ToProto();
|
||||
options.graph_collector->CollectRawGraph(def);
|
||||
}
|
||||
|
||||
DeviceSet device_set;
|
||||
for (auto d : device_mgr_->ListDevices()) {
|
||||
device_set.AddDevice(d);
|
||||
@ -592,6 +599,13 @@ Status ProcessFunctionLibraryRuntime::InstantiateMultiDevice(
|
||||
OptimizationPassRegistry::POST_REWRITE_FOR_EXEC, optimization_options));
|
||||
DumpGraph("After all optimization passes", graph.get());
|
||||
|
||||
if (options.graph_collector != nullptr) {
|
||||
GraphDef def;
|
||||
graph->ToGraphDef(&def);
|
||||
*def.mutable_library() = lib_def->ReachableDefinitions(def).ToProto();
|
||||
options.graph_collector->CollectOptimizedGraph(def);
|
||||
}
|
||||
|
||||
std::unordered_map<string, std::unique_ptr<Graph>> subgraphs;
|
||||
TF_RETURN_IF_ERROR(
|
||||
PartitionFunctionGraph(device_set, std::move(graph), &subgraphs));
|
||||
@ -600,7 +614,8 @@ Status ProcessFunctionLibraryRuntime::InstantiateMultiDevice(
|
||||
for (const auto& pair : subgraphs) {
|
||||
GraphDef def;
|
||||
pair.second->ToGraphDef(&def);
|
||||
options.graph_collector->CollectGraph(def);
|
||||
*def.mutable_library() = lib_def->ReachableDefinitions(def).ToProto();
|
||||
options.graph_collector->CollectPartitionedGraph(def);
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -525,11 +525,42 @@ struct TensorValue {
|
||||
// Used to store partitioned graphs from function-calling ops.
|
||||
struct GraphCollector {
|
||||
mutex mu;
|
||||
std::vector<GraphDef> graphs GUARDED_BY(mu);
|
||||
std::vector<GraphDef> partitioned_graphs GUARDED_BY(mu);
|
||||
GraphDef raw_graph GUARDED_BY(mu);
|
||||
GraphDef optimized_graph GUARDED_BY(mu);
|
||||
|
||||
void CollectGraph(const GraphDef& graph) {
|
||||
bool dirty GUARDED_BY(mu);
|
||||
|
||||
GraphCollector() : dirty(false) {}
|
||||
|
||||
void CollectRawGraph(const GraphDef& graph) {
|
||||
mutex_lock ml(mu);
|
||||
graphs.push_back(graph);
|
||||
raw_graph.MergeFrom(graph);
|
||||
dirty = true;
|
||||
}
|
||||
|
||||
void CollectOptimizedGraph(const GraphDef& graph) {
|
||||
mutex_lock ml(mu);
|
||||
optimized_graph.MergeFrom(graph);
|
||||
dirty = true;
|
||||
}
|
||||
|
||||
void CollectPartitionedGraph(const GraphDef& graph) {
|
||||
mutex_lock ml(mu);
|
||||
partitioned_graphs.push_back(graph);
|
||||
dirty = true;
|
||||
}
|
||||
|
||||
void ClearGraphs() EXCLUSIVE_LOCKS_REQUIRED(mu) {
|
||||
raw_graph.Clear();
|
||||
optimized_graph.Clear();
|
||||
partitioned_graphs.clear();
|
||||
dirty = false;
|
||||
}
|
||||
|
||||
bool HasUpdatedGraphs() {
|
||||
mutex_lock ml(mu);
|
||||
return dirty;
|
||||
}
|
||||
};
|
||||
|
||||
|
@ -520,6 +520,25 @@ message RunMetadata {
|
||||
|
||||
// Graphs of the partitions executed by executors.
|
||||
repeated GraphDef partition_graphs = 3;
|
||||
|
||||
message FunctionGraphs {
|
||||
// TODO(nareshmodi): Include some sort of function/cache-key identifier?
|
||||
repeated GraphDef partition_graphs = 1;
|
||||
|
||||
GraphDef pre_optimization_graph = 2;
|
||||
GraphDef post_optimization_graph = 3;
|
||||
}
|
||||
// This is only populated for graphs that are run as functions in TensorFlow
|
||||
// V2. There will be an entry below for each function that is traced.
|
||||
// The main use cases of the post_optimization_graph and the partition_graphs
|
||||
// is to give the caller insight into the graphs that were actually run by the
|
||||
// runtime. Additional information (such as those in step_stats) will match
|
||||
// these graphs.
|
||||
// We also include the pre_optimization_graph since it is usually easier to
|
||||
// read, and is helpful in situations where the caller wants to get a high
|
||||
// level idea of what the built graph looks like (since the various graph
|
||||
// optimization passes might change the structure of the graph significantly).
|
||||
repeated FunctionGraphs function_graphs = 4;
|
||||
}
|
||||
|
||||
// Defines a connection between two tensors in a `GraphDef`.
|
||||
|
@ -0,0 +1,27 @@
|
||||
path: "tensorflow.RunMetadata.FunctionGraphs"
|
||||
tf_proto {
|
||||
descriptor {
|
||||
name: "FunctionGraphs"
|
||||
field {
|
||||
name: "partition_graphs"
|
||||
number: 1
|
||||
label: LABEL_REPEATED
|
||||
type: TYPE_MESSAGE
|
||||
type_name: ".tensorflow.GraphDef"
|
||||
}
|
||||
field {
|
||||
name: "pre_optimization_graph"
|
||||
number: 2
|
||||
label: LABEL_OPTIONAL
|
||||
type: TYPE_MESSAGE
|
||||
type_name: ".tensorflow.GraphDef"
|
||||
}
|
||||
field {
|
||||
name: "post_optimization_graph"
|
||||
number: 3
|
||||
label: LABEL_OPTIONAL
|
||||
type: TYPE_MESSAGE
|
||||
type_name: ".tensorflow.GraphDef"
|
||||
}
|
||||
}
|
||||
}
|
@ -23,5 +23,36 @@ tf_proto {
|
||||
type: TYPE_MESSAGE
|
||||
type_name: ".tensorflow.GraphDef"
|
||||
}
|
||||
field {
|
||||
name: "function_graphs"
|
||||
number: 4
|
||||
label: LABEL_REPEATED
|
||||
type: TYPE_MESSAGE
|
||||
type_name: ".tensorflow.RunMetadata.FunctionGraphs"
|
||||
}
|
||||
nested_type {
|
||||
name: "FunctionGraphs"
|
||||
field {
|
||||
name: "partition_graphs"
|
||||
number: 1
|
||||
label: LABEL_REPEATED
|
||||
type: TYPE_MESSAGE
|
||||
type_name: ".tensorflow.GraphDef"
|
||||
}
|
||||
field {
|
||||
name: "pre_optimization_graph"
|
||||
number: 2
|
||||
label: LABEL_OPTIONAL
|
||||
type: TYPE_MESSAGE
|
||||
type_name: ".tensorflow.GraphDef"
|
||||
}
|
||||
field {
|
||||
name: "post_optimization_graph"
|
||||
number: 3
|
||||
label: LABEL_OPTIONAL
|
||||
type: TYPE_MESSAGE
|
||||
type_name: ".tensorflow.GraphDef"
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
Loading…
x
Reference in New Issue
Block a user