tfdbg: extend grpc_debug_server protocol for interactive debugging
Previously, a grpc-gated debug op has two modes: DISABLED and ENABLED. This CL splits the ENABLED state into two states: READ_ONLY and READ_WRITE. * READ_ONLY is equivalent to the previous ENABLED state, wherein a debug op publishes debug tensor to the grpc debug server and proceeds. It can be regarded as a "watchpoint" that doesn't block execution. * READ_WRITE is a "breakpoint". In addition to publishing the debug tensor, it blocks and awaits a EventReply proto response from the grpc debug server before proceeding. PiperOrigin-RevId: 164987725
This commit is contained in:
parent
c0f9b0a91e
commit
3c482c66b5
@ -65,10 +65,6 @@ class GrpcDebugTest : public ::testing::Test {
|
|||||||
|
|
||||||
void ClearEnabledWatchKeys() { DebugGrpcIO::ClearEnabledWatchKeys(); }
|
void ClearEnabledWatchKeys() { DebugGrpcIO::ClearEnabledWatchKeys(); }
|
||||||
|
|
||||||
void CreateEmptyEnabledSet(const string& grpc_debug_url) {
|
|
||||||
DebugGrpcIO::CreateEmptyEnabledSet(grpc_debug_url);
|
|
||||||
}
|
|
||||||
|
|
||||||
const int64 GetChannelConnectionTimeoutMicros() {
|
const int64 GetChannelConnectionTimeoutMicros() {
|
||||||
return DebugGrpcIO::channel_connection_timeout_micros;
|
return DebugGrpcIO::channel_connection_timeout_micros;
|
||||||
}
|
}
|
||||||
@ -261,7 +257,7 @@ TEST_F(GrpcDebugTest, SendMultipleDebugTensorsSynchronizedViaGrpcTest) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(GrpcDebugTest, SendeDebugTensorsThroughMultipleRoundsUsingGrpcGating) {
|
TEST_F(GrpcDebugTest, SendDebugTensorsThroughMultipleRoundsUsingGrpcGating) {
|
||||||
// Prepare the tensor to send.
|
// Prepare the tensor to send.
|
||||||
const DebugNodeKey kDebugNodeKey("/job:localhost/replica:0/task:0/cpu:0",
|
const DebugNodeKey kDebugNodeKey("/job:localhost/replica:0/task:0/cpu:0",
|
||||||
"test_namescope/test_node", 0,
|
"test_namescope/test_node", 0,
|
||||||
@ -283,8 +279,60 @@ TEST_F(GrpcDebugTest, SendeDebugTensorsThroughMultipleRoundsUsingGrpcGating) {
|
|||||||
TF_ASSERT_OK(DebugIO::PublishDebugTensor(kDebugNodeKey, tensor, wall_time,
|
TF_ASSERT_OK(DebugIO::PublishDebugTensor(kDebugNodeKey, tensor, wall_time,
|
||||||
urls, enable_gated_grpc));
|
urls, enable_gated_grpc));
|
||||||
|
|
||||||
server_data_.server->RequestDebugOpStateChangeAtNextStream(i == 0,
|
server_data_.server->RequestDebugOpStateChangeAtNextStream(
|
||||||
kDebugNodeKey);
|
i == 0 ? EventReply::DebugOpStateChange::READ_ONLY
|
||||||
|
: EventReply::DebugOpStateChange::DISABLED,
|
||||||
|
kDebugNodeKey);
|
||||||
|
|
||||||
|
// Close the debug gRPC stream.
|
||||||
|
Status close_status = DebugIO::CloseDebugURL(server_data_.url);
|
||||||
|
ASSERT_TRUE(close_status.ok());
|
||||||
|
|
||||||
|
// Check dumped files according to the expected gating results.
|
||||||
|
if (i < 2) {
|
||||||
|
ASSERT_EQ(1, server_data_.server->node_names.size());
|
||||||
|
ASSERT_EQ(1, server_data_.server->output_slots.size());
|
||||||
|
ASSERT_EQ(1, server_data_.server->debug_ops.size());
|
||||||
|
EXPECT_EQ(kDebugNodeKey.device_name,
|
||||||
|
server_data_.server->device_names[0]);
|
||||||
|
EXPECT_EQ(kDebugNodeKey.node_name, server_data_.server->node_names[0]);
|
||||||
|
EXPECT_EQ(kDebugNodeKey.output_slot,
|
||||||
|
server_data_.server->output_slots[0]);
|
||||||
|
EXPECT_EQ(kDebugNodeKey.debug_op, server_data_.server->debug_ops[0]);
|
||||||
|
} else {
|
||||||
|
ASSERT_EQ(0, server_data_.server->node_names.size());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(GrpcDebugTest, SendDebugTensorsThroughMultipleRoundsUnderReadWriteMode) {
|
||||||
|
// Prepare the tensor to send.
|
||||||
|
const DebugNodeKey kDebugNodeKey("/job:localhost/replica:0/task:0/cpu:0",
|
||||||
|
"test_namescope/test_node", 0,
|
||||||
|
"DebugIdentity");
|
||||||
|
Tensor tensor(DT_INT32, TensorShape({1, 1}));
|
||||||
|
tensor.flat<int>()(0) = 42;
|
||||||
|
|
||||||
|
const std::vector<string> urls({server_data_.url});
|
||||||
|
for (int i = 0; i < 3; ++i) {
|
||||||
|
server_data_.server->ClearReceivedDebugData();
|
||||||
|
const uint64 wall_time = Env::Default()->NowMicros();
|
||||||
|
|
||||||
|
// On the 1st send (i == 0), gating is disabled, so data should be sent.
|
||||||
|
// On the 2nd send (i == 1), gating is enabled, and the server has enabled
|
||||||
|
// the watch key in the previous send (READ_WRITE), so data should be
|
||||||
|
// sent. In this iteration, the server response with a EventReply proto to
|
||||||
|
// unblock the debug node.
|
||||||
|
// On the 3rd send (i == 2), gating is enabled, but the server has disabled
|
||||||
|
// the watch key in the previous send, so data should not be sent.
|
||||||
|
const bool enable_gated_grpc = (i != 0);
|
||||||
|
TF_ASSERT_OK(DebugIO::PublishDebugTensor(kDebugNodeKey, tensor, wall_time,
|
||||||
|
urls, enable_gated_grpc));
|
||||||
|
|
||||||
|
server_data_.server->RequestDebugOpStateChangeAtNextStream(
|
||||||
|
i == 0 ? EventReply::DebugOpStateChange::READ_WRITE
|
||||||
|
: EventReply::DebugOpStateChange::DISABLED,
|
||||||
|
kDebugNodeKey);
|
||||||
|
|
||||||
// Close the debug gRPC stream.
|
// Close the debug gRPC stream.
|
||||||
Status close_status = DebugIO::CloseDebugURL(server_data_.url);
|
Status close_status = DebugIO::CloseDebugURL(server_data_.url);
|
||||||
@ -308,8 +356,6 @@ TEST_F(GrpcDebugTest, SendeDebugTensorsThroughMultipleRoundsUsingGrpcGating) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(GrpcDebugTest, TestGateDebugNodeOnEmptyEnabledSet) {
|
TEST_F(GrpcDebugTest, TestGateDebugNodeOnEmptyEnabledSet) {
|
||||||
CreateEmptyEnabledSet("grpc://localhost:3333");
|
|
||||||
|
|
||||||
ASSERT_FALSE(DebugIO::IsDebugNodeGateOpen("foo:0:DebugIdentity",
|
ASSERT_FALSE(DebugIO::IsDebugNodeGateOpen("foo:0:DebugIdentity",
|
||||||
{"grpc://localhost:3333"}));
|
{"grpc://localhost:3333"}));
|
||||||
|
|
||||||
@ -322,8 +368,12 @@ TEST_F(GrpcDebugTest, TestGateDebugNodeOnNonEmptyEnabledSet) {
|
|||||||
const string kGrpcUrl1 = "grpc://localhost:3333";
|
const string kGrpcUrl1 = "grpc://localhost:3333";
|
||||||
const string kGrpcUrl2 = "grpc://localhost:3334";
|
const string kGrpcUrl2 = "grpc://localhost:3334";
|
||||||
|
|
||||||
DebugGrpcIO::EnableWatchKey(kGrpcUrl1, "foo:0:DebugIdentity");
|
DebugGrpcIO::SetDebugNodeKeyGrpcState(
|
||||||
DebugGrpcIO::EnableWatchKey(kGrpcUrl1, "bar:0:DebugIdentity");
|
kGrpcUrl1, "foo:0:DebugIdentity",
|
||||||
|
EventReply::DebugOpStateChange::READ_ONLY);
|
||||||
|
DebugGrpcIO::SetDebugNodeKeyGrpcState(
|
||||||
|
kGrpcUrl1, "bar:0:DebugIdentity",
|
||||||
|
EventReply::DebugOpStateChange::READ_ONLY);
|
||||||
|
|
||||||
ASSERT_FALSE(
|
ASSERT_FALSE(
|
||||||
DebugIO::IsDebugNodeGateOpen("foo:1:DebugIdentity", {kGrpcUrl1}));
|
DebugIO::IsDebugNodeGateOpen("foo:1:DebugIdentity", {kGrpcUrl1}));
|
||||||
@ -350,9 +400,12 @@ TEST_F(GrpcDebugTest, TestGateDebugNodeOnMultipleEmptyEnabledSets) {
|
|||||||
const string kGrpcUrl2 = "grpc://localhost:3334";
|
const string kGrpcUrl2 = "grpc://localhost:3334";
|
||||||
const string kGrpcUrl3 = "grpc://localhost:3335";
|
const string kGrpcUrl3 = "grpc://localhost:3335";
|
||||||
|
|
||||||
DebugGrpcIO::EnableWatchKey(kGrpcUrl1, "foo:0:DebugIdentity");
|
DebugGrpcIO::SetDebugNodeKeyGrpcState(
|
||||||
DebugGrpcIO::EnableWatchKey(kGrpcUrl2, "bar:0:DebugIdentity");
|
kGrpcUrl1, "foo:0:DebugIdentity",
|
||||||
CreateEmptyEnabledSet(kGrpcUrl3);
|
EventReply::DebugOpStateChange::READ_ONLY);
|
||||||
|
DebugGrpcIO::SetDebugNodeKeyGrpcState(
|
||||||
|
kGrpcUrl2, "bar:0:DebugIdentity",
|
||||||
|
EventReply::DebugOpStateChange::READ_ONLY);
|
||||||
|
|
||||||
ASSERT_TRUE(DebugIO::IsDebugNodeGateOpen("foo:0:DebugIdentity", {kGrpcUrl1}));
|
ASSERT_TRUE(DebugIO::IsDebugNodeGateOpen("foo:0:DebugIdentity", {kGrpcUrl1}));
|
||||||
ASSERT_TRUE(DebugIO::IsDebugNodeGateOpen("bar:0:DebugIdentity", {kGrpcUrl2}));
|
ASSERT_TRUE(DebugIO::IsDebugNodeGateOpen("bar:0:DebugIdentity", {kGrpcUrl2}));
|
||||||
@ -375,7 +428,9 @@ TEST_F(GrpcDebugTest, TestGateDebugNodeOnMultipleEmptyEnabledSets) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(GrpcDebugTest, TestGateDebugNodeOnNonEmptyEnabledSetAndEmptyURLs) {
|
TEST_F(GrpcDebugTest, TestGateDebugNodeOnNonEmptyEnabledSetAndEmptyURLs) {
|
||||||
DebugGrpcIO::EnableWatchKey("grpc://localhost:3333", "foo:0:DebugIdentity");
|
DebugGrpcIO::SetDebugNodeKeyGrpcState(
|
||||||
|
"grpc://localhost:3333", "foo:0:DebugIdentity",
|
||||||
|
EventReply::DebugOpStateChange::READ_ONLY);
|
||||||
|
|
||||||
std::vector<string> debug_urls_1;
|
std::vector<string> debug_urls_1;
|
||||||
ASSERT_FALSE(
|
ASSERT_FALSE(
|
||||||
@ -385,7 +440,6 @@ TEST_F(GrpcDebugTest, TestGateDebugNodeOnNonEmptyEnabledSetAndEmptyURLs) {
|
|||||||
TEST_F(GrpcDebugTest, TestGateCopyNodeOnEmptyEnabledSet) {
|
TEST_F(GrpcDebugTest, TestGateCopyNodeOnEmptyEnabledSet) {
|
||||||
const string kGrpcUrl1 = "grpc://localhost:3333";
|
const string kGrpcUrl1 = "grpc://localhost:3333";
|
||||||
const string kWatch1 = "foo:0:DebugIdentity";
|
const string kWatch1 = "foo:0:DebugIdentity";
|
||||||
CreateEmptyEnabledSet(kGrpcUrl1);
|
|
||||||
|
|
||||||
ASSERT_FALSE(DebugIO::IsCopyNodeGateOpen(
|
ASSERT_FALSE(DebugIO::IsCopyNodeGateOpen(
|
||||||
{DebugWatchAndURLSpec(kWatch1, kGrpcUrl1, true)}));
|
{DebugWatchAndURLSpec(kWatch1, kGrpcUrl1, true)}));
|
||||||
@ -404,9 +458,8 @@ TEST_F(GrpcDebugTest, TestGateCopyNodeOnNonEmptyEnabledSet) {
|
|||||||
const string kGrpcUrl2 = "grpc://localhost:3334";
|
const string kGrpcUrl2 = "grpc://localhost:3334";
|
||||||
const string kWatch1 = "foo:0:DebugIdentity";
|
const string kWatch1 = "foo:0:DebugIdentity";
|
||||||
const string kWatch2 = "foo:1:DebugIdentity";
|
const string kWatch2 = "foo:1:DebugIdentity";
|
||||||
CreateEmptyEnabledSet(kGrpcUrl1);
|
DebugGrpcIO::SetDebugNodeKeyGrpcState(
|
||||||
CreateEmptyEnabledSet(kGrpcUrl2);
|
kGrpcUrl1, kWatch1, EventReply::DebugOpStateChange::READ_ONLY);
|
||||||
DebugGrpcIO::EnableWatchKey(kGrpcUrl1, kWatch1);
|
|
||||||
|
|
||||||
ASSERT_TRUE(DebugIO::IsCopyNodeGateOpen(
|
ASSERT_TRUE(DebugIO::IsCopyNodeGateOpen(
|
||||||
{DebugWatchAndURLSpec(kWatch1, kGrpcUrl1, true)}));
|
{DebugWatchAndURLSpec(kWatch1, kGrpcUrl1, true)}));
|
||||||
|
@ -72,27 +72,44 @@ namespace test {
|
|||||||
output_slots.push_back(metadata.output_slot());
|
output_slots.push_back(metadata.output_slot());
|
||||||
debug_ops.push_back(debug_op);
|
debug_ops.push_back(debug_op);
|
||||||
debug_tensors.push_back(tensor);
|
debug_tensors.push_back(tensor);
|
||||||
|
|
||||||
|
// If the debug node is currently in the READ_WRITE mode, send an
|
||||||
|
// EventReply to 1) unblock the execution and 2) optionally modify the
|
||||||
|
// value.
|
||||||
|
const DebugNodeKey debug_node_key(metadata.device(), node_name,
|
||||||
|
metadata.output_slot(), debug_op);
|
||||||
|
if (write_enabled_debug_node_keys_.find(debug_node_key) !=
|
||||||
|
write_enabled_debug_node_keys_.end()) {
|
||||||
|
stream->Write(EventReply());
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
{
|
{
|
||||||
mutex_lock l(changes_mu_);
|
mutex_lock l(states_mu_);
|
||||||
for (size_t i = 0; i < changes_to_enable_.size(); ++i) {
|
for (size_t i = 0; i < new_states_.size(); ++i) {
|
||||||
EventReply event_reply;
|
EventReply event_reply;
|
||||||
EventReply::DebugOpStateChange* change =
|
EventReply::DebugOpStateChange* change =
|
||||||
event_reply.add_debug_op_state_changes();
|
event_reply.add_debug_op_state_changes();
|
||||||
change->set_change(changes_to_enable_[i]
|
|
||||||
? EventReply::DebugOpStateChange::ENABLE
|
// State changes will take effect in the next stream, i.e., next debugged
|
||||||
: EventReply::DebugOpStateChange::DISABLE);
|
// Session.run() call.
|
||||||
change->set_node_name(changes_node_names_[i]);
|
change->set_state(new_states_[i]);
|
||||||
change->set_output_slot(changes_output_slots_[i]);
|
const DebugNodeKey& debug_node_key = debug_node_keys_[i];
|
||||||
change->set_debug_op(changes_debug_ops_[i]);
|
change->set_node_name(debug_node_key.node_name);
|
||||||
|
change->set_output_slot(debug_node_key.output_slot);
|
||||||
|
change->set_debug_op(debug_node_key.debug_op);
|
||||||
stream->Write(event_reply);
|
stream->Write(event_reply);
|
||||||
|
|
||||||
|
if (new_states_[i] == EventReply::DebugOpStateChange::READ_WRITE) {
|
||||||
|
write_enabled_debug_node_keys_.insert(debug_node_key);
|
||||||
|
} else {
|
||||||
|
write_enabled_debug_node_keys_.erase(debug_node_key);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
changes_to_enable_.clear();
|
|
||||||
changes_node_names_.clear();
|
debug_node_keys_.clear();
|
||||||
changes_output_slots_.clear();
|
new_states_.clear();
|
||||||
changes_debug_ops_.clear();
|
|
||||||
}
|
}
|
||||||
|
|
||||||
return ::grpc::Status::OK;
|
return ::grpc::Status::OK;
|
||||||
@ -109,13 +126,12 @@ void TestEventListenerImpl::ClearReceivedDebugData() {
|
|||||||
}
|
}
|
||||||
|
|
||||||
void TestEventListenerImpl::RequestDebugOpStateChangeAtNextStream(
|
void TestEventListenerImpl::RequestDebugOpStateChangeAtNextStream(
|
||||||
bool to_enable, const DebugNodeKey& debug_node_key) {
|
const EventReply::DebugOpStateChange::State new_state,
|
||||||
mutex_lock l(changes_mu_);
|
const DebugNodeKey& debug_node_key) {
|
||||||
|
mutex_lock l(states_mu_);
|
||||||
|
|
||||||
changes_to_enable_.push_back(to_enable);
|
debug_node_keys_.push_back(debug_node_key);
|
||||||
changes_node_names_.push_back(debug_node_key.node_name);
|
new_states_.push_back(new_state);
|
||||||
changes_output_slots_.push_back(debug_node_key.output_slot);
|
|
||||||
changes_debug_ops_.push_back(debug_node_key.debug_op);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
void TestEventListenerImpl::RunServer(const int server_port) {
|
void TestEventListenerImpl::RunServer(const int server_port) {
|
||||||
|
@ -17,6 +17,7 @@ limitations under the License.
|
|||||||
#define TENSORFLOW_DEBUG_GRPC_TESTLIB_H_
|
#define TENSORFLOW_DEBUG_GRPC_TESTLIB_H_
|
||||||
|
|
||||||
#include <atomic>
|
#include <atomic>
|
||||||
|
#include <unordered_set>
|
||||||
|
|
||||||
#include "grpc++/grpc++.h"
|
#include "grpc++/grpc++.h"
|
||||||
#include "tensorflow/core/debug/debug_io_utils.h"
|
#include "tensorflow/core/debug/debug_io_utils.h"
|
||||||
@ -44,7 +45,8 @@ class TestEventListenerImpl final : public EventListener::Service {
|
|||||||
void ClearReceivedDebugData();
|
void ClearReceivedDebugData();
|
||||||
|
|
||||||
void RequestDebugOpStateChangeAtNextStream(
|
void RequestDebugOpStateChangeAtNextStream(
|
||||||
bool to_enable, const DebugNodeKey& debug_node_key);
|
const EventReply::DebugOpStateChange::State new_state,
|
||||||
|
const DebugNodeKey& debug_node_key);
|
||||||
|
|
||||||
std::vector<string> debug_metadata_strings;
|
std::vector<string> debug_metadata_strings;
|
||||||
std::vector<string> encoded_graph_defs;
|
std::vector<string> encoded_graph_defs;
|
||||||
@ -58,12 +60,13 @@ class TestEventListenerImpl final : public EventListener::Service {
|
|||||||
std::atomic_bool stop_requested_;
|
std::atomic_bool stop_requested_;
|
||||||
std::atomic_bool stopped_;
|
std::atomic_bool stopped_;
|
||||||
|
|
||||||
std::vector<bool> changes_to_enable_ GUARDED_BY(changes_mu_);
|
std::vector<DebugNodeKey> debug_node_keys_ GUARDED_BY(states_mu_);
|
||||||
std::vector<string> changes_node_names_ GUARDED_BY(changes_mu_);
|
std::vector<EventReply::DebugOpStateChange::State> new_states_
|
||||||
std::vector<int32> changes_output_slots_ GUARDED_BY(changes_mu_);
|
GUARDED_BY(states_mu_);
|
||||||
std::vector<string> changes_debug_ops_ GUARDED_BY(changes_mu_);
|
|
||||||
|
|
||||||
mutex changes_mu_;
|
std::unordered_set<DebugNodeKey> write_enabled_debug_node_keys_;
|
||||||
|
|
||||||
|
mutex states_mu_;
|
||||||
};
|
};
|
||||||
|
|
||||||
// Poll a gRPC debug server by sending a small tensor repeatedly till success.
|
// Poll a gRPC debug server by sending a small tensor repeatedly till success.
|
||||||
|
@ -15,6 +15,11 @@ limitations under the License.
|
|||||||
|
|
||||||
#include "tensorflow/core/debug/debug_io_utils.h"
|
#include "tensorflow/core/debug/debug_io_utils.h"
|
||||||
|
|
||||||
|
#include <stddef.h>
|
||||||
|
#include <string.h>
|
||||||
|
#include <cmath>
|
||||||
|
#include <limits>
|
||||||
|
#include <utility>
|
||||||
#include <vector>
|
#include <vector>
|
||||||
|
|
||||||
#ifndef PLATFORM_WINDOWS
|
#ifndef PLATFORM_WINDOWS
|
||||||
@ -297,6 +302,15 @@ DebugNodeKey::DebugNodeKey(const string& device_name, const string& node_name,
|
|||||||
strings::StrCat(node_name, ":", output_slot, ":", debug_op)),
|
strings::StrCat(node_name, ":", output_slot, ":", debug_op)),
|
||||||
device_path(DeviceNameToDevicePath(device_name)) {}
|
device_path(DeviceNameToDevicePath(device_name)) {}
|
||||||
|
|
||||||
|
bool DebugNodeKey::operator==(const DebugNodeKey& other) const {
|
||||||
|
return (device_name == other.device_name && node_name == other.node_name &&
|
||||||
|
output_slot == other.output_slot && debug_op == other.debug_op);
|
||||||
|
}
|
||||||
|
|
||||||
|
bool DebugNodeKey::operator!=(const DebugNodeKey& other) const {
|
||||||
|
return !((*this) == other);
|
||||||
|
}
|
||||||
|
|
||||||
Status ReadEventFromFile(const string& dump_file_path, Event* event) {
|
Status ReadEventFromFile(const string& dump_file_path, Event* event) {
|
||||||
Env* env(Env::Default());
|
Env* env(Env::Default());
|
||||||
|
|
||||||
@ -537,7 +551,7 @@ bool DebugIO::IsCopyNodeGateOpen(
|
|||||||
DebugIO::kGrpcURLScheme)) {
|
DebugIO::kGrpcURLScheme)) {
|
||||||
return true;
|
return true;
|
||||||
} else {
|
} else {
|
||||||
if (DebugGrpcIO::IsGateOpen(spec.watch_key, spec.url)) {
|
if (DebugGrpcIO::IsReadGateOpen(spec.url, spec.watch_key)) {
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -557,7 +571,7 @@ bool DebugIO::IsDebugNodeGateOpen(const string& watch_key,
|
|||||||
DebugIO::kGrpcURLScheme)) {
|
DebugIO::kGrpcURLScheme)) {
|
||||||
return true;
|
return true;
|
||||||
} else {
|
} else {
|
||||||
if (DebugGrpcIO::IsGateOpen(watch_key, debug_url)) {
|
if (DebugGrpcIO::IsReadGateOpen(debug_url, watch_key)) {
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -575,7 +589,7 @@ bool DebugIO::IsDebugURLGateOpen(const string& watch_key,
|
|||||||
if (debug_url.find(kGrpcURLScheme) != 0) {
|
if (debug_url.find(kGrpcURLScheme) != 0) {
|
||||||
return true;
|
return true;
|
||||||
} else {
|
} else {
|
||||||
return DebugGrpcIO::IsGateOpen(watch_key, debug_url);
|
return DebugGrpcIO::IsReadGateOpen(debug_url, watch_key);
|
||||||
}
|
}
|
||||||
#else
|
#else
|
||||||
return true;
|
return true;
|
||||||
@ -731,6 +745,12 @@ bool DebugGrpcChannel::WriteEvent(const Event& event) {
|
|||||||
return reader_writer_->Write(event);
|
return reader_writer_->Write(event);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
bool DebugGrpcChannel::ReadEventReply(EventReply* event_reply) {
|
||||||
|
mutex_lock l(mu_);
|
||||||
|
|
||||||
|
return reader_writer_->Read(event_reply);
|
||||||
|
}
|
||||||
|
|
||||||
Status DebugGrpcChannel::ReceiveServerRepliesAndClose() {
|
Status DebugGrpcChannel::ReceiveServerRepliesAndClose() {
|
||||||
mutex_lock l(mu_);
|
mutex_lock l(mu_);
|
||||||
|
|
||||||
@ -744,13 +764,8 @@ Status DebugGrpcChannel::ReceiveServerRepliesAndClose() {
|
|||||||
string watch_key = strings::StrCat(debug_op_state_change.node_name(), ":",
|
string watch_key = strings::StrCat(debug_op_state_change.node_name(), ":",
|
||||||
debug_op_state_change.output_slot(),
|
debug_op_state_change.output_slot(),
|
||||||
":", debug_op_state_change.debug_op());
|
":", debug_op_state_change.debug_op());
|
||||||
if (debug_op_state_change.change() ==
|
DebugGrpcIO::SetDebugNodeKeyGrpcState(url_, watch_key,
|
||||||
EventReply::DebugOpStateChange::ENABLE) {
|
debug_op_state_change.state());
|
||||||
DebugGrpcIO::EnableWatchKey(url_, watch_key);
|
|
||||||
} else if (debug_op_state_change.change() ==
|
|
||||||
EventReply::DebugOpStateChange::DISABLE) {
|
|
||||||
DebugGrpcIO::DisableWatchKey(url_, watch_key);
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -789,7 +804,8 @@ Status DebugGrpcIO::SendTensorThroughGrpcStream(
|
|||||||
const DebugNodeKey& debug_node_key, const Tensor& tensor,
|
const DebugNodeKey& debug_node_key, const Tensor& tensor,
|
||||||
const uint64 wall_time_us, const string& grpc_stream_url,
|
const uint64 wall_time_us, const string& grpc_stream_url,
|
||||||
const bool gated) {
|
const bool gated) {
|
||||||
if (gated && !IsGateOpen(debug_node_key.debug_node_name, grpc_stream_url)) {
|
if (gated &&
|
||||||
|
!IsReadGateOpen(grpc_stream_url, debug_node_key.debug_node_name)) {
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
} else {
|
} else {
|
||||||
std::vector<Event> events;
|
std::vector<Event> events;
|
||||||
@ -799,10 +815,35 @@ Status DebugGrpcIO::SendTensorThroughGrpcStream(
|
|||||||
TF_RETURN_IF_ERROR(
|
TF_RETURN_IF_ERROR(
|
||||||
SendEventProtoThroughGrpcStream(event, grpc_stream_url));
|
SendEventProtoThroughGrpcStream(event, grpc_stream_url));
|
||||||
}
|
}
|
||||||
|
if (IsWriteGateOpen(grpc_stream_url, debug_node_key.debug_node_name)) {
|
||||||
|
EventReply event_reply;
|
||||||
|
TF_RETURN_IF_ERROR(ReceiveEventReplyProtoThroughGrpcStream(
|
||||||
|
&event_reply, grpc_stream_url));
|
||||||
|
// TODO(cais): Support new tensor value carried in the EventReply for
|
||||||
|
// overriding the value of the tensor being published.
|
||||||
|
}
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// static
|
||||||
|
Status DebugGrpcIO::ReceiveEventReplyProtoThroughGrpcStream(
|
||||||
|
EventReply* event_reply, const string& grpc_stream_url) {
|
||||||
|
std::shared_ptr<DebugGrpcChannel> debug_grpc_channel;
|
||||||
|
{
|
||||||
|
mutex_lock l(streams_mu);
|
||||||
|
std::unordered_map<string, std::shared_ptr<DebugGrpcChannel>>*
|
||||||
|
stream_channels = GetStreamChannels();
|
||||||
|
debug_grpc_channel = (*stream_channels)[grpc_stream_url];
|
||||||
|
}
|
||||||
|
if (debug_grpc_channel->ReadEventReply(event_reply)) {
|
||||||
|
return Status::OK();
|
||||||
|
} else {
|
||||||
|
return errors::Cancelled(strings::StrCat(
|
||||||
|
"Reading EventReply from stream URL ", grpc_stream_url, " failed."));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// static
|
// static
|
||||||
Status DebugGrpcIO::SendEventProtoThroughGrpcStream(
|
Status DebugGrpcIO::SendEventProtoThroughGrpcStream(
|
||||||
const Event& event_proto, const string& grpc_stream_url) {
|
const Event& event_proto, const string& grpc_stream_url) {
|
||||||
@ -821,9 +862,7 @@ Status DebugGrpcIO::SendEventProtoThroughGrpcStream(
|
|||||||
debug_grpc_channel.reset(new DebugGrpcChannel(server_stream_addr));
|
debug_grpc_channel.reset(new DebugGrpcChannel(server_stream_addr));
|
||||||
TF_RETURN_IF_ERROR(
|
TF_RETURN_IF_ERROR(
|
||||||
debug_grpc_channel->Connect(channel_connection_timeout_micros));
|
debug_grpc_channel->Connect(channel_connection_timeout_micros));
|
||||||
|
|
||||||
(*stream_channels)[grpc_stream_url] = debug_grpc_channel;
|
(*stream_channels)[grpc_stream_url] = debug_grpc_channel;
|
||||||
CreateEmptyEnabledSet(grpc_stream_url);
|
|
||||||
} else {
|
} else {
|
||||||
debug_grpc_channel = (*stream_channels)[grpc_stream_url];
|
debug_grpc_channel = (*stream_channels)[grpc_stream_url];
|
||||||
}
|
}
|
||||||
@ -838,16 +877,22 @@ Status DebugGrpcIO::SendEventProtoThroughGrpcStream(
|
|||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
|
||||||
// static
|
bool DebugGrpcIO::IsReadGateOpen(const string& grpc_debug_url,
|
||||||
bool DebugGrpcIO::IsGateOpen(const string& watch_key,
|
const string& watch_key) {
|
||||||
const string& grpc_debug_url) {
|
const DebugNodeName2State* enabled_node_to_state =
|
||||||
std::unordered_map<string, std::unordered_set<string>>* enabled_watch_keys =
|
GetEnabledDebugOpStatesAtUrl(grpc_debug_url);
|
||||||
GetEnabledWatchKeys();
|
return enabled_node_to_state->find(watch_key) != enabled_node_to_state->end();
|
||||||
if (enabled_watch_keys->find(grpc_debug_url) == enabled_watch_keys->end()) {
|
}
|
||||||
|
|
||||||
|
bool DebugGrpcIO::IsWriteGateOpen(const string& grpc_debug_url,
|
||||||
|
const string& watch_key) {
|
||||||
|
const DebugNodeName2State* enabled_node_to_state =
|
||||||
|
GetEnabledDebugOpStatesAtUrl(grpc_debug_url);
|
||||||
|
auto it = enabled_node_to_state->find(watch_key);
|
||||||
|
if (it == enabled_node_to_state->end()) {
|
||||||
return false;
|
return false;
|
||||||
} else {
|
} else {
|
||||||
const auto& url_enabled = (*enabled_watch_keys)[grpc_debug_url];
|
return it->second == EventReply::DebugOpStateChange::READ_WRITE;
|
||||||
return url_enabled.find(watch_key) != url_enabled.end();
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -871,56 +916,46 @@ Status DebugGrpcIO::CloseGrpcStream(const string& grpc_stream_url) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// static
|
// static
|
||||||
std::unordered_map<string, std::unordered_set<string>>*
|
std::unordered_map<string, DebugGrpcIO::DebugNodeName2State>*
|
||||||
DebugGrpcIO::GetEnabledWatchKeys() {
|
DebugGrpcIO::GetEnabledDebugOpStates() {
|
||||||
static std::unordered_map<string, std::unordered_set<string>>*
|
static std::unordered_map<string, DebugNodeName2State>*
|
||||||
enabled_watch_keys =
|
enabled_debug_op_states =
|
||||||
new std::unordered_map<string, std::unordered_set<string>>();
|
new std::unordered_map<string, DebugNodeName2State>();
|
||||||
return enabled_watch_keys;
|
return enabled_debug_op_states;
|
||||||
}
|
}
|
||||||
|
|
||||||
// static
|
// static
|
||||||
void DebugGrpcIO::EnableWatchKey(const string& grpc_debug_url,
|
DebugGrpcIO::DebugNodeName2State* DebugGrpcIO::GetEnabledDebugOpStatesAtUrl(
|
||||||
const string& watch_key) {
|
const string& grpc_debug_url) {
|
||||||
std::unordered_map<string, std::unordered_set<string>>* enabled_watch_keys =
|
std::unordered_map<string, DebugNodeName2State>* states =
|
||||||
GetEnabledWatchKeys();
|
GetEnabledDebugOpStates();
|
||||||
if (enabled_watch_keys->find(grpc_debug_url) == enabled_watch_keys->end()) {
|
if (states->find(grpc_debug_url) == states->end()) {
|
||||||
CreateEmptyEnabledSet(grpc_debug_url);
|
DebugNodeName2State url_enabled_debug_op_states;
|
||||||
|
(*states)[grpc_debug_url] = url_enabled_debug_op_states;
|
||||||
}
|
}
|
||||||
(*enabled_watch_keys)[grpc_debug_url].insert(watch_key);
|
return &(*states)[grpc_debug_url];
|
||||||
}
|
}
|
||||||
|
|
||||||
// static
|
// static
|
||||||
void DebugGrpcIO::DisableWatchKey(const string& grpc_debug_url,
|
void DebugGrpcIO::SetDebugNodeKeyGrpcState(
|
||||||
const string& watch_key) {
|
const string& grpc_debug_url, const string& watch_key,
|
||||||
std::unordered_map<string, std::unordered_set<string>>* enabled_watch_keys =
|
const EventReply::DebugOpStateChange::State new_state) {
|
||||||
GetEnabledWatchKeys();
|
DebugNodeName2State* states = GetEnabledDebugOpStatesAtUrl(grpc_debug_url);
|
||||||
if (enabled_watch_keys->find(grpc_debug_url) == enabled_watch_keys->end()) {
|
if (new_state == EventReply::DebugOpStateChange::DISABLED) {
|
||||||
LOG(WARNING) << "Attempt to disable a watch key for an unregistered gRPC "
|
if (states->find(watch_key) == states->end()) {
|
||||||
<< "debug URL: " << grpc_debug_url;
|
LOG(ERROR) << "Attempt to disable a watch key that is not currently "
|
||||||
} else {
|
<< "enabled at " << grpc_debug_url << ": " << watch_key;
|
||||||
std::unordered_set<string>& url_enabled =
|
|
||||||
(*enabled_watch_keys)[grpc_debug_url];
|
|
||||||
if (url_enabled.find(watch_key) == url_enabled.end()) {
|
|
||||||
LOG(WARNING) << "Attempt to disable a watch key that is not currently "
|
|
||||||
<< "enabled at " << grpc_debug_url << ": " << watch_key;
|
|
||||||
} else {
|
} else {
|
||||||
url_enabled.erase(watch_key);
|
states->erase(watch_key);
|
||||||
}
|
}
|
||||||
|
} else if (new_state != EventReply::DebugOpStateChange::STATE_UNSPECIFIED) {
|
||||||
|
(*states)[watch_key] = new_state;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// static
|
// static
|
||||||
void DebugGrpcIO::ClearEnabledWatchKeys() { GetEnabledWatchKeys()->clear(); }
|
void DebugGrpcIO::ClearEnabledWatchKeys() {
|
||||||
|
GetEnabledDebugOpStates()->clear();
|
||||||
// static
|
|
||||||
void DebugGrpcIO::CreateEmptyEnabledSet(const string& grpc_debug_url) {
|
|
||||||
std::unordered_map<string, std::unordered_set<string>>* enabled_watch_keys =
|
|
||||||
GetEnabledWatchKeys();
|
|
||||||
if (enabled_watch_keys->find(grpc_debug_url) == enabled_watch_keys->end()) {
|
|
||||||
std::unordered_set<string> empty_watch_keys;
|
|
||||||
(*enabled_watch_keys)[grpc_debug_url] = empty_watch_keys;
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
#endif // #ifndef PLATFORM_WINDOWS
|
#endif // #ifndef PLATFORM_WINDOWS
|
||||||
|
@ -16,8 +16,13 @@ limitations under the License.
|
|||||||
#ifndef TENSORFLOW_DEBUG_IO_UTILS_H_
|
#ifndef TENSORFLOW_DEBUG_IO_UTILS_H_
|
||||||
#define TENSORFLOW_DEBUG_IO_UTILS_H_
|
#define TENSORFLOW_DEBUG_IO_UTILS_H_
|
||||||
|
|
||||||
|
#include <cstddef>
|
||||||
|
#include <functional>
|
||||||
|
#include <memory>
|
||||||
|
#include <string>
|
||||||
#include <unordered_map>
|
#include <unordered_map>
|
||||||
#include <unordered_set>
|
#include <unordered_set>
|
||||||
|
#include <vector>
|
||||||
|
|
||||||
#include "tensorflow/core/framework/tensor.h"
|
#include "tensorflow/core/framework/tensor.h"
|
||||||
#include "tensorflow/core/graph/graph.h"
|
#include "tensorflow/core/graph/graph.h"
|
||||||
@ -49,6 +54,9 @@ struct DebugNodeKey {
|
|||||||
// ,job_localhost,replica_0,task_0,cpu_0.
|
// ,job_localhost,replica_0,task_0,cpu_0.
|
||||||
static const string DeviceNameToDevicePath(const string& device_name);
|
static const string DeviceNameToDevicePath(const string& device_name);
|
||||||
|
|
||||||
|
bool operator==(const DebugNodeKey& other) const;
|
||||||
|
bool operator!=(const DebugNodeKey& other) const;
|
||||||
|
|
||||||
const string device_name;
|
const string device_name;
|
||||||
const string node_name;
|
const string node_name;
|
||||||
const int32 output_slot;
|
const int32 output_slot;
|
||||||
@ -219,6 +227,19 @@ class DebugFileIO {
|
|||||||
|
|
||||||
} // namespace tensorflow
|
} // namespace tensorflow
|
||||||
|
|
||||||
|
namespace std {
|
||||||
|
|
||||||
|
template <>
|
||||||
|
struct hash<::tensorflow::DebugNodeKey> {
|
||||||
|
size_t operator()(const ::tensorflow::DebugNodeKey& k) const {
|
||||||
|
return ::tensorflow::Hash64(
|
||||||
|
::tensorflow::strings::StrCat(k.device_name, ":", k.node_name, ":",
|
||||||
|
k.output_slot, ":", k.debug_op, ":"));
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
} // namespace std
|
||||||
|
|
||||||
// TODO(cais): Support grpc:// debug URLs in open source once Python grpc
|
// TODO(cais): Support grpc:// debug URLs in open source once Python grpc
|
||||||
// genrule becomes available. See b/23796275.
|
// genrule becomes available. See b/23796275.
|
||||||
#ifndef PLATFORM_WINDOWS
|
#ifndef PLATFORM_WINDOWS
|
||||||
@ -252,7 +273,8 @@ class DebugGrpcChannel {
|
|||||||
// Write an Event proto to the debug gRPC stream.
|
// Write an Event proto to the debug gRPC stream.
|
||||||
//
|
//
|
||||||
// Thread-safety: Safe with respect to other calls to the same method and
|
// Thread-safety: Safe with respect to other calls to the same method and
|
||||||
// call to Close().
|
// calls to ReadEventReply() and Close().
|
||||||
|
//
|
||||||
// Args:
|
// Args:
|
||||||
// event: The event proto to be written to the stream.
|
// event: The event proto to be written to the stream.
|
||||||
//
|
//
|
||||||
@ -260,6 +282,19 @@ class DebugGrpcChannel {
|
|||||||
// True iff the write is successful.
|
// True iff the write is successful.
|
||||||
bool WriteEvent(const Event& event);
|
bool WriteEvent(const Event& event);
|
||||||
|
|
||||||
|
// Read an EventReply proto from the debug gRPC stream.
|
||||||
|
//
|
||||||
|
// This method blocks and waits for an EventReply from the server.
|
||||||
|
// Thread-safety: Safe with respect to other calls to the same method and
|
||||||
|
// calls to WriteEvent() and Close().
|
||||||
|
//
|
||||||
|
// Args:
|
||||||
|
// event_reply: the to-be-modified EventReply proto passed as reference.
|
||||||
|
//
|
||||||
|
// Returns:
|
||||||
|
// True iff the read is successful.
|
||||||
|
bool ReadEventReply(EventReply* event_reply);
|
||||||
|
|
||||||
// Receive EventReplies from server (if any) and close the stream and the
|
// Receive EventReplies from server (if any) and close the stream and the
|
||||||
// channel.
|
// channel.
|
||||||
Status ReceiveServerRepliesAndClose();
|
Status ReceiveServerRepliesAndClose();
|
||||||
@ -294,48 +329,49 @@ class DebugGrpcIO {
|
|||||||
static Status SendEventProtoThroughGrpcStream(const Event& event_proto,
|
static Status SendEventProtoThroughGrpcStream(const Event& event_proto,
|
||||||
const string& grpc_stream_url);
|
const string& grpc_stream_url);
|
||||||
|
|
||||||
// Checks whether a debug watch key is allowed to send data to a given grpc://
|
// Receive an EventReply proto through a debug gRPC stream.
|
||||||
// debug URL given the current gating status.
|
static Status ReceiveEventReplyProtoThroughGrpcStream(
|
||||||
//
|
EventReply* event_reply, const string& grpc_stream_url);
|
||||||
// Args:
|
|
||||||
// watch_key: debug tensor watch key, in the format of
|
// Check whether a debug watch key is read-activated at a given gRPC URL.
|
||||||
// tensor_name:debug_op, e.g., "Weights:0:DebugIdentity".
|
static bool IsReadGateOpen(const string& grpc_debug_url,
|
||||||
// grpc_debug_url: the debug URL, e.g., "grpc://localhost:3333",
|
const string& watch_key);
|
||||||
//
|
|
||||||
// Returns:
|
// Check whether a debug watch key is write-activated (i.e., read- and
|
||||||
// Whether the sending of debug data to grpc_debug_url should
|
// write-activated) at a given gRPC URL.
|
||||||
// proceed.
|
static bool IsWriteGateOpen(const string& grpc_debug_url,
|
||||||
static bool IsGateOpen(const string& watch_key, const string& grpc_debug_url);
|
const string& watch_key);
|
||||||
|
|
||||||
// Closes a gRPC stream to the given address, if it exists.
|
// Closes a gRPC stream to the given address, if it exists.
|
||||||
// Thread-safety: Safe with respect to other calls to the same method and
|
// Thread-safety: Safe with respect to other calls to the same method and
|
||||||
// calls to SendTensorThroughGrpcStream().
|
// calls to SendTensorThroughGrpcStream().
|
||||||
static Status CloseGrpcStream(const string& grpc_stream_url);
|
static Status CloseGrpcStream(const string& grpc_stream_url);
|
||||||
|
|
||||||
// Enables a debug watch key at a grpc:// debug URL.
|
// Set the gRPC state of a debug node key.
|
||||||
static void EnableWatchKey(const string& grpc_debug_url,
|
// TODO(cais): Include device information in watch_key.
|
||||||
const string& watch_key);
|
static void SetDebugNodeKeyGrpcState(
|
||||||
|
const string& grpc_debug_url, const string& watch_key,
|
||||||
// Disables a debug watch key at a grpc:// debug URL.
|
const EventReply::DebugOpStateChange::State new_state);
|
||||||
static void DisableWatchKey(const string& grpc_debug_url,
|
|
||||||
const string& watch_key);
|
|
||||||
|
|
||||||
private:
|
private:
|
||||||
|
using DebugNodeName2State =
|
||||||
|
std::unordered_map<string, EventReply::DebugOpStateChange::State>;
|
||||||
|
|
||||||
// Returns a global map from grpc debug URLs to the corresponding
|
// Returns a global map from grpc debug URLs to the corresponding
|
||||||
// DebugGrpcChannels.
|
// DebugGrpcChannels.
|
||||||
static std::unordered_map<string, std::shared_ptr<DebugGrpcChannel>>*
|
static std::unordered_map<string, std::shared_ptr<DebugGrpcChannel>>*
|
||||||
GetStreamChannels();
|
GetStreamChannels();
|
||||||
|
|
||||||
// Returns a global map from grpc debug URLs to the enabled gated debug nodes.
|
// Returns a map from debug URL to a map from debug op name to enabled state.
|
||||||
// The keys are grpc:// URLs of the debug servers, e.g., "grpc://foo:3333".
|
static std::unordered_map<string, DebugNodeName2State>*
|
||||||
// Each value element of the value has the format
|
GetEnabledDebugOpStates();
|
||||||
// <node_name>:<output_slot>:<debug_op>", e.g.,
|
|
||||||
// "Weights_1:0:DebugNumericSummary".
|
|
||||||
static std::unordered_map<string, std::unordered_set<string>>*
|
|
||||||
GetEnabledWatchKeys();
|
|
||||||
|
|
||||||
|
// Returns a map from debug op names to enabled state, for a given debug URL.
|
||||||
|
static DebugNodeName2State* GetEnabledDebugOpStatesAtUrl(
|
||||||
|
const string& grpc_debug_url);
|
||||||
|
|
||||||
|
// Clear enabled debug op state from all debug URLs (if any).
|
||||||
static void ClearEnabledWatchKeys();
|
static void ClearEnabledWatchKeys();
|
||||||
static void CreateEmptyEnabledSet(const string& grpc_debug_url);
|
|
||||||
|
|
||||||
static mutex streams_mu;
|
static mutex streams_mu;
|
||||||
static int64 channel_connection_timeout_micros;
|
static int64 channel_connection_timeout_micros;
|
||||||
|
@ -13,6 +13,8 @@ See the License for the specific language governing permissions and
|
|||||||
limitations under the License.
|
limitations under the License.
|
||||||
==============================================================================*/
|
==============================================================================*/
|
||||||
|
|
||||||
|
#include <unordered_set>
|
||||||
|
|
||||||
#include "tensorflow/core/debug/debug_io_utils.h"
|
#include "tensorflow/core/debug/debug_io_utils.h"
|
||||||
|
|
||||||
#include "tensorflow/core/debug/debugger_event_metadata.pb.h"
|
#include "tensorflow/core/debug/debugger_event_metadata.pb.h"
|
||||||
@ -62,6 +64,39 @@ TEST_F(DebugIOUtilsTest, ConstructDebugNodeKey) {
|
|||||||
debug_node_key.device_path);
|
debug_node_key.device_path);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
TEST_F(DebugIOUtilsTest, EqualityOfDebugNodeKeys) {
|
||||||
|
const DebugNodeKey debug_node_key_1("/job:worker/replica:1/task:0/gpu:2",
|
||||||
|
"hidden_1/MatMul", 0, "DebugIdentity");
|
||||||
|
const DebugNodeKey debug_node_key_2("/job:worker/replica:1/task:0/gpu:2",
|
||||||
|
"hidden_1/MatMul", 0, "DebugIdentity");
|
||||||
|
const DebugNodeKey debug_node_key_3("/job:worker/replica:1/task:0/gpu:2",
|
||||||
|
"hidden_1/BiasAdd", 0, "DebugIdentity");
|
||||||
|
const DebugNodeKey debug_node_key_4("/job:worker/replica:1/task:0/gpu:2",
|
||||||
|
"hidden_1/MatMul", 0,
|
||||||
|
"DebugNumericSummary");
|
||||||
|
EXPECT_EQ(debug_node_key_1, debug_node_key_2);
|
||||||
|
EXPECT_NE(debug_node_key_1, debug_node_key_3);
|
||||||
|
EXPECT_NE(debug_node_key_1, debug_node_key_4);
|
||||||
|
EXPECT_NE(debug_node_key_3, debug_node_key_4);
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(DebugIOUtilsTest, DebugNodeKeysIsHashable) {
|
||||||
|
const DebugNodeKey debug_node_key_1("/job:worker/replica:1/task:0/gpu:2",
|
||||||
|
"hidden_1/MatMul", 0, "DebugIdentity");
|
||||||
|
const DebugNodeKey debug_node_key_2("/job:worker/replica:1/task:0/gpu:2",
|
||||||
|
"hidden_1/MatMul", 0, "DebugIdentity");
|
||||||
|
const DebugNodeKey debug_node_key_3("/job:worker/replica:1/task:0/gpu:2",
|
||||||
|
"hidden_1/BiasAdd", 0, "DebugIdentity");
|
||||||
|
|
||||||
|
std::unordered_set<DebugNodeKey> keys;
|
||||||
|
keys.insert(debug_node_key_1);
|
||||||
|
ASSERT_EQ(1, keys.size());
|
||||||
|
keys.insert(debug_node_key_3);
|
||||||
|
ASSERT_EQ(2, keys.size());
|
||||||
|
keys.erase(debug_node_key_2);
|
||||||
|
ASSERT_EQ(1, keys.size());
|
||||||
|
}
|
||||||
|
|
||||||
TEST_F(DebugIOUtilsTest, DumpFloatTensorToFileSunnyDay) {
|
TEST_F(DebugIOUtilsTest, DumpFloatTensorToFileSunnyDay) {
|
||||||
Initialize();
|
Initialize();
|
||||||
|
|
||||||
|
@ -17,6 +17,7 @@ syntax = "proto3";
|
|||||||
|
|
||||||
package tensorflow;
|
package tensorflow;
|
||||||
|
|
||||||
|
import "tensorflow/core/framework/tensor.proto";
|
||||||
import "tensorflow/core/util/event.proto";
|
import "tensorflow/core/util/event.proto";
|
||||||
|
|
||||||
// Reply message from EventListener to the client, i.e., to the source of the
|
// Reply message from EventListener to the client, i.e., to the source of the
|
||||||
@ -24,18 +25,25 @@ import "tensorflow/core/util/event.proto";
|
|||||||
// TensorFlow graph being executed.
|
// TensorFlow graph being executed.
|
||||||
message EventReply {
|
message EventReply {
|
||||||
message DebugOpStateChange {
|
message DebugOpStateChange {
|
||||||
enum Change {
|
enum State {
|
||||||
DISABLE = 0;
|
STATE_UNSPECIFIED = 0;
|
||||||
ENABLE = 1;
|
DISABLED = 1;
|
||||||
|
READ_ONLY = 2;
|
||||||
|
READ_WRITE = 3;
|
||||||
}
|
}
|
||||||
|
|
||||||
Change change = 1;
|
State state = 1;
|
||||||
string node_name = 2;
|
string node_name = 2;
|
||||||
int32 output_slot = 3;
|
int32 output_slot = 3;
|
||||||
string debug_op = 4;
|
string debug_op = 4;
|
||||||
}
|
}
|
||||||
|
|
||||||
repeated DebugOpStateChange debug_op_state_changes = 1;
|
repeated DebugOpStateChange debug_op_state_changes = 1;
|
||||||
|
|
||||||
|
// New tensor value to override the current tensor value with.
|
||||||
|
TensorProto tensor = 2;
|
||||||
|
// TODO(cais): Make use of this field to implement overriding of tensor value
|
||||||
|
// during debugging.
|
||||||
}
|
}
|
||||||
|
|
||||||
// EventListener: Receives Event protos, e.g., from debugged TensorFlow
|
// EventListener: Receives Event protos, e.g., from debugged TensorFlow
|
||||||
|
@ -259,10 +259,6 @@ class DebugNumericSummaryOpTest : public OpsTestBase {
|
|||||||
|
|
||||||
#if defined(PLATFORM_GOOGLE)
|
#if defined(PLATFORM_GOOGLE)
|
||||||
void ClearEnabledWatchKeys() { DebugGrpcIO::ClearEnabledWatchKeys(); }
|
void ClearEnabledWatchKeys() { DebugGrpcIO::ClearEnabledWatchKeys(); }
|
||||||
|
|
||||||
void CreateEmptyEnabledSet(const string& grpc_debug_url) {
|
|
||||||
DebugGrpcIO::CreateEmptyEnabledSet(grpc_debug_url);
|
|
||||||
}
|
|
||||||
#endif
|
#endif
|
||||||
};
|
};
|
||||||
|
|
||||||
@ -604,7 +600,6 @@ TEST_F(DebugNumericSummaryOpTest, BoolSuccess) {
|
|||||||
#if defined(PLATFORM_GOOGLE)
|
#if defined(PLATFORM_GOOGLE)
|
||||||
TEST_F(DebugNumericSummaryOpTest, DisabledDueToEmptyEnabledSet) {
|
TEST_F(DebugNumericSummaryOpTest, DisabledDueToEmptyEnabledSet) {
|
||||||
ClearEnabledWatchKeys();
|
ClearEnabledWatchKeys();
|
||||||
CreateEmptyEnabledSet("grpc://server:3333");
|
|
||||||
|
|
||||||
std::vector<string> debug_urls({"grpc://server:3333"});
|
std::vector<string> debug_urls({"grpc://server:3333"});
|
||||||
TF_ASSERT_OK(InitGated(DT_FLOAT, debug_urls));
|
TF_ASSERT_OK(InitGated(DT_FLOAT, debug_urls));
|
||||||
@ -617,8 +612,9 @@ TEST_F(DebugNumericSummaryOpTest, DisabledDueToEmptyEnabledSet) {
|
|||||||
|
|
||||||
TEST_F(DebugNumericSummaryOpTest, DisabledDueToNonMatchingWatchKey) {
|
TEST_F(DebugNumericSummaryOpTest, DisabledDueToNonMatchingWatchKey) {
|
||||||
ClearEnabledWatchKeys();
|
ClearEnabledWatchKeys();
|
||||||
DebugGrpcIO::EnableWatchKey("grpc://server:3333",
|
DebugGrpcIO::SetDebugNodeKeyGrpcState(
|
||||||
"FakeTensor:1:DebugNumeriSummary");
|
"grpc://server:3333", "FakeTensor:1:DebugNumeriSummary",
|
||||||
|
EventReply::DebugOpStateChange::READ_ONLY);
|
||||||
|
|
||||||
std::vector<string> debug_urls({"grpc://server:3333"});
|
std::vector<string> debug_urls({"grpc://server:3333"});
|
||||||
TF_ASSERT_OK(InitGated(DT_FLOAT, debug_urls));
|
TF_ASSERT_OK(InitGated(DT_FLOAT, debug_urls));
|
||||||
|
@ -31,16 +31,20 @@ from tensorflow.core.debug import debug_service_pb2
|
|||||||
from tensorflow.core.framework import graph_pb2
|
from tensorflow.core.framework import graph_pb2
|
||||||
from tensorflow.python.debug.lib import debug_data
|
from tensorflow.python.debug.lib import debug_data
|
||||||
from tensorflow.python.debug.lib import debug_service_pb2_grpc
|
from tensorflow.python.debug.lib import debug_service_pb2_grpc
|
||||||
|
from tensorflow.python.platform import tf_logging as logging
|
||||||
|
|
||||||
DebugWatch = collections.namedtuple("DebugWatch",
|
DebugWatch = collections.namedtuple("DebugWatch",
|
||||||
["node_name", "output_slot", "debug_op"])
|
["node_name", "output_slot", "debug_op"])
|
||||||
|
|
||||||
|
|
||||||
def _watch_key_event_reply(to_enable, node_name, output_slot, debug_op):
|
def _watch_key_event_reply(new_state, node_name, output_slot, debug_op):
|
||||||
"""Make EventReply proto to represent a request to watch/unwatch a debug op.
|
"""Make `EventReply` proto to represent a request to watch/unwatch a debug op.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
to_enable: (`bool`) whether the request is to enable the watch key.
|
new_state: (`debug_service_pb2.EventReply.DebugOpStateChange.State`) the new
|
||||||
|
state to set the debug node to, i.e., whether the debug node will become
|
||||||
|
disabled under the grpc mode (`DISABLED`), become a watchpoint
|
||||||
|
(`READ_ONLY`) or become a breakpoint (`READ_WRITE`).
|
||||||
node_name: (`str`) name of the node.
|
node_name: (`str`) name of the node.
|
||||||
output_slot: (`int`) output slot of the tensor.
|
output_slot: (`int`) output slot of the tensor.
|
||||||
debug_op: (`str`) the debug op attached to node_name:output_slot tensor to
|
debug_op: (`str`) the debug op attached to node_name:output_slot tensor to
|
||||||
@ -51,9 +55,7 @@ def _watch_key_event_reply(to_enable, node_name, output_slot, debug_op):
|
|||||||
"""
|
"""
|
||||||
event_reply = debug_service_pb2.EventReply()
|
event_reply = debug_service_pb2.EventReply()
|
||||||
state_change = event_reply.debug_op_state_changes.add()
|
state_change = event_reply.debug_op_state_changes.add()
|
||||||
state_change.change = (
|
state_change.state = new_state
|
||||||
debug_service_pb2.EventReply.DebugOpStateChange.ENABLE
|
|
||||||
if to_enable else debug_service_pb2.EventReply.DebugOpStateChange.DISABLE)
|
|
||||||
state_change.node_name = node_name
|
state_change.node_name = node_name
|
||||||
state_change.output_slot = output_slot
|
state_change.output_slot = output_slot
|
||||||
state_change.debug_op = debug_op
|
state_change.debug_op = debug_op
|
||||||
@ -125,6 +127,7 @@ class EventListenerBaseServicer(debug_service_pb2_grpc.EventListenerServicer):
|
|||||||
|
|
||||||
self._event_reply_queue = queue.Queue()
|
self._event_reply_queue = queue.Queue()
|
||||||
self._gated_grpc_debug_watches = set()
|
self._gated_grpc_debug_watches = set()
|
||||||
|
self._breakpoints = set()
|
||||||
|
|
||||||
def SendEvents(self, request_iterator, context):
|
def SendEvents(self, request_iterator, context):
|
||||||
"""Implementation of the SendEvents service method.
|
"""Implementation of the SendEvents service method.
|
||||||
@ -171,11 +174,30 @@ class EventListenerBaseServicer(debug_service_pb2_grpc.EventListenerServicer):
|
|||||||
maybe_tensor_event = self._process_tensor_event_in_chunks(
|
maybe_tensor_event = self._process_tensor_event_in_chunks(
|
||||||
event, tensor_chunks)
|
event, tensor_chunks)
|
||||||
if maybe_tensor_event:
|
if maybe_tensor_event:
|
||||||
stream_handler.on_value_event(maybe_tensor_event)
|
event_reply = stream_handler.on_value_event(maybe_tensor_event)
|
||||||
|
if event_reply is not None:
|
||||||
|
yield event_reply
|
||||||
|
|
||||||
# The server writes EventReply messages, if any.
|
# The server writes EventReply messages, if any.
|
||||||
while not self._event_reply_queue.empty():
|
while not self._event_reply_queue.empty():
|
||||||
yield self._event_reply_queue.get()
|
event_reply = self._event_reply_queue.get()
|
||||||
|
for state_change in event_reply.debug_op_state_changes:
|
||||||
|
if (state_change.state ==
|
||||||
|
debug_service_pb2.EventReply.DebugOpStateChange.READ_WRITE):
|
||||||
|
logging.info("Adding breakpoint %s:%d:%s", state_change.node_name,
|
||||||
|
state_change.output_slot, state_change.debug_op)
|
||||||
|
self._breakpoints.add(
|
||||||
|
(state_change.node_name, state_change.output_slot,
|
||||||
|
state_change.debug_op))
|
||||||
|
elif (state_change.state ==
|
||||||
|
debug_service_pb2.EventReply.DebugOpStateChange.DISABLED):
|
||||||
|
logging.info("Removing watchpoint or breakpoint: %s:%d:%s",
|
||||||
|
state_change.node_name, state_change.output_slot,
|
||||||
|
state_change.debug_op)
|
||||||
|
self._breakpoints.discard(
|
||||||
|
(state_change.node_name, state_change.output_slot,
|
||||||
|
state_change.debug_op))
|
||||||
|
yield event_reply
|
||||||
|
|
||||||
def _process_tensor_event_in_chunks(self, event, tensor_chunks):
|
def _process_tensor_event_in_chunks(self, event, tensor_chunks):
|
||||||
"""Possibly reassemble event chunks.
|
"""Possibly reassemble event chunks.
|
||||||
@ -336,8 +358,8 @@ class EventListenerBaseServicer(debug_service_pb2_grpc.EventListenerServicer):
|
|||||||
finally:
|
finally:
|
||||||
self._server_lock.release()
|
self._server_lock.release()
|
||||||
|
|
||||||
def request_watch(self, node_name, output_slot, debug_op):
|
def request_watch(self, node_name, output_slot, debug_op, breakpoint=False):
|
||||||
"""Request enabling a debug tensor watch.
|
"""Request enabling a debug tensor watchpoint or breakpoint.
|
||||||
|
|
||||||
This will let the server send a EventReply to the client side
|
This will let the server send a EventReply to the client side
|
||||||
(i.e., the debugged TensorFlow runtime process) to request adding a watch
|
(i.e., the debugged TensorFlow runtime process) to request adding a watch
|
||||||
@ -355,12 +377,19 @@ class EventListenerBaseServicer(debug_service_pb2_grpc.EventListenerServicer):
|
|||||||
output_slot: (`int`) output slot index of the tensor to watch.
|
output_slot: (`int`) output slot index of the tensor to watch.
|
||||||
debug_op: (`str`) name of the debug op to enable. This should not include
|
debug_op: (`str`) name of the debug op to enable. This should not include
|
||||||
any attribute substrings.
|
any attribute substrings.
|
||||||
|
breakpoint: (`bool`) Iff `True`, the debug op will block and wait until it
|
||||||
|
receives an `EventReply` response from the server. The `EventReply`
|
||||||
|
proto may carry a TensorProto that modifies the value of the debug op's
|
||||||
|
output tensor.
|
||||||
"""
|
"""
|
||||||
self._event_reply_queue.put(
|
self._event_reply_queue.put(
|
||||||
_watch_key_event_reply(True, node_name, output_slot, debug_op))
|
_watch_key_event_reply(
|
||||||
|
debug_service_pb2.EventReply.DebugOpStateChange.READ_WRITE
|
||||||
|
if breakpoint else debug_service_pb2.EventReply.DebugOpStateChange.
|
||||||
|
READ_ONLY, node_name, output_slot, debug_op))
|
||||||
|
|
||||||
def request_unwatch(self, node_name, output_slot, debug_op):
|
def request_unwatch(self, node_name, output_slot, debug_op):
|
||||||
"""Request disabling a debug tensor watch.
|
"""Request disabling a debug tensor watchpoint or breakpoint.
|
||||||
|
|
||||||
The request will take effect on the next debugged `Session.run()` call.
|
The request will take effect on the next debugged `Session.run()` call.
|
||||||
|
|
||||||
@ -374,7 +403,18 @@ class EventListenerBaseServicer(debug_service_pb2_grpc.EventListenerServicer):
|
|||||||
any attribute substrings.
|
any attribute substrings.
|
||||||
"""
|
"""
|
||||||
self._event_reply_queue.put(
|
self._event_reply_queue.put(
|
||||||
_watch_key_event_reply(False, node_name, output_slot, debug_op))
|
_watch_key_event_reply(debug_service_pb2.EventReply.DebugOpStateChange.
|
||||||
|
DISABLED, node_name, output_slot, debug_op))
|
||||||
|
|
||||||
|
@property
|
||||||
|
def breakpoints(self):
|
||||||
|
"""Get a set of the currently-activated breakpoints.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A `set` of 3-tuples: (node_name, output_slot, debug_op), e.g.,
|
||||||
|
{("MatMul", 0, "DebugIdentity")}.
|
||||||
|
"""
|
||||||
|
return self._breakpoints
|
||||||
|
|
||||||
def gated_grpc_debug_watches(self):
|
def gated_grpc_debug_watches(self):
|
||||||
"""Get the list of debug watches with attribute gated_grpc=True.
|
"""Get the list of debug watches with attribute gated_grpc=True.
|
||||||
|
@ -31,6 +31,7 @@ import time
|
|||||||
|
|
||||||
import portpicker
|
import portpicker
|
||||||
|
|
||||||
|
from tensorflow.core.debug import debug_service_pb2
|
||||||
from tensorflow.core.protobuf import config_pb2
|
from tensorflow.core.protobuf import config_pb2
|
||||||
from tensorflow.core.util import event_pb2
|
from tensorflow.core.util import event_pb2
|
||||||
from tensorflow.python.client import session
|
from tensorflow.python.client import session
|
||||||
@ -138,13 +139,26 @@ class EventListenerTestStreamHandler(
|
|||||||
|
|
||||||
Args:
|
Args:
|
||||||
event: The Event proto carrying a tensor value.
|
event: The Event proto carrying a tensor value.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
If the debug node belongs to the set of currently activated breakpoints,
|
||||||
|
a `EventReply` proto will be returned.
|
||||||
"""
|
"""
|
||||||
if self._dump_dir:
|
if self._dump_dir:
|
||||||
self._write_value_event(event)
|
self._write_value_event(event)
|
||||||
else:
|
else:
|
||||||
value = event.summary.value[0]
|
value = event.summary.value[0]
|
||||||
|
tensor_value = debug_data.load_tensor_from_event(event)
|
||||||
self._event_listener_servicer.debug_tensor_values[value.node_name].append(
|
self._event_listener_servicer.debug_tensor_values[value.node_name].append(
|
||||||
debug_data.load_tensor_from_event(event))
|
tensor_value)
|
||||||
|
|
||||||
|
items = event.summary.value[0].node_name.split(":")
|
||||||
|
node_name = items[0]
|
||||||
|
output_slot = int(items[1])
|
||||||
|
debug_op = items[2]
|
||||||
|
if ((node_name, output_slot, debug_op) in
|
||||||
|
self._event_listener_servicer.breakpoints):
|
||||||
|
return debug_service_pb2.EventReply()
|
||||||
|
|
||||||
def _try_makedirs(self, dir_path):
|
def _try_makedirs(self, dir_path):
|
||||||
if not os.path.isdir(dir_path):
|
if not os.path.isdir(dir_path):
|
||||||
|
@ -556,6 +556,55 @@ class SessionDebugGrpcGatingTest(test_util.TensorFlowTestCase):
|
|||||||
[50 + 5.0 * i],
|
[50 + 5.0 * i],
|
||||||
self._server_2.debug_tensor_values["v:0:DebugIdentity"])
|
self._server_2.debug_tensor_values["v:0:DebugIdentity"])
|
||||||
|
|
||||||
|
def testToggleBreakpointWorks(self):
|
||||||
|
with session.Session(config=no_rewrite_session_config()) as sess:
|
||||||
|
v = variables.Variable(50.0, name="v")
|
||||||
|
delta = constant_op.constant(5.0, name="delta")
|
||||||
|
inc_v = state_ops.assign_add(v, delta, name="inc_v")
|
||||||
|
|
||||||
|
sess.run(v.initializer)
|
||||||
|
|
||||||
|
run_metadata = config_pb2.RunMetadata()
|
||||||
|
run_options = config_pb2.RunOptions(output_partition_graphs=True)
|
||||||
|
debug_utils.watch_graph(
|
||||||
|
run_options,
|
||||||
|
sess.graph,
|
||||||
|
debug_ops=["DebugIdentity(gated_grpc=true)"],
|
||||||
|
debug_urls=[self._debug_server_url_1])
|
||||||
|
|
||||||
|
for i in xrange(4):
|
||||||
|
self._server_1.clear_data()
|
||||||
|
|
||||||
|
# N.B.: These requests will be fulfilled not in this debugged
|
||||||
|
# Session.run() invocation, but in the next one.
|
||||||
|
if i in (0, 2):
|
||||||
|
# Enable breakpoint at delta:0:DebugIdentity in runs 0 and 2.
|
||||||
|
self._server_1.request_watch(
|
||||||
|
"delta", 0, "DebugIdentity", breakpoint=True)
|
||||||
|
else:
|
||||||
|
# Disable the breakpoint in runs 1 and 3.
|
||||||
|
self._server_1.request_unwatch("delta", 0, "DebugIdentity")
|
||||||
|
|
||||||
|
output = sess.run(inc_v, options=run_options, run_metadata=run_metadata)
|
||||||
|
self.assertAllClose(50.0 + 5.0 * (i + 1), output)
|
||||||
|
|
||||||
|
if i in (0, 2):
|
||||||
|
# After the end of runs 0 and 2, the server has received the requests
|
||||||
|
# to enable the breakpoint at delta:0:DebugIdentity. So the server
|
||||||
|
# should keep track of the correct breakpoints.
|
||||||
|
self.assertSetEqual({("delta", 0, "DebugIdentity")},
|
||||||
|
self._server_1.breakpoints)
|
||||||
|
else:
|
||||||
|
# During runs 1 and 3, the server should have received the published
|
||||||
|
# debug tensor delta:0:DebugIdentity. The breakpoint should have been
|
||||||
|
# unblocked by EventReply reponses from the server.
|
||||||
|
self.assertAllClose(
|
||||||
|
[5.0],
|
||||||
|
self._server_1.debug_tensor_values["delta:0:DebugIdentity"])
|
||||||
|
# After the runs, the server should have properly removed the
|
||||||
|
# breakpoints due to the request_unwatch calls.
|
||||||
|
self.assertSetEqual(set(), self._server_1.breakpoints)
|
||||||
|
|
||||||
def testGetGrpcDebugWatchesReturnsCorrectAnswer(self):
|
def testGetGrpcDebugWatchesReturnsCorrectAnswer(self):
|
||||||
with session.Session() as sess:
|
with session.Session() as sess:
|
||||||
v = variables.Variable(50.0, name="v")
|
v = variables.Variable(50.0, name="v")
|
||||||
|
Loading…
Reference in New Issue
Block a user