[tfdbg] enable_dump_debug_info(): Support CONCISE_HEALTH and SHAPE modes

- Add eager and graph instrumentation logic for CONCISE_HEALTH and SHAPE
  modes to dumping_callback.py
- Add unit tests for the new modes on CPU, GPU and TPU:
  - To dumping_callback_test.py
  - To tpu_callbacks_test.py
- Add CONCISE_HEALTH and SHAPE to the continuously running TensorFlowGpuRegression
  benchmarks on Guitar.
- Register the DebugNumericSummaryV2 ops with more succinct template definitions.
- Expand the kernel registration to int and bool dtypes.
- Minor changes: Simplify and clean up some test code.

PiperOrigin-RevId: 286097114
Change-Id: I4e62984b7f11034dcb13fc10ee55b92d08011507
This commit is contained in:
Shanqing Cai 2019-12-17 17:40:09 -08:00 committed by TensorFlower Gardener
parent a51291d9b4
commit 316bd31e02
5 changed files with 269 additions and 35 deletions

View File

@ -160,6 +160,9 @@ TF_CALL_half(REGISTER_DEBUG_NUMERIC_SUMMARY_V2_FLOAT);
TF_CALL_bfloat16(REGISTER_DEBUG_NUMERIC_SUMMARY_V2_FLOAT);
TF_CALL_float(REGISTER_DEBUG_NUMERIC_SUMMARY_V2_FLOAT);
TF_CALL_double(REGISTER_DEBUG_NUMERIC_SUMMARY_V2_FLOAT);
TF_CALL_INTEGRAL_TYPES(REGISTER_DEBUG_NUMERIC_SUMMARY_V2_FLOAT);
TF_CALL_bool(REGISTER_DEBUG_NUMERIC_SUMMARY_V2_FLOAT);
// TODO(cais): Add string support.
#define REGISTER_DEBUG_NUMERIC_SUMMARY_V2_DOUBLE(type) \
REGISTER_KERNEL_BUILDER(Name("DebugNumericSummaryV2") \
@ -171,39 +174,31 @@ TF_CALL_half(REGISTER_DEBUG_NUMERIC_SUMMARY_V2_DOUBLE);
TF_CALL_bfloat16(REGISTER_DEBUG_NUMERIC_SUMMARY_V2_DOUBLE);
TF_CALL_float(REGISTER_DEBUG_NUMERIC_SUMMARY_V2_DOUBLE);
TF_CALL_double(REGISTER_DEBUG_NUMERIC_SUMMARY_V2_DOUBLE);
TF_CALL_INTEGRAL_TYPES(REGISTER_DEBUG_NUMERIC_SUMMARY_V2_DOUBLE);
TF_CALL_bool(REGISTER_DEBUG_NUMERIC_SUMMARY_V2_DOUBLE);
// TODO(cais): Add string support.
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
REGISTER_KERNEL_BUILDER(Name("DebugNumericSummaryV2")
.Device(DEVICE_GPU)
.TypeConstraint<Eigen::half>("T")
.TypeConstraint<float>("output_dtype"),
DebugNumericSummaryV2Op<GPUDevice, Eigen::half, float>);
REGISTER_KERNEL_BUILDER(Name("DebugNumericSummaryV2")
.Device(DEVICE_GPU)
.TypeConstraint<float>("T")
.TypeConstraint<float>("output_dtype"),
DebugNumericSummaryV2Op<GPUDevice, float, float>);
REGISTER_KERNEL_BUILDER(Name("DebugNumericSummaryV2")
.Device(DEVICE_GPU)
.TypeConstraint<double>("T")
.TypeConstraint<float>("output_dtype"),
DebugNumericSummaryV2Op<GPUDevice, double, float>);
REGISTER_KERNEL_BUILDER(
Name("DebugNumericSummaryV2")
.Device(DEVICE_GPU)
.TypeConstraint<Eigen::half>("T")
.TypeConstraint<double>("output_dtype"),
DebugNumericSummaryV2Op<GPUDevice, Eigen::half, double>);
REGISTER_KERNEL_BUILDER(Name("DebugNumericSummaryV2")
.Device(DEVICE_GPU)
.TypeConstraint<float>("T")
.TypeConstraint<double>("output_dtype"),
DebugNumericSummaryV2Op<GPUDevice, float, double>);
REGISTER_KERNEL_BUILDER(Name("DebugNumericSummaryV2")
.Device(DEVICE_GPU)
.TypeConstraint<double>("T")
.TypeConstraint<double>("output_dtype"),
DebugNumericSummaryV2Op<GPUDevice, double, double>);
#define REGISTER_DEBUG_NUMERIC_SUMMARY_V2_GPU(in_type, out_type) \
REGISTER_KERNEL_BUILDER( \
Name("DebugNumericSummaryV2") \
.Device(DEVICE_GPU) \
.TypeConstraint<in_type>("T") \
.TypeConstraint<out_type>("output_dtype"), \
DebugNumericSummaryV2Op<GPUDevice, in_type, out_type>);
REGISTER_DEBUG_NUMERIC_SUMMARY_V2_GPU(Eigen::half, float);
REGISTER_DEBUG_NUMERIC_SUMMARY_V2_GPU(float, float);
REGISTER_DEBUG_NUMERIC_SUMMARY_V2_GPU(double, float);
REGISTER_DEBUG_NUMERIC_SUMMARY_V2_GPU(int16, float);
REGISTER_DEBUG_NUMERIC_SUMMARY_V2_GPU(int32, float);
REGISTER_DEBUG_NUMERIC_SUMMARY_V2_GPU(Eigen::half, double);
REGISTER_DEBUG_NUMERIC_SUMMARY_V2_GPU(float, double);
REGISTER_DEBUG_NUMERIC_SUMMARY_V2_GPU(double, double);
REGISTER_DEBUG_NUMERIC_SUMMARY_V2_GPU(int16, double);
REGISTER_DEBUG_NUMERIC_SUMMARY_V2_GPU(int32, double);
#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
} // namespace tensorflow

View File

@ -168,9 +168,13 @@ struct CurtHealthLaunch {
template struct CurtHealthLaunch<Eigen::half, float>;
template struct CurtHealthLaunch<float, float>;
template struct CurtHealthLaunch<double, float>;
template struct CurtHealthLaunch<int16, float>;
template struct CurtHealthLaunch<int32, float>;
template struct CurtHealthLaunch<Eigen::half, double>;
template struct CurtHealthLaunch<float, double>;
template struct CurtHealthLaunch<double, double>;
template struct CurtHealthLaunch<int16, double>;
template struct CurtHealthLaunch<int32, double>;
template <typename Tin, typename Tout>
struct ConciseHealthLaunch {
@ -188,9 +192,13 @@ struct ConciseHealthLaunch {
template struct ConciseHealthLaunch<Eigen::half, float>;
template struct ConciseHealthLaunch<float, float>;
template struct ConciseHealthLaunch<double, float>;
template struct ConciseHealthLaunch<int16, float>;
template struct ConciseHealthLaunch<int32, float>;
template struct ConciseHealthLaunch<Eigen::half, double>;
template struct ConciseHealthLaunch<float, double>;
template struct ConciseHealthLaunch<double, double>;
template struct ConciseHealthLaunch<int16, double>;
template struct ConciseHealthLaunch<int32, double>;
template <typename Tin, typename Tout>
struct FullHealthLaunch {
@ -208,9 +216,13 @@ struct FullHealthLaunch {
template struct FullHealthLaunch<Eigen::half, float>;
template struct FullHealthLaunch<float, float>;
template struct FullHealthLaunch<double, float>;
template struct FullHealthLaunch<int16, float>;
template struct FullHealthLaunch<int32, float>;
template struct FullHealthLaunch<Eigen::half, double>;
template struct FullHealthLaunch<float, double>;
template struct FullHealthLaunch<double, double>;
template struct FullHealthLaunch<int16, double>;
template struct FullHealthLaunch<int32, double>;
template <typename Tin, typename Tout>
struct ReduceInfNanThreeSlotsLaunch {
@ -229,9 +241,13 @@ struct ReduceInfNanThreeSlotsLaunch {
template struct ReduceInfNanThreeSlotsLaunch<Eigen::half, float>;
template struct ReduceInfNanThreeSlotsLaunch<float, float>;
template struct ReduceInfNanThreeSlotsLaunch<double, float>;
template struct ReduceInfNanThreeSlotsLaunch<int16, float>;
template struct ReduceInfNanThreeSlotsLaunch<int32, float>;
template struct ReduceInfNanThreeSlotsLaunch<Eigen::half, double>;
template struct ReduceInfNanThreeSlotsLaunch<float, double>;
template struct ReduceInfNanThreeSlotsLaunch<double, double>;
template struct ReduceInfNanThreeSlotsLaunch<int16, double>;
template struct ReduceInfNanThreeSlotsLaunch<int32, double>;
} // namespace tensorflow
#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM

View File

@ -100,7 +100,8 @@ def parse_args():
type=str,
default="NO_TENSOR",
help="Mode for dumping tensor values. Options: NO_TENSOR, CURT_HEALTH, "
"FULL_TENSOR. This is relevant only when --dump_dir is set.")
"CONCISE_HEALTH, SHAPE, FULL_TENSOR. This is relevant only when "
"--dump_dir is set.")
# TODO(cais): Add more tensor debug mode strings once they are supported.
parser.add_argument(
"--dump_circular_buffer_size",

View File

@ -324,10 +324,20 @@ class _DumpingCallback(object):
debug_tensor.op)
instrumented_tensors.append(identity)
return instrumented_tensors
elif tensor_debug_mode == debug_event_pb2.TensorDebugMode.CURT_HEALTH:
elif tensor_debug_mode in (debug_event_pb2.TensorDebugMode.CURT_HEALTH,
debug_event_pb2.TensorDebugMode.CONCISE_HEALTH,
debug_event_pb2.TensorDebugMode.SHAPE):
for output_slot, tensor in enumerate(tensors):
dtype = tensor.dtype
dtype_is_dumpable = (
tensor_debug_mode in (
debug_event_pb2.TensorDebugMode.CURT_HEALTH,
debug_event_pb2.TensorDebugMode.CONCISE_HEALTH) and
dtype.is_floating or
tensor_debug_mode == debug_event_pb2.TensorDebugMode.SHAPE and
(dtype.is_floating or dtype.is_integer or dtype.is_bool))
if (not self._should_dump_tensor(op_type, tensor.dtype) or
not tensor.dtype.is_floating):
not dtype_is_dumpable):
if is_v1_graph_mode:
instrumented_tensors.append(tensor)
continue
@ -409,6 +419,8 @@ class _DumpingCallback(object):
tensor_debug_mode=tensor_debug_mode,
code_location=self._process_stack_frames())
elif tensor_debug_mode in (debug_event_pb2.TensorDebugMode.CURT_HEALTH,
debug_event_pb2.TensorDebugMode.CONCISE_HEALTH,
debug_event_pb2.TensorDebugMode.SHAPE,
debug_event_pb2.TensorDebugMode.FULL_TENSOR):
execution_proto = debug_event_pb2.Execution(
op_type=op_type,
@ -421,7 +433,9 @@ class _DumpingCallback(object):
for tensor in tensors:
if (self._should_dump_tensor(op_type, tensor.dtype) and
tensor.dtype.is_numpy_compatible):
if tensor_debug_mode == debug_event_pb2.TensorDebugMode.CURT_HEALTH:
if tensor_debug_mode in (
debug_event_pb2.TensorDebugMode.CURT_HEALTH,
debug_event_pb2.TensorDebugMode.CONCISE_HEALTH):
if tensor.dtype.is_floating:
tensor_proto = _concrete_tensor_to_proto(
gen_debug_ops.debug_numeric_summary_v2(
@ -431,6 +445,17 @@ class _DumpingCallback(object):
else:
# A placeholder for non-floating-type output tensors.
tensor_proto = tensor_pb2.TensorProto()
elif tensor_debug_mode == debug_event_pb2.TensorDebugMode.SHAPE:
if (tensor.dtype.is_floating or tensor.dtype.is_integer or
tensor.dtype.is_bool):
tensor_proto = _concrete_tensor_to_proto(
gen_debug_ops.debug_numeric_summary_v2(
tensor,
tensor_debug_mode=tensor_debug_mode,
output_dtype=dtypes.float64))
else:
# A placeholder for non-floating-type output tensors.
tensor_proto = tensor_pb2.TensorProto()
elif tensor_debug_mode == debug_event_pb2.TensorDebugMode.FULL_TENSOR:
tensor_proto = _concrete_tensor_to_proto(tensor)
if tensor_proto:
@ -657,6 +682,8 @@ def enable_dump_debug_info(dump_root,
tensor_debug_mode = debug_event_pb2.TensorDebugMode.Value(tensor_debug_mode)
if tensor_debug_mode not in (debug_event_pb2.TensorDebugMode.NO_TENSOR,
debug_event_pb2.TensorDebugMode.CURT_HEALTH,
debug_event_pb2.TensorDebugMode.CONCISE_HEALTH,
debug_event_pb2.TensorDebugMode.SHAPE,
debug_event_pb2.TensorDebugMode.FULL_TENSOR):
raise NotImplementedError(
"tfdbg dumping: support for tensor debug mode %s is not "

View File

@ -88,6 +88,8 @@ class TracingCallbackTest(
@parameterized.named_parameters(
("NoTensor", "NO_TENSOR"),
("CurtHealth", "CURT_HEALTH"),
("ConciseHealth", "CONCISE_HEALTH"),
("Shape", "SHAPE"),
("FullTensor", "FULL_TENSOR"),
)
def testPureEagerOpExecution(self, tensor_debug_mode):
@ -146,6 +148,26 @@ class TracingCallbackTest(
self.assertAllClose(
tensor_util.MakeNdarray(execution.tensor_protos[0]),
[-1.0, 0.0])
elif tensor_debug_mode == "CONCISE_HEALTH":
self.assertLen(execution.tensor_protos, 1)
if execution.op_type in ("AddV2", "Mul", "RealDiv"):
# 1st element: -1 is the unset tensor_id for eager op execution.
# 2nd element: each scalar tensor has 1 element.
# Remaining elements: no -inf, inf or nan in these
self.assertAllClose(
tensor_util.MakeNdarray(execution.tensor_protos[0]),
[-1, 1, 0, 0, 0])
elif tensor_debug_mode == "SHAPE":
self.assertLen(execution.tensor_protos, 1)
if execution.op_type in ("AddV2", "Mul", "RealDiv"):
# 1st element: -1 is the unset tensor_id for eager op execution.
# 2nd element: dtype enum value (float32).
# 3rd element: rank (scalar).
# 4th element: element count (4).
# Remaining elements: shape at fixed length (6).
self.assertAllClose(
tensor_util.MakeNdarray(execution.tensor_protos[0]),
[-1, 1, 0, 1, 0, 0, 0, 0, 0, 0])
elif tensor_debug_mode == "FULL_TENSOR":
# Under the FULL_TENSOR mode, the value of the tensor should be
# available through `tensor_protos`.
@ -202,9 +224,127 @@ class TracingCallbackTest(
with self.assertRaises(StopIteration):
next(graph_trace_iter)
@parameterized.named_parameters(
("CurtHealth", "CURT_HEALTH"),
("ConciseHealth", "CONCISE_HEALTH"),
("Shape", "SHAPE"),
)
@test_util.run_in_graph_and_eager_modes
def testModesSummarizingBadNumericalValue(self, tensor_debug_mode):
writer = dumping_callback.enable_dump_debug_info(
self.dump_root, tensor_debug_mode=tensor_debug_mode)
@def_function.function
def func(x, y):
return (x + y) / (x - y)
x = np.array([-3, -1, 0, 0, 1, 1, 1, 2], dtype=np.float16)
y = np.array([2, -1, 0, 0, 1, 1, 1, 3], dtype=np.float16)
# (x + y) / (x - y) = [0.2, -inf, nan, nan, inf, inf, inf, -5].
self.evaluate(func(x, y))
writer.FlushNonExecutionFiles()
writer.FlushExecutionFiles()
stack_frame_by_id = self._readAndCheckSourceFilesAndStackFrames()
(context_ids,
_, op_name_to_op_type, _) = self._readAndCheckGraphsFile(stack_frame_by_id)
(op_names, _, _,
tensor_values) = self._readAndCheckGraphExecutionTracesFile(context_ids)
executed_op_types = [op_name_to_op_type[op_name] for op_name in op_names]
self.assertCountEqual(executed_op_types, ["AddV2", "Sub", "RealDiv"])
if tensor_debug_mode == "CURT_HEALTH":
for op_type, tensor_value in zip(executed_op_types, tensor_values):
self.assertLen(tensor_value, 2)
# 1st element: tensor_id, should be >= 0.
# TODO(cais): Assert on detailed value once Function-graph association
# is in place.
self.assertGreaterEqual(tensor_value[0], 0)
# 2nd element: 0 means there is no inf or nan.
if op_type == "RealDiv":
self.assertEqual(tensor_value[1], 1)
else:
self.assertEqual(tensor_value[1], 0)
elif tensor_debug_mode == "CONCISE_HEALTH":
for op_type, tensor_value in zip(executed_op_types, tensor_values):
self.assertLen(tensor_value, 5)
# 1st element: tensor_id, should be >= 0.
# TODO(cais): Assert on detailed value once Function-graph association
# is in place.
self.assertGreaterEqual(tensor_value[0], 0)
# 2nd element: element count.
self.assertEqual(tensor_value[1], 8)
# Remaining 3 elements: The counts of -inf, inf and nan.
if op_type == "RealDiv":
self.assertAllClose(tensor_value[2:], [1, 3, 2])
else:
self.assertAllClose(tensor_value[2:], [0, 0, 0])
else: # SHAPE.
for op_type, tensor_value in zip(executed_op_types, tensor_values):
self.assertLen(tensor_value, 10)
# 1st element: tensor_id, should be >= 0.
# TODO(cais): Assert on detailed value once Function-graph association
# is in place.
self.assertGreaterEqual(tensor_value[0], 0)
# 2nd element: dtype enum value (float16).
self.assertEqual(tensor_value[1], 19)
# 3rd element: rank (1)
self.assertEqual(tensor_value[2], 1)
# 4th element: element count.
self.assertEqual(tensor_value[3], 8)
# Remaining elements: shape at fixed length.
self.assertAllClose(tensor_value[4:], [8, 0, 0, 0, 0, 0])
@parameterized.named_parameters(
("Shape", "SHAPE"),
)
@test_util.run_in_graph_and_eager_modes
def testBooleanTensors(self, tensor_debug_mode):
writer = dumping_callback.enable_dump_debug_info(
self.dump_root, tensor_debug_mode=tensor_debug_mode)
@def_function.function
def func(x, y):
return math_ops.logical_not(math_ops.logical_and(x, y))
x = np.array([[False, False], [True, True]], dtype=np.bool)
y = np.array([[False, True], [False, True]], dtype=np.bool)
self.assertAllEqual(
self.evaluate(func(x, y)), [[True, True], [True, False]])
writer.FlushNonExecutionFiles()
writer.FlushExecutionFiles()
stack_frame_by_id = self._readAndCheckSourceFilesAndStackFrames()
(context_ids,
_, op_name_to_op_type, _) = self._readAndCheckGraphsFile(stack_frame_by_id)
(op_names, _, _,
tensor_values) = self._readAndCheckGraphExecutionTracesFile(context_ids)
executed_op_types = [op_name_to_op_type[op_name] for op_name in op_names]
self.assertEqual(executed_op_types, ["LogicalAnd", "LogicalNot"])
for tensor_value in tensor_values:
# 1st element: tensor_id, should be >= 0.
# TODO(cais): Assert on detailed value once Function-graph association
# is in place.
self.assertGreaterEqual(tensor_value[0], 0)
# 2nd element: dtype enum value (bool).
self.assertEqual(tensor_value[1], 10)
# 3rd element: rank (2)
self.assertEqual(tensor_value[2], 2)
# 4th element: element count.
self.assertEqual(tensor_value[3], 4)
# Remaining elements: shape at fixed length.
self.assertAllClose(tensor_value[4:], [2, 2, 0, 0, 0, 0])
@parameterized.named_parameters(
("NoTensor", "NO_TENSOR"),
("CurtHealth", "CURT_HEALTH"),
("ConciseHealth", "CONCISE_HEALTH"),
("Shape", "SHAPE"),
("FullTensor", "FULL_TENSOR"),
)
@test_util.run_in_graph_and_eager_modes
@ -276,6 +416,30 @@ class TracingCallbackTest(
self.assertGreaterEqual(tensor_value[0], 0)
# 2nd element: 0 means there is no inf or nan.
self.assertEqual(tensor_value[1], 0)
elif tensor_debug_mode == "CONCISE_HEALTH":
for tensor_value in tensor_values:
self.assertLen(tensor_value, 5)
# 1st element: tensor_id, should be >= 0.
# TODO(cais): Assert on detailed value once Function-graph association
# is in place.
self.assertGreaterEqual(tensor_value[0], 0)
# 2nd element: element count. Remaining elements: all zero because there
# is no -inf, inf or nan.
self.assertAllClose(tensor_value[1:], [1, 0, 0, 0])
elif tensor_debug_mode == "SHAPE":
for tensor_value in tensor_values:
# 1st element: tensor_id, should be >= 0.
# TODO(cais): Assert on detailed value once Function-graph association
# is in place.
self.assertGreaterEqual(tensor_value[0], 0)
# 2nd element: dtype (float32).
self.assertGreaterEqual(tensor_value[1], 1)
# 3rd element: rank (scalar).
self.assertGreaterEqual(tensor_value[2], 0)
# 4th element: element count.
self.assertGreaterEqual(tensor_value[3], 1)
# Remaining elements: shape padded to fixed length.
self.assertAllClose(tensor_value[4:], [0, 0, 0, 0, 0, 0])
elif tensor_debug_mode == "FULL_TENSOR":
self.assertAllClose(tensor_values[0], 5.0) # 1st AddV2 op.
self.assertAllClose(tensor_values[1], np.log(5.0)) # Log op.
@ -713,6 +877,8 @@ class TracingCallbackTest(
@parameterized.named_parameters(
("NoTensor", "NO_TENSOR"),
("CurtHealth", "CURT_HEALTH"),
("ConciseHealth", "CONCISE_HEALTH"),
("Shape", "SHAPE"),
("FullTensor", "FULL_TENSOR"),
)
def testMultiThreadedExecutionWithSameSetting(self, tensor_debug_mode):
@ -774,6 +940,35 @@ class TracingCallbackTest(
self.assertGreaterEqual(tensor_value[0], 0)
# 2nd element: 0 means there is no inf or nan.
self.assertEqual(tensor_value[1], 0)
elif tensor_debug_mode == "CONCISE_HEALTH":
for tensor_value in tensor_values:
self.assertLen(tensor_value, 5)
# 1st element: tensor_id, should be >= 0.
# TODO(cais): Assert on detailed value once Function-graph association
# is in place.
self.assertGreaterEqual(tensor_value[0], 0)
# 2nd element: element count. Remaining elements: all zero because there
# is no -inf, inf or nan.
self.assertAllClose(tensor_value[1:], [1, 0, 0, 0])
elif tensor_debug_mode == "SHAPE":
mul_values = [
tensor_values[i]
for i, op_type in enumerate(executed_op_types)
if op_type == "Mul"
]
for mul_value in mul_values:
# 1st element: tensor_id, should be >= 0.
# TODO(cais): Assert on detailed value once Function-graph association
# is in place.
self.assertGreaterEqual(mul_value[0], 0)
# 2nd element: dtype enum value (float32).
self.assertEqual(mul_value[1], 1)
# 3rd element: rank.
self.assertEqual(mul_value[2], 0)
# 3rd element: element count.
self.assertEqual(mul_value[3], 1)
# Remaining elements: shape padded to a fixed length.
self.assertAllClose(mul_value[4:], [0, 0, 0, 0, 0, 0])
elif tensor_debug_mode == "FULL_TENSOR":
mul_values = [
tensor_values[i]