tfdbg core: implement gRPC debug URLs

Change: 139976177
This commit is contained in:
Shanqing Cai 2016-11-22 17:30:19 -08:00 committed by TensorFlower Gardener
parent cce0d12e13
commit ef2a926ec0
13 changed files with 762 additions and 61 deletions

View File

@ -4,8 +4,6 @@
file(GLOB tf_core_direct_session_srcs
"${tensorflow_source_dir}/tensorflow/core/common_runtime/direct_session.cc"
"${tensorflow_source_dir}/tensorflow/core/common_runtime/direct_session.h"
"${tensorflow_source_dir}/tensorflow/core/debug/*.h"
"${tensorflow_source_dir}/tensorflow/core/debug/*.cc"
)
file(GLOB_RECURSE tf_core_direct_session_test_srcs
@ -18,3 +16,5 @@ list(REMOVE_ITEM tf_core_direct_session_srcs ${tf_core_direct_session_test_srcs}
add_library(tf_core_direct_session OBJECT ${tf_core_direct_session_srcs})
add_dependencies(tf_core_direct_session tf_core_cpu)
add_definitions(-DNOTFDBG)

View File

@ -656,7 +656,6 @@ filegroup(
name = "android_srcs",
srcs = [
":proto_text_srcs_all",
"//tensorflow/core/debug:android_srcs",
"//tensorflow/core/kernels:android_srcs",
"//tensorflow/core/platform/default/build_config:android_srcs",
"//tensorflow/core/util/ctc:android_srcs",
@ -685,6 +684,7 @@ filegroup(
"**/*testutil*",
"**/*testlib*",
"**/*main.cc",
"debug/**/*",
"graph/dot.*",
"lib/jpeg/**/*",
"lib/png/**/*",
@ -723,7 +723,10 @@ filegroup(
cc_library(
name = "android_tensorflow_lib_lite",
srcs = if_android(["//tensorflow/core:android_srcs"]),
copts = tf_copts() + ["-Os"],
copts = tf_copts() + [
"-Os",
"-DNOTFDBG",
],
linkopts = ["-lz"],
tags = [
"manual",
@ -749,7 +752,7 @@ cc_library(
"//tensorflow/core/kernels:android_core_ops",
"//tensorflow/core/kernels:android_extended_ops",
]),
copts = tf_copts() + ["-Os"] + ["-std=c++11"],
copts = tf_copts() + ["-Os"] + ["-std=c++11"] + ["-DNOTFDBG"],
visibility = ["//visibility:public"],
deps = [
":protos_cc",
@ -782,7 +785,7 @@ cc_library(
cc_library(
name = "android_tensorflow_lib",
srcs = if_android([":android_op_registrations_and_gradients"]),
copts = tf_copts(),
copts = tf_copts() + ["-DNOTFDBG"],
tags = [
"manual",
"notap",
@ -807,6 +810,7 @@ cc_library(
copts = tf_copts() + [
"-Os",
"-DSUPPORT_SELECTIVE_REGISTRATION",
"-DNOTFDBG",
],
tags = [
"manual",

View File

@ -65,6 +65,7 @@ tf_cuda_library(
copts = tf_copts(),
linkstatic = 1,
deps = [
":debug_io_utils",
"//tensorflow/core:core_cpu_internal",
"//tensorflow/core:framework",
"//tensorflow/core:lib",
@ -82,17 +83,51 @@ tf_cuda_library(
copts = tf_copts(),
linkstatic = 1,
deps = [
":debug_service_proto_cc", # TODO(cais): Confirm safe.
":debug_service_proto_cc",
"//tensorflow/core:core_cpu_internal",
"//tensorflow/core:framework",
"//tensorflow/core:lib",
"//tensorflow/core:lib_internal",
"//tensorflow/core:proto_text",
"//tensorflow/core:protos_all_cc",
"@grpc//:grpc++_unsecure",
],
alwayslink = 1,
)
tf_cuda_library(
name = "debug_grpc_testlib",
srcs = ["debug_grpc_testlib.cc"],
hdrs = ["debug_grpc_testlib.h"],
copts = tf_copts(),
linkstatic = 1,
deps = [
":debug_graph_utils",
":debug_io_utils",
":debug_service_proto_cc",
"//tensorflow/core:framework",
"//tensorflow/core:lib",
"//tensorflow/core:lib_internal",
"@grpc//:grpc++_unsecure",
],
alwayslink = 1,
)
cc_binary(
name = "debug_test_server_main",
srcs = [
"debug_test_server_main.cc",
],
deps = [
":debug_grpc_testlib",
"//tensorflow/core:core_cpu",
"//tensorflow/core:framework_internal",
"//tensorflow/core:lib",
"//tensorflow/core:protos_all_cc",
"@grpc//:grpc++_unsecure",
],
)
tf_cc_test_gpu(
name = "debug_gateway_test",
size = "small",
@ -129,6 +164,7 @@ tf_cc_test(
srcs = ["debug_io_utils_test.cc"],
linkstatic = tf_kernel_tests_linkstatic(),
deps = [
":debug_grpc_testlib",
":debug_io_utils",
"//tensorflow/core:core_cpu",
"//tensorflow/core:core_cpu_internal",
@ -143,15 +179,43 @@ tf_cc_test(
],
)
filegroup(
name = "android_srcs",
srcs = [
"debug_graph_utils.cc",
"debug_graph_utils.h",
tf_cc_test(
name = "debug_grpc_io_utils_test",
size = "small",
srcs = ["debug_grpc_io_utils_test.cc"],
data = [
":debug_test_server_main",
],
linkstatic = tf_kernel_tests_linkstatic(),
deps = [
":debug_graph_utils",
":debug_grpc_testlib",
":debug_io_utils",
"//tensorflow/core:core_cpu",
"//tensorflow/core:core_cpu_internal",
"//tensorflow/core:framework",
"//tensorflow/core:framework_internal",
"//tensorflow/core:lib",
"//tensorflow/core:lib_internal",
"//tensorflow/core:protos_all_cc",
"//tensorflow/core:test",
"//tensorflow/core:test_main",
"//tensorflow/core:testlib",
],
visibility = ["//visibility:public"],
)
# TODO(cais): Add the following back in when tfdbg is supported on Android.
# filegroup(
# name = "android_srcs",
# srcs = [
# "debug_graph_utils.cc",
# "debug_graph_utils.h",
# "debug_io_utils.cc",
# "debug_io_utils.h",
# ],
# visibility = ["//visibility:public"],
# )
# -----------------------------------------------------------------------------
# Google-internal targets. These must be at the end for syncrepo.

View File

@ -16,6 +16,7 @@ limitations under the License.
#include "tensorflow/core/debug/debug_graph_utils.h"
#include "tensorflow/core/common_runtime/memory_types.h"
#include "tensorflow/core/debug/debug_io_utils.h"
#include "tensorflow/core/framework/kernel_def.pb.h"
#include "tensorflow/core/framework/node_def_builder.h"
#include "tensorflow/core/framework/op_kernel.h"
@ -35,8 +36,9 @@ DebuggerState::DebuggerState(
}
DebuggerState::~DebuggerState() {
// TODO(cais): This is currently no-op. For gRPC debug URLs in debug_urls_,
// add cleanup actions such as closing streams.
for (const string& debug_url : debug_urls_) {
DebugIO::CloseDebugURL(debug_url);
}
}
const string DebuggerState::SummarizeDebugTensorWatches() {

View File

@ -94,11 +94,11 @@ class DebugNodeInserter {
const protobuf::RepeatedPtrField<DebugTensorWatch>& watches, Graph* graph,
Device* device);
// Get canonical name of the copy node.
// Get canonical name of a copy node.
static const string GetCopyNodeName(const string& node_name,
const int output_slot);
// Get canonical name of the debug node.
// Get canonical name of a debug node.
static const string GetDebugNodeName(const string& tensor_name,
const int debug_op_num,
const string& debug_op_name);

View File

@ -0,0 +1,216 @@
/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/core/debug/debug_graph_utils.h"
#include "tensorflow/core/debug/debug_grpc_testlib.h"
#include "tensorflow/core/debug/debug_io_utils.h"
#include "tensorflow/core/lib/core/notification.h"
#include "tensorflow/core/lib/core/status_test_util.h"
#include "tensorflow/core/lib/core/threadpool.h"
#include "tensorflow/core/lib/io/path.h"
#include "tensorflow/core/lib/strings/str_util.h"
#include "tensorflow/core/lib/strings/stringprintf.h"
#include "tensorflow/core/platform/env.h"
#include "tensorflow/core/util/event.pb.h"
namespace tensorflow {
namespace {
class GrpcDebugTest : public ::testing::Test {
protected:
bool SetUpServer() {
// Obtain port number for the test server.
int port = testing::PickUnusedPortOrDie();
server_client_pair.reset(new test::GrpcTestServerClientPair(port));
// Launch a debug test server in a subprocess.
const string test_server_bin = strings::StrCat(
testing::TensorFlowSrcRoot(), "/core/debug/debug_test_server_main");
const std::vector<string> argv(
{test_server_bin,
strings::Printf("%d", server_client_pair->server_port),
server_client_pair->dump_root});
subprocess_ = testing::CreateSubProcess(argv);
return subprocess_->Start();
}
void TearDownServer() {
// Stop the test server subprocess.
subprocess_->Kill(9);
// Clean up server dump directory.
int64 undeleted_files = -1;
int64 undeleted_dirs = -1;
Env::Default()->DeleteRecursively(server_client_pair->dump_root,
&undeleted_files, &undeleted_dirs);
ASSERT_EQ(0, undeleted_files);
ASSERT_EQ(0, undeleted_dirs);
}
std::unique_ptr<test::GrpcTestServerClientPair> server_client_pair;
private:
std::shared_ptr<SubProcess> subprocess_;
};
TEST_F(GrpcDebugTest, AttemptToSendToNonexistentGrpcAddress) {
Tensor tensor(DT_FLOAT, TensorShape({1, 1}));
tensor.flat<float>()(0) = 42.0;
const string kInvalidGrpcUrl = "grpc://0.0.0.0:0";
// Attempt to publish debug tensor to the invalid URL should lead to a non-OK
// Status.
Status publish_status = DebugIO::PublishDebugTensor(
"foo_tensor", "DebugIdentity", tensor, Env::Default()->NowMicros(),
{kInvalidGrpcUrl});
ASSERT_FALSE(publish_status.ok());
ASSERT_NE(
string::npos,
publish_status.error_message().find(
"Channel at the following gRPC address is not ready: 0.0.0.0:0"));
DebugIO::CloseDebugURL(kInvalidGrpcUrl);
}
TEST_F(GrpcDebugTest, SendSingleDebugTensorViaGrpcTest) {
// Start the server process.
ASSERT_TRUE(SetUpServer());
// Poll the server with Event stream requests until first success.
ASSERT_TRUE(server_client_pair->PollTillFirstRequestSucceeds());
// Verify that the expected dump file exists.
std::vector<string> dump_files;
Env::Default()->GetChildren(server_client_pair->dump_root, &dump_files);
ASSERT_EQ(1, dump_files.size());
ASSERT_EQ(0, dump_files[0].find("prep_node_0_DebugIdentity_"));
TearDownServer();
}
TEST_F(GrpcDebugTest, SendMultipleDebugTensorsSynchronizedViaGrpcTest) {
const int kSends = 4;
// Start the server process.
ASSERT_TRUE(SetUpServer());
// Prepare the tensors to sent.
std::vector<Tensor> tensors;
for (int i = 0; i < kSends; ++i) {
Tensor tensor(DT_INT32, TensorShape({1, 1}));
tensor.flat<int>()(0) = i * i;
tensors.push_back(tensor);
}
// Poll the server with Event stream requests until first success.
ASSERT_TRUE(server_client_pair->PollTillFirstRequestSucceeds());
thread::ThreadPool* tp =
new thread::ThreadPool(Env::Default(), "grpc_debug_test", kSends);
mutex mu;
Notification all_done;
int tensor_count GUARDED_BY(mu) = 0;
std::vector<Status> statuses GUARDED_BY(mu);
const std::vector<string> urls({server_client_pair->test_server_url});
// Set up the concurrent tasks of sending Tensors via an Event stream to the
// server.
auto fn = [this, &mu, &tensor_count, &tensors, &statuses, &all_done,
&urls]() {
int this_count;
{
mutex_lock l(mu);
this_count = tensor_count++;
}
// Different concurrent tasks will send different tensors.
const uint64 wall_time = Env::Default()->NowMicros();
Status publish_status = DebugIO::PublishDebugTensor(
strings::StrCat("synchronized_node_", this_count, ":0"),
"DebugIdentity", tensors[this_count], wall_time, urls);
{
mutex_lock l(mu);
statuses.push_back(publish_status);
if (this_count == kSends - 1 && !all_done.HasBeenNotified()) {
all_done.Notify();
}
}
};
// Schedule the concurrent tasks.
for (int i = 0; i < kSends; ++i) {
tp->Schedule(fn);
}
// Wait for all client tasks to finish.
all_done.WaitForNotification();
delete tp;
// Close the debug gRPC stream.
Status close_status =
DebugIO::CloseDebugURL(server_client_pair->test_server_url);
ASSERT_TRUE(close_status.ok());
// Check all statuses from the PublishDebugTensor calls().
for (const Status& status : statuses) {
TF_ASSERT_OK(status);
}
// Load the dump files generated by the server upon receiving the tensors
// via the Event stream.
std::vector<string> dump_files;
Env::Default()->GetChildren(server_client_pair->dump_root, &dump_files);
// One prep tensor plus kSends concurrent tensors are expected.
ASSERT_EQ(1 + kSends, dump_files.size());
// Verify the content of the dumped tensors (in Event proto files).
for (const string& dump_file : dump_files) {
if (dump_file.find("prep_node") == 0) {
continue;
}
std::vector<string> items = str_util::Split(dump_file, '_');
int tensor_index;
strings::safe_strto32(items[2], &tensor_index);
const string file_path =
io::JoinPath(server_client_pair->dump_root, dump_file);
Event event;
TF_ASSERT_OK(ReadEventFromFile(file_path, &event));
const TensorProto& tensor_proto = event.summary().value(0).tensor();
Tensor tensor(tensor_proto.dtype());
ASSERT_TRUE(tensor.FromProto(tensor_proto));
// Verify the content of the tensor sent via the Event stream.
ASSERT_EQ(TensorShape({1, 1}), tensor.shape());
ASSERT_EQ(tensor_index * tensor_index, tensor.flat<int>()(0));
}
TearDownServer();
}
} // namespace
} // namespace tensorflow

View File

@ -0,0 +1,106 @@
/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/core/debug/debug_grpc_testlib.h"
#include "tensorflow/core/debug/debug_graph_utils.h"
#include "tensorflow/core/debug/debug_io_utils.h"
#include "tensorflow/core/lib/io/path.h"
#include "tensorflow/core/lib/strings/str_util.h"
#include "tensorflow/core/platform/env.h"
#include "tensorflow/core/platform/tracing.h"
namespace tensorflow {
namespace test {
::grpc::Status TestEventListenerImpl::SendEvents(
::grpc::ServerContext* context,
::grpc::ServerReaderWriter< ::tensorflow::EventReply, ::tensorflow::Event>*
stream) {
Event event;
while (stream->Read(&event)) {
const Summary::Value& val = event.summary().value(0);
std::vector<string> name_items =
tensorflow::str_util::Split(val.node_name(), ':');
const string node_name = name_items[0];
int32 output_slot = 0;
tensorflow::strings::safe_strto32(name_items[1], &output_slot);
const string debug_op = name_items[2];
const TensorProto& tensor_proto = val.tensor();
Tensor tensor(tensor_proto.dtype());
if (!tensor.FromProto(tensor_proto)) {
return ::grpc::Status::CANCELLED;
}
string dump_path;
DebugFileIO::DumpTensorToDir(node_name, output_slot, debug_op, tensor,
event.wall_time(), dump_root, &dump_path);
}
return ::grpc::Status::OK;
}
GrpcTestServerClientPair::GrpcTestServerClientPair(const int server_port)
: server_port(server_port) {
const int kTensorSize = 2;
prep_tensor_.reset(
new Tensor(DT_FLOAT, TensorShape({kTensorSize, kTensorSize})));
for (int i = 0; i < kTensorSize * kTensorSize; ++i) {
prep_tensor_->flat<float>()(i) = static_cast<float>(i);
}
// Obtain server's gRPC url.
test_server_url = strings::StrCat("grpc://0.0.0.0:", server_port);
// Obtain dump directory for the stream server.
string tmp_dir = port::Tracing::LogDir();
dump_root =
io::JoinPath(tmp_dir, strings::StrCat("tfdbg_dump_port", server_port, "_",
Env::Default()->NowMicros()));
}
bool GrpcTestServerClientPair::PollTillFirstRequestSucceeds() {
const std::vector<string> urls({test_server_url});
int n_attempts = 0;
bool success = false;
// Try a number of times to send the Event proto to the server, as it may
// take the server a few seconds to start up and become responsive.
while (n_attempts++ < kMaxAttempts) {
const uint64 wall_time = Env::Default()->NowMicros();
Status publish_s = DebugIO::PublishDebugTensor(
"prep_node:0", "DebugIdentity", *prep_tensor_, wall_time, urls);
Status close_s = DebugIO::CloseDebugURL(test_server_url);
if (publish_s.ok() && close_s.ok()) {
success = true;
break;
} else {
Env::Default()->SleepForMicroseconds(kSleepDurationMicros);
}
}
return success;
}
} // namespace test
} // namespace tensorflow

View File

@ -0,0 +1,68 @@
/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_DEBUG_GRPC_TESTLIB_H_
#define TENSORFLOW_DEBUG_GRPC_TESTLIB_H_
#include "grpc++/grpc++.h"
#include "tensorflow/core/debug/debug_service.grpc.pb.h"
#include "tensorflow/core/framework/tensor.h"
namespace tensorflow {
namespace test {
class TestEventListenerImpl final : public EventListener::Service {
public:
TestEventListenerImpl(const string& dump_root) : dump_root(dump_root) {}
::grpc::Status SendEvents(
::grpc::ServerContext* context,
::grpc::ServerReaderWriter< ::tensorflow::EventReply,
::tensorflow::Event>* stream);
string dump_root;
};
class GrpcTestServerClientPair {
public:
GrpcTestServerClientPair(const int server_port);
virtual ~GrpcTestServerClientPair() {}
// Keep sending requests to the test server until the first success.
// This is necessary because the server may take a certain amount of time
// to start up and become responsive.
//
// Returns: A boolean indicating whether a successful response is obtained
// within the limit of maximum number of attempts.
bool PollTillFirstRequestSucceeds();
string dump_root;
int server_port;
string test_server_url;
private:
std::unique_ptr<Tensor> prep_tensor_;
const int kMaxAttempts = 100;
const int kSleepDurationMicros = 100 * 1000;
};
} // namespace test
} // namespace tensorflow
#endif // TENSORFLOW_DEBUG_GRPC_TESTLIB_H_

View File

@ -17,6 +17,8 @@ limitations under the License.
#include <vector>
#include "grpc++/create_channel.h"
#include "tensorflow/core/debug/debug_service.grpc.pb.h"
#include "tensorflow/core/framework/summary.pb.h"
#include "tensorflow/core/lib/io/path.h"
#include "tensorflow/core/lib/strings/str_util.h"
@ -56,6 +58,35 @@ Event WrapTensorAsEvent(const string& tensor_name, const string& debug_op,
} // namespace
Status ReadEventFromFile(const string& dump_file_path, Event* event) {
Env* env(Env::Default());
string content;
uint64 file_size = 0;
Status s = env->GetFileSize(dump_file_path, &file_size);
if (!s.ok()) {
return s;
}
content.resize(file_size);
std::unique_ptr<RandomAccessFile> file;
s = env->NewRandomAccessFile(dump_file_path, &file);
if (!s.ok()) {
return s;
}
StringPiece result;
s = file->Read(0, file_size, &result, &(content)[0]);
if (!s.ok()) {
return s;
}
event->ParseFromString(content);
return Status::OK();
}
// static
const char* const DebugIO::kFileURLScheme = "file://";
// static
@ -85,6 +116,7 @@ Status DebugIO::PublishDebugTensor(const string& tensor_name,
}
int num_failed_urls = 0;
std::vector<Status> fail_statuses;
for (const string& url : debug_urls) {
if (str_util::Lowercase(url).find(kFileURLScheme) == 0) {
const string dump_root_dir = url.substr(strlen(kFileURLScheme));
@ -94,12 +126,18 @@ Status DebugIO::PublishDebugTensor(const string& tensor_name,
wall_time_us, dump_root_dir, nullptr);
if (!s.ok()) {
num_failed_urls++;
fail_statuses.push_back(s);
}
} else if (str_util::Lowercase(url).find(kGrpcURLScheme) == 0) {
// TODO(cais): Implement PublishTensor with grpc urls.
return Status(error::UNIMPLEMENTED,
strings::StrCat("Puslishing to GRPC debug target is not ",
"implemented yet"));
const string grpc_server_stream_addr = url.substr(strlen(kGrpcURLScheme));
Status s = DebugGrpcIO::SendTensorThroughGrpcStream(
node_name, output_slot, debug_op, tensor, wall_time_us,
grpc_server_stream_addr);
if (!s.ok()) {
num_failed_urls++;
fail_statuses.push_back(s);
}
} else {
return Status(error::UNAVAILABLE,
strings::StrCat("Invalid debug target URL: ", url));
@ -109,13 +147,31 @@ Status DebugIO::PublishDebugTensor(const string& tensor_name,
if (num_failed_urls == 0) {
return Status::OK();
} else {
return Status(
error::INTERNAL,
strings::StrCat("Puslishing to ", num_failed_urls, " of ",
debug_urls.size(), " debug target URLs failed"));
string error_message = strings::StrCat(
"Publishing to ", num_failed_urls, " of ", debug_urls.size(),
" debug target URLs failed, due to the following errors:");
for (Status& status : fail_statuses) {
error_message =
strings::StrCat(error_message, " ", status.error_message(), ";");
}
return Status(error::INTERNAL, error_message);
}
}
Status DebugIO::CloseDebugURL(const string& debug_url) {
if (debug_url.find(DebugIO::kGrpcURLScheme) == 0) {
return DebugGrpcIO::CloseGrpcStream(
debug_url.substr(strlen(DebugIO::kGrpcURLScheme)));
} else {
// No-op for non-gRPC URLs.
return Status::OK();
}
}
// static
static Status CloseDebugURL(const string& debug_url) { return Status::OK(); }
// static
Status DebugFileIO::DumpTensorToDir(
const string& node_name, const int32 output_slot, const string& debug_op,
@ -208,4 +264,96 @@ Status DebugFileIO::RecursiveCreateDir(Env* env, const string& dir) {
}
}
DebugGrpcChannel::DebugGrpcChannel(const string& server_stream_addr)
: ctx_(),
channel_(::grpc::CreateCustomChannel(server_stream_addr,
::grpc::InsecureChannelCredentials(),
::grpc::ChannelArguments())),
stub_(EventListener::NewStub(channel_)),
reader_writer_(stub_->SendEvents(&ctx_)),
mu_() {}
// TODO(cais): Set GRPC_ARG_MAX_MESSAGE_LENGTH to max if necessary.
bool DebugGrpcChannel::is_channel_ready() {
return channel_->GetState(false) == GRPC_CHANNEL_READY;
}
bool DebugGrpcChannel::WriteEvent(const Event& event) {
mutex_lock l(mu_);
return reader_writer_->Write(event);
}
Status DebugGrpcChannel::Close() {
mutex_lock l(mu_);
reader_writer_->WritesDone();
if (reader_writer_->Finish().ok()) {
std::cout << "Finish() returned ok status" << std::endl; // DEBUG
return Status::OK();
} else {
std::cout << "Finish() returned non-ok status" << std::endl; // DEBUG
return Status(error::FAILED_PRECONDITION,
"Failed to close debug GRPC stream.");
}
}
// static
mutex DebugGrpcIO::streams_mu;
std::unordered_map<string, std::shared_ptr<DebugGrpcChannel>>
DebugGrpcIO::stream_channels;
// static
Status DebugGrpcIO::SendTensorThroughGrpcStream(
const string& node_name, const int32 output_slot, const string& debug_op,
const Tensor& tensor, const uint64 wall_time_us,
const string& server_stream_addr) {
const string tensor_name = strings::StrCat(node_name, ":", output_slot);
// Prepare tensor Event data to be sent.
Event event = WrapTensorAsEvent(tensor_name, debug_op, tensor, wall_time_us);
std::shared_ptr<DebugGrpcChannel> debug_grpc_channel;
{
mutex_lock l(streams_mu);
if (stream_channels.find(server_stream_addr) == stream_channels.end()) {
debug_grpc_channel.reset(new DebugGrpcChannel(server_stream_addr));
if (!debug_grpc_channel->is_channel_ready()) {
return errors::FailedPrecondition(
strings::StrCat("Channel at the following gRPC address is ",
"not ready: ", server_stream_addr));
}
stream_channels[server_stream_addr] = debug_grpc_channel;
} else {
debug_grpc_channel = stream_channels[server_stream_addr];
}
}
bool write_ok = debug_grpc_channel->WriteEvent(event);
if (!write_ok) {
return errors::Cancelled(strings::StrCat("Write event to stream URL ",
server_stream_addr, "failed."));
}
return Status::OK();
}
Status DebugGrpcIO::CloseGrpcStream(const string& server_stream_addr) {
mutex_lock l(streams_mu);
if (stream_channels.find(server_stream_addr) != stream_channels.end()) {
// Stream of the specified address exists. Close it and remove it from
// record.
Status s;
s = stream_channels[server_stream_addr]->Close();
stream_channels.erase(server_stream_addr);
return s;
} else {
// Stream of the specified address does not exist. No action.
return Status::OK();
}
}
} // namespace tensorflow

View File

@ -16,6 +16,10 @@ limitations under the License.
#ifndef TENSORFLOW_DEBUG_IO_UTILS_H_
#define TENSORFLOW_DEBUG_IO_UTILS_H_
#include <unordered_map>
#include <unordered_set>
#include "tensorflow/core/debug/debug_service.grpc.pb.h"
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/lib/gtl/array_slice.h"
@ -23,6 +27,8 @@ limitations under the License.
namespace tensorflow {
Status ReadEventFromFile(const string& dump_file_path, Event* event);
class DebugIO {
public:
// Publish a tensor to a debug target URL.
@ -36,12 +42,14 @@ class DebugIO {
// tensor: The Tensor object being published.
// wall_time_us: Time stamp for the Tensor. Unit: microseconds (us).
// debug_urls: An array of debug target URLs, e.g.,
// "file:///foo/tfdbg_dump", "grpc://localhot:11011"
// "file:///foo/tfdbg_dump", "grpc://localhost:11011"
static Status PublishDebugTensor(const string& tensor_name,
const string& debug_op, const Tensor& tensor,
const uint64 wall_time_us,
const gtl::ArraySlice<string>& debug_urls);
static Status CloseDebugURL(const string& debug_url);
private:
static const char* const kFileURLScheme;
static const char* const kGrpcURLScheme;
@ -70,7 +78,7 @@ class DebugFileIO {
// tensor: The Tensor object to be dumped to file.
// wall_time_us: Wall time at which the Tensor is generated during graph
// execution. Unit: microseconds (us).
// dump_root_dir: Root diretory for dumping the tensor.
// dump_root_dir: Root directory for dumping the tensor.
// dump_file_path: The actual dump file path (passed as reference).
static Status DumpTensorToDir(const string& node_name,
const int32 output_slot, const string& debug_op,
@ -104,6 +112,68 @@ class DebugFileIO {
static Status RecursiveCreateDir(Env* env, const string& dir);
};
class DebugGrpcChannel {
public:
// Constructor of DebugGrpcChannel.
//
// Args:
// server_stream_addr: Address (host name and port) of the debug stream
// server implementing the EventListener service (see
// debug_service.proto). E.g., "127.0.0.1:12345".
DebugGrpcChannel(const string& server_stream_addr);
virtual ~DebugGrpcChannel() {}
// Query whether the gRPC channel is ready for use.
bool is_channel_ready();
// Write an Event proto to the debug gRPC stream.
//
// Thread-safety: Safe with respect to other calls to the same method and
// call to Close().
// Args:
// event: The event proto to be written to the stream.
//
// Returns:
// True iff the write is successful.
bool WriteEvent(const Event& event);
// Close the stream and the channel.
Status Close();
private:
::grpc::ClientContext ctx_;
std::shared_ptr<::grpc::Channel> channel_;
std::unique_ptr<EventListener::Stub> stub_;
std::unique_ptr<::grpc::ClientReaderWriterInterface<Event, EventReply>>
reader_writer_;
mutex mu_;
};
class DebugGrpcIO {
public:
// Send a tensor through a debug gRPC stream.
// Thread-safety: Safe with respect to other calls to the same method and
// calls to CloseGrpcStream().
static Status SendTensorThroughGrpcStream(const string& node_name,
const int32 output_slot,
const string& debug_op,
const Tensor& tensor,
const uint64 wall_time_us,
const string& server_stream_addr);
// Close a gRPC stream to the given address, if it exists.
// Thread-safety: Safe with respect to other calls to the same method and
// calls to SendTensorThroughGrpcStream().
static Status CloseGrpcStream(const string& server_stream_addr);
private:
static mutex streams_mu;
static std::unordered_map<string, std::shared_ptr<DebugGrpcChannel>>
stream_channels GUARDED_BY(streams_mu);
};
} // namespace tensorflow
#endif // TENSORFLOW_DEBUG_IO_UTILS_H_

View File

@ -42,33 +42,6 @@ class DebugIOUtilsTest : public ::testing::Test {
tensor_b_->flat<string>()(1) = "garply";
}
Status ReadEventFromFile(const string& dump_file_path, Event* event) {
string content;
uint64 file_size = 0;
Status s = env_->GetFileSize(dump_file_path, &file_size);
if (!s.ok()) {
return s;
}
content.resize(file_size);
std::unique_ptr<RandomAccessFile> file;
s = env_->NewRandomAccessFile(dump_file_path, &file);
if (!s.ok()) {
return s;
}
StringPiece result;
s = file->Read(0, file_size, &result, &(content)[0]);
if (!s.ok()) {
return s;
}
event->ParseFromString(content);
return Status::OK();
}
Env* env_;
std::unique_ptr<Tensor> tensor_a_;
std::unique_ptr<Tensor> tensor_b_;
@ -84,7 +57,7 @@ TEST_F(DebugIOUtilsTest, DumpFloatTensorToFileSunnyDay) {
const string kNodeName = "foo/bar/qux/tensor_a";
const string kDebugOpName = "DebugIdentity";
const int32 output_slot = 0;
uint64 wall_time = env_->NowMicros();
const uint64 wall_time = env_->NowMicros();
string dump_file_path;
TF_ASSERT_OK(DebugFileIO::DumpTensorToDir(kNodeName, output_slot,
@ -127,7 +100,7 @@ TEST_F(DebugIOUtilsTest, DumpStringTensorToFileSunnyDay) {
const string kNodeName = "quux/grault/tensor_b";
const string kDebugOpName = "DebugIdentity";
const int32 output_slot = 1;
uint64 wall_time = env_->NowMicros();
const uint64 wall_time = env_->NowMicros();
string dump_file_name;
Status s = DebugFileIO::DumpTensorToDir(kNodeName, output_slot, kDebugOpName,
@ -190,7 +163,7 @@ TEST_F(DebugIOUtilsTest, DumpTensorToFileCannotCreateDirectory) {
const string kNodeName = "baz/tensor_a";
const string kDebugOpName = "DebugIdentity";
const int32 output_slot = 0;
uint64 wall_time = env_->NowMicros();
const uint64 wall_time = env_->NowMicros();
string dump_file_name;
Status s = DebugFileIO::DumpTensorToDir(kNodeName, output_slot, kDebugOpName,
@ -215,8 +188,7 @@ TEST_F(DebugIOUtilsTest, PublishTensorToMultipleFileURLs) {
const string kNodeName = "foo/bar/qux/tensor_a";
const string kDebugOpName = "DebugIdentity";
const int32 output_slot = 0;
uint64 wall_time = env_->NowMicros();
const uint64 wall_time = env_->NowMicros();
std::vector<string> dump_roots;
std::vector<string> dump_file_paths;
@ -237,6 +209,7 @@ TEST_F(DebugIOUtilsTest, PublishTensorToMultipleFileURLs) {
const string tensor_name = strings::StrCat(kNodeName, ":", output_slot);
const string debug_node_name =
strings::StrCat(tensor_name, ":", kDebugOpName);
Status s = DebugIO::PublishDebugTensor(tensor_name, kDebugOpName, *tensor_a_,
wall_time, urls);
ASSERT_TRUE(s.ok());
@ -283,7 +256,7 @@ TEST_F(DebugIOUtilsTest, PublishTensorConcurrentlyToPartiallyOverlappingPaths) {
thread::ThreadPool* tp =
new thread::ThreadPool(Env::Default(), "test", kConcurrentPubs);
uint64 wall_time = env_->NowMicros();
const uint64 wall_time = env_->NowMicros();
const string dump_root_base = testing::TmpDir();
const string tensor_name = strings::StrCat(kNodeName, ":", kOutputSlot);
@ -318,6 +291,7 @@ TEST_F(DebugIOUtilsTest, PublishTensorConcurrentlyToPartiallyOverlappingPaths) {
std::vector<string> urls;
urls.push_back(debug_url);
Status s = DebugIO::PublishDebugTensor(tensor_name, kDebugOpName,
*tensor_a_, wall_time, urls);
ASSERT_TRUE(s.ok());

View File

@ -0,0 +1,49 @@
/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "grpc++/grpc++.h"
#include "tensorflow/core/debug/debug_grpc_testlib.h"
#include "tensorflow/core/lib/strings/str_util.h"
#include "tensorflow/core/lib/strings/strcat.h"
#include "tensorflow/core/lib/strings/stringprintf.h"
#include "tensorflow/core/platform/env.h"
#include "tensorflow/core/platform/init_main.h"
// Usage: debug_test_server_main <port> <dump_root>
int main(int argc, char* argv[]) {
if (argc != 3) {
std::cerr << "Usage: debug_test_server_main <port> <dump_root>"
<< std::endl;
return 1;
}
tensorflow::port::InitMain(argv[0], &argc, &argv);
int port = 0;
tensorflow::strings::safe_strto32(argv[1], &port);
std::string test_server_addr = tensorflow::strings::StrCat("0.0.0.0:", port);
tensorflow::test::TestEventListenerImpl debug_test_server(argv[2]);
::grpc::ServerBuilder builder;
builder.AddListeningPort(test_server_addr,
::grpc::InsecureServerCredentials());
builder.RegisterService(&debug_test_server);
std::unique_ptr<::grpc::Server> test_server = builder.BuildAndStart();
test_server->Wait();
return 0;
}

View File

@ -4152,7 +4152,7 @@ input_min: If range is given, this is the min of the range.
input_max: If range is given, this is the max of the range.
)doc");
// EXPERIMENTAL: tfdb debugger-inserted ops.
// EXPERIMENTAL: tfdbg debugger-inserted ops.
REGISTER_OP("Copy")
.Input("input: T")
.Output("output: T")