From 619ad608cb7d429dbf59fafab5ef90257438419b Mon Sep 17 00:00:00 2001 From: Allen Lavoie <allenl@google.com> Date: Tue, 3 Sep 2019 09:48:29 -0700 Subject: [PATCH] Silence a warning when reading variables with forwardprop GradientTapes, used in the implementation, didn't like non-floating inputs. Silences those warnings for trainable types (although I've kept the warning text the same since I think mentioning variants and resource handles would be needlessly confusing). PiperOrigin-RevId: 266950768 --- tensorflow/python/BUILD | 2 ++ tensorflow/python/eager/BUILD | 12 +++++++++ tensorflow/python/eager/backprop.py | 7 ++--- tensorflow/python/eager/backprop_util.py | 33 ++++++++++++++++++++++++ tensorflow/python/eager/forwardprop.py | 6 ++--- tensorflow/python/eager/function.py | 9 ++++--- tensorflow/python/ops/cond_v2.py | 3 ++- tensorflow/python/ops/gradients_util.py | 21 +++++---------- tensorflow/python/ops/while_v2.py | 5 ++-- 9 files changed, 70 insertions(+), 28 deletions(-) create mode 100644 tensorflow/python/eager/backprop_util.py diff --git a/tensorflow/python/BUILD b/tensorflow/python/BUILD index 93701ea3848..2362ce91620 100644 --- a/tensorflow/python/BUILD +++ b/tensorflow/python/BUILD @@ -2901,6 +2901,8 @@ py_library( ":unconnected_gradients", ":util", "//tensorflow/core:protos_all_py", + "//tensorflow/python/eager:backprop", + "//tensorflow/python/eager:backprop_util", "//tensorflow/python/eager:context", "//third_party/py/numpy", "@six_archive//:six", diff --git a/tensorflow/python/eager/BUILD b/tensorflow/python/eager/BUILD index 4c93ba13fbc..c6d2f1662d1 100644 --- a/tensorflow/python/eager/BUILD +++ b/tensorflow/python/eager/BUILD @@ -501,6 +501,7 @@ py_library( srcs_version = "PY2AND3", visibility = ["//tensorflow:internal"], deps = [ + ":backprop_util", ":context", ":execute", ":imperative_grad", @@ -520,6 +521,17 @@ py_library( ], ) +py_library( + name = "backprop_util", + srcs = ["backprop_util.py"], + srcs_version = "PY2AND3", + visibility = ["//tensorflow:internal"], + deps = [ + "//tensorflow/python:dtypes", + "//tensorflow/python:tensor_util", + ], +) + py_library( name = "forwardprop", srcs = ["forwardprop.py"], diff --git a/tensorflow/python/eager/backprop.py b/tensorflow/python/eager/backprop.py index 0fdc0d7e53c..826d39c4777 100644 --- a/tensorflow/python/eager/backprop.py +++ b/tensorflow/python/eager/backprop.py @@ -26,6 +26,7 @@ import six from tensorflow.python import pywrap_tensorflow from tensorflow.python import _pywrap_utils +from tensorflow.python.eager import backprop_util from tensorflow.python.eager import context from tensorflow.python.eager import execute from tensorflow.python.eager import imperative_grad @@ -853,7 +854,7 @@ class GradientTape(object): if not (_pywrap_utils.IsTensor(t) or _pywrap_utils.IsVariable(t)): raise ValueError("Passed in object of type {}, not tf.Tensor".format( type(t))) - if not t.dtype.is_floating: + if not backprop_util.IsTrainable(t): logging.log_first_n( logging.WARN, "The dtype of the watched tensor must be " "floating (e.g. tf.float32), got %r", 5, t.dtype) @@ -987,7 +988,7 @@ class GradientTape(object): flat_targets = [] for t in nest.flatten(target): - if not t.dtype.is_floating: + if not backprop_util.IsTrainable(t): logging.vlog( logging.WARN, "The dtype of the target tensor must be " "floating (e.g. tf.float32) when calling GradientTape.gradient, " @@ -1001,7 +1002,7 @@ class GradientTape(object): flat_sources_raw = flat_sources flat_sources = [_handle_or_self(x) for x in flat_sources] for t in flat_sources_raw: - if not t.dtype.is_floating: + if not backprop_util.IsTrainable(t): logging.vlog( logging.WARN, "The dtype of the source tensor must be " "floating (e.g. tf.float32) when calling GradientTape.gradient, " diff --git a/tensorflow/python/eager/backprop_util.py b/tensorflow/python/eager/backprop_util.py new file mode 100644 index 00000000000..ae026c0fbbb --- /dev/null +++ b/tensorflow/python/eager/backprop_util.py @@ -0,0 +1,33 @@ +# Copyright 2015 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Shared utilities related to backprop.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import tensor_util + + +def IsTrainable(tensor_or_dtype): + if tensor_util.is_tensor(tensor_or_dtype): + dtype = tensor_or_dtype.dtype + else: + dtype = tensor_or_dtype + dtype = dtypes.as_dtype(dtype) + return dtype.base_dtype in (dtypes.float16, dtypes.float32, dtypes.float64, + dtypes.complex64, dtypes.complex128, + dtypes.resource, dtypes.variant) diff --git a/tensorflow/python/eager/forwardprop.py b/tensorflow/python/eager/forwardprop.py index a2bfc02daf0..145fada8124 100644 --- a/tensorflow/python/eager/forwardprop.py +++ b/tensorflow/python/eager/forwardprop.py @@ -20,13 +20,13 @@ from __future__ import print_function from tensorflow.python import pywrap_tensorflow from tensorflow.python.eager import backprop +from tensorflow.python.eager import backprop_util from tensorflow.python.eager import def_function from tensorflow.python.eager import execute from tensorflow.python.framework import ops from tensorflow.python.ops import array_ops -from tensorflow.python.ops import gradients_util from tensorflow.python.ops.unconnected_gradients import UnconnectedGradients from tensorflow.python.platform import tf_logging as logging from tensorflow.python.util import nest @@ -61,7 +61,7 @@ def _forward_gradient(op_name, attr_tuple, inputs, outputs, tangents): trainable_indices = [] nontrivial_tangents = [] for input_index, tensor in enumerate(inputs): - if gradients_util.IsTrainable(tensor): + if backprop_util.IsTrainable(tensor): trainable_inputs.append(tensor) trainable_indices.append(input_index) nontrivial_tangents.append(tangents[input_index]) @@ -76,7 +76,7 @@ def _forward_gradient(op_name, attr_tuple, inputs, outputs, tangents): trainable_outputs = [] nontrivial_output_indices = [] for output_index, output in enumerate(outputs): - if gradients_util.IsTrainable(output): + if backprop_util.IsTrainable(output): forwardprop_aids.append( array_ops.ones_like(output, name="unused_forwardprop_aid")) trainable_outputs.append(output) diff --git a/tensorflow/python/eager/function.py b/tensorflow/python/eager/function.py index 6447bb55f7d..41c6032f30e 100644 --- a/tensorflow/python/eager/function.py +++ b/tensorflow/python/eager/function.py @@ -35,6 +35,7 @@ from tensorflow.core.framework import attr_value_pb2 from tensorflow.core.framework import function_pb2 from tensorflow.python import pywrap_tensorflow from tensorflow.python import _pywrap_utils +from tensorflow.python.eager import backprop_util from tensorflow.python.eager import context from tensorflow.python.eager import execute from tensorflow.python.eager import tape @@ -586,7 +587,7 @@ class _DelayedRewriteGradientFunctions(object): """ trainable_outputs = [ output for output in self._func_graph.outputs[:num_doutputs] - if gradients_util.IsTrainable(output)] + if backprop_util.IsTrainable(output)] signature = [] for t in trainable_outputs: @@ -668,7 +669,7 @@ class _DelayedRewriteGradientFunctions(object): # expects numeric inputs. cleaned_doutputs = [] for doutput, placeholder in zip(doutputs, self._func_graph.outputs): - if gradients_util.IsTrainable(placeholder): + if backprop_util.IsTrainable(placeholder): if doutput is not None: cleaned_doutputs.append(doutput) else: @@ -749,7 +750,7 @@ class _TapeGradientFunctions(object): handles_to_variables = self._func_graph.variable_captures trainable_outputs = [] for output in outputs: - if gradients_util.IsTrainable(output): + if backprop_util.IsTrainable(output): # Swap in the Variable object for resource handles if we can so # sparse gradients work. output = handles_to_variables.get(ops.tensor_id(output), output) @@ -858,7 +859,7 @@ class _TapeGradientFunctions(object): for output_index, output in enumerate(outputs): if trainable_recorded_outputs < backward_function_inputs: recorded_outputs.append(output) - if gradients_util.IsTrainable(output): + if backprop_util.IsTrainable(output): trainable_recorded_outputs += 1 else: skip_positions.append(output_index) diff --git a/tensorflow/python/ops/cond_v2.py b/tensorflow/python/ops/cond_v2.py index b3eb9a5718c..3d099d52cbd 100644 --- a/tensorflow/python/ops/cond_v2.py +++ b/tensorflow/python/ops/cond_v2.py @@ -25,6 +25,7 @@ from __future__ import print_function import collections +from tensorflow.python.eager import backprop_util from tensorflow.python.framework import dtypes from tensorflow.python.framework import func_graph as func_graph_module from tensorflow.python.framework import ops @@ -344,7 +345,7 @@ def _grad_fn(func_graph, grads): ys = [] grad_ys = [] for y, grad_y in zip(func_graph.outputs, grads): - if not gradients_util.IsTrainable(y): + if not backprop_util.IsTrainable(y): continue ys.append(y) grad_ys.append(grad_y) diff --git a/tensorflow/python/ops/gradients_util.py b/tensorflow/python/ops/gradients_util.py index 2f8b15925d4..c89978bdfa4 100644 --- a/tensorflow/python/ops/gradients_util.py +++ b/tensorflow/python/ops/gradients_util.py @@ -25,6 +25,7 @@ from six.moves import xrange, zip # pylint: disable=redefined-builtin from tensorflow.core.framework import attr_value_pb2 from tensorflow.python.eager import backprop +from tensorflow.python.eager import backprop_util from tensorflow.python.eager import context from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes @@ -226,19 +227,8 @@ def _DefaultGradYs(grad_ys, return new_grad_ys -def IsTrainable(tensor_or_dtype): - if isinstance(tensor_or_dtype, ops.Tensor): - dtype = tensor_or_dtype.dtype - else: - dtype = tensor_or_dtype - dtype = dtypes.as_dtype(dtype) - return dtype.base_dtype in (dtypes.float16, dtypes.float32, dtypes.float64, - dtypes.complex64, dtypes.complex128, - dtypes.resource, dtypes.variant) - - def _IsBackpropagatable(tensor): - if IsTrainable(tensor): + if backprop_util.IsTrainable(tensor): return True dtype = dtypes.as_dtype(tensor.dtype) return dtype.base_dtype == dtypes.bfloat16 @@ -592,7 +582,7 @@ def _GradientsHelper(ys, if loop_state: loop_exits = loop_state.ProcessUnusedLoopExits(pending_count, to_ops_set) for y in loop_exits: - if IsTrainable(y): + if backprop_util.IsTrainable(y): _SetGrad(grads, y, loop_state.ZerosLikeForExit(y)) queue.append(y.op) @@ -658,7 +648,8 @@ def _GradientsHelper(ys, # therefore dC/doutput[i] is 0. for i, out_grad in enumerate(out_grads): if (not isinstance(out_grad, ops.Tensor) and not out_grad) and ( - (not grad_fn and is_func_call) or IsTrainable(op.outputs[i])): + (not grad_fn and is_func_call) + or backprop_util.IsTrainable(op.outputs[i])): # Only trainable outputs or outputs for a function call that # will use SymbolicGradient get a zero gradient. Gradient # functions should ignore the gradient for other outputs. @@ -765,7 +756,7 @@ def _UpdatePendingAndEnqueueReady(grads, op, queue, pending_count, loop_state, # For an unused exit, if it has trainable outputs, backprop # a zero gradient. Otherwise, just ignore it. for y in grad_state.unused_exits: - if IsTrainable(y): + if backprop_util.IsTrainable(y): _SetGrad(grads, y, loop_state.ZerosLikeForExit(y)) queue.append(y.op) else: diff --git a/tensorflow/python/ops/while_v2.py b/tensorflow/python/ops/while_v2.py index 73a767caf25..dfdf1ef83e9 100644 --- a/tensorflow/python/ops/while_v2.py +++ b/tensorflow/python/ops/while_v2.py @@ -24,6 +24,7 @@ from __future__ import division from __future__ import print_function from tensorflow.core.framework import attr_value_pb2 +from tensorflow.python.eager import backprop_util from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes from tensorflow.python.framework import func_graph as func_graph_module @@ -509,7 +510,7 @@ def _zeros_like(op_output): def _is_trainable(tensor): """Returns whether the given tensor is trainable.""" - if not gradients_util.IsTrainable(tensor): + if not backprop_util.IsTrainable(tensor): return False # Special case: untrainable accumulator output. The gradients algorithm @@ -520,7 +521,7 @@ def _is_trainable(tensor): if tensor.op.type == "TensorListPopBack" and tensor.value_index == 0: assert tensor.dtype == dtypes.variant element_type = tensor.op.get_attr("element_dtype") - return gradients_util.IsTrainable(element_type) + return backprop_util.IsTrainable(element_type) return True