452 lines
18 KiB
Python
452 lines
18 KiB
Python
# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
# ==============================================================================
|
|
"""Unit tests for source_utils."""
|
|
|
|
from __future__ import absolute_import
|
|
from __future__ import division
|
|
from __future__ import print_function
|
|
|
|
import ast
|
|
import os
|
|
import sys
|
|
import tempfile
|
|
import zipfile
|
|
|
|
import numpy as np
|
|
|
|
from tensorflow.core.protobuf import config_pb2
|
|
from tensorflow.python.client import session
|
|
from tensorflow.python.debug.lib import debug_data
|
|
from tensorflow.python.debug.lib import debug_utils
|
|
from tensorflow.python.debug.lib import source_utils
|
|
from tensorflow.python.framework import constant_op
|
|
from tensorflow.python.framework import ops
|
|
from tensorflow.python.framework import test_util
|
|
from tensorflow.python.lib.io import file_io
|
|
from tensorflow.python.ops import control_flow_ops
|
|
from tensorflow.python.ops import math_ops
|
|
# Import resource_variable_ops for the variables-to-tensor implicit conversion.
|
|
from tensorflow.python.ops import resource_variable_ops # pylint: disable=unused-import
|
|
from tensorflow.python.ops import variables
|
|
from tensorflow.python.platform import googletest
|
|
from tensorflow.python.util import tf_inspect
|
|
|
|
|
|
def line_number_above():
|
|
"""Get lineno of the AST node immediately above this function's call site.
|
|
|
|
It is assumed that there is no empty line(s) between the call site and the
|
|
preceding AST node.
|
|
|
|
Returns:
|
|
The lineno of the preceding AST node, at the same level of the AST.
|
|
If the preceding AST spans multiple lines:
|
|
- In Python 3.8+, the lineno of the first line is returned.
|
|
- In older Python versions, the lineno of the last line is returned.
|
|
"""
|
|
# https://bugs.python.org/issue12458: In Python 3.8, traceback started
|
|
# to return the lineno of the first line of a multi-line continuation block,
|
|
# instead of that of the last line. Therefore, in Python 3.8+, we use `ast` to
|
|
# get the lineno of the first line.
|
|
call_site_lineno = tf_inspect.stack()[1][2]
|
|
if sys.version_info < (3, 8):
|
|
return call_site_lineno - 1
|
|
else:
|
|
with open(__file__, "rb") as f:
|
|
source_text = f.read().decode("utf-8")
|
|
source_tree = ast.parse(source_text)
|
|
prev_node = _find_preceding_ast_node(source_tree, call_site_lineno)
|
|
return prev_node.lineno
|
|
|
|
|
|
def _find_preceding_ast_node(node, lineno):
|
|
"""Find the ast node immediately before and not including lineno."""
|
|
for i, child_node in enumerate(node.body):
|
|
if child_node.lineno == lineno:
|
|
return node.body[i - 1]
|
|
if hasattr(child_node, "body"):
|
|
found_node = _find_preceding_ast_node(child_node, lineno)
|
|
if found_node:
|
|
return found_node
|
|
|
|
|
|
class GuessIsTensorFlowLibraryTest(test_util.TensorFlowTestCase):
|
|
|
|
def setUp(self):
|
|
self.curr_file_path = os.path.normpath(os.path.abspath(__file__))
|
|
|
|
def tearDown(self):
|
|
ops.reset_default_graph()
|
|
|
|
def testGuessedBaseDirIsProbablyCorrect(self):
|
|
# In the non-pip world, code resides in "tensorflow/"
|
|
# In the pip world, after virtual pip, code resides in "tensorflow_core/"
|
|
# So, we have to check both of them
|
|
self.assertIn(
|
|
os.path.basename(source_utils._TENSORFLOW_BASEDIR),
|
|
["tensorflow", "tensorflow_core"])
|
|
|
|
def testUnitTestFileReturnsFalse(self):
|
|
self.assertFalse(
|
|
source_utils.guess_is_tensorflow_py_library(self.curr_file_path))
|
|
|
|
def testSourceUtilModuleReturnsTrue(self):
|
|
self.assertTrue(
|
|
source_utils.guess_is_tensorflow_py_library(source_utils.__file__))
|
|
|
|
@test_util.run_v1_only("Tensor.op is not available in TF 2.x")
|
|
def testFileInPythonKernelsPathReturnsTrue(self):
|
|
x = constant_op.constant(42.0, name="x")
|
|
self.assertTrue(
|
|
source_utils.guess_is_tensorflow_py_library(x.op.traceback[-1][0]))
|
|
|
|
def testDebuggerExampleFilePathReturnsFalse(self):
|
|
self.assertFalse(
|
|
source_utils.guess_is_tensorflow_py_library(os.path.normpath(
|
|
"site-packages/tensorflow/python/debug/examples/debug_mnist.py")))
|
|
self.assertFalse(
|
|
source_utils.guess_is_tensorflow_py_library(os.path.normpath(
|
|
"site-packages/tensorflow/python/debug/examples/v1/example_v1.py")))
|
|
self.assertFalse(
|
|
source_utils.guess_is_tensorflow_py_library(os.path.normpath(
|
|
"site-packages/tensorflow/python/debug/examples/v2/example_v2.py")))
|
|
self.assertFalse(
|
|
source_utils.guess_is_tensorflow_py_library(os.path.normpath(
|
|
"site-packages/tensorflow/python/debug/examples/v3/example_v3.py")))
|
|
|
|
def testReturnsFalseForNonPythonFile(self):
|
|
self.assertFalse(
|
|
source_utils.guess_is_tensorflow_py_library(
|
|
os.path.join(os.path.dirname(self.curr_file_path), "foo.cc")))
|
|
|
|
def testReturnsFalseForStdin(self):
|
|
self.assertFalse(source_utils.guess_is_tensorflow_py_library("<stdin>"))
|
|
|
|
def testReturnsFalseForEmptyFileName(self):
|
|
self.assertFalse(source_utils.guess_is_tensorflow_py_library(""))
|
|
|
|
|
|
class SourceHelperTest(test_util.TensorFlowTestCase):
|
|
|
|
def createAndRunGraphHelper(self):
|
|
"""Create and run a TensorFlow Graph to generate debug dumps.
|
|
|
|
This is intentionally done in separate method, to make it easier to test
|
|
the stack-top mode of source annotation.
|
|
"""
|
|
|
|
self.dump_root = self.get_temp_dir()
|
|
self.curr_file_path = os.path.abspath(
|
|
tf_inspect.getfile(tf_inspect.currentframe()))
|
|
|
|
# Run a simple TF graph to generate some debug dumps that can be used in
|
|
# source annotation.
|
|
with session.Session() as sess:
|
|
self.u_init = constant_op.constant(
|
|
np.array([[5.0, 3.0], [-1.0, 0.0]]), shape=[2, 2], name="u_init")
|
|
self.u_init_line_number = line_number_above()
|
|
|
|
self.u = variables.Variable(self.u_init, name="u")
|
|
self.u_line_number = line_number_above()
|
|
|
|
self.v_init = constant_op.constant(
|
|
np.array([[2.0], [-1.0]]), shape=[2, 1], name="v_init")
|
|
self.v_init_line_number = line_number_above()
|
|
|
|
self.v = variables.Variable(self.v_init, name="v")
|
|
self.v_line_number = line_number_above()
|
|
|
|
self.w = math_ops.matmul(self.u, self.v, name="w")
|
|
self.w_line_number = line_number_above()
|
|
|
|
self.evaluate(self.u.initializer)
|
|
self.evaluate(self.v.initializer)
|
|
|
|
run_options = config_pb2.RunOptions(output_partition_graphs=True)
|
|
debug_utils.watch_graph(
|
|
run_options, sess.graph, debug_urls=["file://%s" % self.dump_root])
|
|
run_metadata = config_pb2.RunMetadata()
|
|
sess.run(self.w, options=run_options, run_metadata=run_metadata)
|
|
|
|
self.dump = debug_data.DebugDumpDir(
|
|
self.dump_root, partition_graphs=run_metadata.partition_graphs)
|
|
self.dump.set_python_graph(sess.graph)
|
|
|
|
def setUp(self):
|
|
self.createAndRunGraphHelper()
|
|
self.helper_line_number = line_number_above()
|
|
|
|
def tearDown(self):
|
|
if os.path.isdir(self.dump_root):
|
|
file_io.delete_recursively(self.dump_root)
|
|
ops.reset_default_graph()
|
|
|
|
def testAnnotateWholeValidSourceFileGivesCorrectResult(self):
|
|
source_annotation = source_utils.annotate_source(self.dump,
|
|
self.curr_file_path)
|
|
|
|
self.assertIn(self.u_init.op.name,
|
|
source_annotation[self.u_init_line_number])
|
|
self.assertIn(self.u.op.name, source_annotation[self.u_line_number])
|
|
self.assertIn(self.v_init.op.name,
|
|
source_annotation[self.v_init_line_number])
|
|
self.assertIn(self.v.op.name, source_annotation[self.v_line_number])
|
|
self.assertIn(self.w.op.name, source_annotation[self.w_line_number])
|
|
|
|
# In the non-stack-top (default) mode, the helper line should be annotated
|
|
# with all the ops as well.
|
|
self.assertIn(self.u_init.op.name,
|
|
source_annotation[self.helper_line_number])
|
|
self.assertIn(self.u.op.name, source_annotation[self.helper_line_number])
|
|
self.assertIn(self.v_init.op.name,
|
|
source_annotation[self.helper_line_number])
|
|
self.assertIn(self.v.op.name, source_annotation[self.helper_line_number])
|
|
self.assertIn(self.w.op.name, source_annotation[self.helper_line_number])
|
|
|
|
def testAnnotateWithStackTopGivesCorrectResult(self):
|
|
source_annotation = source_utils.annotate_source(
|
|
self.dump, self.curr_file_path, file_stack_top=True)
|
|
|
|
self.assertIn(self.u_init.op.name,
|
|
source_annotation[self.u_init_line_number])
|
|
self.assertIn(self.u.op.name, source_annotation[self.u_line_number])
|
|
self.assertIn(self.v_init.op.name,
|
|
source_annotation[self.v_init_line_number])
|
|
self.assertIn(self.v.op.name, source_annotation[self.v_line_number])
|
|
self.assertIn(self.w.op.name, source_annotation[self.w_line_number])
|
|
|
|
# In the stack-top mode, the helper line should not have been annotated.
|
|
self.assertNotIn(self.helper_line_number, source_annotation)
|
|
|
|
def testAnnotateSubsetOfLinesGivesCorrectResult(self):
|
|
source_annotation = source_utils.annotate_source(
|
|
self.dump,
|
|
self.curr_file_path,
|
|
min_line=self.u_line_number,
|
|
max_line=self.u_line_number + 1)
|
|
|
|
self.assertIn(self.u.op.name, source_annotation[self.u_line_number])
|
|
self.assertNotIn(self.v_line_number, source_annotation)
|
|
|
|
def testAnnotateDumpedTensorsGivesCorrectResult(self):
|
|
source_annotation = source_utils.annotate_source(
|
|
self.dump, self.curr_file_path, do_dumped_tensors=True)
|
|
|
|
# Note: Constant Tensors u_init and v_init may not get dumped due to
|
|
# constant-folding.
|
|
self.assertIn(self.u.name, source_annotation[self.u_line_number])
|
|
self.assertIn(self.v.name, source_annotation[self.v_line_number])
|
|
self.assertIn(self.w.name, source_annotation[self.w_line_number])
|
|
|
|
self.assertNotIn(self.u.op.name, source_annotation[self.u_line_number])
|
|
self.assertNotIn(self.v.op.name, source_annotation[self.v_line_number])
|
|
self.assertNotIn(self.w.op.name, source_annotation[self.w_line_number])
|
|
|
|
self.assertIn(self.u.name, source_annotation[self.helper_line_number])
|
|
self.assertIn(self.v.name, source_annotation[self.helper_line_number])
|
|
self.assertIn(self.w.name, source_annotation[self.helper_line_number])
|
|
|
|
def testCallingAnnotateSourceWithoutPythonGraphRaisesException(self):
|
|
self.dump.set_python_graph(None)
|
|
with self.assertRaises(ValueError):
|
|
source_utils.annotate_source(self.dump, self.curr_file_path)
|
|
|
|
def testCallingAnnotateSourceOnUnrelatedSourceFileDoesNotError(self):
|
|
# Create an unrelated source file.
|
|
unrelated_source_path = tempfile.mktemp()
|
|
with open(unrelated_source_path, "wt") as source_file:
|
|
source_file.write("print('hello, world')\n")
|
|
|
|
self.assertEqual({},
|
|
source_utils.annotate_source(self.dump,
|
|
unrelated_source_path))
|
|
|
|
# Clean up unrelated source file.
|
|
os.remove(unrelated_source_path)
|
|
|
|
def testLoadingPythonSourceFileWithNonAsciiChars(self):
|
|
source_path = tempfile.mktemp()
|
|
with open(source_path, "wb") as source_file:
|
|
source_file.write(u"print('\U0001f642')\n".encode("utf-8"))
|
|
source_lines, _ = source_utils.load_source(source_path)
|
|
self.assertEqual(source_lines, [u"print('\U0001f642')", u""])
|
|
# Clean up unrelated source file.
|
|
os.remove(source_path)
|
|
|
|
def testLoadNonexistentNonParPathFailsWithIOError(self):
|
|
bad_path = os.path.join(self.get_temp_dir(), "nonexistent.py")
|
|
with self.assertRaisesRegex(IOError,
|
|
"neither exists nor can be loaded.*par.*"):
|
|
source_utils.load_source(bad_path)
|
|
|
|
def testLoadingPythonSourceFileInParFileSucceeds(self):
|
|
# Create the .par file first.
|
|
temp_file_path = os.path.join(self.get_temp_dir(), "model.py")
|
|
with open(temp_file_path, "wb") as f:
|
|
f.write(b"import tensorflow as tf\nx = tf.constant(42.0)\n")
|
|
par_path = os.path.join(self.get_temp_dir(), "train_model.par")
|
|
with zipfile.ZipFile(par_path, "w") as zf:
|
|
zf.write(temp_file_path, os.path.join("tensorflow_models", "model.py"))
|
|
|
|
source_path = os.path.join(par_path, "tensorflow_models", "model.py")
|
|
source_lines, _ = source_utils.load_source(source_path)
|
|
self.assertEqual(
|
|
source_lines, ["import tensorflow as tf", "x = tf.constant(42.0)", ""])
|
|
|
|
def testLoadingPythonSourceFileInParFileFailsRaisingIOError(self):
|
|
# Create the .par file first.
|
|
temp_file_path = os.path.join(self.get_temp_dir(), "model.py")
|
|
with open(temp_file_path, "wb") as f:
|
|
f.write(b"import tensorflow as tf\nx = tf.constant(42.0)\n")
|
|
par_path = os.path.join(self.get_temp_dir(), "train_model.par")
|
|
with zipfile.ZipFile(par_path, "w") as zf:
|
|
zf.write(temp_file_path, os.path.join("tensorflow_models", "model.py"))
|
|
|
|
source_path = os.path.join(par_path, "tensorflow_models", "nonexistent.py")
|
|
with self.assertRaisesRegex(IOError,
|
|
"neither exists nor can be loaded.*par.*"):
|
|
source_utils.load_source(source_path)
|
|
|
|
|
|
@test_util.run_v1_only("Sessions are not available in TF 2.x")
|
|
class ListSourceAgainstDumpTest(test_util.TensorFlowTestCase):
|
|
|
|
def createAndRunGraphWithWhileLoop(self):
|
|
"""Create and run a TensorFlow Graph with a while loop to generate dumps."""
|
|
|
|
self.dump_root = self.get_temp_dir()
|
|
self.curr_file_path = os.path.abspath(
|
|
tf_inspect.getfile(tf_inspect.currentframe()))
|
|
|
|
# Run a simple TF graph to generate some debug dumps that can be used in
|
|
# source annotation.
|
|
with session.Session() as sess:
|
|
loop_body = lambda i: math_ops.add(i, 2)
|
|
self.traceback_first_line = line_number_above()
|
|
|
|
loop_cond = lambda i: math_ops.less(i, 16)
|
|
|
|
i = constant_op.constant(10, name="i")
|
|
loop = control_flow_ops.while_loop(loop_cond, loop_body, [i])
|
|
|
|
run_options = config_pb2.RunOptions(output_partition_graphs=True)
|
|
debug_utils.watch_graph(
|
|
run_options, sess.graph, debug_urls=["file://%s" % self.dump_root])
|
|
run_metadata = config_pb2.RunMetadata()
|
|
sess.run(loop, options=run_options, run_metadata=run_metadata)
|
|
|
|
self.dump = debug_data.DebugDumpDir(
|
|
self.dump_root, partition_graphs=run_metadata.partition_graphs)
|
|
self.dump.set_python_graph(sess.graph)
|
|
|
|
def setUp(self):
|
|
self.createAndRunGraphWithWhileLoop()
|
|
|
|
def tearDown(self):
|
|
if os.path.isdir(self.dump_root):
|
|
file_io.delete_recursively(self.dump_root)
|
|
ops.reset_default_graph()
|
|
|
|
def testGenerateSourceList(self):
|
|
source_list = source_utils.list_source_files_against_dump(self.dump)
|
|
|
|
# Assert that the file paths are sorted and unique.
|
|
file_paths = [item[0] for item in source_list]
|
|
self.assertEqual(sorted(file_paths), file_paths)
|
|
self.assertEqual(len(set(file_paths)), len(file_paths))
|
|
|
|
# Assert that each item of source_list has length 6.
|
|
for item in source_list:
|
|
self.assertTrue(isinstance(item, tuple))
|
|
self.assertEqual(6, len(item))
|
|
|
|
# The while loop body should have executed 3 times. The following table
|
|
# lists the tensors and how many times each of them is dumped.
|
|
# Tensor name # of times dumped:
|
|
# i:0 1
|
|
# while/Enter:0 1
|
|
# while/Merge:0 4
|
|
# while/Merge:1 4
|
|
# while/Less/y:0 4
|
|
# while/Less:0 4
|
|
# while/LoopCond:0 4
|
|
# while/Switch:0 1
|
|
# while/Switch:1 3
|
|
# while/Identity:0 3
|
|
# while/Add/y:0 3
|
|
# while/Add:0 3
|
|
# while/NextIteration:0 3
|
|
# while/Exit:0 1
|
|
# ----------------------------
|
|
# (Total) 39
|
|
#
|
|
# The total number of nodes is 12.
|
|
# The total number of tensors is 14 (2 of the nodes have 2 outputs:
|
|
# while/Merge, while/Switch).
|
|
|
|
_, is_tf_py_library, num_nodes, num_tensors, num_dumps, first_line = (
|
|
source_list[file_paths.index(self.curr_file_path)])
|
|
self.assertFalse(is_tf_py_library)
|
|
self.assertEqual(12, num_nodes)
|
|
self.assertEqual(14, num_tensors)
|
|
self.assertEqual(39, num_dumps)
|
|
self.assertEqual(self.traceback_first_line, first_line)
|
|
|
|
def testGenerateSourceListWithNodeNameFilter(self):
|
|
source_list = source_utils.list_source_files_against_dump(
|
|
self.dump, node_name_regex_allowlist=r"while/Add.*")
|
|
|
|
# Assert that the file paths are sorted.
|
|
file_paths = [item[0] for item in source_list]
|
|
self.assertEqual(sorted(file_paths), file_paths)
|
|
self.assertEqual(len(set(file_paths)), len(file_paths))
|
|
|
|
# Assert that each item of source_list has length 4.
|
|
for item in source_list:
|
|
self.assertTrue(isinstance(item, tuple))
|
|
self.assertEqual(6, len(item))
|
|
|
|
# Due to the node-name filtering the result should only contain 2 nodes
|
|
# and 2 tensors. The total number of dumped tensors should be 6:
|
|
# while/Add/y:0 3
|
|
# while/Add:0 3
|
|
_, is_tf_py_library, num_nodes, num_tensors, num_dumps, _ = (
|
|
source_list[file_paths.index(self.curr_file_path)])
|
|
self.assertFalse(is_tf_py_library)
|
|
self.assertEqual(2, num_nodes)
|
|
self.assertEqual(2, num_tensors)
|
|
self.assertEqual(6, num_dumps)
|
|
|
|
def testGenerateSourceListWithPathRegexFilter(self):
|
|
curr_file_basename = os.path.basename(self.curr_file_path)
|
|
source_list = source_utils.list_source_files_against_dump(
|
|
self.dump,
|
|
path_regex_allowlist=(".*" + curr_file_basename.replace(".", "\\.") +
|
|
"$"))
|
|
|
|
self.assertEqual(1, len(source_list))
|
|
(file_path, is_tf_py_library, num_nodes, num_tensors, num_dumps,
|
|
first_line) = source_list[0]
|
|
self.assertEqual(self.curr_file_path, file_path)
|
|
self.assertFalse(is_tf_py_library)
|
|
self.assertEqual(12, num_nodes)
|
|
self.assertEqual(14, num_tensors)
|
|
self.assertEqual(39, num_dumps)
|
|
self.assertEqual(self.traceback_first_line, first_line)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
googletest.main()
|