tfdbg: extend grpc_debug_server protocol for interactive debugging
Previously, a grpc-gated debug op has two modes: DISABLED and ENABLED. This CL splits the ENABLED state into two states: READ_ONLY and READ_WRITE. * READ_ONLY is equivalent to the previous ENABLED state, wherein a debug op publishes debug tensor to the grpc debug server and proceeds. It can be regarded as a "watchpoint" that doesn't block execution. * READ_WRITE is a "breakpoint". In addition to publishing the debug tensor, it blocks and awaits a EventReply proto response from the grpc debug server before proceeding. PiperOrigin-RevId: 164987725
This commit is contained in:
parent
c0f9b0a91e
commit
3c482c66b5
@ -65,10 +65,6 @@ class GrpcDebugTest : public ::testing::Test {
|
||||
|
||||
void ClearEnabledWatchKeys() { DebugGrpcIO::ClearEnabledWatchKeys(); }
|
||||
|
||||
void CreateEmptyEnabledSet(const string& grpc_debug_url) {
|
||||
DebugGrpcIO::CreateEmptyEnabledSet(grpc_debug_url);
|
||||
}
|
||||
|
||||
const int64 GetChannelConnectionTimeoutMicros() {
|
||||
return DebugGrpcIO::channel_connection_timeout_micros;
|
||||
}
|
||||
@ -261,7 +257,7 @@ TEST_F(GrpcDebugTest, SendMultipleDebugTensorsSynchronizedViaGrpcTest) {
|
||||
}
|
||||
}
|
||||
|
||||
TEST_F(GrpcDebugTest, SendeDebugTensorsThroughMultipleRoundsUsingGrpcGating) {
|
||||
TEST_F(GrpcDebugTest, SendDebugTensorsThroughMultipleRoundsUsingGrpcGating) {
|
||||
// Prepare the tensor to send.
|
||||
const DebugNodeKey kDebugNodeKey("/job:localhost/replica:0/task:0/cpu:0",
|
||||
"test_namescope/test_node", 0,
|
||||
@ -283,8 +279,60 @@ TEST_F(GrpcDebugTest, SendeDebugTensorsThroughMultipleRoundsUsingGrpcGating) {
|
||||
TF_ASSERT_OK(DebugIO::PublishDebugTensor(kDebugNodeKey, tensor, wall_time,
|
||||
urls, enable_gated_grpc));
|
||||
|
||||
server_data_.server->RequestDebugOpStateChangeAtNextStream(i == 0,
|
||||
kDebugNodeKey);
|
||||
server_data_.server->RequestDebugOpStateChangeAtNextStream(
|
||||
i == 0 ? EventReply::DebugOpStateChange::READ_ONLY
|
||||
: EventReply::DebugOpStateChange::DISABLED,
|
||||
kDebugNodeKey);
|
||||
|
||||
// Close the debug gRPC stream.
|
||||
Status close_status = DebugIO::CloseDebugURL(server_data_.url);
|
||||
ASSERT_TRUE(close_status.ok());
|
||||
|
||||
// Check dumped files according to the expected gating results.
|
||||
if (i < 2) {
|
||||
ASSERT_EQ(1, server_data_.server->node_names.size());
|
||||
ASSERT_EQ(1, server_data_.server->output_slots.size());
|
||||
ASSERT_EQ(1, server_data_.server->debug_ops.size());
|
||||
EXPECT_EQ(kDebugNodeKey.device_name,
|
||||
server_data_.server->device_names[0]);
|
||||
EXPECT_EQ(kDebugNodeKey.node_name, server_data_.server->node_names[0]);
|
||||
EXPECT_EQ(kDebugNodeKey.output_slot,
|
||||
server_data_.server->output_slots[0]);
|
||||
EXPECT_EQ(kDebugNodeKey.debug_op, server_data_.server->debug_ops[0]);
|
||||
} else {
|
||||
ASSERT_EQ(0, server_data_.server->node_names.size());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
TEST_F(GrpcDebugTest, SendDebugTensorsThroughMultipleRoundsUnderReadWriteMode) {
|
||||
// Prepare the tensor to send.
|
||||
const DebugNodeKey kDebugNodeKey("/job:localhost/replica:0/task:0/cpu:0",
|
||||
"test_namescope/test_node", 0,
|
||||
"DebugIdentity");
|
||||
Tensor tensor(DT_INT32, TensorShape({1, 1}));
|
||||
tensor.flat<int>()(0) = 42;
|
||||
|
||||
const std::vector<string> urls({server_data_.url});
|
||||
for (int i = 0; i < 3; ++i) {
|
||||
server_data_.server->ClearReceivedDebugData();
|
||||
const uint64 wall_time = Env::Default()->NowMicros();
|
||||
|
||||
// On the 1st send (i == 0), gating is disabled, so data should be sent.
|
||||
// On the 2nd send (i == 1), gating is enabled, and the server has enabled
|
||||
// the watch key in the previous send (READ_WRITE), so data should be
|
||||
// sent. In this iteration, the server response with a EventReply proto to
|
||||
// unblock the debug node.
|
||||
// On the 3rd send (i == 2), gating is enabled, but the server has disabled
|
||||
// the watch key in the previous send, so data should not be sent.
|
||||
const bool enable_gated_grpc = (i != 0);
|
||||
TF_ASSERT_OK(DebugIO::PublishDebugTensor(kDebugNodeKey, tensor, wall_time,
|
||||
urls, enable_gated_grpc));
|
||||
|
||||
server_data_.server->RequestDebugOpStateChangeAtNextStream(
|
||||
i == 0 ? EventReply::DebugOpStateChange::READ_WRITE
|
||||
: EventReply::DebugOpStateChange::DISABLED,
|
||||
kDebugNodeKey);
|
||||
|
||||
// Close the debug gRPC stream.
|
||||
Status close_status = DebugIO::CloseDebugURL(server_data_.url);
|
||||
@ -308,8 +356,6 @@ TEST_F(GrpcDebugTest, SendeDebugTensorsThroughMultipleRoundsUsingGrpcGating) {
|
||||
}
|
||||
|
||||
TEST_F(GrpcDebugTest, TestGateDebugNodeOnEmptyEnabledSet) {
|
||||
CreateEmptyEnabledSet("grpc://localhost:3333");
|
||||
|
||||
ASSERT_FALSE(DebugIO::IsDebugNodeGateOpen("foo:0:DebugIdentity",
|
||||
{"grpc://localhost:3333"}));
|
||||
|
||||
@ -322,8 +368,12 @@ TEST_F(GrpcDebugTest, TestGateDebugNodeOnNonEmptyEnabledSet) {
|
||||
const string kGrpcUrl1 = "grpc://localhost:3333";
|
||||
const string kGrpcUrl2 = "grpc://localhost:3334";
|
||||
|
||||
DebugGrpcIO::EnableWatchKey(kGrpcUrl1, "foo:0:DebugIdentity");
|
||||
DebugGrpcIO::EnableWatchKey(kGrpcUrl1, "bar:0:DebugIdentity");
|
||||
DebugGrpcIO::SetDebugNodeKeyGrpcState(
|
||||
kGrpcUrl1, "foo:0:DebugIdentity",
|
||||
EventReply::DebugOpStateChange::READ_ONLY);
|
||||
DebugGrpcIO::SetDebugNodeKeyGrpcState(
|
||||
kGrpcUrl1, "bar:0:DebugIdentity",
|
||||
EventReply::DebugOpStateChange::READ_ONLY);
|
||||
|
||||
ASSERT_FALSE(
|
||||
DebugIO::IsDebugNodeGateOpen("foo:1:DebugIdentity", {kGrpcUrl1}));
|
||||
@ -350,9 +400,12 @@ TEST_F(GrpcDebugTest, TestGateDebugNodeOnMultipleEmptyEnabledSets) {
|
||||
const string kGrpcUrl2 = "grpc://localhost:3334";
|
||||
const string kGrpcUrl3 = "grpc://localhost:3335";
|
||||
|
||||
DebugGrpcIO::EnableWatchKey(kGrpcUrl1, "foo:0:DebugIdentity");
|
||||
DebugGrpcIO::EnableWatchKey(kGrpcUrl2, "bar:0:DebugIdentity");
|
||||
CreateEmptyEnabledSet(kGrpcUrl3);
|
||||
DebugGrpcIO::SetDebugNodeKeyGrpcState(
|
||||
kGrpcUrl1, "foo:0:DebugIdentity",
|
||||
EventReply::DebugOpStateChange::READ_ONLY);
|
||||
DebugGrpcIO::SetDebugNodeKeyGrpcState(
|
||||
kGrpcUrl2, "bar:0:DebugIdentity",
|
||||
EventReply::DebugOpStateChange::READ_ONLY);
|
||||
|
||||
ASSERT_TRUE(DebugIO::IsDebugNodeGateOpen("foo:0:DebugIdentity", {kGrpcUrl1}));
|
||||
ASSERT_TRUE(DebugIO::IsDebugNodeGateOpen("bar:0:DebugIdentity", {kGrpcUrl2}));
|
||||
@ -375,7 +428,9 @@ TEST_F(GrpcDebugTest, TestGateDebugNodeOnMultipleEmptyEnabledSets) {
|
||||
}
|
||||
|
||||
TEST_F(GrpcDebugTest, TestGateDebugNodeOnNonEmptyEnabledSetAndEmptyURLs) {
|
||||
DebugGrpcIO::EnableWatchKey("grpc://localhost:3333", "foo:0:DebugIdentity");
|
||||
DebugGrpcIO::SetDebugNodeKeyGrpcState(
|
||||
"grpc://localhost:3333", "foo:0:DebugIdentity",
|
||||
EventReply::DebugOpStateChange::READ_ONLY);
|
||||
|
||||
std::vector<string> debug_urls_1;
|
||||
ASSERT_FALSE(
|
||||
@ -385,7 +440,6 @@ TEST_F(GrpcDebugTest, TestGateDebugNodeOnNonEmptyEnabledSetAndEmptyURLs) {
|
||||
TEST_F(GrpcDebugTest, TestGateCopyNodeOnEmptyEnabledSet) {
|
||||
const string kGrpcUrl1 = "grpc://localhost:3333";
|
||||
const string kWatch1 = "foo:0:DebugIdentity";
|
||||
CreateEmptyEnabledSet(kGrpcUrl1);
|
||||
|
||||
ASSERT_FALSE(DebugIO::IsCopyNodeGateOpen(
|
||||
{DebugWatchAndURLSpec(kWatch1, kGrpcUrl1, true)}));
|
||||
@ -404,9 +458,8 @@ TEST_F(GrpcDebugTest, TestGateCopyNodeOnNonEmptyEnabledSet) {
|
||||
const string kGrpcUrl2 = "grpc://localhost:3334";
|
||||
const string kWatch1 = "foo:0:DebugIdentity";
|
||||
const string kWatch2 = "foo:1:DebugIdentity";
|
||||
CreateEmptyEnabledSet(kGrpcUrl1);
|
||||
CreateEmptyEnabledSet(kGrpcUrl2);
|
||||
DebugGrpcIO::EnableWatchKey(kGrpcUrl1, kWatch1);
|
||||
DebugGrpcIO::SetDebugNodeKeyGrpcState(
|
||||
kGrpcUrl1, kWatch1, EventReply::DebugOpStateChange::READ_ONLY);
|
||||
|
||||
ASSERT_TRUE(DebugIO::IsCopyNodeGateOpen(
|
||||
{DebugWatchAndURLSpec(kWatch1, kGrpcUrl1, true)}));
|
||||
|
@ -72,27 +72,44 @@ namespace test {
|
||||
output_slots.push_back(metadata.output_slot());
|
||||
debug_ops.push_back(debug_op);
|
||||
debug_tensors.push_back(tensor);
|
||||
|
||||
// If the debug node is currently in the READ_WRITE mode, send an
|
||||
// EventReply to 1) unblock the execution and 2) optionally modify the
|
||||
// value.
|
||||
const DebugNodeKey debug_node_key(metadata.device(), node_name,
|
||||
metadata.output_slot(), debug_op);
|
||||
if (write_enabled_debug_node_keys_.find(debug_node_key) !=
|
||||
write_enabled_debug_node_keys_.end()) {
|
||||
stream->Write(EventReply());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
{
|
||||
mutex_lock l(changes_mu_);
|
||||
for (size_t i = 0; i < changes_to_enable_.size(); ++i) {
|
||||
mutex_lock l(states_mu_);
|
||||
for (size_t i = 0; i < new_states_.size(); ++i) {
|
||||
EventReply event_reply;
|
||||
EventReply::DebugOpStateChange* change =
|
||||
event_reply.add_debug_op_state_changes();
|
||||
change->set_change(changes_to_enable_[i]
|
||||
? EventReply::DebugOpStateChange::ENABLE
|
||||
: EventReply::DebugOpStateChange::DISABLE);
|
||||
change->set_node_name(changes_node_names_[i]);
|
||||
change->set_output_slot(changes_output_slots_[i]);
|
||||
change->set_debug_op(changes_debug_ops_[i]);
|
||||
|
||||
// State changes will take effect in the next stream, i.e., next debugged
|
||||
// Session.run() call.
|
||||
change->set_state(new_states_[i]);
|
||||
const DebugNodeKey& debug_node_key = debug_node_keys_[i];
|
||||
change->set_node_name(debug_node_key.node_name);
|
||||
change->set_output_slot(debug_node_key.output_slot);
|
||||
change->set_debug_op(debug_node_key.debug_op);
|
||||
stream->Write(event_reply);
|
||||
|
||||
if (new_states_[i] == EventReply::DebugOpStateChange::READ_WRITE) {
|
||||
write_enabled_debug_node_keys_.insert(debug_node_key);
|
||||
} else {
|
||||
write_enabled_debug_node_keys_.erase(debug_node_key);
|
||||
}
|
||||
}
|
||||
changes_to_enable_.clear();
|
||||
changes_node_names_.clear();
|
||||
changes_output_slots_.clear();
|
||||
changes_debug_ops_.clear();
|
||||
|
||||
debug_node_keys_.clear();
|
||||
new_states_.clear();
|
||||
}
|
||||
|
||||
return ::grpc::Status::OK;
|
||||
@ -109,13 +126,12 @@ void TestEventListenerImpl::ClearReceivedDebugData() {
|
||||
}
|
||||
|
||||
void TestEventListenerImpl::RequestDebugOpStateChangeAtNextStream(
|
||||
bool to_enable, const DebugNodeKey& debug_node_key) {
|
||||
mutex_lock l(changes_mu_);
|
||||
const EventReply::DebugOpStateChange::State new_state,
|
||||
const DebugNodeKey& debug_node_key) {
|
||||
mutex_lock l(states_mu_);
|
||||
|
||||
changes_to_enable_.push_back(to_enable);
|
||||
changes_node_names_.push_back(debug_node_key.node_name);
|
||||
changes_output_slots_.push_back(debug_node_key.output_slot);
|
||||
changes_debug_ops_.push_back(debug_node_key.debug_op);
|
||||
debug_node_keys_.push_back(debug_node_key);
|
||||
new_states_.push_back(new_state);
|
||||
}
|
||||
|
||||
void TestEventListenerImpl::RunServer(const int server_port) {
|
||||
|
@ -17,6 +17,7 @@ limitations under the License.
|
||||
#define TENSORFLOW_DEBUG_GRPC_TESTLIB_H_
|
||||
|
||||
#include <atomic>
|
||||
#include <unordered_set>
|
||||
|
||||
#include "grpc++/grpc++.h"
|
||||
#include "tensorflow/core/debug/debug_io_utils.h"
|
||||
@ -44,7 +45,8 @@ class TestEventListenerImpl final : public EventListener::Service {
|
||||
void ClearReceivedDebugData();
|
||||
|
||||
void RequestDebugOpStateChangeAtNextStream(
|
||||
bool to_enable, const DebugNodeKey& debug_node_key);
|
||||
const EventReply::DebugOpStateChange::State new_state,
|
||||
const DebugNodeKey& debug_node_key);
|
||||
|
||||
std::vector<string> debug_metadata_strings;
|
||||
std::vector<string> encoded_graph_defs;
|
||||
@ -58,12 +60,13 @@ class TestEventListenerImpl final : public EventListener::Service {
|
||||
std::atomic_bool stop_requested_;
|
||||
std::atomic_bool stopped_;
|
||||
|
||||
std::vector<bool> changes_to_enable_ GUARDED_BY(changes_mu_);
|
||||
std::vector<string> changes_node_names_ GUARDED_BY(changes_mu_);
|
||||
std::vector<int32> changes_output_slots_ GUARDED_BY(changes_mu_);
|
||||
std::vector<string> changes_debug_ops_ GUARDED_BY(changes_mu_);
|
||||
std::vector<DebugNodeKey> debug_node_keys_ GUARDED_BY(states_mu_);
|
||||
std::vector<EventReply::DebugOpStateChange::State> new_states_
|
||||
GUARDED_BY(states_mu_);
|
||||
|
||||
mutex changes_mu_;
|
||||
std::unordered_set<DebugNodeKey> write_enabled_debug_node_keys_;
|
||||
|
||||
mutex states_mu_;
|
||||
};
|
||||
|
||||
// Poll a gRPC debug server by sending a small tensor repeatedly till success.
|
||||
|
@ -15,6 +15,11 @@ limitations under the License.
|
||||
|
||||
#include "tensorflow/core/debug/debug_io_utils.h"
|
||||
|
||||
#include <stddef.h>
|
||||
#include <string.h>
|
||||
#include <cmath>
|
||||
#include <limits>
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
|
||||
#ifndef PLATFORM_WINDOWS
|
||||
@ -297,6 +302,15 @@ DebugNodeKey::DebugNodeKey(const string& device_name, const string& node_name,
|
||||
strings::StrCat(node_name, ":", output_slot, ":", debug_op)),
|
||||
device_path(DeviceNameToDevicePath(device_name)) {}
|
||||
|
||||
bool DebugNodeKey::operator==(const DebugNodeKey& other) const {
|
||||
return (device_name == other.device_name && node_name == other.node_name &&
|
||||
output_slot == other.output_slot && debug_op == other.debug_op);
|
||||
}
|
||||
|
||||
bool DebugNodeKey::operator!=(const DebugNodeKey& other) const {
|
||||
return !((*this) == other);
|
||||
}
|
||||
|
||||
Status ReadEventFromFile(const string& dump_file_path, Event* event) {
|
||||
Env* env(Env::Default());
|
||||
|
||||
@ -537,7 +551,7 @@ bool DebugIO::IsCopyNodeGateOpen(
|
||||
DebugIO::kGrpcURLScheme)) {
|
||||
return true;
|
||||
} else {
|
||||
if (DebugGrpcIO::IsGateOpen(spec.watch_key, spec.url)) {
|
||||
if (DebugGrpcIO::IsReadGateOpen(spec.url, spec.watch_key)) {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
@ -557,7 +571,7 @@ bool DebugIO::IsDebugNodeGateOpen(const string& watch_key,
|
||||
DebugIO::kGrpcURLScheme)) {
|
||||
return true;
|
||||
} else {
|
||||
if (DebugGrpcIO::IsGateOpen(watch_key, debug_url)) {
|
||||
if (DebugGrpcIO::IsReadGateOpen(debug_url, watch_key)) {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
@ -575,7 +589,7 @@ bool DebugIO::IsDebugURLGateOpen(const string& watch_key,
|
||||
if (debug_url.find(kGrpcURLScheme) != 0) {
|
||||
return true;
|
||||
} else {
|
||||
return DebugGrpcIO::IsGateOpen(watch_key, debug_url);
|
||||
return DebugGrpcIO::IsReadGateOpen(debug_url, watch_key);
|
||||
}
|
||||
#else
|
||||
return true;
|
||||
@ -731,6 +745,12 @@ bool DebugGrpcChannel::WriteEvent(const Event& event) {
|
||||
return reader_writer_->Write(event);
|
||||
}
|
||||
|
||||
bool DebugGrpcChannel::ReadEventReply(EventReply* event_reply) {
|
||||
mutex_lock l(mu_);
|
||||
|
||||
return reader_writer_->Read(event_reply);
|
||||
}
|
||||
|
||||
Status DebugGrpcChannel::ReceiveServerRepliesAndClose() {
|
||||
mutex_lock l(mu_);
|
||||
|
||||
@ -744,13 +764,8 @@ Status DebugGrpcChannel::ReceiveServerRepliesAndClose() {
|
||||
string watch_key = strings::StrCat(debug_op_state_change.node_name(), ":",
|
||||
debug_op_state_change.output_slot(),
|
||||
":", debug_op_state_change.debug_op());
|
||||
if (debug_op_state_change.change() ==
|
||||
EventReply::DebugOpStateChange::ENABLE) {
|
||||
DebugGrpcIO::EnableWatchKey(url_, watch_key);
|
||||
} else if (debug_op_state_change.change() ==
|
||||
EventReply::DebugOpStateChange::DISABLE) {
|
||||
DebugGrpcIO::DisableWatchKey(url_, watch_key);
|
||||
}
|
||||
DebugGrpcIO::SetDebugNodeKeyGrpcState(url_, watch_key,
|
||||
debug_op_state_change.state());
|
||||
}
|
||||
}
|
||||
|
||||
@ -789,7 +804,8 @@ Status DebugGrpcIO::SendTensorThroughGrpcStream(
|
||||
const DebugNodeKey& debug_node_key, const Tensor& tensor,
|
||||
const uint64 wall_time_us, const string& grpc_stream_url,
|
||||
const bool gated) {
|
||||
if (gated && !IsGateOpen(debug_node_key.debug_node_name, grpc_stream_url)) {
|
||||
if (gated &&
|
||||
!IsReadGateOpen(grpc_stream_url, debug_node_key.debug_node_name)) {
|
||||
return Status::OK();
|
||||
} else {
|
||||
std::vector<Event> events;
|
||||
@ -799,10 +815,35 @@ Status DebugGrpcIO::SendTensorThroughGrpcStream(
|
||||
TF_RETURN_IF_ERROR(
|
||||
SendEventProtoThroughGrpcStream(event, grpc_stream_url));
|
||||
}
|
||||
if (IsWriteGateOpen(grpc_stream_url, debug_node_key.debug_node_name)) {
|
||||
EventReply event_reply;
|
||||
TF_RETURN_IF_ERROR(ReceiveEventReplyProtoThroughGrpcStream(
|
||||
&event_reply, grpc_stream_url));
|
||||
// TODO(cais): Support new tensor value carried in the EventReply for
|
||||
// overriding the value of the tensor being published.
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
}
|
||||
|
||||
// static
|
||||
Status DebugGrpcIO::ReceiveEventReplyProtoThroughGrpcStream(
|
||||
EventReply* event_reply, const string& grpc_stream_url) {
|
||||
std::shared_ptr<DebugGrpcChannel> debug_grpc_channel;
|
||||
{
|
||||
mutex_lock l(streams_mu);
|
||||
std::unordered_map<string, std::shared_ptr<DebugGrpcChannel>>*
|
||||
stream_channels = GetStreamChannels();
|
||||
debug_grpc_channel = (*stream_channels)[grpc_stream_url];
|
||||
}
|
||||
if (debug_grpc_channel->ReadEventReply(event_reply)) {
|
||||
return Status::OK();
|
||||
} else {
|
||||
return errors::Cancelled(strings::StrCat(
|
||||
"Reading EventReply from stream URL ", grpc_stream_url, " failed."));
|
||||
}
|
||||
}
|
||||
|
||||
// static
|
||||
Status DebugGrpcIO::SendEventProtoThroughGrpcStream(
|
||||
const Event& event_proto, const string& grpc_stream_url) {
|
||||
@ -821,9 +862,7 @@ Status DebugGrpcIO::SendEventProtoThroughGrpcStream(
|
||||
debug_grpc_channel.reset(new DebugGrpcChannel(server_stream_addr));
|
||||
TF_RETURN_IF_ERROR(
|
||||
debug_grpc_channel->Connect(channel_connection_timeout_micros));
|
||||
|
||||
(*stream_channels)[grpc_stream_url] = debug_grpc_channel;
|
||||
CreateEmptyEnabledSet(grpc_stream_url);
|
||||
} else {
|
||||
debug_grpc_channel = (*stream_channels)[grpc_stream_url];
|
||||
}
|
||||
@ -838,16 +877,22 @@ Status DebugGrpcIO::SendEventProtoThroughGrpcStream(
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
// static
|
||||
bool DebugGrpcIO::IsGateOpen(const string& watch_key,
|
||||
const string& grpc_debug_url) {
|
||||
std::unordered_map<string, std::unordered_set<string>>* enabled_watch_keys =
|
||||
GetEnabledWatchKeys();
|
||||
if (enabled_watch_keys->find(grpc_debug_url) == enabled_watch_keys->end()) {
|
||||
bool DebugGrpcIO::IsReadGateOpen(const string& grpc_debug_url,
|
||||
const string& watch_key) {
|
||||
const DebugNodeName2State* enabled_node_to_state =
|
||||
GetEnabledDebugOpStatesAtUrl(grpc_debug_url);
|
||||
return enabled_node_to_state->find(watch_key) != enabled_node_to_state->end();
|
||||
}
|
||||
|
||||
bool DebugGrpcIO::IsWriteGateOpen(const string& grpc_debug_url,
|
||||
const string& watch_key) {
|
||||
const DebugNodeName2State* enabled_node_to_state =
|
||||
GetEnabledDebugOpStatesAtUrl(grpc_debug_url);
|
||||
auto it = enabled_node_to_state->find(watch_key);
|
||||
if (it == enabled_node_to_state->end()) {
|
||||
return false;
|
||||
} else {
|
||||
const auto& url_enabled = (*enabled_watch_keys)[grpc_debug_url];
|
||||
return url_enabled.find(watch_key) != url_enabled.end();
|
||||
return it->second == EventReply::DebugOpStateChange::READ_WRITE;
|
||||
}
|
||||
}
|
||||
|
||||
@ -871,56 +916,46 @@ Status DebugGrpcIO::CloseGrpcStream(const string& grpc_stream_url) {
|
||||
}
|
||||
|
||||
// static
|
||||
std::unordered_map<string, std::unordered_set<string>>*
|
||||
DebugGrpcIO::GetEnabledWatchKeys() {
|
||||
static std::unordered_map<string, std::unordered_set<string>>*
|
||||
enabled_watch_keys =
|
||||
new std::unordered_map<string, std::unordered_set<string>>();
|
||||
return enabled_watch_keys;
|
||||
std::unordered_map<string, DebugGrpcIO::DebugNodeName2State>*
|
||||
DebugGrpcIO::GetEnabledDebugOpStates() {
|
||||
static std::unordered_map<string, DebugNodeName2State>*
|
||||
enabled_debug_op_states =
|
||||
new std::unordered_map<string, DebugNodeName2State>();
|
||||
return enabled_debug_op_states;
|
||||
}
|
||||
|
||||
// static
|
||||
void DebugGrpcIO::EnableWatchKey(const string& grpc_debug_url,
|
||||
const string& watch_key) {
|
||||
std::unordered_map<string, std::unordered_set<string>>* enabled_watch_keys =
|
||||
GetEnabledWatchKeys();
|
||||
if (enabled_watch_keys->find(grpc_debug_url) == enabled_watch_keys->end()) {
|
||||
CreateEmptyEnabledSet(grpc_debug_url);
|
||||
DebugGrpcIO::DebugNodeName2State* DebugGrpcIO::GetEnabledDebugOpStatesAtUrl(
|
||||
const string& grpc_debug_url) {
|
||||
std::unordered_map<string, DebugNodeName2State>* states =
|
||||
GetEnabledDebugOpStates();
|
||||
if (states->find(grpc_debug_url) == states->end()) {
|
||||
DebugNodeName2State url_enabled_debug_op_states;
|
||||
(*states)[grpc_debug_url] = url_enabled_debug_op_states;
|
||||
}
|
||||
(*enabled_watch_keys)[grpc_debug_url].insert(watch_key);
|
||||
return &(*states)[grpc_debug_url];
|
||||
}
|
||||
|
||||
// static
|
||||
void DebugGrpcIO::DisableWatchKey(const string& grpc_debug_url,
|
||||
const string& watch_key) {
|
||||
std::unordered_map<string, std::unordered_set<string>>* enabled_watch_keys =
|
||||
GetEnabledWatchKeys();
|
||||
if (enabled_watch_keys->find(grpc_debug_url) == enabled_watch_keys->end()) {
|
||||
LOG(WARNING) << "Attempt to disable a watch key for an unregistered gRPC "
|
||||
<< "debug URL: " << grpc_debug_url;
|
||||
} else {
|
||||
std::unordered_set<string>& url_enabled =
|
||||
(*enabled_watch_keys)[grpc_debug_url];
|
||||
if (url_enabled.find(watch_key) == url_enabled.end()) {
|
||||
LOG(WARNING) << "Attempt to disable a watch key that is not currently "
|
||||
<< "enabled at " << grpc_debug_url << ": " << watch_key;
|
||||
void DebugGrpcIO::SetDebugNodeKeyGrpcState(
|
||||
const string& grpc_debug_url, const string& watch_key,
|
||||
const EventReply::DebugOpStateChange::State new_state) {
|
||||
DebugNodeName2State* states = GetEnabledDebugOpStatesAtUrl(grpc_debug_url);
|
||||
if (new_state == EventReply::DebugOpStateChange::DISABLED) {
|
||||
if (states->find(watch_key) == states->end()) {
|
||||
LOG(ERROR) << "Attempt to disable a watch key that is not currently "
|
||||
<< "enabled at " << grpc_debug_url << ": " << watch_key;
|
||||
} else {
|
||||
url_enabled.erase(watch_key);
|
||||
states->erase(watch_key);
|
||||
}
|
||||
} else if (new_state != EventReply::DebugOpStateChange::STATE_UNSPECIFIED) {
|
||||
(*states)[watch_key] = new_state;
|
||||
}
|
||||
}
|
||||
|
||||
// static
|
||||
void DebugGrpcIO::ClearEnabledWatchKeys() { GetEnabledWatchKeys()->clear(); }
|
||||
|
||||
// static
|
||||
void DebugGrpcIO::CreateEmptyEnabledSet(const string& grpc_debug_url) {
|
||||
std::unordered_map<string, std::unordered_set<string>>* enabled_watch_keys =
|
||||
GetEnabledWatchKeys();
|
||||
if (enabled_watch_keys->find(grpc_debug_url) == enabled_watch_keys->end()) {
|
||||
std::unordered_set<string> empty_watch_keys;
|
||||
(*enabled_watch_keys)[grpc_debug_url] = empty_watch_keys;
|
||||
}
|
||||
void DebugGrpcIO::ClearEnabledWatchKeys() {
|
||||
GetEnabledDebugOpStates()->clear();
|
||||
}
|
||||
|
||||
#endif // #ifndef PLATFORM_WINDOWS
|
||||
|
@ -16,8 +16,13 @@ limitations under the License.
|
||||
#ifndef TENSORFLOW_DEBUG_IO_UTILS_H_
|
||||
#define TENSORFLOW_DEBUG_IO_UTILS_H_
|
||||
|
||||
#include <cstddef>
|
||||
#include <functional>
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <unordered_map>
|
||||
#include <unordered_set>
|
||||
#include <vector>
|
||||
|
||||
#include "tensorflow/core/framework/tensor.h"
|
||||
#include "tensorflow/core/graph/graph.h"
|
||||
@ -49,6 +54,9 @@ struct DebugNodeKey {
|
||||
// ,job_localhost,replica_0,task_0,cpu_0.
|
||||
static const string DeviceNameToDevicePath(const string& device_name);
|
||||
|
||||
bool operator==(const DebugNodeKey& other) const;
|
||||
bool operator!=(const DebugNodeKey& other) const;
|
||||
|
||||
const string device_name;
|
||||
const string node_name;
|
||||
const int32 output_slot;
|
||||
@ -219,6 +227,19 @@ class DebugFileIO {
|
||||
|
||||
} // namespace tensorflow
|
||||
|
||||
namespace std {
|
||||
|
||||
template <>
|
||||
struct hash<::tensorflow::DebugNodeKey> {
|
||||
size_t operator()(const ::tensorflow::DebugNodeKey& k) const {
|
||||
return ::tensorflow::Hash64(
|
||||
::tensorflow::strings::StrCat(k.device_name, ":", k.node_name, ":",
|
||||
k.output_slot, ":", k.debug_op, ":"));
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace std
|
||||
|
||||
// TODO(cais): Support grpc:// debug URLs in open source once Python grpc
|
||||
// genrule becomes available. See b/23796275.
|
||||
#ifndef PLATFORM_WINDOWS
|
||||
@ -252,7 +273,8 @@ class DebugGrpcChannel {
|
||||
// Write an Event proto to the debug gRPC stream.
|
||||
//
|
||||
// Thread-safety: Safe with respect to other calls to the same method and
|
||||
// call to Close().
|
||||
// calls to ReadEventReply() and Close().
|
||||
//
|
||||
// Args:
|
||||
// event: The event proto to be written to the stream.
|
||||
//
|
||||
@ -260,6 +282,19 @@ class DebugGrpcChannel {
|
||||
// True iff the write is successful.
|
||||
bool WriteEvent(const Event& event);
|
||||
|
||||
// Read an EventReply proto from the debug gRPC stream.
|
||||
//
|
||||
// This method blocks and waits for an EventReply from the server.
|
||||
// Thread-safety: Safe with respect to other calls to the same method and
|
||||
// calls to WriteEvent() and Close().
|
||||
//
|
||||
// Args:
|
||||
// event_reply: the to-be-modified EventReply proto passed as reference.
|
||||
//
|
||||
// Returns:
|
||||
// True iff the read is successful.
|
||||
bool ReadEventReply(EventReply* event_reply);
|
||||
|
||||
// Receive EventReplies from server (if any) and close the stream and the
|
||||
// channel.
|
||||
Status ReceiveServerRepliesAndClose();
|
||||
@ -294,48 +329,49 @@ class DebugGrpcIO {
|
||||
static Status SendEventProtoThroughGrpcStream(const Event& event_proto,
|
||||
const string& grpc_stream_url);
|
||||
|
||||
// Checks whether a debug watch key is allowed to send data to a given grpc://
|
||||
// debug URL given the current gating status.
|
||||
//
|
||||
// Args:
|
||||
// watch_key: debug tensor watch key, in the format of
|
||||
// tensor_name:debug_op, e.g., "Weights:0:DebugIdentity".
|
||||
// grpc_debug_url: the debug URL, e.g., "grpc://localhost:3333",
|
||||
//
|
||||
// Returns:
|
||||
// Whether the sending of debug data to grpc_debug_url should
|
||||
// proceed.
|
||||
static bool IsGateOpen(const string& watch_key, const string& grpc_debug_url);
|
||||
// Receive an EventReply proto through a debug gRPC stream.
|
||||
static Status ReceiveEventReplyProtoThroughGrpcStream(
|
||||
EventReply* event_reply, const string& grpc_stream_url);
|
||||
|
||||
// Check whether a debug watch key is read-activated at a given gRPC URL.
|
||||
static bool IsReadGateOpen(const string& grpc_debug_url,
|
||||
const string& watch_key);
|
||||
|
||||
// Check whether a debug watch key is write-activated (i.e., read- and
|
||||
// write-activated) at a given gRPC URL.
|
||||
static bool IsWriteGateOpen(const string& grpc_debug_url,
|
||||
const string& watch_key);
|
||||
|
||||
// Closes a gRPC stream to the given address, if it exists.
|
||||
// Thread-safety: Safe with respect to other calls to the same method and
|
||||
// calls to SendTensorThroughGrpcStream().
|
||||
static Status CloseGrpcStream(const string& grpc_stream_url);
|
||||
|
||||
// Enables a debug watch key at a grpc:// debug URL.
|
||||
static void EnableWatchKey(const string& grpc_debug_url,
|
||||
const string& watch_key);
|
||||
|
||||
// Disables a debug watch key at a grpc:// debug URL.
|
||||
static void DisableWatchKey(const string& grpc_debug_url,
|
||||
const string& watch_key);
|
||||
// Set the gRPC state of a debug node key.
|
||||
// TODO(cais): Include device information in watch_key.
|
||||
static void SetDebugNodeKeyGrpcState(
|
||||
const string& grpc_debug_url, const string& watch_key,
|
||||
const EventReply::DebugOpStateChange::State new_state);
|
||||
|
||||
private:
|
||||
using DebugNodeName2State =
|
||||
std::unordered_map<string, EventReply::DebugOpStateChange::State>;
|
||||
|
||||
// Returns a global map from grpc debug URLs to the corresponding
|
||||
// DebugGrpcChannels.
|
||||
static std::unordered_map<string, std::shared_ptr<DebugGrpcChannel>>*
|
||||
GetStreamChannels();
|
||||
|
||||
// Returns a global map from grpc debug URLs to the enabled gated debug nodes.
|
||||
// The keys are grpc:// URLs of the debug servers, e.g., "grpc://foo:3333".
|
||||
// Each value element of the value has the format
|
||||
// <node_name>:<output_slot>:<debug_op>", e.g.,
|
||||
// "Weights_1:0:DebugNumericSummary".
|
||||
static std::unordered_map<string, std::unordered_set<string>>*
|
||||
GetEnabledWatchKeys();
|
||||
// Returns a map from debug URL to a map from debug op name to enabled state.
|
||||
static std::unordered_map<string, DebugNodeName2State>*
|
||||
GetEnabledDebugOpStates();
|
||||
|
||||
// Returns a map from debug op names to enabled state, for a given debug URL.
|
||||
static DebugNodeName2State* GetEnabledDebugOpStatesAtUrl(
|
||||
const string& grpc_debug_url);
|
||||
|
||||
// Clear enabled debug op state from all debug URLs (if any).
|
||||
static void ClearEnabledWatchKeys();
|
||||
static void CreateEmptyEnabledSet(const string& grpc_debug_url);
|
||||
|
||||
static mutex streams_mu;
|
||||
static int64 channel_connection_timeout_micros;
|
||||
|
@ -13,6 +13,8 @@ See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include <unordered_set>
|
||||
|
||||
#include "tensorflow/core/debug/debug_io_utils.h"
|
||||
|
||||
#include "tensorflow/core/debug/debugger_event_metadata.pb.h"
|
||||
@ -62,6 +64,39 @@ TEST_F(DebugIOUtilsTest, ConstructDebugNodeKey) {
|
||||
debug_node_key.device_path);
|
||||
}
|
||||
|
||||
TEST_F(DebugIOUtilsTest, EqualityOfDebugNodeKeys) {
|
||||
const DebugNodeKey debug_node_key_1("/job:worker/replica:1/task:0/gpu:2",
|
||||
"hidden_1/MatMul", 0, "DebugIdentity");
|
||||
const DebugNodeKey debug_node_key_2("/job:worker/replica:1/task:0/gpu:2",
|
||||
"hidden_1/MatMul", 0, "DebugIdentity");
|
||||
const DebugNodeKey debug_node_key_3("/job:worker/replica:1/task:0/gpu:2",
|
||||
"hidden_1/BiasAdd", 0, "DebugIdentity");
|
||||
const DebugNodeKey debug_node_key_4("/job:worker/replica:1/task:0/gpu:2",
|
||||
"hidden_1/MatMul", 0,
|
||||
"DebugNumericSummary");
|
||||
EXPECT_EQ(debug_node_key_1, debug_node_key_2);
|
||||
EXPECT_NE(debug_node_key_1, debug_node_key_3);
|
||||
EXPECT_NE(debug_node_key_1, debug_node_key_4);
|
||||
EXPECT_NE(debug_node_key_3, debug_node_key_4);
|
||||
}
|
||||
|
||||
TEST_F(DebugIOUtilsTest, DebugNodeKeysIsHashable) {
|
||||
const DebugNodeKey debug_node_key_1("/job:worker/replica:1/task:0/gpu:2",
|
||||
"hidden_1/MatMul", 0, "DebugIdentity");
|
||||
const DebugNodeKey debug_node_key_2("/job:worker/replica:1/task:0/gpu:2",
|
||||
"hidden_1/MatMul", 0, "DebugIdentity");
|
||||
const DebugNodeKey debug_node_key_3("/job:worker/replica:1/task:0/gpu:2",
|
||||
"hidden_1/BiasAdd", 0, "DebugIdentity");
|
||||
|
||||
std::unordered_set<DebugNodeKey> keys;
|
||||
keys.insert(debug_node_key_1);
|
||||
ASSERT_EQ(1, keys.size());
|
||||
keys.insert(debug_node_key_3);
|
||||
ASSERT_EQ(2, keys.size());
|
||||
keys.erase(debug_node_key_2);
|
||||
ASSERT_EQ(1, keys.size());
|
||||
}
|
||||
|
||||
TEST_F(DebugIOUtilsTest, DumpFloatTensorToFileSunnyDay) {
|
||||
Initialize();
|
||||
|
||||
|
@ -17,6 +17,7 @@ syntax = "proto3";
|
||||
|
||||
package tensorflow;
|
||||
|
||||
import "tensorflow/core/framework/tensor.proto";
|
||||
import "tensorflow/core/util/event.proto";
|
||||
|
||||
// Reply message from EventListener to the client, i.e., to the source of the
|
||||
@ -24,18 +25,25 @@ import "tensorflow/core/util/event.proto";
|
||||
// TensorFlow graph being executed.
|
||||
message EventReply {
|
||||
message DebugOpStateChange {
|
||||
enum Change {
|
||||
DISABLE = 0;
|
||||
ENABLE = 1;
|
||||
enum State {
|
||||
STATE_UNSPECIFIED = 0;
|
||||
DISABLED = 1;
|
||||
READ_ONLY = 2;
|
||||
READ_WRITE = 3;
|
||||
}
|
||||
|
||||
Change change = 1;
|
||||
State state = 1;
|
||||
string node_name = 2;
|
||||
int32 output_slot = 3;
|
||||
string debug_op = 4;
|
||||
}
|
||||
|
||||
repeated DebugOpStateChange debug_op_state_changes = 1;
|
||||
|
||||
// New tensor value to override the current tensor value with.
|
||||
TensorProto tensor = 2;
|
||||
// TODO(cais): Make use of this field to implement overriding of tensor value
|
||||
// during debugging.
|
||||
}
|
||||
|
||||
// EventListener: Receives Event protos, e.g., from debugged TensorFlow
|
||||
|
@ -259,10 +259,6 @@ class DebugNumericSummaryOpTest : public OpsTestBase {
|
||||
|
||||
#if defined(PLATFORM_GOOGLE)
|
||||
void ClearEnabledWatchKeys() { DebugGrpcIO::ClearEnabledWatchKeys(); }
|
||||
|
||||
void CreateEmptyEnabledSet(const string& grpc_debug_url) {
|
||||
DebugGrpcIO::CreateEmptyEnabledSet(grpc_debug_url);
|
||||
}
|
||||
#endif
|
||||
};
|
||||
|
||||
@ -604,7 +600,6 @@ TEST_F(DebugNumericSummaryOpTest, BoolSuccess) {
|
||||
#if defined(PLATFORM_GOOGLE)
|
||||
TEST_F(DebugNumericSummaryOpTest, DisabledDueToEmptyEnabledSet) {
|
||||
ClearEnabledWatchKeys();
|
||||
CreateEmptyEnabledSet("grpc://server:3333");
|
||||
|
||||
std::vector<string> debug_urls({"grpc://server:3333"});
|
||||
TF_ASSERT_OK(InitGated(DT_FLOAT, debug_urls));
|
||||
@ -617,8 +612,9 @@ TEST_F(DebugNumericSummaryOpTest, DisabledDueToEmptyEnabledSet) {
|
||||
|
||||
TEST_F(DebugNumericSummaryOpTest, DisabledDueToNonMatchingWatchKey) {
|
||||
ClearEnabledWatchKeys();
|
||||
DebugGrpcIO::EnableWatchKey("grpc://server:3333",
|
||||
"FakeTensor:1:DebugNumeriSummary");
|
||||
DebugGrpcIO::SetDebugNodeKeyGrpcState(
|
||||
"grpc://server:3333", "FakeTensor:1:DebugNumeriSummary",
|
||||
EventReply::DebugOpStateChange::READ_ONLY);
|
||||
|
||||
std::vector<string> debug_urls({"grpc://server:3333"});
|
||||
TF_ASSERT_OK(InitGated(DT_FLOAT, debug_urls));
|
||||
|
@ -31,16 +31,20 @@ from tensorflow.core.debug import debug_service_pb2
|
||||
from tensorflow.core.framework import graph_pb2
|
||||
from tensorflow.python.debug.lib import debug_data
|
||||
from tensorflow.python.debug.lib import debug_service_pb2_grpc
|
||||
from tensorflow.python.platform import tf_logging as logging
|
||||
|
||||
DebugWatch = collections.namedtuple("DebugWatch",
|
||||
["node_name", "output_slot", "debug_op"])
|
||||
|
||||
|
||||
def _watch_key_event_reply(to_enable, node_name, output_slot, debug_op):
|
||||
"""Make EventReply proto to represent a request to watch/unwatch a debug op.
|
||||
def _watch_key_event_reply(new_state, node_name, output_slot, debug_op):
|
||||
"""Make `EventReply` proto to represent a request to watch/unwatch a debug op.
|
||||
|
||||
Args:
|
||||
to_enable: (`bool`) whether the request is to enable the watch key.
|
||||
new_state: (`debug_service_pb2.EventReply.DebugOpStateChange.State`) the new
|
||||
state to set the debug node to, i.e., whether the debug node will become
|
||||
disabled under the grpc mode (`DISABLED`), become a watchpoint
|
||||
(`READ_ONLY`) or become a breakpoint (`READ_WRITE`).
|
||||
node_name: (`str`) name of the node.
|
||||
output_slot: (`int`) output slot of the tensor.
|
||||
debug_op: (`str`) the debug op attached to node_name:output_slot tensor to
|
||||
@ -51,9 +55,7 @@ def _watch_key_event_reply(to_enable, node_name, output_slot, debug_op):
|
||||
"""
|
||||
event_reply = debug_service_pb2.EventReply()
|
||||
state_change = event_reply.debug_op_state_changes.add()
|
||||
state_change.change = (
|
||||
debug_service_pb2.EventReply.DebugOpStateChange.ENABLE
|
||||
if to_enable else debug_service_pb2.EventReply.DebugOpStateChange.DISABLE)
|
||||
state_change.state = new_state
|
||||
state_change.node_name = node_name
|
||||
state_change.output_slot = output_slot
|
||||
state_change.debug_op = debug_op
|
||||
@ -125,6 +127,7 @@ class EventListenerBaseServicer(debug_service_pb2_grpc.EventListenerServicer):
|
||||
|
||||
self._event_reply_queue = queue.Queue()
|
||||
self._gated_grpc_debug_watches = set()
|
||||
self._breakpoints = set()
|
||||
|
||||
def SendEvents(self, request_iterator, context):
|
||||
"""Implementation of the SendEvents service method.
|
||||
@ -171,11 +174,30 @@ class EventListenerBaseServicer(debug_service_pb2_grpc.EventListenerServicer):
|
||||
maybe_tensor_event = self._process_tensor_event_in_chunks(
|
||||
event, tensor_chunks)
|
||||
if maybe_tensor_event:
|
||||
stream_handler.on_value_event(maybe_tensor_event)
|
||||
event_reply = stream_handler.on_value_event(maybe_tensor_event)
|
||||
if event_reply is not None:
|
||||
yield event_reply
|
||||
|
||||
# The server writes EventReply messages, if any.
|
||||
while not self._event_reply_queue.empty():
|
||||
yield self._event_reply_queue.get()
|
||||
event_reply = self._event_reply_queue.get()
|
||||
for state_change in event_reply.debug_op_state_changes:
|
||||
if (state_change.state ==
|
||||
debug_service_pb2.EventReply.DebugOpStateChange.READ_WRITE):
|
||||
logging.info("Adding breakpoint %s:%d:%s", state_change.node_name,
|
||||
state_change.output_slot, state_change.debug_op)
|
||||
self._breakpoints.add(
|
||||
(state_change.node_name, state_change.output_slot,
|
||||
state_change.debug_op))
|
||||
elif (state_change.state ==
|
||||
debug_service_pb2.EventReply.DebugOpStateChange.DISABLED):
|
||||
logging.info("Removing watchpoint or breakpoint: %s:%d:%s",
|
||||
state_change.node_name, state_change.output_slot,
|
||||
state_change.debug_op)
|
||||
self._breakpoints.discard(
|
||||
(state_change.node_name, state_change.output_slot,
|
||||
state_change.debug_op))
|
||||
yield event_reply
|
||||
|
||||
def _process_tensor_event_in_chunks(self, event, tensor_chunks):
|
||||
"""Possibly reassemble event chunks.
|
||||
@ -336,8 +358,8 @@ class EventListenerBaseServicer(debug_service_pb2_grpc.EventListenerServicer):
|
||||
finally:
|
||||
self._server_lock.release()
|
||||
|
||||
def request_watch(self, node_name, output_slot, debug_op):
|
||||
"""Request enabling a debug tensor watch.
|
||||
def request_watch(self, node_name, output_slot, debug_op, breakpoint=False):
|
||||
"""Request enabling a debug tensor watchpoint or breakpoint.
|
||||
|
||||
This will let the server send a EventReply to the client side
|
||||
(i.e., the debugged TensorFlow runtime process) to request adding a watch
|
||||
@ -355,12 +377,19 @@ class EventListenerBaseServicer(debug_service_pb2_grpc.EventListenerServicer):
|
||||
output_slot: (`int`) output slot index of the tensor to watch.
|
||||
debug_op: (`str`) name of the debug op to enable. This should not include
|
||||
any attribute substrings.
|
||||
breakpoint: (`bool`) Iff `True`, the debug op will block and wait until it
|
||||
receives an `EventReply` response from the server. The `EventReply`
|
||||
proto may carry a TensorProto that modifies the value of the debug op's
|
||||
output tensor.
|
||||
"""
|
||||
self._event_reply_queue.put(
|
||||
_watch_key_event_reply(True, node_name, output_slot, debug_op))
|
||||
_watch_key_event_reply(
|
||||
debug_service_pb2.EventReply.DebugOpStateChange.READ_WRITE
|
||||
if breakpoint else debug_service_pb2.EventReply.DebugOpStateChange.
|
||||
READ_ONLY, node_name, output_slot, debug_op))
|
||||
|
||||
def request_unwatch(self, node_name, output_slot, debug_op):
|
||||
"""Request disabling a debug tensor watch.
|
||||
"""Request disabling a debug tensor watchpoint or breakpoint.
|
||||
|
||||
The request will take effect on the next debugged `Session.run()` call.
|
||||
|
||||
@ -374,7 +403,18 @@ class EventListenerBaseServicer(debug_service_pb2_grpc.EventListenerServicer):
|
||||
any attribute substrings.
|
||||
"""
|
||||
self._event_reply_queue.put(
|
||||
_watch_key_event_reply(False, node_name, output_slot, debug_op))
|
||||
_watch_key_event_reply(debug_service_pb2.EventReply.DebugOpStateChange.
|
||||
DISABLED, node_name, output_slot, debug_op))
|
||||
|
||||
@property
|
||||
def breakpoints(self):
|
||||
"""Get a set of the currently-activated breakpoints.
|
||||
|
||||
Returns:
|
||||
A `set` of 3-tuples: (node_name, output_slot, debug_op), e.g.,
|
||||
{("MatMul", 0, "DebugIdentity")}.
|
||||
"""
|
||||
return self._breakpoints
|
||||
|
||||
def gated_grpc_debug_watches(self):
|
||||
"""Get the list of debug watches with attribute gated_grpc=True.
|
||||
|
@ -31,6 +31,7 @@ import time
|
||||
|
||||
import portpicker
|
||||
|
||||
from tensorflow.core.debug import debug_service_pb2
|
||||
from tensorflow.core.protobuf import config_pb2
|
||||
from tensorflow.core.util import event_pb2
|
||||
from tensorflow.python.client import session
|
||||
@ -138,13 +139,26 @@ class EventListenerTestStreamHandler(
|
||||
|
||||
Args:
|
||||
event: The Event proto carrying a tensor value.
|
||||
|
||||
Returns:
|
||||
If the debug node belongs to the set of currently activated breakpoints,
|
||||
a `EventReply` proto will be returned.
|
||||
"""
|
||||
if self._dump_dir:
|
||||
self._write_value_event(event)
|
||||
else:
|
||||
value = event.summary.value[0]
|
||||
tensor_value = debug_data.load_tensor_from_event(event)
|
||||
self._event_listener_servicer.debug_tensor_values[value.node_name].append(
|
||||
debug_data.load_tensor_from_event(event))
|
||||
tensor_value)
|
||||
|
||||
items = event.summary.value[0].node_name.split(":")
|
||||
node_name = items[0]
|
||||
output_slot = int(items[1])
|
||||
debug_op = items[2]
|
||||
if ((node_name, output_slot, debug_op) in
|
||||
self._event_listener_servicer.breakpoints):
|
||||
return debug_service_pb2.EventReply()
|
||||
|
||||
def _try_makedirs(self, dir_path):
|
||||
if not os.path.isdir(dir_path):
|
||||
|
@ -556,6 +556,55 @@ class SessionDebugGrpcGatingTest(test_util.TensorFlowTestCase):
|
||||
[50 + 5.0 * i],
|
||||
self._server_2.debug_tensor_values["v:0:DebugIdentity"])
|
||||
|
||||
def testToggleBreakpointWorks(self):
|
||||
with session.Session(config=no_rewrite_session_config()) as sess:
|
||||
v = variables.Variable(50.0, name="v")
|
||||
delta = constant_op.constant(5.0, name="delta")
|
||||
inc_v = state_ops.assign_add(v, delta, name="inc_v")
|
||||
|
||||
sess.run(v.initializer)
|
||||
|
||||
run_metadata = config_pb2.RunMetadata()
|
||||
run_options = config_pb2.RunOptions(output_partition_graphs=True)
|
||||
debug_utils.watch_graph(
|
||||
run_options,
|
||||
sess.graph,
|
||||
debug_ops=["DebugIdentity(gated_grpc=true)"],
|
||||
debug_urls=[self._debug_server_url_1])
|
||||
|
||||
for i in xrange(4):
|
||||
self._server_1.clear_data()
|
||||
|
||||
# N.B.: These requests will be fulfilled not in this debugged
|
||||
# Session.run() invocation, but in the next one.
|
||||
if i in (0, 2):
|
||||
# Enable breakpoint at delta:0:DebugIdentity in runs 0 and 2.
|
||||
self._server_1.request_watch(
|
||||
"delta", 0, "DebugIdentity", breakpoint=True)
|
||||
else:
|
||||
# Disable the breakpoint in runs 1 and 3.
|
||||
self._server_1.request_unwatch("delta", 0, "DebugIdentity")
|
||||
|
||||
output = sess.run(inc_v, options=run_options, run_metadata=run_metadata)
|
||||
self.assertAllClose(50.0 + 5.0 * (i + 1), output)
|
||||
|
||||
if i in (0, 2):
|
||||
# After the end of runs 0 and 2, the server has received the requests
|
||||
# to enable the breakpoint at delta:0:DebugIdentity. So the server
|
||||
# should keep track of the correct breakpoints.
|
||||
self.assertSetEqual({("delta", 0, "DebugIdentity")},
|
||||
self._server_1.breakpoints)
|
||||
else:
|
||||
# During runs 1 and 3, the server should have received the published
|
||||
# debug tensor delta:0:DebugIdentity. The breakpoint should have been
|
||||
# unblocked by EventReply reponses from the server.
|
||||
self.assertAllClose(
|
||||
[5.0],
|
||||
self._server_1.debug_tensor_values["delta:0:DebugIdentity"])
|
||||
# After the runs, the server should have properly removed the
|
||||
# breakpoints due to the request_unwatch calls.
|
||||
self.assertSetEqual(set(), self._server_1.breakpoints)
|
||||
|
||||
def testGetGrpcDebugWatchesReturnsCorrectAnswer(self):
|
||||
with session.Session() as sess:
|
||||
v = variables.Variable(50.0, name="v")
|
||||
|
Loading…
Reference in New Issue
Block a user