From e6f22ee5f4d483c1b05fbdaf8e8f5d55033f2bdb Mon Sep 17 00:00:00 2001 From: Shanqing Cai Date: Tue, 21 Apr 2020 18:39:29 -0700 Subject: [PATCH] [tfdbg2] Ensure Const ops in graphs are captured by op_callbacks Details of the changes: - In the Python API of tensorflow, Const ops are created by calling `_create_op_internal()` from constant_op.py. This differs from how most other ops are created, and is similar to Placeholder ops, which are already instrumented by tfdbg2' op_callbacks. In this CL, we add a op_callback hook to the code in constant_op.py to allow instrumentation of Const ops. that makes that call. - In `_ConstantValue()` in tensor_util.py, add a special case for `CheckNumericsV2` op, so the `constant_value()` does not treat the `CheckNumericsV2` op as the constant tensor value. Similarly, add special cases for `Identity` and `DebugIdentityV2`. - In `dumping_callback_test.py`, replace use of a deprecated Dataset API (`make_one_shot_iterator()`) with non-deprecated API (`iter()` and `next()`) - Make other necessary changes to tfdbg2's tests to accommodate the Const ops which were previously not instrumented, but are now. - Increase the shard_count of learning/brain/python/debug/tpu_callbacks_test.py to 6 to avoid timeouts under the instrumented number of instrumented ops. PiperOrigin-RevId: 307723353 Change-Id: Iecdbfcb439f6e04fc12c1503ad5339d42703e8bc --- .../debug/lib/check_numerics_callback_test.py | 29 +++++++++- .../debug/lib/debug_events_monitors_test.py | 12 +++- .../python/debug/lib/dumping_callback.py | 11 +++- .../python/debug/lib/dumping_callback_test.py | 56 +++++++++++++++++-- tensorflow/python/framework/constant_op.py | 15 +++-- .../python/framework/op_callbacks_test.py | 16 +++++- tensorflow/python/framework/tensor_util.py | 4 ++ .../kernel_tests/confusion_matrix_test.py | 4 +- 8 files changed, 124 insertions(+), 23 deletions(-) diff --git a/tensorflow/python/debug/lib/check_numerics_callback_test.py b/tensorflow/python/debug/lib/check_numerics_callback_test.py index ea5d70f0d08..5f578da03c3 100644 --- a/tensorflow/python/debug/lib/check_numerics_callback_test.py +++ b/tensorflow/python/debug/lib/check_numerics_callback_test.py @@ -94,10 +94,16 @@ class CheckNumericsCallbackTest(test_util.TensorFlowTestCase): dataset = dataset_ops.Dataset.from_tensor_slices(tensor).batch(2).map( map_fn) - iterator = dataset_ops.make_one_shot_iterator(dataset) - self.assertAllClose(self.evaluate(iterator.get_next()), np.log([1.25, 2])) - self.assertAllClose(self.evaluate(iterator.get_next()), np.log([3.25, 5])) + @def_function.function + def get_batches(): + iterator = iter(dataset) + return [next(iterator), next(iterator)] + + batches = self.evaluate(get_batches()) + self.assertLen(batches, 2) + self.assertAllClose(batches[0], np.log([1.25, 2])) + self.assertAllClose(batches[1], np.log([3.25, 5])) class CheckNumericsCallbackUnhealthyTest(test_util.TensorFlowTestCase): @@ -267,6 +273,23 @@ class CheckNumericsCallbackUnhealthyTest(test_util.TensorFlowTestCase): self.assertTrue(re.search(r"Stack trace of op's creation", message)) self.assertIn("accum.assign(accum * 2.0)", message) + @test_util.run_in_graph_and_eager_modes + def testNanInConstIsCaptured(self): + check_numerics_callback.enable_check_numerics() + v = variables.Variable(3.0, dtype=dtypes.float32) + @def_function.function + def add_a_bad_constant(x): + c = constant_op.constant(np.nan) + return x + c + if not context.executing_eagerly(): + self.evaluate(v.initializer) + message = self._assertRaisesInvalidArgumentErrorAndGetMessage( + lambda: self.evaluate(add_a_bad_constant(v))) + self.assertTrue(re.search(r"graph op.*\"Const\"", message)) + self.assertTrue(re.search(r"dtype:.*float32", message)) + self.assertTrue(re.search(r"shape:.*\(\)", message)) + self.assertTrue(re.search(r"Graph name:.*add_a_bad_constant", message)) + @test_util.run_in_graph_and_eager_modes def testCatchInfinityInDatasetMapFunction(self): """Test that callback catches NaN in a tf.dataset map function.""" diff --git a/tensorflow/python/debug/lib/debug_events_monitors_test.py b/tensorflow/python/debug/lib/debug_events_monitors_test.py index 05eaa510648..e8dcd6e4329 100644 --- a/tensorflow/python/debug/lib/debug_events_monitors_test.py +++ b/tensorflow/python/debug/lib/debug_events_monitors_test.py @@ -173,7 +173,8 @@ class DebugEventsMonitorTest(dumping_callback_test_lib.DumpingCallbackTestBase, self.assertLen(traces[1].debug_tensor_value, 11) self.assertLen(traces[2].debug_tensor_value, 11) elif tensor_debug_mode == "FULL_TENSOR": - self.assertLen(traces, 4) # [Placeholder:0, Unique:0, Unique:1, Sum:0]. + # [Placeholder:0, Unique:0, Unique:1, Const:0, Sum:0]. + self.assertLen(traces, 5) self.assertEqual(traces[0].op_type, "Placeholder") self.assertEqual(traces[0].output_slot, 0) self.assertIsNone(traces[0].debug_tensor_value) @@ -192,11 +193,16 @@ class DebugEventsMonitorTest(dumping_callback_test_lib.DumpingCallbackTestBase, self.assertAllEqual( reader.graph_execution_trace_to_tensor_value(traces[2]), [0, 1, 2, 3, 0]) - self.assertEqual(traces[3].op_type, "Sum") + self.assertEqual(traces[3].op_type, "Const") self.assertEqual(traces[3].output_slot, 0) self.assertIsNone(traces[3].debug_tensor_value) self.assertAllClose( - reader.graph_execution_trace_to_tensor_value(traces[3]), 17.) + reader.graph_execution_trace_to_tensor_value(traces[3]), [0]) + self.assertEqual(traces[4].op_type, "Sum") + self.assertEqual(traces[4].output_slot, 0) + self.assertIsNone(traces[4].debug_tensor_value) + self.assertAllClose( + reader.graph_execution_trace_to_tensor_value(traces[4]), 17.) class AlertDataObjectsTest(test_util.TensorFlowTestCase): diff --git a/tensorflow/python/debug/lib/dumping_callback.py b/tensorflow/python/debug/lib/dumping_callback.py index 921891033ab..efc5caae321 100644 --- a/tensorflow/python/debug/lib/dumping_callback.py +++ b/tensorflow/python/debug/lib/dumping_callback.py @@ -292,7 +292,12 @@ class _DumpingCallback(object): # TODO(cais): Evaluate performance optimization options. For the # `NO_TENSOR` debug mode, an alternative is to add `debug_tensor` as a # control dependency of `tensor.op` without an additional identity op. - if tensor_debug_mode == debug_event_pb2.TensorDebugMode.FULL_TENSOR: + if (tensor_debug_mode == debug_event_pb2.TensorDebugMode.FULL_TENSOR and + op_type != "Const"): + # NOTE(b/153716279): Under v1 graph mode, overriding the output tensor + # of Const ops can lead to downstream errors related to shapes. We opt + # to use an identity op to avoid this issue at the cost of slightly + # larger graph size. return debug_tensor else: identity = array_ops.identity(tensor) @@ -530,8 +535,8 @@ class _DumpingCallback(object): is_v1_graph_mode = not ops.executing_eagerly_outside_functions() context_id = self._get_context_id(graph) # Innermost context ID. output_tensor_ids = self._get_symbolic_tensor_ids(len(outputs)) - if op_type in ("Placeholder", "PlaceholderWithDefault"): - # In some cases, the op name of a Placeholder op in a graph + if op_type in ("Const", "Placeholder", "PlaceholderWithDefault"): + # In some cases, the op name of a Const or Placeholder op in a graph # can be duplicate (e.g., with the name "resource"). # When this happens, we give the op an debugger-generated name # in order to prevent problems and check failures down the pipe. diff --git a/tensorflow/python/debug/lib/dumping_callback_test.py b/tensorflow/python/debug/lib/dumping_callback_test.py index 5f932ef87b2..3486430ccfa 100644 --- a/tensorflow/python/debug/lib/dumping_callback_test.py +++ b/tensorflow/python/debug/lib/dumping_callback_test.py @@ -289,7 +289,8 @@ class DumpingCallbackTest( with debug_events_reader.DebugDataReader(self.dump_root) as reader: reader.update() graph_exec_traces = reader.graph_execution_traces() - executed_op_types = [trace.op_type for trace in graph_exec_traces] + executed_op_types = [trace.op_type for trace in graph_exec_traces + if trace.op_type != "Const"] self.assertCountEqual( executed_op_types, ["Placeholder", "Placeholder", "AddV2", "Sub", "RealDiv"]) @@ -344,6 +345,46 @@ class DumpingCallbackTest( self.assertAllClose(trace.debug_tensor_value, [tensor_id, 19, 1, 8, 8, 0, 0, 0, 0, 0]) + @parameterized.named_parameters( + ("CurtHealth", "CURT_HEALTH"), + ("FullTensor", "FULL_TENSOR"), + ) + @test_util.run_in_graph_and_eager_modes + def testConstTensorsAreCaptured(self, tensor_debug_mode): + writer = dumping_callback.enable_dump_debug_info( + self.dump_root, tensor_debug_mode=tensor_debug_mode) + @def_function.function + def times_two_plus_three(x): + return x * constant_op.constant(2.0) + constant_op.constant(3.0) + self.assertAllEqual( + self.evaluate(times_two_plus_three(10.0)), 23.0) + writer.FlushNonExecutionFiles() + writer.FlushExecutionFiles() + + with debug_events_reader.DebugDataReader(self.dump_root) as reader: + reader.update() + const_traces = [trace for trace in reader.graph_execution_traces() + if trace.op_type == "Const"] + self.assertGreaterEqual(len(const_traces), 3) + if tensor_debug_mode == "CURT_HEALTH": + # Under CURT_HEALTH, each debug tensor value has the form + # [tensor_id, has_inf_or_nan]. + self.assertLen(const_traces[0].debug_tensor_value, 2) + self.assertEqual(const_traces[0].debug_tensor_value[1], 0) + self.assertLen(const_traces[1].debug_tensor_value, 2) + self.assertEqual(const_traces[1].debug_tensor_value[1], 0) + self.assertLen(const_traces[2].debug_tensor_value, 2) + self.assertEqual(const_traces[2].debug_tensor_value[1], 0) + else: # FULL_TENSOR. + const_tensor_values = [ + reader.graph_execution_trace_to_tensor_value(const_trace) + for const_trace in const_traces] + # Avoid making assertion on the particular order of the debug tensors + # for the three Consts because it may be indeterminate. + self.assertIn(10.0, const_tensor_values) + self.assertIn(2.0, const_tensor_values) + self.assertIn(3.0, const_tensor_values) + @parameterized.named_parameters( ("Shape", "SHAPE"), ) @@ -367,7 +408,8 @@ class DumpingCallbackTest( with debug_events_reader.DebugDataReader(self.dump_root) as reader: reader.update() graph_exec_traces = reader.graph_execution_traces() - executed_op_types = [trace.op_type for trace in graph_exec_traces] + executed_op_types = [trace.op_type for trace in graph_exec_traces + if trace.op_type != "Const"] self.assertEqual( executed_op_types, ["Placeholder", "Placeholder", "LogicalAnd", "LogicalNot"]) @@ -489,7 +531,8 @@ class DumpingCallbackTest( _, stack_frames = reader.read_graph_op_creation_stack_trace(op_digest) self._verifyStackFrames(stack_frames) - graph_exec_traces = reader.graph_execution_traces() + graph_exec_traces = [trace for trace in reader.graph_execution_traces() + if trace.op_type != "Const"] executed_op_types = [digest.op_type for digest in graph_exec_traces] self.assertEqual( executed_op_types, @@ -902,10 +945,10 @@ class DumpingCallbackTest( reader.update() graph_exec_digests = reader.graph_execution_traces(digest=True) executed_op_types = [digest.op_type for digest in graph_exec_digests - if digest.op_type != "Placeholder"] + if digest.op_type not in ("Const", "Placeholder")] tensor_values = [reader.graph_execution_trace_to_tensor_value(digest) for digest in graph_exec_digests - if digest.op_type != "Placeholder"] + if digest.op_type not in ("Const", "Placeholder")] if tensor_dtypes == [dtypes.float32] and not op_regex: self.assertEqual(executed_op_types, ["Unique", "Sum"]) @@ -1003,7 +1046,8 @@ class DumpingCallbackTest( self.assertAllClose(tensor_values, [8.0]) graph_exec_traces = reader.graph_execution_traces() - executed_op_types = [trace.op_type for trace in graph_exec_traces] + executed_op_types = [trace.op_type for trace in graph_exec_traces + if trace.op_type != "Const"] if tensor_debug_mode != "CURT_HEALTH": # Less outputs a boolean tensor, which is not tracked under CURT_HEALTH. # The Less op should have been executed 5 times. diff --git a/tensorflow/python/framework/constant_op.py b/tensorflow/python/framework/constant_op.py index 9736bb8b78b..af9a0f7738c 100644 --- a/tensorflow/python/framework/constant_op.py +++ b/tensorflow/python/framework/constant_op.py @@ -28,6 +28,7 @@ from tensorflow.core.framework import types_pb2 from tensorflow.python.eager import context from tensorflow.python.eager import execute from tensorflow.python.framework import dtypes +from tensorflow.python.framework import op_callbacks from tensorflow.python.framework import ops from tensorflow.python.framework import tensor_shape from tensorflow.python.framework import tensor_util @@ -299,11 +300,17 @@ def _constant_impl( value, dtype=dtype, shape=shape, verify_shape=verify_shape, allow_broadcast=allow_broadcast)) dtype_value = attr_value_pb2.AttrValue(type=tensor_value.tensor.dtype) + attrs = {"value": tensor_value, "dtype": dtype_value} const_tensor = g._create_op_internal( # pylint: disable=protected-access - "Const", [], [dtype_value.type], - attrs={"value": tensor_value, - "dtype": dtype_value}, - name=name).outputs[0] + "Const", [], [dtype_value.type], attrs=attrs, name=name).outputs[0] + + if op_callbacks.should_invoke_op_callbacks(): + # TODO(b/147670703): Once the special-op creation code paths + # are unified. Remove this `if` block. + callback_outputs = op_callbacks.invoke_op_callbacks( + "Const", tuple(), attrs, (const_tensor,), op_name=name, graph=g) + if callback_outputs is not None: + const_tensor, = callback_outputs return const_tensor diff --git a/tensorflow/python/framework/op_callbacks_test.py b/tensorflow/python/framework/op_callbacks_test.py index 31b6a583b8e..8868ffd664e 100644 --- a/tensorflow/python/framework/op_callbacks_test.py +++ b/tensorflow/python/framework/op_callbacks_test.py @@ -109,7 +109,8 @@ class _NumpyFunctionCallback(object): if compat.as_bytes(op_type) in (_ENTER_OP, _EXIT_OP, _IF_OP, _MERGE_OP, _NEXT_ITERATION_OP, _STATELESS_IF_OP, _SWITCH_OP, _WHILE_OP, _IDENTITY_OP, - _VAR_HANDLE_OP, _PLACEHOLDER_OP): + _VAR_HANDLE_OP, _PLACEHOLDER_OP, + _CONSTANT_OP): # TODO(cais): Overriding the output of StatelessIf, If and While ops # currently fails with error. Investigate (b/139668453). # Avoid instrumenting Identity ops as well, as they are inserted @@ -724,7 +725,7 @@ class OpCallbacksTest(test_util.TensorFlowTestCase): def testOverrideDTypeInFuncGraph(self): def to_float64(op_type, inputs, attrs, outputs, op_name=None, graph=None): del inputs, attrs, op_name, graph # Unused. - if op_type == "Placeholder": + if op_type in ("Const", "Placeholder"): return outputs else: return [math_ops.cast(output, dtypes.float64) for output in outputs] @@ -751,6 +752,17 @@ class OpCallbacksTest(test_util.TensorFlowTestCase): self.assertIsNone(w) self.assertEqual(instrument.eager_op_types, [_ADD_OP]) + def testOpCallbackCapturesConstTensors(self): + instrument = _NumpyFunctionCallback() + op_callbacks.add_op_callback(instrument.callback) + + @def_function.function + def times_two_plus_three(x): + return x * 2.0 + 3.0 + + self.assertAllClose(times_two_plus_three(constant_op.constant(10.0)), 23.0) + self.assertEqual(instrument.graph_op_types.count(b"Const"), 2) + @test_util.run_in_graph_and_eager_modes def testOpCallbackWorksWithGradientTape(self): instrument = _NumpyFunctionCallback() diff --git a/tensorflow/python/framework/tensor_util.py b/tensorflow/python/framework/tensor_util.py index 8f22cad4135..9c386ffa0c6 100644 --- a/tensorflow/python/framework/tensor_util.py +++ b/tensorflow/python/framework/tensor_util.py @@ -791,6 +791,10 @@ def _ConstantValue(tensor, partial): return np.not_equal(value1, value2) elif tensor.op.type == "StopGradient": return constant_value(tensor.op.inputs[0], partial) + elif tensor.op.type == "Identity": + return constant_value(tensor.op.inputs[0], partial) + elif tensor.op.type in ("CheckNumericsV2", "DebugIdentityV2"): + return constant_value(tensor.op.inputs[0], partial) else: return None diff --git a/tensorflow/python/kernel_tests/confusion_matrix_test.py b/tensorflow/python/kernel_tests/confusion_matrix_test.py index c1178253a4b..8ea9b9f83dd 100644 --- a/tensorflow/python/kernel_tests/confusion_matrix_test.py +++ b/tensorflow/python/kernel_tests/confusion_matrix_test.py @@ -188,7 +188,7 @@ class ConfusionMatrixTest(test.TestCase): def testLabelsTooLarge(self): labels = np.asarray([1, 1, 0, 3, 5], dtype=np.int32) predictions = np.asarray([2, 1, 0, 2, 2], dtype=np.int32) - with self.assertRaisesOpError("`labels`.*x < y"): + with self.assertRaisesOpError("`labels`.*out of bound"): self._testConfMatrix( labels=labels, predictions=predictions, num_classes=3, truth=None) @@ -203,7 +203,7 @@ class ConfusionMatrixTest(test.TestCase): def testPredictionsTooLarge(self): labels = np.asarray([1, 1, 0, 2, 2], dtype=np.int32) predictions = np.asarray([2, 1, 0, 3, 5], dtype=np.int32) - with self.assertRaisesOpError("`predictions`.*x < y"): + with self.assertRaisesOpError("`predictions`.*out of bound"): self._testConfMatrix( labels=labels, predictions=predictions, num_classes=3, truth=None)