Merge pull request #30107 from feihugis:Issue_28248_Tape_Watch_Warning

PiperOrigin-RevId: 255187877
This commit is contained in:
TensorFlower Gardener 2019-06-26 13:28:06 -07:00
commit 232fb86bfb

View File

@ -104,8 +104,11 @@ class EagerFunc(object):
"""Passes `args` to `self._func`, which is executed eagerly."""
with context.eager_mode(), backprop.GradientTape() as tape:
# Only watch tensors with a floating dtype.
for tensor in args:
tape.watch(tensor)
for t in nest.flatten(tensor):
if t.dtype.is_floating:
tape.watch(t)
ret = self._func(*args)
# Use tf.identity to copy the returned tensors to device if neccesary.
with ops.device(device):