Add post optimization graph to RunMetadata (when running eager functions)

This stores the pre-grappler graph + post-grappler graph + partitioned graphs
for each instantiated function.

This will be useful to get pre-optimization/post-optimization graphs for
displaying within tensorboard.

PiperOrigin-RevId: 233813975
This commit is contained in:
Akshay Modi 2019-02-13 13:08:05 -08:00 committed by TensorFlower Gardener
parent 8c92769a36
commit 7df8df33c5
6 changed files with 144 additions and 6 deletions

View File

@ -932,10 +932,25 @@ Status EagerKernelExecute(EagerContext* ctx, Device* device,
{
GraphCollector* collector = ctx->GetGraphCollector();
mutex_lock mll(collector->mu);
for (const auto& graph : collector->graphs) {
// Adding to partition graphs for backward compatibility.
for (const auto& graph : collector->partitioned_graphs) {
*ctx->RunMetadataProto()->add_partition_graphs() = graph;
}
collector->graphs.clear();
if (collector->dirty) {
auto* function_graphs =
ctx->RunMetadataProto()->add_function_graphs();
*function_graphs->mutable_post_optimization_graph() =
collector->optimized_graph;
*function_graphs->mutable_pre_optimization_graph() =
collector->raw_graph;
for (const auto& graph : collector->partitioned_graphs) {
*function_graphs->add_partition_graphs() = graph;
}
}
collector->ClearGraphs();
}
auto* step_stats = ctx->RunMetadataProto()->mutable_step_stats();
// Lazily initialize the RunMetadata with information about all devices if

View File

@ -526,6 +526,13 @@ Status ProcessFunctionLibraryRuntime::InstantiateMultiDevice(
TF_RETURN_IF_ERROR(GetGraphAndRets(function_name, attrs, fdef, lib_def,
&graph, &ret_node_names));
if (options.graph_collector != nullptr) {
GraphDef def;
graph->ToGraphDef(&def);
*def.mutable_library() = lib_def->ReachableDefinitions(def).ToProto();
options.graph_collector->CollectRawGraph(def);
}
DeviceSet device_set;
for (auto d : device_mgr_->ListDevices()) {
device_set.AddDevice(d);
@ -592,6 +599,13 @@ Status ProcessFunctionLibraryRuntime::InstantiateMultiDevice(
OptimizationPassRegistry::POST_REWRITE_FOR_EXEC, optimization_options));
DumpGraph("After all optimization passes", graph.get());
if (options.graph_collector != nullptr) {
GraphDef def;
graph->ToGraphDef(&def);
*def.mutable_library() = lib_def->ReachableDefinitions(def).ToProto();
options.graph_collector->CollectOptimizedGraph(def);
}
std::unordered_map<string, std::unique_ptr<Graph>> subgraphs;
TF_RETURN_IF_ERROR(
PartitionFunctionGraph(device_set, std::move(graph), &subgraphs));
@ -600,7 +614,8 @@ Status ProcessFunctionLibraryRuntime::InstantiateMultiDevice(
for (const auto& pair : subgraphs) {
GraphDef def;
pair.second->ToGraphDef(&def);
options.graph_collector->CollectGraph(def);
*def.mutable_library() = lib_def->ReachableDefinitions(def).ToProto();
options.graph_collector->CollectPartitionedGraph(def);
}
}

View File

@ -525,11 +525,42 @@ struct TensorValue {
// Used to store partitioned graphs from function-calling ops.
struct GraphCollector {
mutex mu;
std::vector<GraphDef> graphs GUARDED_BY(mu);
std::vector<GraphDef> partitioned_graphs GUARDED_BY(mu);
GraphDef raw_graph GUARDED_BY(mu);
GraphDef optimized_graph GUARDED_BY(mu);
void CollectGraph(const GraphDef& graph) {
bool dirty GUARDED_BY(mu);
GraphCollector() : dirty(false) {}
void CollectRawGraph(const GraphDef& graph) {
mutex_lock ml(mu);
graphs.push_back(graph);
raw_graph.MergeFrom(graph);
dirty = true;
}
void CollectOptimizedGraph(const GraphDef& graph) {
mutex_lock ml(mu);
optimized_graph.MergeFrom(graph);
dirty = true;
}
void CollectPartitionedGraph(const GraphDef& graph) {
mutex_lock ml(mu);
partitioned_graphs.push_back(graph);
dirty = true;
}
void ClearGraphs() EXCLUSIVE_LOCKS_REQUIRED(mu) {
raw_graph.Clear();
optimized_graph.Clear();
partitioned_graphs.clear();
dirty = false;
}
bool HasUpdatedGraphs() {
mutex_lock ml(mu);
return dirty;
}
};

View File

@ -520,6 +520,25 @@ message RunMetadata {
// Graphs of the partitions executed by executors.
repeated GraphDef partition_graphs = 3;
message FunctionGraphs {
// TODO(nareshmodi): Include some sort of function/cache-key identifier?
repeated GraphDef partition_graphs = 1;
GraphDef pre_optimization_graph = 2;
GraphDef post_optimization_graph = 3;
}
// This is only populated for graphs that are run as functions in TensorFlow
// V2. There will be an entry below for each function that is traced.
// The main use cases of the post_optimization_graph and the partition_graphs
// is to give the caller insight into the graphs that were actually run by the
// runtime. Additional information (such as those in step_stats) will match
// these graphs.
// We also include the pre_optimization_graph since it is usually easier to
// read, and is helpful in situations where the caller wants to get a high
// level idea of what the built graph looks like (since the various graph
// optimization passes might change the structure of the graph significantly).
repeated FunctionGraphs function_graphs = 4;
}
// Defines a connection between two tensors in a `GraphDef`.

View File

@ -0,0 +1,27 @@
path: "tensorflow.RunMetadata.FunctionGraphs"
tf_proto {
descriptor {
name: "FunctionGraphs"
field {
name: "partition_graphs"
number: 1
label: LABEL_REPEATED
type: TYPE_MESSAGE
type_name: ".tensorflow.GraphDef"
}
field {
name: "pre_optimization_graph"
number: 2
label: LABEL_OPTIONAL
type: TYPE_MESSAGE
type_name: ".tensorflow.GraphDef"
}
field {
name: "post_optimization_graph"
number: 3
label: LABEL_OPTIONAL
type: TYPE_MESSAGE
type_name: ".tensorflow.GraphDef"
}
}
}

View File

@ -23,5 +23,36 @@ tf_proto {
type: TYPE_MESSAGE
type_name: ".tensorflow.GraphDef"
}
field {
name: "function_graphs"
number: 4
label: LABEL_REPEATED
type: TYPE_MESSAGE
type_name: ".tensorflow.RunMetadata.FunctionGraphs"
}
nested_type {
name: "FunctionGraphs"
field {
name: "partition_graphs"
number: 1
label: LABEL_REPEATED
type: TYPE_MESSAGE
type_name: ".tensorflow.GraphDef"
}
field {
name: "pre_optimization_graph"
number: 2
label: LABEL_OPTIONAL
type: TYPE_MESSAGE
type_name: ".tensorflow.GraphDef"
}
field {
name: "post_optimization_graph"
number: 3
label: LABEL_OPTIONAL
type: TYPE_MESSAGE
type_name: ".tensorflow.GraphDef"
}
}
}
}