[tfdbg2] Fix graph-mode path_length_limit and stack_heigth_limit in enable_check_numerics()
Cause of the bug: - Previously, the helper method get_check_numerics_error_message() was called with the proper kwargs only under eager mode. The graph mode code path incorrectly omitted the kwargs. This CL fixes that. The fix is covered by mock-based unit tests. PiperOrigin-RevId: 312994212 Change-Id: I8800ec85741da6efe8fb8f3115ea7f57a38f0882
This commit is contained in:
parent
1727b70d6a
commit
a1f496664e
|
@ -275,7 +275,9 @@ class CheckNumericsCallback(object):
|
|||
output,
|
||||
inputs,
|
||||
graph=graph,
|
||||
traceback=output.op.traceback))
|
||||
traceback=output.op.traceback,
|
||||
stack_height_limit=self._stack_height_limit,
|
||||
path_length_limit=self._path_length_limit))
|
||||
_CHECK_NUMERICS_INPUT_LOOKUP[graph][checked_output.name] = output
|
||||
instrumented_outputs.append(self._get_output_tensor(
|
||||
op_type_bytes, output, checked_output, is_v1_graph_mode))
|
||||
|
|
|
@ -39,6 +39,7 @@ from tensorflow.python.ops import math_grad # pylint: disable=unused-import
|
|||
from tensorflow.python.ops import math_ops
|
||||
from tensorflow.python.ops import variables
|
||||
from tensorflow.python.platform import googletest
|
||||
from tensorflow.python.platform import test
|
||||
|
||||
|
||||
class LimitStringLengthTest(test_util.TensorFlowTestCase):
|
||||
|
@ -105,6 +106,27 @@ class CheckNumericsCallbackTest(test_util.TensorFlowTestCase):
|
|||
self.assertAllClose(batches[0], np.log([1.25, 2]))
|
||||
self.assertAllClose(batches[1], np.log([3.25, 5]))
|
||||
|
||||
@test_util.run_in_graph_and_eager_modes
|
||||
def testGraphModeUsesCorrectPathLengthAndStackHeightLimits(self):
|
||||
check_numerics_callback.enable_check_numerics(
|
||||
stack_height_limit=123, path_length_limit=1200)
|
||||
|
||||
@def_function.function
|
||||
def add_fn(x, y):
|
||||
return x + y
|
||||
|
||||
fake_get_check_numerics_error_message = test.mock.MagicMock(
|
||||
return_value="dummy_message")
|
||||
with test.mock.patch.object(check_numerics_callback,
|
||||
"get_check_numerics_error_message",
|
||||
fake_get_check_numerics_error_message):
|
||||
x = constant_op.constant(2.0)
|
||||
y = constant_op.constant(3.0)
|
||||
self.assertAllClose(self.evaluate(add_fn(x, y)), 5.0)
|
||||
(_, call_kwargs) = fake_get_check_numerics_error_message.call_args
|
||||
self.assertEqual(call_kwargs["stack_height_limit"], 123)
|
||||
self.assertEqual(call_kwargs["path_length_limit"], 1200)
|
||||
|
||||
|
||||
class CheckNumericsCallbackUnhealthyTest(test_util.TensorFlowTestCase):
|
||||
"""Test for cases in which enable_check_numerics() catches infs or nans."""
|
||||
|
@ -372,6 +394,22 @@ class CheckNumericsCallbackUnhealthyTest(test_util.TensorFlowTestCase):
|
|||
re.search(r"graph op.*\"Xdivy\"", message)))
|
||||
self.assertTrue(re.search(r"dtype.*float32", message))
|
||||
|
||||
def testEagerModeUsesCorrectPathLengthAndStackHeightLimits(self):
|
||||
check_numerics_callback.enable_check_numerics(
|
||||
stack_height_limit=123, path_length_limit=1200)
|
||||
fake_get_check_numerics_error_message = test.mock.MagicMock(
|
||||
return_value="dummy_message")
|
||||
with test.mock.patch.object(check_numerics_callback,
|
||||
"get_check_numerics_error_message",
|
||||
fake_get_check_numerics_error_message):
|
||||
x = constant_op.constant(2.0)
|
||||
y = constant_op.constant(0.0)
|
||||
self._assertRaisesInvalidArgumentErrorAndGetMessage(
|
||||
lambda: x / y) # Expected to generate an inf.
|
||||
(_, call_kwargs) = fake_get_check_numerics_error_message.call_args
|
||||
self.assertEqual(call_kwargs["stack_height_limit"], 123)
|
||||
self.assertEqual(call_kwargs["path_length_limit"], 1200)
|
||||
|
||||
@test_util.run_in_graph_and_eager_modes
|
||||
def testExpectedNaNOpOutputs(self):
|
||||
"""Test calling operations with benign NaN output."""
|
||||
|
|
Loading…
Reference in New Issue