Fix skipped tests in error_interpolation_test.py

These only make sense in graph mode, so I made this explicit within the tests.

PiperOrigin-RevId: 339050796
Change-Id: I97674d036f95882d8d4275e6b65e7c432ec0d943
This commit is contained in:
James Keeling 2020-10-26 09:11:17 -07:00 committed by TensorFlower Gardener
parent 8c78dae8fb
commit b8f5303051

View File

@ -25,7 +25,6 @@ import re
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import error_interpolation
from tensorflow.python.framework import ops
from tensorflow.python.framework import test_util
from tensorflow.python.framework import traceable_stack
from tensorflow.python.ops import math_ops
from tensorflow.python.platform import test
@ -126,13 +125,8 @@ class ComputeColocationSummaryFromOpTest(test.TestCase):
# so it is excluded from eager tests. Even when used in eager mode, it is
# via FunctionGraphs, and directly verifying in graph mode is the narrowest
# way to unit test the functionality.
@test_util.run_deprecated_v1
class CreateGraphDebugInfoDefTest(test.TestCase):
def setUp(self):
super(CreateGraphDebugInfoDefTest, self).setUp()
ops.reset_default_graph()
def _getFirstStackTraceForFile(self, graph_debug_info, key, file_index):
self.assertIn(key, graph_debug_info.traces)
stack_trace = graph_debug_info.traces[key]
@ -146,201 +140,205 @@ class CreateGraphDebugInfoDefTest(test.TestCase):
return found_flc
def testStackTraceExtraction(self):
# Since the create_graph_debug_info_def() function does not actually
# do anything special with functions except name mangling, just verify
# it with a loose op and manually provided function name.
# The following ops *must* be on consecutive lines (it will be verified
# in the resulting trace).
# pyformat: disable
global_op = constant_op.constant(0, name="Global").op
op1 = constant_op.constant(1, name="One").op
op2 = constant_op.constant(2, name="Two").op
non_traceback_op = constant_op.constant(3, name="NonTraceback").op
# Ensure op without traceback does not fail
del non_traceback_op._traceback
# pyformat: enable
# This test is verifying stack trace information added in graph mode, so
# only makes sense in graph mode.
with ops.Graph().as_default():
# Since the create_graph_debug_info_def() function does not actually
# do anything special with functions except name mangling, just verify
# it with a loose op and manually provided function name.
# The following ops *must* be on consecutive lines (it will be verified
# in the resulting trace).
# pyformat: disable
global_op = constant_op.constant(0, name="Global").op
op1 = constant_op.constant(1, name="One").op
op2 = constant_op.constant(2, name="Two").op
non_traceback_op = constant_op.constant(3, name="NonTraceback").op
# Ensure op without traceback does not fail
del non_traceback_op._traceback
# pyformat: enable
export_ops = [("", global_op), ("func1", op1), ("func2", op2),
("func2", non_traceback_op)]
graph_debug_info = error_interpolation.create_graph_debug_info_def(
export_ops)
this_file_index = -1
for file_index, file_name in enumerate(graph_debug_info.files):
if "{}error_interpolation_test.py".format(os.sep) in file_name:
this_file_index = file_index
self.assertGreaterEqual(
this_file_index, 0,
"Could not find this file in trace:" + repr(graph_debug_info))
export_ops = [("", global_op), ("func1", op1), ("func2", op2),
("func2", non_traceback_op)]
graph_debug_info = error_interpolation.create_graph_debug_info_def(
export_ops)
this_file_index = -1
for file_index, file_name in enumerate(graph_debug_info.files):
if "{}error_interpolation_test.py".format(os.sep) in file_name:
this_file_index = file_index
self.assertGreaterEqual(
this_file_index, 0,
"Could not find this file in trace:" + repr(graph_debug_info))
# Verify the traces exist for each op.
global_flc = self._getFirstStackTraceForFile(graph_debug_info, "Global@",
this_file_index)
op1_flc = self._getFirstStackTraceForFile(graph_debug_info, "One@func1",
this_file_index)
op2_flc = self._getFirstStackTraceForFile(graph_debug_info, "Two@func2",
this_file_index)
# Verify the traces exist for each op.
global_flc = self._getFirstStackTraceForFile(graph_debug_info, "Global@",
this_file_index)
op1_flc = self._getFirstStackTraceForFile(graph_debug_info, "One@func1",
this_file_index)
op2_flc = self._getFirstStackTraceForFile(graph_debug_info, "Two@func2",
this_file_index)
global_line = global_flc.line
self.assertEqual(op1_flc.line, global_line + 1, "op1 not on next line")
self.assertEqual(op2_flc.line, global_line + 2, "op2 not on next line")
global_line = global_flc.line
self.assertEqual(op1_flc.line, global_line + 1, "op1 not on next line")
self.assertEqual(op2_flc.line, global_line + 2, "op2 not on next line")
@test_util.run_deprecated_v1
class InterpolateFilenamesAndLineNumbersTest(test.TestCase):
def setUp(self):
super(InterpolateFilenamesAndLineNumbersTest, self).setUp()
ops.reset_default_graph()
# Add nodes to the graph for retrieval by name later.
constant_op.constant(1, name="One")
constant_op.constant(2, name="Two")
three = constant_op.constant(3, name="Three")
self.graph = three.graph
def testFindIndexOfDefiningFrameForOp(self):
local_op = constant_op.constant(42).op
user_filename = "hope.py"
_modify_op_stack_with_filenames(
local_op,
num_user_frames=3,
user_filename=user_filename,
num_inner_tf_frames=5)
idx = error_interpolation._find_index_of_defining_frame(local_op._traceback)
# Expected frame is 6th from the end because there are 5 inner frames witih
# TF filenames.
expected_frame = len(local_op._traceback) - 6
self.assertEqual(expected_frame, idx)
with ops.Graph().as_default():
local_op = constant_op.constant(42).op
user_filename = "hope.py"
_modify_op_stack_with_filenames(
local_op,
num_user_frames=3,
user_filename=user_filename,
num_inner_tf_frames=5)
idx = error_interpolation._find_index_of_defining_frame(
local_op._traceback)
# Expected frame is 6th from the end because there are 5 inner frames with
# TF filenames.
expected_frame = len(local_op._traceback) - 6
self.assertEqual(expected_frame, idx)
def testFindIndexOfDefiningFrameForOpReturnsZeroOnError(self):
local_op = constant_op.constant(43).op
# Truncate stack to known length.
local_op._traceback = local_op._traceback[:7]
# Ensure all frames look like TF frames.
_modify_op_stack_with_filenames(
local_op,
num_user_frames=0,
user_filename="user_file.py",
num_inner_tf_frames=7)
idx = error_interpolation._find_index_of_defining_frame(local_op._traceback)
self.assertEqual(0, idx)
with ops.Graph().as_default():
local_op = constant_op.constant(43).op
# Truncate stack to known length.
local_op._traceback = local_op._traceback[:7]
# Ensure all frames look like TF frames.
_modify_op_stack_with_filenames(
local_op,
num_user_frames=0,
user_filename="user_file.py",
num_inner_tf_frames=7)
idx = error_interpolation._find_index_of_defining_frame(
local_op._traceback)
self.assertEqual(0, idx)
def testNothingToDo(self):
normal_string = "This is just a normal string"
interpolated_string = error_interpolation.interpolate(
normal_string, self.graph)
self.assertEqual(interpolated_string, normal_string)
with ops.Graph().as_default():
constant_op.constant(1, name="One")
normal_string = "This is just a normal string"
interpolated_string = error_interpolation.interpolate(
normal_string, ops.get_default_graph())
self.assertEqual(interpolated_string, normal_string)
def testOneTagWithAFakeNameResultsInPlaceholders(self):
one_tag_string = "{{node MinusOne}}"
interpolated_string = error_interpolation.interpolate(
one_tag_string, self.graph)
self.assertEqual(one_tag_string, interpolated_string)
with ops.Graph().as_default():
one_tag_string = "{{node MinusOne}}"
interpolated_string = error_interpolation.interpolate(
one_tag_string, ops.get_default_graph())
self.assertEqual(one_tag_string, interpolated_string)
def testTwoTagsNoSeps(self):
two_tags_no_seps = "{{node One}}{{node Three}}"
interpolated_string = error_interpolation.interpolate(
two_tags_no_seps, self.graph)
self.assertRegex(
interpolated_string, r"error_interpolation_test\.py:[0-9]+."
r"*error_interpolation_test\.py:[0-9]+")
with ops.Graph().as_default():
constant_op.constant(1, name="One")
constant_op.constant(2, name="Two")
constant_op.constant(3, name="Three")
two_tags_no_seps = "{{node One}}{{node Three}}"
interpolated_string = error_interpolation.interpolate(
two_tags_no_seps, ops.get_default_graph())
self.assertRegex(
interpolated_string, r"error_interpolation_test\.py:[0-9]+."
r"*error_interpolation_test\.py:[0-9]+")
def testTwoTagsWithSeps(self):
two_tags_with_seps = ";;;{{node Two}},,,{{node Three}};;;"
interpolated_string = error_interpolation.interpolate(
two_tags_with_seps, self.graph)
expected_regex = (r"^;;;.*error_interpolation_test\.py:[0-9]+\) "
r",,,.*error_interpolation_test\.py:[0-9]+\) ;;;$")
self.assertRegex(interpolated_string, expected_regex)
with ops.Graph().as_default():
constant_op.constant(1, name="One")
constant_op.constant(2, name="Two")
constant_op.constant(3, name="Three")
two_tags_with_seps = ";;;{{node Two}},,,{{node Three}};;;"
interpolated_string = error_interpolation.interpolate(
two_tags_with_seps, ops.get_default_graph())
expected_regex = (r"^;;;.*error_interpolation_test\.py:[0-9]+\) "
r",,,.*error_interpolation_test\.py:[0-9]+\) ;;;$")
self.assertRegex(interpolated_string, expected_regex)
def testNewLine(self):
newline = "\n\n{{node One}}"
interpolated_string = error_interpolation.interpolate(newline, self.graph)
self.assertRegex(interpolated_string,
r"error_interpolation_test\.py:[0-9]+.*")
with ops.Graph().as_default():
constant_op.constant(1, name="One")
constant_op.constant(2, name="Two")
newline = "\n\n{{node One}}"
interpolated_string = error_interpolation.interpolate(
newline, ops.get_default_graph())
self.assertRegex(interpolated_string,
r"error_interpolation_test\.py:[0-9]+.*")
@test_util.run_deprecated_v1
class InputNodesTest(test.TestCase):
def setUp(self):
super(InputNodesTest, self).setUp()
# Add nodes to the graph for retrieval by name later.
one = constant_op.constant(1, name="One")
two = constant_op.constant(2, name="Two")
three = math_ops.add(one, two, name="Three")
non_traceback_op = constant_op.constant(3, name="NonTraceback")
# Ensure op without traceback does not fail
del non_traceback_op.op._traceback
self.graph = three.graph
def testNoInputs(self):
two_tags_with_seps = ";;;{{node One}},,,{{node Two}};;;"
interpolated_string = error_interpolation.interpolate(
two_tags_with_seps, self.graph)
expected_regex = (r"^;;;.*error_interpolation_test\.py:[0-9]+\) "
r",,,.*error_interpolation_test\.py:[0-9]+\) ;;;$")
self.assertRegex(interpolated_string, expected_regex)
with ops.Graph().as_default():
one = constant_op.constant(1, name="One")
two = constant_op.constant(2, name="Two")
_ = math_ops.add(one, two, name="Three")
two_tags_with_seps = ";;;{{node One}},,,{{node Two}};;;"
interpolated_string = error_interpolation.interpolate(
two_tags_with_seps, ops.get_default_graph())
expected_regex = (r"^;;;.*error_interpolation_test\.py:[0-9]+\) "
r",,,.*error_interpolation_test\.py:[0-9]+\) ;;;$")
self.assertRegex(interpolated_string, expected_regex)
def testBasicInputs(self):
tag = ";;;{{node Three}};;;"
interpolated_string = error_interpolation.interpolate(tag, self.graph)
expected_regex = re.compile(
r"^;;;.*error_interpolation_test\.py:[0-9]+\) "
r";;;.*Input.*error_interpolation_test\.py:[0-9]+\)", re.DOTALL)
self.assertRegex(interpolated_string, expected_regex)
with ops.Graph().as_default():
one = constant_op.constant(1, name="One")
two = constant_op.constant(2, name="Two")
_ = math_ops.add(one, two, name="Three")
tag = ";;;{{node Three}};;;"
interpolated_string = error_interpolation.interpolate(
tag, ops.get_default_graph())
expected_regex = re.compile(
r"^;;;.*error_interpolation_test\.py:[0-9]+\) "
r";;;.*Input.*error_interpolation_test\.py:[0-9]+\)", re.DOTALL)
self.assertRegex(interpolated_string, expected_regex)
@test_util.run_deprecated_v1
class InterpolateDeviceSummaryTest(test.TestCase):
def _fancy_device_function(self, unused_op):
return "/cpu:*"
def setUp(self):
super(InterpolateDeviceSummaryTest, self).setUp()
ops.reset_default_graph()
self.zero = constant_op.constant([0.0], name="zero")
with ops.device("/cpu"):
self.one = constant_op.constant([1.0], name="one")
with ops.device("/cpu:0"):
self.two = constant_op.constant([2.0], name="two")
with ops.device(self._fancy_device_function):
self.three = constant_op.constant(3.0, name="three")
self.graph = self.three.graph
def testNodeZeroHasNoDeviceSummaryInfo(self):
message = "{{colocation_node zero}}"
result = error_interpolation.interpolate(message, self.graph)
self.assertIn("No device assignments were active", result)
with ops.Graph().as_default():
self.zero = constant_op.constant([0.0], name="zero")
message = "{{colocation_node zero}}"
result = error_interpolation.interpolate(message, ops.get_default_graph())
self.assertIn("No device assignments were active", result)
def testNodeOneHasExactlyOneInterpolatedDevice(self):
message = "{{colocation_node one}}"
result = error_interpolation.interpolate(message, self.graph)
self.assertEqual(2, result.count("tf.device(/cpu)"))
with ops.Graph().as_default():
with ops.device("/cpu"):
self.one = constant_op.constant([1.0], name="one")
message = "{{colocation_node one}}"
result = error_interpolation.interpolate(message, ops.get_default_graph())
self.assertEqual(2, result.count("tf.device(/cpu)"))
def testNodeTwoHasTwoInterpolatedDevice(self):
message = "{{colocation_node two}}"
result = error_interpolation.interpolate(message, self.graph)
self.assertEqual(2, result.count("tf.device(/cpu)"))
self.assertEqual(2, result.count("tf.device(/cpu:0)"))
with ops.Graph().as_default():
with ops.device("/cpu"):
with ops.device("/cpu:0"):
self.two = constant_op.constant([2.0], name="two")
message = "{{colocation_node two}}"
result = error_interpolation.interpolate(message, ops.get_default_graph())
self.assertEqual(2, result.count("tf.device(/cpu)"))
self.assertEqual(2, result.count("tf.device(/cpu:0)"))
def testNodeThreeHasFancyFunctionDisplayNameForInterpolatedDevice(self):
message = "{{colocation_node three}}"
result = error_interpolation.interpolate(message, self.graph)
num_devices = result.count("tf.device")
self.assertEqual(2, num_devices)
name_re = r"_fancy_device_function<.*error_interpolation_test.py, [0-9]+>"
expected_re = r"with tf.device\(.*%s\)" % name_re
self.assertRegex(result, expected_re)
with ops.Graph().as_default():
with ops.device(self._fancy_device_function):
self.three = constant_op.constant(3.0, name="three")
message = "{{colocation_node three}}"
result = error_interpolation.interpolate(message, ops.get_default_graph())
num_devices = result.count("tf.device")
self.assertEqual(2, num_devices)
name_re = r"_fancy_device_function<.*error_interpolation_test.py, [0-9]+>"
expected_re = r"with tf.device\(.*%s\)" % name_re
self.assertRegex(result, expected_re)
@test_util.run_deprecated_v1
class InterpolateColocationSummaryTest(test.TestCase):
def setUp(self):
super(InterpolateColocationSummaryTest, self).setUp()
ops.reset_default_graph()
def _set_up_graph(self):
# Add nodes to the graph for retrieval by name later.
node_one = constant_op.constant(1, name="One")
node_two = constant_op.constant(2, name="Two")
@ -359,32 +357,39 @@ class InterpolateColocationSummaryTest(test.TestCase):
with ops.colocate_with(node_one):
constant_op.constant(5, name="Five_with_one_with_two")
self.graph = node_three.graph
def testNodeThreeHasColocationInterpolation(self):
message = "{{colocation_node Three_with_one}}"
result = error_interpolation.interpolate(message, self.graph)
self.assertIn("colocate_with(One)", result)
with ops.Graph().as_default():
self._set_up_graph()
message = "{{colocation_node Three_with_one}}"
result = error_interpolation.interpolate(message, ops.get_default_graph())
self.assertIn("colocate_with(One)", result)
def testNodeFourHasColocationInterpolationForNodeThreeOnly(self):
message = "{{colocation_node Four_with_three}}"
result = error_interpolation.interpolate(message, self.graph)
self.assertIn("colocate_with(Three_with_one)", result)
self.assertNotIn(
"One", result,
"Node One should not appear in Four_with_three's summary:\n%s" % result)
with ops.Graph().as_default():
self._set_up_graph()
message = "{{colocation_node Four_with_three}}"
result = error_interpolation.interpolate(message, ops.get_default_graph())
self.assertIn("colocate_with(Three_with_one)", result)
self.assertNotIn(
"One", result,
"Node One should not appear in Four_with_three's summary:\n%s" %
result)
def testNodeFiveHasColocationInterpolationForNodeOneAndTwo(self):
message = "{{colocation_node Five_with_one_with_two}}"
result = error_interpolation.interpolate(message, self.graph)
self.assertIn("colocate_with(One)", result)
self.assertIn("colocate_with(Two)", result)
with ops.Graph().as_default():
self._set_up_graph()
message = "{{colocation_node Five_with_one_with_two}}"
result = error_interpolation.interpolate(message, ops.get_default_graph())
self.assertIn("colocate_with(One)", result)
self.assertIn("colocate_with(Two)", result)
def testColocationInterpolationForNodeLackingColocation(self):
message = "{{colocation_node One}}"
result = error_interpolation.interpolate(message, self.graph)
self.assertIn("No node-device colocations", result)
self.assertNotIn("Two", result)
with ops.Graph().as_default():
self._set_up_graph()
message = "{{colocation_node One}}"
result = error_interpolation.interpolate(message, ops.get_default_graph())
self.assertIn("No node-device colocations", result)
self.assertNotIn("Two", result)
class IsFrameworkFilenameTest(test.TestCase):