tfdbg: Watch output tensors without emitted edges
Currently, DebugNodeInserter is capable of only watching node output slots that have data edges (i.e., non-control edges) emitting from them. However, it would be useful to watch node output slots that do not have edges emitting from them, or output slots in nodes that have only control edges emitting from them. A common case is in the optimization ops in the backward paths of a training graph, where the train ops receives control edges from a number of nodes that each update the value of a Variable. Each of these Variable-updating nodes has a data output slot (slot 0), but those slots are not connected to any data edges. In the watch list, if you specify variable_updating_op:0, this tensor will not get watched without this CL. This CL adds the ability to watch such nodes' outputs. The CL also makes it possible to watch output slots that do not have any control or non-control edges emitting from them. Change: 136825445
This commit is contained in:
parent
e9e56f2b18
commit
00526b6f86
@ -448,6 +448,108 @@ TEST_F(SessionDebugMinusAXTest,
|
||||
}
|
||||
}
|
||||
|
||||
class SessionDebugOutputSlotWithoutOngoingEdgeTest : public ::testing::Test {
|
||||
public:
|
||||
void Initialize() {
|
||||
Graph graph(OpRegistry::Global());
|
||||
|
||||
#if GOOGLE_CUDA
|
||||
const string kDeviceName = "/job:localhost/replica:0/task:0/gpu:0";
|
||||
#else
|
||||
const string kDeviceName = "/job:localhost/replica:0/task:0/cpu:0";
|
||||
#endif
|
||||
|
||||
Tensor a_tensor(DT_FLOAT, TensorShape({1, 1}));
|
||||
test::FillValues<float>(&a_tensor, {42.0});
|
||||
Node* a = test::graph::Constant(&graph, a_tensor);
|
||||
a->set_assigned_device_name(kDeviceName);
|
||||
|
||||
Node* c = test::graph::Constant(&graph, a_tensor);
|
||||
c->set_assigned_device_name(kDeviceName);
|
||||
c_ = c->name();
|
||||
|
||||
// Node c will be executed only because of the control edge from c to y.
|
||||
// Its output slot (slot 0) does not have an outgoing edge. This test
|
||||
// is for testing that the debugger can watch that slot properly.
|
||||
Node* y = test::graph::NoOp(&graph, {c});
|
||||
y->set_assigned_device_name(kDeviceName);
|
||||
y_ = y->name();
|
||||
|
||||
test::graph::ToGraphDef(&graph, &def_);
|
||||
}
|
||||
|
||||
string c_;
|
||||
string y_;
|
||||
GraphDef def_;
|
||||
};
|
||||
|
||||
TEST_F(SessionDebugOutputSlotWithoutOngoingEdgeTest,
|
||||
WatchSlotWithoutOutgoingEdge) {
|
||||
Initialize();
|
||||
std::unique_ptr<DirectSession> session(CreateSession());
|
||||
ASSERT_TRUE(session != nullptr);
|
||||
|
||||
DebugGateway debug_gateway(session.get());
|
||||
|
||||
// Supply completion and value callbacks
|
||||
mutex mu;
|
||||
|
||||
string debug_identity_node_name = DebugNodeInserter::GetDebugNodeName(
|
||||
strings::StrCat(c_, ":", 0), 0, "DebugIdentity");
|
||||
|
||||
Notification callbacks_done;
|
||||
|
||||
debug_gateway.SetNodeCompletionCallback(
|
||||
[&mu, &callbacks_done](const string& node_name, const bool any_output) {
|
||||
mutex_lock l(mu);
|
||||
if (node_name == "_SINK" && !callbacks_done.HasBeenNotified()) {
|
||||
callbacks_done.Notify();
|
||||
}
|
||||
});
|
||||
|
||||
std::vector<Tensor> debug_identity_tensor_vals;
|
||||
debug_gateway.SetNodeValueCallback(
|
||||
[this, &mu, &debug_identity_node_name, &debug_identity_tensor_vals](
|
||||
const string& node_name, const int output_slot,
|
||||
const Tensor& tensor_value, const bool is_ref) {
|
||||
mutex_lock l(mu);
|
||||
|
||||
if (node_name == debug_identity_node_name && output_slot == 0) {
|
||||
debug_identity_tensor_vals.push_back(tensor_value);
|
||||
}
|
||||
});
|
||||
|
||||
// Add DebugIdentity watch on c:0, which does not have an outgoing edge.
|
||||
RunOptions run_opts;
|
||||
run_opts.set_output_partition_graphs(true);
|
||||
|
||||
DebugTensorWatch* tensor_watch_opts = run_opts.add_debug_tensor_watch_opts();
|
||||
tensor_watch_opts->set_node_name(c_);
|
||||
tensor_watch_opts->set_output_slot(0);
|
||||
tensor_watch_opts->add_debug_ops("DebugIdentity");
|
||||
|
||||
TF_ASSERT_OK(session->Create(def_));
|
||||
|
||||
// Invoke Session::Run() on y.
|
||||
std::vector<std::pair<string, Tensor>> inputs;
|
||||
std::vector<string> output_names;
|
||||
std::vector<string> target_nodes = {y_};
|
||||
std::vector<Tensor> outputs;
|
||||
|
||||
RunMetadata run_metadata;
|
||||
Status s = session->Run(run_opts, inputs, output_names, target_nodes,
|
||||
&outputs, &run_metadata);
|
||||
TF_ASSERT_OK(s);
|
||||
|
||||
// Wait for callbacks to complete.
|
||||
callbacks_done.WaitForNotification();
|
||||
|
||||
// Assert that DebugIdentity node watching the control edge has been run.
|
||||
ASSERT_EQ(1, debug_identity_tensor_vals.size());
|
||||
auto mat_identity = debug_identity_tensor_vals[0].matrix<float>();
|
||||
ASSERT_EQ(42.0, mat_identity(0, 0));
|
||||
}
|
||||
|
||||
class SessionDebugVariableTest : public ::testing::Test {
|
||||
public:
|
||||
void Initialize() {
|
||||
|
@ -73,66 +73,57 @@ Status DebugNodeInserter::InsertNodes(
|
||||
}
|
||||
|
||||
DeviceType device_type = DeviceType{device->device_type()};
|
||||
// 1. Record existing edges in the graph.
|
||||
std::vector<const Edge*> existing_edges;
|
||||
for (const Edge* edge : graph->edges()) {
|
||||
existing_edges.push_back(edge);
|
||||
}
|
||||
|
||||
// A map from tensor names to edges to be removed
|
||||
std::unordered_map<string, std::vector<const Edge*>> edges_to_remove;
|
||||
// A map from tensor names to newly added debug nodes (maybe more than one
|
||||
// for a given tensor).
|
||||
std::unordered_map<string, std::vector<Node*>> added_debug_nodes;
|
||||
std::unordered_map<string, Node*> added_copy_nodes;
|
||||
// Keep track of all edges to be removed.
|
||||
std::vector<const Edge*> edges_to_remove;
|
||||
|
||||
// 2. Iterate through the edges, look for edges that match the tensor watch
|
||||
// list.
|
||||
for (const Edge* edge : existing_edges) {
|
||||
Node* src_node = edge->src();
|
||||
Node* dst_node = edge->dst();
|
||||
|
||||
if (edge->IsControlEdge()) {
|
||||
continue;
|
||||
for (Node* src_node : graph->nodes()) {
|
||||
// Make a map from output slot to outgoing edges from the slot.
|
||||
std::unordered_map<int, std::vector<const Edge*>> output_slot_to_edges;
|
||||
for (const Edge* edge : src_node->out_edges()) {
|
||||
const int src_output = edge->src_output();
|
||||
if (output_slot_to_edges.find(src_output) == output_slot_to_edges.end()) {
|
||||
output_slot_to_edges[src_output] = {edge};
|
||||
} else {
|
||||
output_slot_to_edges[src_output].push_back(edge);
|
||||
}
|
||||
}
|
||||
|
||||
const bool is_ref = IsRefType(dst_node->input_type(edge->dst_input()));
|
||||
MemoryType memory_type;
|
||||
MemoryTypeForOutput(device_type, graph, src_node, edge->src_output(),
|
||||
&memory_type);
|
||||
|
||||
const string tensor_name =
|
||||
strings::StrCat(src_node->name(), ":", edge->src_output());
|
||||
if (tensor_watches.find(tensor_name) == tensor_watches.end()) {
|
||||
// Add debug nodes only for edges with matching source node and source
|
||||
// output slot.
|
||||
continue;
|
||||
}
|
||||
|
||||
if (added_copy_nodes.find(tensor_name) == added_copy_nodes.end()) {
|
||||
// It is the first time an edge with this source tensor is encountered:
|
||||
// we will:
|
||||
// 1) Mark this edge as to be removed, iff the destination node has
|
||||
// non-Ref input
|
||||
// 2) Create a Copy node
|
||||
// 3) Add a new edge, from the source tensor to the Copy node
|
||||
// 4) Add a new edge, from the Copy node to the destination node, iff
|
||||
// the destination node has non-Ref input
|
||||
// 5) Create all the requested debug nodes and their edges to the Copy
|
||||
// node.
|
||||
if (!is_ref) {
|
||||
std::vector<const Edge*> node_edges_to_remove;
|
||||
node_edges_to_remove.push_back(edge);
|
||||
edges_to_remove[tensor_name] = node_edges_to_remove;
|
||||
// Iterate through all output slots of the node.
|
||||
for (int src_output_slot = 0; src_output_slot < src_node->num_outputs();
|
||||
++src_output_slot) {
|
||||
const string tensor_name =
|
||||
strings::StrCat(src_node->name(), ":", src_output_slot);
|
||||
if (tensor_watches.find(tensor_name) == tensor_watches.end()) {
|
||||
// Add debug nodes only for edges with matching source node and source
|
||||
// output slot.
|
||||
continue;
|
||||
}
|
||||
|
||||
const DataType src_dt = src_node->output_type(edge->src_output());
|
||||
// Now we have encountered a watched tensor. We will:
|
||||
// 1) Mark this edge as to be removed, iff this is a non-Reference
|
||||
// tensor
|
||||
// 2) Create a Copy node for the tensor
|
||||
// 3) Add a new edge, from the source tensor to the Copy node
|
||||
// 4) Add a new edge, from the Copy node to the destination node, iff
|
||||
// this is a non-Reference tensor.
|
||||
// 5) Create all the requested debug nodes and their edges to the Copy
|
||||
// node.
|
||||
// 6) Add control edges from the debug nodes to the destination nodes
|
||||
// to ensure that the tensors values exported by the debug nodes
|
||||
// to the debug URLs reflect the values before the execution of
|
||||
// the destination nodes.
|
||||
|
||||
// Create the copy node.
|
||||
const DataType src_dt = src_node->output_type(src_output_slot);
|
||||
MemoryType memory_type;
|
||||
MemoryTypeForOutput(device_type, graph, src_node, src_output_slot,
|
||||
&memory_type);
|
||||
|
||||
// Create the copy node for the watched tensor.
|
||||
Node* copy_node;
|
||||
Status copy_s = CreateCopyNode(
|
||||
graph, device_type, memory_type == HOST_MEMORY, src_node->name(),
|
||||
edge->src_output(), src_dt, tensor_name, ©_node);
|
||||
src_output_slot, src_dt, tensor_name, ©_node);
|
||||
if (!copy_s.ok()) {
|
||||
return Status(
|
||||
error::FAILED_PRECONDITION,
|
||||
@ -140,20 +131,11 @@ Status DebugNodeInserter::InsertNodes(
|
||||
tensor_name, ", due to: ", copy_s.error_message()));
|
||||
}
|
||||
|
||||
// Record the added copy node for later use.
|
||||
added_copy_nodes[tensor_name] = copy_node;
|
||||
|
||||
// Add edge from watched tensor to the copy node.
|
||||
graph->AddEdge(src_node, edge->src_output(), copy_node, 0);
|
||||
|
||||
// Add edge from the copy node to the destination node, iff the
|
||||
// destination node has non-Ref input.
|
||||
if (!is_ref) {
|
||||
graph->AddEdge(copy_node, 0, dst_node, edge->dst_input());
|
||||
}
|
||||
graph->AddEdge(src_node, src_output_slot, copy_node, 0);
|
||||
|
||||
// Create all requested debug nodes and their edges to the Copy node.
|
||||
std::vector<Node*> node_added_debug_nodes;
|
||||
std::vector<Node*> debug_nodes;
|
||||
for (size_t i = 0; i < tensor_watches[tensor_name].size(); ++i) {
|
||||
const string& debug_op_name = tensor_watches[tensor_name][i];
|
||||
|
||||
@ -169,47 +151,37 @@ Status DebugNodeInserter::InsertNodes(
|
||||
debug_s.error_message()));
|
||||
}
|
||||
|
||||
node_added_debug_nodes.push_back(debug_node);
|
||||
|
||||
// Create edges from the Copy node to the debug node.
|
||||
graph->AddEdge(copy_node, 0, debug_node, 0);
|
||||
|
||||
debug_nodes.push_back(debug_node);
|
||||
}
|
||||
|
||||
// Is the output a reference?
|
||||
const bool is_ref = IsRefType(src_node->output_type(src_output_slot));
|
||||
|
||||
// Iterate through all outgoing edges attached to the slot.
|
||||
for (const Edge* edge : output_slot_to_edges[src_output_slot]) {
|
||||
// Mark the edge for removal.
|
||||
if (!is_ref) {
|
||||
edges_to_remove.push_back(edge);
|
||||
graph->AddEdge(copy_node, 0, edge->dst(), edge->dst_input());
|
||||
}
|
||||
|
||||
// Add control edges from the debug nodes to the destination node
|
||||
// to ensure that the debug nodes are executed before the destination
|
||||
// node.
|
||||
graph->AddEdge(debug_node, Graph::kControlSlot, dst_node,
|
||||
Graph::kControlSlot);
|
||||
}
|
||||
added_debug_nodes[tensor_name] = node_added_debug_nodes;
|
||||
} else {
|
||||
// It is not the first time an edge with this source is encountered.
|
||||
// We will do the following iff the destination node has non-Ref input
|
||||
// 1) Mark the edge for removal
|
||||
// 2) Create an edge from the copy node to the destination node
|
||||
// Iff the destination has Ref-input, the edge will not change.
|
||||
// Regardless of whether the destination has Ref-inpt, we will
|
||||
// 3) Add control edges from the already-created debug node(s) for the
|
||||
// watched tensor to the destination node.
|
||||
if (!is_ref) {
|
||||
edges_to_remove[tensor_name].push_back(edge);
|
||||
graph->AddEdge(added_copy_nodes[tensor_name], 0, dst_node,
|
||||
edge->dst_input());
|
||||
}
|
||||
|
||||
for (Node* debug_node : added_debug_nodes[tensor_name]) {
|
||||
graph->AddEdge(debug_node, Graph::kControlSlot, dst_node,
|
||||
Graph::kControlSlot);
|
||||
for (Node* debug_node : debug_nodes) {
|
||||
graph->AddEdge(debug_node, Graph::kControlSlot, edge->dst(),
|
||||
Graph::kControlSlot);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Remove all edges marked for removal.
|
||||
for (auto it : edges_to_remove) {
|
||||
std::vector<const Edge*> edges = it.second;
|
||||
|
||||
for (const Edge* edge : edges) {
|
||||
graph->RemoveEdge(edge);
|
||||
}
|
||||
for (const Edge* edge : edges_to_remove) {
|
||||
graph->RemoveEdge(edge);
|
||||
}
|
||||
|
||||
return Status::OK();
|
||||
|
@ -23,6 +23,7 @@ import tempfile
|
||||
|
||||
import numpy as np
|
||||
from six.moves import xrange # pylint: disable=redefined-builtin
|
||||
import tensorflow as tf
|
||||
|
||||
from tensorflow.core.protobuf import config_pb2
|
||||
from tensorflow.python.client import session
|
||||
@ -30,6 +31,7 @@ from tensorflow.python.debug import debug_data
|
||||
from tensorflow.python.debug import debug_utils
|
||||
from tensorflow.python.framework import constant_op
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.framework import test_util
|
||||
from tensorflow.python.ops import control_flow_ops
|
||||
from tensorflow.python.ops import math_ops
|
||||
@ -644,6 +646,158 @@ class SessionDebugTest(test_util.TensorFlowTestCase):
|
||||
partition_graphs=run_metadata.partition_graphs,
|
||||
validate=False)
|
||||
|
||||
def testWatchingOutputSlotWithoutOutgoingEdge(self):
|
||||
"""Test watching output slots not attached to any outgoing edges."""
|
||||
|
||||
with session.Session() as sess:
|
||||
u_init_val = np.array([[5.0, 3.0], [-1.0, 0.0]])
|
||||
u = constant_op.constant(u_init_val, shape=[2, 2], name="u")
|
||||
|
||||
# Create a control edge from a node with an output: From u to z.
|
||||
# Node u will get executed only because of the control edge. The output
|
||||
# tensor u:0 is not attached to any outgoing edge in the graph. This test
|
||||
# checks that the debugger can watch such a tensor.
|
||||
with ops.control_dependencies([u]):
|
||||
z = control_flow_ops.no_op(name="z")
|
||||
|
||||
run_options = config_pb2.RunOptions(output_partition_graphs=True)
|
||||
debug_utils.watch_graph(
|
||||
run_options,
|
||||
sess.graph,
|
||||
debug_ops=["DebugIdentity"],
|
||||
debug_urls="file://%s" % self._dump_root)
|
||||
|
||||
run_metadata = config_pb2.RunMetadata()
|
||||
sess.run(z, options=run_options, run_metadata=run_metadata)
|
||||
|
||||
dump = debug_data.DebugDumpDir(
|
||||
self._dump_root, partition_graphs=run_metadata.partition_graphs)
|
||||
|
||||
# Assert that the DebugIdentity watch on u works properly.
|
||||
self.assertEqual(1, len(dump.dumped_tensor_data))
|
||||
datum = dump.dumped_tensor_data[0]
|
||||
self.assertEqual("u", datum.node_name)
|
||||
self.assertEqual(0, datum.output_slot)
|
||||
self.assertEqual("DebugIdentity", datum.debug_op)
|
||||
self.assertAllClose([[5.0, 3.0], [-1.0, 0.0]], datum.get_tensor())
|
||||
|
||||
def testWatchingVariableUpdateOps(self):
|
||||
"""Watch output slots on Variable-updating ops, with no emitted edges."""
|
||||
|
||||
with session.Session() as sess:
|
||||
u_init = constant_op.constant(10.0)
|
||||
u = variables.Variable(u_init, name="gdo/u")
|
||||
v_init = constant_op.constant(20.0)
|
||||
v = variables.Variable(v_init, name="gdo/v")
|
||||
|
||||
w = math_ops.mul(u, v, name="gdo/w")
|
||||
# gdo stands for GradientDescentOptimizer.
|
||||
|
||||
train_op = tf.train.GradientDescentOptimizer(learning_rate=0.1).minimize(
|
||||
w, name="gdo/train")
|
||||
|
||||
u.initializer.run()
|
||||
v.initializer.run()
|
||||
|
||||
run_options = config_pb2.RunOptions(output_partition_graphs=True)
|
||||
debug_utils.watch_graph(
|
||||
run_options,
|
||||
sess.graph,
|
||||
debug_ops=["DebugIdentity"],
|
||||
debug_urls="file://%s" % self._dump_root)
|
||||
|
||||
run_metadata = config_pb2.RunMetadata()
|
||||
sess.run(train_op, options=run_options, run_metadata=run_metadata)
|
||||
|
||||
dump = debug_data.DebugDumpDir(
|
||||
self._dump_root, partition_graphs=run_metadata.partition_graphs)
|
||||
|
||||
update_u_data = dump.watch_key_to_data(
|
||||
"gdo/train/update_gdo/u/ApplyGradientDescent:0:DebugIdentity")
|
||||
self.assertEqual(1, len(update_u_data))
|
||||
|
||||
# Gradient descent on u: w = u * v, so dw / du = v.
|
||||
# Updated value of u should be:
|
||||
# 10.0 - learning_rate * v = 10.0 - 0.1 * 20.0 = 8.0
|
||||
self.assertAllClose(8.0, update_u_data[0].get_tensor())
|
||||
|
||||
update_v_data = dump.watch_key_to_data(
|
||||
"gdo/train/update_gdo/v/ApplyGradientDescent:0:DebugIdentity")
|
||||
self.assertEqual(1, len(update_v_data))
|
||||
|
||||
# Gradient descent on u: w = u * v, so dw / dv = u.
|
||||
# Updated value of u should be:
|
||||
# 20.0 - learning_rate * u = 20.0 - 0.1 * 10.0 = 19.0
|
||||
self.assertAllClose(19.0, update_v_data[0].get_tensor())
|
||||
|
||||
# Verify that the Variables u and v are updated properly.
|
||||
self.assertAllClose(8.0, sess.run(u))
|
||||
self.assertAllClose(19.0, sess.run(v))
|
||||
|
||||
def testWatchingUnconnectedOutputTensor(self):
|
||||
"""Watch an output slot not emitting any edges.
|
||||
|
||||
(Not even control edges from the node.)
|
||||
"""
|
||||
|
||||
with session.Session() as sess:
|
||||
x_init = constant_op.constant([2, 2, 3, 5, 5])
|
||||
x = variables.Variable(x_init, name="unconnected/x")
|
||||
|
||||
# The UniqueOp (tf.unique) has two output slots. Use only slot 0 in the
|
||||
# graph. Let the debugger watch the unused slot 1.
|
||||
unique_x, _ = tf.unique(x, name="unconnected/unique_x")
|
||||
y = tf.add(unique_x, [0, 1, 2], name="unconnected/y")
|
||||
|
||||
x.initializer.run()
|
||||
|
||||
# Verify that only slot 0 of unique_x has recipients, while slot 1 of the
|
||||
# same node does not have recipients.
|
||||
unique_x_slot_0_recipients = []
|
||||
unique_x_slot_1_recipients = []
|
||||
for op in sess.graph.get_operations():
|
||||
for inp in op.inputs:
|
||||
if inp.name == "unconnected/unique_x:0":
|
||||
unique_x_slot_0_recipients.append(op.name)
|
||||
elif inp.name == "unconnected/unique_x:1":
|
||||
unique_x_slot_1_recipients.append(op.name)
|
||||
|
||||
self.assertEqual(["unconnected/y"], unique_x_slot_0_recipients)
|
||||
self.assertEqual([], unique_x_slot_1_recipients)
|
||||
|
||||
run_options = config_pb2.RunOptions(output_partition_graphs=True)
|
||||
debug_utils.watch_graph(
|
||||
run_options,
|
||||
sess.graph,
|
||||
debug_ops=["DebugIdentity"],
|
||||
debug_urls="file://%s" % self._dump_root)
|
||||
|
||||
run_metadata = config_pb2.RunMetadata()
|
||||
result = sess.run(y, options=run_options, run_metadata=run_metadata)
|
||||
self.assertAllClose([2, 4, 7], result)
|
||||
|
||||
dump = debug_data.DebugDumpDir(
|
||||
self._dump_root, partition_graphs=run_metadata.partition_graphs)
|
||||
|
||||
# Assert that the connected slot (slot 0) is dumped properly.
|
||||
unique_x_slot_0_dumps = dump.watch_key_to_data(
|
||||
"unconnected/unique_x:0:DebugIdentity")
|
||||
self.assertEqual(1, len(unique_x_slot_0_dumps))
|
||||
self.assertEqual("unconnected/unique_x",
|
||||
unique_x_slot_0_dumps[0].node_name)
|
||||
self.assertEqual(0, unique_x_slot_0_dumps[0].output_slot)
|
||||
self.assertAllClose([2, 3, 5], unique_x_slot_0_dumps[0].get_tensor())
|
||||
|
||||
# Assert that the unconnected slot (slot 1) is dumped properly.
|
||||
unique_x_slot_1_dumps = dump.watch_key_to_data(
|
||||
"unconnected/unique_x:1:DebugIdentity")
|
||||
self.assertEqual(1, len(unique_x_slot_1_dumps))
|
||||
self.assertEqual("unconnected/unique_x",
|
||||
unique_x_slot_1_dumps[0].node_name)
|
||||
self.assertEqual(1, unique_x_slot_1_dumps[0].output_slot)
|
||||
self.assertAllClose([0, 0, 1, 2, 2],
|
||||
unique_x_slot_1_dumps[0].get_tensor())
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
googletest.main()
|
||||
|
Loading…
Reference in New Issue
Block a user