diff --git a/tensorflow/python/ops/script_ops.py b/tensorflow/python/ops/script_ops.py index bb41559b599..684736f81cb 100644 --- a/tensorflow/python/ops/script_ops.py +++ b/tensorflow/python/ops/script_ops.py @@ -104,8 +104,11 @@ class EagerFunc(object): """Passes `args` to `self._func`, which is executed eagerly.""" with context.eager_mode(), backprop.GradientTape() as tape: + # Only watch tensors with a floating dtype. for tensor in args: - tape.watch(tensor) + for t in nest.flatten(tensor): + if t.dtype.is_floating: + tape.watch(t) ret = self._func(*args) # Use tf.identity to copy the returned tensors to device if neccesary. with ops.device(device):