From 3c482c66b5a1f74875969e96834ff7564e829668 Mon Sep 17 00:00:00 2001 From: Shanqing Cai Date: Fri, 11 Aug 2017 09:37:02 -0700 Subject: [PATCH] tfdbg: extend grpc_debug_server protocol for interactive debugging Previously, a grpc-gated debug op has two modes: DISABLED and ENABLED. This CL splits the ENABLED state into two states: READ_ONLY and READ_WRITE. * READ_ONLY is equivalent to the previous ENABLED state, wherein a debug op publishes debug tensor to the grpc debug server and proceeds. It can be regarded as a "watchpoint" that doesn't block execution. * READ_WRITE is a "breakpoint". In addition to publishing the debug tensor, it blocks and awaits a EventReply proto response from the grpc debug server before proceeding. PiperOrigin-RevId: 164987725 --- .../core/debug/debug_grpc_io_utils_test.cc | 91 ++++++++--- tensorflow/core/debug/debug_grpc_testlib.cc | 52 +++--- tensorflow/core/debug/debug_grpc_testlib.h | 15 +- tensorflow/core/debug/debug_io_utils.cc | 151 +++++++++++------- tensorflow/core/debug/debug_io_utils.h | 92 +++++++---- tensorflow/core/debug/debug_io_utils_test.cc | 35 ++++ tensorflow/core/debug/debug_service.proto | 16 +- tensorflow/core/kernels/debug_ops_test.cc | 10 +- .../python/debug/lib/grpc_debug_server.py | 66 ++++++-- .../debug/lib/grpc_debug_test_server.py | 16 +- .../debug/lib/session_debug_grpc_test.py | 49 ++++++ 11 files changed, 439 insertions(+), 154 deletions(-) diff --git a/tensorflow/core/debug/debug_grpc_io_utils_test.cc b/tensorflow/core/debug/debug_grpc_io_utils_test.cc index 65104241820..803cce85585 100644 --- a/tensorflow/core/debug/debug_grpc_io_utils_test.cc +++ b/tensorflow/core/debug/debug_grpc_io_utils_test.cc @@ -65,10 +65,6 @@ class GrpcDebugTest : public ::testing::Test { void ClearEnabledWatchKeys() { DebugGrpcIO::ClearEnabledWatchKeys(); } - void CreateEmptyEnabledSet(const string& grpc_debug_url) { - DebugGrpcIO::CreateEmptyEnabledSet(grpc_debug_url); - } - const int64 GetChannelConnectionTimeoutMicros() { return DebugGrpcIO::channel_connection_timeout_micros; } @@ -261,7 +257,7 @@ TEST_F(GrpcDebugTest, SendMultipleDebugTensorsSynchronizedViaGrpcTest) { } } -TEST_F(GrpcDebugTest, SendeDebugTensorsThroughMultipleRoundsUsingGrpcGating) { +TEST_F(GrpcDebugTest, SendDebugTensorsThroughMultipleRoundsUsingGrpcGating) { // Prepare the tensor to send. const DebugNodeKey kDebugNodeKey("/job:localhost/replica:0/task:0/cpu:0", "test_namescope/test_node", 0, @@ -283,8 +279,60 @@ TEST_F(GrpcDebugTest, SendeDebugTensorsThroughMultipleRoundsUsingGrpcGating) { TF_ASSERT_OK(DebugIO::PublishDebugTensor(kDebugNodeKey, tensor, wall_time, urls, enable_gated_grpc)); - server_data_.server->RequestDebugOpStateChangeAtNextStream(i == 0, - kDebugNodeKey); + server_data_.server->RequestDebugOpStateChangeAtNextStream( + i == 0 ? EventReply::DebugOpStateChange::READ_ONLY + : EventReply::DebugOpStateChange::DISABLED, + kDebugNodeKey); + + // Close the debug gRPC stream. + Status close_status = DebugIO::CloseDebugURL(server_data_.url); + ASSERT_TRUE(close_status.ok()); + + // Check dumped files according to the expected gating results. + if (i < 2) { + ASSERT_EQ(1, server_data_.server->node_names.size()); + ASSERT_EQ(1, server_data_.server->output_slots.size()); + ASSERT_EQ(1, server_data_.server->debug_ops.size()); + EXPECT_EQ(kDebugNodeKey.device_name, + server_data_.server->device_names[0]); + EXPECT_EQ(kDebugNodeKey.node_name, server_data_.server->node_names[0]); + EXPECT_EQ(kDebugNodeKey.output_slot, + server_data_.server->output_slots[0]); + EXPECT_EQ(kDebugNodeKey.debug_op, server_data_.server->debug_ops[0]); + } else { + ASSERT_EQ(0, server_data_.server->node_names.size()); + } + } +} + +TEST_F(GrpcDebugTest, SendDebugTensorsThroughMultipleRoundsUnderReadWriteMode) { + // Prepare the tensor to send. + const DebugNodeKey kDebugNodeKey("/job:localhost/replica:0/task:0/cpu:0", + "test_namescope/test_node", 0, + "DebugIdentity"); + Tensor tensor(DT_INT32, TensorShape({1, 1})); + tensor.flat()(0) = 42; + + const std::vector urls({server_data_.url}); + for (int i = 0; i < 3; ++i) { + server_data_.server->ClearReceivedDebugData(); + const uint64 wall_time = Env::Default()->NowMicros(); + + // On the 1st send (i == 0), gating is disabled, so data should be sent. + // On the 2nd send (i == 1), gating is enabled, and the server has enabled + // the watch key in the previous send (READ_WRITE), so data should be + // sent. In this iteration, the server response with a EventReply proto to + // unblock the debug node. + // On the 3rd send (i == 2), gating is enabled, but the server has disabled + // the watch key in the previous send, so data should not be sent. + const bool enable_gated_grpc = (i != 0); + TF_ASSERT_OK(DebugIO::PublishDebugTensor(kDebugNodeKey, tensor, wall_time, + urls, enable_gated_grpc)); + + server_data_.server->RequestDebugOpStateChangeAtNextStream( + i == 0 ? EventReply::DebugOpStateChange::READ_WRITE + : EventReply::DebugOpStateChange::DISABLED, + kDebugNodeKey); // Close the debug gRPC stream. Status close_status = DebugIO::CloseDebugURL(server_data_.url); @@ -308,8 +356,6 @@ TEST_F(GrpcDebugTest, SendeDebugTensorsThroughMultipleRoundsUsingGrpcGating) { } TEST_F(GrpcDebugTest, TestGateDebugNodeOnEmptyEnabledSet) { - CreateEmptyEnabledSet("grpc://localhost:3333"); - ASSERT_FALSE(DebugIO::IsDebugNodeGateOpen("foo:0:DebugIdentity", {"grpc://localhost:3333"})); @@ -322,8 +368,12 @@ TEST_F(GrpcDebugTest, TestGateDebugNodeOnNonEmptyEnabledSet) { const string kGrpcUrl1 = "grpc://localhost:3333"; const string kGrpcUrl2 = "grpc://localhost:3334"; - DebugGrpcIO::EnableWatchKey(kGrpcUrl1, "foo:0:DebugIdentity"); - DebugGrpcIO::EnableWatchKey(kGrpcUrl1, "bar:0:DebugIdentity"); + DebugGrpcIO::SetDebugNodeKeyGrpcState( + kGrpcUrl1, "foo:0:DebugIdentity", + EventReply::DebugOpStateChange::READ_ONLY); + DebugGrpcIO::SetDebugNodeKeyGrpcState( + kGrpcUrl1, "bar:0:DebugIdentity", + EventReply::DebugOpStateChange::READ_ONLY); ASSERT_FALSE( DebugIO::IsDebugNodeGateOpen("foo:1:DebugIdentity", {kGrpcUrl1})); @@ -350,9 +400,12 @@ TEST_F(GrpcDebugTest, TestGateDebugNodeOnMultipleEmptyEnabledSets) { const string kGrpcUrl2 = "grpc://localhost:3334"; const string kGrpcUrl3 = "grpc://localhost:3335"; - DebugGrpcIO::EnableWatchKey(kGrpcUrl1, "foo:0:DebugIdentity"); - DebugGrpcIO::EnableWatchKey(kGrpcUrl2, "bar:0:DebugIdentity"); - CreateEmptyEnabledSet(kGrpcUrl3); + DebugGrpcIO::SetDebugNodeKeyGrpcState( + kGrpcUrl1, "foo:0:DebugIdentity", + EventReply::DebugOpStateChange::READ_ONLY); + DebugGrpcIO::SetDebugNodeKeyGrpcState( + kGrpcUrl2, "bar:0:DebugIdentity", + EventReply::DebugOpStateChange::READ_ONLY); ASSERT_TRUE(DebugIO::IsDebugNodeGateOpen("foo:0:DebugIdentity", {kGrpcUrl1})); ASSERT_TRUE(DebugIO::IsDebugNodeGateOpen("bar:0:DebugIdentity", {kGrpcUrl2})); @@ -375,7 +428,9 @@ TEST_F(GrpcDebugTest, TestGateDebugNodeOnMultipleEmptyEnabledSets) { } TEST_F(GrpcDebugTest, TestGateDebugNodeOnNonEmptyEnabledSetAndEmptyURLs) { - DebugGrpcIO::EnableWatchKey("grpc://localhost:3333", "foo:0:DebugIdentity"); + DebugGrpcIO::SetDebugNodeKeyGrpcState( + "grpc://localhost:3333", "foo:0:DebugIdentity", + EventReply::DebugOpStateChange::READ_ONLY); std::vector debug_urls_1; ASSERT_FALSE( @@ -385,7 +440,6 @@ TEST_F(GrpcDebugTest, TestGateDebugNodeOnNonEmptyEnabledSetAndEmptyURLs) { TEST_F(GrpcDebugTest, TestGateCopyNodeOnEmptyEnabledSet) { const string kGrpcUrl1 = "grpc://localhost:3333"; const string kWatch1 = "foo:0:DebugIdentity"; - CreateEmptyEnabledSet(kGrpcUrl1); ASSERT_FALSE(DebugIO::IsCopyNodeGateOpen( {DebugWatchAndURLSpec(kWatch1, kGrpcUrl1, true)})); @@ -404,9 +458,8 @@ TEST_F(GrpcDebugTest, TestGateCopyNodeOnNonEmptyEnabledSet) { const string kGrpcUrl2 = "grpc://localhost:3334"; const string kWatch1 = "foo:0:DebugIdentity"; const string kWatch2 = "foo:1:DebugIdentity"; - CreateEmptyEnabledSet(kGrpcUrl1); - CreateEmptyEnabledSet(kGrpcUrl2); - DebugGrpcIO::EnableWatchKey(kGrpcUrl1, kWatch1); + DebugGrpcIO::SetDebugNodeKeyGrpcState( + kGrpcUrl1, kWatch1, EventReply::DebugOpStateChange::READ_ONLY); ASSERT_TRUE(DebugIO::IsCopyNodeGateOpen( {DebugWatchAndURLSpec(kWatch1, kGrpcUrl1, true)})); diff --git a/tensorflow/core/debug/debug_grpc_testlib.cc b/tensorflow/core/debug/debug_grpc_testlib.cc index acfbf7852b2..aa80ea84e34 100644 --- a/tensorflow/core/debug/debug_grpc_testlib.cc +++ b/tensorflow/core/debug/debug_grpc_testlib.cc @@ -72,27 +72,44 @@ namespace test { output_slots.push_back(metadata.output_slot()); debug_ops.push_back(debug_op); debug_tensors.push_back(tensor); + + // If the debug node is currently in the READ_WRITE mode, send an + // EventReply to 1) unblock the execution and 2) optionally modify the + // value. + const DebugNodeKey debug_node_key(metadata.device(), node_name, + metadata.output_slot(), debug_op); + if (write_enabled_debug_node_keys_.find(debug_node_key) != + write_enabled_debug_node_keys_.end()) { + stream->Write(EventReply()); + } } } { - mutex_lock l(changes_mu_); - for (size_t i = 0; i < changes_to_enable_.size(); ++i) { + mutex_lock l(states_mu_); + for (size_t i = 0; i < new_states_.size(); ++i) { EventReply event_reply; EventReply::DebugOpStateChange* change = event_reply.add_debug_op_state_changes(); - change->set_change(changes_to_enable_[i] - ? EventReply::DebugOpStateChange::ENABLE - : EventReply::DebugOpStateChange::DISABLE); - change->set_node_name(changes_node_names_[i]); - change->set_output_slot(changes_output_slots_[i]); - change->set_debug_op(changes_debug_ops_[i]); + + // State changes will take effect in the next stream, i.e., next debugged + // Session.run() call. + change->set_state(new_states_[i]); + const DebugNodeKey& debug_node_key = debug_node_keys_[i]; + change->set_node_name(debug_node_key.node_name); + change->set_output_slot(debug_node_key.output_slot); + change->set_debug_op(debug_node_key.debug_op); stream->Write(event_reply); + + if (new_states_[i] == EventReply::DebugOpStateChange::READ_WRITE) { + write_enabled_debug_node_keys_.insert(debug_node_key); + } else { + write_enabled_debug_node_keys_.erase(debug_node_key); + } } - changes_to_enable_.clear(); - changes_node_names_.clear(); - changes_output_slots_.clear(); - changes_debug_ops_.clear(); + + debug_node_keys_.clear(); + new_states_.clear(); } return ::grpc::Status::OK; @@ -109,13 +126,12 @@ void TestEventListenerImpl::ClearReceivedDebugData() { } void TestEventListenerImpl::RequestDebugOpStateChangeAtNextStream( - bool to_enable, const DebugNodeKey& debug_node_key) { - mutex_lock l(changes_mu_); + const EventReply::DebugOpStateChange::State new_state, + const DebugNodeKey& debug_node_key) { + mutex_lock l(states_mu_); - changes_to_enable_.push_back(to_enable); - changes_node_names_.push_back(debug_node_key.node_name); - changes_output_slots_.push_back(debug_node_key.output_slot); - changes_debug_ops_.push_back(debug_node_key.debug_op); + debug_node_keys_.push_back(debug_node_key); + new_states_.push_back(new_state); } void TestEventListenerImpl::RunServer(const int server_port) { diff --git a/tensorflow/core/debug/debug_grpc_testlib.h b/tensorflow/core/debug/debug_grpc_testlib.h index d55933c5809..58361bf78f4 100644 --- a/tensorflow/core/debug/debug_grpc_testlib.h +++ b/tensorflow/core/debug/debug_grpc_testlib.h @@ -17,6 +17,7 @@ limitations under the License. #define TENSORFLOW_DEBUG_GRPC_TESTLIB_H_ #include +#include #include "grpc++/grpc++.h" #include "tensorflow/core/debug/debug_io_utils.h" @@ -44,7 +45,8 @@ class TestEventListenerImpl final : public EventListener::Service { void ClearReceivedDebugData(); void RequestDebugOpStateChangeAtNextStream( - bool to_enable, const DebugNodeKey& debug_node_key); + const EventReply::DebugOpStateChange::State new_state, + const DebugNodeKey& debug_node_key); std::vector debug_metadata_strings; std::vector encoded_graph_defs; @@ -58,12 +60,13 @@ class TestEventListenerImpl final : public EventListener::Service { std::atomic_bool stop_requested_; std::atomic_bool stopped_; - std::vector changes_to_enable_ GUARDED_BY(changes_mu_); - std::vector changes_node_names_ GUARDED_BY(changes_mu_); - std::vector changes_output_slots_ GUARDED_BY(changes_mu_); - std::vector changes_debug_ops_ GUARDED_BY(changes_mu_); + std::vector debug_node_keys_ GUARDED_BY(states_mu_); + std::vector new_states_ + GUARDED_BY(states_mu_); - mutex changes_mu_; + std::unordered_set write_enabled_debug_node_keys_; + + mutex states_mu_; }; // Poll a gRPC debug server by sending a small tensor repeatedly till success. diff --git a/tensorflow/core/debug/debug_io_utils.cc b/tensorflow/core/debug/debug_io_utils.cc index 98ef68a9aa7..2c0284a6e0b 100644 --- a/tensorflow/core/debug/debug_io_utils.cc +++ b/tensorflow/core/debug/debug_io_utils.cc @@ -15,6 +15,11 @@ limitations under the License. #include "tensorflow/core/debug/debug_io_utils.h" +#include +#include +#include +#include +#include #include #ifndef PLATFORM_WINDOWS @@ -297,6 +302,15 @@ DebugNodeKey::DebugNodeKey(const string& device_name, const string& node_name, strings::StrCat(node_name, ":", output_slot, ":", debug_op)), device_path(DeviceNameToDevicePath(device_name)) {} +bool DebugNodeKey::operator==(const DebugNodeKey& other) const { + return (device_name == other.device_name && node_name == other.node_name && + output_slot == other.output_slot && debug_op == other.debug_op); +} + +bool DebugNodeKey::operator!=(const DebugNodeKey& other) const { + return !((*this) == other); +} + Status ReadEventFromFile(const string& dump_file_path, Event* event) { Env* env(Env::Default()); @@ -537,7 +551,7 @@ bool DebugIO::IsCopyNodeGateOpen( DebugIO::kGrpcURLScheme)) { return true; } else { - if (DebugGrpcIO::IsGateOpen(spec.watch_key, spec.url)) { + if (DebugGrpcIO::IsReadGateOpen(spec.url, spec.watch_key)) { return true; } } @@ -557,7 +571,7 @@ bool DebugIO::IsDebugNodeGateOpen(const string& watch_key, DebugIO::kGrpcURLScheme)) { return true; } else { - if (DebugGrpcIO::IsGateOpen(watch_key, debug_url)) { + if (DebugGrpcIO::IsReadGateOpen(debug_url, watch_key)) { return true; } } @@ -575,7 +589,7 @@ bool DebugIO::IsDebugURLGateOpen(const string& watch_key, if (debug_url.find(kGrpcURLScheme) != 0) { return true; } else { - return DebugGrpcIO::IsGateOpen(watch_key, debug_url); + return DebugGrpcIO::IsReadGateOpen(debug_url, watch_key); } #else return true; @@ -731,6 +745,12 @@ bool DebugGrpcChannel::WriteEvent(const Event& event) { return reader_writer_->Write(event); } +bool DebugGrpcChannel::ReadEventReply(EventReply* event_reply) { + mutex_lock l(mu_); + + return reader_writer_->Read(event_reply); +} + Status DebugGrpcChannel::ReceiveServerRepliesAndClose() { mutex_lock l(mu_); @@ -744,13 +764,8 @@ Status DebugGrpcChannel::ReceiveServerRepliesAndClose() { string watch_key = strings::StrCat(debug_op_state_change.node_name(), ":", debug_op_state_change.output_slot(), ":", debug_op_state_change.debug_op()); - if (debug_op_state_change.change() == - EventReply::DebugOpStateChange::ENABLE) { - DebugGrpcIO::EnableWatchKey(url_, watch_key); - } else if (debug_op_state_change.change() == - EventReply::DebugOpStateChange::DISABLE) { - DebugGrpcIO::DisableWatchKey(url_, watch_key); - } + DebugGrpcIO::SetDebugNodeKeyGrpcState(url_, watch_key, + debug_op_state_change.state()); } } @@ -789,7 +804,8 @@ Status DebugGrpcIO::SendTensorThroughGrpcStream( const DebugNodeKey& debug_node_key, const Tensor& tensor, const uint64 wall_time_us, const string& grpc_stream_url, const bool gated) { - if (gated && !IsGateOpen(debug_node_key.debug_node_name, grpc_stream_url)) { + if (gated && + !IsReadGateOpen(grpc_stream_url, debug_node_key.debug_node_name)) { return Status::OK(); } else { std::vector events; @@ -799,10 +815,35 @@ Status DebugGrpcIO::SendTensorThroughGrpcStream( TF_RETURN_IF_ERROR( SendEventProtoThroughGrpcStream(event, grpc_stream_url)); } + if (IsWriteGateOpen(grpc_stream_url, debug_node_key.debug_node_name)) { + EventReply event_reply; + TF_RETURN_IF_ERROR(ReceiveEventReplyProtoThroughGrpcStream( + &event_reply, grpc_stream_url)); + // TODO(cais): Support new tensor value carried in the EventReply for + // overriding the value of the tensor being published. + } return Status::OK(); } } +// static +Status DebugGrpcIO::ReceiveEventReplyProtoThroughGrpcStream( + EventReply* event_reply, const string& grpc_stream_url) { + std::shared_ptr debug_grpc_channel; + { + mutex_lock l(streams_mu); + std::unordered_map>* + stream_channels = GetStreamChannels(); + debug_grpc_channel = (*stream_channels)[grpc_stream_url]; + } + if (debug_grpc_channel->ReadEventReply(event_reply)) { + return Status::OK(); + } else { + return errors::Cancelled(strings::StrCat( + "Reading EventReply from stream URL ", grpc_stream_url, " failed.")); + } +} + // static Status DebugGrpcIO::SendEventProtoThroughGrpcStream( const Event& event_proto, const string& grpc_stream_url) { @@ -821,9 +862,7 @@ Status DebugGrpcIO::SendEventProtoThroughGrpcStream( debug_grpc_channel.reset(new DebugGrpcChannel(server_stream_addr)); TF_RETURN_IF_ERROR( debug_grpc_channel->Connect(channel_connection_timeout_micros)); - (*stream_channels)[grpc_stream_url] = debug_grpc_channel; - CreateEmptyEnabledSet(grpc_stream_url); } else { debug_grpc_channel = (*stream_channels)[grpc_stream_url]; } @@ -838,16 +877,22 @@ Status DebugGrpcIO::SendEventProtoThroughGrpcStream( return Status::OK(); } -// static -bool DebugGrpcIO::IsGateOpen(const string& watch_key, - const string& grpc_debug_url) { - std::unordered_map>* enabled_watch_keys = - GetEnabledWatchKeys(); - if (enabled_watch_keys->find(grpc_debug_url) == enabled_watch_keys->end()) { +bool DebugGrpcIO::IsReadGateOpen(const string& grpc_debug_url, + const string& watch_key) { + const DebugNodeName2State* enabled_node_to_state = + GetEnabledDebugOpStatesAtUrl(grpc_debug_url); + return enabled_node_to_state->find(watch_key) != enabled_node_to_state->end(); +} + +bool DebugGrpcIO::IsWriteGateOpen(const string& grpc_debug_url, + const string& watch_key) { + const DebugNodeName2State* enabled_node_to_state = + GetEnabledDebugOpStatesAtUrl(grpc_debug_url); + auto it = enabled_node_to_state->find(watch_key); + if (it == enabled_node_to_state->end()) { return false; } else { - const auto& url_enabled = (*enabled_watch_keys)[grpc_debug_url]; - return url_enabled.find(watch_key) != url_enabled.end(); + return it->second == EventReply::DebugOpStateChange::READ_WRITE; } } @@ -871,56 +916,46 @@ Status DebugGrpcIO::CloseGrpcStream(const string& grpc_stream_url) { } // static -std::unordered_map>* -DebugGrpcIO::GetEnabledWatchKeys() { - static std::unordered_map>* - enabled_watch_keys = - new std::unordered_map>(); - return enabled_watch_keys; +std::unordered_map* +DebugGrpcIO::GetEnabledDebugOpStates() { + static std::unordered_map* + enabled_debug_op_states = + new std::unordered_map(); + return enabled_debug_op_states; } // static -void DebugGrpcIO::EnableWatchKey(const string& grpc_debug_url, - const string& watch_key) { - std::unordered_map>* enabled_watch_keys = - GetEnabledWatchKeys(); - if (enabled_watch_keys->find(grpc_debug_url) == enabled_watch_keys->end()) { - CreateEmptyEnabledSet(grpc_debug_url); +DebugGrpcIO::DebugNodeName2State* DebugGrpcIO::GetEnabledDebugOpStatesAtUrl( + const string& grpc_debug_url) { + std::unordered_map* states = + GetEnabledDebugOpStates(); + if (states->find(grpc_debug_url) == states->end()) { + DebugNodeName2State url_enabled_debug_op_states; + (*states)[grpc_debug_url] = url_enabled_debug_op_states; } - (*enabled_watch_keys)[grpc_debug_url].insert(watch_key); + return &(*states)[grpc_debug_url]; } // static -void DebugGrpcIO::DisableWatchKey(const string& grpc_debug_url, - const string& watch_key) { - std::unordered_map>* enabled_watch_keys = - GetEnabledWatchKeys(); - if (enabled_watch_keys->find(grpc_debug_url) == enabled_watch_keys->end()) { - LOG(WARNING) << "Attempt to disable a watch key for an unregistered gRPC " - << "debug URL: " << grpc_debug_url; - } else { - std::unordered_set& url_enabled = - (*enabled_watch_keys)[grpc_debug_url]; - if (url_enabled.find(watch_key) == url_enabled.end()) { - LOG(WARNING) << "Attempt to disable a watch key that is not currently " - << "enabled at " << grpc_debug_url << ": " << watch_key; +void DebugGrpcIO::SetDebugNodeKeyGrpcState( + const string& grpc_debug_url, const string& watch_key, + const EventReply::DebugOpStateChange::State new_state) { + DebugNodeName2State* states = GetEnabledDebugOpStatesAtUrl(grpc_debug_url); + if (new_state == EventReply::DebugOpStateChange::DISABLED) { + if (states->find(watch_key) == states->end()) { + LOG(ERROR) << "Attempt to disable a watch key that is not currently " + << "enabled at " << grpc_debug_url << ": " << watch_key; } else { - url_enabled.erase(watch_key); + states->erase(watch_key); } + } else if (new_state != EventReply::DebugOpStateChange::STATE_UNSPECIFIED) { + (*states)[watch_key] = new_state; } } // static -void DebugGrpcIO::ClearEnabledWatchKeys() { GetEnabledWatchKeys()->clear(); } - -// static -void DebugGrpcIO::CreateEmptyEnabledSet(const string& grpc_debug_url) { - std::unordered_map>* enabled_watch_keys = - GetEnabledWatchKeys(); - if (enabled_watch_keys->find(grpc_debug_url) == enabled_watch_keys->end()) { - std::unordered_set empty_watch_keys; - (*enabled_watch_keys)[grpc_debug_url] = empty_watch_keys; - } +void DebugGrpcIO::ClearEnabledWatchKeys() { + GetEnabledDebugOpStates()->clear(); } #endif // #ifndef PLATFORM_WINDOWS diff --git a/tensorflow/core/debug/debug_io_utils.h b/tensorflow/core/debug/debug_io_utils.h index caf9f5341d3..35e735172ba 100644 --- a/tensorflow/core/debug/debug_io_utils.h +++ b/tensorflow/core/debug/debug_io_utils.h @@ -16,8 +16,13 @@ limitations under the License. #ifndef TENSORFLOW_DEBUG_IO_UTILS_H_ #define TENSORFLOW_DEBUG_IO_UTILS_H_ +#include +#include +#include +#include #include #include +#include #include "tensorflow/core/framework/tensor.h" #include "tensorflow/core/graph/graph.h" @@ -49,6 +54,9 @@ struct DebugNodeKey { // ,job_localhost,replica_0,task_0,cpu_0. static const string DeviceNameToDevicePath(const string& device_name); + bool operator==(const DebugNodeKey& other) const; + bool operator!=(const DebugNodeKey& other) const; + const string device_name; const string node_name; const int32 output_slot; @@ -219,6 +227,19 @@ class DebugFileIO { } // namespace tensorflow +namespace std { + +template <> +struct hash<::tensorflow::DebugNodeKey> { + size_t operator()(const ::tensorflow::DebugNodeKey& k) const { + return ::tensorflow::Hash64( + ::tensorflow::strings::StrCat(k.device_name, ":", k.node_name, ":", + k.output_slot, ":", k.debug_op, ":")); + } +}; + +} // namespace std + // TODO(cais): Support grpc:// debug URLs in open source once Python grpc // genrule becomes available. See b/23796275. #ifndef PLATFORM_WINDOWS @@ -252,7 +273,8 @@ class DebugGrpcChannel { // Write an Event proto to the debug gRPC stream. // // Thread-safety: Safe with respect to other calls to the same method and - // call to Close(). + // calls to ReadEventReply() and Close(). + // // Args: // event: The event proto to be written to the stream. // @@ -260,6 +282,19 @@ class DebugGrpcChannel { // True iff the write is successful. bool WriteEvent(const Event& event); + // Read an EventReply proto from the debug gRPC stream. + // + // This method blocks and waits for an EventReply from the server. + // Thread-safety: Safe with respect to other calls to the same method and + // calls to WriteEvent() and Close(). + // + // Args: + // event_reply: the to-be-modified EventReply proto passed as reference. + // + // Returns: + // True iff the read is successful. + bool ReadEventReply(EventReply* event_reply); + // Receive EventReplies from server (if any) and close the stream and the // channel. Status ReceiveServerRepliesAndClose(); @@ -294,48 +329,49 @@ class DebugGrpcIO { static Status SendEventProtoThroughGrpcStream(const Event& event_proto, const string& grpc_stream_url); - // Checks whether a debug watch key is allowed to send data to a given grpc:// - // debug URL given the current gating status. - // - // Args: - // watch_key: debug tensor watch key, in the format of - // tensor_name:debug_op, e.g., "Weights:0:DebugIdentity". - // grpc_debug_url: the debug URL, e.g., "grpc://localhost:3333", - // - // Returns: - // Whether the sending of debug data to grpc_debug_url should - // proceed. - static bool IsGateOpen(const string& watch_key, const string& grpc_debug_url); + // Receive an EventReply proto through a debug gRPC stream. + static Status ReceiveEventReplyProtoThroughGrpcStream( + EventReply* event_reply, const string& grpc_stream_url); + + // Check whether a debug watch key is read-activated at a given gRPC URL. + static bool IsReadGateOpen(const string& grpc_debug_url, + const string& watch_key); + + // Check whether a debug watch key is write-activated (i.e., read- and + // write-activated) at a given gRPC URL. + static bool IsWriteGateOpen(const string& grpc_debug_url, + const string& watch_key); // Closes a gRPC stream to the given address, if it exists. // Thread-safety: Safe with respect to other calls to the same method and // calls to SendTensorThroughGrpcStream(). static Status CloseGrpcStream(const string& grpc_stream_url); - // Enables a debug watch key at a grpc:// debug URL. - static void EnableWatchKey(const string& grpc_debug_url, - const string& watch_key); - - // Disables a debug watch key at a grpc:// debug URL. - static void DisableWatchKey(const string& grpc_debug_url, - const string& watch_key); + // Set the gRPC state of a debug node key. + // TODO(cais): Include device information in watch_key. + static void SetDebugNodeKeyGrpcState( + const string& grpc_debug_url, const string& watch_key, + const EventReply::DebugOpStateChange::State new_state); private: + using DebugNodeName2State = + std::unordered_map; + // Returns a global map from grpc debug URLs to the corresponding // DebugGrpcChannels. static std::unordered_map>* GetStreamChannels(); - // Returns a global map from grpc debug URLs to the enabled gated debug nodes. - // The keys are grpc:// URLs of the debug servers, e.g., "grpc://foo:3333". - // Each value element of the value has the format - // ::", e.g., - // "Weights_1:0:DebugNumericSummary". - static std::unordered_map>* - GetEnabledWatchKeys(); + // Returns a map from debug URL to a map from debug op name to enabled state. + static std::unordered_map* + GetEnabledDebugOpStates(); + // Returns a map from debug op names to enabled state, for a given debug URL. + static DebugNodeName2State* GetEnabledDebugOpStatesAtUrl( + const string& grpc_debug_url); + + // Clear enabled debug op state from all debug URLs (if any). static void ClearEnabledWatchKeys(); - static void CreateEmptyEnabledSet(const string& grpc_debug_url); static mutex streams_mu; static int64 channel_connection_timeout_micros; diff --git a/tensorflow/core/debug/debug_io_utils_test.cc b/tensorflow/core/debug/debug_io_utils_test.cc index df6fb1d2fe1..eee9d3f97e7 100644 --- a/tensorflow/core/debug/debug_io_utils_test.cc +++ b/tensorflow/core/debug/debug_io_utils_test.cc @@ -13,6 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +#include + #include "tensorflow/core/debug/debug_io_utils.h" #include "tensorflow/core/debug/debugger_event_metadata.pb.h" @@ -62,6 +64,39 @@ TEST_F(DebugIOUtilsTest, ConstructDebugNodeKey) { debug_node_key.device_path); } +TEST_F(DebugIOUtilsTest, EqualityOfDebugNodeKeys) { + const DebugNodeKey debug_node_key_1("/job:worker/replica:1/task:0/gpu:2", + "hidden_1/MatMul", 0, "DebugIdentity"); + const DebugNodeKey debug_node_key_2("/job:worker/replica:1/task:0/gpu:2", + "hidden_1/MatMul", 0, "DebugIdentity"); + const DebugNodeKey debug_node_key_3("/job:worker/replica:1/task:0/gpu:2", + "hidden_1/BiasAdd", 0, "DebugIdentity"); + const DebugNodeKey debug_node_key_4("/job:worker/replica:1/task:0/gpu:2", + "hidden_1/MatMul", 0, + "DebugNumericSummary"); + EXPECT_EQ(debug_node_key_1, debug_node_key_2); + EXPECT_NE(debug_node_key_1, debug_node_key_3); + EXPECT_NE(debug_node_key_1, debug_node_key_4); + EXPECT_NE(debug_node_key_3, debug_node_key_4); +} + +TEST_F(DebugIOUtilsTest, DebugNodeKeysIsHashable) { + const DebugNodeKey debug_node_key_1("/job:worker/replica:1/task:0/gpu:2", + "hidden_1/MatMul", 0, "DebugIdentity"); + const DebugNodeKey debug_node_key_2("/job:worker/replica:1/task:0/gpu:2", + "hidden_1/MatMul", 0, "DebugIdentity"); + const DebugNodeKey debug_node_key_3("/job:worker/replica:1/task:0/gpu:2", + "hidden_1/BiasAdd", 0, "DebugIdentity"); + + std::unordered_set keys; + keys.insert(debug_node_key_1); + ASSERT_EQ(1, keys.size()); + keys.insert(debug_node_key_3); + ASSERT_EQ(2, keys.size()); + keys.erase(debug_node_key_2); + ASSERT_EQ(1, keys.size()); +} + TEST_F(DebugIOUtilsTest, DumpFloatTensorToFileSunnyDay) { Initialize(); diff --git a/tensorflow/core/debug/debug_service.proto b/tensorflow/core/debug/debug_service.proto index 63d6668292a..547c0576f08 100644 --- a/tensorflow/core/debug/debug_service.proto +++ b/tensorflow/core/debug/debug_service.proto @@ -17,6 +17,7 @@ syntax = "proto3"; package tensorflow; +import "tensorflow/core/framework/tensor.proto"; import "tensorflow/core/util/event.proto"; // Reply message from EventListener to the client, i.e., to the source of the @@ -24,18 +25,25 @@ import "tensorflow/core/util/event.proto"; // TensorFlow graph being executed. message EventReply { message DebugOpStateChange { - enum Change { - DISABLE = 0; - ENABLE = 1; + enum State { + STATE_UNSPECIFIED = 0; + DISABLED = 1; + READ_ONLY = 2; + READ_WRITE = 3; } - Change change = 1; + State state = 1; string node_name = 2; int32 output_slot = 3; string debug_op = 4; } repeated DebugOpStateChange debug_op_state_changes = 1; + + // New tensor value to override the current tensor value with. + TensorProto tensor = 2; + // TODO(cais): Make use of this field to implement overriding of tensor value + // during debugging. } // EventListener: Receives Event protos, e.g., from debugged TensorFlow diff --git a/tensorflow/core/kernels/debug_ops_test.cc b/tensorflow/core/kernels/debug_ops_test.cc index 8f73a44a718..89bcbc9c373 100644 --- a/tensorflow/core/kernels/debug_ops_test.cc +++ b/tensorflow/core/kernels/debug_ops_test.cc @@ -259,10 +259,6 @@ class DebugNumericSummaryOpTest : public OpsTestBase { #if defined(PLATFORM_GOOGLE) void ClearEnabledWatchKeys() { DebugGrpcIO::ClearEnabledWatchKeys(); } - - void CreateEmptyEnabledSet(const string& grpc_debug_url) { - DebugGrpcIO::CreateEmptyEnabledSet(grpc_debug_url); - } #endif }; @@ -604,7 +600,6 @@ TEST_F(DebugNumericSummaryOpTest, BoolSuccess) { #if defined(PLATFORM_GOOGLE) TEST_F(DebugNumericSummaryOpTest, DisabledDueToEmptyEnabledSet) { ClearEnabledWatchKeys(); - CreateEmptyEnabledSet("grpc://server:3333"); std::vector debug_urls({"grpc://server:3333"}); TF_ASSERT_OK(InitGated(DT_FLOAT, debug_urls)); @@ -617,8 +612,9 @@ TEST_F(DebugNumericSummaryOpTest, DisabledDueToEmptyEnabledSet) { TEST_F(DebugNumericSummaryOpTest, DisabledDueToNonMatchingWatchKey) { ClearEnabledWatchKeys(); - DebugGrpcIO::EnableWatchKey("grpc://server:3333", - "FakeTensor:1:DebugNumeriSummary"); + DebugGrpcIO::SetDebugNodeKeyGrpcState( + "grpc://server:3333", "FakeTensor:1:DebugNumeriSummary", + EventReply::DebugOpStateChange::READ_ONLY); std::vector debug_urls({"grpc://server:3333"}); TF_ASSERT_OK(InitGated(DT_FLOAT, debug_urls)); diff --git a/tensorflow/python/debug/lib/grpc_debug_server.py b/tensorflow/python/debug/lib/grpc_debug_server.py index 948bf97742b..abe229e564e 100644 --- a/tensorflow/python/debug/lib/grpc_debug_server.py +++ b/tensorflow/python/debug/lib/grpc_debug_server.py @@ -31,16 +31,20 @@ from tensorflow.core.debug import debug_service_pb2 from tensorflow.core.framework import graph_pb2 from tensorflow.python.debug.lib import debug_data from tensorflow.python.debug.lib import debug_service_pb2_grpc +from tensorflow.python.platform import tf_logging as logging DebugWatch = collections.namedtuple("DebugWatch", ["node_name", "output_slot", "debug_op"]) -def _watch_key_event_reply(to_enable, node_name, output_slot, debug_op): - """Make EventReply proto to represent a request to watch/unwatch a debug op. +def _watch_key_event_reply(new_state, node_name, output_slot, debug_op): + """Make `EventReply` proto to represent a request to watch/unwatch a debug op. Args: - to_enable: (`bool`) whether the request is to enable the watch key. + new_state: (`debug_service_pb2.EventReply.DebugOpStateChange.State`) the new + state to set the debug node to, i.e., whether the debug node will become + disabled under the grpc mode (`DISABLED`), become a watchpoint + (`READ_ONLY`) or become a breakpoint (`READ_WRITE`). node_name: (`str`) name of the node. output_slot: (`int`) output slot of the tensor. debug_op: (`str`) the debug op attached to node_name:output_slot tensor to @@ -51,9 +55,7 @@ def _watch_key_event_reply(to_enable, node_name, output_slot, debug_op): """ event_reply = debug_service_pb2.EventReply() state_change = event_reply.debug_op_state_changes.add() - state_change.change = ( - debug_service_pb2.EventReply.DebugOpStateChange.ENABLE - if to_enable else debug_service_pb2.EventReply.DebugOpStateChange.DISABLE) + state_change.state = new_state state_change.node_name = node_name state_change.output_slot = output_slot state_change.debug_op = debug_op @@ -125,6 +127,7 @@ class EventListenerBaseServicer(debug_service_pb2_grpc.EventListenerServicer): self._event_reply_queue = queue.Queue() self._gated_grpc_debug_watches = set() + self._breakpoints = set() def SendEvents(self, request_iterator, context): """Implementation of the SendEvents service method. @@ -171,11 +174,30 @@ class EventListenerBaseServicer(debug_service_pb2_grpc.EventListenerServicer): maybe_tensor_event = self._process_tensor_event_in_chunks( event, tensor_chunks) if maybe_tensor_event: - stream_handler.on_value_event(maybe_tensor_event) + event_reply = stream_handler.on_value_event(maybe_tensor_event) + if event_reply is not None: + yield event_reply # The server writes EventReply messages, if any. while not self._event_reply_queue.empty(): - yield self._event_reply_queue.get() + event_reply = self._event_reply_queue.get() + for state_change in event_reply.debug_op_state_changes: + if (state_change.state == + debug_service_pb2.EventReply.DebugOpStateChange.READ_WRITE): + logging.info("Adding breakpoint %s:%d:%s", state_change.node_name, + state_change.output_slot, state_change.debug_op) + self._breakpoints.add( + (state_change.node_name, state_change.output_slot, + state_change.debug_op)) + elif (state_change.state == + debug_service_pb2.EventReply.DebugOpStateChange.DISABLED): + logging.info("Removing watchpoint or breakpoint: %s:%d:%s", + state_change.node_name, state_change.output_slot, + state_change.debug_op) + self._breakpoints.discard( + (state_change.node_name, state_change.output_slot, + state_change.debug_op)) + yield event_reply def _process_tensor_event_in_chunks(self, event, tensor_chunks): """Possibly reassemble event chunks. @@ -336,8 +358,8 @@ class EventListenerBaseServicer(debug_service_pb2_grpc.EventListenerServicer): finally: self._server_lock.release() - def request_watch(self, node_name, output_slot, debug_op): - """Request enabling a debug tensor watch. + def request_watch(self, node_name, output_slot, debug_op, breakpoint=False): + """Request enabling a debug tensor watchpoint or breakpoint. This will let the server send a EventReply to the client side (i.e., the debugged TensorFlow runtime process) to request adding a watch @@ -355,12 +377,19 @@ class EventListenerBaseServicer(debug_service_pb2_grpc.EventListenerServicer): output_slot: (`int`) output slot index of the tensor to watch. debug_op: (`str`) name of the debug op to enable. This should not include any attribute substrings. + breakpoint: (`bool`) Iff `True`, the debug op will block and wait until it + receives an `EventReply` response from the server. The `EventReply` + proto may carry a TensorProto that modifies the value of the debug op's + output tensor. """ self._event_reply_queue.put( - _watch_key_event_reply(True, node_name, output_slot, debug_op)) + _watch_key_event_reply( + debug_service_pb2.EventReply.DebugOpStateChange.READ_WRITE + if breakpoint else debug_service_pb2.EventReply.DebugOpStateChange. + READ_ONLY, node_name, output_slot, debug_op)) def request_unwatch(self, node_name, output_slot, debug_op): - """Request disabling a debug tensor watch. + """Request disabling a debug tensor watchpoint or breakpoint. The request will take effect on the next debugged `Session.run()` call. @@ -374,7 +403,18 @@ class EventListenerBaseServicer(debug_service_pb2_grpc.EventListenerServicer): any attribute substrings. """ self._event_reply_queue.put( - _watch_key_event_reply(False, node_name, output_slot, debug_op)) + _watch_key_event_reply(debug_service_pb2.EventReply.DebugOpStateChange. + DISABLED, node_name, output_slot, debug_op)) + + @property + def breakpoints(self): + """Get a set of the currently-activated breakpoints. + + Returns: + A `set` of 3-tuples: (node_name, output_slot, debug_op), e.g., + {("MatMul", 0, "DebugIdentity")}. + """ + return self._breakpoints def gated_grpc_debug_watches(self): """Get the list of debug watches with attribute gated_grpc=True. diff --git a/tensorflow/python/debug/lib/grpc_debug_test_server.py b/tensorflow/python/debug/lib/grpc_debug_test_server.py index c83f529d160..19d04ee654e 100644 --- a/tensorflow/python/debug/lib/grpc_debug_test_server.py +++ b/tensorflow/python/debug/lib/grpc_debug_test_server.py @@ -31,6 +31,7 @@ import time import portpicker +from tensorflow.core.debug import debug_service_pb2 from tensorflow.core.protobuf import config_pb2 from tensorflow.core.util import event_pb2 from tensorflow.python.client import session @@ -138,13 +139,26 @@ class EventListenerTestStreamHandler( Args: event: The Event proto carrying a tensor value. + + Returns: + If the debug node belongs to the set of currently activated breakpoints, + a `EventReply` proto will be returned. """ if self._dump_dir: self._write_value_event(event) else: value = event.summary.value[0] + tensor_value = debug_data.load_tensor_from_event(event) self._event_listener_servicer.debug_tensor_values[value.node_name].append( - debug_data.load_tensor_from_event(event)) + tensor_value) + + items = event.summary.value[0].node_name.split(":") + node_name = items[0] + output_slot = int(items[1]) + debug_op = items[2] + if ((node_name, output_slot, debug_op) in + self._event_listener_servicer.breakpoints): + return debug_service_pb2.EventReply() def _try_makedirs(self, dir_path): if not os.path.isdir(dir_path): diff --git a/tensorflow/python/debug/lib/session_debug_grpc_test.py b/tensorflow/python/debug/lib/session_debug_grpc_test.py index f97b4debd31..640b1489e4c 100644 --- a/tensorflow/python/debug/lib/session_debug_grpc_test.py +++ b/tensorflow/python/debug/lib/session_debug_grpc_test.py @@ -556,6 +556,55 @@ class SessionDebugGrpcGatingTest(test_util.TensorFlowTestCase): [50 + 5.0 * i], self._server_2.debug_tensor_values["v:0:DebugIdentity"]) + def testToggleBreakpointWorks(self): + with session.Session(config=no_rewrite_session_config()) as sess: + v = variables.Variable(50.0, name="v") + delta = constant_op.constant(5.0, name="delta") + inc_v = state_ops.assign_add(v, delta, name="inc_v") + + sess.run(v.initializer) + + run_metadata = config_pb2.RunMetadata() + run_options = config_pb2.RunOptions(output_partition_graphs=True) + debug_utils.watch_graph( + run_options, + sess.graph, + debug_ops=["DebugIdentity(gated_grpc=true)"], + debug_urls=[self._debug_server_url_1]) + + for i in xrange(4): + self._server_1.clear_data() + + # N.B.: These requests will be fulfilled not in this debugged + # Session.run() invocation, but in the next one. + if i in (0, 2): + # Enable breakpoint at delta:0:DebugIdentity in runs 0 and 2. + self._server_1.request_watch( + "delta", 0, "DebugIdentity", breakpoint=True) + else: + # Disable the breakpoint in runs 1 and 3. + self._server_1.request_unwatch("delta", 0, "DebugIdentity") + + output = sess.run(inc_v, options=run_options, run_metadata=run_metadata) + self.assertAllClose(50.0 + 5.0 * (i + 1), output) + + if i in (0, 2): + # After the end of runs 0 and 2, the server has received the requests + # to enable the breakpoint at delta:0:DebugIdentity. So the server + # should keep track of the correct breakpoints. + self.assertSetEqual({("delta", 0, "DebugIdentity")}, + self._server_1.breakpoints) + else: + # During runs 1 and 3, the server should have received the published + # debug tensor delta:0:DebugIdentity. The breakpoint should have been + # unblocked by EventReply reponses from the server. + self.assertAllClose( + [5.0], + self._server_1.debug_tensor_values["delta:0:DebugIdentity"]) + # After the runs, the server should have properly removed the + # breakpoints due to the request_unwatch calls. + self.assertSetEqual(set(), self._server_1.breakpoints) + def testGetGrpcDebugWatchesReturnsCorrectAnswer(self): with session.Session() as sess: v = variables.Variable(50.0, name="v")