tfdbg: extend grpc_debug_server protocol for interactive debugging

Previously, a grpc-gated debug op has two modes: DISABLED and ENABLED.
This CL splits the ENABLED state into two states: READ_ONLY and READ_WRITE.

* READ_ONLY is equivalent to the previous ENABLED state, wherein a debug op
  publishes debug tensor to the grpc debug server and proceeds. It can be
  regarded as a "watchpoint" that doesn't block execution.
* READ_WRITE is a "breakpoint". In addition to publishing the debug tensor,
  it blocks and awaits a EventReply proto response from the grpc debug server
  before proceeding.

PiperOrigin-RevId: 164987725
This commit is contained in:
Shanqing Cai 2017-08-11 09:37:02 -07:00 committed by TensorFlower Gardener
parent c0f9b0a91e
commit 3c482c66b5
11 changed files with 439 additions and 154 deletions

View File

@ -65,10 +65,6 @@ class GrpcDebugTest : public ::testing::Test {
void ClearEnabledWatchKeys() { DebugGrpcIO::ClearEnabledWatchKeys(); }
void CreateEmptyEnabledSet(const string& grpc_debug_url) {
DebugGrpcIO::CreateEmptyEnabledSet(grpc_debug_url);
}
const int64 GetChannelConnectionTimeoutMicros() {
return DebugGrpcIO::channel_connection_timeout_micros;
}
@ -261,7 +257,7 @@ TEST_F(GrpcDebugTest, SendMultipleDebugTensorsSynchronizedViaGrpcTest) {
}
}
TEST_F(GrpcDebugTest, SendeDebugTensorsThroughMultipleRoundsUsingGrpcGating) {
TEST_F(GrpcDebugTest, SendDebugTensorsThroughMultipleRoundsUsingGrpcGating) {
// Prepare the tensor to send.
const DebugNodeKey kDebugNodeKey("/job:localhost/replica:0/task:0/cpu:0",
"test_namescope/test_node", 0,
@ -283,8 +279,60 @@ TEST_F(GrpcDebugTest, SendeDebugTensorsThroughMultipleRoundsUsingGrpcGating) {
TF_ASSERT_OK(DebugIO::PublishDebugTensor(kDebugNodeKey, tensor, wall_time,
urls, enable_gated_grpc));
server_data_.server->RequestDebugOpStateChangeAtNextStream(i == 0,
kDebugNodeKey);
server_data_.server->RequestDebugOpStateChangeAtNextStream(
i == 0 ? EventReply::DebugOpStateChange::READ_ONLY
: EventReply::DebugOpStateChange::DISABLED,
kDebugNodeKey);
// Close the debug gRPC stream.
Status close_status = DebugIO::CloseDebugURL(server_data_.url);
ASSERT_TRUE(close_status.ok());
// Check dumped files according to the expected gating results.
if (i < 2) {
ASSERT_EQ(1, server_data_.server->node_names.size());
ASSERT_EQ(1, server_data_.server->output_slots.size());
ASSERT_EQ(1, server_data_.server->debug_ops.size());
EXPECT_EQ(kDebugNodeKey.device_name,
server_data_.server->device_names[0]);
EXPECT_EQ(kDebugNodeKey.node_name, server_data_.server->node_names[0]);
EXPECT_EQ(kDebugNodeKey.output_slot,
server_data_.server->output_slots[0]);
EXPECT_EQ(kDebugNodeKey.debug_op, server_data_.server->debug_ops[0]);
} else {
ASSERT_EQ(0, server_data_.server->node_names.size());
}
}
}
TEST_F(GrpcDebugTest, SendDebugTensorsThroughMultipleRoundsUnderReadWriteMode) {
// Prepare the tensor to send.
const DebugNodeKey kDebugNodeKey("/job:localhost/replica:0/task:0/cpu:0",
"test_namescope/test_node", 0,
"DebugIdentity");
Tensor tensor(DT_INT32, TensorShape({1, 1}));
tensor.flat<int>()(0) = 42;
const std::vector<string> urls({server_data_.url});
for (int i = 0; i < 3; ++i) {
server_data_.server->ClearReceivedDebugData();
const uint64 wall_time = Env::Default()->NowMicros();
// On the 1st send (i == 0), gating is disabled, so data should be sent.
// On the 2nd send (i == 1), gating is enabled, and the server has enabled
// the watch key in the previous send (READ_WRITE), so data should be
// sent. In this iteration, the server response with a EventReply proto to
// unblock the debug node.
// On the 3rd send (i == 2), gating is enabled, but the server has disabled
// the watch key in the previous send, so data should not be sent.
const bool enable_gated_grpc = (i != 0);
TF_ASSERT_OK(DebugIO::PublishDebugTensor(kDebugNodeKey, tensor, wall_time,
urls, enable_gated_grpc));
server_data_.server->RequestDebugOpStateChangeAtNextStream(
i == 0 ? EventReply::DebugOpStateChange::READ_WRITE
: EventReply::DebugOpStateChange::DISABLED,
kDebugNodeKey);
// Close the debug gRPC stream.
Status close_status = DebugIO::CloseDebugURL(server_data_.url);
@ -308,8 +356,6 @@ TEST_F(GrpcDebugTest, SendeDebugTensorsThroughMultipleRoundsUsingGrpcGating) {
}
TEST_F(GrpcDebugTest, TestGateDebugNodeOnEmptyEnabledSet) {
CreateEmptyEnabledSet("grpc://localhost:3333");
ASSERT_FALSE(DebugIO::IsDebugNodeGateOpen("foo:0:DebugIdentity",
{"grpc://localhost:3333"}));
@ -322,8 +368,12 @@ TEST_F(GrpcDebugTest, TestGateDebugNodeOnNonEmptyEnabledSet) {
const string kGrpcUrl1 = "grpc://localhost:3333";
const string kGrpcUrl2 = "grpc://localhost:3334";
DebugGrpcIO::EnableWatchKey(kGrpcUrl1, "foo:0:DebugIdentity");
DebugGrpcIO::EnableWatchKey(kGrpcUrl1, "bar:0:DebugIdentity");
DebugGrpcIO::SetDebugNodeKeyGrpcState(
kGrpcUrl1, "foo:0:DebugIdentity",
EventReply::DebugOpStateChange::READ_ONLY);
DebugGrpcIO::SetDebugNodeKeyGrpcState(
kGrpcUrl1, "bar:0:DebugIdentity",
EventReply::DebugOpStateChange::READ_ONLY);
ASSERT_FALSE(
DebugIO::IsDebugNodeGateOpen("foo:1:DebugIdentity", {kGrpcUrl1}));
@ -350,9 +400,12 @@ TEST_F(GrpcDebugTest, TestGateDebugNodeOnMultipleEmptyEnabledSets) {
const string kGrpcUrl2 = "grpc://localhost:3334";
const string kGrpcUrl3 = "grpc://localhost:3335";
DebugGrpcIO::EnableWatchKey(kGrpcUrl1, "foo:0:DebugIdentity");
DebugGrpcIO::EnableWatchKey(kGrpcUrl2, "bar:0:DebugIdentity");
CreateEmptyEnabledSet(kGrpcUrl3);
DebugGrpcIO::SetDebugNodeKeyGrpcState(
kGrpcUrl1, "foo:0:DebugIdentity",
EventReply::DebugOpStateChange::READ_ONLY);
DebugGrpcIO::SetDebugNodeKeyGrpcState(
kGrpcUrl2, "bar:0:DebugIdentity",
EventReply::DebugOpStateChange::READ_ONLY);
ASSERT_TRUE(DebugIO::IsDebugNodeGateOpen("foo:0:DebugIdentity", {kGrpcUrl1}));
ASSERT_TRUE(DebugIO::IsDebugNodeGateOpen("bar:0:DebugIdentity", {kGrpcUrl2}));
@ -375,7 +428,9 @@ TEST_F(GrpcDebugTest, TestGateDebugNodeOnMultipleEmptyEnabledSets) {
}
TEST_F(GrpcDebugTest, TestGateDebugNodeOnNonEmptyEnabledSetAndEmptyURLs) {
DebugGrpcIO::EnableWatchKey("grpc://localhost:3333", "foo:0:DebugIdentity");
DebugGrpcIO::SetDebugNodeKeyGrpcState(
"grpc://localhost:3333", "foo:0:DebugIdentity",
EventReply::DebugOpStateChange::READ_ONLY);
std::vector<string> debug_urls_1;
ASSERT_FALSE(
@ -385,7 +440,6 @@ TEST_F(GrpcDebugTest, TestGateDebugNodeOnNonEmptyEnabledSetAndEmptyURLs) {
TEST_F(GrpcDebugTest, TestGateCopyNodeOnEmptyEnabledSet) {
const string kGrpcUrl1 = "grpc://localhost:3333";
const string kWatch1 = "foo:0:DebugIdentity";
CreateEmptyEnabledSet(kGrpcUrl1);
ASSERT_FALSE(DebugIO::IsCopyNodeGateOpen(
{DebugWatchAndURLSpec(kWatch1, kGrpcUrl1, true)}));
@ -404,9 +458,8 @@ TEST_F(GrpcDebugTest, TestGateCopyNodeOnNonEmptyEnabledSet) {
const string kGrpcUrl2 = "grpc://localhost:3334";
const string kWatch1 = "foo:0:DebugIdentity";
const string kWatch2 = "foo:1:DebugIdentity";
CreateEmptyEnabledSet(kGrpcUrl1);
CreateEmptyEnabledSet(kGrpcUrl2);
DebugGrpcIO::EnableWatchKey(kGrpcUrl1, kWatch1);
DebugGrpcIO::SetDebugNodeKeyGrpcState(
kGrpcUrl1, kWatch1, EventReply::DebugOpStateChange::READ_ONLY);
ASSERT_TRUE(DebugIO::IsCopyNodeGateOpen(
{DebugWatchAndURLSpec(kWatch1, kGrpcUrl1, true)}));

View File

@ -72,27 +72,44 @@ namespace test {
output_slots.push_back(metadata.output_slot());
debug_ops.push_back(debug_op);
debug_tensors.push_back(tensor);
// If the debug node is currently in the READ_WRITE mode, send an
// EventReply to 1) unblock the execution and 2) optionally modify the
// value.
const DebugNodeKey debug_node_key(metadata.device(), node_name,
metadata.output_slot(), debug_op);
if (write_enabled_debug_node_keys_.find(debug_node_key) !=
write_enabled_debug_node_keys_.end()) {
stream->Write(EventReply());
}
}
}
{
mutex_lock l(changes_mu_);
for (size_t i = 0; i < changes_to_enable_.size(); ++i) {
mutex_lock l(states_mu_);
for (size_t i = 0; i < new_states_.size(); ++i) {
EventReply event_reply;
EventReply::DebugOpStateChange* change =
event_reply.add_debug_op_state_changes();
change->set_change(changes_to_enable_[i]
? EventReply::DebugOpStateChange::ENABLE
: EventReply::DebugOpStateChange::DISABLE);
change->set_node_name(changes_node_names_[i]);
change->set_output_slot(changes_output_slots_[i]);
change->set_debug_op(changes_debug_ops_[i]);
// State changes will take effect in the next stream, i.e., next debugged
// Session.run() call.
change->set_state(new_states_[i]);
const DebugNodeKey& debug_node_key = debug_node_keys_[i];
change->set_node_name(debug_node_key.node_name);
change->set_output_slot(debug_node_key.output_slot);
change->set_debug_op(debug_node_key.debug_op);
stream->Write(event_reply);
if (new_states_[i] == EventReply::DebugOpStateChange::READ_WRITE) {
write_enabled_debug_node_keys_.insert(debug_node_key);
} else {
write_enabled_debug_node_keys_.erase(debug_node_key);
}
}
changes_to_enable_.clear();
changes_node_names_.clear();
changes_output_slots_.clear();
changes_debug_ops_.clear();
debug_node_keys_.clear();
new_states_.clear();
}
return ::grpc::Status::OK;
@ -109,13 +126,12 @@ void TestEventListenerImpl::ClearReceivedDebugData() {
}
void TestEventListenerImpl::RequestDebugOpStateChangeAtNextStream(
bool to_enable, const DebugNodeKey& debug_node_key) {
mutex_lock l(changes_mu_);
const EventReply::DebugOpStateChange::State new_state,
const DebugNodeKey& debug_node_key) {
mutex_lock l(states_mu_);
changes_to_enable_.push_back(to_enable);
changes_node_names_.push_back(debug_node_key.node_name);
changes_output_slots_.push_back(debug_node_key.output_slot);
changes_debug_ops_.push_back(debug_node_key.debug_op);
debug_node_keys_.push_back(debug_node_key);
new_states_.push_back(new_state);
}
void TestEventListenerImpl::RunServer(const int server_port) {

View File

@ -17,6 +17,7 @@ limitations under the License.
#define TENSORFLOW_DEBUG_GRPC_TESTLIB_H_
#include <atomic>
#include <unordered_set>
#include "grpc++/grpc++.h"
#include "tensorflow/core/debug/debug_io_utils.h"
@ -44,7 +45,8 @@ class TestEventListenerImpl final : public EventListener::Service {
void ClearReceivedDebugData();
void RequestDebugOpStateChangeAtNextStream(
bool to_enable, const DebugNodeKey& debug_node_key);
const EventReply::DebugOpStateChange::State new_state,
const DebugNodeKey& debug_node_key);
std::vector<string> debug_metadata_strings;
std::vector<string> encoded_graph_defs;
@ -58,12 +60,13 @@ class TestEventListenerImpl final : public EventListener::Service {
std::atomic_bool stop_requested_;
std::atomic_bool stopped_;
std::vector<bool> changes_to_enable_ GUARDED_BY(changes_mu_);
std::vector<string> changes_node_names_ GUARDED_BY(changes_mu_);
std::vector<int32> changes_output_slots_ GUARDED_BY(changes_mu_);
std::vector<string> changes_debug_ops_ GUARDED_BY(changes_mu_);
std::vector<DebugNodeKey> debug_node_keys_ GUARDED_BY(states_mu_);
std::vector<EventReply::DebugOpStateChange::State> new_states_
GUARDED_BY(states_mu_);
mutex changes_mu_;
std::unordered_set<DebugNodeKey> write_enabled_debug_node_keys_;
mutex states_mu_;
};
// Poll a gRPC debug server by sending a small tensor repeatedly till success.

View File

@ -15,6 +15,11 @@ limitations under the License.
#include "tensorflow/core/debug/debug_io_utils.h"
#include <stddef.h>
#include <string.h>
#include <cmath>
#include <limits>
#include <utility>
#include <vector>
#ifndef PLATFORM_WINDOWS
@ -297,6 +302,15 @@ DebugNodeKey::DebugNodeKey(const string& device_name, const string& node_name,
strings::StrCat(node_name, ":", output_slot, ":", debug_op)),
device_path(DeviceNameToDevicePath(device_name)) {}
bool DebugNodeKey::operator==(const DebugNodeKey& other) const {
return (device_name == other.device_name && node_name == other.node_name &&
output_slot == other.output_slot && debug_op == other.debug_op);
}
bool DebugNodeKey::operator!=(const DebugNodeKey& other) const {
return !((*this) == other);
}
Status ReadEventFromFile(const string& dump_file_path, Event* event) {
Env* env(Env::Default());
@ -537,7 +551,7 @@ bool DebugIO::IsCopyNodeGateOpen(
DebugIO::kGrpcURLScheme)) {
return true;
} else {
if (DebugGrpcIO::IsGateOpen(spec.watch_key, spec.url)) {
if (DebugGrpcIO::IsReadGateOpen(spec.url, spec.watch_key)) {
return true;
}
}
@ -557,7 +571,7 @@ bool DebugIO::IsDebugNodeGateOpen(const string& watch_key,
DebugIO::kGrpcURLScheme)) {
return true;
} else {
if (DebugGrpcIO::IsGateOpen(watch_key, debug_url)) {
if (DebugGrpcIO::IsReadGateOpen(debug_url, watch_key)) {
return true;
}
}
@ -575,7 +589,7 @@ bool DebugIO::IsDebugURLGateOpen(const string& watch_key,
if (debug_url.find(kGrpcURLScheme) != 0) {
return true;
} else {
return DebugGrpcIO::IsGateOpen(watch_key, debug_url);
return DebugGrpcIO::IsReadGateOpen(debug_url, watch_key);
}
#else
return true;
@ -731,6 +745,12 @@ bool DebugGrpcChannel::WriteEvent(const Event& event) {
return reader_writer_->Write(event);
}
bool DebugGrpcChannel::ReadEventReply(EventReply* event_reply) {
mutex_lock l(mu_);
return reader_writer_->Read(event_reply);
}
Status DebugGrpcChannel::ReceiveServerRepliesAndClose() {
mutex_lock l(mu_);
@ -744,13 +764,8 @@ Status DebugGrpcChannel::ReceiveServerRepliesAndClose() {
string watch_key = strings::StrCat(debug_op_state_change.node_name(), ":",
debug_op_state_change.output_slot(),
":", debug_op_state_change.debug_op());
if (debug_op_state_change.change() ==
EventReply::DebugOpStateChange::ENABLE) {
DebugGrpcIO::EnableWatchKey(url_, watch_key);
} else if (debug_op_state_change.change() ==
EventReply::DebugOpStateChange::DISABLE) {
DebugGrpcIO::DisableWatchKey(url_, watch_key);
}
DebugGrpcIO::SetDebugNodeKeyGrpcState(url_, watch_key,
debug_op_state_change.state());
}
}
@ -789,7 +804,8 @@ Status DebugGrpcIO::SendTensorThroughGrpcStream(
const DebugNodeKey& debug_node_key, const Tensor& tensor,
const uint64 wall_time_us, const string& grpc_stream_url,
const bool gated) {
if (gated && !IsGateOpen(debug_node_key.debug_node_name, grpc_stream_url)) {
if (gated &&
!IsReadGateOpen(grpc_stream_url, debug_node_key.debug_node_name)) {
return Status::OK();
} else {
std::vector<Event> events;
@ -799,10 +815,35 @@ Status DebugGrpcIO::SendTensorThroughGrpcStream(
TF_RETURN_IF_ERROR(
SendEventProtoThroughGrpcStream(event, grpc_stream_url));
}
if (IsWriteGateOpen(grpc_stream_url, debug_node_key.debug_node_name)) {
EventReply event_reply;
TF_RETURN_IF_ERROR(ReceiveEventReplyProtoThroughGrpcStream(
&event_reply, grpc_stream_url));
// TODO(cais): Support new tensor value carried in the EventReply for
// overriding the value of the tensor being published.
}
return Status::OK();
}
}
// static
Status DebugGrpcIO::ReceiveEventReplyProtoThroughGrpcStream(
EventReply* event_reply, const string& grpc_stream_url) {
std::shared_ptr<DebugGrpcChannel> debug_grpc_channel;
{
mutex_lock l(streams_mu);
std::unordered_map<string, std::shared_ptr<DebugGrpcChannel>>*
stream_channels = GetStreamChannels();
debug_grpc_channel = (*stream_channels)[grpc_stream_url];
}
if (debug_grpc_channel->ReadEventReply(event_reply)) {
return Status::OK();
} else {
return errors::Cancelled(strings::StrCat(
"Reading EventReply from stream URL ", grpc_stream_url, " failed."));
}
}
// static
Status DebugGrpcIO::SendEventProtoThroughGrpcStream(
const Event& event_proto, const string& grpc_stream_url) {
@ -821,9 +862,7 @@ Status DebugGrpcIO::SendEventProtoThroughGrpcStream(
debug_grpc_channel.reset(new DebugGrpcChannel(server_stream_addr));
TF_RETURN_IF_ERROR(
debug_grpc_channel->Connect(channel_connection_timeout_micros));
(*stream_channels)[grpc_stream_url] = debug_grpc_channel;
CreateEmptyEnabledSet(grpc_stream_url);
} else {
debug_grpc_channel = (*stream_channels)[grpc_stream_url];
}
@ -838,16 +877,22 @@ Status DebugGrpcIO::SendEventProtoThroughGrpcStream(
return Status::OK();
}
// static
bool DebugGrpcIO::IsGateOpen(const string& watch_key,
const string& grpc_debug_url) {
std::unordered_map<string, std::unordered_set<string>>* enabled_watch_keys =
GetEnabledWatchKeys();
if (enabled_watch_keys->find(grpc_debug_url) == enabled_watch_keys->end()) {
bool DebugGrpcIO::IsReadGateOpen(const string& grpc_debug_url,
const string& watch_key) {
const DebugNodeName2State* enabled_node_to_state =
GetEnabledDebugOpStatesAtUrl(grpc_debug_url);
return enabled_node_to_state->find(watch_key) != enabled_node_to_state->end();
}
bool DebugGrpcIO::IsWriteGateOpen(const string& grpc_debug_url,
const string& watch_key) {
const DebugNodeName2State* enabled_node_to_state =
GetEnabledDebugOpStatesAtUrl(grpc_debug_url);
auto it = enabled_node_to_state->find(watch_key);
if (it == enabled_node_to_state->end()) {
return false;
} else {
const auto& url_enabled = (*enabled_watch_keys)[grpc_debug_url];
return url_enabled.find(watch_key) != url_enabled.end();
return it->second == EventReply::DebugOpStateChange::READ_WRITE;
}
}
@ -871,56 +916,46 @@ Status DebugGrpcIO::CloseGrpcStream(const string& grpc_stream_url) {
}
// static
std::unordered_map<string, std::unordered_set<string>>*
DebugGrpcIO::GetEnabledWatchKeys() {
static std::unordered_map<string, std::unordered_set<string>>*
enabled_watch_keys =
new std::unordered_map<string, std::unordered_set<string>>();
return enabled_watch_keys;
std::unordered_map<string, DebugGrpcIO::DebugNodeName2State>*
DebugGrpcIO::GetEnabledDebugOpStates() {
static std::unordered_map<string, DebugNodeName2State>*
enabled_debug_op_states =
new std::unordered_map<string, DebugNodeName2State>();
return enabled_debug_op_states;
}
// static
void DebugGrpcIO::EnableWatchKey(const string& grpc_debug_url,
const string& watch_key) {
std::unordered_map<string, std::unordered_set<string>>* enabled_watch_keys =
GetEnabledWatchKeys();
if (enabled_watch_keys->find(grpc_debug_url) == enabled_watch_keys->end()) {
CreateEmptyEnabledSet(grpc_debug_url);
DebugGrpcIO::DebugNodeName2State* DebugGrpcIO::GetEnabledDebugOpStatesAtUrl(
const string& grpc_debug_url) {
std::unordered_map<string, DebugNodeName2State>* states =
GetEnabledDebugOpStates();
if (states->find(grpc_debug_url) == states->end()) {
DebugNodeName2State url_enabled_debug_op_states;
(*states)[grpc_debug_url] = url_enabled_debug_op_states;
}
(*enabled_watch_keys)[grpc_debug_url].insert(watch_key);
return &(*states)[grpc_debug_url];
}
// static
void DebugGrpcIO::DisableWatchKey(const string& grpc_debug_url,
const string& watch_key) {
std::unordered_map<string, std::unordered_set<string>>* enabled_watch_keys =
GetEnabledWatchKeys();
if (enabled_watch_keys->find(grpc_debug_url) == enabled_watch_keys->end()) {
LOG(WARNING) << "Attempt to disable a watch key for an unregistered gRPC "
<< "debug URL: " << grpc_debug_url;
} else {
std::unordered_set<string>& url_enabled =
(*enabled_watch_keys)[grpc_debug_url];
if (url_enabled.find(watch_key) == url_enabled.end()) {
LOG(WARNING) << "Attempt to disable a watch key that is not currently "
<< "enabled at " << grpc_debug_url << ": " << watch_key;
void DebugGrpcIO::SetDebugNodeKeyGrpcState(
const string& grpc_debug_url, const string& watch_key,
const EventReply::DebugOpStateChange::State new_state) {
DebugNodeName2State* states = GetEnabledDebugOpStatesAtUrl(grpc_debug_url);
if (new_state == EventReply::DebugOpStateChange::DISABLED) {
if (states->find(watch_key) == states->end()) {
LOG(ERROR) << "Attempt to disable a watch key that is not currently "
<< "enabled at " << grpc_debug_url << ": " << watch_key;
} else {
url_enabled.erase(watch_key);
states->erase(watch_key);
}
} else if (new_state != EventReply::DebugOpStateChange::STATE_UNSPECIFIED) {
(*states)[watch_key] = new_state;
}
}
// static
void DebugGrpcIO::ClearEnabledWatchKeys() { GetEnabledWatchKeys()->clear(); }
// static
void DebugGrpcIO::CreateEmptyEnabledSet(const string& grpc_debug_url) {
std::unordered_map<string, std::unordered_set<string>>* enabled_watch_keys =
GetEnabledWatchKeys();
if (enabled_watch_keys->find(grpc_debug_url) == enabled_watch_keys->end()) {
std::unordered_set<string> empty_watch_keys;
(*enabled_watch_keys)[grpc_debug_url] = empty_watch_keys;
}
void DebugGrpcIO::ClearEnabledWatchKeys() {
GetEnabledDebugOpStates()->clear();
}
#endif // #ifndef PLATFORM_WINDOWS

View File

@ -16,8 +16,13 @@ limitations under the License.
#ifndef TENSORFLOW_DEBUG_IO_UTILS_H_
#define TENSORFLOW_DEBUG_IO_UTILS_H_
#include <cstddef>
#include <functional>
#include <memory>
#include <string>
#include <unordered_map>
#include <unordered_set>
#include <vector>
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/graph/graph.h"
@ -49,6 +54,9 @@ struct DebugNodeKey {
// ,job_localhost,replica_0,task_0,cpu_0.
static const string DeviceNameToDevicePath(const string& device_name);
bool operator==(const DebugNodeKey& other) const;
bool operator!=(const DebugNodeKey& other) const;
const string device_name;
const string node_name;
const int32 output_slot;
@ -219,6 +227,19 @@ class DebugFileIO {
} // namespace tensorflow
namespace std {
template <>
struct hash<::tensorflow::DebugNodeKey> {
size_t operator()(const ::tensorflow::DebugNodeKey& k) const {
return ::tensorflow::Hash64(
::tensorflow::strings::StrCat(k.device_name, ":", k.node_name, ":",
k.output_slot, ":", k.debug_op, ":"));
}
};
} // namespace std
// TODO(cais): Support grpc:// debug URLs in open source once Python grpc
// genrule becomes available. See b/23796275.
#ifndef PLATFORM_WINDOWS
@ -252,7 +273,8 @@ class DebugGrpcChannel {
// Write an Event proto to the debug gRPC stream.
//
// Thread-safety: Safe with respect to other calls to the same method and
// call to Close().
// calls to ReadEventReply() and Close().
//
// Args:
// event: The event proto to be written to the stream.
//
@ -260,6 +282,19 @@ class DebugGrpcChannel {
// True iff the write is successful.
bool WriteEvent(const Event& event);
// Read an EventReply proto from the debug gRPC stream.
//
// This method blocks and waits for an EventReply from the server.
// Thread-safety: Safe with respect to other calls to the same method and
// calls to WriteEvent() and Close().
//
// Args:
// event_reply: the to-be-modified EventReply proto passed as reference.
//
// Returns:
// True iff the read is successful.
bool ReadEventReply(EventReply* event_reply);
// Receive EventReplies from server (if any) and close the stream and the
// channel.
Status ReceiveServerRepliesAndClose();
@ -294,48 +329,49 @@ class DebugGrpcIO {
static Status SendEventProtoThroughGrpcStream(const Event& event_proto,
const string& grpc_stream_url);
// Checks whether a debug watch key is allowed to send data to a given grpc://
// debug URL given the current gating status.
//
// Args:
// watch_key: debug tensor watch key, in the format of
// tensor_name:debug_op, e.g., "Weights:0:DebugIdentity".
// grpc_debug_url: the debug URL, e.g., "grpc://localhost:3333",
//
// Returns:
// Whether the sending of debug data to grpc_debug_url should
// proceed.
static bool IsGateOpen(const string& watch_key, const string& grpc_debug_url);
// Receive an EventReply proto through a debug gRPC stream.
static Status ReceiveEventReplyProtoThroughGrpcStream(
EventReply* event_reply, const string& grpc_stream_url);
// Check whether a debug watch key is read-activated at a given gRPC URL.
static bool IsReadGateOpen(const string& grpc_debug_url,
const string& watch_key);
// Check whether a debug watch key is write-activated (i.e., read- and
// write-activated) at a given gRPC URL.
static bool IsWriteGateOpen(const string& grpc_debug_url,
const string& watch_key);
// Closes a gRPC stream to the given address, if it exists.
// Thread-safety: Safe with respect to other calls to the same method and
// calls to SendTensorThroughGrpcStream().
static Status CloseGrpcStream(const string& grpc_stream_url);
// Enables a debug watch key at a grpc:// debug URL.
static void EnableWatchKey(const string& grpc_debug_url,
const string& watch_key);
// Disables a debug watch key at a grpc:// debug URL.
static void DisableWatchKey(const string& grpc_debug_url,
const string& watch_key);
// Set the gRPC state of a debug node key.
// TODO(cais): Include device information in watch_key.
static void SetDebugNodeKeyGrpcState(
const string& grpc_debug_url, const string& watch_key,
const EventReply::DebugOpStateChange::State new_state);
private:
using DebugNodeName2State =
std::unordered_map<string, EventReply::DebugOpStateChange::State>;
// Returns a global map from grpc debug URLs to the corresponding
// DebugGrpcChannels.
static std::unordered_map<string, std::shared_ptr<DebugGrpcChannel>>*
GetStreamChannels();
// Returns a global map from grpc debug URLs to the enabled gated debug nodes.
// The keys are grpc:// URLs of the debug servers, e.g., "grpc://foo:3333".
// Each value element of the value has the format
// <node_name>:<output_slot>:<debug_op>", e.g.,
// "Weights_1:0:DebugNumericSummary".
static std::unordered_map<string, std::unordered_set<string>>*
GetEnabledWatchKeys();
// Returns a map from debug URL to a map from debug op name to enabled state.
static std::unordered_map<string, DebugNodeName2State>*
GetEnabledDebugOpStates();
// Returns a map from debug op names to enabled state, for a given debug URL.
static DebugNodeName2State* GetEnabledDebugOpStatesAtUrl(
const string& grpc_debug_url);
// Clear enabled debug op state from all debug URLs (if any).
static void ClearEnabledWatchKeys();
static void CreateEmptyEnabledSet(const string& grpc_debug_url);
static mutex streams_mu;
static int64 channel_connection_timeout_micros;

View File

@ -13,6 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include <unordered_set>
#include "tensorflow/core/debug/debug_io_utils.h"
#include "tensorflow/core/debug/debugger_event_metadata.pb.h"
@ -62,6 +64,39 @@ TEST_F(DebugIOUtilsTest, ConstructDebugNodeKey) {
debug_node_key.device_path);
}
TEST_F(DebugIOUtilsTest, EqualityOfDebugNodeKeys) {
const DebugNodeKey debug_node_key_1("/job:worker/replica:1/task:0/gpu:2",
"hidden_1/MatMul", 0, "DebugIdentity");
const DebugNodeKey debug_node_key_2("/job:worker/replica:1/task:0/gpu:2",
"hidden_1/MatMul", 0, "DebugIdentity");
const DebugNodeKey debug_node_key_3("/job:worker/replica:1/task:0/gpu:2",
"hidden_1/BiasAdd", 0, "DebugIdentity");
const DebugNodeKey debug_node_key_4("/job:worker/replica:1/task:0/gpu:2",
"hidden_1/MatMul", 0,
"DebugNumericSummary");
EXPECT_EQ(debug_node_key_1, debug_node_key_2);
EXPECT_NE(debug_node_key_1, debug_node_key_3);
EXPECT_NE(debug_node_key_1, debug_node_key_4);
EXPECT_NE(debug_node_key_3, debug_node_key_4);
}
TEST_F(DebugIOUtilsTest, DebugNodeKeysIsHashable) {
const DebugNodeKey debug_node_key_1("/job:worker/replica:1/task:0/gpu:2",
"hidden_1/MatMul", 0, "DebugIdentity");
const DebugNodeKey debug_node_key_2("/job:worker/replica:1/task:0/gpu:2",
"hidden_1/MatMul", 0, "DebugIdentity");
const DebugNodeKey debug_node_key_3("/job:worker/replica:1/task:0/gpu:2",
"hidden_1/BiasAdd", 0, "DebugIdentity");
std::unordered_set<DebugNodeKey> keys;
keys.insert(debug_node_key_1);
ASSERT_EQ(1, keys.size());
keys.insert(debug_node_key_3);
ASSERT_EQ(2, keys.size());
keys.erase(debug_node_key_2);
ASSERT_EQ(1, keys.size());
}
TEST_F(DebugIOUtilsTest, DumpFloatTensorToFileSunnyDay) {
Initialize();

View File

@ -17,6 +17,7 @@ syntax = "proto3";
package tensorflow;
import "tensorflow/core/framework/tensor.proto";
import "tensorflow/core/util/event.proto";
// Reply message from EventListener to the client, i.e., to the source of the
@ -24,18 +25,25 @@ import "tensorflow/core/util/event.proto";
// TensorFlow graph being executed.
message EventReply {
message DebugOpStateChange {
enum Change {
DISABLE = 0;
ENABLE = 1;
enum State {
STATE_UNSPECIFIED = 0;
DISABLED = 1;
READ_ONLY = 2;
READ_WRITE = 3;
}
Change change = 1;
State state = 1;
string node_name = 2;
int32 output_slot = 3;
string debug_op = 4;
}
repeated DebugOpStateChange debug_op_state_changes = 1;
// New tensor value to override the current tensor value with.
TensorProto tensor = 2;
// TODO(cais): Make use of this field to implement overriding of tensor value
// during debugging.
}
// EventListener: Receives Event protos, e.g., from debugged TensorFlow

View File

@ -259,10 +259,6 @@ class DebugNumericSummaryOpTest : public OpsTestBase {
#if defined(PLATFORM_GOOGLE)
void ClearEnabledWatchKeys() { DebugGrpcIO::ClearEnabledWatchKeys(); }
void CreateEmptyEnabledSet(const string& grpc_debug_url) {
DebugGrpcIO::CreateEmptyEnabledSet(grpc_debug_url);
}
#endif
};
@ -604,7 +600,6 @@ TEST_F(DebugNumericSummaryOpTest, BoolSuccess) {
#if defined(PLATFORM_GOOGLE)
TEST_F(DebugNumericSummaryOpTest, DisabledDueToEmptyEnabledSet) {
ClearEnabledWatchKeys();
CreateEmptyEnabledSet("grpc://server:3333");
std::vector<string> debug_urls({"grpc://server:3333"});
TF_ASSERT_OK(InitGated(DT_FLOAT, debug_urls));
@ -617,8 +612,9 @@ TEST_F(DebugNumericSummaryOpTest, DisabledDueToEmptyEnabledSet) {
TEST_F(DebugNumericSummaryOpTest, DisabledDueToNonMatchingWatchKey) {
ClearEnabledWatchKeys();
DebugGrpcIO::EnableWatchKey("grpc://server:3333",
"FakeTensor:1:DebugNumeriSummary");
DebugGrpcIO::SetDebugNodeKeyGrpcState(
"grpc://server:3333", "FakeTensor:1:DebugNumeriSummary",
EventReply::DebugOpStateChange::READ_ONLY);
std::vector<string> debug_urls({"grpc://server:3333"});
TF_ASSERT_OK(InitGated(DT_FLOAT, debug_urls));

View File

@ -31,16 +31,20 @@ from tensorflow.core.debug import debug_service_pb2
from tensorflow.core.framework import graph_pb2
from tensorflow.python.debug.lib import debug_data
from tensorflow.python.debug.lib import debug_service_pb2_grpc
from tensorflow.python.platform import tf_logging as logging
DebugWatch = collections.namedtuple("DebugWatch",
["node_name", "output_slot", "debug_op"])
def _watch_key_event_reply(to_enable, node_name, output_slot, debug_op):
"""Make EventReply proto to represent a request to watch/unwatch a debug op.
def _watch_key_event_reply(new_state, node_name, output_slot, debug_op):
"""Make `EventReply` proto to represent a request to watch/unwatch a debug op.
Args:
to_enable: (`bool`) whether the request is to enable the watch key.
new_state: (`debug_service_pb2.EventReply.DebugOpStateChange.State`) the new
state to set the debug node to, i.e., whether the debug node will become
disabled under the grpc mode (`DISABLED`), become a watchpoint
(`READ_ONLY`) or become a breakpoint (`READ_WRITE`).
node_name: (`str`) name of the node.
output_slot: (`int`) output slot of the tensor.
debug_op: (`str`) the debug op attached to node_name:output_slot tensor to
@ -51,9 +55,7 @@ def _watch_key_event_reply(to_enable, node_name, output_slot, debug_op):
"""
event_reply = debug_service_pb2.EventReply()
state_change = event_reply.debug_op_state_changes.add()
state_change.change = (
debug_service_pb2.EventReply.DebugOpStateChange.ENABLE
if to_enable else debug_service_pb2.EventReply.DebugOpStateChange.DISABLE)
state_change.state = new_state
state_change.node_name = node_name
state_change.output_slot = output_slot
state_change.debug_op = debug_op
@ -125,6 +127,7 @@ class EventListenerBaseServicer(debug_service_pb2_grpc.EventListenerServicer):
self._event_reply_queue = queue.Queue()
self._gated_grpc_debug_watches = set()
self._breakpoints = set()
def SendEvents(self, request_iterator, context):
"""Implementation of the SendEvents service method.
@ -171,11 +174,30 @@ class EventListenerBaseServicer(debug_service_pb2_grpc.EventListenerServicer):
maybe_tensor_event = self._process_tensor_event_in_chunks(
event, tensor_chunks)
if maybe_tensor_event:
stream_handler.on_value_event(maybe_tensor_event)
event_reply = stream_handler.on_value_event(maybe_tensor_event)
if event_reply is not None:
yield event_reply
# The server writes EventReply messages, if any.
while not self._event_reply_queue.empty():
yield self._event_reply_queue.get()
event_reply = self._event_reply_queue.get()
for state_change in event_reply.debug_op_state_changes:
if (state_change.state ==
debug_service_pb2.EventReply.DebugOpStateChange.READ_WRITE):
logging.info("Adding breakpoint %s:%d:%s", state_change.node_name,
state_change.output_slot, state_change.debug_op)
self._breakpoints.add(
(state_change.node_name, state_change.output_slot,
state_change.debug_op))
elif (state_change.state ==
debug_service_pb2.EventReply.DebugOpStateChange.DISABLED):
logging.info("Removing watchpoint or breakpoint: %s:%d:%s",
state_change.node_name, state_change.output_slot,
state_change.debug_op)
self._breakpoints.discard(
(state_change.node_name, state_change.output_slot,
state_change.debug_op))
yield event_reply
def _process_tensor_event_in_chunks(self, event, tensor_chunks):
"""Possibly reassemble event chunks.
@ -336,8 +358,8 @@ class EventListenerBaseServicer(debug_service_pb2_grpc.EventListenerServicer):
finally:
self._server_lock.release()
def request_watch(self, node_name, output_slot, debug_op):
"""Request enabling a debug tensor watch.
def request_watch(self, node_name, output_slot, debug_op, breakpoint=False):
"""Request enabling a debug tensor watchpoint or breakpoint.
This will let the server send a EventReply to the client side
(i.e., the debugged TensorFlow runtime process) to request adding a watch
@ -355,12 +377,19 @@ class EventListenerBaseServicer(debug_service_pb2_grpc.EventListenerServicer):
output_slot: (`int`) output slot index of the tensor to watch.
debug_op: (`str`) name of the debug op to enable. This should not include
any attribute substrings.
breakpoint: (`bool`) Iff `True`, the debug op will block and wait until it
receives an `EventReply` response from the server. The `EventReply`
proto may carry a TensorProto that modifies the value of the debug op's
output tensor.
"""
self._event_reply_queue.put(
_watch_key_event_reply(True, node_name, output_slot, debug_op))
_watch_key_event_reply(
debug_service_pb2.EventReply.DebugOpStateChange.READ_WRITE
if breakpoint else debug_service_pb2.EventReply.DebugOpStateChange.
READ_ONLY, node_name, output_slot, debug_op))
def request_unwatch(self, node_name, output_slot, debug_op):
"""Request disabling a debug tensor watch.
"""Request disabling a debug tensor watchpoint or breakpoint.
The request will take effect on the next debugged `Session.run()` call.
@ -374,7 +403,18 @@ class EventListenerBaseServicer(debug_service_pb2_grpc.EventListenerServicer):
any attribute substrings.
"""
self._event_reply_queue.put(
_watch_key_event_reply(False, node_name, output_slot, debug_op))
_watch_key_event_reply(debug_service_pb2.EventReply.DebugOpStateChange.
DISABLED, node_name, output_slot, debug_op))
@property
def breakpoints(self):
"""Get a set of the currently-activated breakpoints.
Returns:
A `set` of 3-tuples: (node_name, output_slot, debug_op), e.g.,
{("MatMul", 0, "DebugIdentity")}.
"""
return self._breakpoints
def gated_grpc_debug_watches(self):
"""Get the list of debug watches with attribute gated_grpc=True.

View File

@ -31,6 +31,7 @@ import time
import portpicker
from tensorflow.core.debug import debug_service_pb2
from tensorflow.core.protobuf import config_pb2
from tensorflow.core.util import event_pb2
from tensorflow.python.client import session
@ -138,13 +139,26 @@ class EventListenerTestStreamHandler(
Args:
event: The Event proto carrying a tensor value.
Returns:
If the debug node belongs to the set of currently activated breakpoints,
a `EventReply` proto will be returned.
"""
if self._dump_dir:
self._write_value_event(event)
else:
value = event.summary.value[0]
tensor_value = debug_data.load_tensor_from_event(event)
self._event_listener_servicer.debug_tensor_values[value.node_name].append(
debug_data.load_tensor_from_event(event))
tensor_value)
items = event.summary.value[0].node_name.split(":")
node_name = items[0]
output_slot = int(items[1])
debug_op = items[2]
if ((node_name, output_slot, debug_op) in
self._event_listener_servicer.breakpoints):
return debug_service_pb2.EventReply()
def _try_makedirs(self, dir_path):
if not os.path.isdir(dir_path):

View File

@ -556,6 +556,55 @@ class SessionDebugGrpcGatingTest(test_util.TensorFlowTestCase):
[50 + 5.0 * i],
self._server_2.debug_tensor_values["v:0:DebugIdentity"])
def testToggleBreakpointWorks(self):
with session.Session(config=no_rewrite_session_config()) as sess:
v = variables.Variable(50.0, name="v")
delta = constant_op.constant(5.0, name="delta")
inc_v = state_ops.assign_add(v, delta, name="inc_v")
sess.run(v.initializer)
run_metadata = config_pb2.RunMetadata()
run_options = config_pb2.RunOptions(output_partition_graphs=True)
debug_utils.watch_graph(
run_options,
sess.graph,
debug_ops=["DebugIdentity(gated_grpc=true)"],
debug_urls=[self._debug_server_url_1])
for i in xrange(4):
self._server_1.clear_data()
# N.B.: These requests will be fulfilled not in this debugged
# Session.run() invocation, but in the next one.
if i in (0, 2):
# Enable breakpoint at delta:0:DebugIdentity in runs 0 and 2.
self._server_1.request_watch(
"delta", 0, "DebugIdentity", breakpoint=True)
else:
# Disable the breakpoint in runs 1 and 3.
self._server_1.request_unwatch("delta", 0, "DebugIdentity")
output = sess.run(inc_v, options=run_options, run_metadata=run_metadata)
self.assertAllClose(50.0 + 5.0 * (i + 1), output)
if i in (0, 2):
# After the end of runs 0 and 2, the server has received the requests
# to enable the breakpoint at delta:0:DebugIdentity. So the server
# should keep track of the correct breakpoints.
self.assertSetEqual({("delta", 0, "DebugIdentity")},
self._server_1.breakpoints)
else:
# During runs 1 and 3, the server should have received the published
# debug tensor delta:0:DebugIdentity. The breakpoint should have been
# unblocked by EventReply reponses from the server.
self.assertAllClose(
[5.0],
self._server_1.debug_tensor_values["delta:0:DebugIdentity"])
# After the runs, the server should have properly removed the
# breakpoints due to the request_unwatch calls.
self.assertSetEqual(set(), self._server_1.breakpoints)
def testGetGrpcDebugWatchesReturnsCorrectAnswer(self):
with session.Session() as sess:
v = variables.Variable(50.0, name="v")