Fix dtype check in GradientTape
ResourceVariables handles will have Tensor type as dtype.resource. Instead, we should check the dtype of the Variable Tensor. PiperOrigin-RevId: 241623008
This commit is contained in:
parent
a6a0611a63
commit
f44bcfbea9
@ -957,7 +957,7 @@ class GradientTape(object):
|
||||
flat_sources = nest.flatten(sources)
|
||||
flat_sources_raw = flat_sources
|
||||
flat_sources = [_handle_or_self(x) for x in flat_sources]
|
||||
for t in flat_sources:
|
||||
for t in flat_sources_raw:
|
||||
if not t.dtype.is_floating:
|
||||
logging.vlog(
|
||||
logging.WARN, "The dtype of the source tensor must be "
|
||||
|
Loading…
Reference in New Issue
Block a user