From 316bd31e02b78a071d2f7f5a87898dd5f125f371 Mon Sep 17 00:00:00 2001
From: Shanqing Cai <cais@google.com>
Date: Tue, 17 Dec 2019 17:40:09 -0800
Subject: [PATCH] [tfdbg] enable_dump_debug_info(): Support CONCISE_HEALTH and
 SHAPE modes

- Add eager and graph instrumentation logic for CONCISE_HEALTH and SHAPE
  modes to dumping_callback.py
- Add unit tests for the new modes on CPU, GPU and TPU:
  - To dumping_callback_test.py
  - To tpu_callbacks_test.py
- Add CONCISE_HEALTH and SHAPE to the continuously running TensorFlowGpuRegression
  benchmarks on Guitar.
- Register the DebugNumericSummaryV2 ops with more succinct template definitions.
- Expand the kernel registration to int and bool dtypes.
- Minor changes: Simplify and clean up some test code.

PiperOrigin-RevId: 286097114
Change-Id: I4e62984b7f11034dcb13fc10ee55b92d08011507
---
 tensorflow/core/kernels/debug_ops.cc          |  57 +++--
 tensorflow/core/kernels/debug_ops_gpu.cu.cc   |  16 ++
 .../debug/examples/v2/debug_mnist_v2.py       |   3 +-
 .../python/debug/lib/dumping_callback.py      |  33 ++-
 .../python/debug/lib/dumping_callback_test.py | 195 ++++++++++++++++++
 5 files changed, 269 insertions(+), 35 deletions(-)

diff --git a/tensorflow/core/kernels/debug_ops.cc b/tensorflow/core/kernels/debug_ops.cc
index 03c7cfdac38..db42b9f6511 100644
--- a/tensorflow/core/kernels/debug_ops.cc
+++ b/tensorflow/core/kernels/debug_ops.cc
@@ -160,6 +160,9 @@ TF_CALL_half(REGISTER_DEBUG_NUMERIC_SUMMARY_V2_FLOAT);
 TF_CALL_bfloat16(REGISTER_DEBUG_NUMERIC_SUMMARY_V2_FLOAT);
 TF_CALL_float(REGISTER_DEBUG_NUMERIC_SUMMARY_V2_FLOAT);
 TF_CALL_double(REGISTER_DEBUG_NUMERIC_SUMMARY_V2_FLOAT);
+TF_CALL_INTEGRAL_TYPES(REGISTER_DEBUG_NUMERIC_SUMMARY_V2_FLOAT);
+TF_CALL_bool(REGISTER_DEBUG_NUMERIC_SUMMARY_V2_FLOAT);
+// TODO(cais): Add string support.
 
 #define REGISTER_DEBUG_NUMERIC_SUMMARY_V2_DOUBLE(type)                 \
   REGISTER_KERNEL_BUILDER(Name("DebugNumericSummaryV2")                \
@@ -171,39 +174,31 @@ TF_CALL_half(REGISTER_DEBUG_NUMERIC_SUMMARY_V2_DOUBLE);
 TF_CALL_bfloat16(REGISTER_DEBUG_NUMERIC_SUMMARY_V2_DOUBLE);
 TF_CALL_float(REGISTER_DEBUG_NUMERIC_SUMMARY_V2_DOUBLE);
 TF_CALL_double(REGISTER_DEBUG_NUMERIC_SUMMARY_V2_DOUBLE);
+TF_CALL_INTEGRAL_TYPES(REGISTER_DEBUG_NUMERIC_SUMMARY_V2_DOUBLE);
+TF_CALL_bool(REGISTER_DEBUG_NUMERIC_SUMMARY_V2_DOUBLE);
+// TODO(cais): Add string support.
 
 #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
-REGISTER_KERNEL_BUILDER(Name("DebugNumericSummaryV2")
-                            .Device(DEVICE_GPU)
-                            .TypeConstraint<Eigen::half>("T")
-                            .TypeConstraint<float>("output_dtype"),
-                        DebugNumericSummaryV2Op<GPUDevice, Eigen::half, float>);
-REGISTER_KERNEL_BUILDER(Name("DebugNumericSummaryV2")
-                            .Device(DEVICE_GPU)
-                            .TypeConstraint<float>("T")
-                            .TypeConstraint<float>("output_dtype"),
-                        DebugNumericSummaryV2Op<GPUDevice, float, float>);
-REGISTER_KERNEL_BUILDER(Name("DebugNumericSummaryV2")
-                            .Device(DEVICE_GPU)
-                            .TypeConstraint<double>("T")
-                            .TypeConstraint<float>("output_dtype"),
-                        DebugNumericSummaryV2Op<GPUDevice, double, float>);
-REGISTER_KERNEL_BUILDER(
-    Name("DebugNumericSummaryV2")
-        .Device(DEVICE_GPU)
-        .TypeConstraint<Eigen::half>("T")
-        .TypeConstraint<double>("output_dtype"),
-    DebugNumericSummaryV2Op<GPUDevice, Eigen::half, double>);
-REGISTER_KERNEL_BUILDER(Name("DebugNumericSummaryV2")
-                            .Device(DEVICE_GPU)
-                            .TypeConstraint<float>("T")
-                            .TypeConstraint<double>("output_dtype"),
-                        DebugNumericSummaryV2Op<GPUDevice, float, double>);
-REGISTER_KERNEL_BUILDER(Name("DebugNumericSummaryV2")
-                            .Device(DEVICE_GPU)
-                            .TypeConstraint<double>("T")
-                            .TypeConstraint<double>("output_dtype"),
-                        DebugNumericSummaryV2Op<GPUDevice, double, double>);
+#define REGISTER_DEBUG_NUMERIC_SUMMARY_V2_GPU(in_type, out_type) \
+  REGISTER_KERNEL_BUILDER(                                       \
+      Name("DebugNumericSummaryV2")                              \
+          .Device(DEVICE_GPU)                                    \
+          .TypeConstraint<in_type>("T")                          \
+          .TypeConstraint<out_type>("output_dtype"),             \
+      DebugNumericSummaryV2Op<GPUDevice, in_type, out_type>);
+
+REGISTER_DEBUG_NUMERIC_SUMMARY_V2_GPU(Eigen::half, float);
+REGISTER_DEBUG_NUMERIC_SUMMARY_V2_GPU(float, float);
+REGISTER_DEBUG_NUMERIC_SUMMARY_V2_GPU(double, float);
+REGISTER_DEBUG_NUMERIC_SUMMARY_V2_GPU(int16, float);
+REGISTER_DEBUG_NUMERIC_SUMMARY_V2_GPU(int32, float);
+
+REGISTER_DEBUG_NUMERIC_SUMMARY_V2_GPU(Eigen::half, double);
+REGISTER_DEBUG_NUMERIC_SUMMARY_V2_GPU(float, double);
+REGISTER_DEBUG_NUMERIC_SUMMARY_V2_GPU(double, double);
+REGISTER_DEBUG_NUMERIC_SUMMARY_V2_GPU(int16, double);
+REGISTER_DEBUG_NUMERIC_SUMMARY_V2_GPU(int32, double);
+
 #endif  // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
 
 }  // namespace tensorflow
diff --git a/tensorflow/core/kernels/debug_ops_gpu.cu.cc b/tensorflow/core/kernels/debug_ops_gpu.cu.cc
index 2e93c3ca24d..a388b067f99 100644
--- a/tensorflow/core/kernels/debug_ops_gpu.cu.cc
+++ b/tensorflow/core/kernels/debug_ops_gpu.cu.cc
@@ -168,9 +168,13 @@ struct CurtHealthLaunch {
 template struct CurtHealthLaunch<Eigen::half, float>;
 template struct CurtHealthLaunch<float, float>;
 template struct CurtHealthLaunch<double, float>;
+template struct CurtHealthLaunch<int16, float>;
+template struct CurtHealthLaunch<int32, float>;
 template struct CurtHealthLaunch<Eigen::half, double>;
 template struct CurtHealthLaunch<float, double>;
 template struct CurtHealthLaunch<double, double>;
+template struct CurtHealthLaunch<int16, double>;
+template struct CurtHealthLaunch<int32, double>;
 
 template <typename Tin, typename Tout>
 struct ConciseHealthLaunch {
@@ -188,9 +192,13 @@ struct ConciseHealthLaunch {
 template struct ConciseHealthLaunch<Eigen::half, float>;
 template struct ConciseHealthLaunch<float, float>;
 template struct ConciseHealthLaunch<double, float>;
+template struct ConciseHealthLaunch<int16, float>;
+template struct ConciseHealthLaunch<int32, float>;
 template struct ConciseHealthLaunch<Eigen::half, double>;
 template struct ConciseHealthLaunch<float, double>;
 template struct ConciseHealthLaunch<double, double>;
+template struct ConciseHealthLaunch<int16, double>;
+template struct ConciseHealthLaunch<int32, double>;
 
 template <typename Tin, typename Tout>
 struct FullHealthLaunch {
@@ -208,9 +216,13 @@ struct FullHealthLaunch {
 template struct FullHealthLaunch<Eigen::half, float>;
 template struct FullHealthLaunch<float, float>;
 template struct FullHealthLaunch<double, float>;
+template struct FullHealthLaunch<int16, float>;
+template struct FullHealthLaunch<int32, float>;
 template struct FullHealthLaunch<Eigen::half, double>;
 template struct FullHealthLaunch<float, double>;
 template struct FullHealthLaunch<double, double>;
+template struct FullHealthLaunch<int16, double>;
+template struct FullHealthLaunch<int32, double>;
 
 template <typename Tin, typename Tout>
 struct ReduceInfNanThreeSlotsLaunch {
@@ -229,9 +241,13 @@ struct ReduceInfNanThreeSlotsLaunch {
 template struct ReduceInfNanThreeSlotsLaunch<Eigen::half, float>;
 template struct ReduceInfNanThreeSlotsLaunch<float, float>;
 template struct ReduceInfNanThreeSlotsLaunch<double, float>;
+template struct ReduceInfNanThreeSlotsLaunch<int16, float>;
+template struct ReduceInfNanThreeSlotsLaunch<int32, float>;
 template struct ReduceInfNanThreeSlotsLaunch<Eigen::half, double>;
 template struct ReduceInfNanThreeSlotsLaunch<float, double>;
 template struct ReduceInfNanThreeSlotsLaunch<double, double>;
+template struct ReduceInfNanThreeSlotsLaunch<int16, double>;
+template struct ReduceInfNanThreeSlotsLaunch<int32, double>;
 
 }  // namespace tensorflow
 #endif  // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
diff --git a/tensorflow/python/debug/examples/v2/debug_mnist_v2.py b/tensorflow/python/debug/examples/v2/debug_mnist_v2.py
index 9d410b36c98..539be3cd54f 100644
--- a/tensorflow/python/debug/examples/v2/debug_mnist_v2.py
+++ b/tensorflow/python/debug/examples/v2/debug_mnist_v2.py
@@ -100,7 +100,8 @@ def parse_args():
       type=str,
       default="NO_TENSOR",
       help="Mode for dumping tensor values. Options: NO_TENSOR, CURT_HEALTH, "
-      "FULL_TENSOR. This is relevant only when --dump_dir is set.")
+      "CONCISE_HEALTH, SHAPE, FULL_TENSOR. This is relevant only when "
+      "--dump_dir is set.")
   # TODO(cais): Add more tensor debug mode strings once they are supported.
   parser.add_argument(
       "--dump_circular_buffer_size",
diff --git a/tensorflow/python/debug/lib/dumping_callback.py b/tensorflow/python/debug/lib/dumping_callback.py
index 2532bd2e7e3..98e7292a785 100644
--- a/tensorflow/python/debug/lib/dumping_callback.py
+++ b/tensorflow/python/debug/lib/dumping_callback.py
@@ -324,10 +324,20 @@ class _DumpingCallback(object):
               debug_tensor.op)
           instrumented_tensors.append(identity)
       return instrumented_tensors
-    elif tensor_debug_mode == debug_event_pb2.TensorDebugMode.CURT_HEALTH:
+    elif tensor_debug_mode in (debug_event_pb2.TensorDebugMode.CURT_HEALTH,
+                               debug_event_pb2.TensorDebugMode.CONCISE_HEALTH,
+                               debug_event_pb2.TensorDebugMode.SHAPE):
       for output_slot, tensor in enumerate(tensors):
+        dtype = tensor.dtype
+        dtype_is_dumpable = (
+            tensor_debug_mode in (
+                debug_event_pb2.TensorDebugMode.CURT_HEALTH,
+                debug_event_pb2.TensorDebugMode.CONCISE_HEALTH) and
+            dtype.is_floating or
+            tensor_debug_mode == debug_event_pb2.TensorDebugMode.SHAPE and
+            (dtype.is_floating or dtype.is_integer or dtype.is_bool))
         if (not self._should_dump_tensor(op_type, tensor.dtype) or
-            not tensor.dtype.is_floating):
+            not dtype_is_dumpable):
           if is_v1_graph_mode:
             instrumented_tensors.append(tensor)
           continue
@@ -409,6 +419,8 @@ class _DumpingCallback(object):
           tensor_debug_mode=tensor_debug_mode,
           code_location=self._process_stack_frames())
     elif tensor_debug_mode in (debug_event_pb2.TensorDebugMode.CURT_HEALTH,
+                               debug_event_pb2.TensorDebugMode.CONCISE_HEALTH,
+                               debug_event_pb2.TensorDebugMode.SHAPE,
                                debug_event_pb2.TensorDebugMode.FULL_TENSOR):
       execution_proto = debug_event_pb2.Execution(
           op_type=op_type,
@@ -421,7 +433,9 @@ class _DumpingCallback(object):
       for tensor in tensors:
         if (self._should_dump_tensor(op_type, tensor.dtype) and
             tensor.dtype.is_numpy_compatible):
-          if tensor_debug_mode == debug_event_pb2.TensorDebugMode.CURT_HEALTH:
+          if tensor_debug_mode in (
+              debug_event_pb2.TensorDebugMode.CURT_HEALTH,
+              debug_event_pb2.TensorDebugMode.CONCISE_HEALTH):
             if tensor.dtype.is_floating:
               tensor_proto = _concrete_tensor_to_proto(
                   gen_debug_ops.debug_numeric_summary_v2(
@@ -431,6 +445,17 @@ class _DumpingCallback(object):
             else:
               # A placeholder for non-floating-type output tensors.
               tensor_proto = tensor_pb2.TensorProto()
+          elif tensor_debug_mode == debug_event_pb2.TensorDebugMode.SHAPE:
+            if (tensor.dtype.is_floating or tensor.dtype.is_integer or
+                tensor.dtype.is_bool):
+              tensor_proto = _concrete_tensor_to_proto(
+                  gen_debug_ops.debug_numeric_summary_v2(
+                      tensor,
+                      tensor_debug_mode=tensor_debug_mode,
+                      output_dtype=dtypes.float64))
+            else:
+              # A placeholder for non-floating-type output tensors.
+              tensor_proto = tensor_pb2.TensorProto()
           elif tensor_debug_mode == debug_event_pb2.TensorDebugMode.FULL_TENSOR:
             tensor_proto = _concrete_tensor_to_proto(tensor)
           if tensor_proto:
@@ -657,6 +682,8 @@ def enable_dump_debug_info(dump_root,
   tensor_debug_mode = debug_event_pb2.TensorDebugMode.Value(tensor_debug_mode)
   if tensor_debug_mode not in (debug_event_pb2.TensorDebugMode.NO_TENSOR,
                                debug_event_pb2.TensorDebugMode.CURT_HEALTH,
+                               debug_event_pb2.TensorDebugMode.CONCISE_HEALTH,
+                               debug_event_pb2.TensorDebugMode.SHAPE,
                                debug_event_pb2.TensorDebugMode.FULL_TENSOR):
     raise NotImplementedError(
         "tfdbg dumping: support for tensor debug mode %s is not "
diff --git a/tensorflow/python/debug/lib/dumping_callback_test.py b/tensorflow/python/debug/lib/dumping_callback_test.py
index 9400610b946..b7e90f3179c 100644
--- a/tensorflow/python/debug/lib/dumping_callback_test.py
+++ b/tensorflow/python/debug/lib/dumping_callback_test.py
@@ -88,6 +88,8 @@ class TracingCallbackTest(
   @parameterized.named_parameters(
       ("NoTensor", "NO_TENSOR"),
       ("CurtHealth", "CURT_HEALTH"),
+      ("ConciseHealth", "CONCISE_HEALTH"),
+      ("Shape", "SHAPE"),
       ("FullTensor", "FULL_TENSOR"),
   )
   def testPureEagerOpExecution(self, tensor_debug_mode):
@@ -146,6 +148,26 @@ class TracingCallbackTest(
             self.assertAllClose(
                 tensor_util.MakeNdarray(execution.tensor_protos[0]),
                 [-1.0, 0.0])
+        elif tensor_debug_mode == "CONCISE_HEALTH":
+          self.assertLen(execution.tensor_protos, 1)
+          if execution.op_type in ("AddV2", "Mul", "RealDiv"):
+            # 1st element: -1 is the unset tensor_id for eager op execution.
+            # 2nd element: each scalar tensor has 1 element.
+            # Remaining elements: no -inf, inf or nan in these
+            self.assertAllClose(
+                tensor_util.MakeNdarray(execution.tensor_protos[0]),
+                [-1, 1, 0, 0, 0])
+        elif tensor_debug_mode == "SHAPE":
+          self.assertLen(execution.tensor_protos, 1)
+          if execution.op_type in ("AddV2", "Mul", "RealDiv"):
+            # 1st element: -1 is the unset tensor_id for eager op execution.
+            # 2nd element: dtype enum value (float32).
+            # 3rd element: rank (scalar).
+            # 4th element: element count (4).
+            # Remaining elements: shape at fixed length (6).
+            self.assertAllClose(
+                tensor_util.MakeNdarray(execution.tensor_protos[0]),
+                [-1, 1, 0, 1, 0, 0, 0, 0, 0, 0])
         elif tensor_debug_mode == "FULL_TENSOR":
           # Under the FULL_TENSOR mode, the value of the tensor should be
           # available through `tensor_protos`.
@@ -202,9 +224,127 @@ class TracingCallbackTest(
       with self.assertRaises(StopIteration):
         next(graph_trace_iter)
 
+  @parameterized.named_parameters(
+      ("CurtHealth", "CURT_HEALTH"),
+      ("ConciseHealth", "CONCISE_HEALTH"),
+      ("Shape", "SHAPE"),
+  )
+  @test_util.run_in_graph_and_eager_modes
+  def testModesSummarizingBadNumericalValue(self, tensor_debug_mode):
+    writer = dumping_callback.enable_dump_debug_info(
+        self.dump_root, tensor_debug_mode=tensor_debug_mode)
+
+    @def_function.function
+    def func(x, y):
+      return (x + y) / (x - y)
+
+    x = np.array([-3, -1, 0, 0, 1, 1, 1, 2], dtype=np.float16)
+    y = np.array([2, -1, 0, 0, 1, 1, 1, 3], dtype=np.float16)
+    # (x + y) / (x - y) = [0.2, -inf, nan, nan, inf, inf, inf, -5].
+    self.evaluate(func(x, y))
+
+    writer.FlushNonExecutionFiles()
+    writer.FlushExecutionFiles()
+
+    stack_frame_by_id = self._readAndCheckSourceFilesAndStackFrames()
+    (context_ids,
+     _, op_name_to_op_type, _) = self._readAndCheckGraphsFile(stack_frame_by_id)
+
+    (op_names, _, _,
+     tensor_values) = self._readAndCheckGraphExecutionTracesFile(context_ids)
+    executed_op_types = [op_name_to_op_type[op_name] for op_name in op_names]
+    self.assertCountEqual(executed_op_types, ["AddV2", "Sub", "RealDiv"])
+
+    if tensor_debug_mode == "CURT_HEALTH":
+      for op_type, tensor_value in zip(executed_op_types, tensor_values):
+        self.assertLen(tensor_value, 2)
+        # 1st element: tensor_id, should be >= 0.
+        # TODO(cais): Assert on detailed value once Function-graph association
+        # is in place.
+        self.assertGreaterEqual(tensor_value[0], 0)
+        # 2nd element: 0 means there is no inf or nan.
+        if op_type == "RealDiv":
+          self.assertEqual(tensor_value[1], 1)
+        else:
+          self.assertEqual(tensor_value[1], 0)
+    elif tensor_debug_mode == "CONCISE_HEALTH":
+      for op_type, tensor_value in zip(executed_op_types, tensor_values):
+        self.assertLen(tensor_value, 5)
+        # 1st element: tensor_id, should be >= 0.
+        # TODO(cais): Assert on detailed value once Function-graph association
+        # is in place.
+        self.assertGreaterEqual(tensor_value[0], 0)
+        # 2nd element: element count.
+        self.assertEqual(tensor_value[1], 8)
+        # Remaining 3 elements: The counts of -inf, inf and nan.
+        if op_type == "RealDiv":
+          self.assertAllClose(tensor_value[2:], [1, 3, 2])
+        else:
+          self.assertAllClose(tensor_value[2:], [0, 0, 0])
+    else:  # SHAPE.
+      for op_type, tensor_value in zip(executed_op_types, tensor_values):
+        self.assertLen(tensor_value, 10)
+        # 1st element: tensor_id, should be >= 0.
+        # TODO(cais): Assert on detailed value once Function-graph association
+        # is in place.
+        self.assertGreaterEqual(tensor_value[0], 0)
+        # 2nd element: dtype enum value (float16).
+        self.assertEqual(tensor_value[1], 19)
+        # 3rd element: rank (1)
+        self.assertEqual(tensor_value[2], 1)
+        # 4th element: element count.
+        self.assertEqual(tensor_value[3], 8)
+        # Remaining elements: shape at fixed length.
+        self.assertAllClose(tensor_value[4:], [8, 0, 0, 0, 0, 0])
+
+  @parameterized.named_parameters(
+      ("Shape", "SHAPE"),
+  )
+  @test_util.run_in_graph_and_eager_modes
+  def testBooleanTensors(self, tensor_debug_mode):
+    writer = dumping_callback.enable_dump_debug_info(
+        self.dump_root, tensor_debug_mode=tensor_debug_mode)
+
+    @def_function.function
+    def func(x, y):
+      return math_ops.logical_not(math_ops.logical_and(x, y))
+
+    x = np.array([[False, False], [True, True]], dtype=np.bool)
+    y = np.array([[False, True], [False, True]], dtype=np.bool)
+    self.assertAllEqual(
+        self.evaluate(func(x, y)), [[True, True], [True, False]])
+
+    writer.FlushNonExecutionFiles()
+    writer.FlushExecutionFiles()
+
+    stack_frame_by_id = self._readAndCheckSourceFilesAndStackFrames()
+    (context_ids,
+     _, op_name_to_op_type, _) = self._readAndCheckGraphsFile(stack_frame_by_id)
+
+    (op_names, _, _,
+     tensor_values) = self._readAndCheckGraphExecutionTracesFile(context_ids)
+    executed_op_types = [op_name_to_op_type[op_name] for op_name in op_names]
+    self.assertEqual(executed_op_types, ["LogicalAnd", "LogicalNot"])
+
+    for tensor_value in tensor_values:
+      # 1st element: tensor_id, should be >= 0.
+      # TODO(cais): Assert on detailed value once Function-graph association
+      # is in place.
+      self.assertGreaterEqual(tensor_value[0], 0)
+      # 2nd element: dtype enum value (bool).
+      self.assertEqual(tensor_value[1], 10)
+      # 3rd element: rank (2)
+      self.assertEqual(tensor_value[2], 2)
+      # 4th element: element count.
+      self.assertEqual(tensor_value[3], 4)
+      # Remaining elements: shape at fixed length.
+      self.assertAllClose(tensor_value[4:], [2, 2, 0, 0, 0, 0])
+
   @parameterized.named_parameters(
       ("NoTensor", "NO_TENSOR"),
       ("CurtHealth", "CURT_HEALTH"),
+      ("ConciseHealth", "CONCISE_HEALTH"),
+      ("Shape", "SHAPE"),
       ("FullTensor", "FULL_TENSOR"),
   )
   @test_util.run_in_graph_and_eager_modes
@@ -276,6 +416,30 @@ class TracingCallbackTest(
         self.assertGreaterEqual(tensor_value[0], 0)
         # 2nd element: 0 means there is no inf or nan.
         self.assertEqual(tensor_value[1], 0)
+    elif tensor_debug_mode == "CONCISE_HEALTH":
+      for tensor_value in tensor_values:
+        self.assertLen(tensor_value, 5)
+        # 1st element: tensor_id, should be >= 0.
+        # TODO(cais): Assert on detailed value once Function-graph association
+        # is in place.
+        self.assertGreaterEqual(tensor_value[0], 0)
+        # 2nd element: element count. Remaining elements: all zero because there
+        # is no -inf, inf or nan.
+        self.assertAllClose(tensor_value[1:], [1, 0, 0, 0])
+    elif tensor_debug_mode == "SHAPE":
+      for tensor_value in tensor_values:
+        # 1st element: tensor_id, should be >= 0.
+        # TODO(cais): Assert on detailed value once Function-graph association
+        # is in place.
+        self.assertGreaterEqual(tensor_value[0], 0)
+        # 2nd element: dtype (float32).
+        self.assertGreaterEqual(tensor_value[1], 1)
+        # 3rd element: rank (scalar).
+        self.assertGreaterEqual(tensor_value[2], 0)
+        # 4th element: element count.
+        self.assertGreaterEqual(tensor_value[3], 1)
+        # Remaining elements: shape padded to fixed length.
+        self.assertAllClose(tensor_value[4:], [0, 0, 0, 0, 0, 0])
     elif tensor_debug_mode == "FULL_TENSOR":
       self.assertAllClose(tensor_values[0], 5.0)  # 1st AddV2 op.
       self.assertAllClose(tensor_values[1], np.log(5.0))  # Log op.
@@ -713,6 +877,8 @@ class TracingCallbackTest(
   @parameterized.named_parameters(
       ("NoTensor", "NO_TENSOR"),
       ("CurtHealth", "CURT_HEALTH"),
+      ("ConciseHealth", "CONCISE_HEALTH"),
+      ("Shape", "SHAPE"),
       ("FullTensor", "FULL_TENSOR"),
   )
   def testMultiThreadedExecutionWithSameSetting(self, tensor_debug_mode):
@@ -774,6 +940,35 @@ class TracingCallbackTest(
         self.assertGreaterEqual(tensor_value[0], 0)
         # 2nd element: 0 means there is no inf or nan.
         self.assertEqual(tensor_value[1], 0)
+    elif tensor_debug_mode == "CONCISE_HEALTH":
+      for tensor_value in tensor_values:
+        self.assertLen(tensor_value, 5)
+        # 1st element: tensor_id, should be >= 0.
+        # TODO(cais): Assert on detailed value once Function-graph association
+        # is in place.
+        self.assertGreaterEqual(tensor_value[0], 0)
+        # 2nd element: element count. Remaining elements: all zero because there
+        # is no -inf, inf or nan.
+        self.assertAllClose(tensor_value[1:], [1, 0, 0, 0])
+    elif tensor_debug_mode == "SHAPE":
+      mul_values = [
+          tensor_values[i]
+          for i, op_type in enumerate(executed_op_types)
+          if op_type == "Mul"
+      ]
+      for mul_value in mul_values:
+        # 1st element: tensor_id, should be >= 0.
+        # TODO(cais): Assert on detailed value once Function-graph association
+        # is in place.
+        self.assertGreaterEqual(mul_value[0], 0)
+        # 2nd element: dtype enum value (float32).
+        self.assertEqual(mul_value[1], 1)
+        # 3rd element: rank.
+        self.assertEqual(mul_value[2], 0)
+        # 3rd element: element count.
+        self.assertEqual(mul_value[3], 1)
+        # Remaining elements: shape padded to a fixed length.
+        self.assertAllClose(mul_value[4:], [0, 0, 0, 0, 0, 0])
     elif tensor_debug_mode == "FULL_TENSOR":
       mul_values = [
           tensor_values[i]