tfdbg: fix a bug in string representation of SparseTensors
Fixes: #12059 PiperOrigin-RevId: 164462827
This commit is contained in:
parent
324e2f7a19
commit
ecb544ae12
tensorflow/python/debug
@ -330,8 +330,12 @@ def get_run_start_intro(run_call_count,
|
||||
else:
|
||||
feed_dict_lines = []
|
||||
for feed_key in feed_dict:
|
||||
feed_key_name = (feed_key if isinstance(feed_key, six.string_types)
|
||||
else feed_key.name)
|
||||
if isinstance(feed_key, six.string_types):
|
||||
feed_key_name = feed_key
|
||||
elif hasattr(feed_key, "name"):
|
||||
feed_key_name = feed_key.name
|
||||
else:
|
||||
feed_key_name = str(feed_key)
|
||||
feed_dict_line = debugger_cli_common.RichLine(" ")
|
||||
feed_dict_line += debugger_cli_common.RichLine(
|
||||
feed_key_name,
|
||||
@ -456,7 +460,8 @@ def get_run_short_description(run_call_count,
|
||||
if len(feed_dict) == 1:
|
||||
for key in feed_dict:
|
||||
description += "1 feed (%s)" % (
|
||||
key if isinstance(key, six.string_types) else key.name)
|
||||
key if isinstance(key, six.string_types) or not hasattr(key, "name")
|
||||
else key.name)
|
||||
else:
|
||||
description += "%d feeds" % len(feed_dict)
|
||||
|
||||
|
@ -166,6 +166,17 @@ class GetRunStartIntroAndDescriptionTest(test_util.TensorFlowTestCase):
|
||||
menu.caption_to_item("invoke_stepper").content)
|
||||
self.assertEqual("exit", menu.caption_to_item("exit").content)
|
||||
|
||||
def testSparseTensorAsFeedShouldHandleNoNameAttribute(self):
|
||||
sparse_feed_val = ([[0, 0], [1, 1]], [10.0, 20.0])
|
||||
run_start_intro = cli_shared.get_run_start_intro(
|
||||
1, self.sparse_d, {self.sparse_d: sparse_feed_val}, {})
|
||||
self.assertEqual(str(self.sparse_d), run_start_intro.lines[7].strip())
|
||||
|
||||
short_description = cli_shared.get_run_short_description(
|
||||
1, self.sparse_d, {self.sparse_d: sparse_feed_val})
|
||||
self.assertEqual(
|
||||
"run #1: 1 fetch; 1 feed (%s)" % self.sparse_d, short_description)
|
||||
|
||||
def testSparseTensorAsFetchShouldHandleNoNameAttribute(self):
|
||||
run_start_intro = cli_shared.get_run_start_intro(1, self.sparse_d, None, {})
|
||||
self.assertEqual(str(self.sparse_d), run_start_intro.lines[4].strip())
|
||||
|
@ -36,6 +36,7 @@ from tensorflow.python.ops import control_flow_ops
|
||||
from tensorflow.python.ops import math_ops
|
||||
# Import resource_variable_ops for the variables-to-tensor implicit conversion.
|
||||
from tensorflow.python.ops import resource_variable_ops # pylint: disable=unused-import
|
||||
from tensorflow.python.ops import sparse_ops
|
||||
from tensorflow.python.ops import state_ops
|
||||
from tensorflow.python.ops import variables
|
||||
from tensorflow.python.platform import googletest
|
||||
@ -135,6 +136,10 @@ class LocalCLIDebugWrapperSessionTest(test_util.TensorFlowTestCase):
|
||||
[[0.0, 1.0, 2.0], [-4.0, -1.0, 0.0]], dtype=dtypes.float32, name="m")
|
||||
self.y = math_ops.matmul(self.m, self.xph, name="y")
|
||||
|
||||
self.sparse_ph = array_ops.sparse_placeholder(
|
||||
dtypes.float32, shape=([5, 5]), name="sparse_placeholder")
|
||||
self.sparse_add = sparse_ops.sparse_add(self.sparse_ph, self.sparse_ph)
|
||||
|
||||
self.sess = session.Session()
|
||||
|
||||
# Initialize variable.
|
||||
@ -248,6 +253,16 @@ class LocalCLIDebugWrapperSessionTest(test_util.TensorFlowTestCase):
|
||||
wrapped_sess.observers["run_start_cli_run_numbers"])
|
||||
self.assertEqual([], wrapped_sess.observers["run_end_cli_run_numbers"])
|
||||
|
||||
def testRunningWithSparsePlaceholderFeedWorks(self):
|
||||
wrapped_sess = LocalCLIDebuggerWrapperSessionForTest(
|
||||
[["run"], ["run"]], self.sess, dump_root=self._tmp_dir)
|
||||
|
||||
sparse_feed = ([[0, 1], [0, 2]], [10.0, 20.0])
|
||||
sparse_result = wrapped_sess.run(
|
||||
self.sparse_add, feed_dict={self.sparse_ph: sparse_feed})
|
||||
self.assertAllEqual([[0, 1], [0, 2]], sparse_result.indices)
|
||||
self.assertAllClose([20.0, 40.0], sparse_result.values)
|
||||
|
||||
def testRunsUnderNonDebugThenDebugMode(self):
|
||||
# Do two NON_DEBUG_RUNs, followed by DEBUG_RUNs.
|
||||
wrapped_sess = LocalCLIDebuggerWrapperSessionForTest(
|
||||
|
Loading…
Reference in New Issue
Block a user