Add a :debug BUILD target which, when linked into a binary, enables
DirectSession support for TensorFlow Debugger (tfdbg). Binaries that do not want debugging support can avoid this dependency and its transitive deps. This replaces the previous approach that was based on a preprocessor flag (-DNOTFDBG). Change: 141321165
This commit is contained in:
parent
3b99c7947a
commit
db2a81a82c
tensorflow
@ -16,5 +16,3 @@ list(REMOVE_ITEM tf_core_direct_session_srcs ${tf_core_direct_session_test_srcs}
|
||||
add_library(tf_core_direct_session OBJECT ${tf_core_direct_session_srcs})
|
||||
|
||||
add_dependencies(tf_core_direct_session tf_core_cpu)
|
||||
|
||||
add_definitions(-DNOTFDBG)
|
@ -103,6 +103,7 @@ set(tf_proto_text_srcs
|
||||
"tensorflow/core/framework/versions.proto"
|
||||
"tensorflow/core/lib/core/error_codes.proto"
|
||||
"tensorflow/core/protobuf/config.proto"
|
||||
"tensorflow/core/protobuf/debug.proto"
|
||||
"tensorflow/core/protobuf/tensor_bundle.proto"
|
||||
"tensorflow/core/protobuf/saver.proto"
|
||||
"tensorflow/core/util/memmapped_file_system.proto"
|
||||
|
@ -138,7 +138,7 @@ $(shell mkdir -p $(DEPDIR) >/dev/null)
|
||||
# Settings for the target compiler.
|
||||
CXX := $(CC_PREFIX) gcc
|
||||
OPTFLAGS := -O2
|
||||
CXXFLAGS := --std=c++11 -DIS_SLIM_BUILD -fno-exceptions -DNDEBUG -DNOTFDBG $(OPTFLAGS)
|
||||
CXXFLAGS := --std=c++11 -DIS_SLIM_BUILD -fno-exceptions -DNDEBUG $(OPTFLAGS)
|
||||
LDFLAGS := \
|
||||
-L/usr/local/lib
|
||||
DEPFLAGS = -MT $@ -MMD -MP -MF $(DEPDIR)/$*.Td
|
||||
|
@ -8,6 +8,7 @@ tensorflow/core/protobuf/queue_runner.pb.cc
|
||||
tensorflow/core/protobuf/named_tensor.pb.cc
|
||||
tensorflow/core/protobuf/meta_graph.pb.cc
|
||||
tensorflow/core/protobuf/config.pb.cc
|
||||
tensorflow/core/protobuf/debug.pb.cc
|
||||
tensorflow/core/lib/core/error_codes.pb.cc
|
||||
tensorflow/core/framework/versions.pb.cc
|
||||
tensorflow/core/framework/variable.pb.cc
|
||||
|
@ -8,6 +8,7 @@ tensorflow/core/protobuf/queue_runner.pb.h
|
||||
tensorflow/core/protobuf/named_tensor.pb.h
|
||||
tensorflow/core/protobuf/meta_graph.pb.h
|
||||
tensorflow/core/protobuf/config.pb.h
|
||||
tensorflow/core/protobuf/debug.pb.h
|
||||
tensorflow/core/protobuf/tensor_bundle.pb.h
|
||||
tensorflow/core/lib/core/error_codes.pb.h
|
||||
tensorflow/core/framework/versions.pb.h
|
||||
|
@ -2,6 +2,7 @@ tensorflow/core/util/saved_tensor_slice.pb_text.cc
|
||||
tensorflow/core/util/memmapped_file_system.pb_text.cc
|
||||
tensorflow/core/protobuf/saver.pb_text.cc
|
||||
tensorflow/core/protobuf/config.pb_text.cc
|
||||
tensorflow/core/protobuf/debug.pb_text.cc
|
||||
tensorflow/core/protobuf/tensor_bundle.pb_text.cc
|
||||
tensorflow/core/lib/core/error_codes.pb_text.cc
|
||||
tensorflow/core/framework/versions.pb_text.cc
|
||||
|
@ -8,6 +8,7 @@ tensorflow/core/protobuf/queue_runner.proto
|
||||
tensorflow/core/protobuf/named_tensor.proto
|
||||
tensorflow/core/protobuf/meta_graph.proto
|
||||
tensorflow/core/protobuf/config.proto
|
||||
tensorflow/core/protobuf/debug.proto
|
||||
tensorflow/core/protobuf/tensor_bundle.proto
|
||||
tensorflow/core/lib/core/error_codes.proto
|
||||
tensorflow/core/framework/versions.proto
|
||||
|
@ -135,6 +135,7 @@ CORE_PROTO_SRCS = [
|
||||
"framework/versions.proto",
|
||||
"lib/core/error_codes.proto",
|
||||
"protobuf/config.proto",
|
||||
"protobuf/debug.proto",
|
||||
"protobuf/tensor_bundle.proto",
|
||||
"protobuf/saver.proto",
|
||||
"util/memmapped_file_system.proto",
|
||||
@ -768,7 +769,6 @@ cc_library(
|
||||
srcs = if_android(["//tensorflow/core:android_srcs"]),
|
||||
copts = tf_copts() + [
|
||||
"-Os",
|
||||
"-DNOTFDBG",
|
||||
],
|
||||
linkopts = ["-lz"],
|
||||
tags = [
|
||||
@ -795,7 +795,7 @@ cc_library(
|
||||
"//tensorflow/core/kernels:android_core_ops",
|
||||
"//tensorflow/core/kernels:android_extended_ops",
|
||||
]),
|
||||
copts = tf_copts() + ["-Os"] + ["-std=c++11"] + ["-DNOTFDBG"],
|
||||
copts = tf_copts() + ["-Os"] + ["-std=c++11"],
|
||||
visibility = ["//visibility:public"],
|
||||
deps = [
|
||||
":protos_cc",
|
||||
@ -828,7 +828,7 @@ cc_library(
|
||||
cc_library(
|
||||
name = "android_tensorflow_lib",
|
||||
srcs = if_android([":android_op_registrations_and_gradients"]),
|
||||
copts = tf_copts() + ["-DNOTFDBG"],
|
||||
copts = tf_copts(),
|
||||
tags = [
|
||||
"manual",
|
||||
"notap",
|
||||
@ -853,7 +853,6 @@ cc_library(
|
||||
copts = tf_copts() + [
|
||||
"-Os",
|
||||
"-DSUPPORT_SELECTIVE_REGISTRATION",
|
||||
"-DNOTFDBG",
|
||||
],
|
||||
tags = [
|
||||
"manual",
|
||||
@ -1366,7 +1365,6 @@ tf_cuda_library(
|
||||
":lib_internal",
|
||||
":proto_text",
|
||||
":protos_all_cc",
|
||||
"//tensorflow/core/debug:debug_graph_utils",
|
||||
],
|
||||
alwayslink = 1,
|
||||
)
|
||||
@ -1937,6 +1935,46 @@ tf_cc_test(
|
||||
],
|
||||
)
|
||||
|
||||
# This is identical to :common_runtime_direct_session_test with the addition of
|
||||
# a dependency on alwayslink target //third_party/tensorflow/core/debug, which
|
||||
# enables support for TensorFlow Debugger (tfdbg).
|
||||
tf_cc_test(
|
||||
name = "common_runtime_direct_session_with_debug_test",
|
||||
size = "small",
|
||||
srcs = ["common_runtime/direct_session_test.cc"],
|
||||
linkstatic = tf_kernel_tests_linkstatic(),
|
||||
deps = [
|
||||
":core",
|
||||
":core_cpu",
|
||||
":core_cpu_internal",
|
||||
":direct_session_internal",
|
||||
":framework",
|
||||
":framework_internal",
|
||||
":lib",
|
||||
":lib_internal",
|
||||
":ops",
|
||||
":protos_all_cc",
|
||||
":test",
|
||||
":test_main",
|
||||
":testlib",
|
||||
"//third_party/eigen3",
|
||||
"//tensorflow/cc:cc_ops",
|
||||
# Link with support for TensorFlow Debugger (tfdbg).
|
||||
"//tensorflow/core/debug",
|
||||
"//tensorflow/core/kernels:control_flow_ops",
|
||||
"//tensorflow/core/kernels:cwise_op",
|
||||
"//tensorflow/core/kernels:dense_update_ops",
|
||||
"//tensorflow/core/kernels:fifo_queue_op",
|
||||
"//tensorflow/core/kernels:function_ops",
|
||||
"//tensorflow/core/kernels:identity_op",
|
||||
"//tensorflow/core/kernels:matmul_op",
|
||||
"//tensorflow/core/kernels:ops_util",
|
||||
"//tensorflow/core/kernels:queue_ops",
|
||||
"//tensorflow/core/kernels:session_ops",
|
||||
"//tensorflow/core/kernels:variable_ops",
|
||||
],
|
||||
)
|
||||
|
||||
tf_cc_test(
|
||||
name = "common_runtime_direct_session_with_tracking_alloc_test",
|
||||
size = "small",
|
||||
|
37
tensorflow/core/common_runtime/debugger_state_interface.cc
Normal file
37
tensorflow/core/common_runtime/debugger_state_interface.cc
Normal file
@ -0,0 +1,37 @@
|
||||
/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/core/common_runtime/debugger_state_interface.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
DebuggerStateFactory* DebuggerStateRegistry::factory_ = nullptr;
|
||||
|
||||
// static
|
||||
void DebuggerStateRegistry::RegisterFactory(
|
||||
const DebuggerStateFactory& factory) {
|
||||
delete factory_;
|
||||
factory_ = new DebuggerStateFactory(factory);
|
||||
}
|
||||
|
||||
// static
|
||||
std::unique_ptr<DebuggerStateInterface> DebuggerStateRegistry::CreateState(
|
||||
const DebugOptions& debug_options) {
|
||||
return (factory_ == nullptr || *factory_ == nullptr)
|
||||
? nullptr
|
||||
: (*factory_)(debug_options);
|
||||
}
|
||||
|
||||
} // end namespace tensorflow
|
72
tensorflow/core/common_runtime/debugger_state_interface.h
Normal file
72
tensorflow/core/common_runtime/debugger_state_interface.h
Normal file
@ -0,0 +1,72 @@
|
||||
/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef TENSORFLOW_COMMON_RUNTIME_DEBUGGER_STATE_INTERFACE_H_
|
||||
#define TENSORFLOW_COMMON_RUNTIME_DEBUGGER_STATE_INTERFACE_H_
|
||||
|
||||
#include <memory>
|
||||
|
||||
#include "tensorflow/core/lib/core/status.h"
|
||||
#include "tensorflow/core/platform/macros.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
class DebugOptions; // Defined in core/protobuf/debug.h.
|
||||
class Device;
|
||||
class Graph;
|
||||
|
||||
// An abstract interface for storing and retrieving debugging information.
|
||||
class DebuggerStateInterface {
|
||||
public:
|
||||
virtual ~DebuggerStateInterface() {}
|
||||
|
||||
// Returns a summary string for RepeatedPtrFields of DebugTensorWatches.
|
||||
virtual const string SummarizeDebugTensorWatches() = 0;
|
||||
|
||||
// Insert special-purpose debug nodes to graph. See the documentation of
|
||||
// DebugNodeInserter::InsertNodes() for details.
|
||||
virtual Status InsertNodes(Graph* graph, Device* device) = 0;
|
||||
};
|
||||
|
||||
typedef std::function<std::unique_ptr<DebuggerStateInterface>(
|
||||
const DebugOptions& options)>
|
||||
DebuggerStateFactory;
|
||||
|
||||
// Contains only static methods for registering DebuggerStateFactory.
|
||||
// We don't expect to create any instances of this class.
|
||||
// Call DebuggerStateRegistry::RegisterFactory() at initialization time to
|
||||
// define a global factory that creates instances of DebuggerState, then call
|
||||
// DebuggerStateRegistry::CreateState() to create a single instance.
|
||||
class DebuggerStateRegistry {
|
||||
public:
|
||||
// Registers a function that creates a concrete DebuggerStateInterface
|
||||
// implementation based on DebugOptions.
|
||||
static void RegisterFactory(const DebuggerStateFactory& factory);
|
||||
|
||||
// If RegisterFactory() has been called, creates and returns a concrete
|
||||
// DebuggerStateInterface implementation using the registered factory,
|
||||
// owned by the caller. Otherwise returns nullptr.
|
||||
static std::unique_ptr<DebuggerStateInterface> CreateState(
|
||||
const DebugOptions& debug_options);
|
||||
|
||||
private:
|
||||
static DebuggerStateFactory* factory_;
|
||||
|
||||
TF_DISALLOW_COPY_AND_ASSIGN(DebuggerStateRegistry);
|
||||
};
|
||||
|
||||
} // end namespace tensorflow
|
||||
|
||||
#endif // TENSORFLOW_COMMON_RUNTIME_DEBUGGER_STATE_INTERFACE_H_
|
@ -20,6 +20,7 @@ limitations under the License.
|
||||
#include <vector>
|
||||
|
||||
#include "tensorflow/core/common_runtime/constant_folding.h"
|
||||
#include "tensorflow/core/common_runtime/debugger_state_interface.h"
|
||||
#include "tensorflow/core/common_runtime/device_factory.h"
|
||||
#include "tensorflow/core/common_runtime/executor.h"
|
||||
#include "tensorflow/core/common_runtime/function.h"
|
||||
@ -395,12 +396,10 @@ Status DirectSession::Run(const RunOptions& run_options,
|
||||
ExecutorsAndKeys* executors_and_keys;
|
||||
RunStateArgs run_state_args;
|
||||
|
||||
#ifndef NOTFDBG
|
||||
// EXPERIMENTAL: Options that allow the client to insert nodes into partition
|
||||
// graphs for debugging.
|
||||
run_state_args.debugger_state.reset(
|
||||
new DebuggerState(run_options.debug_tensor_watch_opts()));
|
||||
#endif
|
||||
run_state_args.debugger_state =
|
||||
DebuggerStateRegistry::CreateState(run_options.debug_options());
|
||||
|
||||
TF_RETURN_IF_ERROR(
|
||||
GetOrCreateExecutors(pool, input_tensor_names, output_names, target_nodes,
|
||||
@ -880,12 +879,10 @@ Status DirectSession::GetOrCreateExecutors(
|
||||
std::sort(tn_sorted.begin(), tn_sorted.end());
|
||||
|
||||
string debug_tensor_watches_summary;
|
||||
#ifndef NOTFDBG
|
||||
if (run_state_args->debugger_state) {
|
||||
debug_tensor_watches_summary =
|
||||
run_state_args->debugger_state->SummarizeDebugTensorWatches();
|
||||
}
|
||||
#endif
|
||||
const string key = strings::StrCat(
|
||||
str_util::Join(inputs_sorted, ","), "->",
|
||||
str_util::Join(outputs_sorted, ","), "/", str_util::Join(tn_sorted, ","),
|
||||
@ -985,12 +982,10 @@ Status DirectSession::GetOrCreateExecutors(
|
||||
optimizer.Optimize(lib, options_.env, device, &partition_graph);
|
||||
|
||||
// EXPERIMENTAL: tfdbg inserts debug nodes (i.e., probes) to the graph
|
||||
#ifndef NOTFDBG
|
||||
if (run_state_args->debugger_state) {
|
||||
TF_RETURN_IF_ERROR(run_state_args->debugger_state->InsertNodes(
|
||||
partition_graph, params.device));
|
||||
}
|
||||
#endif
|
||||
iter->second.reset(partition_graph);
|
||||
|
||||
TF_RETURN_IF_ERROR(EnsureMemoryTypes(DeviceType(device->device_type()),
|
||||
|
@ -24,15 +24,13 @@ limitations under the License.
|
||||
#include <vector>
|
||||
|
||||
#include "tensorflow/core/common_runtime/costmodel_manager.h"
|
||||
#include "tensorflow/core/common_runtime/debugger_state_interface.h"
|
||||
#include "tensorflow/core/common_runtime/device_mgr.h"
|
||||
#include "tensorflow/core/common_runtime/device_set.h"
|
||||
#include "tensorflow/core/common_runtime/executor.h"
|
||||
#include "tensorflow/core/common_runtime/rendezvous_mgr.h"
|
||||
#include "tensorflow/core/common_runtime/session_factory.h"
|
||||
#include "tensorflow/core/common_runtime/simple_graph_execution_state.h"
|
||||
#ifndef NOTFDBG
|
||||
#include "tensorflow/core/debug/debug_graph_utils.h"
|
||||
#endif
|
||||
#include "tensorflow/core/framework/cancellation.h"
|
||||
#include "tensorflow/core/framework/graph.pb.h"
|
||||
#include "tensorflow/core/framework/session_state.h"
|
||||
@ -48,9 +46,7 @@ limitations under the License.
|
||||
namespace tensorflow {
|
||||
|
||||
class CostModel;
|
||||
#ifndef NOTFDBG
|
||||
class DebugGateway;
|
||||
#endif
|
||||
class Device;
|
||||
class DirectSessionFactory;
|
||||
|
||||
@ -164,9 +160,7 @@ class DirectSession : public Session {
|
||||
bool is_partial_run = false;
|
||||
string handle;
|
||||
std::unique_ptr<Graph> graph;
|
||||
#ifndef NOTFDBG
|
||||
std::unique_ptr<DebuggerState> debugger_state;
|
||||
#endif
|
||||
std::unique_ptr<DebuggerStateInterface> debugger_state;
|
||||
};
|
||||
|
||||
// Initializes the base execution state given the 'graph',
|
||||
@ -303,10 +297,8 @@ class DirectSession : public Session {
|
||||
|
||||
TF_DISALLOW_COPY_AND_ASSIGN(DirectSession);
|
||||
|
||||
#ifndef NOTFDBG
|
||||
// EXPERIMENTAL: debugger (tfdbg) related
|
||||
friend class DebugGateway;
|
||||
#endif
|
||||
};
|
||||
|
||||
} // end namespace tensorflow
|
||||
|
@ -39,6 +39,21 @@ tf_proto_library_cc(
|
||||
cc_libs = ["//tensorflow/core:protos_all_cc"],
|
||||
)
|
||||
|
||||
# Depending on this target causes a concrete DebuggerState implementation
|
||||
# to be registered at initialization time. For details, please see
|
||||
# core/common_runtime/debugger_state_interface.h.
|
||||
cc_library(
|
||||
name = "debug",
|
||||
srcs = ["debug.cc"],
|
||||
copts = tf_copts(),
|
||||
linkstatic = 1,
|
||||
deps = [
|
||||
":debug_graph_utils",
|
||||
"//tensorflow/core:core_cpu_internal",
|
||||
],
|
||||
alwayslink = 1,
|
||||
)
|
||||
|
||||
tf_cuda_library(
|
||||
name = "debug_gateway_internal",
|
||||
srcs = ["debug_gateway.cc"],
|
||||
@ -46,6 +61,7 @@ tf_cuda_library(
|
||||
copts = tf_copts(),
|
||||
linkstatic = 1,
|
||||
deps = [
|
||||
":debug",
|
||||
"//tensorflow/core:core_cpu_internal",
|
||||
"//tensorflow/core:direct_session_internal",
|
||||
"//tensorflow/core:framework",
|
||||
@ -136,6 +152,7 @@ tf_cc_test_gpu(
|
||||
linkstatic = tf_kernel_tests_linkstatic(),
|
||||
tags = tf_cuda_tests_tags() + ["nomac"],
|
||||
deps = [
|
||||
":debug",
|
||||
":debug_gateway_internal",
|
||||
":debug_graph_utils",
|
||||
"//tensorflow/cc:cc_ops",
|
||||
|
40
tensorflow/core/debug/debug.cc
Normal file
40
tensorflow/core/debug/debug.cc
Normal file
@ -0,0 +1,40 @@
|
||||
/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include <memory>
|
||||
|
||||
#include "tensorflow/core/common_runtime/debugger_state_interface.h"
|
||||
#include "tensorflow/core/debug/debug_graph_utils.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace {
|
||||
|
||||
// Registers a concrete implementation of DebuggerState for use by
|
||||
// DirectSession.
|
||||
class DebuggerStateRegistration {
|
||||
public:
|
||||
static std::unique_ptr<DebuggerStateInterface> CreateDebuggerState(
|
||||
const DebugOptions& options) {
|
||||
return std::unique_ptr<DebuggerStateInterface>(new DebuggerState(options));
|
||||
}
|
||||
|
||||
DebuggerStateRegistration() {
|
||||
DebuggerStateRegistry::RegisterFactory(CreateDebuggerState);
|
||||
}
|
||||
};
|
||||
static DebuggerStateRegistration register_debugger_state_implementation;
|
||||
|
||||
} // end namespace
|
||||
} // end namespace tensorflow
|
@ -18,6 +18,7 @@ limitations under the License.
|
||||
#include <algorithm>
|
||||
#include <unordered_map>
|
||||
|
||||
#include "tensorflow/core/debug/debug_graph_utils.h"
|
||||
#include "tensorflow/core/framework/tensor_testutil.h"
|
||||
#include "tensorflow/core/graph/testlib.h"
|
||||
#include "tensorflow/core/lib/core/notification.h"
|
||||
@ -228,7 +229,8 @@ TEST_F(SessionDebugMinusAXTest, RunSimpleNetworkWithTwoDebugNodesInserted) {
|
||||
|
||||
const string debug_identity = "DebugIdentity";
|
||||
const string debug_nan_count = "DebugNanCount";
|
||||
DebugTensorWatch* tensor_watch_opts = run_opts.add_debug_tensor_watch_opts();
|
||||
DebugTensorWatch* tensor_watch_opts =
|
||||
run_opts.mutable_debug_options()->add_debug_tensor_watch_opts();
|
||||
tensor_watch_opts->set_node_name(y_);
|
||||
tensor_watch_opts->set_output_slot(0);
|
||||
tensor_watch_opts->add_debug_ops(debug_identity);
|
||||
@ -409,7 +411,7 @@ TEST_F(SessionDebugMinusAXTest,
|
||||
run_opts.set_output_partition_graphs(true);
|
||||
|
||||
DebugTensorWatch* tensor_watch_opts =
|
||||
run_opts.add_debug_tensor_watch_opts();
|
||||
run_opts.mutable_debug_options()->add_debug_tensor_watch_opts();
|
||||
tensor_watch_opts->set_output_slot(0);
|
||||
tensor_watch_opts->add_debug_ops(debug_identity);
|
||||
|
||||
@ -561,7 +563,8 @@ TEST_F(SessionDebugOutputSlotWithoutOngoingEdgeTest,
|
||||
RunOptions run_opts;
|
||||
run_opts.set_output_partition_graphs(true);
|
||||
|
||||
DebugTensorWatch* tensor_watch_opts = run_opts.add_debug_tensor_watch_opts();
|
||||
DebugTensorWatch* tensor_watch_opts =
|
||||
run_opts.mutable_debug_options()->add_debug_tensor_watch_opts();
|
||||
tensor_watch_opts->set_node_name(c_);
|
||||
tensor_watch_opts->set_output_slot(0);
|
||||
tensor_watch_opts->add_debug_ops("DebugIdentity");
|
||||
@ -659,7 +662,8 @@ TEST_F(SessionDebugVariableTest, WatchUninitializedVariableWithDebugOps) {
|
||||
// Set up DebugTensorWatch for an uninitialized tensor (in node var).
|
||||
RunOptions run_opts;
|
||||
const string debug_identity = "DebugIdentity";
|
||||
DebugTensorWatch* tensor_watch_opts = run_opts.add_debug_tensor_watch_opts();
|
||||
DebugTensorWatch* tensor_watch_opts =
|
||||
run_opts.mutable_debug_options()->add_debug_tensor_watch_opts();
|
||||
tensor_watch_opts->set_node_name(var_node_name_);
|
||||
tensor_watch_opts->set_output_slot(0);
|
||||
tensor_watch_opts->add_debug_ops(debug_identity);
|
||||
@ -746,7 +750,8 @@ TEST_F(SessionDebugVariableTest, VariableAssignWithDebugOps) {
|
||||
run_opts.set_output_partition_graphs(true);
|
||||
const string debug_identity = "DebugIdentity";
|
||||
const string debug_nan_count = "DebugNanCount";
|
||||
DebugTensorWatch* tensor_watch_opts = run_opts.add_debug_tensor_watch_opts();
|
||||
DebugTensorWatch* tensor_watch_opts =
|
||||
run_opts.mutable_debug_options()->add_debug_tensor_watch_opts();
|
||||
tensor_watch_opts->set_node_name(var_node_name_);
|
||||
tensor_watch_opts->set_output_slot(0);
|
||||
tensor_watch_opts->add_debug_ops(debug_identity);
|
||||
@ -904,7 +909,8 @@ TEST_F(SessionDebugGPUSwitchTest, RunSwitchWithHostMemoryDebugOp) {
|
||||
const string watched_tensor = strings::StrCat(pred_node_name_, "/_1");
|
||||
|
||||
const string debug_identity = "DebugIdentity";
|
||||
DebugTensorWatch* tensor_watch_opts = run_opts.add_debug_tensor_watch_opts();
|
||||
DebugTensorWatch* tensor_watch_opts =
|
||||
run_opts.mutable_debug_options()->add_debug_tensor_watch_opts();
|
||||
tensor_watch_opts->set_node_name(watched_tensor);
|
||||
tensor_watch_opts->set_output_slot(0);
|
||||
tensor_watch_opts->add_debug_ops(debug_identity);
|
||||
|
@ -22,12 +22,12 @@ limitations under the License.
|
||||
#include "tensorflow/core/framework/op_kernel.h"
|
||||
#include "tensorflow/core/graph/node_builder.h"
|
||||
#include "tensorflow/core/lib/strings/strcat.h"
|
||||
#include "tensorflow/core/protobuf/debug.pb.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
DebuggerState::DebuggerState(
|
||||
const protobuf::RepeatedPtrField<DebugTensorWatch>& watches)
|
||||
: watches(watches), debug_urls_() {
|
||||
DebuggerState::DebuggerState(const DebugOptions& debug_options)
|
||||
: watches(debug_options.debug_tensor_watch_opts()), debug_urls_() {
|
||||
for (const DebugTensorWatch& watch : watches) {
|
||||
for (const string& url : watch.debug_urls()) {
|
||||
debug_urls_.insert(url);
|
||||
|
@ -20,26 +20,26 @@ limitations under the License.
|
||||
#include <unordered_set>
|
||||
#include <vector>
|
||||
|
||||
#include "tensorflow/core/common_runtime/debugger_state_interface.h"
|
||||
#include "tensorflow/core/common_runtime/device.h"
|
||||
#include "tensorflow/core/graph/graph.h"
|
||||
#include "tensorflow/core/lib/core/status.h"
|
||||
#include "tensorflow/core/platform/protobuf.h"
|
||||
#include "tensorflow/core/protobuf/config.pb.h"
|
||||
#include "tensorflow/core/protobuf/debug.pb.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
class DebuggerState {
|
||||
class DebuggerState : public DebuggerStateInterface {
|
||||
public:
|
||||
DebuggerState(
|
||||
const protobuf::RepeatedPtrField<DebugTensorWatch>& debug_tensor_watches);
|
||||
DebuggerState(const DebugOptions& debug_options);
|
||||
virtual ~DebuggerState();
|
||||
|
||||
// Returns a summary string for RepeatedPtrFields of DebugTensorWatches.
|
||||
const string SummarizeDebugTensorWatches();
|
||||
const string SummarizeDebugTensorWatches() override;
|
||||
|
||||
// Insert special-purpose debug nodes to graph. See the documentation of
|
||||
// DebugNodeInserter::InsertNodes() for details.
|
||||
Status InsertNodes(Graph* graph, Device* device);
|
||||
Status InsertNodes(Graph* graph, Device* device) override;
|
||||
|
||||
const protobuf::RepeatedPtrField<DebugTensorWatch>& watches;
|
||||
|
||||
|
@ -9,6 +9,7 @@ option java_package = "org.tensorflow.framework";
|
||||
import "tensorflow/core/framework/cost_graph.proto";
|
||||
import "tensorflow/core/framework/graph.proto";
|
||||
import "tensorflow/core/framework/step_stats.proto";
|
||||
import "tensorflow/core/protobuf/debug.proto";
|
||||
|
||||
message GPUOptions {
|
||||
// A value between 0 and 1 that indicates what fraction of the
|
||||
@ -222,30 +223,6 @@ message ConfigProto {
|
||||
int64 operation_timeout_in_ms = 11;
|
||||
};
|
||||
|
||||
// EXPERIMENTAL. Option for watching a node.
|
||||
message DebugTensorWatch {
|
||||
// Name of the node to watch.
|
||||
string node_name = 1;
|
||||
|
||||
// Output slot to watch.
|
||||
// The semantics of output_slot == -1 is that the node is only watched for
|
||||
// completion, but not for any output tensors. See NodeCompletionCallback
|
||||
// in debug_gateway.h.
|
||||
// TODO(cais): Implement this semantics.
|
||||
int32 output_slot = 2;
|
||||
|
||||
// Name(s) of the debugging op(s).
|
||||
// One or more than one probes on a tensor.
|
||||
// e.g., {"DebugIdentity", "DebugNanCount"}
|
||||
repeated string debug_ops = 3;
|
||||
|
||||
// URL(s) for debug targets(s).
|
||||
// E.g., "file:///foo/tfdbg_dump", "grpc://localhost:11011"
|
||||
// Each debug op listed in debug_ops will publish its output tensor (debug
|
||||
// signal) to all URLs in debug_urls.
|
||||
repeated string debug_urls = 4;
|
||||
}
|
||||
|
||||
// EXPERIMENTAL. Options for a single Run() call.
|
||||
message RunOptions {
|
||||
// TODO(pbar) Turn this into a TraceOptions proto which allows
|
||||
@ -264,12 +241,14 @@ message RunOptions {
|
||||
// The thread pool to use, if session_inter_op_thread_pool is configured.
|
||||
int32 inter_op_thread_pool = 3;
|
||||
|
||||
// Debugging options
|
||||
repeated DebugTensorWatch debug_tensor_watch_opts = 4;
|
||||
|
||||
// Whether the partition graph(s) executed by the executor(s) should be
|
||||
// outputted via RunMetadata.
|
||||
bool output_partition_graphs = 5;
|
||||
|
||||
// EXPERIMENTAL. Options used to intialize DebuggerState, if enabled.
|
||||
DebugOptions debug_options = 6;
|
||||
|
||||
reserved 4;
|
||||
}
|
||||
|
||||
// EXPERIMENTAL. Metadata output (i.e., non-Tensor) for a single Run() call.
|
||||
|
37
tensorflow/core/protobuf/debug.proto
Normal file
37
tensorflow/core/protobuf/debug.proto
Normal file
@ -0,0 +1,37 @@
|
||||
syntax = "proto3";
|
||||
|
||||
package tensorflow;
|
||||
option cc_enable_arenas = true;
|
||||
option java_outer_classname = "DebugProtos";
|
||||
option java_multiple_files = true;
|
||||
option java_package = "org.tensorflow.framework";
|
||||
|
||||
// EXPERIMENTAL. Option for watching a node.
|
||||
message DebugTensorWatch {
|
||||
// Name of the node to watch.
|
||||
string node_name = 1;
|
||||
|
||||
// Output slot to watch.
|
||||
// The semantics of output_slot == -1 is that the node is only watched for
|
||||
// completion, but not for any output tensors. See NodeCompletionCallback
|
||||
// in debug_gateway.h.
|
||||
// TODO(cais): Implement this semantics.
|
||||
int32 output_slot = 2;
|
||||
|
||||
// Name(s) of the debugging op(s).
|
||||
// One or more than one probes on a tensor.
|
||||
// e.g., {"DebugIdentity", "DebugNanCount"}
|
||||
repeated string debug_ops = 3;
|
||||
|
||||
// URL(s) for debug targets(s).
|
||||
// E.g., "file:///foo/tfdbg_dump", "grpc://localhost:11011"
|
||||
// Each debug op listed in debug_ops will publish its output tensor (debug
|
||||
// signal) to all URLs in debug_urls.
|
||||
repeated string debug_urls = 4;
|
||||
}
|
||||
|
||||
// EXPERIMENTAL. Options for initializing DebuggerState.
|
||||
message DebugOptions {
|
||||
// Debugging options
|
||||
repeated DebugTensorWatch debug_tensor_watch_opts = 4;
|
||||
}
|
@ -2018,6 +2018,7 @@ tf_py_wrap_cc(
|
||||
"//tensorflow/c:checkpoint_reader",
|
||||
"//tensorflow/c:tf_status_helper",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core/debug",
|
||||
"//tensorflow/core/distributed_runtime:server_lib",
|
||||
"//tensorflow/tools/tfprof/internal:print_model_analysis",
|
||||
"//util/python:python_headers",
|
||||
|
@ -992,7 +992,7 @@ class AnalyzerCLIWhileLoopTest(test_util.TensorFlowTestCase):
|
||||
run_options = config_pb2.RunOptions(output_partition_graphs=True)
|
||||
debug_url = "file://%s" % cls._dump_root
|
||||
|
||||
watch_opts = run_options.debug_tensor_watch_opts
|
||||
watch_opts = run_options.debug_options.debug_tensor_watch_opts
|
||||
|
||||
# Add debug tensor watch for "while/Identity".
|
||||
watch = watch_opts.add()
|
||||
|
@ -114,7 +114,7 @@ def _is_copy_node(node_name):
|
||||
"""Determine whether a node name is that of a debug Copy node.
|
||||
|
||||
Such nodes are inserted by TensorFlow core upon request in
|
||||
RunOptions.debug_tensor_watch_opts.
|
||||
RunOptions.debug_options.debug_tensor_watch_opts.
|
||||
|
||||
Args:
|
||||
node_name: Name of the node.
|
||||
@ -130,7 +130,7 @@ def _is_debug_node(node_name):
|
||||
"""Determine whether a node name is that of a debug node.
|
||||
|
||||
Such nodes are inserted by TensorFlow core upon request in
|
||||
RunOptions.debug_tensor_watch_opts.
|
||||
RunOptions.debug_options.debug_tensor_watch_opts.
|
||||
|
||||
Args:
|
||||
node_name: Name of the node.
|
||||
|
@ -41,7 +41,7 @@ def add_debug_tensor_watch(run_options,
|
||||
string with only one element.
|
||||
"""
|
||||
|
||||
watch_opts = run_options.debug_tensor_watch_opts
|
||||
watch_opts = run_options.debug_options.debug_tensor_watch_opts
|
||||
|
||||
watch = watch_opts.add()
|
||||
watch.node_name = node_name
|
||||
|
@ -98,16 +98,16 @@ class DebugUtilsTest(test_util.TensorFlowTestCase):
|
||||
debug_utils.add_debug_tensor_watch(
|
||||
self._run_options, "foo/node_b", 0, debug_urls="file:///tmp/tfdbg_2")
|
||||
|
||||
self.assertEqual(2, len(self._run_options.debug_tensor_watch_opts))
|
||||
debug_watch_opts = self._run_options.debug_options.debug_tensor_watch_opts
|
||||
self.assertEqual(2, len(debug_watch_opts))
|
||||
|
||||
watch_0 = self._run_options.debug_tensor_watch_opts[0]
|
||||
watch_1 = self._run_options.debug_tensor_watch_opts[1]
|
||||
watch_0 = debug_watch_opts[0]
|
||||
watch_1 = debug_watch_opts[1]
|
||||
|
||||
self.assertEqual("foo/node_a", watch_0.node_name)
|
||||
self.assertEqual(1, watch_0.output_slot)
|
||||
self.assertEqual("foo/node_b", watch_1.node_name)
|
||||
self.assertEqual(0, watch_1.output_slot)
|
||||
|
||||
# Verify default debug op name.
|
||||
self.assertEqual(["DebugIdentity"], watch_0.debug_ops)
|
||||
self.assertEqual(["DebugIdentity"], watch_1.debug_ops)
|
||||
@ -124,9 +124,10 @@ class DebugUtilsTest(test_util.TensorFlowTestCase):
|
||||
debug_ops="DebugNanCount",
|
||||
debug_urls="file:///tmp/tfdbg_1")
|
||||
|
||||
self.assertEqual(1, len(self._run_options.debug_tensor_watch_opts))
|
||||
debug_watch_opts = self._run_options.debug_options.debug_tensor_watch_opts
|
||||
self.assertEqual(1, len(debug_watch_opts))
|
||||
|
||||
watch_0 = self._run_options.debug_tensor_watch_opts[0]
|
||||
watch_0 = debug_watch_opts[0]
|
||||
|
||||
self.assertEqual("foo/node_a", watch_0.node_name)
|
||||
self.assertEqual(0, watch_0.output_slot)
|
||||
@ -145,9 +146,10 @@ class DebugUtilsTest(test_util.TensorFlowTestCase):
|
||||
debug_ops=["DebugNanCount", "DebugIdentity"],
|
||||
debug_urls="file:///tmp/tfdbg_1")
|
||||
|
||||
self.assertEqual(1, len(self._run_options.debug_tensor_watch_opts))
|
||||
debug_watch_opts = self._run_options.debug_options.debug_tensor_watch_opts
|
||||
self.assertEqual(1, len(debug_watch_opts))
|
||||
|
||||
watch_0 = self._run_options.debug_tensor_watch_opts[0]
|
||||
watch_0 = debug_watch_opts[0]
|
||||
|
||||
self.assertEqual("foo/node_a", watch_0.node_name)
|
||||
self.assertEqual(0, watch_0.output_slot)
|
||||
@ -166,9 +168,10 @@ class DebugUtilsTest(test_util.TensorFlowTestCase):
|
||||
debug_ops="DebugNanCount",
|
||||
debug_urls=["file:///tmp/tfdbg_1", "file:///tmp/tfdbg_2"])
|
||||
|
||||
self.assertEqual(1, len(self._run_options.debug_tensor_watch_opts))
|
||||
debug_watch_opts = self._run_options.debug_options.debug_tensor_watch_opts
|
||||
self.assertEqual(1, len(debug_watch_opts))
|
||||
|
||||
watch_0 = self._run_options.debug_tensor_watch_opts[0]
|
||||
watch_0 = debug_watch_opts[0]
|
||||
|
||||
self.assertEqual("foo/node_a", watch_0.node_name)
|
||||
self.assertEqual(0, watch_0.output_slot)
|
||||
@ -187,13 +190,13 @@ class DebugUtilsTest(test_util.TensorFlowTestCase):
|
||||
debug_ops=["DebugIdentity", "DebugNanCount"],
|
||||
debug_urls="file:///tmp/tfdbg_1")
|
||||
|
||||
self.assertEqual(self._expected_num_nodes,
|
||||
len(self._run_options.debug_tensor_watch_opts))
|
||||
debug_watch_opts = self._run_options.debug_options.debug_tensor_watch_opts
|
||||
self.assertEqual(self._expected_num_nodes, len(debug_watch_opts))
|
||||
|
||||
# Verify that each of the nodes in the graph with output tensors in the
|
||||
# graph have debug tensor watch.
|
||||
node_names = self._verify_watches(self._run_options.debug_tensor_watch_opts,
|
||||
0, ["DebugIdentity", "DebugNanCount"],
|
||||
node_names = self._verify_watches(debug_watch_opts, 0,
|
||||
["DebugIdentity", "DebugNanCount"],
|
||||
["file:///tmp/tfdbg_1"])
|
||||
|
||||
# Verify the node names.
|
||||
@ -218,9 +221,9 @@ class DebugUtilsTest(test_util.TensorFlowTestCase):
|
||||
debug_urls="file:///tmp/tfdbg_1",
|
||||
node_name_regex_whitelist="(a1$|a1_init$|a1/.*|p1$)")
|
||||
|
||||
node_names = self._verify_watches(self._run_options.debug_tensor_watch_opts,
|
||||
0, ["DebugIdentity"],
|
||||
["file:///tmp/tfdbg_1"])
|
||||
node_names = self._verify_watches(
|
||||
self._run_options.debug_options.debug_tensor_watch_opts, 0,
|
||||
["DebugIdentity"], ["file:///tmp/tfdbg_1"])
|
||||
self.assertEqual(
|
||||
sorted(["a1_init", "a1", "a1/Assign", "a1/read", "p1"]),
|
||||
sorted(node_names))
|
||||
@ -232,9 +235,9 @@ class DebugUtilsTest(test_util.TensorFlowTestCase):
|
||||
debug_urls="file:///tmp/tfdbg_1",
|
||||
op_type_regex_whitelist="(Variable|MatMul)")
|
||||
|
||||
node_names = self._verify_watches(self._run_options.debug_tensor_watch_opts,
|
||||
0, ["DebugIdentity"],
|
||||
["file:///tmp/tfdbg_1"])
|
||||
node_names = self._verify_watches(
|
||||
self._run_options.debug_options.debug_tensor_watch_opts, 0,
|
||||
["DebugIdentity"], ["file:///tmp/tfdbg_1"])
|
||||
self.assertEqual(sorted(["a1", "b", "p1"]), sorted(node_names))
|
||||
|
||||
def testWatchGraph_nodeNameAndOpTypeWhitelists(self):
|
||||
@ -245,9 +248,9 @@ class DebugUtilsTest(test_util.TensorFlowTestCase):
|
||||
node_name_regex_whitelist="([a-z]+1$)",
|
||||
op_type_regex_whitelist="(MatMul)")
|
||||
|
||||
node_names = self._verify_watches(self._run_options.debug_tensor_watch_opts,
|
||||
0, ["DebugIdentity"],
|
||||
["file:///tmp/tfdbg_1"])
|
||||
node_names = self._verify_watches(
|
||||
self._run_options.debug_options.debug_tensor_watch_opts, 0,
|
||||
["DebugIdentity"], ["file:///tmp/tfdbg_1"])
|
||||
self.assertEqual(["p1"], node_names)
|
||||
|
||||
def testWatchGraph_nodeNameBlacklist(self):
|
||||
@ -257,9 +260,9 @@ class DebugUtilsTest(test_util.TensorFlowTestCase):
|
||||
debug_urls="file:///tmp/tfdbg_1",
|
||||
node_name_regex_blacklist="(a1$|a1_init$|a1/.*|p1$)")
|
||||
|
||||
node_names = self._verify_watches(self._run_options.debug_tensor_watch_opts,
|
||||
0, ["DebugIdentity"],
|
||||
["file:///tmp/tfdbg_1"])
|
||||
node_names = self._verify_watches(
|
||||
self._run_options.debug_options.debug_tensor_watch_opts, 0,
|
||||
["DebugIdentity"], ["file:///tmp/tfdbg_1"])
|
||||
self.assertEqual(
|
||||
sorted(["b_init", "b", "b/Assign", "b/read", "c", "s"]),
|
||||
sorted(node_names))
|
||||
@ -271,9 +274,9 @@ class DebugUtilsTest(test_util.TensorFlowTestCase):
|
||||
debug_urls="file:///tmp/tfdbg_1",
|
||||
op_type_regex_blacklist="(Variable|Identity|Assign|Const)")
|
||||
|
||||
node_names = self._verify_watches(self._run_options.debug_tensor_watch_opts,
|
||||
0, ["DebugIdentity"],
|
||||
["file:///tmp/tfdbg_1"])
|
||||
node_names = self._verify_watches(
|
||||
self._run_options.debug_options.debug_tensor_watch_opts, 0,
|
||||
["DebugIdentity"], ["file:///tmp/tfdbg_1"])
|
||||
self.assertEqual(sorted(["p1", "s"]), sorted(node_names))
|
||||
|
||||
def testWatchGraph_nodeNameAndOpTypeBlacklists(self):
|
||||
@ -284,9 +287,9 @@ class DebugUtilsTest(test_util.TensorFlowTestCase):
|
||||
node_name_regex_blacklist="p1$",
|
||||
op_type_regex_blacklist="(Variable|Identity|Assign|Const)")
|
||||
|
||||
node_names = self._verify_watches(self._run_options.debug_tensor_watch_opts,
|
||||
0, ["DebugIdentity"],
|
||||
["file:///tmp/tfdbg_1"])
|
||||
node_names = self._verify_watches(
|
||||
self._run_options.debug_options.debug_tensor_watch_opts, 0,
|
||||
["DebugIdentity"], ["file:///tmp/tfdbg_1"])
|
||||
self.assertEqual(["s"], node_names)
|
||||
|
||||
|
||||
|
@ -79,7 +79,7 @@ class LocalCLIDebugHook(session_run_hook.SessionRunHook,
|
||||
self.on_run_end(on_run_end_request)
|
||||
|
||||
def _decorate_options_for_debug(self, options, graph):
|
||||
"""Modify RunOptions.debug_tensor_watch_opts for debugging.
|
||||
"""Modify RunOptions.debug_options.debug_tensor_watch_opts for debugging.
|
||||
|
||||
Args:
|
||||
options: (config_pb2.RunOptions) The RunOptions instance to be modified.
|
||||
|
@ -795,5 +795,5 @@ class _HookedSession(_WrappedSession):
|
||||
options.output_partition_graphs,
|
||||
incoming_options.output_partition_graphs)
|
||||
|
||||
options.debug_tensor_watch_opts.extend(
|
||||
incoming_options.debug_tensor_watch_opts)
|
||||
options.debug_options.debug_tensor_watch_opts.extend(
|
||||
incoming_options.debug_options.debug_tensor_watch_opts)
|
||||
|
@ -29,6 +29,7 @@ import tensorflow as tf
|
||||
|
||||
from tensorflow.contrib import testing
|
||||
from tensorflow.core.protobuf import config_pb2
|
||||
from tensorflow.core.protobuf import debug_pb2
|
||||
from tensorflow.python.training import monitored_session
|
||||
|
||||
|
||||
@ -693,7 +694,8 @@ class RunOptionsMetadataHook(tf.train.SessionRunHook):
|
||||
trace_level=self._trace_level,
|
||||
timeout_in_ms=self._timeout_in_ms,
|
||||
output_partition_graphs=self._output_partition_graphs)
|
||||
options.debug_tensor_watch_opts.extend([self._debug_tensor_watch])
|
||||
options.debug_options.debug_tensor_watch_opts.extend(
|
||||
[self._debug_tensor_watch])
|
||||
return tf.train.SessionRunArgs(None, None, options=options)
|
||||
|
||||
def after_run(self, run_context, run_values):
|
||||
@ -1019,13 +1021,13 @@ class MonitoredSessionTest(tf.test.TestCase):
|
||||
my_const = tf.constant(42, name='my_const')
|
||||
_ = tf.constant(24, name='my_const_2')
|
||||
|
||||
watch_a = config_pb2.DebugTensorWatch(
|
||||
watch_a = debug_pb2.DebugTensorWatch(
|
||||
node_name='my_const',
|
||||
output_slot=0,
|
||||
debug_ops=['DebugIdentity'],
|
||||
debug_urls=[])
|
||||
hook_a = RunOptionsMetadataHook(2, 30000, False, watch_a)
|
||||
watch_b = config_pb2.DebugTensorWatch(
|
||||
watch_b = debug_pb2.DebugTensorWatch(
|
||||
node_name='my_const_2',
|
||||
output_slot=0,
|
||||
debug_ops=['DebugIdentity'],
|
||||
@ -1044,7 +1046,8 @@ class MonitoredSessionTest(tf.test.TestCase):
|
||||
trace_level=3,
|
||||
timeout_in_ms=60000,
|
||||
output_partition_graphs=True,
|
||||
debug_tensor_watch_opts=[watch_a, watch_b])
|
||||
debug_options=debug_pb2.DebugOptions(
|
||||
debug_tensor_watch_opts=[watch_a, watch_b]))
|
||||
],
|
||||
hook_b.run_options_list)
|
||||
self.assertEqual(1, len(hook_b.run_metadata_list))
|
||||
@ -1059,21 +1062,22 @@ class MonitoredSessionTest(tf.test.TestCase):
|
||||
my_const = tf.constant(42, name='my_const')
|
||||
_ = tf.constant(24, name='my_const_2')
|
||||
|
||||
hook_watch = config_pb2.DebugTensorWatch(
|
||||
hook_watch = debug_pb2.DebugTensorWatch(
|
||||
node_name='my_const_2',
|
||||
output_slot=0,
|
||||
debug_ops=['DebugIdentity'],
|
||||
debug_urls=[])
|
||||
hook = RunOptionsMetadataHook(2, 60000, False, hook_watch)
|
||||
with tf.train.MonitoredSession(hooks=[hook]) as session:
|
||||
caller_watch = config_pb2.DebugTensorWatch(
|
||||
caller_watch = debug_pb2.DebugTensorWatch(
|
||||
node_name='my_const',
|
||||
output_slot=0,
|
||||
debug_ops=['DebugIdentity'],
|
||||
debug_urls=[])
|
||||
caller_options = config_pb2.RunOptions(
|
||||
trace_level=3, timeout_in_ms=30000, output_partition_graphs=True)
|
||||
caller_options.debug_tensor_watch_opts.extend([caller_watch])
|
||||
caller_options.debug_options.debug_tensor_watch_opts.extend(
|
||||
[caller_watch])
|
||||
self.assertEqual(42, session.run(my_const, options=caller_options))
|
||||
|
||||
# trace_level=3 from the caller should override 2 from the hook.
|
||||
@ -1088,7 +1092,8 @@ class MonitoredSessionTest(tf.test.TestCase):
|
||||
trace_level=3,
|
||||
timeout_in_ms=60000,
|
||||
output_partition_graphs=True,
|
||||
debug_tensor_watch_opts=[caller_watch, hook_watch])
|
||||
debug_options=debug_pb2.DebugOptions(
|
||||
debug_tensor_watch_opts=[caller_watch, hook_watch]))
|
||||
],
|
||||
hook.run_options_list)
|
||||
self.assertEqual(1, len(hook.run_metadata_list))
|
||||
|
Loading…
Reference in New Issue
Block a user