Silence a warning when reading variables with forwardprop

GradientTapes, used in the implementation, didn't like non-floating inputs. Silences those warnings for trainable types (although I've kept the warning text the same since I think mentioning variants and resource handles would be needlessly confusing).

PiperOrigin-RevId: 266950768
This commit is contained in:
Allen Lavoie 2019-09-03 09:48:29 -07:00 committed by TensorFlower Gardener
parent ce2b635fcd
commit 619ad608cb
9 changed files with 70 additions and 28 deletions

View File

@ -2901,6 +2901,8 @@ py_library(
":unconnected_gradients",
":util",
"//tensorflow/core:protos_all_py",
"//tensorflow/python/eager:backprop",
"//tensorflow/python/eager:backprop_util",
"//tensorflow/python/eager:context",
"//third_party/py/numpy",
"@six_archive//:six",

View File

@ -501,6 +501,7 @@ py_library(
srcs_version = "PY2AND3",
visibility = ["//tensorflow:internal"],
deps = [
":backprop_util",
":context",
":execute",
":imperative_grad",
@ -520,6 +521,17 @@ py_library(
],
)
py_library(
name = "backprop_util",
srcs = ["backprop_util.py"],
srcs_version = "PY2AND3",
visibility = ["//tensorflow:internal"],
deps = [
"//tensorflow/python:dtypes",
"//tensorflow/python:tensor_util",
],
)
py_library(
name = "forwardprop",
srcs = ["forwardprop.py"],

View File

@ -26,6 +26,7 @@ import six
from tensorflow.python import pywrap_tensorflow
from tensorflow.python import _pywrap_utils
from tensorflow.python.eager import backprop_util
from tensorflow.python.eager import context
from tensorflow.python.eager import execute
from tensorflow.python.eager import imperative_grad
@ -853,7 +854,7 @@ class GradientTape(object):
if not (_pywrap_utils.IsTensor(t) or _pywrap_utils.IsVariable(t)):
raise ValueError("Passed in object of type {}, not tf.Tensor".format(
type(t)))
if not t.dtype.is_floating:
if not backprop_util.IsTrainable(t):
logging.log_first_n(
logging.WARN, "The dtype of the watched tensor must be "
"floating (e.g. tf.float32), got %r", 5, t.dtype)
@ -987,7 +988,7 @@ class GradientTape(object):
flat_targets = []
for t in nest.flatten(target):
if not t.dtype.is_floating:
if not backprop_util.IsTrainable(t):
logging.vlog(
logging.WARN, "The dtype of the target tensor must be "
"floating (e.g. tf.float32) when calling GradientTape.gradient, "
@ -1001,7 +1002,7 @@ class GradientTape(object):
flat_sources_raw = flat_sources
flat_sources = [_handle_or_self(x) for x in flat_sources]
for t in flat_sources_raw:
if not t.dtype.is_floating:
if not backprop_util.IsTrainable(t):
logging.vlog(
logging.WARN, "The dtype of the source tensor must be "
"floating (e.g. tf.float32) when calling GradientTape.gradient, "

View File

@ -0,0 +1,33 @@
# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Shared utilities related to backprop."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import tensor_util
def IsTrainable(tensor_or_dtype):
if tensor_util.is_tensor(tensor_or_dtype):
dtype = tensor_or_dtype.dtype
else:
dtype = tensor_or_dtype
dtype = dtypes.as_dtype(dtype)
return dtype.base_dtype in (dtypes.float16, dtypes.float32, dtypes.float64,
dtypes.complex64, dtypes.complex128,
dtypes.resource, dtypes.variant)

View File

@ -20,13 +20,13 @@ from __future__ import print_function
from tensorflow.python import pywrap_tensorflow
from tensorflow.python.eager import backprop
from tensorflow.python.eager import backprop_util
from tensorflow.python.eager import def_function
from tensorflow.python.eager import execute
from tensorflow.python.framework import ops
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import gradients_util
from tensorflow.python.ops.unconnected_gradients import UnconnectedGradients
from tensorflow.python.platform import tf_logging as logging
from tensorflow.python.util import nest
@ -61,7 +61,7 @@ def _forward_gradient(op_name, attr_tuple, inputs, outputs, tangents):
trainable_indices = []
nontrivial_tangents = []
for input_index, tensor in enumerate(inputs):
if gradients_util.IsTrainable(tensor):
if backprop_util.IsTrainable(tensor):
trainable_inputs.append(tensor)
trainable_indices.append(input_index)
nontrivial_tangents.append(tangents[input_index])
@ -76,7 +76,7 @@ def _forward_gradient(op_name, attr_tuple, inputs, outputs, tangents):
trainable_outputs = []
nontrivial_output_indices = []
for output_index, output in enumerate(outputs):
if gradients_util.IsTrainable(output):
if backprop_util.IsTrainable(output):
forwardprop_aids.append(
array_ops.ones_like(output, name="unused_forwardprop_aid"))
trainable_outputs.append(output)

View File

@ -35,6 +35,7 @@ from tensorflow.core.framework import attr_value_pb2
from tensorflow.core.framework import function_pb2
from tensorflow.python import pywrap_tensorflow
from tensorflow.python import _pywrap_utils
from tensorflow.python.eager import backprop_util
from tensorflow.python.eager import context
from tensorflow.python.eager import execute
from tensorflow.python.eager import tape
@ -586,7 +587,7 @@ class _DelayedRewriteGradientFunctions(object):
"""
trainable_outputs = [
output for output in self._func_graph.outputs[:num_doutputs]
if gradients_util.IsTrainable(output)]
if backprop_util.IsTrainable(output)]
signature = []
for t in trainable_outputs:
@ -668,7 +669,7 @@ class _DelayedRewriteGradientFunctions(object):
# expects numeric inputs.
cleaned_doutputs = []
for doutput, placeholder in zip(doutputs, self._func_graph.outputs):
if gradients_util.IsTrainable(placeholder):
if backprop_util.IsTrainable(placeholder):
if doutput is not None:
cleaned_doutputs.append(doutput)
else:
@ -749,7 +750,7 @@ class _TapeGradientFunctions(object):
handles_to_variables = self._func_graph.variable_captures
trainable_outputs = []
for output in outputs:
if gradients_util.IsTrainable(output):
if backprop_util.IsTrainable(output):
# Swap in the Variable object for resource handles if we can so
# sparse gradients work.
output = handles_to_variables.get(ops.tensor_id(output), output)
@ -858,7 +859,7 @@ class _TapeGradientFunctions(object):
for output_index, output in enumerate(outputs):
if trainable_recorded_outputs < backward_function_inputs:
recorded_outputs.append(output)
if gradients_util.IsTrainable(output):
if backprop_util.IsTrainable(output):
trainable_recorded_outputs += 1
else:
skip_positions.append(output_index)

View File

@ -25,6 +25,7 @@ from __future__ import print_function
import collections
from tensorflow.python.eager import backprop_util
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import func_graph as func_graph_module
from tensorflow.python.framework import ops
@ -344,7 +345,7 @@ def _grad_fn(func_graph, grads):
ys = []
grad_ys = []
for y, grad_y in zip(func_graph.outputs, grads):
if not gradients_util.IsTrainable(y):
if not backprop_util.IsTrainable(y):
continue
ys.append(y)
grad_ys.append(grad_y)

View File

@ -25,6 +25,7 @@ from six.moves import xrange, zip # pylint: disable=redefined-builtin
from tensorflow.core.framework import attr_value_pb2
from tensorflow.python.eager import backprop
from tensorflow.python.eager import backprop_util
from tensorflow.python.eager import context
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
@ -226,19 +227,8 @@ def _DefaultGradYs(grad_ys,
return new_grad_ys
def IsTrainable(tensor_or_dtype):
if isinstance(tensor_or_dtype, ops.Tensor):
dtype = tensor_or_dtype.dtype
else:
dtype = tensor_or_dtype
dtype = dtypes.as_dtype(dtype)
return dtype.base_dtype in (dtypes.float16, dtypes.float32, dtypes.float64,
dtypes.complex64, dtypes.complex128,
dtypes.resource, dtypes.variant)
def _IsBackpropagatable(tensor):
if IsTrainable(tensor):
if backprop_util.IsTrainable(tensor):
return True
dtype = dtypes.as_dtype(tensor.dtype)
return dtype.base_dtype == dtypes.bfloat16
@ -592,7 +582,7 @@ def _GradientsHelper(ys,
if loop_state:
loop_exits = loop_state.ProcessUnusedLoopExits(pending_count, to_ops_set)
for y in loop_exits:
if IsTrainable(y):
if backprop_util.IsTrainable(y):
_SetGrad(grads, y, loop_state.ZerosLikeForExit(y))
queue.append(y.op)
@ -658,7 +648,8 @@ def _GradientsHelper(ys,
# therefore dC/doutput[i] is 0.
for i, out_grad in enumerate(out_grads):
if (not isinstance(out_grad, ops.Tensor) and not out_grad) and (
(not grad_fn and is_func_call) or IsTrainable(op.outputs[i])):
(not grad_fn and is_func_call)
or backprop_util.IsTrainable(op.outputs[i])):
# Only trainable outputs or outputs for a function call that
# will use SymbolicGradient get a zero gradient. Gradient
# functions should ignore the gradient for other outputs.
@ -765,7 +756,7 @@ def _UpdatePendingAndEnqueueReady(grads, op, queue, pending_count, loop_state,
# For an unused exit, if it has trainable outputs, backprop
# a zero gradient. Otherwise, just ignore it.
for y in grad_state.unused_exits:
if IsTrainable(y):
if backprop_util.IsTrainable(y):
_SetGrad(grads, y, loop_state.ZerosLikeForExit(y))
queue.append(y.op)
else:

View File

@ -24,6 +24,7 @@ from __future__ import division
from __future__ import print_function
from tensorflow.core.framework import attr_value_pb2
from tensorflow.python.eager import backprop_util
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import func_graph as func_graph_module
@ -509,7 +510,7 @@ def _zeros_like(op_output):
def _is_trainable(tensor):
"""Returns whether the given tensor is trainable."""
if not gradients_util.IsTrainable(tensor):
if not backprop_util.IsTrainable(tensor):
return False
# Special case: untrainable accumulator output. The gradients algorithm
@ -520,7 +521,7 @@ def _is_trainable(tensor):
if tensor.op.type == "TensorListPopBack" and tensor.value_index == 0:
assert tensor.dtype == dtypes.variant
element_type = tensor.op.get_attr("element_dtype")
return gradients_util.IsTrainable(element_type)
return backprop_util.IsTrainable(element_type)
return True