Fix skipped tests in error_interpolation_test.py

These only make sense in graph mode, so I made this explicit within the tests.

PiperOrigin-RevId: 339050796
Change-Id: I97674d036f95882d8d4275e6b65e7c432ec0d943
This commit is contained in:
James Keeling 2020-10-26 09:11:17 -07:00 committed by TensorFlower Gardener
parent 8c78dae8fb
commit b8f5303051

View File

@ -25,7 +25,6 @@ import re
from tensorflow.python.framework import constant_op from tensorflow.python.framework import constant_op
from tensorflow.python.framework import error_interpolation from tensorflow.python.framework import error_interpolation
from tensorflow.python.framework import ops from tensorflow.python.framework import ops
from tensorflow.python.framework import test_util
from tensorflow.python.framework import traceable_stack from tensorflow.python.framework import traceable_stack
from tensorflow.python.ops import math_ops from tensorflow.python.ops import math_ops
from tensorflow.python.platform import test from tensorflow.python.platform import test
@ -126,13 +125,8 @@ class ComputeColocationSummaryFromOpTest(test.TestCase):
# so it is excluded from eager tests. Even when used in eager mode, it is # so it is excluded from eager tests. Even when used in eager mode, it is
# via FunctionGraphs, and directly verifying in graph mode is the narrowest # via FunctionGraphs, and directly verifying in graph mode is the narrowest
# way to unit test the functionality. # way to unit test the functionality.
@test_util.run_deprecated_v1
class CreateGraphDebugInfoDefTest(test.TestCase): class CreateGraphDebugInfoDefTest(test.TestCase):
def setUp(self):
super(CreateGraphDebugInfoDefTest, self).setUp()
ops.reset_default_graph()
def _getFirstStackTraceForFile(self, graph_debug_info, key, file_index): def _getFirstStackTraceForFile(self, graph_debug_info, key, file_index):
self.assertIn(key, graph_debug_info.traces) self.assertIn(key, graph_debug_info.traces)
stack_trace = graph_debug_info.traces[key] stack_trace = graph_debug_info.traces[key]
@ -146,6 +140,9 @@ class CreateGraphDebugInfoDefTest(test.TestCase):
return found_flc return found_flc
def testStackTraceExtraction(self): def testStackTraceExtraction(self):
# This test is verifying stack trace information added in graph mode, so
# only makes sense in graph mode.
with ops.Graph().as_default():
# Since the create_graph_debug_info_def() function does not actually # Since the create_graph_debug_info_def() function does not actually
# do anything special with functions except name mangling, just verify # do anything special with functions except name mangling, just verify
# it with a loose op and manually provided function name. # it with a loose op and manually provided function name.
@ -185,19 +182,10 @@ class CreateGraphDebugInfoDefTest(test.TestCase):
self.assertEqual(op2_flc.line, global_line + 2, "op2 not on next line") self.assertEqual(op2_flc.line, global_line + 2, "op2 not on next line")
@test_util.run_deprecated_v1
class InterpolateFilenamesAndLineNumbersTest(test.TestCase): class InterpolateFilenamesAndLineNumbersTest(test.TestCase):
def setUp(self):
super(InterpolateFilenamesAndLineNumbersTest, self).setUp()
ops.reset_default_graph()
# Add nodes to the graph for retrieval by name later.
constant_op.constant(1, name="One")
constant_op.constant(2, name="Two")
three = constant_op.constant(3, name="Three")
self.graph = three.graph
def testFindIndexOfDefiningFrameForOp(self): def testFindIndexOfDefiningFrameForOp(self):
with ops.Graph().as_default():
local_op = constant_op.constant(42).op local_op = constant_op.constant(42).op
user_filename = "hope.py" user_filename = "hope.py"
_modify_op_stack_with_filenames( _modify_op_stack_with_filenames(
@ -205,13 +193,15 @@ class InterpolateFilenamesAndLineNumbersTest(test.TestCase):
num_user_frames=3, num_user_frames=3,
user_filename=user_filename, user_filename=user_filename,
num_inner_tf_frames=5) num_inner_tf_frames=5)
idx = error_interpolation._find_index_of_defining_frame(local_op._traceback) idx = error_interpolation._find_index_of_defining_frame(
# Expected frame is 6th from the end because there are 5 inner frames witih local_op._traceback)
# Expected frame is 6th from the end because there are 5 inner frames with
# TF filenames. # TF filenames.
expected_frame = len(local_op._traceback) - 6 expected_frame = len(local_op._traceback) - 6
self.assertEqual(expected_frame, idx) self.assertEqual(expected_frame, idx)
def testFindIndexOfDefiningFrameForOpReturnsZeroOnError(self): def testFindIndexOfDefiningFrameForOpReturnsZeroOnError(self):
with ops.Graph().as_default():
local_op = constant_op.constant(43).op local_op = constant_op.constant(43).op
# Truncate stack to known length. # Truncate stack to known length.
local_op._traceback = local_op._traceback[:7] local_op._traceback = local_op._traceback[:7]
@ -221,113 +211,124 @@ class InterpolateFilenamesAndLineNumbersTest(test.TestCase):
num_user_frames=0, num_user_frames=0,
user_filename="user_file.py", user_filename="user_file.py",
num_inner_tf_frames=7) num_inner_tf_frames=7)
idx = error_interpolation._find_index_of_defining_frame(local_op._traceback) idx = error_interpolation._find_index_of_defining_frame(
local_op._traceback)
self.assertEqual(0, idx) self.assertEqual(0, idx)
def testNothingToDo(self): def testNothingToDo(self):
with ops.Graph().as_default():
constant_op.constant(1, name="One")
normal_string = "This is just a normal string" normal_string = "This is just a normal string"
interpolated_string = error_interpolation.interpolate( interpolated_string = error_interpolation.interpolate(
normal_string, self.graph) normal_string, ops.get_default_graph())
self.assertEqual(interpolated_string, normal_string) self.assertEqual(interpolated_string, normal_string)
def testOneTagWithAFakeNameResultsInPlaceholders(self): def testOneTagWithAFakeNameResultsInPlaceholders(self):
with ops.Graph().as_default():
one_tag_string = "{{node MinusOne}}" one_tag_string = "{{node MinusOne}}"
interpolated_string = error_interpolation.interpolate( interpolated_string = error_interpolation.interpolate(
one_tag_string, self.graph) one_tag_string, ops.get_default_graph())
self.assertEqual(one_tag_string, interpolated_string) self.assertEqual(one_tag_string, interpolated_string)
def testTwoTagsNoSeps(self): def testTwoTagsNoSeps(self):
with ops.Graph().as_default():
constant_op.constant(1, name="One")
constant_op.constant(2, name="Two")
constant_op.constant(3, name="Three")
two_tags_no_seps = "{{node One}}{{node Three}}" two_tags_no_seps = "{{node One}}{{node Three}}"
interpolated_string = error_interpolation.interpolate( interpolated_string = error_interpolation.interpolate(
two_tags_no_seps, self.graph) two_tags_no_seps, ops.get_default_graph())
self.assertRegex( self.assertRegex(
interpolated_string, r"error_interpolation_test\.py:[0-9]+." interpolated_string, r"error_interpolation_test\.py:[0-9]+."
r"*error_interpolation_test\.py:[0-9]+") r"*error_interpolation_test\.py:[0-9]+")
def testTwoTagsWithSeps(self): def testTwoTagsWithSeps(self):
with ops.Graph().as_default():
constant_op.constant(1, name="One")
constant_op.constant(2, name="Two")
constant_op.constant(3, name="Three")
two_tags_with_seps = ";;;{{node Two}},,,{{node Three}};;;" two_tags_with_seps = ";;;{{node Two}},,,{{node Three}};;;"
interpolated_string = error_interpolation.interpolate( interpolated_string = error_interpolation.interpolate(
two_tags_with_seps, self.graph) two_tags_with_seps, ops.get_default_graph())
expected_regex = (r"^;;;.*error_interpolation_test\.py:[0-9]+\) " expected_regex = (r"^;;;.*error_interpolation_test\.py:[0-9]+\) "
r",,,.*error_interpolation_test\.py:[0-9]+\) ;;;$") r",,,.*error_interpolation_test\.py:[0-9]+\) ;;;$")
self.assertRegex(interpolated_string, expected_regex) self.assertRegex(interpolated_string, expected_regex)
def testNewLine(self): def testNewLine(self):
with ops.Graph().as_default():
constant_op.constant(1, name="One")
constant_op.constant(2, name="Two")
newline = "\n\n{{node One}}" newline = "\n\n{{node One}}"
interpolated_string = error_interpolation.interpolate(newline, self.graph) interpolated_string = error_interpolation.interpolate(
newline, ops.get_default_graph())
self.assertRegex(interpolated_string, self.assertRegex(interpolated_string,
r"error_interpolation_test\.py:[0-9]+.*") r"error_interpolation_test\.py:[0-9]+.*")
@test_util.run_deprecated_v1
class InputNodesTest(test.TestCase): class InputNodesTest(test.TestCase):
def setUp(self): def testNoInputs(self):
super(InputNodesTest, self).setUp() with ops.Graph().as_default():
# Add nodes to the graph for retrieval by name later.
one = constant_op.constant(1, name="One") one = constant_op.constant(1, name="One")
two = constant_op.constant(2, name="Two") two = constant_op.constant(2, name="Two")
three = math_ops.add(one, two, name="Three") _ = math_ops.add(one, two, name="Three")
non_traceback_op = constant_op.constant(3, name="NonTraceback")
# Ensure op without traceback does not fail
del non_traceback_op.op._traceback
self.graph = three.graph
def testNoInputs(self):
two_tags_with_seps = ";;;{{node One}},,,{{node Two}};;;" two_tags_with_seps = ";;;{{node One}},,,{{node Two}};;;"
interpolated_string = error_interpolation.interpolate( interpolated_string = error_interpolation.interpolate(
two_tags_with_seps, self.graph) two_tags_with_seps, ops.get_default_graph())
expected_regex = (r"^;;;.*error_interpolation_test\.py:[0-9]+\) " expected_regex = (r"^;;;.*error_interpolation_test\.py:[0-9]+\) "
r",,,.*error_interpolation_test\.py:[0-9]+\) ;;;$") r",,,.*error_interpolation_test\.py:[0-9]+\) ;;;$")
self.assertRegex(interpolated_string, expected_regex) self.assertRegex(interpolated_string, expected_regex)
def testBasicInputs(self): def testBasicInputs(self):
with ops.Graph().as_default():
one = constant_op.constant(1, name="One")
two = constant_op.constant(2, name="Two")
_ = math_ops.add(one, two, name="Three")
tag = ";;;{{node Three}};;;" tag = ";;;{{node Three}};;;"
interpolated_string = error_interpolation.interpolate(tag, self.graph) interpolated_string = error_interpolation.interpolate(
tag, ops.get_default_graph())
expected_regex = re.compile( expected_regex = re.compile(
r"^;;;.*error_interpolation_test\.py:[0-9]+\) " r"^;;;.*error_interpolation_test\.py:[0-9]+\) "
r";;;.*Input.*error_interpolation_test\.py:[0-9]+\)", re.DOTALL) r";;;.*Input.*error_interpolation_test\.py:[0-9]+\)", re.DOTALL)
self.assertRegex(interpolated_string, expected_regex) self.assertRegex(interpolated_string, expected_regex)
@test_util.run_deprecated_v1
class InterpolateDeviceSummaryTest(test.TestCase): class InterpolateDeviceSummaryTest(test.TestCase):
def _fancy_device_function(self, unused_op): def _fancy_device_function(self, unused_op):
return "/cpu:*" return "/cpu:*"
def setUp(self):
super(InterpolateDeviceSummaryTest, self).setUp()
ops.reset_default_graph()
self.zero = constant_op.constant([0.0], name="zero")
with ops.device("/cpu"):
self.one = constant_op.constant([1.0], name="one")
with ops.device("/cpu:0"):
self.two = constant_op.constant([2.0], name="two")
with ops.device(self._fancy_device_function):
self.three = constant_op.constant(3.0, name="three")
self.graph = self.three.graph
def testNodeZeroHasNoDeviceSummaryInfo(self): def testNodeZeroHasNoDeviceSummaryInfo(self):
with ops.Graph().as_default():
self.zero = constant_op.constant([0.0], name="zero")
message = "{{colocation_node zero}}" message = "{{colocation_node zero}}"
result = error_interpolation.interpolate(message, self.graph) result = error_interpolation.interpolate(message, ops.get_default_graph())
self.assertIn("No device assignments were active", result) self.assertIn("No device assignments were active", result)
def testNodeOneHasExactlyOneInterpolatedDevice(self): def testNodeOneHasExactlyOneInterpolatedDevice(self):
with ops.Graph().as_default():
with ops.device("/cpu"):
self.one = constant_op.constant([1.0], name="one")
message = "{{colocation_node one}}" message = "{{colocation_node one}}"
result = error_interpolation.interpolate(message, self.graph) result = error_interpolation.interpolate(message, ops.get_default_graph())
self.assertEqual(2, result.count("tf.device(/cpu)")) self.assertEqual(2, result.count("tf.device(/cpu)"))
def testNodeTwoHasTwoInterpolatedDevice(self): def testNodeTwoHasTwoInterpolatedDevice(self):
with ops.Graph().as_default():
with ops.device("/cpu"):
with ops.device("/cpu:0"):
self.two = constant_op.constant([2.0], name="two")
message = "{{colocation_node two}}" message = "{{colocation_node two}}"
result = error_interpolation.interpolate(message, self.graph) result = error_interpolation.interpolate(message, ops.get_default_graph())
self.assertEqual(2, result.count("tf.device(/cpu)")) self.assertEqual(2, result.count("tf.device(/cpu)"))
self.assertEqual(2, result.count("tf.device(/cpu:0)")) self.assertEqual(2, result.count("tf.device(/cpu:0)"))
def testNodeThreeHasFancyFunctionDisplayNameForInterpolatedDevice(self): def testNodeThreeHasFancyFunctionDisplayNameForInterpolatedDevice(self):
with ops.Graph().as_default():
with ops.device(self._fancy_device_function):
self.three = constant_op.constant(3.0, name="three")
message = "{{colocation_node three}}" message = "{{colocation_node three}}"
result = error_interpolation.interpolate(message, self.graph) result = error_interpolation.interpolate(message, ops.get_default_graph())
num_devices = result.count("tf.device") num_devices = result.count("tf.device")
self.assertEqual(2, num_devices) self.assertEqual(2, num_devices)
name_re = r"_fancy_device_function<.*error_interpolation_test.py, [0-9]+>" name_re = r"_fancy_device_function<.*error_interpolation_test.py, [0-9]+>"
@ -335,12 +336,9 @@ class InterpolateDeviceSummaryTest(test.TestCase):
self.assertRegex(result, expected_re) self.assertRegex(result, expected_re)
@test_util.run_deprecated_v1
class InterpolateColocationSummaryTest(test.TestCase): class InterpolateColocationSummaryTest(test.TestCase):
def setUp(self): def _set_up_graph(self):
super(InterpolateColocationSummaryTest, self).setUp()
ops.reset_default_graph()
# Add nodes to the graph for retrieval by name later. # Add nodes to the graph for retrieval by name later.
node_one = constant_op.constant(1, name="One") node_one = constant_op.constant(1, name="One")
node_two = constant_op.constant(2, name="Two") node_two = constant_op.constant(2, name="Two")
@ -359,30 +357,37 @@ class InterpolateColocationSummaryTest(test.TestCase):
with ops.colocate_with(node_one): with ops.colocate_with(node_one):
constant_op.constant(5, name="Five_with_one_with_two") constant_op.constant(5, name="Five_with_one_with_two")
self.graph = node_three.graph
def testNodeThreeHasColocationInterpolation(self): def testNodeThreeHasColocationInterpolation(self):
with ops.Graph().as_default():
self._set_up_graph()
message = "{{colocation_node Three_with_one}}" message = "{{colocation_node Three_with_one}}"
result = error_interpolation.interpolate(message, self.graph) result = error_interpolation.interpolate(message, ops.get_default_graph())
self.assertIn("colocate_with(One)", result) self.assertIn("colocate_with(One)", result)
def testNodeFourHasColocationInterpolationForNodeThreeOnly(self): def testNodeFourHasColocationInterpolationForNodeThreeOnly(self):
with ops.Graph().as_default():
self._set_up_graph()
message = "{{colocation_node Four_with_three}}" message = "{{colocation_node Four_with_three}}"
result = error_interpolation.interpolate(message, self.graph) result = error_interpolation.interpolate(message, ops.get_default_graph())
self.assertIn("colocate_with(Three_with_one)", result) self.assertIn("colocate_with(Three_with_one)", result)
self.assertNotIn( self.assertNotIn(
"One", result, "One", result,
"Node One should not appear in Four_with_three's summary:\n%s" % result) "Node One should not appear in Four_with_three's summary:\n%s" %
result)
def testNodeFiveHasColocationInterpolationForNodeOneAndTwo(self): def testNodeFiveHasColocationInterpolationForNodeOneAndTwo(self):
with ops.Graph().as_default():
self._set_up_graph()
message = "{{colocation_node Five_with_one_with_two}}" message = "{{colocation_node Five_with_one_with_two}}"
result = error_interpolation.interpolate(message, self.graph) result = error_interpolation.interpolate(message, ops.get_default_graph())
self.assertIn("colocate_with(One)", result) self.assertIn("colocate_with(One)", result)
self.assertIn("colocate_with(Two)", result) self.assertIn("colocate_with(Two)", result)
def testColocationInterpolationForNodeLackingColocation(self): def testColocationInterpolationForNodeLackingColocation(self):
with ops.Graph().as_default():
self._set_up_graph()
message = "{{colocation_node One}}" message = "{{colocation_node One}}"
result = error_interpolation.interpolate(message, self.graph) result = error_interpolation.interpolate(message, ops.get_default_graph())
self.assertIn("No node-device colocations", result) self.assertIn("No node-device colocations", result)
self.assertNotIn("Two", result) self.assertNotIn("Two", result)