[tpu:profiler] Write RunMetadata of the computation graph to event file, if available.
The RunMetadata can be used to annotate HLO graphs with colors based on node compute time. PiperOrigin-RevId: 167477021
This commit is contained in:
parent
d57572e996
commit
0302320e11
@ -38,7 +38,9 @@ limitations under the License.
|
||||
#include "tensorflow/core/platform/env.h"
|
||||
#include "tensorflow/core/platform/init_main.h"
|
||||
#include "tensorflow/core/platform/protobuf.h"
|
||||
#include "tensorflow/core/protobuf/config.pb.h"
|
||||
#include "tensorflow/core/util/command_line_flags.h"
|
||||
#include "tensorflow/core/util/event.pb.h"
|
||||
#include "tensorflow/core/util/events_writer.h"
|
||||
|
||||
namespace tensorflow {
|
||||
@ -136,14 +138,39 @@ ProfileResponse Profile(const string& service_addr, int duration_ms) {
|
||||
return response;
|
||||
}
|
||||
|
||||
void DumpGraph(StringPiece logdir, StringPiece run, const string& graph_def) {
|
||||
void DumpGraphEvents(const string& logdir, const string& run,
|
||||
const ProfileResponse& response) {
|
||||
int num_graphs = response.computation_graph_size();
|
||||
if (response.computation_graph_size() == 0) return;
|
||||
// The server might generates multiple graphs for one program; we simply
|
||||
// pick the first one.
|
||||
if (num_graphs > 1) {
|
||||
std::cout << num_graphs
|
||||
<< " TPU program variants observed over the profiling period. "
|
||||
<< "One computation graph will be chosen arbitrarily."
|
||||
<< std::endl;
|
||||
}
|
||||
// The graph plugin expects the graph in <logdir>/<run>/<event.file>.
|
||||
string run_dir = JoinPath(logdir, strings::StrCat(kGraphRunPrefix, run));
|
||||
TF_CHECK_OK(Env::Default()->RecursivelyCreateDir(run_dir));
|
||||
EventsWriter event_writer(JoinPath(run_dir, "events"));
|
||||
Event event;
|
||||
event.set_graph_def(graph_def);
|
||||
// Add the computation graph.
|
||||
event.set_graph_def(response.computation_graph(0).SerializeAsString());
|
||||
event_writer.WriteEvent(event);
|
||||
std::cout << "Wrote a HLO graph to " << event_writer.FileName() << std::endl;
|
||||
|
||||
if (response.has_hlo_metadata()) {
|
||||
tensorflow::TaggedRunMetadata tagged_run_metadata;
|
||||
tagged_run_metadata.set_tag(run);
|
||||
tagged_run_metadata.set_run_metadata(
|
||||
response.hlo_metadata().SerializeAsString());
|
||||
tensorflow::Event meta_event;
|
||||
*meta_event.mutable_tagged_run_metadata() = tagged_run_metadata;
|
||||
event_writer.WriteEvent(meta_event);
|
||||
std::cout << "Wrote HLO ops run metadata to " << event_writer.FileName()
|
||||
<< std::endl;
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace
|
||||
@ -186,19 +213,7 @@ int main(int argc, char** argv) {
|
||||
LOG(INFO) << "Converting trace events to TraceViewer JSON.";
|
||||
tensorflow::tpu::DumpTraceToLogDirectory(run_dir, response.encoded_trace());
|
||||
}
|
||||
int num_graphs = response.computation_graph_size();
|
||||
if (num_graphs > 0) {
|
||||
// The server might generates multiple graphs for one program; we simply
|
||||
// pick the first one.
|
||||
if (num_graphs > 1) {
|
||||
std::cout << num_graphs
|
||||
<< " TPU program variants observed over the profiling period. "
|
||||
<< "One computation graph will be chosen arbitrarily."
|
||||
<< std::endl;
|
||||
}
|
||||
tensorflow::tpu::DumpGraph(
|
||||
FLAGS_logdir, run, response.computation_graph(0).SerializeAsString());
|
||||
}
|
||||
tensorflow::tpu::DumpGraphEvents(FLAGS_logdir, run, response);
|
||||
if (response.has_op_profile() &&
|
||||
(response.op_profile().has_by_program_structure() ||
|
||||
response.op_profile().has_by_category())) {
|
||||
|
@ -2,6 +2,7 @@ syntax = "proto3";
|
||||
package tensorflow;
|
||||
|
||||
import "tensorflow/core/framework/graph.proto";
|
||||
import "tensorflow/core/protobuf/config.proto";
|
||||
import "tensorflow/contrib/tpu/profiler/op_profile.proto";
|
||||
|
||||
// The TPUProfiler service retrieves performance information about
|
||||
@ -31,6 +32,10 @@ message ProfileResponse {
|
||||
// Graphs of programs executed on TPUs during the profiling period.
|
||||
repeated GraphDef computation_graph = 2;
|
||||
|
||||
// Performance profile that can be used to annotate HLO operations in the
|
||||
// computation graph.
|
||||
RunMetadata hlo_metadata = 5;
|
||||
|
||||
// Encoded Trace proto message that contains metadata about the trace captured
|
||||
// during the profiling period. Describes the devices and resources that
|
||||
// 'trace_events' refers to.
|
||||
@ -40,4 +45,5 @@ message ProfileResponse {
|
||||
// If the trace covers multiple programs, the longest-running one is analyzed.
|
||||
// See op_profile.proto for the detailed semantics of the returned profile.
|
||||
tpu.op_profile.Profile op_profile = 4;
|
||||
// next-field: 6
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user