[tfdbg] Improve compatiblity with Grappler
- Make tensors from Grappler-created nodes visible to tfdbg. - To this end, add a wildcard node name to the DebugTensorWatch proto. - Add unit test based on tfdbg's filesystem dump mode. - A few unit tests are updated to account for the fact that additional tensors get watched (mainly under GPU tests) with all runtime graph nodes now being watched. PiperOrigin-RevId: 260997920
This commit is contained in:
parent
7e297bab3f
commit
fb7da355b0
tensorflow
@ -56,6 +56,10 @@ Status DebugNodeInserter::InsertNodes(
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
// Debug ops and URLs for wildcard node names (if any).
|
||||
std::vector<string> default_debug_ops;
|
||||
std::vector<string> default_debug_urls;
|
||||
|
||||
// A map from tensor name (e.g., "node_a:0") to list of debug op names
|
||||
// (e.g., {"DebugIdentity", "DebugNanCount"})
|
||||
std::unordered_map<string, std::vector<string>> tensor_watches;
|
||||
@ -65,16 +69,39 @@ Status DebugNodeInserter::InsertNodes(
|
||||
|
||||
// Cache the proto content for fast lookup later
|
||||
for (const DebugTensorWatch& watch : watches) {
|
||||
if (watch.output_slot() < 0) {
|
||||
// The semantics of output_slot == -1 is that the node is watched only
|
||||
// for completion, but not for output tensor values (see
|
||||
// NodeCompletionCallback in debug_gateway.h).
|
||||
continue;
|
||||
}
|
||||
if (watch.debug_ops().empty()) {
|
||||
continue;
|
||||
}
|
||||
|
||||
if (watch.debug_urls().empty()) {
|
||||
continue;
|
||||
}
|
||||
|
||||
if (watch.node_name() == "*") {
|
||||
if (watch.output_slot() == -1) {
|
||||
default_debug_ops.insert(default_debug_ops.end(),
|
||||
watch.debug_ops().begin(),
|
||||
watch.debug_ops().end());
|
||||
default_debug_urls.insert(default_debug_urls.end(),
|
||||
watch.debug_urls().begin(),
|
||||
watch.debug_urls().end());
|
||||
} else {
|
||||
return Status(error::FAILED_PRECONDITION,
|
||||
strings::StrCat(
|
||||
"output_slot is expected to be -1 for wildcard ",
|
||||
"node name (\"*\"), but got ", watch.output_slot()));
|
||||
}
|
||||
continue;
|
||||
} else {
|
||||
if (watch.output_slot() < 0) {
|
||||
return Status(
|
||||
error::FAILED_PRECONDITION,
|
||||
strings::StrCat("A negative output_slot in DebugTensorWatch is ",
|
||||
"valid only for the wildcard node name (\"*\"), ",
|
||||
"but got node name ", watch.node_name()));
|
||||
}
|
||||
}
|
||||
|
||||
string tensor_name =
|
||||
strings::StrCat(watch.node_name(), ":", watch.output_slot());
|
||||
|
||||
@ -120,9 +147,9 @@ Status DebugNodeInserter::InsertNodes(
|
||||
++src_output_slot) {
|
||||
const string tensor_name =
|
||||
strings::StrCat(src_node->name(), ":", src_output_slot);
|
||||
if (tensor_watches.find(tensor_name) == tensor_watches.end()) {
|
||||
// Add debug nodes only for edges with matching source node and source
|
||||
// output slot.
|
||||
const bool explicit_tensor_match =
|
||||
tensor_watches.find(tensor_name) != tensor_watches.end();
|
||||
if (!explicit_tensor_match && default_debug_ops.empty()) {
|
||||
continue;
|
||||
}
|
||||
|
||||
@ -146,11 +173,17 @@ Status DebugNodeInserter::InsertNodes(
|
||||
src_output_slot, &memory_type));
|
||||
|
||||
// Create the copy node for the watched tensor.
|
||||
const std::vector<string> debug_ops = explicit_tensor_match
|
||||
? tensor_watches[tensor_name]
|
||||
: default_debug_ops;
|
||||
const std::vector<string> debug_urls =
|
||||
explicit_tensor_match ? tensor_watch_urls[tensor_name]
|
||||
: default_debug_urls;
|
||||
Node* copy_node;
|
||||
Status copy_s = CreateCopyNode(
|
||||
graph, device_type, memory_type == HOST_MEMORY, src_node->name(),
|
||||
src_output_slot, src_dt, tensor_name, tensor_watches[tensor_name],
|
||||
tensor_watch_urls[tensor_name], ©_node);
|
||||
Status copy_s =
|
||||
CreateCopyNode(graph, device_type, memory_type == HOST_MEMORY,
|
||||
src_node->name(), src_output_slot, src_dt, tensor_name,
|
||||
debug_ops, debug_urls, ©_node);
|
||||
if (!copy_s.ok()) {
|
||||
return Status(
|
||||
error::FAILED_PRECONDITION,
|
||||
@ -163,13 +196,13 @@ Status DebugNodeInserter::InsertNodes(
|
||||
|
||||
// Create all requested debug nodes and their edges to the Copy node.
|
||||
std::vector<Node*> debug_nodes;
|
||||
for (size_t i = 0; i < tensor_watches[tensor_name].size(); ++i) {
|
||||
const string& debug_op_name = tensor_watches[tensor_name][i];
|
||||
for (size_t i = 0; i < debug_ops.size(); ++i) {
|
||||
const string& debug_op_name = debug_ops[i];
|
||||
|
||||
Node* debug_node;
|
||||
Status debug_s = CreateDebugNode(
|
||||
graph, *device, copy_node->name(), src_dt, tensor_name,
|
||||
tensor_watch_urls[tensor_name], i, debug_op_name, &debug_node);
|
||||
Status debug_s = CreateDebugNode(graph, *device, copy_node->name(),
|
||||
src_dt, tensor_name, debug_urls, i,
|
||||
debug_op_name, &debug_node);
|
||||
if (debug_s.ok()) {
|
||||
graph->AddEdge(copy_node, 0, debug_node, 0);
|
||||
debug_nodes.push_back(debug_node);
|
||||
|
@ -10,13 +10,15 @@ option go_package = "github.com/tensorflow/tensorflow/tensorflow/go/core/protobu
|
||||
// Option for watching a node in TensorFlow Debugger (tfdbg).
|
||||
message DebugTensorWatch {
|
||||
// Name of the node to watch.
|
||||
// Use "*" for wildcard. But note: currently, regex is not supported in
|
||||
// general.
|
||||
string node_name = 1;
|
||||
|
||||
// Output slot to watch.
|
||||
// The semantics of output_slot == -1 is that the node is only watched for
|
||||
// completion, but not for any output tensors. See NodeCompletionCallback
|
||||
// in debug_gateway.h.
|
||||
// TODO(cais): Implement this semantics.
|
||||
// The semantics of output_slot == -1 is that all outputs of the node
|
||||
// will be watched (i.e., a wildcard).
|
||||
// Other negative values of output_slot are invalid and will lead to
|
||||
// errors currently.
|
||||
int32 output_slot = 2;
|
||||
|
||||
// Name(s) of the debugging op(s).
|
||||
|
@ -799,6 +799,21 @@ py_library(
|
||||
],
|
||||
)
|
||||
|
||||
cuda_py_test(
|
||||
name = "debug_grappler_test",
|
||||
size = "small",
|
||||
srcs = ["lib/debug_grappler_test.py"],
|
||||
additional_deps = [
|
||||
":debug_data",
|
||||
":debug_utils",
|
||||
"//tensorflow/python:client",
|
||||
"//tensorflow/python:framework_for_generated_wrappers",
|
||||
"//tensorflow/python:math_ops",
|
||||
"//tensorflow/python:platform_test",
|
||||
"//tensorflow/python:variables",
|
||||
],
|
||||
)
|
||||
|
||||
cuda_py_test(
|
||||
name = "session_debug_file_test",
|
||||
size = "small",
|
||||
|
@ -142,14 +142,9 @@ def assert_listed_tensors(tst,
|
||||
attr_segs = out.font_attr_segs
|
||||
line_counter = 0
|
||||
|
||||
num_tensors = len(expected_tensor_names)
|
||||
|
||||
if tensor_filter_name is None:
|
||||
tst.assertEqual("%d dumped tensor(s):" % num_tensors, next(line_iter))
|
||||
else:
|
||||
tst.assertEqual("%d dumped tensor(s) passing filter \"%s\":" %
|
||||
(num_tensors, tensor_filter_name), next(line_iter))
|
||||
num_dumped_tensors = int(next(line_iter).split(" ")[0])
|
||||
line_counter += 1
|
||||
tst.assertGreaterEqual(num_dumped_tensors, len(expected_tensor_names))
|
||||
|
||||
if op_type_regex is not None:
|
||||
tst.assertEqual("Op type regex filter: \"%s\"" % op_type_regex,
|
||||
|
121
tensorflow/python/debug/lib/debug_grappler_test.py
Normal file
121
tensorflow/python/debug/lib/debug_grappler_test.py
Normal file
@ -0,0 +1,121 @@
|
||||
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""Tests for debugger functionalities in tf.Session."""
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import os
|
||||
import shutil
|
||||
import tempfile
|
||||
|
||||
from tensorflow.core.protobuf import config_pb2
|
||||
from tensorflow.core.protobuf import rewriter_config_pb2
|
||||
from tensorflow.python.client import session
|
||||
from tensorflow.python.debug.lib import debug_data
|
||||
from tensorflow.python.debug.lib import debug_utils
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.framework import test_util
|
||||
from tensorflow.python.ops import math_ops
|
||||
from tensorflow.python.ops import variables
|
||||
from tensorflow.python.platform import googletest
|
||||
|
||||
|
||||
def _grappler_enabled_session_config():
|
||||
"""Constructs a Session config proto that explicitly enables Grappler.
|
||||
|
||||
Returns:
|
||||
A config proto that obtains extra safety for the unit tests in this
|
||||
file by ensuring that the relevant Grappler rewrites are always enabled.
|
||||
"""
|
||||
rewriter_config = rewriter_config_pb2.RewriterConfig(
|
||||
disable_model_pruning=False,
|
||||
arithmetic_optimization=rewriter_config_pb2.RewriterConfig.ON)
|
||||
graph_options = config_pb2.GraphOptions(rewrite_options=rewriter_config)
|
||||
return config_pb2.ConfigProto(graph_options=graph_options)
|
||||
|
||||
|
||||
class SessionDebugGrapplerInteractionTest(test_util.TensorFlowTestCase):
|
||||
|
||||
def setUp(self):
|
||||
super(SessionDebugGrapplerInteractionTest, self).setUp()
|
||||
self._dump_root = tempfile.mkdtemp()
|
||||
self._debug_url = "file://%s" % self._dump_root
|
||||
|
||||
def tearDown(self):
|
||||
ops.reset_default_graph()
|
||||
if os.path.isdir(self._dump_root):
|
||||
shutil.rmtree(self._dump_root)
|
||||
super(SessionDebugGrapplerInteractionTest, self).tearDown()
|
||||
|
||||
def testArithmeticOptimizationActive(self):
|
||||
"""Tests that tfdbg can dump the tensor from nodes created by Grappler."""
|
||||
with session.Session(config=_grappler_enabled_session_config()) as sess:
|
||||
u = variables.VariableV1([[1, 2], [3, 4]], name="u", dtype=dtypes.float32)
|
||||
# The next two ops should be optimized by Grappler into a single op:
|
||||
# either an AddN op or a Mul op.
|
||||
x = math_ops.add(u, u)
|
||||
x = math_ops.add(x, u)
|
||||
y = math_ops.multiply(x, u)
|
||||
|
||||
sess.run(variables.global_variables_initializer())
|
||||
|
||||
run_options = config_pb2.RunOptions(output_partition_graphs=True)
|
||||
debug_utils.watch_graph(
|
||||
run_options,
|
||||
sess.graph,
|
||||
debug_ops=["DebugIdentity"],
|
||||
debug_urls=[self._debug_url])
|
||||
|
||||
run_metadata = config_pb2.RunMetadata()
|
||||
run_result = sess.run(y, options=run_options, run_metadata=run_metadata)
|
||||
self.assertAllClose(run_result, [[3, 12], [27, 48]])
|
||||
|
||||
dump_data = debug_data.DebugDumpDir(
|
||||
self._dump_root, partition_graphs=run_metadata.partition_graphs,
|
||||
validate=True)
|
||||
|
||||
original_node_names = set([op.name for op in sess.graph.get_operations()])
|
||||
dumped_node_names = set(dump_data.nodes())
|
||||
grappler_created_node_names = dumped_node_names - original_node_names
|
||||
grappler_removed_node_names = original_node_names - dumped_node_names
|
||||
|
||||
# Assert that Grappler should have replaced some of the nodes from the
|
||||
# original graph with new nodes.
|
||||
self.assertTrue(grappler_created_node_names)
|
||||
self.assertTrue(grappler_removed_node_names)
|
||||
|
||||
# Iterate through the nodes created by Grappler. One of them should be
|
||||
# be the result of replacing the original add ops with an AddN op or a
|
||||
# Mul op.
|
||||
found_optimized_node = False
|
||||
for grappler_node_name in grappler_created_node_names:
|
||||
node_op_type = dump_data.node_op_type(grappler_node_name)
|
||||
# Look for the node created by Grappler's arithmetic optimization.
|
||||
if node_op_type in ("AddN", "Mul"):
|
||||
datum = dump_data.get_tensors(grappler_node_name, 0, "DebugIdentity")
|
||||
self.assertEqual(1, len(datum))
|
||||
self.assertAllClose(datum[0], [[3, 6], [9, 12]])
|
||||
found_optimized_node = True
|
||||
break
|
||||
self.assertTrue(
|
||||
found_optimized_node,
|
||||
"Failed to find optimized node created by Grappler's arithmetic "
|
||||
"optimization.")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
googletest.main()
|
@ -134,6 +134,10 @@ def watch_graph(run_options,
|
||||
reset_disk_byte_usage: (`bool`) whether to reset the tracked disk byte
|
||||
usage to zero (default: `False`).
|
||||
"""
|
||||
if not debug_ops:
|
||||
raise ValueError("debug_ops must not be empty or None.")
|
||||
if not debug_urls:
|
||||
raise ValueError("debug_urls must not be empty or None.")
|
||||
|
||||
if isinstance(debug_ops, str):
|
||||
debug_ops = [debug_ops]
|
||||
@ -173,6 +177,23 @@ def watch_graph(run_options,
|
||||
tolerate_debug_op_creation_failures=(
|
||||
tolerate_debug_op_creation_failures),
|
||||
global_step=global_step)
|
||||
|
||||
# If no filter for node or tensor is used, will add a wildcard node name, so
|
||||
# that all nodes, including the ones created internally by TensorFlow itself
|
||||
# (e.g., by Grappler), can be watched during debugging.
|
||||
use_node_name_wildcard = (not node_name_pattern and
|
||||
not op_type_pattern and
|
||||
not tensor_dtype_pattern)
|
||||
if use_node_name_wildcard:
|
||||
add_debug_tensor_watch(
|
||||
run_options,
|
||||
"*",
|
||||
output_slot=-1,
|
||||
debug_ops=debug_ops,
|
||||
debug_urls=debug_urls,
|
||||
tolerate_debug_op_creation_failures=tolerate_debug_op_creation_failures,
|
||||
global_step=global_step)
|
||||
|
||||
run_options.debug_options.reset_disk_byte_usage = reset_disk_byte_usage
|
||||
|
||||
|
||||
|
@ -59,11 +59,13 @@ class DebugUtilsTest(test_util.TensorFlowTestCase):
|
||||
cls._graph = cls._sess.graph
|
||||
|
||||
# These are all the expected nodes in the graph:
|
||||
# Two variables (a, b), each with four nodes (Variable, init, Assign,
|
||||
# read).
|
||||
# One constant (c).
|
||||
# One add operation and one matmul operation.
|
||||
cls._expected_num_nodes = 4 * 2 + 1 + 1 + 1
|
||||
# - Two variables (a, b), each with four nodes (Variable, init, Assign,
|
||||
# read).
|
||||
# - One constant (c).
|
||||
# - One add operation and one matmul operation.
|
||||
# - One wildcard node name ("*") that covers nodes created internally
|
||||
# by TensorFlow itself (e.g., Grappler).
|
||||
cls._expected_num_nodes = 4 * 2 + 1 + 1 + 1 + 1
|
||||
|
||||
def setUp(self):
|
||||
self._run_options = config_pb2.RunOptions()
|
||||
@ -88,9 +90,14 @@ class DebugUtilsTest(test_util.TensorFlowTestCase):
|
||||
for watch in watch_opts:
|
||||
node_names.append(watch.node_name)
|
||||
|
||||
self.assertEqual(expected_output_slot, watch.output_slot)
|
||||
self.assertEqual(expected_debug_ops, watch.debug_ops)
|
||||
self.assertEqual(expected_debug_urls, watch.debug_urls)
|
||||
if watch.node_name == "*":
|
||||
self.assertEqual(-1, watch.output_slot)
|
||||
self.assertEqual(expected_debug_ops, watch.debug_ops)
|
||||
self.assertEqual(expected_debug_urls, watch.debug_urls)
|
||||
else:
|
||||
self.assertEqual(expected_output_slot, watch.output_slot)
|
||||
self.assertEqual(expected_debug_ops, watch.debug_ops)
|
||||
self.assertEqual(expected_debug_urls, watch.debug_urls)
|
||||
|
||||
return node_names
|
||||
|
||||
@ -203,19 +210,22 @@ class DebugUtilsTest(test_util.TensorFlowTestCase):
|
||||
["file:///tmp/tfdbg_1"])
|
||||
|
||||
# Verify the node names.
|
||||
self.assertTrue("a1_init" in node_names)
|
||||
self.assertTrue("a1" in node_names)
|
||||
self.assertTrue("a1/Assign" in node_names)
|
||||
self.assertTrue("a1/read" in node_names)
|
||||
self.assertIn("a1_init", node_names)
|
||||
self.assertIn("a1", node_names)
|
||||
self.assertIn("a1/Assign", node_names)
|
||||
self.assertIn("a1/read", node_names)
|
||||
|
||||
self.assertTrue("b_init" in node_names)
|
||||
self.assertTrue("b" in node_names)
|
||||
self.assertTrue("b/Assign" in node_names)
|
||||
self.assertTrue("b/read" in node_names)
|
||||
self.assertIn("b_init", node_names)
|
||||
self.assertIn("b", node_names)
|
||||
self.assertIn("b/Assign", node_names)
|
||||
self.assertIn("b/read", node_names)
|
||||
|
||||
self.assertTrue("c" in node_names)
|
||||
self.assertTrue("p1" in node_names)
|
||||
self.assertTrue("s" in node_names)
|
||||
self.assertIn("c", node_names)
|
||||
self.assertIn("p1", node_names)
|
||||
self.assertIn("s", node_names)
|
||||
|
||||
# Assert that the wildcard node name has been created.
|
||||
self.assertIn("*", node_names)
|
||||
|
||||
@test_util.run_v1_only("b/120545219")
|
||||
def testWatchGraph_nodeNameWhitelist(self):
|
||||
|
@ -164,7 +164,7 @@ class SessionDebugGrpcTest(session_debug_testlib.SessionDebugTestBase):
|
||||
self.assertAllClose(42.0, w_result)
|
||||
|
||||
dump = debug_data.DebugDumpDir(self._dump_root)
|
||||
self.assertEqual(5, dump.size)
|
||||
self.assertLessEqual(5, dump.size)
|
||||
self.assertAllClose([2.1], dump.get_tensors("u", 0, "DebugIdentity"))
|
||||
self.assertAllClose([2.1], dump.get_tensors("u/read", 0, "DebugIdentity"))
|
||||
self.assertAllClose([20.0], dump.get_tensors("v", 0, "DebugIdentity"))
|
||||
|
@ -659,16 +659,15 @@ class SessionDebugTestBase(test_util.TensorFlowTestCase):
|
||||
|
||||
# Verify that the nodes with bad values are caught through running find
|
||||
# on the debug dump.
|
||||
self.assertEqual(3, len(bad_data))
|
||||
self.assertEqual(x_name, bad_data[0].node_name)
|
||||
self.assertEqual(y_name, bad_data[1].node_name)
|
||||
self.assertEqual(z_name, bad_data[2].node_name)
|
||||
self.assertLessEqual(3, len(bad_data))
|
||||
node_names = [datum.node_name for datum in bad_data]
|
||||
self.assertIn(x_name, node_names)
|
||||
self.assertIn(y_name, node_names)
|
||||
self.assertIn(z_name, node_names)
|
||||
|
||||
# Test first_n kwarg of find(): Find the first offending tensor.
|
||||
first_bad_datum = dump.find(has_bad_value, first_n=1)
|
||||
|
||||
self.assertEqual(1, len(first_bad_datum))
|
||||
self.assertEqual(x_name, first_bad_datum[0].node_name)
|
||||
|
||||
def testFindInfOrNanWithOpNameExclusion(self):
|
||||
with session.Session() as sess:
|
||||
@ -708,16 +707,15 @@ class SessionDebugTestBase(test_util.TensorFlowTestCase):
|
||||
|
||||
# Verify that the nodes with bad values are caught through running find
|
||||
# on the debug dump.
|
||||
self.assertEqual(2, len(bad_data))
|
||||
self.assertLessEqual(2, len(bad_data))
|
||||
# Assert that the node `x` should have been excluded.
|
||||
self.assertEqual(y_name, bad_data[0].node_name)
|
||||
self.assertEqual(z_name, bad_data[1].node_name)
|
||||
node_names = [datum.node_name for datum in bad_data]
|
||||
self.assertIn(y_name, node_names)
|
||||
self.assertIn(z_name, node_names)
|
||||
|
||||
first_bad_datum = dump.find(
|
||||
debug_data.has_inf_or_nan, first_n=1, exclude_node_names=".*/x$")
|
||||
|
||||
self.assertEqual(1, len(first_bad_datum))
|
||||
self.assertEqual(y_name, first_bad_datum[0].node_name)
|
||||
|
||||
def _session_run_for_graph_structure_lookup(self):
|
||||
with session.Session(config=no_rewrite_session_config()) as sess:
|
||||
@ -1378,7 +1376,7 @@ class SessionDebugTestBase(test_util.TensorFlowTestCase):
|
||||
sess, y, debug_ops=["DebugNumericSummary(mute_if_healthy=true)"],
|
||||
validate=False)
|
||||
|
||||
self.assertEqual(2, dump.size)
|
||||
self.assertLessEqual(2, dump.size)
|
||||
self.assertAllClose([[
|
||||
1.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, np.inf, -np.inf, np.nan,
|
||||
np.nan, 1.0, 0.0
|
||||
@ -1393,7 +1391,7 @@ class SessionDebugTestBase(test_util.TensorFlowTestCase):
|
||||
shutil.rmtree(self._dump_root)
|
||||
_, dump = self._debug_run_and_get_dump(
|
||||
sess, y, debug_ops=["DebugNumericSummary()"])
|
||||
self.assertEqual(8, dump.size)
|
||||
self.assertLessEqual(8, dump.size)
|
||||
|
||||
def testDebugNumericSummaryMuteOnHealthyAndCustomBoundsWork(self):
|
||||
with session.Session() as sess:
|
||||
|
@ -459,7 +459,8 @@ class LocalCLIDebugWrapperSessionTest(test_util.TensorFlowTestCase):
|
||||
self.assertEqual(2, len(debug_dumps))
|
||||
for debug_dump in debug_dumps:
|
||||
node_names = [datum.node_name for datum in debug_dump.dumped_tensor_data]
|
||||
self.assertItemsEqual(["callable_a", "callable_b"], node_names)
|
||||
self.assertIn("callable_a", node_names)
|
||||
self.assertIn("callable_b", node_names)
|
||||
|
||||
def testDebuggingMakeCallableFromOptionsWithTwoFeedsWorks(self):
|
||||
ph1 = array_ops.placeholder(dtypes.float32, name="callable_ph1")
|
||||
@ -486,7 +487,8 @@ class LocalCLIDebugWrapperSessionTest(test_util.TensorFlowTestCase):
|
||||
self.assertEqual(2, len(debug_dumps))
|
||||
for debug_dump in debug_dumps:
|
||||
node_names = [datum.node_name for datum in debug_dump.dumped_tensor_data]
|
||||
self.assertItemsEqual(["callable_a", "callable_b"], node_names)
|
||||
self.assertIn("callable_a", node_names)
|
||||
self.assertIn("callable_b", node_names)
|
||||
|
||||
def testDebugMakeCallableFromOptionsWithCustomOptionsAndMetadataWorks(self):
|
||||
variable_1 = variables.VariableV1(
|
||||
|
Loading…
Reference in New Issue
Block a user