Fix dtype check in GradientTape

ResourceVariables handles will have Tensor type as dtype.resource.
Instead, we should check the dtype of the Variable Tensor.

PiperOrigin-RevId: 241623008
This commit is contained in:
Gaurav Jain 2019-04-02 16:04:19 -07:00 committed by TensorFlower Gardener
parent a6a0611a63
commit f44bcfbea9

View File

@ -957,7 +957,7 @@ class GradientTape(object):
flat_sources = nest.flatten(sources)
flat_sources_raw = flat_sources
flat_sources = [_handle_or_self(x) for x in flat_sources]
for t in flat_sources:
for t in flat_sources_raw:
if not t.dtype.is_floating:
logging.vlog(
logging.WARN, "The dtype of the source tensor must be "