diff --git a/tensorflow/python/debug/lib/check_numerics_callback.py b/tensorflow/python/debug/lib/check_numerics_callback.py
index fc2124af2f2..5260015850f 100644
--- a/tensorflow/python/debug/lib/check_numerics_callback.py
+++ b/tensorflow/python/debug/lib/check_numerics_callback.py
@@ -50,9 +50,14 @@ def limit_string_length(string, max_len=50):
 _CHECK_NUMERICS_CALLBACK_SKIP_OPS = (
     # TODO(b/139668453): The following skipped ops are related to a limitation
     # in the op callback.
+    b"Enter",
+    b"Exit",
     b"Identity",
     b"If",
+    b"Merge",
+    b"NextIteration",
     b"StatelessIf",
+    b"Switch",
     b"While",
 )
 
diff --git a/tensorflow/python/debug/lib/check_numerics_callback_test.py b/tensorflow/python/debug/lib/check_numerics_callback_test.py
index 457d7b292dc..3f565d21c36 100644
--- a/tensorflow/python/debug/lib/check_numerics_callback_test.py
+++ b/tensorflow/python/debug/lib/check_numerics_callback_test.py
@@ -271,6 +271,24 @@ class CheckNumericsCallbackTest(test_util.TensorFlowTestCase):
     history = model.fit(xs, ys, epochs=epochs, verbose=0)
     self.assertEqual(len(history.history["loss"]), epochs)
 
+  @test_util.run_in_graph_and_eager_modes
+  def testKerasModelWithRNNHealthyPredictAndFitCalls(self):
+    """Test a simple healthy keras recurrent model works under the callback."""
+    check_numerics_callback.enable_check_numerics()
+
+    model = models.Sequential()
+    model.add(layers.LSTM(1, input_shape=(2, 4)))
+    model.compile(loss="mse", optimizer="rmsprop")
+
+    xs = np.zeros([8, 2, 4], dtype=np.float32)
+    ys = np.zeros([8, 1], dtype=np.float32)
+
+    model.predict(xs)
+
+    epochs = 3
+    history = model.fit(xs, ys, epochs=epochs, verbose=0)
+    self.assertEqual(len(history.history["loss"]), epochs)
+
   @test_util.run_in_graph_and_eager_modes
   def testKerasModelUnhealthyPredictAndFitCallsWithLargeLearningRate(self):
     """Test keras model training crashes with Infinity is caught by callback."""
diff --git a/tensorflow/python/framework/op_callbacks_test.py b/tensorflow/python/framework/op_callbacks_test.py
index 0027c74ba7f..c55b9720a3b 100644
--- a/tensorflow/python/framework/op_callbacks_test.py
+++ b/tensorflow/python/framework/op_callbacks_test.py
@@ -49,13 +49,17 @@ _ADD_OP = b"AddV2"
 _ASSIGN_ADD_VARIABLE_OP = b"AssignAddVariableOp"
 _CONSTANT_OP = b"Const"
 _COS_OP = b"Cos"
+_ENTER_OP = b"Enter"
+_EXIT_OP = b"Exit"
 _GREATER_OP = b"Greater"
 _IDENTITY_OP = b"Identity"
 _IF_OP = b"If"
 _LESS_OP = b"Less"
 _LOG_OP = b"Log"
+_MERGE_OP = b"Merge"
 _MATMUL_OP = b"MatMul"
 _MUL_OP = b"Mul"
+_NEXT_ITERATION_OP = b"NextIteration"
 _PLACEHOLDER_OP = b"Placeholder"
 _POW_OP = b"Pow"
 _READ_VARIALBE_OP = b"ReadVariableOp"
@@ -64,6 +68,7 @@ _SPARSE_TENSOR_DENSE_MATMUL_OP = b"SparseTensorDenseMatMul"
 _SQRT_OP = b"Sqrt"
 _SQUARE_OP = b"Square"
 _STATELESS_IF_OP = b"StatelessIf"
+_SWTICH_OP = b"Switch"
 _UNIQUE_OP = b"Unique"
 _VAR_HANDLE_OP = b"VarHandleOp"
 _WHILE_OP = b"While"
@@ -71,8 +76,9 @@ _WHILE_OP = b"While"
 
 class _NumpyFunctionCallback(object):
 
-  def __init__(self, instrument_graph_ops=True):
+  def __init__(self, instrument_graph_ops=True, float_only=False):
     self.instrument_graph_ops = instrument_graph_ops
+    self._float_only = float_only
     self.reset()
 
   def callback(self, op_type, inputs, attrs, outputs, op_name=None, graph=None):
@@ -102,7 +108,8 @@ class _NumpyFunctionCallback(object):
       instrumented_outputs = []
       for output in outputs:
         if compat.as_bytes(op_type) in (
-            _IF_OP, _STATELESS_IF_OP, _WHILE_OP, _IDENTITY_OP,
+            _ENTER_OP, _EXIT_OP, _IF_OP, _MERGE_OP, _NEXT_ITERATION_OP,
+            _STATELESS_IF_OP, _SWTICH_OP, _WHILE_OP, _IDENTITY_OP,
             _VAR_HANDLE_OP):
           # TODO(cais): Overriding the output of StatelessIf, If and While ops
           # currently fails with error. Investigate (b/139668453).
@@ -110,7 +117,6 @@ class _NumpyFunctionCallback(object):
           # by tf.function/AutoGraph for marshalling outputs.
           instrumented_output = output
         else:
-
           def record(ndarray_value):
             if compat.as_bytes(op_name) not in self.graph_internal_ndarrays:
               self.graph_internal_ndarrays[compat.as_bytes(op_name)] = []
@@ -118,9 +124,12 @@ class _NumpyFunctionCallback(object):
                 ndarray_value)
             return ndarray_value
 
-          instrumented_output = script_ops.numpy_function(
-              record, [output], output.dtype)
-          instrumented_output.set_shape(output.shape)
+          if self._float_only and not output.dtype.is_floating:
+            instrumented_output = output
+          else:
+            instrumented_output = script_ops.numpy_function(
+                record, [output], output.dtype)
+            instrumented_output.set_shape(output.shape)
         instrumented_outputs.append(instrumented_output)
 
       return instrumented_outputs
@@ -397,6 +406,49 @@ class OpCallbacksTest(test_util.TensorFlowTestCase):
       self.assertEqual(len(log_op_outputs), 1)
     self.assertAllClose(log_op_outputs[0], [0.0, np.log(2.0)])
 
+  @test_util.run_in_graph_and_eager_modes
+  def testPadOp(self):
+    instrument = _NumpyFunctionCallback()
+
+    op_callbacks.add_op_callback(instrument.callback)
+
+    @def_function.function
+    def my_pad(x, padding):
+      return array_ops.pad(x, padding)
+
+    x = constant_op.constant([[1, 2], [3, 4]], dtype=dtypes.float32)
+    paddings = [[1, 1], [2, 2]]
+
+    y = my_pad(x, paddings)
+    expected_output = np.array([
+        [0, 0, 0, 0, 0, 0],
+        [0, 0, 1, 2, 0, 0],
+        [0, 0, 3, 4, 0, 0],
+        [0, 0, 0, 0, 0, 0],
+    ], dtype=np.float32)
+    self.assertAllClose(y, expected_output)
+    self.assertAllClose(
+        instrument.graph_internal_ndarrays[b"Pad"][0], expected_output)
+
+  @test_util.run_in_graph_and_eager_modes
+  def testKerasLSTMPredict(self):
+    instrument = _NumpyFunctionCallback(float_only=True)
+
+    op_callbacks.add_op_callback(instrument.callback)
+
+    model = keras.Sequential()
+    model.add(keras.layers.LSTM(1, input_shape=(2, 4)))
+    model.compile(loss="mse", optimizer="sgd")
+
+    xs = np.zeros([8, 2, 4], dtype=np.float32)
+    ys = model.predict(xs)
+
+    self.assertAllClose(ys, np.zeros([8, 1]))
+    # We avoid asserting on the internal details of the LSTM implementation.
+    # Instead, we just assert that some graph-internal execution states are
+    # recorded by the callback.
+    self.assertTrue(instrument.graph_internal_ndarrays)
+
   @test_util.run_in_graph_and_eager_modes
   def testSimpleGraphConstructionWithCallbackReturningNone(self):
     """Test that callbacks that return None works."""
diff --git a/tensorflow/python/ops/array_ops.py b/tensorflow/python/ops/array_ops.py
index 3084efeaee7..94a5a1b0674 100644
--- a/tensorflow/python/ops/array_ops.py
+++ b/tensorflow/python/ops/array_ops.py
@@ -2952,8 +2952,7 @@ def pad(tensor, paddings, mode="CONSTANT", name=None, constant_values=0):  # pyl
 
   # Restore shape information where possible.
   if not context.executing_eagerly():
-    paddings_constant = tensor_util.constant_value(
-        result.op.inputs[1], partial=True)
+    paddings_constant = _get_paddings_constant(paddings)
     input_shape = result.op.inputs[0].shape
     if (input_shape.ndims is not None and
         not result.shape.is_fully_defined() and paddings_constant is not None):
@@ -2968,6 +2967,28 @@ def pad(tensor, paddings, mode="CONSTANT", name=None, constant_values=0):  # pyl
   return result
 
 
+def _get_paddings_constant(paddings):
+  """Helper to get the constant values of the paddings arg to pad().
+
+  Used under V1 graph mode to facilitate computation of the shape of the output
+  tensor of `pad()`.
+
+  Args:
+    paddings: The same paddings arg as passed to pad(). Can be a Tensor, or
+      a nested list or tuple of Tensor and/or numbers.
+
+  Returns:
+    A nested list or numbers or `None`, in which `None` indicates unknown
+    padding size.
+  """
+  if isinstance(paddings, ops.Tensor):
+    return tensor_util.constant_value(paddings, partial=True)
+  elif isinstance(paddings, (list, tuple)):
+    return [_get_paddings_constant(x) for x in paddings]
+  else:
+    return paddings
+
+
 @tf_export("meshgrid")
 def meshgrid(*args, **kwargs):
   """Broadcasts parameters for evaluation on an N-D grid.