Silence a warning when reading variables with forwardprop
GradientTapes, used in the implementation, didn't like non-floating inputs. Silences those warnings for trainable types (although I've kept the warning text the same since I think mentioning variants and resource handles would be needlessly confusing). PiperOrigin-RevId: 266950768
This commit is contained in:
parent
ce2b635fcd
commit
619ad608cb
@ -2901,6 +2901,8 @@ py_library(
|
||||
":unconnected_gradients",
|
||||
":util",
|
||||
"//tensorflow/core:protos_all_py",
|
||||
"//tensorflow/python/eager:backprop",
|
||||
"//tensorflow/python/eager:backprop_util",
|
||||
"//tensorflow/python/eager:context",
|
||||
"//third_party/py/numpy",
|
||||
"@six_archive//:six",
|
||||
|
@ -501,6 +501,7 @@ py_library(
|
||||
srcs_version = "PY2AND3",
|
||||
visibility = ["//tensorflow:internal"],
|
||||
deps = [
|
||||
":backprop_util",
|
||||
":context",
|
||||
":execute",
|
||||
":imperative_grad",
|
||||
@ -520,6 +521,17 @@ py_library(
|
||||
],
|
||||
)
|
||||
|
||||
py_library(
|
||||
name = "backprop_util",
|
||||
srcs = ["backprop_util.py"],
|
||||
srcs_version = "PY2AND3",
|
||||
visibility = ["//tensorflow:internal"],
|
||||
deps = [
|
||||
"//tensorflow/python:dtypes",
|
||||
"//tensorflow/python:tensor_util",
|
||||
],
|
||||
)
|
||||
|
||||
py_library(
|
||||
name = "forwardprop",
|
||||
srcs = ["forwardprop.py"],
|
||||
|
@ -26,6 +26,7 @@ import six
|
||||
|
||||
from tensorflow.python import pywrap_tensorflow
|
||||
from tensorflow.python import _pywrap_utils
|
||||
from tensorflow.python.eager import backprop_util
|
||||
from tensorflow.python.eager import context
|
||||
from tensorflow.python.eager import execute
|
||||
from tensorflow.python.eager import imperative_grad
|
||||
@ -853,7 +854,7 @@ class GradientTape(object):
|
||||
if not (_pywrap_utils.IsTensor(t) or _pywrap_utils.IsVariable(t)):
|
||||
raise ValueError("Passed in object of type {}, not tf.Tensor".format(
|
||||
type(t)))
|
||||
if not t.dtype.is_floating:
|
||||
if not backprop_util.IsTrainable(t):
|
||||
logging.log_first_n(
|
||||
logging.WARN, "The dtype of the watched tensor must be "
|
||||
"floating (e.g. tf.float32), got %r", 5, t.dtype)
|
||||
@ -987,7 +988,7 @@ class GradientTape(object):
|
||||
|
||||
flat_targets = []
|
||||
for t in nest.flatten(target):
|
||||
if not t.dtype.is_floating:
|
||||
if not backprop_util.IsTrainable(t):
|
||||
logging.vlog(
|
||||
logging.WARN, "The dtype of the target tensor must be "
|
||||
"floating (e.g. tf.float32) when calling GradientTape.gradient, "
|
||||
@ -1001,7 +1002,7 @@ class GradientTape(object):
|
||||
flat_sources_raw = flat_sources
|
||||
flat_sources = [_handle_or_self(x) for x in flat_sources]
|
||||
for t in flat_sources_raw:
|
||||
if not t.dtype.is_floating:
|
||||
if not backprop_util.IsTrainable(t):
|
||||
logging.vlog(
|
||||
logging.WARN, "The dtype of the source tensor must be "
|
||||
"floating (e.g. tf.float32) when calling GradientTape.gradient, "
|
||||
|
33
tensorflow/python/eager/backprop_util.py
Normal file
33
tensorflow/python/eager/backprop_util.py
Normal file
@ -0,0 +1,33 @@
|
||||
# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""Shared utilities related to backprop."""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import tensor_util
|
||||
|
||||
|
||||
def IsTrainable(tensor_or_dtype):
|
||||
if tensor_util.is_tensor(tensor_or_dtype):
|
||||
dtype = tensor_or_dtype.dtype
|
||||
else:
|
||||
dtype = tensor_or_dtype
|
||||
dtype = dtypes.as_dtype(dtype)
|
||||
return dtype.base_dtype in (dtypes.float16, dtypes.float32, dtypes.float64,
|
||||
dtypes.complex64, dtypes.complex128,
|
||||
dtypes.resource, dtypes.variant)
|
@ -20,13 +20,13 @@ from __future__ import print_function
|
||||
|
||||
from tensorflow.python import pywrap_tensorflow
|
||||
from tensorflow.python.eager import backprop
|
||||
from tensorflow.python.eager import backprop_util
|
||||
from tensorflow.python.eager import def_function
|
||||
from tensorflow.python.eager import execute
|
||||
|
||||
from tensorflow.python.framework import ops
|
||||
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.ops import gradients_util
|
||||
from tensorflow.python.ops.unconnected_gradients import UnconnectedGradients
|
||||
from tensorflow.python.platform import tf_logging as logging
|
||||
from tensorflow.python.util import nest
|
||||
@ -61,7 +61,7 @@ def _forward_gradient(op_name, attr_tuple, inputs, outputs, tangents):
|
||||
trainable_indices = []
|
||||
nontrivial_tangents = []
|
||||
for input_index, tensor in enumerate(inputs):
|
||||
if gradients_util.IsTrainable(tensor):
|
||||
if backprop_util.IsTrainable(tensor):
|
||||
trainable_inputs.append(tensor)
|
||||
trainable_indices.append(input_index)
|
||||
nontrivial_tangents.append(tangents[input_index])
|
||||
@ -76,7 +76,7 @@ def _forward_gradient(op_name, attr_tuple, inputs, outputs, tangents):
|
||||
trainable_outputs = []
|
||||
nontrivial_output_indices = []
|
||||
for output_index, output in enumerate(outputs):
|
||||
if gradients_util.IsTrainable(output):
|
||||
if backprop_util.IsTrainable(output):
|
||||
forwardprop_aids.append(
|
||||
array_ops.ones_like(output, name="unused_forwardprop_aid"))
|
||||
trainable_outputs.append(output)
|
||||
|
@ -35,6 +35,7 @@ from tensorflow.core.framework import attr_value_pb2
|
||||
from tensorflow.core.framework import function_pb2
|
||||
from tensorflow.python import pywrap_tensorflow
|
||||
from tensorflow.python import _pywrap_utils
|
||||
from tensorflow.python.eager import backprop_util
|
||||
from tensorflow.python.eager import context
|
||||
from tensorflow.python.eager import execute
|
||||
from tensorflow.python.eager import tape
|
||||
@ -586,7 +587,7 @@ class _DelayedRewriteGradientFunctions(object):
|
||||
"""
|
||||
trainable_outputs = [
|
||||
output for output in self._func_graph.outputs[:num_doutputs]
|
||||
if gradients_util.IsTrainable(output)]
|
||||
if backprop_util.IsTrainable(output)]
|
||||
|
||||
signature = []
|
||||
for t in trainable_outputs:
|
||||
@ -668,7 +669,7 @@ class _DelayedRewriteGradientFunctions(object):
|
||||
# expects numeric inputs.
|
||||
cleaned_doutputs = []
|
||||
for doutput, placeholder in zip(doutputs, self._func_graph.outputs):
|
||||
if gradients_util.IsTrainable(placeholder):
|
||||
if backprop_util.IsTrainable(placeholder):
|
||||
if doutput is not None:
|
||||
cleaned_doutputs.append(doutput)
|
||||
else:
|
||||
@ -749,7 +750,7 @@ class _TapeGradientFunctions(object):
|
||||
handles_to_variables = self._func_graph.variable_captures
|
||||
trainable_outputs = []
|
||||
for output in outputs:
|
||||
if gradients_util.IsTrainable(output):
|
||||
if backprop_util.IsTrainable(output):
|
||||
# Swap in the Variable object for resource handles if we can so
|
||||
# sparse gradients work.
|
||||
output = handles_to_variables.get(ops.tensor_id(output), output)
|
||||
@ -858,7 +859,7 @@ class _TapeGradientFunctions(object):
|
||||
for output_index, output in enumerate(outputs):
|
||||
if trainable_recorded_outputs < backward_function_inputs:
|
||||
recorded_outputs.append(output)
|
||||
if gradients_util.IsTrainable(output):
|
||||
if backprop_util.IsTrainable(output):
|
||||
trainable_recorded_outputs += 1
|
||||
else:
|
||||
skip_positions.append(output_index)
|
||||
|
@ -25,6 +25,7 @@ from __future__ import print_function
|
||||
|
||||
import collections
|
||||
|
||||
from tensorflow.python.eager import backprop_util
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import func_graph as func_graph_module
|
||||
from tensorflow.python.framework import ops
|
||||
@ -344,7 +345,7 @@ def _grad_fn(func_graph, grads):
|
||||
ys = []
|
||||
grad_ys = []
|
||||
for y, grad_y in zip(func_graph.outputs, grads):
|
||||
if not gradients_util.IsTrainable(y):
|
||||
if not backprop_util.IsTrainable(y):
|
||||
continue
|
||||
ys.append(y)
|
||||
grad_ys.append(grad_y)
|
||||
|
@ -25,6 +25,7 @@ from six.moves import xrange, zip # pylint: disable=redefined-builtin
|
||||
|
||||
from tensorflow.core.framework import attr_value_pb2
|
||||
from tensorflow.python.eager import backprop
|
||||
from tensorflow.python.eager import backprop_util
|
||||
from tensorflow.python.eager import context
|
||||
from tensorflow.python.framework import constant_op
|
||||
from tensorflow.python.framework import dtypes
|
||||
@ -226,19 +227,8 @@ def _DefaultGradYs(grad_ys,
|
||||
return new_grad_ys
|
||||
|
||||
|
||||
def IsTrainable(tensor_or_dtype):
|
||||
if isinstance(tensor_or_dtype, ops.Tensor):
|
||||
dtype = tensor_or_dtype.dtype
|
||||
else:
|
||||
dtype = tensor_or_dtype
|
||||
dtype = dtypes.as_dtype(dtype)
|
||||
return dtype.base_dtype in (dtypes.float16, dtypes.float32, dtypes.float64,
|
||||
dtypes.complex64, dtypes.complex128,
|
||||
dtypes.resource, dtypes.variant)
|
||||
|
||||
|
||||
def _IsBackpropagatable(tensor):
|
||||
if IsTrainable(tensor):
|
||||
if backprop_util.IsTrainable(tensor):
|
||||
return True
|
||||
dtype = dtypes.as_dtype(tensor.dtype)
|
||||
return dtype.base_dtype == dtypes.bfloat16
|
||||
@ -592,7 +582,7 @@ def _GradientsHelper(ys,
|
||||
if loop_state:
|
||||
loop_exits = loop_state.ProcessUnusedLoopExits(pending_count, to_ops_set)
|
||||
for y in loop_exits:
|
||||
if IsTrainable(y):
|
||||
if backprop_util.IsTrainable(y):
|
||||
_SetGrad(grads, y, loop_state.ZerosLikeForExit(y))
|
||||
queue.append(y.op)
|
||||
|
||||
@ -658,7 +648,8 @@ def _GradientsHelper(ys,
|
||||
# therefore dC/doutput[i] is 0.
|
||||
for i, out_grad in enumerate(out_grads):
|
||||
if (not isinstance(out_grad, ops.Tensor) and not out_grad) and (
|
||||
(not grad_fn and is_func_call) or IsTrainable(op.outputs[i])):
|
||||
(not grad_fn and is_func_call)
|
||||
or backprop_util.IsTrainable(op.outputs[i])):
|
||||
# Only trainable outputs or outputs for a function call that
|
||||
# will use SymbolicGradient get a zero gradient. Gradient
|
||||
# functions should ignore the gradient for other outputs.
|
||||
@ -765,7 +756,7 @@ def _UpdatePendingAndEnqueueReady(grads, op, queue, pending_count, loop_state,
|
||||
# For an unused exit, if it has trainable outputs, backprop
|
||||
# a zero gradient. Otherwise, just ignore it.
|
||||
for y in grad_state.unused_exits:
|
||||
if IsTrainable(y):
|
||||
if backprop_util.IsTrainable(y):
|
||||
_SetGrad(grads, y, loop_state.ZerosLikeForExit(y))
|
||||
queue.append(y.op)
|
||||
else:
|
||||
|
@ -24,6 +24,7 @@ from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from tensorflow.core.framework import attr_value_pb2
|
||||
from tensorflow.python.eager import backprop_util
|
||||
from tensorflow.python.framework import constant_op
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import func_graph as func_graph_module
|
||||
@ -509,7 +510,7 @@ def _zeros_like(op_output):
|
||||
|
||||
def _is_trainable(tensor):
|
||||
"""Returns whether the given tensor is trainable."""
|
||||
if not gradients_util.IsTrainable(tensor):
|
||||
if not backprop_util.IsTrainable(tensor):
|
||||
return False
|
||||
|
||||
# Special case: untrainable accumulator output. The gradients algorithm
|
||||
@ -520,7 +521,7 @@ def _is_trainable(tensor):
|
||||
if tensor.op.type == "TensorListPopBack" and tensor.value_index == 0:
|
||||
assert tensor.dtype == dtypes.variant
|
||||
element_type = tensor.op.get_attr("element_dtype")
|
||||
return gradients_util.IsTrainable(element_type)
|
||||
return backprop_util.IsTrainable(element_type)
|
||||
|
||||
return True
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user