Fix skipped tests in error_interpolation_test.py

These only make sense in graph mode, so I made this explicit within the tests.

PiperOrigin-RevId: 339050796
Change-Id: I97674d036f95882d8d4275e6b65e7c432ec0d943
This commit is contained in:
James Keeling 2020-10-26 09:11:17 -07:00 committed by TensorFlower Gardener
parent 8c78dae8fb
commit b8f5303051

View File

@ -25,7 +25,6 @@ import re
from tensorflow.python.framework import constant_op from tensorflow.python.framework import constant_op
from tensorflow.python.framework import error_interpolation from tensorflow.python.framework import error_interpolation
from tensorflow.python.framework import ops from tensorflow.python.framework import ops
from tensorflow.python.framework import test_util
from tensorflow.python.framework import traceable_stack from tensorflow.python.framework import traceable_stack
from tensorflow.python.ops import math_ops from tensorflow.python.ops import math_ops
from tensorflow.python.platform import test from tensorflow.python.platform import test
@ -126,13 +125,8 @@ class ComputeColocationSummaryFromOpTest(test.TestCase):
# so it is excluded from eager tests. Even when used in eager mode, it is # so it is excluded from eager tests. Even when used in eager mode, it is
# via FunctionGraphs, and directly verifying in graph mode is the narrowest # via FunctionGraphs, and directly verifying in graph mode is the narrowest
# way to unit test the functionality. # way to unit test the functionality.
@test_util.run_deprecated_v1
class CreateGraphDebugInfoDefTest(test.TestCase): class CreateGraphDebugInfoDefTest(test.TestCase):
def setUp(self):
super(CreateGraphDebugInfoDefTest, self).setUp()
ops.reset_default_graph()
def _getFirstStackTraceForFile(self, graph_debug_info, key, file_index): def _getFirstStackTraceForFile(self, graph_debug_info, key, file_index):
self.assertIn(key, graph_debug_info.traces) self.assertIn(key, graph_debug_info.traces)
stack_trace = graph_debug_info.traces[key] stack_trace = graph_debug_info.traces[key]
@ -146,201 +140,205 @@ class CreateGraphDebugInfoDefTest(test.TestCase):
return found_flc return found_flc
def testStackTraceExtraction(self): def testStackTraceExtraction(self):
# Since the create_graph_debug_info_def() function does not actually # This test is verifying stack trace information added in graph mode, so
# do anything special with functions except name mangling, just verify # only makes sense in graph mode.
# it with a loose op and manually provided function name. with ops.Graph().as_default():
# The following ops *must* be on consecutive lines (it will be verified # Since the create_graph_debug_info_def() function does not actually
# in the resulting trace). # do anything special with functions except name mangling, just verify
# pyformat: disable # it with a loose op and manually provided function name.
global_op = constant_op.constant(0, name="Global").op # The following ops *must* be on consecutive lines (it will be verified
op1 = constant_op.constant(1, name="One").op # in the resulting trace).
op2 = constant_op.constant(2, name="Two").op # pyformat: disable
non_traceback_op = constant_op.constant(3, name="NonTraceback").op global_op = constant_op.constant(0, name="Global").op
# Ensure op without traceback does not fail op1 = constant_op.constant(1, name="One").op
del non_traceback_op._traceback op2 = constant_op.constant(2, name="Two").op
# pyformat: enable non_traceback_op = constant_op.constant(3, name="NonTraceback").op
# Ensure op without traceback does not fail
del non_traceback_op._traceback
# pyformat: enable
export_ops = [("", global_op), ("func1", op1), ("func2", op2), export_ops = [("", global_op), ("func1", op1), ("func2", op2),
("func2", non_traceback_op)] ("func2", non_traceback_op)]
graph_debug_info = error_interpolation.create_graph_debug_info_def( graph_debug_info = error_interpolation.create_graph_debug_info_def(
export_ops) export_ops)
this_file_index = -1 this_file_index = -1
for file_index, file_name in enumerate(graph_debug_info.files): for file_index, file_name in enumerate(graph_debug_info.files):
if "{}error_interpolation_test.py".format(os.sep) in file_name: if "{}error_interpolation_test.py".format(os.sep) in file_name:
this_file_index = file_index this_file_index = file_index
self.assertGreaterEqual( self.assertGreaterEqual(
this_file_index, 0, this_file_index, 0,
"Could not find this file in trace:" + repr(graph_debug_info)) "Could not find this file in trace:" + repr(graph_debug_info))
# Verify the traces exist for each op. # Verify the traces exist for each op.
global_flc = self._getFirstStackTraceForFile(graph_debug_info, "Global@", global_flc = self._getFirstStackTraceForFile(graph_debug_info, "Global@",
this_file_index) this_file_index)
op1_flc = self._getFirstStackTraceForFile(graph_debug_info, "One@func1", op1_flc = self._getFirstStackTraceForFile(graph_debug_info, "One@func1",
this_file_index) this_file_index)
op2_flc = self._getFirstStackTraceForFile(graph_debug_info, "Two@func2", op2_flc = self._getFirstStackTraceForFile(graph_debug_info, "Two@func2",
this_file_index) this_file_index)
global_line = global_flc.line global_line = global_flc.line
self.assertEqual(op1_flc.line, global_line + 1, "op1 not on next line") self.assertEqual(op1_flc.line, global_line + 1, "op1 not on next line")
self.assertEqual(op2_flc.line, global_line + 2, "op2 not on next line") self.assertEqual(op2_flc.line, global_line + 2, "op2 not on next line")
@test_util.run_deprecated_v1
class InterpolateFilenamesAndLineNumbersTest(test.TestCase): class InterpolateFilenamesAndLineNumbersTest(test.TestCase):
def setUp(self):
super(InterpolateFilenamesAndLineNumbersTest, self).setUp()
ops.reset_default_graph()
# Add nodes to the graph for retrieval by name later.
constant_op.constant(1, name="One")
constant_op.constant(2, name="Two")
three = constant_op.constant(3, name="Three")
self.graph = three.graph
def testFindIndexOfDefiningFrameForOp(self): def testFindIndexOfDefiningFrameForOp(self):
local_op = constant_op.constant(42).op with ops.Graph().as_default():
user_filename = "hope.py" local_op = constant_op.constant(42).op
_modify_op_stack_with_filenames( user_filename = "hope.py"
local_op, _modify_op_stack_with_filenames(
num_user_frames=3, local_op,
user_filename=user_filename, num_user_frames=3,
num_inner_tf_frames=5) user_filename=user_filename,
idx = error_interpolation._find_index_of_defining_frame(local_op._traceback) num_inner_tf_frames=5)
# Expected frame is 6th from the end because there are 5 inner frames witih idx = error_interpolation._find_index_of_defining_frame(
# TF filenames. local_op._traceback)
expected_frame = len(local_op._traceback) - 6 # Expected frame is 6th from the end because there are 5 inner frames with
self.assertEqual(expected_frame, idx) # TF filenames.
expected_frame = len(local_op._traceback) - 6
self.assertEqual(expected_frame, idx)
def testFindIndexOfDefiningFrameForOpReturnsZeroOnError(self): def testFindIndexOfDefiningFrameForOpReturnsZeroOnError(self):
local_op = constant_op.constant(43).op with ops.Graph().as_default():
# Truncate stack to known length. local_op = constant_op.constant(43).op
local_op._traceback = local_op._traceback[:7] # Truncate stack to known length.
# Ensure all frames look like TF frames. local_op._traceback = local_op._traceback[:7]
_modify_op_stack_with_filenames( # Ensure all frames look like TF frames.
local_op, _modify_op_stack_with_filenames(
num_user_frames=0, local_op,
user_filename="user_file.py", num_user_frames=0,
num_inner_tf_frames=7) user_filename="user_file.py",
idx = error_interpolation._find_index_of_defining_frame(local_op._traceback) num_inner_tf_frames=7)
self.assertEqual(0, idx) idx = error_interpolation._find_index_of_defining_frame(
local_op._traceback)
self.assertEqual(0, idx)
def testNothingToDo(self): def testNothingToDo(self):
normal_string = "This is just a normal string" with ops.Graph().as_default():
interpolated_string = error_interpolation.interpolate( constant_op.constant(1, name="One")
normal_string, self.graph) normal_string = "This is just a normal string"
self.assertEqual(interpolated_string, normal_string) interpolated_string = error_interpolation.interpolate(
normal_string, ops.get_default_graph())
self.assertEqual(interpolated_string, normal_string)
def testOneTagWithAFakeNameResultsInPlaceholders(self): def testOneTagWithAFakeNameResultsInPlaceholders(self):
one_tag_string = "{{node MinusOne}}" with ops.Graph().as_default():
interpolated_string = error_interpolation.interpolate( one_tag_string = "{{node MinusOne}}"
one_tag_string, self.graph) interpolated_string = error_interpolation.interpolate(
self.assertEqual(one_tag_string, interpolated_string) one_tag_string, ops.get_default_graph())
self.assertEqual(one_tag_string, interpolated_string)
def testTwoTagsNoSeps(self): def testTwoTagsNoSeps(self):
two_tags_no_seps = "{{node One}}{{node Three}}" with ops.Graph().as_default():
interpolated_string = error_interpolation.interpolate( constant_op.constant(1, name="One")
two_tags_no_seps, self.graph) constant_op.constant(2, name="Two")
self.assertRegex( constant_op.constant(3, name="Three")
interpolated_string, r"error_interpolation_test\.py:[0-9]+." two_tags_no_seps = "{{node One}}{{node Three}}"
r"*error_interpolation_test\.py:[0-9]+") interpolated_string = error_interpolation.interpolate(
two_tags_no_seps, ops.get_default_graph())
self.assertRegex(
interpolated_string, r"error_interpolation_test\.py:[0-9]+."
r"*error_interpolation_test\.py:[0-9]+")
def testTwoTagsWithSeps(self): def testTwoTagsWithSeps(self):
two_tags_with_seps = ";;;{{node Two}},,,{{node Three}};;;" with ops.Graph().as_default():
interpolated_string = error_interpolation.interpolate( constant_op.constant(1, name="One")
two_tags_with_seps, self.graph) constant_op.constant(2, name="Two")
expected_regex = (r"^;;;.*error_interpolation_test\.py:[0-9]+\) " constant_op.constant(3, name="Three")
r",,,.*error_interpolation_test\.py:[0-9]+\) ;;;$") two_tags_with_seps = ";;;{{node Two}},,,{{node Three}};;;"
self.assertRegex(interpolated_string, expected_regex) interpolated_string = error_interpolation.interpolate(
two_tags_with_seps, ops.get_default_graph())
expected_regex = (r"^;;;.*error_interpolation_test\.py:[0-9]+\) "
r",,,.*error_interpolation_test\.py:[0-9]+\) ;;;$")
self.assertRegex(interpolated_string, expected_regex)
def testNewLine(self): def testNewLine(self):
newline = "\n\n{{node One}}" with ops.Graph().as_default():
interpolated_string = error_interpolation.interpolate(newline, self.graph) constant_op.constant(1, name="One")
self.assertRegex(interpolated_string, constant_op.constant(2, name="Two")
r"error_interpolation_test\.py:[0-9]+.*") newline = "\n\n{{node One}}"
interpolated_string = error_interpolation.interpolate(
newline, ops.get_default_graph())
self.assertRegex(interpolated_string,
r"error_interpolation_test\.py:[0-9]+.*")
@test_util.run_deprecated_v1
class InputNodesTest(test.TestCase): class InputNodesTest(test.TestCase):
def setUp(self):
super(InputNodesTest, self).setUp()
# Add nodes to the graph for retrieval by name later.
one = constant_op.constant(1, name="One")
two = constant_op.constant(2, name="Two")
three = math_ops.add(one, two, name="Three")
non_traceback_op = constant_op.constant(3, name="NonTraceback")
# Ensure op without traceback does not fail
del non_traceback_op.op._traceback
self.graph = three.graph
def testNoInputs(self): def testNoInputs(self):
two_tags_with_seps = ";;;{{node One}},,,{{node Two}};;;" with ops.Graph().as_default():
interpolated_string = error_interpolation.interpolate( one = constant_op.constant(1, name="One")
two_tags_with_seps, self.graph) two = constant_op.constant(2, name="Two")
expected_regex = (r"^;;;.*error_interpolation_test\.py:[0-9]+\) " _ = math_ops.add(one, two, name="Three")
r",,,.*error_interpolation_test\.py:[0-9]+\) ;;;$") two_tags_with_seps = ";;;{{node One}},,,{{node Two}};;;"
self.assertRegex(interpolated_string, expected_regex) interpolated_string = error_interpolation.interpolate(
two_tags_with_seps, ops.get_default_graph())
expected_regex = (r"^;;;.*error_interpolation_test\.py:[0-9]+\) "
r",,,.*error_interpolation_test\.py:[0-9]+\) ;;;$")
self.assertRegex(interpolated_string, expected_regex)
def testBasicInputs(self): def testBasicInputs(self):
tag = ";;;{{node Three}};;;" with ops.Graph().as_default():
interpolated_string = error_interpolation.interpolate(tag, self.graph) one = constant_op.constant(1, name="One")
expected_regex = re.compile( two = constant_op.constant(2, name="Two")
r"^;;;.*error_interpolation_test\.py:[0-9]+\) " _ = math_ops.add(one, two, name="Three")
r";;;.*Input.*error_interpolation_test\.py:[0-9]+\)", re.DOTALL) tag = ";;;{{node Three}};;;"
self.assertRegex(interpolated_string, expected_regex) interpolated_string = error_interpolation.interpolate(
tag, ops.get_default_graph())
expected_regex = re.compile(
r"^;;;.*error_interpolation_test\.py:[0-9]+\) "
r";;;.*Input.*error_interpolation_test\.py:[0-9]+\)", re.DOTALL)
self.assertRegex(interpolated_string, expected_regex)
@test_util.run_deprecated_v1
class InterpolateDeviceSummaryTest(test.TestCase): class InterpolateDeviceSummaryTest(test.TestCase):
def _fancy_device_function(self, unused_op): def _fancy_device_function(self, unused_op):
return "/cpu:*" return "/cpu:*"
def setUp(self):
super(InterpolateDeviceSummaryTest, self).setUp()
ops.reset_default_graph()
self.zero = constant_op.constant([0.0], name="zero")
with ops.device("/cpu"):
self.one = constant_op.constant([1.0], name="one")
with ops.device("/cpu:0"):
self.two = constant_op.constant([2.0], name="two")
with ops.device(self._fancy_device_function):
self.three = constant_op.constant(3.0, name="three")
self.graph = self.three.graph
def testNodeZeroHasNoDeviceSummaryInfo(self): def testNodeZeroHasNoDeviceSummaryInfo(self):
message = "{{colocation_node zero}}" with ops.Graph().as_default():
result = error_interpolation.interpolate(message, self.graph) self.zero = constant_op.constant([0.0], name="zero")
self.assertIn("No device assignments were active", result) message = "{{colocation_node zero}}"
result = error_interpolation.interpolate(message, ops.get_default_graph())
self.assertIn("No device assignments were active", result)
def testNodeOneHasExactlyOneInterpolatedDevice(self): def testNodeOneHasExactlyOneInterpolatedDevice(self):
message = "{{colocation_node one}}" with ops.Graph().as_default():
result = error_interpolation.interpolate(message, self.graph) with ops.device("/cpu"):
self.assertEqual(2, result.count("tf.device(/cpu)")) self.one = constant_op.constant([1.0], name="one")
message = "{{colocation_node one}}"
result = error_interpolation.interpolate(message, ops.get_default_graph())
self.assertEqual(2, result.count("tf.device(/cpu)"))
def testNodeTwoHasTwoInterpolatedDevice(self): def testNodeTwoHasTwoInterpolatedDevice(self):
message = "{{colocation_node two}}" with ops.Graph().as_default():
result = error_interpolation.interpolate(message, self.graph) with ops.device("/cpu"):
self.assertEqual(2, result.count("tf.device(/cpu)")) with ops.device("/cpu:0"):
self.assertEqual(2, result.count("tf.device(/cpu:0)")) self.two = constant_op.constant([2.0], name="two")
message = "{{colocation_node two}}"
result = error_interpolation.interpolate(message, ops.get_default_graph())
self.assertEqual(2, result.count("tf.device(/cpu)"))
self.assertEqual(2, result.count("tf.device(/cpu:0)"))
def testNodeThreeHasFancyFunctionDisplayNameForInterpolatedDevice(self): def testNodeThreeHasFancyFunctionDisplayNameForInterpolatedDevice(self):
message = "{{colocation_node three}}" with ops.Graph().as_default():
result = error_interpolation.interpolate(message, self.graph) with ops.device(self._fancy_device_function):
num_devices = result.count("tf.device") self.three = constant_op.constant(3.0, name="three")
self.assertEqual(2, num_devices) message = "{{colocation_node three}}"
name_re = r"_fancy_device_function<.*error_interpolation_test.py, [0-9]+>" result = error_interpolation.interpolate(message, ops.get_default_graph())
expected_re = r"with tf.device\(.*%s\)" % name_re num_devices = result.count("tf.device")
self.assertRegex(result, expected_re) self.assertEqual(2, num_devices)
name_re = r"_fancy_device_function<.*error_interpolation_test.py, [0-9]+>"
expected_re = r"with tf.device\(.*%s\)" % name_re
self.assertRegex(result, expected_re)
@test_util.run_deprecated_v1
class InterpolateColocationSummaryTest(test.TestCase): class InterpolateColocationSummaryTest(test.TestCase):
def setUp(self): def _set_up_graph(self):
super(InterpolateColocationSummaryTest, self).setUp()
ops.reset_default_graph()
# Add nodes to the graph for retrieval by name later. # Add nodes to the graph for retrieval by name later.
node_one = constant_op.constant(1, name="One") node_one = constant_op.constant(1, name="One")
node_two = constant_op.constant(2, name="Two") node_two = constant_op.constant(2, name="Two")
@ -359,32 +357,39 @@ class InterpolateColocationSummaryTest(test.TestCase):
with ops.colocate_with(node_one): with ops.colocate_with(node_one):
constant_op.constant(5, name="Five_with_one_with_two") constant_op.constant(5, name="Five_with_one_with_two")
self.graph = node_three.graph
def testNodeThreeHasColocationInterpolation(self): def testNodeThreeHasColocationInterpolation(self):
message = "{{colocation_node Three_with_one}}" with ops.Graph().as_default():
result = error_interpolation.interpolate(message, self.graph) self._set_up_graph()
self.assertIn("colocate_with(One)", result) message = "{{colocation_node Three_with_one}}"
result = error_interpolation.interpolate(message, ops.get_default_graph())
self.assertIn("colocate_with(One)", result)
def testNodeFourHasColocationInterpolationForNodeThreeOnly(self): def testNodeFourHasColocationInterpolationForNodeThreeOnly(self):
message = "{{colocation_node Four_with_three}}" with ops.Graph().as_default():
result = error_interpolation.interpolate(message, self.graph) self._set_up_graph()
self.assertIn("colocate_with(Three_with_one)", result) message = "{{colocation_node Four_with_three}}"
self.assertNotIn( result = error_interpolation.interpolate(message, ops.get_default_graph())
"One", result, self.assertIn("colocate_with(Three_with_one)", result)
"Node One should not appear in Four_with_three's summary:\n%s" % result) self.assertNotIn(
"One", result,
"Node One should not appear in Four_with_three's summary:\n%s" %
result)
def testNodeFiveHasColocationInterpolationForNodeOneAndTwo(self): def testNodeFiveHasColocationInterpolationForNodeOneAndTwo(self):
message = "{{colocation_node Five_with_one_with_two}}" with ops.Graph().as_default():
result = error_interpolation.interpolate(message, self.graph) self._set_up_graph()
self.assertIn("colocate_with(One)", result) message = "{{colocation_node Five_with_one_with_two}}"
self.assertIn("colocate_with(Two)", result) result = error_interpolation.interpolate(message, ops.get_default_graph())
self.assertIn("colocate_with(One)", result)
self.assertIn("colocate_with(Two)", result)
def testColocationInterpolationForNodeLackingColocation(self): def testColocationInterpolationForNodeLackingColocation(self):
message = "{{colocation_node One}}" with ops.Graph().as_default():
result = error_interpolation.interpolate(message, self.graph) self._set_up_graph()
self.assertIn("No node-device colocations", result) message = "{{colocation_node One}}"
self.assertNotIn("Two", result) result = error_interpolation.interpolate(message, ops.get_default_graph())
self.assertIn("No node-device colocations", result)
self.assertNotIn("Two", result)
class IsFrameworkFilenameTest(test.TestCase): class IsFrameworkFilenameTest(test.TestCase):