PR #27183: Add some log info within GradientTape

Imported from GitHub PR #27183

@alextp, reference to #26143.

I saw @shashvatshahi1998 made some changes already, So I did some supplement here.

Please take a look at it, Thanks.

Copybara import of the project:

  - 958bf19a18c56e4512b464aaad41cbc2ad1141ac Add some log info within GradientTape by a6802739 <songhao9021@gmail.com>
  - e84c59595ad7d227a95917a32fbb264588d0de06 Merge 958bf19a18c56e4512b464aaad41cbc2ad1141ac into 1f95f... by songhao <songhao9021@gmail.com>

COPYBARA_INTEGRATE_REVIEW=https://github.com/tensorflow/tensorflow/pull/27183 from a6802739:add_log_info 958bf19a18c56e4512b464aaad41cbc2ad1141ac
PiperOrigin-RevId: 241399172
This commit is contained in:
A. Unique TensorFlower 2019-04-01 14:24:58 -07:00 committed by TensorFlower Gardener
parent daf7c5337d
commit 764fdc6fb1

View File

@ -812,6 +812,10 @@ class GradientTape(object):
tensor: a Tensor or list of Tensors.
"""
for t in nest.flatten(tensor):
if not t.dtype.is_floating:
logging.vlog(
logging.WARN, "The dtype of the watched tensor must be "
"floating (e.g. tf.float32), got %r", t.dtype)
if hasattr(t, "handle"):
# There are many variable-like objects, all of them currently have
# `handle` attribute that points to a tensor. If this changes, internals
@ -940,6 +944,11 @@ class GradientTape(object):
flat_targets = []
for t in nest.flatten(target):
if not t.dtype.is_floating:
logging.vlog(
logging.WARN, "The dtype of the target tensor must be "
"floating (e.g. tf.float32) when calling GradientTape.gradient, "
"got %r", t.dtype)
if resource_variable_ops.is_resource_variable(t):
with self:
t = ops.convert_to_tensor(t)
@ -948,6 +957,12 @@ class GradientTape(object):
flat_sources = nest.flatten(sources)
flat_sources_raw = flat_sources
flat_sources = [_handle_or_self(x) for x in flat_sources]
for t in flat_sources:
if not t.dtype.is_floating:
logging.vlog(
logging.WARN, "The dtype of the source tensor must be "
"floating (e.g. tf.float32) when calling GradientTape.gradient, "
"got %r", t.dtype)
if output_gradients is not None:
output_gradients = [None if x is None else ops.convert_to_tensor(x)