STT-tensorflow/tensorflow/python/debug/cli/command_parser_test.py
Shanqing Cai 44113ce5bb tfdbg CLI: add eval of arbitrary Python / np expressions
RELNOTES: TensorFlow Debugger (tfdbg) command-line interface: Support evaluation of arbitrary Python and numpy (np) expressions with debug tensor names enclosed in pairs of backtics. E.g., tfdbg> eval 'np.sum(`Softmax:0`, axis=1)'.
PiperOrigin-RevId: 164217384
2017-08-03 19:41:46 -07:00

520 lines
21 KiB
Python

# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for TensorFlow Debugger command parser."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import sys
from tensorflow.python.debug.cli import command_parser
from tensorflow.python.framework import test_util
from tensorflow.python.platform import googletest
class ParseCommandTest(test_util.TensorFlowTestCase):
def testParseNoBracketsOrQuotes(self):
command = ""
self.assertEqual([], command_parser.parse_command(command))
command = "a"
self.assertEqual(["a"], command_parser.parse_command(command))
command = "foo bar baz qux"
self.assertEqual(["foo", "bar", "baz", "qux"],
command_parser.parse_command(command))
command = "foo bar\tbaz\t qux"
self.assertEqual(["foo", "bar", "baz", "qux"],
command_parser.parse_command(command))
def testParseLeadingTrailingWhitespaces(self):
command = " foo bar baz qux "
self.assertEqual(["foo", "bar", "baz", "qux"],
command_parser.parse_command(command))
command = "\nfoo bar baz qux\n"
self.assertEqual(["foo", "bar", "baz", "qux"],
command_parser.parse_command(command))
def testParseCommandsWithBrackets(self):
command = "pt foo[1, 2, :]"
self.assertEqual(["pt", "foo[1, 2, :]"],
command_parser.parse_command(command))
command = "pt foo[1, 2, :] -a"
self.assertEqual(["pt", "foo[1, 2, :]", "-a"],
command_parser.parse_command(command))
command = "inject_value foo [1, 2,:] 0"
self.assertEqual(["inject_value", "foo", "[1, 2,:]", "0"],
command_parser.parse_command(command))
def testParseCommandWithTwoArgsContainingBrackets(self):
command = "pt foo[1, :] bar[:, 2]"
self.assertEqual(["pt", "foo[1, :]", "bar[:, 2]"],
command_parser.parse_command(command))
command = "pt foo[] bar[:, 2]"
self.assertEqual(["pt", "foo[]", "bar[:, 2]"],
command_parser.parse_command(command))
def testParseCommandWithUnmatchedBracket(self):
command = "pt foo[1, 2, :"
self.assertNotEqual(["pt", "foo[1, 2, :]"],
command_parser.parse_command(command))
def testParseCommandsWithQuotes(self):
command = "inject_value foo \"np.zeros([100, 500])\""
self.assertEqual(["inject_value", "foo", "np.zeros([100, 500])"],
command_parser.parse_command(command))
# The pair of double quotes should have been stripped.
command = "inject_value foo 'np.zeros([100, 500])'"
self.assertEqual(["inject_value", "foo", "np.zeros([100, 500])"],
command_parser.parse_command(command))
# The pair of single quotes should have been stripped.
command = "\"command prefix with spaces\" arg1"
self.assertEqual(["command prefix with spaces", "arg1"],
command_parser.parse_command(command))
def testParseCommandWithTwoArgsContainingQuotes(self):
command = "foo \"bar\" \"qux\""
self.assertEqual(["foo", "bar", "qux"],
command_parser.parse_command(command))
command = "foo \"\" \"qux\""
self.assertEqual(["foo", "", "qux"],
command_parser.parse_command(command))
class ExtractOutputFilePathTest(test_util.TensorFlowTestCase):
def testNoOutputFilePathIsReflected(self):
args, output_path = command_parser.extract_output_file_path(["pt", "a:0"])
self.assertEqual(["pt", "a:0"], args)
self.assertIsNone(output_path)
def testHasOutputFilePathInOneArgsIsReflected(self):
args, output_path = command_parser.extract_output_file_path(
["pt", "a:0", ">/tmp/foo.txt"])
self.assertEqual(["pt", "a:0"], args)
self.assertEqual(output_path, "/tmp/foo.txt")
def testHasOutputFilePathInTwoArgsIsReflected(self):
args, output_path = command_parser.extract_output_file_path(
["pt", "a:0", ">", "/tmp/foo.txt"])
self.assertEqual(["pt", "a:0"], args)
self.assertEqual(output_path, "/tmp/foo.txt")
def testHasGreaterThanSignButNoFileNameCausesSyntaxError(self):
with self.assertRaisesRegexp(SyntaxError, "Redirect file path is empty"):
command_parser.extract_output_file_path(
["pt", "a:0", ">"])
def testOutputPathMergedWithLastArgIsHandledCorrectly(self):
args, output_path = command_parser.extract_output_file_path(
["pt", "a:0>/tmp/foo.txt"])
self.assertEqual(["pt", "a:0"], args)
self.assertEqual(output_path, "/tmp/foo.txt")
def testOutputPathInLastArgGreaterThanInSecondLastIsHandledCorrectly(self):
args, output_path = command_parser.extract_output_file_path(
["pt", "a:0>", "/tmp/foo.txt"])
self.assertEqual(["pt", "a:0"], args)
self.assertEqual(output_path, "/tmp/foo.txt")
def testFlagWithEqualGreaterThanShouldIgnoreIntervalFlags(self):
args, output_path = command_parser.extract_output_file_path(
["lp", "--execution_time=>100ms"])
self.assertEqual(["lp", "--execution_time=>100ms"], args)
self.assertIsNone(output_path)
args, output_path = command_parser.extract_output_file_path(
["lp", "--execution_time", ">1.2s"])
self.assertEqual(["lp", "--execution_time", ">1.2s"], args)
self.assertIsNone(output_path)
args, output_path = command_parser.extract_output_file_path(
["lp", "-e", ">1200"])
self.assertEqual(["lp", "-e", ">1200"], args)
self.assertIsNone(output_path)
args, output_path = command_parser.extract_output_file_path(
["lp", "--foo_value", ">-.2MB"])
self.assertEqual(["lp", "--foo_value", ">-.2MB"], args)
self.assertIsNone(output_path)
args, output_path = command_parser.extract_output_file_path(
["lp", "--bar_value", ">-42e3GB"])
self.assertEqual(["lp", "--bar_value", ">-42e3GB"], args)
self.assertIsNone(output_path)
args, output_path = command_parser.extract_output_file_path(
["lp", "--execution_time", ">=100ms"])
self.assertEqual(["lp", "--execution_time", ">=100ms"], args)
self.assertIsNone(output_path)
args, output_path = command_parser.extract_output_file_path(
["lp", "--execution_time=>=100ms"])
self.assertEqual(["lp", "--execution_time=>=100ms"], args)
self.assertIsNone(output_path)
def testFlagWithEqualGreaterThanShouldRecognizeFilePaths(self):
args, output_path = command_parser.extract_output_file_path(
["lp", ">1.2s"])
self.assertEqual(["lp"], args)
self.assertEqual("1.2s", output_path)
args, output_path = command_parser.extract_output_file_path(
["lp", "--execution_time", ">x.yms"])
self.assertEqual(["lp", "--execution_time"], args)
self.assertEqual("x.yms", output_path)
args, output_path = command_parser.extract_output_file_path(
["lp", "--memory", ">a.1kB"])
self.assertEqual(["lp", "--memory"], args)
self.assertEqual("a.1kB", output_path)
args, output_path = command_parser.extract_output_file_path(
["lp", "--memory", ">e002MB"])
self.assertEqual(["lp", "--memory"], args)
self.assertEqual("e002MB", output_path)
def testOneArgumentIsHandledCorrectly(self):
args, output_path = command_parser.extract_output_file_path(["lt"])
self.assertEqual(["lt"], args)
self.assertIsNone(output_path)
def testEmptyArgumentIsHandledCorrectly(self):
args, output_path = command_parser.extract_output_file_path([])
self.assertEqual([], args)
self.assertIsNone(output_path)
class ParseTensorNameTest(test_util.TensorFlowTestCase):
def testParseTensorNameWithoutSlicing(self):
(tensor_name,
tensor_slicing) = command_parser.parse_tensor_name_with_slicing(
"hidden/weights/Variable:0")
self.assertEqual("hidden/weights/Variable:0", tensor_name)
self.assertEqual("", tensor_slicing)
def testParseTensorNameWithSlicing(self):
(tensor_name,
tensor_slicing) = command_parser.parse_tensor_name_with_slicing(
"hidden/weights/Variable:0[:, 1]")
self.assertEqual("hidden/weights/Variable:0", tensor_name)
self.assertEqual("[:, 1]", tensor_slicing)
class ValidateSlicingStringTest(test_util.TensorFlowTestCase):
def testValidateValidSlicingStrings(self):
self.assertTrue(command_parser.validate_slicing_string("[1]"))
self.assertTrue(command_parser.validate_slicing_string("[2,3]"))
self.assertTrue(command_parser.validate_slicing_string("[4, 5, 6]"))
self.assertTrue(command_parser.validate_slicing_string("[7,:, :]"))
def testValidateInvalidSlicingStrings(self):
self.assertFalse(command_parser.validate_slicing_string(""))
self.assertFalse(command_parser.validate_slicing_string("[1,"))
self.assertFalse(command_parser.validate_slicing_string("2,3]"))
self.assertFalse(command_parser.validate_slicing_string("[4, foo()]"))
self.assertFalse(command_parser.validate_slicing_string("[5, bar]"))
class ParseIndicesTest(test_util.TensorFlowTestCase):
def testParseValidIndicesStringsWithBrackets(self):
self.assertEqual([0], command_parser.parse_indices("[0]"))
self.assertEqual([0], command_parser.parse_indices(" [0] "))
self.assertEqual([-1, 2], command_parser.parse_indices("[-1, 2]"))
self.assertEqual([3, 4, -5],
command_parser.parse_indices("[3,4,-5]"))
def testParseValidIndicesStringsWithoutBrackets(self):
self.assertEqual([0], command_parser.parse_indices("0"))
self.assertEqual([0], command_parser.parse_indices(" 0 "))
self.assertEqual([-1, 2], command_parser.parse_indices("-1, 2"))
self.assertEqual([3, 4, -5], command_parser.parse_indices("3,4,-5"))
def testParseInvalidIndicesStringsWithoutBrackets(self):
with self.assertRaisesRegexp(
ValueError, r"invalid literal for int\(\) with base 10: 'a'"):
self.assertEqual([0], command_parser.parse_indices("0,a"))
with self.assertRaisesRegexp(
ValueError, r"invalid literal for int\(\) with base 10: '2\]'"):
self.assertEqual([0], command_parser.parse_indices("1, 2]"))
with self.assertRaisesRegexp(
ValueError, r"invalid literal for int\(\) with base 10: ''"):
self.assertEqual([0], command_parser.parse_indices("3, 4,"))
class ParseRangesTest(test_util.TensorFlowTestCase):
INF_VALUE = sys.float_info.max
def testParseEmptyRangeString(self):
self.assertEqual([], command_parser.parse_ranges(""))
self.assertEqual([], command_parser.parse_ranges(" "))
def testParseSingleRange(self):
self.assertAllClose([[-0.1, 0.2]],
command_parser.parse_ranges("[-0.1, 0.2]"))
self.assertAllClose([[-0.1, self.INF_VALUE]],
command_parser.parse_ranges("[-0.1, inf]"))
self.assertAllClose([[-self.INF_VALUE, self.INF_VALUE]],
command_parser.parse_ranges("[-inf, inf]"))
def testParseSingleListOfRanges(self):
self.assertAllClose([[-0.1, 0.2], [10.0, 12.0]],
command_parser.parse_ranges("[[-0.1, 0.2], [10, 12]]"))
self.assertAllClose(
[[-self.INF_VALUE, -1.0], [1.0, self.INF_VALUE]],
command_parser.parse_ranges("[[-inf, -1.0],[1.0, inf]]"))
def testParseInvalidRangeString(self):
with self.assertRaises(SyntaxError):
command_parser.parse_ranges("[[1,2]")
with self.assertRaisesRegexp(ValueError,
"Incorrect number of elements in range"):
command_parser.parse_ranges("[1,2,3]")
with self.assertRaisesRegexp(ValueError,
"Incorrect number of elements in range"):
command_parser.parse_ranges("[inf]")
with self.assertRaisesRegexp(ValueError,
"Incorrect type in the 1st element of range"):
command_parser.parse_ranges("[1j, 1]")
with self.assertRaisesRegexp(ValueError,
"Incorrect type in the 2nd element of range"):
command_parser.parse_ranges("[1, 1j]")
class ParseReadableSizeStrTest(test_util.TensorFlowTestCase):
def testParseNoUnitWorks(self):
self.assertEqual(0, command_parser.parse_readable_size_str("0"))
self.assertEqual(1024, command_parser.parse_readable_size_str("1024 "))
self.assertEqual(2000, command_parser.parse_readable_size_str(" 2000 "))
def testParseKiloBytesWorks(self):
self.assertEqual(0, command_parser.parse_readable_size_str("0kB"))
self.assertEqual(1024**2, command_parser.parse_readable_size_str("1024 kB"))
self.assertEqual(1024**2 * 2,
command_parser.parse_readable_size_str("2048k"))
self.assertEqual(1024**2 * 2,
command_parser.parse_readable_size_str("2048kB"))
self.assertEqual(1024 / 4, command_parser.parse_readable_size_str("0.25k"))
def testParseMegaBytesWorks(self):
self.assertEqual(0, command_parser.parse_readable_size_str("0MB"))
self.assertEqual(1024**3, command_parser.parse_readable_size_str("1024 MB"))
self.assertEqual(1024**3 * 2,
command_parser.parse_readable_size_str("2048M"))
self.assertEqual(1024**3 * 2,
command_parser.parse_readable_size_str("2048MB"))
self.assertEqual(1024**2 / 4,
command_parser.parse_readable_size_str("0.25M"))
def testParseGigaBytesWorks(self):
self.assertEqual(0, command_parser.parse_readable_size_str("0GB"))
self.assertEqual(1024**4, command_parser.parse_readable_size_str("1024 GB"))
self.assertEqual(1024**4 * 2,
command_parser.parse_readable_size_str("2048G"))
self.assertEqual(1024**4 * 2,
command_parser.parse_readable_size_str("2048GB"))
self.assertEqual(1024**3 / 4,
command_parser.parse_readable_size_str("0.25G"))
def testParseUnsupportedUnitRaisesException(self):
with self.assertRaisesRegexp(
ValueError, "Failed to parsed human-readable byte size str: \"0foo\""):
command_parser.parse_readable_size_str("0foo")
with self.assertRaisesRegexp(
ValueError, "Failed to parsed human-readable byte size str: \"2E\""):
command_parser.parse_readable_size_str("2EB")
class ParseReadableTimeStrTest(test_util.TensorFlowTestCase):
def testParseNoUnitWorks(self):
self.assertEqual(0, command_parser.parse_readable_time_str("0"))
self.assertEqual(100, command_parser.parse_readable_time_str("100 "))
self.assertEqual(25, command_parser.parse_readable_time_str(" 25 "))
def testParseSeconds(self):
self.assertEqual(1e6, command_parser.parse_readable_time_str("1 s"))
self.assertEqual(2e6, command_parser.parse_readable_time_str("2s"))
def testParseMicros(self):
self.assertEqual(2, command_parser.parse_readable_time_str("2us"))
def testParseMillis(self):
self.assertEqual(2e3, command_parser.parse_readable_time_str("2ms"))
def testParseUnsupportedUnitRaisesException(self):
with self.assertRaisesRegexp(
ValueError, r".*float.*2us.*"):
command_parser.parse_readable_time_str("2uss")
with self.assertRaisesRegexp(
ValueError, r".*float.*2m.*"):
command_parser.parse_readable_time_str("2m")
with self.assertRaisesRegexp(
ValueError, r"Invalid time -1. Time value must be positive."):
command_parser.parse_readable_time_str("-1s")
class ParseInterval(test_util.TensorFlowTestCase):
def testParseTimeInterval(self):
self.assertEquals(
command_parser.Interval(10, True, 1e3, True),
command_parser.parse_time_interval("[10us, 1ms]"))
self.assertEquals(
command_parser.Interval(10, False, 1e3, False),
command_parser.parse_time_interval("(10us, 1ms)"))
self.assertEquals(
command_parser.Interval(10, False, 1e3, True),
command_parser.parse_time_interval("(10us, 1ms]"))
self.assertEquals(
command_parser.Interval(10, True, 1e3, False),
command_parser.parse_time_interval("[10us, 1ms)"))
self.assertEquals(command_parser.Interval(0, False, 1e3, True),
command_parser.parse_time_interval("<=1ms"))
self.assertEquals(
command_parser.Interval(1e3, True, float("inf"), False),
command_parser.parse_time_interval(">=1ms"))
self.assertEquals(command_parser.Interval(0, False, 1e3, False),
command_parser.parse_time_interval("<1ms"))
self.assertEquals(
command_parser.Interval(1e3, False, float("inf"), False),
command_parser.parse_time_interval(">1ms"))
def testParseTimeGreaterLessThanWithInvalidValueStrings(self):
with self.assertRaisesRegexp(ValueError, "Invalid value string after >= "):
command_parser.parse_time_interval(">=wms")
with self.assertRaisesRegexp(ValueError, "Invalid value string after > "):
command_parser.parse_time_interval(">Yms")
with self.assertRaisesRegexp(ValueError, "Invalid value string after <= "):
command_parser.parse_time_interval("<= _ms")
with self.assertRaisesRegexp(ValueError, "Invalid value string after < "):
command_parser.parse_time_interval("<-ms")
def testParseTimeIntervalsWithInvalidValueStrings(self):
with self.assertRaisesRegexp(ValueError, "Invalid first item in interval:"):
command_parser.parse_time_interval("[wms, 10ms]")
with self.assertRaisesRegexp(ValueError,
"Invalid second item in interval:"):
command_parser.parse_time_interval("[ 0ms, _ms]")
with self.assertRaisesRegexp(ValueError, "Invalid first item in interval:"):
command_parser.parse_time_interval("(xms, _ms]")
with self.assertRaisesRegexp(ValueError, "Invalid first item in interval:"):
command_parser.parse_time_interval("((3ms, _ms)")
def testInvalidTimeIntervalRaisesException(self):
with self.assertRaisesRegexp(
ValueError,
r"Invalid interval format: \[10us, 1ms. Valid formats are: "
r"\[min, max\], \(min, max\), <max, >min"):
command_parser.parse_time_interval("[10us, 1ms")
with self.assertRaisesRegexp(
ValueError,
r"Incorrect interval format: \[10us, 1ms, 2ms\]. Interval should "
r"specify two values: \[min, max\] or \(min, max\)"):
command_parser.parse_time_interval("[10us, 1ms, 2ms]")
with self.assertRaisesRegexp(
ValueError,
r"Invalid interval \[1s, 1ms\]. Start must be before end of interval."):
command_parser.parse_time_interval("[1s, 1ms]")
def testParseMemoryInterval(self):
self.assertEquals(
command_parser.Interval(1024, True, 2048, True),
command_parser.parse_memory_interval("[1k, 2k]"))
self.assertEquals(
command_parser.Interval(1024, False, 2048, False),
command_parser.parse_memory_interval("(1kB, 2kB)"))
self.assertEquals(
command_parser.Interval(1024, False, 2048, True),
command_parser.parse_memory_interval("(1k, 2k]"))
self.assertEquals(
command_parser.Interval(1024, True, 2048, False),
command_parser.parse_memory_interval("[1k, 2k)"))
self.assertEquals(
command_parser.Interval(0, False, 2048, True),
command_parser.parse_memory_interval("<=2k"))
self.assertEquals(
command_parser.Interval(11, True, float("inf"), False),
command_parser.parse_memory_interval(">=11"))
self.assertEquals(command_parser.Interval(0, False, 2048, False),
command_parser.parse_memory_interval("<2k"))
self.assertEquals(
command_parser.Interval(11, False, float("inf"), False),
command_parser.parse_memory_interval(">11"))
def testParseMemoryIntervalsWithInvalidValueStrings(self):
with self.assertRaisesRegexp(ValueError, "Invalid value string after >= "):
command_parser.parse_time_interval(">=wM")
with self.assertRaisesRegexp(ValueError, "Invalid value string after > "):
command_parser.parse_time_interval(">YM")
with self.assertRaisesRegexp(ValueError, "Invalid value string after <= "):
command_parser.parse_time_interval("<= _MB")
with self.assertRaisesRegexp(ValueError, "Invalid value string after < "):
command_parser.parse_time_interval("<-MB")
def testInvalidMemoryIntervalRaisesException(self):
with self.assertRaisesRegexp(
ValueError,
r"Invalid interval \[5k, 3k\]. Start of interval must be less than or "
"equal to end of interval."):
command_parser.parse_memory_interval("[5k, 3k]")
def testIntervalContains(self):
interval = command_parser.Interval(
start=1, start_included=True, end=10, end_included=True)
self.assertTrue(interval.contains(1))
self.assertTrue(interval.contains(10))
self.assertTrue(interval.contains(5))
interval.start_included = False
self.assertFalse(interval.contains(1))
self.assertTrue(interval.contains(10))
interval.end_included = False
self.assertFalse(interval.contains(1))
self.assertFalse(interval.contains(10))
interval.start_included = True
self.assertTrue(interval.contains(1))
self.assertFalse(interval.contains(10))
if __name__ == "__main__":
googletest.main()