[tfdbg2] Ensure Const ops in graphs are captured by op_callbacks

Details of the changes:
- In the Python API of tensorflow, Const ops are created by calling
  `_create_op_internal()` from constant_op.py. This differs from how most other ops
  are created, and is similar to Placeholder ops, which are already instrumented
  by tfdbg2' op_callbacks. In this CL, we add a op_callback hook to the code in
  constant_op.py to allow instrumentation of Const ops.
  that makes that call.
- In `_ConstantValue()` in tensor_util.py, add a special case for `CheckNumericsV2` op,
  so the `constant_value()` does not treat the `CheckNumericsV2` op as the constant
  tensor value. Similarly, add special cases for `Identity` and `DebugIdentityV2`.
- In `dumping_callback_test.py`, replace use of a deprecated Dataset API
  (`make_one_shot_iterator()`) with non-deprecated API (`iter()` and `next()`)
- Make other necessary changes to tfdbg2's tests to accommodate the Const ops
  which were previously not instrumented, but are now.
- Increase the shard_count of learning/brain/python/debug/tpu_callbacks_test.py to 6
  to avoid timeouts under the instrumented number of instrumented ops.

PiperOrigin-RevId: 307723353
Change-Id: Iecdbfcb439f6e04fc12c1503ad5339d42703e8bc
This commit is contained in:
Shanqing Cai 2020-04-21 18:39:29 -07:00 committed by TensorFlower Gardener
parent 1608a3a24a
commit e6f22ee5f4
8 changed files with 124 additions and 23 deletions

View File

@ -94,10 +94,16 @@ class CheckNumericsCallbackTest(test_util.TensorFlowTestCase):
dataset = dataset_ops.Dataset.from_tensor_slices(tensor).batch(2).map(
map_fn)
iterator = dataset_ops.make_one_shot_iterator(dataset)
self.assertAllClose(self.evaluate(iterator.get_next()), np.log([1.25, 2]))
self.assertAllClose(self.evaluate(iterator.get_next()), np.log([3.25, 5]))
@def_function.function
def get_batches():
iterator = iter(dataset)
return [next(iterator), next(iterator)]
batches = self.evaluate(get_batches())
self.assertLen(batches, 2)
self.assertAllClose(batches[0], np.log([1.25, 2]))
self.assertAllClose(batches[1], np.log([3.25, 5]))
class CheckNumericsCallbackUnhealthyTest(test_util.TensorFlowTestCase):
@ -267,6 +273,23 @@ class CheckNumericsCallbackUnhealthyTest(test_util.TensorFlowTestCase):
self.assertTrue(re.search(r"Stack trace of op's creation", message))
self.assertIn("accum.assign(accum * 2.0)", message)
@test_util.run_in_graph_and_eager_modes
def testNanInConstIsCaptured(self):
check_numerics_callback.enable_check_numerics()
v = variables.Variable(3.0, dtype=dtypes.float32)
@def_function.function
def add_a_bad_constant(x):
c = constant_op.constant(np.nan)
return x + c
if not context.executing_eagerly():
self.evaluate(v.initializer)
message = self._assertRaisesInvalidArgumentErrorAndGetMessage(
lambda: self.evaluate(add_a_bad_constant(v)))
self.assertTrue(re.search(r"graph op.*\"Const\"", message))
self.assertTrue(re.search(r"dtype:.*float32", message))
self.assertTrue(re.search(r"shape:.*\(\)", message))
self.assertTrue(re.search(r"Graph name:.*add_a_bad_constant", message))
@test_util.run_in_graph_and_eager_modes
def testCatchInfinityInDatasetMapFunction(self):
"""Test that callback catches NaN in a tf.dataset map function."""

View File

@ -173,7 +173,8 @@ class DebugEventsMonitorTest(dumping_callback_test_lib.DumpingCallbackTestBase,
self.assertLen(traces[1].debug_tensor_value, 11)
self.assertLen(traces[2].debug_tensor_value, 11)
elif tensor_debug_mode == "FULL_TENSOR":
self.assertLen(traces, 4) # [Placeholder:0, Unique:0, Unique:1, Sum:0].
# [Placeholder:0, Unique:0, Unique:1, Const:0, Sum:0].
self.assertLen(traces, 5)
self.assertEqual(traces[0].op_type, "Placeholder")
self.assertEqual(traces[0].output_slot, 0)
self.assertIsNone(traces[0].debug_tensor_value)
@ -192,11 +193,16 @@ class DebugEventsMonitorTest(dumping_callback_test_lib.DumpingCallbackTestBase,
self.assertAllEqual(
reader.graph_execution_trace_to_tensor_value(traces[2]),
[0, 1, 2, 3, 0])
self.assertEqual(traces[3].op_type, "Sum")
self.assertEqual(traces[3].op_type, "Const")
self.assertEqual(traces[3].output_slot, 0)
self.assertIsNone(traces[3].debug_tensor_value)
self.assertAllClose(
reader.graph_execution_trace_to_tensor_value(traces[3]), 17.)
reader.graph_execution_trace_to_tensor_value(traces[3]), [0])
self.assertEqual(traces[4].op_type, "Sum")
self.assertEqual(traces[4].output_slot, 0)
self.assertIsNone(traces[4].debug_tensor_value)
self.assertAllClose(
reader.graph_execution_trace_to_tensor_value(traces[4]), 17.)
class AlertDataObjectsTest(test_util.TensorFlowTestCase):

View File

@ -292,7 +292,12 @@ class _DumpingCallback(object):
# TODO(cais): Evaluate performance optimization options. For the
# `NO_TENSOR` debug mode, an alternative is to add `debug_tensor` as a
# control dependency of `tensor.op` without an additional identity op.
if tensor_debug_mode == debug_event_pb2.TensorDebugMode.FULL_TENSOR:
if (tensor_debug_mode == debug_event_pb2.TensorDebugMode.FULL_TENSOR and
op_type != "Const"):
# NOTE(b/153716279): Under v1 graph mode, overriding the output tensor
# of Const ops can lead to downstream errors related to shapes. We opt
# to use an identity op to avoid this issue at the cost of slightly
# larger graph size.
return debug_tensor
else:
identity = array_ops.identity(tensor)
@ -530,8 +535,8 @@ class _DumpingCallback(object):
is_v1_graph_mode = not ops.executing_eagerly_outside_functions()
context_id = self._get_context_id(graph) # Innermost context ID.
output_tensor_ids = self._get_symbolic_tensor_ids(len(outputs))
if op_type in ("Placeholder", "PlaceholderWithDefault"):
# In some cases, the op name of a Placeholder op in a graph
if op_type in ("Const", "Placeholder", "PlaceholderWithDefault"):
# In some cases, the op name of a Const or Placeholder op in a graph
# can be duplicate (e.g., with the name "resource").
# When this happens, we give the op an debugger-generated name
# in order to prevent problems and check failures down the pipe.

View File

@ -289,7 +289,8 @@ class DumpingCallbackTest(
with debug_events_reader.DebugDataReader(self.dump_root) as reader:
reader.update()
graph_exec_traces = reader.graph_execution_traces()
executed_op_types = [trace.op_type for trace in graph_exec_traces]
executed_op_types = [trace.op_type for trace in graph_exec_traces
if trace.op_type != "Const"]
self.assertCountEqual(
executed_op_types,
["Placeholder", "Placeholder", "AddV2", "Sub", "RealDiv"])
@ -344,6 +345,46 @@ class DumpingCallbackTest(
self.assertAllClose(trace.debug_tensor_value,
[tensor_id, 19, 1, 8, 8, 0, 0, 0, 0, 0])
@parameterized.named_parameters(
("CurtHealth", "CURT_HEALTH"),
("FullTensor", "FULL_TENSOR"),
)
@test_util.run_in_graph_and_eager_modes
def testConstTensorsAreCaptured(self, tensor_debug_mode):
writer = dumping_callback.enable_dump_debug_info(
self.dump_root, tensor_debug_mode=tensor_debug_mode)
@def_function.function
def times_two_plus_three(x):
return x * constant_op.constant(2.0) + constant_op.constant(3.0)
self.assertAllEqual(
self.evaluate(times_two_plus_three(10.0)), 23.0)
writer.FlushNonExecutionFiles()
writer.FlushExecutionFiles()
with debug_events_reader.DebugDataReader(self.dump_root) as reader:
reader.update()
const_traces = [trace for trace in reader.graph_execution_traces()
if trace.op_type == "Const"]
self.assertGreaterEqual(len(const_traces), 3)
if tensor_debug_mode == "CURT_HEALTH":
# Under CURT_HEALTH, each debug tensor value has the form
# [tensor_id, has_inf_or_nan].
self.assertLen(const_traces[0].debug_tensor_value, 2)
self.assertEqual(const_traces[0].debug_tensor_value[1], 0)
self.assertLen(const_traces[1].debug_tensor_value, 2)
self.assertEqual(const_traces[1].debug_tensor_value[1], 0)
self.assertLen(const_traces[2].debug_tensor_value, 2)
self.assertEqual(const_traces[2].debug_tensor_value[1], 0)
else: # FULL_TENSOR.
const_tensor_values = [
reader.graph_execution_trace_to_tensor_value(const_trace)
for const_trace in const_traces]
# Avoid making assertion on the particular order of the debug tensors
# for the three Consts because it may be indeterminate.
self.assertIn(10.0, const_tensor_values)
self.assertIn(2.0, const_tensor_values)
self.assertIn(3.0, const_tensor_values)
@parameterized.named_parameters(
("Shape", "SHAPE"),
)
@ -367,7 +408,8 @@ class DumpingCallbackTest(
with debug_events_reader.DebugDataReader(self.dump_root) as reader:
reader.update()
graph_exec_traces = reader.graph_execution_traces()
executed_op_types = [trace.op_type for trace in graph_exec_traces]
executed_op_types = [trace.op_type for trace in graph_exec_traces
if trace.op_type != "Const"]
self.assertEqual(
executed_op_types,
["Placeholder", "Placeholder", "LogicalAnd", "LogicalNot"])
@ -489,7 +531,8 @@ class DumpingCallbackTest(
_, stack_frames = reader.read_graph_op_creation_stack_trace(op_digest)
self._verifyStackFrames(stack_frames)
graph_exec_traces = reader.graph_execution_traces()
graph_exec_traces = [trace for trace in reader.graph_execution_traces()
if trace.op_type != "Const"]
executed_op_types = [digest.op_type for digest in graph_exec_traces]
self.assertEqual(
executed_op_types,
@ -902,10 +945,10 @@ class DumpingCallbackTest(
reader.update()
graph_exec_digests = reader.graph_execution_traces(digest=True)
executed_op_types = [digest.op_type for digest in graph_exec_digests
if digest.op_type != "Placeholder"]
if digest.op_type not in ("Const", "Placeholder")]
tensor_values = [reader.graph_execution_trace_to_tensor_value(digest)
for digest in graph_exec_digests
if digest.op_type != "Placeholder"]
if digest.op_type not in ("Const", "Placeholder")]
if tensor_dtypes == [dtypes.float32] and not op_regex:
self.assertEqual(executed_op_types, ["Unique", "Sum"])
@ -1003,7 +1046,8 @@ class DumpingCallbackTest(
self.assertAllClose(tensor_values, [8.0])
graph_exec_traces = reader.graph_execution_traces()
executed_op_types = [trace.op_type for trace in graph_exec_traces]
executed_op_types = [trace.op_type for trace in graph_exec_traces
if trace.op_type != "Const"]
if tensor_debug_mode != "CURT_HEALTH":
# Less outputs a boolean tensor, which is not tracked under CURT_HEALTH.
# The Less op should have been executed 5 times.

View File

@ -28,6 +28,7 @@ from tensorflow.core.framework import types_pb2
from tensorflow.python.eager import context
from tensorflow.python.eager import execute
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import op_callbacks
from tensorflow.python.framework import ops
from tensorflow.python.framework import tensor_shape
from tensorflow.python.framework import tensor_util
@ -299,11 +300,17 @@ def _constant_impl(
value, dtype=dtype, shape=shape, verify_shape=verify_shape,
allow_broadcast=allow_broadcast))
dtype_value = attr_value_pb2.AttrValue(type=tensor_value.tensor.dtype)
attrs = {"value": tensor_value, "dtype": dtype_value}
const_tensor = g._create_op_internal( # pylint: disable=protected-access
"Const", [], [dtype_value.type],
attrs={"value": tensor_value,
"dtype": dtype_value},
name=name).outputs[0]
"Const", [], [dtype_value.type], attrs=attrs, name=name).outputs[0]
if op_callbacks.should_invoke_op_callbacks():
# TODO(b/147670703): Once the special-op creation code paths
# are unified. Remove this `if` block.
callback_outputs = op_callbacks.invoke_op_callbacks(
"Const", tuple(), attrs, (const_tensor,), op_name=name, graph=g)
if callback_outputs is not None:
const_tensor, = callback_outputs
return const_tensor

View File

@ -109,7 +109,8 @@ class _NumpyFunctionCallback(object):
if compat.as_bytes(op_type) in (_ENTER_OP, _EXIT_OP, _IF_OP, _MERGE_OP,
_NEXT_ITERATION_OP, _STATELESS_IF_OP,
_SWITCH_OP, _WHILE_OP, _IDENTITY_OP,
_VAR_HANDLE_OP, _PLACEHOLDER_OP):
_VAR_HANDLE_OP, _PLACEHOLDER_OP,
_CONSTANT_OP):
# TODO(cais): Overriding the output of StatelessIf, If and While ops
# currently fails with error. Investigate (b/139668453).
# Avoid instrumenting Identity ops as well, as they are inserted
@ -724,7 +725,7 @@ class OpCallbacksTest(test_util.TensorFlowTestCase):
def testOverrideDTypeInFuncGraph(self):
def to_float64(op_type, inputs, attrs, outputs, op_name=None, graph=None):
del inputs, attrs, op_name, graph # Unused.
if op_type == "Placeholder":
if op_type in ("Const", "Placeholder"):
return outputs
else:
return [math_ops.cast(output, dtypes.float64) for output in outputs]
@ -751,6 +752,17 @@ class OpCallbacksTest(test_util.TensorFlowTestCase):
self.assertIsNone(w)
self.assertEqual(instrument.eager_op_types, [_ADD_OP])
def testOpCallbackCapturesConstTensors(self):
instrument = _NumpyFunctionCallback()
op_callbacks.add_op_callback(instrument.callback)
@def_function.function
def times_two_plus_three(x):
return x * 2.0 + 3.0
self.assertAllClose(times_two_plus_three(constant_op.constant(10.0)), 23.0)
self.assertEqual(instrument.graph_op_types.count(b"Const"), 2)
@test_util.run_in_graph_and_eager_modes
def testOpCallbackWorksWithGradientTape(self):
instrument = _NumpyFunctionCallback()

View File

@ -791,6 +791,10 @@ def _ConstantValue(tensor, partial):
return np.not_equal(value1, value2)
elif tensor.op.type == "StopGradient":
return constant_value(tensor.op.inputs[0], partial)
elif tensor.op.type == "Identity":
return constant_value(tensor.op.inputs[0], partial)
elif tensor.op.type in ("CheckNumericsV2", "DebugIdentityV2"):
return constant_value(tensor.op.inputs[0], partial)
else:
return None

View File

@ -188,7 +188,7 @@ class ConfusionMatrixTest(test.TestCase):
def testLabelsTooLarge(self):
labels = np.asarray([1, 1, 0, 3, 5], dtype=np.int32)
predictions = np.asarray([2, 1, 0, 2, 2], dtype=np.int32)
with self.assertRaisesOpError("`labels`.*x < y"):
with self.assertRaisesOpError("`labels`.*out of bound"):
self._testConfMatrix(
labels=labels, predictions=predictions, num_classes=3, truth=None)
@ -203,7 +203,7 @@ class ConfusionMatrixTest(test.TestCase):
def testPredictionsTooLarge(self):
labels = np.asarray([1, 1, 0, 2, 2], dtype=np.int32)
predictions = np.asarray([2, 1, 0, 3, 5], dtype=np.int32)
with self.assertRaisesOpError("`predictions`.*x < y"):
with self.assertRaisesOpError("`predictions`.*out of bound"):
self._testConfMatrix(
labels=labels, predictions=predictions, num_classes=3, truth=None)