STT-tensorflow/tensorflow/python/debug/cli/cli_shared_test.py
Gaurav Jain 24f578cd66 Add @run_deprecated_v1 annotation to tests failing in v2
PiperOrigin-RevId: 223422907
2018-11-29 15:43:25 -08:00

394 lines
16 KiB
Python

# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Unit tests for the shared functions and classes for tfdbg CLI."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from collections import namedtuple
from tensorflow.python.debug.cli import cli_shared
from tensorflow.python.debug.cli import debugger_cli_common
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import errors
from tensorflow.python.framework import ops
from tensorflow.python.framework import sparse_tensor
from tensorflow.python.framework import test_util
from tensorflow.python.ops import variables
from tensorflow.python.platform import googletest
class BytesToReadableStrTest(test_util.TensorFlowTestCase):
def testNoneSizeWorks(self):
self.assertEqual(str(None), cli_shared.bytes_to_readable_str(None))
def testSizesBelowOneKiloByteWorks(self):
self.assertEqual("0", cli_shared.bytes_to_readable_str(0))
self.assertEqual("500", cli_shared.bytes_to_readable_str(500))
self.assertEqual("1023", cli_shared.bytes_to_readable_str(1023))
def testSizesBetweenOneKiloByteandOneMegaByteWorks(self):
self.assertEqual("1.00k", cli_shared.bytes_to_readable_str(1024))
self.assertEqual("2.40k", cli_shared.bytes_to_readable_str(int(1024 * 2.4)))
self.assertEqual("1023.00k", cli_shared.bytes_to_readable_str(1024 * 1023))
def testSizesBetweenOneMegaByteandOneGigaByteWorks(self):
self.assertEqual("1.00M", cli_shared.bytes_to_readable_str(1024**2))
self.assertEqual("2.40M",
cli_shared.bytes_to_readable_str(int(1024**2 * 2.4)))
self.assertEqual("1023.00M",
cli_shared.bytes_to_readable_str(1024**2 * 1023))
def testSizeAboveOneGigaByteWorks(self):
self.assertEqual("1.00G", cli_shared.bytes_to_readable_str(1024**3))
self.assertEqual("2000.00G",
cli_shared.bytes_to_readable_str(1024**3 * 2000))
def testReadableStrIncludesBAtTheEndOnRequest(self):
self.assertEqual("0B", cli_shared.bytes_to_readable_str(0, include_b=True))
self.assertEqual(
"1.00kB", cli_shared.bytes_to_readable_str(
1024, include_b=True))
self.assertEqual(
"1.00MB", cli_shared.bytes_to_readable_str(
1024**2, include_b=True))
self.assertEqual(
"1.00GB", cli_shared.bytes_to_readable_str(
1024**3, include_b=True))
class TimeToReadableStrTest(test_util.TensorFlowTestCase):
def testNoneTimeWorks(self):
self.assertEqual("0", cli_shared.time_to_readable_str(None))
def testMicrosecondsTime(self):
self.assertEqual("40us", cli_shared.time_to_readable_str(40))
def testMillisecondTime(self):
self.assertEqual("40ms", cli_shared.time_to_readable_str(40e3))
def testSecondTime(self):
self.assertEqual("40s", cli_shared.time_to_readable_str(40e6))
def testForceTimeUnit(self):
self.assertEqual("40s",
cli_shared.time_to_readable_str(
40e6, force_time_unit=cli_shared.TIME_UNIT_S))
self.assertEqual("40000ms",
cli_shared.time_to_readable_str(
40e6, force_time_unit=cli_shared.TIME_UNIT_MS))
self.assertEqual("40000000us",
cli_shared.time_to_readable_str(
40e6, force_time_unit=cli_shared.TIME_UNIT_US))
self.assertEqual("4e-05s",
cli_shared.time_to_readable_str(
40, force_time_unit=cli_shared.TIME_UNIT_S))
self.assertEqual("0",
cli_shared.time_to_readable_str(
0, force_time_unit=cli_shared.TIME_UNIT_S))
with self.assertRaisesRegexp(ValueError, r"Invalid time unit: ks"):
cli_shared.time_to_readable_str(100, force_time_unit="ks")
class GetRunStartIntroAndDescriptionTest(test_util.TensorFlowTestCase):
def setUp(self):
self.const_a = constant_op.constant(11.0, name="a")
self.const_b = constant_op.constant(22.0, name="b")
self.const_c = constant_op.constant(33.0, name="c")
self.sparse_d = sparse_tensor.SparseTensor(
indices=[[0, 0], [1, 1]], values=[1.0, 2.0], dense_shape=[3, 3])
def tearDown(self):
ops.reset_default_graph()
@test_util.run_deprecated_v1
def testSingleFetchNoFeeds(self):
run_start_intro = cli_shared.get_run_start_intro(12, self.const_a, None, {})
# Verify line about run() call number.
self.assertTrue(run_start_intro.lines[1].endswith("run() call #12:"))
# Verify line about fetch.
const_a_name_line = run_start_intro.lines[4]
self.assertEqual(self.const_a.name, const_a_name_line.strip())
# Verify line about feeds.
feeds_line = run_start_intro.lines[7]
self.assertEqual("(Empty)", feeds_line.strip())
# Verify lines about possible commands and their font attributes.
self.assertEqual("run:", run_start_intro.lines[11][2:])
annot = run_start_intro.font_attr_segs[11][0]
self.assertEqual(2, annot[0])
self.assertEqual(5, annot[1])
self.assertEqual("run", annot[2][0].content)
self.assertEqual("bold", annot[2][1])
annot = run_start_intro.font_attr_segs[13][0]
self.assertEqual(2, annot[0])
self.assertEqual(8, annot[1])
self.assertEqual("run -n", annot[2][0].content)
self.assertEqual("bold", annot[2][1])
self.assertEqual("run -t <T>:", run_start_intro.lines[15][2:])
self.assertEqual([(2, 12, "bold")], run_start_intro.font_attr_segs[15])
self.assertEqual("run -f <filter_name>:", run_start_intro.lines[17][2:])
self.assertEqual([(2, 22, "bold")], run_start_intro.font_attr_segs[17])
annot = run_start_intro.font_attr_segs[21][0]
self.assertEqual(2, annot[0])
self.assertEqual(16, annot[1])
self.assertEqual("invoke_stepper", annot[2][0].content)
# Verify short description.
description = cli_shared.get_run_short_description(12, self.const_a, None)
self.assertEqual("run #12: 1 fetch (a:0); 0 feeds", description)
# Verify the main menu associated with the run_start_intro.
self.assertIn(debugger_cli_common.MAIN_MENU_KEY,
run_start_intro.annotations)
menu = run_start_intro.annotations[debugger_cli_common.MAIN_MENU_KEY]
self.assertEqual("run", menu.caption_to_item("run").content)
self.assertEqual("invoke_stepper",
menu.caption_to_item("invoke_stepper").content)
self.assertEqual("exit", menu.caption_to_item("exit").content)
def testSparseTensorAsFeedShouldHandleNoNameAttribute(self):
sparse_feed_val = ([[0, 0], [1, 1]], [10.0, 20.0])
run_start_intro = cli_shared.get_run_start_intro(
1, self.sparse_d, {self.sparse_d: sparse_feed_val}, {})
self.assertEqual(str(self.sparse_d), run_start_intro.lines[7].strip())
short_description = cli_shared.get_run_short_description(
1, self.sparse_d, {self.sparse_d: sparse_feed_val})
self.assertEqual(
"run #1: 1 fetch; 1 feed (%s)" % self.sparse_d, short_description)
def testSparseTensorAsFetchShouldHandleNoNameAttribute(self):
run_start_intro = cli_shared.get_run_start_intro(1, self.sparse_d, None, {})
self.assertEqual(str(self.sparse_d), run_start_intro.lines[4].strip())
@test_util.run_deprecated_v1
def testTwoFetchesListNoFeeds(self):
fetches = [self.const_a, self.const_b]
run_start_intro = cli_shared.get_run_start_intro(1, fetches, None, {})
const_a_name_line = run_start_intro.lines[4]
const_b_name_line = run_start_intro.lines[5]
self.assertEqual(self.const_a.name, const_a_name_line.strip())
self.assertEqual(self.const_b.name, const_b_name_line.strip())
feeds_line = run_start_intro.lines[8]
self.assertEqual("(Empty)", feeds_line.strip())
# Verify short description.
description = cli_shared.get_run_short_description(1, fetches, None)
self.assertEqual("run #1: 2 fetches; 0 feeds", description)
@test_util.run_deprecated_v1
def testNestedListAsFetches(self):
fetches = [self.const_c, [self.const_a, self.const_b]]
run_start_intro = cli_shared.get_run_start_intro(1, fetches, None, {})
# Verify lines about the fetches.
self.assertEqual(self.const_c.name, run_start_intro.lines[4].strip())
self.assertEqual(self.const_a.name, run_start_intro.lines[5].strip())
self.assertEqual(self.const_b.name, run_start_intro.lines[6].strip())
# Verify short description.
description = cli_shared.get_run_short_description(1, fetches, None)
self.assertEqual("run #1: 3 fetches; 0 feeds", description)
@test_util.run_deprecated_v1
def testNestedDictAsFetches(self):
fetches = {"c": self.const_c, "ab": {"a": self.const_a, "b": self.const_b}}
run_start_intro = cli_shared.get_run_start_intro(1, fetches, None, {})
# Verify lines about the fetches. The ordering of the dict keys is
# indeterminate.
fetch_names = set()
fetch_names.add(run_start_intro.lines[4].strip())
fetch_names.add(run_start_intro.lines[5].strip())
fetch_names.add(run_start_intro.lines[6].strip())
self.assertEqual({"a:0", "b:0", "c:0"}, fetch_names)
# Verify short description.
description = cli_shared.get_run_short_description(1, fetches, None)
self.assertEqual("run #1: 3 fetches; 0 feeds", description)
@test_util.run_deprecated_v1
def testTwoFetchesAsTupleNoFeeds(self):
fetches = (self.const_a, self.const_b)
run_start_intro = cli_shared.get_run_start_intro(1, fetches, None, {})
const_a_name_line = run_start_intro.lines[4]
const_b_name_line = run_start_intro.lines[5]
self.assertEqual(self.const_a.name, const_a_name_line.strip())
self.assertEqual(self.const_b.name, const_b_name_line.strip())
feeds_line = run_start_intro.lines[8]
self.assertEqual("(Empty)", feeds_line.strip())
# Verify short description.
description = cli_shared.get_run_short_description(1, fetches, None)
self.assertEqual("run #1: 2 fetches; 0 feeds", description)
@test_util.run_deprecated_v1
def testTwoFetchesAsNamedTupleNoFeeds(self):
fetches_namedtuple = namedtuple("fetches", "x y")
fetches = fetches_namedtuple(self.const_b, self.const_c)
run_start_intro = cli_shared.get_run_start_intro(1, fetches, None, {})
const_b_name_line = run_start_intro.lines[4]
const_c_name_line = run_start_intro.lines[5]
self.assertEqual(self.const_b.name, const_b_name_line.strip())
self.assertEqual(self.const_c.name, const_c_name_line.strip())
feeds_line = run_start_intro.lines[8]
self.assertEqual("(Empty)", feeds_line.strip())
# Verify short description.
description = cli_shared.get_run_short_description(1, fetches, None)
self.assertEqual("run #1: 2 fetches; 0 feeds", description)
@test_util.run_deprecated_v1
def testWithFeedDict(self):
feed_dict = {
self.const_a: 10.0,
self.const_b: 20.0,
}
run_start_intro = cli_shared.get_run_start_intro(1, self.const_c, feed_dict,
{})
const_c_name_line = run_start_intro.lines[4]
self.assertEqual(self.const_c.name, const_c_name_line.strip())
# Verify lines about the feed dict.
feed_a_line = run_start_intro.lines[7]
feed_b_line = run_start_intro.lines[8]
self.assertEqual(self.const_a.name, feed_a_line.strip())
self.assertEqual(self.const_b.name, feed_b_line.strip())
# Verify short description.
description = cli_shared.get_run_short_description(1, self.const_c,
feed_dict)
self.assertEqual("run #1: 1 fetch (c:0); 2 feeds", description)
@test_util.run_deprecated_v1
def testTensorFilters(self):
feed_dict = {self.const_a: 10.0}
tensor_filters = {
"filter_a": lambda x: True,
"filter_b": lambda x: False,
}
run_start_intro = cli_shared.get_run_start_intro(1, self.const_c, feed_dict,
tensor_filters)
# Verify the listed names of the tensor filters.
filter_names = set()
filter_names.add(run_start_intro.lines[20].split(" ")[-1])
filter_names.add(run_start_intro.lines[21].split(" ")[-1])
self.assertEqual({"filter_a", "filter_b"}, filter_names)
# Verify short description.
description = cli_shared.get_run_short_description(1, self.const_c,
feed_dict)
self.assertEqual("run #1: 1 fetch (c:0); 1 feed (a:0)", description)
# Verify the command links for the two filters.
command_set = set()
annot = run_start_intro.font_attr_segs[20][0]
command_set.add(annot[2].content)
annot = run_start_intro.font_attr_segs[21][0]
command_set.add(annot[2].content)
self.assertEqual({"run -f filter_a", "run -f filter_b"}, command_set)
@test_util.run_deprecated_v1
def testGetRunShortDescriptionWorksForTensorFeedKey(self):
short_description = cli_shared.get_run_short_description(
1, self.const_a, {self.const_a: 42.0})
self.assertEqual("run #1: 1 fetch (a:0); 1 feed (a:0)", short_description)
@test_util.run_deprecated_v1
def testGetRunShortDescriptionWorksForUnicodeFeedKey(self):
short_description = cli_shared.get_run_short_description(
1, self.const_a, {u"foo": 42.0})
self.assertEqual("run #1: 1 fetch (a:0); 1 feed (foo)", short_description)
class GetErrorIntroTest(test_util.TensorFlowTestCase):
def setUp(self):
self.var_a = variables.Variable(42.0, name="a")
def tearDown(self):
ops.reset_default_graph()
@test_util.run_deprecated_v1
def testShapeError(self):
tf_error = errors.OpError(None, self.var_a.initializer, "foo description",
None)
error_intro = cli_shared.get_error_intro(tf_error)
self.assertEqual("!!! An error occurred during the run !!!",
error_intro.lines[1])
self.assertEqual([(0, len(error_intro.lines[1]), "blink")],
error_intro.font_attr_segs[1])
self.assertEqual(2, error_intro.lines[4].index("ni -a -d -t a/Assign"))
self.assertEqual(2, error_intro.font_attr_segs[4][0][0])
self.assertEqual(22, error_intro.font_attr_segs[4][0][1])
self.assertEqual("ni -a -d -t a/Assign",
error_intro.font_attr_segs[4][0][2][0].content)
self.assertEqual("bold", error_intro.font_attr_segs[4][0][2][1])
self.assertEqual(2, error_intro.lines[6].index("li -r a/Assign"))
self.assertEqual(2, error_intro.font_attr_segs[6][0][0])
self.assertEqual(16, error_intro.font_attr_segs[6][0][1])
self.assertEqual("li -r a/Assign",
error_intro.font_attr_segs[6][0][2][0].content)
self.assertEqual("bold", error_intro.font_attr_segs[6][0][2][1])
self.assertEqual(2, error_intro.lines[8].index("lt"))
self.assertEqual(2, error_intro.font_attr_segs[8][0][0])
self.assertEqual(4, error_intro.font_attr_segs[8][0][1])
self.assertEqual("lt", error_intro.font_attr_segs[8][0][2][0].content)
self.assertEqual("bold", error_intro.font_attr_segs[8][0][2][1])
self.assertStartsWith(error_intro.lines[11], "Op name:")
self.assertTrue(error_intro.lines[11].endswith("a/Assign"))
self.assertStartsWith(error_intro.lines[12], "Error type:")
self.assertTrue(error_intro.lines[12].endswith(str(type(tf_error))))
self.assertEqual("Details:", error_intro.lines[14])
self.assertStartsWith(error_intro.lines[15], "foo description")
def testGetErrorIntroForNoOpName(self):
tf_error = errors.OpError(None, None, "Fake OpError", -1)
error_intro = cli_shared.get_error_intro(tf_error)
self.assertIn("Cannot determine the name of the op", error_intro.lines[3])
if __name__ == "__main__":
googletest.main()