tfdbg: extend grpc_debug_server protocol for interactive debugging

Previously, a grpc-gated debug op has two modes: DISABLED and ENABLED.
This CL splits the ENABLED state into two states: READ_ONLY and READ_WRITE.

* READ_ONLY is equivalent to the previous ENABLED state, wherein a debug op
  publishes debug tensor to the grpc debug server and proceeds. It can be
  regarded as a "watchpoint" that doesn't block execution.
* READ_WRITE is a "breakpoint". In addition to publishing the debug tensor,
  it blocks and awaits a EventReply proto response from the grpc debug server
  before proceeding.

PiperOrigin-RevId: 164987725
This commit is contained in:
Shanqing Cai 2017-08-11 09:37:02 -07:00 committed by TensorFlower Gardener
parent c0f9b0a91e
commit 3c482c66b5
11 changed files with 439 additions and 154 deletions

View File

@ -65,10 +65,6 @@ class GrpcDebugTest : public ::testing::Test {
void ClearEnabledWatchKeys() { DebugGrpcIO::ClearEnabledWatchKeys(); } void ClearEnabledWatchKeys() { DebugGrpcIO::ClearEnabledWatchKeys(); }
void CreateEmptyEnabledSet(const string& grpc_debug_url) {
DebugGrpcIO::CreateEmptyEnabledSet(grpc_debug_url);
}
const int64 GetChannelConnectionTimeoutMicros() { const int64 GetChannelConnectionTimeoutMicros() {
return DebugGrpcIO::channel_connection_timeout_micros; return DebugGrpcIO::channel_connection_timeout_micros;
} }
@ -261,7 +257,7 @@ TEST_F(GrpcDebugTest, SendMultipleDebugTensorsSynchronizedViaGrpcTest) {
} }
} }
TEST_F(GrpcDebugTest, SendeDebugTensorsThroughMultipleRoundsUsingGrpcGating) { TEST_F(GrpcDebugTest, SendDebugTensorsThroughMultipleRoundsUsingGrpcGating) {
// Prepare the tensor to send. // Prepare the tensor to send.
const DebugNodeKey kDebugNodeKey("/job:localhost/replica:0/task:0/cpu:0", const DebugNodeKey kDebugNodeKey("/job:localhost/replica:0/task:0/cpu:0",
"test_namescope/test_node", 0, "test_namescope/test_node", 0,
@ -283,8 +279,60 @@ TEST_F(GrpcDebugTest, SendeDebugTensorsThroughMultipleRoundsUsingGrpcGating) {
TF_ASSERT_OK(DebugIO::PublishDebugTensor(kDebugNodeKey, tensor, wall_time, TF_ASSERT_OK(DebugIO::PublishDebugTensor(kDebugNodeKey, tensor, wall_time,
urls, enable_gated_grpc)); urls, enable_gated_grpc));
server_data_.server->RequestDebugOpStateChangeAtNextStream(i == 0, server_data_.server->RequestDebugOpStateChangeAtNextStream(
kDebugNodeKey); i == 0 ? EventReply::DebugOpStateChange::READ_ONLY
: EventReply::DebugOpStateChange::DISABLED,
kDebugNodeKey);
// Close the debug gRPC stream.
Status close_status = DebugIO::CloseDebugURL(server_data_.url);
ASSERT_TRUE(close_status.ok());
// Check dumped files according to the expected gating results.
if (i < 2) {
ASSERT_EQ(1, server_data_.server->node_names.size());
ASSERT_EQ(1, server_data_.server->output_slots.size());
ASSERT_EQ(1, server_data_.server->debug_ops.size());
EXPECT_EQ(kDebugNodeKey.device_name,
server_data_.server->device_names[0]);
EXPECT_EQ(kDebugNodeKey.node_name, server_data_.server->node_names[0]);
EXPECT_EQ(kDebugNodeKey.output_slot,
server_data_.server->output_slots[0]);
EXPECT_EQ(kDebugNodeKey.debug_op, server_data_.server->debug_ops[0]);
} else {
ASSERT_EQ(0, server_data_.server->node_names.size());
}
}
}
TEST_F(GrpcDebugTest, SendDebugTensorsThroughMultipleRoundsUnderReadWriteMode) {
// Prepare the tensor to send.
const DebugNodeKey kDebugNodeKey("/job:localhost/replica:0/task:0/cpu:0",
"test_namescope/test_node", 0,
"DebugIdentity");
Tensor tensor(DT_INT32, TensorShape({1, 1}));
tensor.flat<int>()(0) = 42;
const std::vector<string> urls({server_data_.url});
for (int i = 0; i < 3; ++i) {
server_data_.server->ClearReceivedDebugData();
const uint64 wall_time = Env::Default()->NowMicros();
// On the 1st send (i == 0), gating is disabled, so data should be sent.
// On the 2nd send (i == 1), gating is enabled, and the server has enabled
// the watch key in the previous send (READ_WRITE), so data should be
// sent. In this iteration, the server response with a EventReply proto to
// unblock the debug node.
// On the 3rd send (i == 2), gating is enabled, but the server has disabled
// the watch key in the previous send, so data should not be sent.
const bool enable_gated_grpc = (i != 0);
TF_ASSERT_OK(DebugIO::PublishDebugTensor(kDebugNodeKey, tensor, wall_time,
urls, enable_gated_grpc));
server_data_.server->RequestDebugOpStateChangeAtNextStream(
i == 0 ? EventReply::DebugOpStateChange::READ_WRITE
: EventReply::DebugOpStateChange::DISABLED,
kDebugNodeKey);
// Close the debug gRPC stream. // Close the debug gRPC stream.
Status close_status = DebugIO::CloseDebugURL(server_data_.url); Status close_status = DebugIO::CloseDebugURL(server_data_.url);
@ -308,8 +356,6 @@ TEST_F(GrpcDebugTest, SendeDebugTensorsThroughMultipleRoundsUsingGrpcGating) {
} }
TEST_F(GrpcDebugTest, TestGateDebugNodeOnEmptyEnabledSet) { TEST_F(GrpcDebugTest, TestGateDebugNodeOnEmptyEnabledSet) {
CreateEmptyEnabledSet("grpc://localhost:3333");
ASSERT_FALSE(DebugIO::IsDebugNodeGateOpen("foo:0:DebugIdentity", ASSERT_FALSE(DebugIO::IsDebugNodeGateOpen("foo:0:DebugIdentity",
{"grpc://localhost:3333"})); {"grpc://localhost:3333"}));
@ -322,8 +368,12 @@ TEST_F(GrpcDebugTest, TestGateDebugNodeOnNonEmptyEnabledSet) {
const string kGrpcUrl1 = "grpc://localhost:3333"; const string kGrpcUrl1 = "grpc://localhost:3333";
const string kGrpcUrl2 = "grpc://localhost:3334"; const string kGrpcUrl2 = "grpc://localhost:3334";
DebugGrpcIO::EnableWatchKey(kGrpcUrl1, "foo:0:DebugIdentity"); DebugGrpcIO::SetDebugNodeKeyGrpcState(
DebugGrpcIO::EnableWatchKey(kGrpcUrl1, "bar:0:DebugIdentity"); kGrpcUrl1, "foo:0:DebugIdentity",
EventReply::DebugOpStateChange::READ_ONLY);
DebugGrpcIO::SetDebugNodeKeyGrpcState(
kGrpcUrl1, "bar:0:DebugIdentity",
EventReply::DebugOpStateChange::READ_ONLY);
ASSERT_FALSE( ASSERT_FALSE(
DebugIO::IsDebugNodeGateOpen("foo:1:DebugIdentity", {kGrpcUrl1})); DebugIO::IsDebugNodeGateOpen("foo:1:DebugIdentity", {kGrpcUrl1}));
@ -350,9 +400,12 @@ TEST_F(GrpcDebugTest, TestGateDebugNodeOnMultipleEmptyEnabledSets) {
const string kGrpcUrl2 = "grpc://localhost:3334"; const string kGrpcUrl2 = "grpc://localhost:3334";
const string kGrpcUrl3 = "grpc://localhost:3335"; const string kGrpcUrl3 = "grpc://localhost:3335";
DebugGrpcIO::EnableWatchKey(kGrpcUrl1, "foo:0:DebugIdentity"); DebugGrpcIO::SetDebugNodeKeyGrpcState(
DebugGrpcIO::EnableWatchKey(kGrpcUrl2, "bar:0:DebugIdentity"); kGrpcUrl1, "foo:0:DebugIdentity",
CreateEmptyEnabledSet(kGrpcUrl3); EventReply::DebugOpStateChange::READ_ONLY);
DebugGrpcIO::SetDebugNodeKeyGrpcState(
kGrpcUrl2, "bar:0:DebugIdentity",
EventReply::DebugOpStateChange::READ_ONLY);
ASSERT_TRUE(DebugIO::IsDebugNodeGateOpen("foo:0:DebugIdentity", {kGrpcUrl1})); ASSERT_TRUE(DebugIO::IsDebugNodeGateOpen("foo:0:DebugIdentity", {kGrpcUrl1}));
ASSERT_TRUE(DebugIO::IsDebugNodeGateOpen("bar:0:DebugIdentity", {kGrpcUrl2})); ASSERT_TRUE(DebugIO::IsDebugNodeGateOpen("bar:0:DebugIdentity", {kGrpcUrl2}));
@ -375,7 +428,9 @@ TEST_F(GrpcDebugTest, TestGateDebugNodeOnMultipleEmptyEnabledSets) {
} }
TEST_F(GrpcDebugTest, TestGateDebugNodeOnNonEmptyEnabledSetAndEmptyURLs) { TEST_F(GrpcDebugTest, TestGateDebugNodeOnNonEmptyEnabledSetAndEmptyURLs) {
DebugGrpcIO::EnableWatchKey("grpc://localhost:3333", "foo:0:DebugIdentity"); DebugGrpcIO::SetDebugNodeKeyGrpcState(
"grpc://localhost:3333", "foo:0:DebugIdentity",
EventReply::DebugOpStateChange::READ_ONLY);
std::vector<string> debug_urls_1; std::vector<string> debug_urls_1;
ASSERT_FALSE( ASSERT_FALSE(
@ -385,7 +440,6 @@ TEST_F(GrpcDebugTest, TestGateDebugNodeOnNonEmptyEnabledSetAndEmptyURLs) {
TEST_F(GrpcDebugTest, TestGateCopyNodeOnEmptyEnabledSet) { TEST_F(GrpcDebugTest, TestGateCopyNodeOnEmptyEnabledSet) {
const string kGrpcUrl1 = "grpc://localhost:3333"; const string kGrpcUrl1 = "grpc://localhost:3333";
const string kWatch1 = "foo:0:DebugIdentity"; const string kWatch1 = "foo:0:DebugIdentity";
CreateEmptyEnabledSet(kGrpcUrl1);
ASSERT_FALSE(DebugIO::IsCopyNodeGateOpen( ASSERT_FALSE(DebugIO::IsCopyNodeGateOpen(
{DebugWatchAndURLSpec(kWatch1, kGrpcUrl1, true)})); {DebugWatchAndURLSpec(kWatch1, kGrpcUrl1, true)}));
@ -404,9 +458,8 @@ TEST_F(GrpcDebugTest, TestGateCopyNodeOnNonEmptyEnabledSet) {
const string kGrpcUrl2 = "grpc://localhost:3334"; const string kGrpcUrl2 = "grpc://localhost:3334";
const string kWatch1 = "foo:0:DebugIdentity"; const string kWatch1 = "foo:0:DebugIdentity";
const string kWatch2 = "foo:1:DebugIdentity"; const string kWatch2 = "foo:1:DebugIdentity";
CreateEmptyEnabledSet(kGrpcUrl1); DebugGrpcIO::SetDebugNodeKeyGrpcState(
CreateEmptyEnabledSet(kGrpcUrl2); kGrpcUrl1, kWatch1, EventReply::DebugOpStateChange::READ_ONLY);
DebugGrpcIO::EnableWatchKey(kGrpcUrl1, kWatch1);
ASSERT_TRUE(DebugIO::IsCopyNodeGateOpen( ASSERT_TRUE(DebugIO::IsCopyNodeGateOpen(
{DebugWatchAndURLSpec(kWatch1, kGrpcUrl1, true)})); {DebugWatchAndURLSpec(kWatch1, kGrpcUrl1, true)}));

View File

@ -72,27 +72,44 @@ namespace test {
output_slots.push_back(metadata.output_slot()); output_slots.push_back(metadata.output_slot());
debug_ops.push_back(debug_op); debug_ops.push_back(debug_op);
debug_tensors.push_back(tensor); debug_tensors.push_back(tensor);
// If the debug node is currently in the READ_WRITE mode, send an
// EventReply to 1) unblock the execution and 2) optionally modify the
// value.
const DebugNodeKey debug_node_key(metadata.device(), node_name,
metadata.output_slot(), debug_op);
if (write_enabled_debug_node_keys_.find(debug_node_key) !=
write_enabled_debug_node_keys_.end()) {
stream->Write(EventReply());
}
} }
} }
{ {
mutex_lock l(changes_mu_); mutex_lock l(states_mu_);
for (size_t i = 0; i < changes_to_enable_.size(); ++i) { for (size_t i = 0; i < new_states_.size(); ++i) {
EventReply event_reply; EventReply event_reply;
EventReply::DebugOpStateChange* change = EventReply::DebugOpStateChange* change =
event_reply.add_debug_op_state_changes(); event_reply.add_debug_op_state_changes();
change->set_change(changes_to_enable_[i]
? EventReply::DebugOpStateChange::ENABLE // State changes will take effect in the next stream, i.e., next debugged
: EventReply::DebugOpStateChange::DISABLE); // Session.run() call.
change->set_node_name(changes_node_names_[i]); change->set_state(new_states_[i]);
change->set_output_slot(changes_output_slots_[i]); const DebugNodeKey& debug_node_key = debug_node_keys_[i];
change->set_debug_op(changes_debug_ops_[i]); change->set_node_name(debug_node_key.node_name);
change->set_output_slot(debug_node_key.output_slot);
change->set_debug_op(debug_node_key.debug_op);
stream->Write(event_reply); stream->Write(event_reply);
if (new_states_[i] == EventReply::DebugOpStateChange::READ_WRITE) {
write_enabled_debug_node_keys_.insert(debug_node_key);
} else {
write_enabled_debug_node_keys_.erase(debug_node_key);
}
} }
changes_to_enable_.clear();
changes_node_names_.clear(); debug_node_keys_.clear();
changes_output_slots_.clear(); new_states_.clear();
changes_debug_ops_.clear();
} }
return ::grpc::Status::OK; return ::grpc::Status::OK;
@ -109,13 +126,12 @@ void TestEventListenerImpl::ClearReceivedDebugData() {
} }
void TestEventListenerImpl::RequestDebugOpStateChangeAtNextStream( void TestEventListenerImpl::RequestDebugOpStateChangeAtNextStream(
bool to_enable, const DebugNodeKey& debug_node_key) { const EventReply::DebugOpStateChange::State new_state,
mutex_lock l(changes_mu_); const DebugNodeKey& debug_node_key) {
mutex_lock l(states_mu_);
changes_to_enable_.push_back(to_enable); debug_node_keys_.push_back(debug_node_key);
changes_node_names_.push_back(debug_node_key.node_name); new_states_.push_back(new_state);
changes_output_slots_.push_back(debug_node_key.output_slot);
changes_debug_ops_.push_back(debug_node_key.debug_op);
} }
void TestEventListenerImpl::RunServer(const int server_port) { void TestEventListenerImpl::RunServer(const int server_port) {

View File

@ -17,6 +17,7 @@ limitations under the License.
#define TENSORFLOW_DEBUG_GRPC_TESTLIB_H_ #define TENSORFLOW_DEBUG_GRPC_TESTLIB_H_
#include <atomic> #include <atomic>
#include <unordered_set>
#include "grpc++/grpc++.h" #include "grpc++/grpc++.h"
#include "tensorflow/core/debug/debug_io_utils.h" #include "tensorflow/core/debug/debug_io_utils.h"
@ -44,7 +45,8 @@ class TestEventListenerImpl final : public EventListener::Service {
void ClearReceivedDebugData(); void ClearReceivedDebugData();
void RequestDebugOpStateChangeAtNextStream( void RequestDebugOpStateChangeAtNextStream(
bool to_enable, const DebugNodeKey& debug_node_key); const EventReply::DebugOpStateChange::State new_state,
const DebugNodeKey& debug_node_key);
std::vector<string> debug_metadata_strings; std::vector<string> debug_metadata_strings;
std::vector<string> encoded_graph_defs; std::vector<string> encoded_graph_defs;
@ -58,12 +60,13 @@ class TestEventListenerImpl final : public EventListener::Service {
std::atomic_bool stop_requested_; std::atomic_bool stop_requested_;
std::atomic_bool stopped_; std::atomic_bool stopped_;
std::vector<bool> changes_to_enable_ GUARDED_BY(changes_mu_); std::vector<DebugNodeKey> debug_node_keys_ GUARDED_BY(states_mu_);
std::vector<string> changes_node_names_ GUARDED_BY(changes_mu_); std::vector<EventReply::DebugOpStateChange::State> new_states_
std::vector<int32> changes_output_slots_ GUARDED_BY(changes_mu_); GUARDED_BY(states_mu_);
std::vector<string> changes_debug_ops_ GUARDED_BY(changes_mu_);
mutex changes_mu_; std::unordered_set<DebugNodeKey> write_enabled_debug_node_keys_;
mutex states_mu_;
}; };
// Poll a gRPC debug server by sending a small tensor repeatedly till success. // Poll a gRPC debug server by sending a small tensor repeatedly till success.

View File

@ -15,6 +15,11 @@ limitations under the License.
#include "tensorflow/core/debug/debug_io_utils.h" #include "tensorflow/core/debug/debug_io_utils.h"
#include <stddef.h>
#include <string.h>
#include <cmath>
#include <limits>
#include <utility>
#include <vector> #include <vector>
#ifndef PLATFORM_WINDOWS #ifndef PLATFORM_WINDOWS
@ -297,6 +302,15 @@ DebugNodeKey::DebugNodeKey(const string& device_name, const string& node_name,
strings::StrCat(node_name, ":", output_slot, ":", debug_op)), strings::StrCat(node_name, ":", output_slot, ":", debug_op)),
device_path(DeviceNameToDevicePath(device_name)) {} device_path(DeviceNameToDevicePath(device_name)) {}
bool DebugNodeKey::operator==(const DebugNodeKey& other) const {
return (device_name == other.device_name && node_name == other.node_name &&
output_slot == other.output_slot && debug_op == other.debug_op);
}
bool DebugNodeKey::operator!=(const DebugNodeKey& other) const {
return !((*this) == other);
}
Status ReadEventFromFile(const string& dump_file_path, Event* event) { Status ReadEventFromFile(const string& dump_file_path, Event* event) {
Env* env(Env::Default()); Env* env(Env::Default());
@ -537,7 +551,7 @@ bool DebugIO::IsCopyNodeGateOpen(
DebugIO::kGrpcURLScheme)) { DebugIO::kGrpcURLScheme)) {
return true; return true;
} else { } else {
if (DebugGrpcIO::IsGateOpen(spec.watch_key, spec.url)) { if (DebugGrpcIO::IsReadGateOpen(spec.url, spec.watch_key)) {
return true; return true;
} }
} }
@ -557,7 +571,7 @@ bool DebugIO::IsDebugNodeGateOpen(const string& watch_key,
DebugIO::kGrpcURLScheme)) { DebugIO::kGrpcURLScheme)) {
return true; return true;
} else { } else {
if (DebugGrpcIO::IsGateOpen(watch_key, debug_url)) { if (DebugGrpcIO::IsReadGateOpen(debug_url, watch_key)) {
return true; return true;
} }
} }
@ -575,7 +589,7 @@ bool DebugIO::IsDebugURLGateOpen(const string& watch_key,
if (debug_url.find(kGrpcURLScheme) != 0) { if (debug_url.find(kGrpcURLScheme) != 0) {
return true; return true;
} else { } else {
return DebugGrpcIO::IsGateOpen(watch_key, debug_url); return DebugGrpcIO::IsReadGateOpen(debug_url, watch_key);
} }
#else #else
return true; return true;
@ -731,6 +745,12 @@ bool DebugGrpcChannel::WriteEvent(const Event& event) {
return reader_writer_->Write(event); return reader_writer_->Write(event);
} }
bool DebugGrpcChannel::ReadEventReply(EventReply* event_reply) {
mutex_lock l(mu_);
return reader_writer_->Read(event_reply);
}
Status DebugGrpcChannel::ReceiveServerRepliesAndClose() { Status DebugGrpcChannel::ReceiveServerRepliesAndClose() {
mutex_lock l(mu_); mutex_lock l(mu_);
@ -744,13 +764,8 @@ Status DebugGrpcChannel::ReceiveServerRepliesAndClose() {
string watch_key = strings::StrCat(debug_op_state_change.node_name(), ":", string watch_key = strings::StrCat(debug_op_state_change.node_name(), ":",
debug_op_state_change.output_slot(), debug_op_state_change.output_slot(),
":", debug_op_state_change.debug_op()); ":", debug_op_state_change.debug_op());
if (debug_op_state_change.change() == DebugGrpcIO::SetDebugNodeKeyGrpcState(url_, watch_key,
EventReply::DebugOpStateChange::ENABLE) { debug_op_state_change.state());
DebugGrpcIO::EnableWatchKey(url_, watch_key);
} else if (debug_op_state_change.change() ==
EventReply::DebugOpStateChange::DISABLE) {
DebugGrpcIO::DisableWatchKey(url_, watch_key);
}
} }
} }
@ -789,7 +804,8 @@ Status DebugGrpcIO::SendTensorThroughGrpcStream(
const DebugNodeKey& debug_node_key, const Tensor& tensor, const DebugNodeKey& debug_node_key, const Tensor& tensor,
const uint64 wall_time_us, const string& grpc_stream_url, const uint64 wall_time_us, const string& grpc_stream_url,
const bool gated) { const bool gated) {
if (gated && !IsGateOpen(debug_node_key.debug_node_name, grpc_stream_url)) { if (gated &&
!IsReadGateOpen(grpc_stream_url, debug_node_key.debug_node_name)) {
return Status::OK(); return Status::OK();
} else { } else {
std::vector<Event> events; std::vector<Event> events;
@ -799,10 +815,35 @@ Status DebugGrpcIO::SendTensorThroughGrpcStream(
TF_RETURN_IF_ERROR( TF_RETURN_IF_ERROR(
SendEventProtoThroughGrpcStream(event, grpc_stream_url)); SendEventProtoThroughGrpcStream(event, grpc_stream_url));
} }
if (IsWriteGateOpen(grpc_stream_url, debug_node_key.debug_node_name)) {
EventReply event_reply;
TF_RETURN_IF_ERROR(ReceiveEventReplyProtoThroughGrpcStream(
&event_reply, grpc_stream_url));
// TODO(cais): Support new tensor value carried in the EventReply for
// overriding the value of the tensor being published.
}
return Status::OK(); return Status::OK();
} }
} }
// static
Status DebugGrpcIO::ReceiveEventReplyProtoThroughGrpcStream(
EventReply* event_reply, const string& grpc_stream_url) {
std::shared_ptr<DebugGrpcChannel> debug_grpc_channel;
{
mutex_lock l(streams_mu);
std::unordered_map<string, std::shared_ptr<DebugGrpcChannel>>*
stream_channels = GetStreamChannels();
debug_grpc_channel = (*stream_channels)[grpc_stream_url];
}
if (debug_grpc_channel->ReadEventReply(event_reply)) {
return Status::OK();
} else {
return errors::Cancelled(strings::StrCat(
"Reading EventReply from stream URL ", grpc_stream_url, " failed."));
}
}
// static // static
Status DebugGrpcIO::SendEventProtoThroughGrpcStream( Status DebugGrpcIO::SendEventProtoThroughGrpcStream(
const Event& event_proto, const string& grpc_stream_url) { const Event& event_proto, const string& grpc_stream_url) {
@ -821,9 +862,7 @@ Status DebugGrpcIO::SendEventProtoThroughGrpcStream(
debug_grpc_channel.reset(new DebugGrpcChannel(server_stream_addr)); debug_grpc_channel.reset(new DebugGrpcChannel(server_stream_addr));
TF_RETURN_IF_ERROR( TF_RETURN_IF_ERROR(
debug_grpc_channel->Connect(channel_connection_timeout_micros)); debug_grpc_channel->Connect(channel_connection_timeout_micros));
(*stream_channels)[grpc_stream_url] = debug_grpc_channel; (*stream_channels)[grpc_stream_url] = debug_grpc_channel;
CreateEmptyEnabledSet(grpc_stream_url);
} else { } else {
debug_grpc_channel = (*stream_channels)[grpc_stream_url]; debug_grpc_channel = (*stream_channels)[grpc_stream_url];
} }
@ -838,16 +877,22 @@ Status DebugGrpcIO::SendEventProtoThroughGrpcStream(
return Status::OK(); return Status::OK();
} }
// static bool DebugGrpcIO::IsReadGateOpen(const string& grpc_debug_url,
bool DebugGrpcIO::IsGateOpen(const string& watch_key, const string& watch_key) {
const string& grpc_debug_url) { const DebugNodeName2State* enabled_node_to_state =
std::unordered_map<string, std::unordered_set<string>>* enabled_watch_keys = GetEnabledDebugOpStatesAtUrl(grpc_debug_url);
GetEnabledWatchKeys(); return enabled_node_to_state->find(watch_key) != enabled_node_to_state->end();
if (enabled_watch_keys->find(grpc_debug_url) == enabled_watch_keys->end()) { }
bool DebugGrpcIO::IsWriteGateOpen(const string& grpc_debug_url,
const string& watch_key) {
const DebugNodeName2State* enabled_node_to_state =
GetEnabledDebugOpStatesAtUrl(grpc_debug_url);
auto it = enabled_node_to_state->find(watch_key);
if (it == enabled_node_to_state->end()) {
return false; return false;
} else { } else {
const auto& url_enabled = (*enabled_watch_keys)[grpc_debug_url]; return it->second == EventReply::DebugOpStateChange::READ_WRITE;
return url_enabled.find(watch_key) != url_enabled.end();
} }
} }
@ -871,56 +916,46 @@ Status DebugGrpcIO::CloseGrpcStream(const string& grpc_stream_url) {
} }
// static // static
std::unordered_map<string, std::unordered_set<string>>* std::unordered_map<string, DebugGrpcIO::DebugNodeName2State>*
DebugGrpcIO::GetEnabledWatchKeys() { DebugGrpcIO::GetEnabledDebugOpStates() {
static std::unordered_map<string, std::unordered_set<string>>* static std::unordered_map<string, DebugNodeName2State>*
enabled_watch_keys = enabled_debug_op_states =
new std::unordered_map<string, std::unordered_set<string>>(); new std::unordered_map<string, DebugNodeName2State>();
return enabled_watch_keys; return enabled_debug_op_states;
} }
// static // static
void DebugGrpcIO::EnableWatchKey(const string& grpc_debug_url, DebugGrpcIO::DebugNodeName2State* DebugGrpcIO::GetEnabledDebugOpStatesAtUrl(
const string& watch_key) { const string& grpc_debug_url) {
std::unordered_map<string, std::unordered_set<string>>* enabled_watch_keys = std::unordered_map<string, DebugNodeName2State>* states =
GetEnabledWatchKeys(); GetEnabledDebugOpStates();
if (enabled_watch_keys->find(grpc_debug_url) == enabled_watch_keys->end()) { if (states->find(grpc_debug_url) == states->end()) {
CreateEmptyEnabledSet(grpc_debug_url); DebugNodeName2State url_enabled_debug_op_states;
(*states)[grpc_debug_url] = url_enabled_debug_op_states;
} }
(*enabled_watch_keys)[grpc_debug_url].insert(watch_key); return &(*states)[grpc_debug_url];
} }
// static // static
void DebugGrpcIO::DisableWatchKey(const string& grpc_debug_url, void DebugGrpcIO::SetDebugNodeKeyGrpcState(
const string& watch_key) { const string& grpc_debug_url, const string& watch_key,
std::unordered_map<string, std::unordered_set<string>>* enabled_watch_keys = const EventReply::DebugOpStateChange::State new_state) {
GetEnabledWatchKeys(); DebugNodeName2State* states = GetEnabledDebugOpStatesAtUrl(grpc_debug_url);
if (enabled_watch_keys->find(grpc_debug_url) == enabled_watch_keys->end()) { if (new_state == EventReply::DebugOpStateChange::DISABLED) {
LOG(WARNING) << "Attempt to disable a watch key for an unregistered gRPC " if (states->find(watch_key) == states->end()) {
<< "debug URL: " << grpc_debug_url; LOG(ERROR) << "Attempt to disable a watch key that is not currently "
} else { << "enabled at " << grpc_debug_url << ": " << watch_key;
std::unordered_set<string>& url_enabled =
(*enabled_watch_keys)[grpc_debug_url];
if (url_enabled.find(watch_key) == url_enabled.end()) {
LOG(WARNING) << "Attempt to disable a watch key that is not currently "
<< "enabled at " << grpc_debug_url << ": " << watch_key;
} else { } else {
url_enabled.erase(watch_key); states->erase(watch_key);
} }
} else if (new_state != EventReply::DebugOpStateChange::STATE_UNSPECIFIED) {
(*states)[watch_key] = new_state;
} }
} }
// static // static
void DebugGrpcIO::ClearEnabledWatchKeys() { GetEnabledWatchKeys()->clear(); } void DebugGrpcIO::ClearEnabledWatchKeys() {
GetEnabledDebugOpStates()->clear();
// static
void DebugGrpcIO::CreateEmptyEnabledSet(const string& grpc_debug_url) {
std::unordered_map<string, std::unordered_set<string>>* enabled_watch_keys =
GetEnabledWatchKeys();
if (enabled_watch_keys->find(grpc_debug_url) == enabled_watch_keys->end()) {
std::unordered_set<string> empty_watch_keys;
(*enabled_watch_keys)[grpc_debug_url] = empty_watch_keys;
}
} }
#endif // #ifndef PLATFORM_WINDOWS #endif // #ifndef PLATFORM_WINDOWS

View File

@ -16,8 +16,13 @@ limitations under the License.
#ifndef TENSORFLOW_DEBUG_IO_UTILS_H_ #ifndef TENSORFLOW_DEBUG_IO_UTILS_H_
#define TENSORFLOW_DEBUG_IO_UTILS_H_ #define TENSORFLOW_DEBUG_IO_UTILS_H_
#include <cstddef>
#include <functional>
#include <memory>
#include <string>
#include <unordered_map> #include <unordered_map>
#include <unordered_set> #include <unordered_set>
#include <vector>
#include "tensorflow/core/framework/tensor.h" #include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/graph/graph.h" #include "tensorflow/core/graph/graph.h"
@ -49,6 +54,9 @@ struct DebugNodeKey {
// ,job_localhost,replica_0,task_0,cpu_0. // ,job_localhost,replica_0,task_0,cpu_0.
static const string DeviceNameToDevicePath(const string& device_name); static const string DeviceNameToDevicePath(const string& device_name);
bool operator==(const DebugNodeKey& other) const;
bool operator!=(const DebugNodeKey& other) const;
const string device_name; const string device_name;
const string node_name; const string node_name;
const int32 output_slot; const int32 output_slot;
@ -219,6 +227,19 @@ class DebugFileIO {
} // namespace tensorflow } // namespace tensorflow
namespace std {
template <>
struct hash<::tensorflow::DebugNodeKey> {
size_t operator()(const ::tensorflow::DebugNodeKey& k) const {
return ::tensorflow::Hash64(
::tensorflow::strings::StrCat(k.device_name, ":", k.node_name, ":",
k.output_slot, ":", k.debug_op, ":"));
}
};
} // namespace std
// TODO(cais): Support grpc:// debug URLs in open source once Python grpc // TODO(cais): Support grpc:// debug URLs in open source once Python grpc
// genrule becomes available. See b/23796275. // genrule becomes available. See b/23796275.
#ifndef PLATFORM_WINDOWS #ifndef PLATFORM_WINDOWS
@ -252,7 +273,8 @@ class DebugGrpcChannel {
// Write an Event proto to the debug gRPC stream. // Write an Event proto to the debug gRPC stream.
// //
// Thread-safety: Safe with respect to other calls to the same method and // Thread-safety: Safe with respect to other calls to the same method and
// call to Close(). // calls to ReadEventReply() and Close().
//
// Args: // Args:
// event: The event proto to be written to the stream. // event: The event proto to be written to the stream.
// //
@ -260,6 +282,19 @@ class DebugGrpcChannel {
// True iff the write is successful. // True iff the write is successful.
bool WriteEvent(const Event& event); bool WriteEvent(const Event& event);
// Read an EventReply proto from the debug gRPC stream.
//
// This method blocks and waits for an EventReply from the server.
// Thread-safety: Safe with respect to other calls to the same method and
// calls to WriteEvent() and Close().
//
// Args:
// event_reply: the to-be-modified EventReply proto passed as reference.
//
// Returns:
// True iff the read is successful.
bool ReadEventReply(EventReply* event_reply);
// Receive EventReplies from server (if any) and close the stream and the // Receive EventReplies from server (if any) and close the stream and the
// channel. // channel.
Status ReceiveServerRepliesAndClose(); Status ReceiveServerRepliesAndClose();
@ -294,48 +329,49 @@ class DebugGrpcIO {
static Status SendEventProtoThroughGrpcStream(const Event& event_proto, static Status SendEventProtoThroughGrpcStream(const Event& event_proto,
const string& grpc_stream_url); const string& grpc_stream_url);
// Checks whether a debug watch key is allowed to send data to a given grpc:// // Receive an EventReply proto through a debug gRPC stream.
// debug URL given the current gating status. static Status ReceiveEventReplyProtoThroughGrpcStream(
// EventReply* event_reply, const string& grpc_stream_url);
// Args:
// watch_key: debug tensor watch key, in the format of // Check whether a debug watch key is read-activated at a given gRPC URL.
// tensor_name:debug_op, e.g., "Weights:0:DebugIdentity". static bool IsReadGateOpen(const string& grpc_debug_url,
// grpc_debug_url: the debug URL, e.g., "grpc://localhost:3333", const string& watch_key);
//
// Returns: // Check whether a debug watch key is write-activated (i.e., read- and
// Whether the sending of debug data to grpc_debug_url should // write-activated) at a given gRPC URL.
// proceed. static bool IsWriteGateOpen(const string& grpc_debug_url,
static bool IsGateOpen(const string& watch_key, const string& grpc_debug_url); const string& watch_key);
// Closes a gRPC stream to the given address, if it exists. // Closes a gRPC stream to the given address, if it exists.
// Thread-safety: Safe with respect to other calls to the same method and // Thread-safety: Safe with respect to other calls to the same method and
// calls to SendTensorThroughGrpcStream(). // calls to SendTensorThroughGrpcStream().
static Status CloseGrpcStream(const string& grpc_stream_url); static Status CloseGrpcStream(const string& grpc_stream_url);
// Enables a debug watch key at a grpc:// debug URL. // Set the gRPC state of a debug node key.
static void EnableWatchKey(const string& grpc_debug_url, // TODO(cais): Include device information in watch_key.
const string& watch_key); static void SetDebugNodeKeyGrpcState(
const string& grpc_debug_url, const string& watch_key,
// Disables a debug watch key at a grpc:// debug URL. const EventReply::DebugOpStateChange::State new_state);
static void DisableWatchKey(const string& grpc_debug_url,
const string& watch_key);
private: private:
using DebugNodeName2State =
std::unordered_map<string, EventReply::DebugOpStateChange::State>;
// Returns a global map from grpc debug URLs to the corresponding // Returns a global map from grpc debug URLs to the corresponding
// DebugGrpcChannels. // DebugGrpcChannels.
static std::unordered_map<string, std::shared_ptr<DebugGrpcChannel>>* static std::unordered_map<string, std::shared_ptr<DebugGrpcChannel>>*
GetStreamChannels(); GetStreamChannels();
// Returns a global map from grpc debug URLs to the enabled gated debug nodes. // Returns a map from debug URL to a map from debug op name to enabled state.
// The keys are grpc:// URLs of the debug servers, e.g., "grpc://foo:3333". static std::unordered_map<string, DebugNodeName2State>*
// Each value element of the value has the format GetEnabledDebugOpStates();
// <node_name>:<output_slot>:<debug_op>", e.g.,
// "Weights_1:0:DebugNumericSummary".
static std::unordered_map<string, std::unordered_set<string>>*
GetEnabledWatchKeys();
// Returns a map from debug op names to enabled state, for a given debug URL.
static DebugNodeName2State* GetEnabledDebugOpStatesAtUrl(
const string& grpc_debug_url);
// Clear enabled debug op state from all debug URLs (if any).
static void ClearEnabledWatchKeys(); static void ClearEnabledWatchKeys();
static void CreateEmptyEnabledSet(const string& grpc_debug_url);
static mutex streams_mu; static mutex streams_mu;
static int64 channel_connection_timeout_micros; static int64 channel_connection_timeout_micros;

View File

@ -13,6 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License. limitations under the License.
==============================================================================*/ ==============================================================================*/
#include <unordered_set>
#include "tensorflow/core/debug/debug_io_utils.h" #include "tensorflow/core/debug/debug_io_utils.h"
#include "tensorflow/core/debug/debugger_event_metadata.pb.h" #include "tensorflow/core/debug/debugger_event_metadata.pb.h"
@ -62,6 +64,39 @@ TEST_F(DebugIOUtilsTest, ConstructDebugNodeKey) {
debug_node_key.device_path); debug_node_key.device_path);
} }
TEST_F(DebugIOUtilsTest, EqualityOfDebugNodeKeys) {
const DebugNodeKey debug_node_key_1("/job:worker/replica:1/task:0/gpu:2",
"hidden_1/MatMul", 0, "DebugIdentity");
const DebugNodeKey debug_node_key_2("/job:worker/replica:1/task:0/gpu:2",
"hidden_1/MatMul", 0, "DebugIdentity");
const DebugNodeKey debug_node_key_3("/job:worker/replica:1/task:0/gpu:2",
"hidden_1/BiasAdd", 0, "DebugIdentity");
const DebugNodeKey debug_node_key_4("/job:worker/replica:1/task:0/gpu:2",
"hidden_1/MatMul", 0,
"DebugNumericSummary");
EXPECT_EQ(debug_node_key_1, debug_node_key_2);
EXPECT_NE(debug_node_key_1, debug_node_key_3);
EXPECT_NE(debug_node_key_1, debug_node_key_4);
EXPECT_NE(debug_node_key_3, debug_node_key_4);
}
TEST_F(DebugIOUtilsTest, DebugNodeKeysIsHashable) {
const DebugNodeKey debug_node_key_1("/job:worker/replica:1/task:0/gpu:2",
"hidden_1/MatMul", 0, "DebugIdentity");
const DebugNodeKey debug_node_key_2("/job:worker/replica:1/task:0/gpu:2",
"hidden_1/MatMul", 0, "DebugIdentity");
const DebugNodeKey debug_node_key_3("/job:worker/replica:1/task:0/gpu:2",
"hidden_1/BiasAdd", 0, "DebugIdentity");
std::unordered_set<DebugNodeKey> keys;
keys.insert(debug_node_key_1);
ASSERT_EQ(1, keys.size());
keys.insert(debug_node_key_3);
ASSERT_EQ(2, keys.size());
keys.erase(debug_node_key_2);
ASSERT_EQ(1, keys.size());
}
TEST_F(DebugIOUtilsTest, DumpFloatTensorToFileSunnyDay) { TEST_F(DebugIOUtilsTest, DumpFloatTensorToFileSunnyDay) {
Initialize(); Initialize();

View File

@ -17,6 +17,7 @@ syntax = "proto3";
package tensorflow; package tensorflow;
import "tensorflow/core/framework/tensor.proto";
import "tensorflow/core/util/event.proto"; import "tensorflow/core/util/event.proto";
// Reply message from EventListener to the client, i.e., to the source of the // Reply message from EventListener to the client, i.e., to the source of the
@ -24,18 +25,25 @@ import "tensorflow/core/util/event.proto";
// TensorFlow graph being executed. // TensorFlow graph being executed.
message EventReply { message EventReply {
message DebugOpStateChange { message DebugOpStateChange {
enum Change { enum State {
DISABLE = 0; STATE_UNSPECIFIED = 0;
ENABLE = 1; DISABLED = 1;
READ_ONLY = 2;
READ_WRITE = 3;
} }
Change change = 1; State state = 1;
string node_name = 2; string node_name = 2;
int32 output_slot = 3; int32 output_slot = 3;
string debug_op = 4; string debug_op = 4;
} }
repeated DebugOpStateChange debug_op_state_changes = 1; repeated DebugOpStateChange debug_op_state_changes = 1;
// New tensor value to override the current tensor value with.
TensorProto tensor = 2;
// TODO(cais): Make use of this field to implement overriding of tensor value
// during debugging.
} }
// EventListener: Receives Event protos, e.g., from debugged TensorFlow // EventListener: Receives Event protos, e.g., from debugged TensorFlow

View File

@ -259,10 +259,6 @@ class DebugNumericSummaryOpTest : public OpsTestBase {
#if defined(PLATFORM_GOOGLE) #if defined(PLATFORM_GOOGLE)
void ClearEnabledWatchKeys() { DebugGrpcIO::ClearEnabledWatchKeys(); } void ClearEnabledWatchKeys() { DebugGrpcIO::ClearEnabledWatchKeys(); }
void CreateEmptyEnabledSet(const string& grpc_debug_url) {
DebugGrpcIO::CreateEmptyEnabledSet(grpc_debug_url);
}
#endif #endif
}; };
@ -604,7 +600,6 @@ TEST_F(DebugNumericSummaryOpTest, BoolSuccess) {
#if defined(PLATFORM_GOOGLE) #if defined(PLATFORM_GOOGLE)
TEST_F(DebugNumericSummaryOpTest, DisabledDueToEmptyEnabledSet) { TEST_F(DebugNumericSummaryOpTest, DisabledDueToEmptyEnabledSet) {
ClearEnabledWatchKeys(); ClearEnabledWatchKeys();
CreateEmptyEnabledSet("grpc://server:3333");
std::vector<string> debug_urls({"grpc://server:3333"}); std::vector<string> debug_urls({"grpc://server:3333"});
TF_ASSERT_OK(InitGated(DT_FLOAT, debug_urls)); TF_ASSERT_OK(InitGated(DT_FLOAT, debug_urls));
@ -617,8 +612,9 @@ TEST_F(DebugNumericSummaryOpTest, DisabledDueToEmptyEnabledSet) {
TEST_F(DebugNumericSummaryOpTest, DisabledDueToNonMatchingWatchKey) { TEST_F(DebugNumericSummaryOpTest, DisabledDueToNonMatchingWatchKey) {
ClearEnabledWatchKeys(); ClearEnabledWatchKeys();
DebugGrpcIO::EnableWatchKey("grpc://server:3333", DebugGrpcIO::SetDebugNodeKeyGrpcState(
"FakeTensor:1:DebugNumeriSummary"); "grpc://server:3333", "FakeTensor:1:DebugNumeriSummary",
EventReply::DebugOpStateChange::READ_ONLY);
std::vector<string> debug_urls({"grpc://server:3333"}); std::vector<string> debug_urls({"grpc://server:3333"});
TF_ASSERT_OK(InitGated(DT_FLOAT, debug_urls)); TF_ASSERT_OK(InitGated(DT_FLOAT, debug_urls));

View File

@ -31,16 +31,20 @@ from tensorflow.core.debug import debug_service_pb2
from tensorflow.core.framework import graph_pb2 from tensorflow.core.framework import graph_pb2
from tensorflow.python.debug.lib import debug_data from tensorflow.python.debug.lib import debug_data
from tensorflow.python.debug.lib import debug_service_pb2_grpc from tensorflow.python.debug.lib import debug_service_pb2_grpc
from tensorflow.python.platform import tf_logging as logging
DebugWatch = collections.namedtuple("DebugWatch", DebugWatch = collections.namedtuple("DebugWatch",
["node_name", "output_slot", "debug_op"]) ["node_name", "output_slot", "debug_op"])
def _watch_key_event_reply(to_enable, node_name, output_slot, debug_op): def _watch_key_event_reply(new_state, node_name, output_slot, debug_op):
"""Make EventReply proto to represent a request to watch/unwatch a debug op. """Make `EventReply` proto to represent a request to watch/unwatch a debug op.
Args: Args:
to_enable: (`bool`) whether the request is to enable the watch key. new_state: (`debug_service_pb2.EventReply.DebugOpStateChange.State`) the new
state to set the debug node to, i.e., whether the debug node will become
disabled under the grpc mode (`DISABLED`), become a watchpoint
(`READ_ONLY`) or become a breakpoint (`READ_WRITE`).
node_name: (`str`) name of the node. node_name: (`str`) name of the node.
output_slot: (`int`) output slot of the tensor. output_slot: (`int`) output slot of the tensor.
debug_op: (`str`) the debug op attached to node_name:output_slot tensor to debug_op: (`str`) the debug op attached to node_name:output_slot tensor to
@ -51,9 +55,7 @@ def _watch_key_event_reply(to_enable, node_name, output_slot, debug_op):
""" """
event_reply = debug_service_pb2.EventReply() event_reply = debug_service_pb2.EventReply()
state_change = event_reply.debug_op_state_changes.add() state_change = event_reply.debug_op_state_changes.add()
state_change.change = ( state_change.state = new_state
debug_service_pb2.EventReply.DebugOpStateChange.ENABLE
if to_enable else debug_service_pb2.EventReply.DebugOpStateChange.DISABLE)
state_change.node_name = node_name state_change.node_name = node_name
state_change.output_slot = output_slot state_change.output_slot = output_slot
state_change.debug_op = debug_op state_change.debug_op = debug_op
@ -125,6 +127,7 @@ class EventListenerBaseServicer(debug_service_pb2_grpc.EventListenerServicer):
self._event_reply_queue = queue.Queue() self._event_reply_queue = queue.Queue()
self._gated_grpc_debug_watches = set() self._gated_grpc_debug_watches = set()
self._breakpoints = set()
def SendEvents(self, request_iterator, context): def SendEvents(self, request_iterator, context):
"""Implementation of the SendEvents service method. """Implementation of the SendEvents service method.
@ -171,11 +174,30 @@ class EventListenerBaseServicer(debug_service_pb2_grpc.EventListenerServicer):
maybe_tensor_event = self._process_tensor_event_in_chunks( maybe_tensor_event = self._process_tensor_event_in_chunks(
event, tensor_chunks) event, tensor_chunks)
if maybe_tensor_event: if maybe_tensor_event:
stream_handler.on_value_event(maybe_tensor_event) event_reply = stream_handler.on_value_event(maybe_tensor_event)
if event_reply is not None:
yield event_reply
# The server writes EventReply messages, if any. # The server writes EventReply messages, if any.
while not self._event_reply_queue.empty(): while not self._event_reply_queue.empty():
yield self._event_reply_queue.get() event_reply = self._event_reply_queue.get()
for state_change in event_reply.debug_op_state_changes:
if (state_change.state ==
debug_service_pb2.EventReply.DebugOpStateChange.READ_WRITE):
logging.info("Adding breakpoint %s:%d:%s", state_change.node_name,
state_change.output_slot, state_change.debug_op)
self._breakpoints.add(
(state_change.node_name, state_change.output_slot,
state_change.debug_op))
elif (state_change.state ==
debug_service_pb2.EventReply.DebugOpStateChange.DISABLED):
logging.info("Removing watchpoint or breakpoint: %s:%d:%s",
state_change.node_name, state_change.output_slot,
state_change.debug_op)
self._breakpoints.discard(
(state_change.node_name, state_change.output_slot,
state_change.debug_op))
yield event_reply
def _process_tensor_event_in_chunks(self, event, tensor_chunks): def _process_tensor_event_in_chunks(self, event, tensor_chunks):
"""Possibly reassemble event chunks. """Possibly reassemble event chunks.
@ -336,8 +358,8 @@ class EventListenerBaseServicer(debug_service_pb2_grpc.EventListenerServicer):
finally: finally:
self._server_lock.release() self._server_lock.release()
def request_watch(self, node_name, output_slot, debug_op): def request_watch(self, node_name, output_slot, debug_op, breakpoint=False):
"""Request enabling a debug tensor watch. """Request enabling a debug tensor watchpoint or breakpoint.
This will let the server send a EventReply to the client side This will let the server send a EventReply to the client side
(i.e., the debugged TensorFlow runtime process) to request adding a watch (i.e., the debugged TensorFlow runtime process) to request adding a watch
@ -355,12 +377,19 @@ class EventListenerBaseServicer(debug_service_pb2_grpc.EventListenerServicer):
output_slot: (`int`) output slot index of the tensor to watch. output_slot: (`int`) output slot index of the tensor to watch.
debug_op: (`str`) name of the debug op to enable. This should not include debug_op: (`str`) name of the debug op to enable. This should not include
any attribute substrings. any attribute substrings.
breakpoint: (`bool`) Iff `True`, the debug op will block and wait until it
receives an `EventReply` response from the server. The `EventReply`
proto may carry a TensorProto that modifies the value of the debug op's
output tensor.
""" """
self._event_reply_queue.put( self._event_reply_queue.put(
_watch_key_event_reply(True, node_name, output_slot, debug_op)) _watch_key_event_reply(
debug_service_pb2.EventReply.DebugOpStateChange.READ_WRITE
if breakpoint else debug_service_pb2.EventReply.DebugOpStateChange.
READ_ONLY, node_name, output_slot, debug_op))
def request_unwatch(self, node_name, output_slot, debug_op): def request_unwatch(self, node_name, output_slot, debug_op):
"""Request disabling a debug tensor watch. """Request disabling a debug tensor watchpoint or breakpoint.
The request will take effect on the next debugged `Session.run()` call. The request will take effect on the next debugged `Session.run()` call.
@ -374,7 +403,18 @@ class EventListenerBaseServicer(debug_service_pb2_grpc.EventListenerServicer):
any attribute substrings. any attribute substrings.
""" """
self._event_reply_queue.put( self._event_reply_queue.put(
_watch_key_event_reply(False, node_name, output_slot, debug_op)) _watch_key_event_reply(debug_service_pb2.EventReply.DebugOpStateChange.
DISABLED, node_name, output_slot, debug_op))
@property
def breakpoints(self):
"""Get a set of the currently-activated breakpoints.
Returns:
A `set` of 3-tuples: (node_name, output_slot, debug_op), e.g.,
{("MatMul", 0, "DebugIdentity")}.
"""
return self._breakpoints
def gated_grpc_debug_watches(self): def gated_grpc_debug_watches(self):
"""Get the list of debug watches with attribute gated_grpc=True. """Get the list of debug watches with attribute gated_grpc=True.

View File

@ -31,6 +31,7 @@ import time
import portpicker import portpicker
from tensorflow.core.debug import debug_service_pb2
from tensorflow.core.protobuf import config_pb2 from tensorflow.core.protobuf import config_pb2
from tensorflow.core.util import event_pb2 from tensorflow.core.util import event_pb2
from tensorflow.python.client import session from tensorflow.python.client import session
@ -138,13 +139,26 @@ class EventListenerTestStreamHandler(
Args: Args:
event: The Event proto carrying a tensor value. event: The Event proto carrying a tensor value.
Returns:
If the debug node belongs to the set of currently activated breakpoints,
a `EventReply` proto will be returned.
""" """
if self._dump_dir: if self._dump_dir:
self._write_value_event(event) self._write_value_event(event)
else: else:
value = event.summary.value[0] value = event.summary.value[0]
tensor_value = debug_data.load_tensor_from_event(event)
self._event_listener_servicer.debug_tensor_values[value.node_name].append( self._event_listener_servicer.debug_tensor_values[value.node_name].append(
debug_data.load_tensor_from_event(event)) tensor_value)
items = event.summary.value[0].node_name.split(":")
node_name = items[0]
output_slot = int(items[1])
debug_op = items[2]
if ((node_name, output_slot, debug_op) in
self._event_listener_servicer.breakpoints):
return debug_service_pb2.EventReply()
def _try_makedirs(self, dir_path): def _try_makedirs(self, dir_path):
if not os.path.isdir(dir_path): if not os.path.isdir(dir_path):

View File

@ -556,6 +556,55 @@ class SessionDebugGrpcGatingTest(test_util.TensorFlowTestCase):
[50 + 5.0 * i], [50 + 5.0 * i],
self._server_2.debug_tensor_values["v:0:DebugIdentity"]) self._server_2.debug_tensor_values["v:0:DebugIdentity"])
def testToggleBreakpointWorks(self):
with session.Session(config=no_rewrite_session_config()) as sess:
v = variables.Variable(50.0, name="v")
delta = constant_op.constant(5.0, name="delta")
inc_v = state_ops.assign_add(v, delta, name="inc_v")
sess.run(v.initializer)
run_metadata = config_pb2.RunMetadata()
run_options = config_pb2.RunOptions(output_partition_graphs=True)
debug_utils.watch_graph(
run_options,
sess.graph,
debug_ops=["DebugIdentity(gated_grpc=true)"],
debug_urls=[self._debug_server_url_1])
for i in xrange(4):
self._server_1.clear_data()
# N.B.: These requests will be fulfilled not in this debugged
# Session.run() invocation, but in the next one.
if i in (0, 2):
# Enable breakpoint at delta:0:DebugIdentity in runs 0 and 2.
self._server_1.request_watch(
"delta", 0, "DebugIdentity", breakpoint=True)
else:
# Disable the breakpoint in runs 1 and 3.
self._server_1.request_unwatch("delta", 0, "DebugIdentity")
output = sess.run(inc_v, options=run_options, run_metadata=run_metadata)
self.assertAllClose(50.0 + 5.0 * (i + 1), output)
if i in (0, 2):
# After the end of runs 0 and 2, the server has received the requests
# to enable the breakpoint at delta:0:DebugIdentity. So the server
# should keep track of the correct breakpoints.
self.assertSetEqual({("delta", 0, "DebugIdentity")},
self._server_1.breakpoints)
else:
# During runs 1 and 3, the server should have received the published
# debug tensor delta:0:DebugIdentity. The breakpoint should have been
# unblocked by EventReply reponses from the server.
self.assertAllClose(
[5.0],
self._server_1.debug_tensor_values["delta:0:DebugIdentity"])
# After the runs, the server should have properly removed the
# breakpoints due to the request_unwatch calls.
self.assertSetEqual(set(), self._server_1.breakpoints)
def testGetGrpcDebugWatchesReturnsCorrectAnswer(self): def testGetGrpcDebugWatchesReturnsCorrectAnswer(self):
with session.Session() as sess: with session.Session() as sess:
v = variables.Variable(50.0, name="v") v = variables.Variable(50.0, name="v")