Extending the core DebugIdentity tensorflow operation with support for writing

to a singleton in memory datastructure that records a mapping from debug_urls
to debug events.  This simplifies reading a large number of states without
writing to disk or making internal RPC calls for arbitrary nodes.

PiperOrigin-RevId: 169337269
This commit is contained in:
A. Unique TensorFlower 2017-09-19 18:55:43 -07:00 committed by TensorFlower Gardener
parent 7ad8e25495
commit 5ce3523bcc
9 changed files with 313 additions and 56 deletions

View File

@ -7,6 +7,10 @@
# DebuggerState to be constructed at initialization time, enabling
# TensorFlow Debugger (tfdbg) support. For details, please see
# core/common_runtime/debugger_state_interface.h.
# ":debug_callback_registry" - Depending on this target exposes a global
# callback registry that will be used to record any observed tensors matching
# a watch state.
# ":debug_node_key" - Defines a struct used for tracking tensors.
package(
default_visibility = ["//tensorflow:internal"],
@ -134,6 +138,8 @@ tf_cuda_library(
copts = tf_copts(),
linkstatic = 1,
deps = [
":debug_callback_registry",
":debug_node_key",
":debug_service_proto_cc",
":debugger_event_metadata_proto_cc",
"//tensorflow/core:core_cpu_internal",
@ -167,6 +173,18 @@ tf_cuda_library(
alwayslink = 1,
)
tf_cuda_library(
name = "debug_node_key",
srcs = ["debug_node_key.cc"],
hdrs = ["debug_node_key.h"],
copts = tf_copts(),
linkstatic = 1,
visibility = ["//visibility:public"],
deps = [
"//tensorflow/core:lib",
],
)
# TODO(cais): Fix flakiness on GPU and change this back to a tf_cc_test_gpu.
# See b/34081273.
tf_cc_test(
@ -206,8 +224,10 @@ tf_cc_test(
srcs = ["debug_io_utils_test.cc"],
linkstatic = tf_kernel_tests_linkstatic(),
deps = [
":debug_callback_registry",
":debug_grpc_testlib",
":debug_io_utils",
":debug_node_key",
":debug_service_proto_cc",
":debugger_event_metadata_proto_cc",
"//tensorflow/core:core_cpu",
@ -286,6 +306,19 @@ tf_cc_test(
],
)
cc_library(
name = "debug_callback_registry",
srcs = ["debug_callback_registry.cc"],
hdrs = ["debug_callback_registry.h"],
visibility = ["//visibility:public"],
deps = [
":debug_node_key",
"//tensorflow/core:framework",
"//tensorflow/core:lib",
"//tensorflow/core:protos_all_cc",
],
)
# TODO(cais): Add the following back in when tfdbg is supported on Android.
# filegroup(
# name = "android_srcs",

View File

@ -0,0 +1,49 @@
/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/core/debug/debug_callback_registry.h"
namespace tensorflow {
DebugCallbackRegistry::DebugCallbackRegistry() {}
/*static */ DebugCallbackRegistry* DebugCallbackRegistry::instance_ = nullptr;
DebugCallbackRegistry* DebugCallbackRegistry::singleton() {
if (instance_ == nullptr) {
instance_ = new DebugCallbackRegistry();
}
return instance_;
}
void DebugCallbackRegistry::RegisterCallback(const string& key,
EventCallback callback) {
mutex_lock lock(mu_);
keyed_callback_[key] = std::move(callback);
}
DebugCallbackRegistry::EventCallback* DebugCallbackRegistry::GetCallback(
const string& key) {
mutex_lock lock(mu_);
auto iter = keyed_callback_.find(key);
return iter == keyed_callback_.end() ? nullptr : &iter->second;
}
void DebugCallbackRegistry::UnregisterCallback(const string& key) {
mutex_lock lock(mu_);
keyed_callback_.erase(key);
}
} // namespace tensorflow

View File

@ -0,0 +1,71 @@
/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_DEBUG_CALLBACK_REGISTRY_H_
#define TENSORFLOW_DEBUG_CALLBACK_REGISTRY_H_
#include <functional>
#include <map>
#include <string>
#include <vector>
#include "tensorflow/core/debug/debug_node_key.h"
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/platform/mutex.h"
namespace tensorflow {
// Supports exporting observed debug events to clients using registered
// callbacks. Users can register a callback for each debug_url stored using
// DebugTensorWatch. The callback key be equivalent to what follows
// "memcbk:///".
//
// All events generated for a watched node will be sent to the call back in the
// order that they are observed.
//
// This callback router should not be used in production or training steps. It
// is optimized for deep inspection of graph state rather than performance.
class DebugCallbackRegistry {
public:
using EventCallback = std::function<void(const DebugNodeKey&, const Tensor&)>;
// Provides singleton access to the in memory event store.
static DebugCallbackRegistry* singleton();
// Returns the registered callback, or nullptr, for key.
EventCallback* GetCallback(const string& key);
// Associates callback with key. This must be called by clients observing
// nodes to be exported by this callback router before running a session.
void RegisterCallback(const string& key, EventCallback callback);
// Removes the callback associated with key.
void UnregisterCallback(const string& key);
private:
DebugCallbackRegistry();
// Mutex to ensure that keyed events are never updated in parallel.
mutex mu_;
// Maps debug_url keys to callbacks for routing observed tensors.
std::map<string, EventCallback> keyed_callback_ GUARDED_BY(mu_);
static DebugCallbackRegistry* instance_;
};
} // namespace tensorflow
#endif // TENSORFLOW_DEBUG_CALLBACK_REGISTRY_H_

View File

@ -29,6 +29,7 @@ limitations under the License.
#pragma comment(lib, "Ws2_32.lib")
#endif // #ifndef PLATFORM_WINDOWS
#include "tensorflow/core/debug/debug_callback_registry.h"
#include "tensorflow/core/debug/debugger_event_metadata.pb.h"
#include "tensorflow/core/framework/graph.pb.h"
#include "tensorflow/core/framework/summary.pb.h"
@ -280,35 +281,12 @@ Status PublishEncodedGraphDefInChunks(const string& encoded_graph_def,
const char* const DebugIO::kDebuggerPluginName = "debugger";
const char* const DebugIO::kMetadataFilePrefix = "_tfdbg_";
const char* const DebugIO::kCoreMetadataTag = "core_metadata_";
const char* const DebugIO::kDeviceTag = "device_";
const char* const DebugIO::kGraphTag = "graph_";
const char* const DebugIO::kHashTag = "hash";
DebugNodeKey::DebugNodeKey(const string& device_name, const string& node_name,
const int32 output_slot, const string& debug_op)
: device_name(device_name),
node_name(node_name),
output_slot(output_slot),
debug_op(debug_op),
debug_node_name(
strings::StrCat(node_name, ":", output_slot, ":", debug_op)),
device_path(DeviceNameToDevicePath(device_name)) {}
bool DebugNodeKey::operator==(const DebugNodeKey& other) const {
return (device_name == other.device_name && node_name == other.node_name &&
output_slot == other.output_slot && debug_op == other.debug_op);
}
bool DebugNodeKey::operator!=(const DebugNodeKey& other) const {
return !((*this) == other);
}
Status ReadEventFromFile(const string& dump_file_path, Event* event) {
Env* env(Env::Default());
@ -338,16 +316,9 @@ Status ReadEventFromFile(const string& dump_file_path, Event* event) {
return Status::OK();
}
const string DebugNodeKey::DeviceNameToDevicePath(const string& device_name) {
return strings::StrCat(
DebugIO::kMetadataFilePrefix, DebugIO::kDeviceTag,
str_util::StringReplace(
str_util::StringReplace(device_name, ":", "_", true), "/", ",",
true));
}
const char* const DebugIO::kFileURLScheme = "file://";
const char* const DebugIO::kGrpcURLScheme = "grpc://";
const char* const DebugIO::kMemoryURLScheme = "memcbk://";
// Publishes debug metadata to a set of debug URLs.
Status DebugIO::PublishDebugMetadata(
@ -423,7 +394,7 @@ Status DebugIO::PublishDebugMetadata(
const string core_metadata_path = AppendTimestampToFilePath(
io::JoinPath(
dump_root_dir,
strings::StrCat(DebugIO::kMetadataFilePrefix,
strings::StrCat(DebugNodeKey::kMetadataFilePrefix,
DebugIO::kCoreMetadataTag, "sessionrun",
strings::Printf("%.14lld", session_run_index))),
Env::Default()->NowMicros());
@ -465,6 +436,12 @@ Status DebugIO::PublishDebugTensor(const DebugNodeKey& debug_node_key,
#else
GRPC_OSS_WINDOWS_UNIMPLEMENTED_ERROR;
#endif
} else if (str_util::Lowercase(url).find(kMemoryURLScheme) == 0) {
const string dump_root_dir = url.substr(strlen(kMemoryURLScheme));
auto* callback_registry = DebugCallbackRegistry::singleton();
auto* callback = callback_registry->GetCallback(dump_root_dir);
CHECK(callback) << "No callback registered for: " << dump_root_dir;
(*callback)(debug_node_key, tensor);
} else {
return Status(error::UNAVAILABLE,
strings::StrCat("Invalid debug target URL: ", url));
@ -515,7 +492,7 @@ Status DebugIO::PublishGraph(const Graph& graph, const string& device_name,
DebugNodeKey::DeviceNameToDevicePath(device_name));
const uint64 graph_hash = ::tensorflow::Hash64(buf);
const string file_name =
strings::StrCat(DebugIO::kMetadataFilePrefix, DebugIO::kGraphTag,
strings::StrCat(DebugNodeKey::kMetadataFilePrefix, DebugIO::kGraphTag,
DebugIO::kHashTag, graph_hash, "_", now_micros);
status.Update(

View File

@ -24,6 +24,7 @@ limitations under the License.
#include <unordered_set>
#include <vector>
#include "tensorflow/core/debug/debug_node_key.h"
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/graph/graph.h"
#include "tensorflow/core/lib/core/status.h"
@ -45,39 +46,18 @@ struct DebugWatchAndURLSpec {
const bool gated_grpc;
};
struct DebugNodeKey {
DebugNodeKey(const string& device_name, const string& node_name,
const int32 output_slot, const string& debug_op);
// Converts a device name string to a device path string.
// E.g., /job:localhost/replica:0/task:0/cpu:0 will be converted to
// ,job_localhost,replica_0,task_0,cpu_0.
static const string DeviceNameToDevicePath(const string& device_name);
bool operator==(const DebugNodeKey& other) const;
bool operator!=(const DebugNodeKey& other) const;
const string device_name;
const string node_name;
const int32 output_slot;
const string debug_op;
const string debug_node_name;
const string device_path;
};
// TODO(cais): Put static functions and members in a namespace, not a class.
class DebugIO {
public:
static const char* const kDebuggerPluginName;
static const char* const kMetadataFilePrefix;
static const char* const kCoreMetadataTag;
static const char* const kDeviceTag;
static const char* const kGraphTag;
static const char* const kHashTag;
static const char* const kFileURLScheme;
static const char* const kGrpcURLScheme;
static const char* const kMemoryURLScheme;
static Status PublishDebugMetadata(
const int64 global_step, const int64 session_run_index,

View File

@ -17,6 +17,8 @@ limitations under the License.
#include "tensorflow/core/debug/debug_io_utils.h"
#include "tensorflow/core/debug/debug_callback_registry.h"
#include "tensorflow/core/debug/debug_node_key.h"
#include "tensorflow/core/debug/debugger_event_metadata.pb.h"
#include "tensorflow/core/framework/summary.pb.h"
#include "tensorflow/core/framework/tensor_testutil.h"
@ -307,6 +309,38 @@ TEST_F(DebugIOUtilsTest, PublishTensorToMultipleFileURLs) {
}
}
TEST_F(DebugIOUtilsTest, PublishTensorToMemoryCallback) {
Initialize();
const DebugNodeKey kDebugNodeKey("/job:localhost/replica:0/task:0/cpu:0",
"foo/bar/qux/tensor_a", 0, "DebugIdentity");
const uint64 wall_time = env_->NowMicros();
bool called = false;
std::vector<string> urls = {"memcbk://test_callback"};
;
auto* callback_registry = DebugCallbackRegistry::singleton();
callback_registry->RegisterCallback(
"test_callback", [this, &kDebugNodeKey, &called](const DebugNodeKey& key,
const Tensor& tensor) {
called = true;
ASSERT_EQ(kDebugNodeKey.device_name, key.device_name);
ASSERT_EQ(kDebugNodeKey.node_name, key.node_name);
ASSERT_EQ(tensor_a_->shape(), tensor.shape());
for (int i = 0; i < tensor.flat<float>().size(); ++i) {
ASSERT_EQ(tensor_a_->flat<float>()(i), tensor.flat<float>()(i));
}
});
Status s =
DebugIO::PublishDebugTensor(kDebugNodeKey, *tensor_a_, wall_time, urls);
ASSERT_TRUE(s.ok());
ASSERT_TRUE(called);
callback_registry->UnregisterCallback("test_callback");
}
TEST_F(DebugIOUtilsTest, PublishTensorConcurrentlyToPartiallyOverlappingPaths) {
Initialize();

View File

@ -0,0 +1,53 @@
/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/core/debug/debug_node_key.h"
#include "tensorflow/core/lib/strings/str_util.h"
namespace tensorflow {
const char* const DebugNodeKey::kMetadataFilePrefix = "_tfdbg_";
const char* const DebugNodeKey::kDeviceTag = "device_";
DebugNodeKey::DebugNodeKey(const string& device_name, const string& node_name,
const int32 output_slot, const string& debug_op)
: device_name(device_name),
node_name(node_name),
output_slot(output_slot),
debug_op(debug_op),
debug_node_name(
strings::StrCat(node_name, ":", output_slot, ":", debug_op)),
device_path(DeviceNameToDevicePath(device_name)) {}
bool DebugNodeKey::operator==(const DebugNodeKey& other) const {
return (device_name == other.device_name && node_name == other.node_name &&
output_slot == other.output_slot && debug_op == other.debug_op);
}
bool DebugNodeKey::operator!=(const DebugNodeKey& other) const {
return !((*this) == other);
}
const string DebugNodeKey::DeviceNameToDevicePath(const string& device_name) {
return strings::StrCat(
kMetadataFilePrefix, kDeviceTag,
str_util::StringReplace(
str_util::StringReplace(device_name, ":", "_", true), "/", ",",
true));
}
} // namespace tensorflow

View File

@ -0,0 +1,51 @@
/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_DEBUG_NODE_KEY_H_
#define TENSORFLOW_DEBUG_NODE_KEY_H_
#include <string>
#include "tensorflow/core/platform/types.h"
namespace tensorflow {
// Encapsulates debug information for a node that was observed.
struct DebugNodeKey {
static const char* const kMetadataFilePrefix;
static const char* const kDeviceTag;
DebugNodeKey(const string& device_name, const string& node_name,
const int32 output_slot, const string& debug_op);
// Converts a device name string to a device path string.
// E.g., /job:localhost/replica:0/task:0/cpu:0 will be converted to
// ,job_localhost,replica_0,task_0,cpu_0.
static const string DeviceNameToDevicePath(const string& device_name);
bool operator==(const DebugNodeKey& other) const;
bool operator!=(const DebugNodeKey& other) const;
const string device_name;
const string node_name;
const int32 output_slot;
const string debug_op;
const string debug_node_name;
const string device_path;
};
} // namespace tensorflow
#endif // TENSORFLOW_DEBUG_NODE_KEY_H_

View File

@ -24,7 +24,16 @@ message DebugTensorWatch {
repeated string debug_ops = 3;
// URL(s) for debug targets(s).
// E.g., "file:///foo/tfdbg_dump", "grpc://localhost:11011"
//
// Supported URL formats are:
// - file:///foo/tfdbg_dump: Writes out Event content to file
// /foo/tfdbg_dump. Assumes all directories can be created if they don't
// already exist.
// - grpc://localhost:11011: Sends an RPC request to an EventListener
// service running at localhost:11011 with the event.
// - memcbk:///event_key: Routes tensors to clients using the
// callback registered with the DebugCallbackRegistry for event_key.
//
// Each debug op listed in debug_ops will publish its output tensor (debug
// signal) to all URLs in debug_urls.
//