Merge pull request #30107 from feihugis:Issue_28248_Tape_Watch_Warning
PiperOrigin-RevId: 255187877
This commit is contained in:
commit
232fb86bfb
@ -104,8 +104,11 @@ class EagerFunc(object):
|
||||
"""Passes `args` to `self._func`, which is executed eagerly."""
|
||||
|
||||
with context.eager_mode(), backprop.GradientTape() as tape:
|
||||
# Only watch tensors with a floating dtype.
|
||||
for tensor in args:
|
||||
tape.watch(tensor)
|
||||
for t in nest.flatten(tensor):
|
||||
if t.dtype.is_floating:
|
||||
tape.watch(t)
|
||||
ret = self._func(*args)
|
||||
# Use tf.identity to copy the returned tensors to device if neccesary.
|
||||
with ops.device(device):
|
||||
|
Loading…
x
Reference in New Issue
Block a user