[tpu:profiler] Write RunMetadata of the computation graph to event file, if available.
The RunMetadata can be used to annotate HLO graphs with colors based on node compute time. PiperOrigin-RevId: 167477021
This commit is contained in:
parent
d57572e996
commit
0302320e11
@ -38,7 +38,9 @@ limitations under the License.
|
|||||||
#include "tensorflow/core/platform/env.h"
|
#include "tensorflow/core/platform/env.h"
|
||||||
#include "tensorflow/core/platform/init_main.h"
|
#include "tensorflow/core/platform/init_main.h"
|
||||||
#include "tensorflow/core/platform/protobuf.h"
|
#include "tensorflow/core/platform/protobuf.h"
|
||||||
|
#include "tensorflow/core/protobuf/config.pb.h"
|
||||||
#include "tensorflow/core/util/command_line_flags.h"
|
#include "tensorflow/core/util/command_line_flags.h"
|
||||||
|
#include "tensorflow/core/util/event.pb.h"
|
||||||
#include "tensorflow/core/util/events_writer.h"
|
#include "tensorflow/core/util/events_writer.h"
|
||||||
|
|
||||||
namespace tensorflow {
|
namespace tensorflow {
|
||||||
@ -136,14 +138,39 @@ ProfileResponse Profile(const string& service_addr, int duration_ms) {
|
|||||||
return response;
|
return response;
|
||||||
}
|
}
|
||||||
|
|
||||||
void DumpGraph(StringPiece logdir, StringPiece run, const string& graph_def) {
|
void DumpGraphEvents(const string& logdir, const string& run,
|
||||||
|
const ProfileResponse& response) {
|
||||||
|
int num_graphs = response.computation_graph_size();
|
||||||
|
if (response.computation_graph_size() == 0) return;
|
||||||
|
// The server might generates multiple graphs for one program; we simply
|
||||||
|
// pick the first one.
|
||||||
|
if (num_graphs > 1) {
|
||||||
|
std::cout << num_graphs
|
||||||
|
<< " TPU program variants observed over the profiling period. "
|
||||||
|
<< "One computation graph will be chosen arbitrarily."
|
||||||
|
<< std::endl;
|
||||||
|
}
|
||||||
// The graph plugin expects the graph in <logdir>/<run>/<event.file>.
|
// The graph plugin expects the graph in <logdir>/<run>/<event.file>.
|
||||||
string run_dir = JoinPath(logdir, strings::StrCat(kGraphRunPrefix, run));
|
string run_dir = JoinPath(logdir, strings::StrCat(kGraphRunPrefix, run));
|
||||||
TF_CHECK_OK(Env::Default()->RecursivelyCreateDir(run_dir));
|
TF_CHECK_OK(Env::Default()->RecursivelyCreateDir(run_dir));
|
||||||
EventsWriter event_writer(JoinPath(run_dir, "events"));
|
EventsWriter event_writer(JoinPath(run_dir, "events"));
|
||||||
Event event;
|
Event event;
|
||||||
event.set_graph_def(graph_def);
|
// Add the computation graph.
|
||||||
|
event.set_graph_def(response.computation_graph(0).SerializeAsString());
|
||||||
event_writer.WriteEvent(event);
|
event_writer.WriteEvent(event);
|
||||||
|
std::cout << "Wrote a HLO graph to " << event_writer.FileName() << std::endl;
|
||||||
|
|
||||||
|
if (response.has_hlo_metadata()) {
|
||||||
|
tensorflow::TaggedRunMetadata tagged_run_metadata;
|
||||||
|
tagged_run_metadata.set_tag(run);
|
||||||
|
tagged_run_metadata.set_run_metadata(
|
||||||
|
response.hlo_metadata().SerializeAsString());
|
||||||
|
tensorflow::Event meta_event;
|
||||||
|
*meta_event.mutable_tagged_run_metadata() = tagged_run_metadata;
|
||||||
|
event_writer.WriteEvent(meta_event);
|
||||||
|
std::cout << "Wrote HLO ops run metadata to " << event_writer.FileName()
|
||||||
|
<< std::endl;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace
|
} // namespace
|
||||||
@ -186,19 +213,7 @@ int main(int argc, char** argv) {
|
|||||||
LOG(INFO) << "Converting trace events to TraceViewer JSON.";
|
LOG(INFO) << "Converting trace events to TraceViewer JSON.";
|
||||||
tensorflow::tpu::DumpTraceToLogDirectory(run_dir, response.encoded_trace());
|
tensorflow::tpu::DumpTraceToLogDirectory(run_dir, response.encoded_trace());
|
||||||
}
|
}
|
||||||
int num_graphs = response.computation_graph_size();
|
tensorflow::tpu::DumpGraphEvents(FLAGS_logdir, run, response);
|
||||||
if (num_graphs > 0) {
|
|
||||||
// The server might generates multiple graphs for one program; we simply
|
|
||||||
// pick the first one.
|
|
||||||
if (num_graphs > 1) {
|
|
||||||
std::cout << num_graphs
|
|
||||||
<< " TPU program variants observed over the profiling period. "
|
|
||||||
<< "One computation graph will be chosen arbitrarily."
|
|
||||||
<< std::endl;
|
|
||||||
}
|
|
||||||
tensorflow::tpu::DumpGraph(
|
|
||||||
FLAGS_logdir, run, response.computation_graph(0).SerializeAsString());
|
|
||||||
}
|
|
||||||
if (response.has_op_profile() &&
|
if (response.has_op_profile() &&
|
||||||
(response.op_profile().has_by_program_structure() ||
|
(response.op_profile().has_by_program_structure() ||
|
||||||
response.op_profile().has_by_category())) {
|
response.op_profile().has_by_category())) {
|
||||||
|
@ -2,6 +2,7 @@ syntax = "proto3";
|
|||||||
package tensorflow;
|
package tensorflow;
|
||||||
|
|
||||||
import "tensorflow/core/framework/graph.proto";
|
import "tensorflow/core/framework/graph.proto";
|
||||||
|
import "tensorflow/core/protobuf/config.proto";
|
||||||
import "tensorflow/contrib/tpu/profiler/op_profile.proto";
|
import "tensorflow/contrib/tpu/profiler/op_profile.proto";
|
||||||
|
|
||||||
// The TPUProfiler service retrieves performance information about
|
// The TPUProfiler service retrieves performance information about
|
||||||
@ -31,6 +32,10 @@ message ProfileResponse {
|
|||||||
// Graphs of programs executed on TPUs during the profiling period.
|
// Graphs of programs executed on TPUs during the profiling period.
|
||||||
repeated GraphDef computation_graph = 2;
|
repeated GraphDef computation_graph = 2;
|
||||||
|
|
||||||
|
// Performance profile that can be used to annotate HLO operations in the
|
||||||
|
// computation graph.
|
||||||
|
RunMetadata hlo_metadata = 5;
|
||||||
|
|
||||||
// Encoded Trace proto message that contains metadata about the trace captured
|
// Encoded Trace proto message that contains metadata about the trace captured
|
||||||
// during the profiling period. Describes the devices and resources that
|
// during the profiling period. Describes the devices and resources that
|
||||||
// 'trace_events' refers to.
|
// 'trace_events' refers to.
|
||||||
@ -40,4 +45,5 @@ message ProfileResponse {
|
|||||||
// If the trace covers multiple programs, the longest-running one is analyzed.
|
// If the trace covers multiple programs, the longest-running one is analyzed.
|
||||||
// See op_profile.proto for the detailed semantics of the returned profile.
|
// See op_profile.proto for the detailed semantics of the returned profile.
|
||||||
tpu.op_profile.Profile op_profile = 4;
|
tpu.op_profile.Profile op_profile = 4;
|
||||||
|
// next-field: 6
|
||||||
}
|
}
|
||||||
|
Loading…
Reference in New Issue
Block a user