[tpu profiler] Dump HLO graphs in profile responses to the log directory.
PiperOrigin-RevId: 163318992
This commit is contained in:
parent
dd1f0cdddb
commit
cda80a7850
tensorflow/contrib/tpu/profiler
@ -21,8 +21,10 @@ cc_binary(
|
||||
visibility = ["//tensorflow/contrib/tpu/profiler:__subpackages__"],
|
||||
deps = [
|
||||
":tpu_profiler_proto_cc",
|
||||
"//tensorflow/core:framework",
|
||||
"//tensorflow/core:framework_internal",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:protos_all_cc",
|
||||
"//tensorflow/core/distributed_runtime/rpc:grpc_util",
|
||||
"@grpc//:grpc++_unsecure",
|
||||
],
|
||||
|
@ -26,6 +26,7 @@ limitations under the License.
|
||||
|
||||
#include "tensorflow/contrib/tpu/profiler/tpu_profiler.grpc.pb.h"
|
||||
#include "tensorflow/core/distributed_runtime/rpc/grpc_util.h"
|
||||
#include "tensorflow/core/framework/graph.pb.h"
|
||||
#include "tensorflow/core/lib/core/errors.h"
|
||||
#include "tensorflow/core/lib/io/path.h"
|
||||
#include "tensorflow/core/lib/strings/str_util.h"
|
||||
@ -33,6 +34,7 @@ limitations under the License.
|
||||
#include "tensorflow/core/platform/env.h"
|
||||
#include "tensorflow/core/platform/init_main.h"
|
||||
#include "tensorflow/core/util/command_line_flags.h"
|
||||
#include "tensorflow/core/util/events_writer.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace tpu {
|
||||
@ -47,6 +49,7 @@ using ::tensorflow::WriteStringToFile;
|
||||
|
||||
constexpr char kProfilePluginDirectory[] = "plugins/profile/";
|
||||
constexpr char kTraceFileName[] = "trace";
|
||||
constexpr char kGraphRunPrefix[] = "tpu_profiler.hlo_graph.";
|
||||
|
||||
tensorflow::string GetCurrentTimeStampAsString() {
|
||||
char s[128];
|
||||
@ -55,10 +58,10 @@ tensorflow::string GetCurrentTimeStampAsString() {
|
||||
return s;
|
||||
}
|
||||
|
||||
// The trace will be stored in <logdir>/plugins/profile/<timestamp>/trace.
|
||||
void DumpTraceToLogDirectory(const tensorflow::string& logdir,
|
||||
// The trace will be stored in <logdir>/plugins/profile/<run>/trace.
|
||||
void DumpTraceToLogDirectory(tensorflow::StringPiece logdir,
|
||||
tensorflow::StringPiece run,
|
||||
tensorflow::StringPiece trace) {
|
||||
tensorflow::string run = GetCurrentTimeStampAsString();
|
||||
tensorflow::string run_dir = JoinPath(logdir, kProfilePluginDirectory, run);
|
||||
TF_CHECK_OK(Env::Default()->RecursivelyCreateDir(run_dir));
|
||||
tensorflow::string path = JoinPath(run_dir, kTraceFileName);
|
||||
@ -83,6 +86,18 @@ ProfileResponse Profile(const tensorflow::string& service_addr,
|
||||
return response;
|
||||
}
|
||||
|
||||
void DumpGraph(tensorflow::StringPiece logdir, tensorflow::StringPiece run,
|
||||
const tensorflow::string& graph_def) {
|
||||
// The graph plugin expects the graph in <logdir>/<run>/<event.file>.
|
||||
tensorflow::string run_dir =
|
||||
JoinPath(logdir, tensorflow::strings::StrCat(kGraphRunPrefix, run));
|
||||
TF_CHECK_OK(Env::Default()->RecursivelyCreateDir(run_dir));
|
||||
tensorflow::EventsWriter event_writer(JoinPath(run_dir, "events"));
|
||||
tensorflow::Event event;
|
||||
event.set_graph_def(graph_def);
|
||||
event_writer.WriteEvent(event);
|
||||
}
|
||||
|
||||
} // namespace
|
||||
} // namespace tpu
|
||||
} // namespace tensorflow
|
||||
@ -111,14 +126,28 @@ int main(int argc, char** argv) {
|
||||
int duration_ms = FLAGS_duration_ms;
|
||||
tensorflow::ProfileResponse response =
|
||||
tensorflow::tpu::Profile(FLAGS_service_addr, duration_ms);
|
||||
// Use the current timestamp as the run name.
|
||||
tensorflow::string run = tensorflow::tpu::GetCurrentTimeStampAsString();
|
||||
// Ignore computation_graph for now.
|
||||
if (response.encoded_trace().empty()) {
|
||||
LOG(WARNING) << "No trace event is collected during the " << duration_ms
|
||||
<< "ms interval.";
|
||||
} else {
|
||||
tensorflow::tpu::DumpTraceToLogDirectory(FLAGS_logdir,
|
||||
tensorflow::tpu::DumpTraceToLogDirectory(FLAGS_logdir, run,
|
||||
response.encoded_trace());
|
||||
}
|
||||
int num_graphs = response.computation_graph_size();
|
||||
if (num_graphs > 0) {
|
||||
// The server might generates multiple graphs for one program; we simply
|
||||
// pick the first one.
|
||||
if (num_graphs > 1) {
|
||||
LOG(INFO) << num_graphs
|
||||
<< " TPU program variants observed over the profiling period. "
|
||||
<< "One computation graph will be chosen arbitrarily.";
|
||||
}
|
||||
tensorflow::tpu::DumpGraph(
|
||||
FLAGS_logdir, run, response.computation_graph(0).SerializeAsString());
|
||||
}
|
||||
// Print this at the end so that it's not buried in irrelevant LOG messages.
|
||||
std::cout
|
||||
<< "NOTE: using the trace duration " << duration_ms << "ms." << std::endl
|
||||
|
Loading…
Reference in New Issue
Block a user