PR #27183: Add some log info within GradientTape
Imported from GitHub PR #27183 @alextp, reference to #26143. I saw @shashvatshahi1998 made some changes already, So I did some supplement here. Please take a look at it, Thanks. Copybara import of the project: - 958bf19a18c56e4512b464aaad41cbc2ad1141ac Add some log info within GradientTape by a6802739 <songhao9021@gmail.com> - e84c59595ad7d227a95917a32fbb264588d0de06 Merge 958bf19a18c56e4512b464aaad41cbc2ad1141ac into 1f95f... by songhao <songhao9021@gmail.com> COPYBARA_INTEGRATE_REVIEW=https://github.com/tensorflow/tensorflow/pull/27183 from a6802739:add_log_info 958bf19a18c56e4512b464aaad41cbc2ad1141ac PiperOrigin-RevId: 241399172
This commit is contained in:
parent
daf7c5337d
commit
764fdc6fb1
@ -812,6 +812,10 @@ class GradientTape(object):
|
||||
tensor: a Tensor or list of Tensors.
|
||||
"""
|
||||
for t in nest.flatten(tensor):
|
||||
if not t.dtype.is_floating:
|
||||
logging.vlog(
|
||||
logging.WARN, "The dtype of the watched tensor must be "
|
||||
"floating (e.g. tf.float32), got %r", t.dtype)
|
||||
if hasattr(t, "handle"):
|
||||
# There are many variable-like objects, all of them currently have
|
||||
# `handle` attribute that points to a tensor. If this changes, internals
|
||||
@ -940,6 +944,11 @@ class GradientTape(object):
|
||||
|
||||
flat_targets = []
|
||||
for t in nest.flatten(target):
|
||||
if not t.dtype.is_floating:
|
||||
logging.vlog(
|
||||
logging.WARN, "The dtype of the target tensor must be "
|
||||
"floating (e.g. tf.float32) when calling GradientTape.gradient, "
|
||||
"got %r", t.dtype)
|
||||
if resource_variable_ops.is_resource_variable(t):
|
||||
with self:
|
||||
t = ops.convert_to_tensor(t)
|
||||
@ -948,6 +957,12 @@ class GradientTape(object):
|
||||
flat_sources = nest.flatten(sources)
|
||||
flat_sources_raw = flat_sources
|
||||
flat_sources = [_handle_or_self(x) for x in flat_sources]
|
||||
for t in flat_sources:
|
||||
if not t.dtype.is_floating:
|
||||
logging.vlog(
|
||||
logging.WARN, "The dtype of the source tensor must be "
|
||||
"floating (e.g. tf.float32) when calling GradientTape.gradient, "
|
||||
"got %r", t.dtype)
|
||||
|
||||
if output_gradients is not None:
|
||||
output_gradients = [None if x is None else ops.convert_to_tensor(x)
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user