tfdbg: fix a bug in string representation of SparseTensors

Fixes: 
PiperOrigin-RevId: 164462827
This commit is contained in:
Shanqing Cai 2017-08-07 08:25:36 -07:00 committed by TensorFlower Gardener
parent 324e2f7a19
commit ecb544ae12
3 changed files with 34 additions and 3 deletions

View File

@ -330,8 +330,12 @@ def get_run_start_intro(run_call_count,
else:
feed_dict_lines = []
for feed_key in feed_dict:
feed_key_name = (feed_key if isinstance(feed_key, six.string_types)
else feed_key.name)
if isinstance(feed_key, six.string_types):
feed_key_name = feed_key
elif hasattr(feed_key, "name"):
feed_key_name = feed_key.name
else:
feed_key_name = str(feed_key)
feed_dict_line = debugger_cli_common.RichLine(" ")
feed_dict_line += debugger_cli_common.RichLine(
feed_key_name,
@ -456,7 +460,8 @@ def get_run_short_description(run_call_count,
if len(feed_dict) == 1:
for key in feed_dict:
description += "1 feed (%s)" % (
key if isinstance(key, six.string_types) else key.name)
key if isinstance(key, six.string_types) or not hasattr(key, "name")
else key.name)
else:
description += "%d feeds" % len(feed_dict)

View File

@ -166,6 +166,17 @@ class GetRunStartIntroAndDescriptionTest(test_util.TensorFlowTestCase):
menu.caption_to_item("invoke_stepper").content)
self.assertEqual("exit", menu.caption_to_item("exit").content)
def testSparseTensorAsFeedShouldHandleNoNameAttribute(self):
sparse_feed_val = ([[0, 0], [1, 1]], [10.0, 20.0])
run_start_intro = cli_shared.get_run_start_intro(
1, self.sparse_d, {self.sparse_d: sparse_feed_val}, {})
self.assertEqual(str(self.sparse_d), run_start_intro.lines[7].strip())
short_description = cli_shared.get_run_short_description(
1, self.sparse_d, {self.sparse_d: sparse_feed_val})
self.assertEqual(
"run #1: 1 fetch; 1 feed (%s)" % self.sparse_d, short_description)
def testSparseTensorAsFetchShouldHandleNoNameAttribute(self):
run_start_intro = cli_shared.get_run_start_intro(1, self.sparse_d, None, {})
self.assertEqual(str(self.sparse_d), run_start_intro.lines[4].strip())

View File

@ -36,6 +36,7 @@ from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import math_ops
# Import resource_variable_ops for the variables-to-tensor implicit conversion.
from tensorflow.python.ops import resource_variable_ops # pylint: disable=unused-import
from tensorflow.python.ops import sparse_ops
from tensorflow.python.ops import state_ops
from tensorflow.python.ops import variables
from tensorflow.python.platform import googletest
@ -135,6 +136,10 @@ class LocalCLIDebugWrapperSessionTest(test_util.TensorFlowTestCase):
[[0.0, 1.0, 2.0], [-4.0, -1.0, 0.0]], dtype=dtypes.float32, name="m")
self.y = math_ops.matmul(self.m, self.xph, name="y")
self.sparse_ph = array_ops.sparse_placeholder(
dtypes.float32, shape=([5, 5]), name="sparse_placeholder")
self.sparse_add = sparse_ops.sparse_add(self.sparse_ph, self.sparse_ph)
self.sess = session.Session()
# Initialize variable.
@ -248,6 +253,16 @@ class LocalCLIDebugWrapperSessionTest(test_util.TensorFlowTestCase):
wrapped_sess.observers["run_start_cli_run_numbers"])
self.assertEqual([], wrapped_sess.observers["run_end_cli_run_numbers"])
def testRunningWithSparsePlaceholderFeedWorks(self):
wrapped_sess = LocalCLIDebuggerWrapperSessionForTest(
[["run"], ["run"]], self.sess, dump_root=self._tmp_dir)
sparse_feed = ([[0, 1], [0, 2]], [10.0, 20.0])
sparse_result = wrapped_sess.run(
self.sparse_add, feed_dict={self.sparse_ph: sparse_feed})
self.assertAllEqual([[0, 1], [0, 2]], sparse_result.indices)
self.assertAllClose([20.0, 40.0], sparse_result.values)
def testRunsUnderNonDebugThenDebugMode(self):
# Do two NON_DEBUG_RUNs, followed by DEBUG_RUNs.
wrapped_sess = LocalCLIDebuggerWrapperSessionForTest(