From 619ad608cb7d429dbf59fafab5ef90257438419b Mon Sep 17 00:00:00 2001
From: Allen Lavoie <allenl@google.com>
Date: Tue, 3 Sep 2019 09:48:29 -0700
Subject: [PATCH] Silence a warning when reading variables with forwardprop

GradientTapes, used in the implementation, didn't like non-floating inputs. Silences those warnings for trainable types (although I've kept the warning text the same since I think mentioning variants and resource handles would be needlessly confusing).

PiperOrigin-RevId: 266950768
---
 tensorflow/python/BUILD                  |  2 ++
 tensorflow/python/eager/BUILD            | 12 +++++++++
 tensorflow/python/eager/backprop.py      |  7 ++---
 tensorflow/python/eager/backprop_util.py | 33 ++++++++++++++++++++++++
 tensorflow/python/eager/forwardprop.py   |  6 ++---
 tensorflow/python/eager/function.py      |  9 ++++---
 tensorflow/python/ops/cond_v2.py         |  3 ++-
 tensorflow/python/ops/gradients_util.py  | 21 +++++----------
 tensorflow/python/ops/while_v2.py        |  5 ++--
 9 files changed, 70 insertions(+), 28 deletions(-)
 create mode 100644 tensorflow/python/eager/backprop_util.py

diff --git a/tensorflow/python/BUILD b/tensorflow/python/BUILD
index 93701ea3848..2362ce91620 100644
--- a/tensorflow/python/BUILD
+++ b/tensorflow/python/BUILD
@@ -2901,6 +2901,8 @@ py_library(
         ":unconnected_gradients",
         ":util",
         "//tensorflow/core:protos_all_py",
+        "//tensorflow/python/eager:backprop",
+        "//tensorflow/python/eager:backprop_util",
         "//tensorflow/python/eager:context",
         "//third_party/py/numpy",
         "@six_archive//:six",
diff --git a/tensorflow/python/eager/BUILD b/tensorflow/python/eager/BUILD
index 4c93ba13fbc..c6d2f1662d1 100644
--- a/tensorflow/python/eager/BUILD
+++ b/tensorflow/python/eager/BUILD
@@ -501,6 +501,7 @@ py_library(
     srcs_version = "PY2AND3",
     visibility = ["//tensorflow:internal"],
     deps = [
+        ":backprop_util",
         ":context",
         ":execute",
         ":imperative_grad",
@@ -520,6 +521,17 @@ py_library(
     ],
 )
 
+py_library(
+    name = "backprop_util",
+    srcs = ["backprop_util.py"],
+    srcs_version = "PY2AND3",
+    visibility = ["//tensorflow:internal"],
+    deps = [
+        "//tensorflow/python:dtypes",
+        "//tensorflow/python:tensor_util",
+    ],
+)
+
 py_library(
     name = "forwardprop",
     srcs = ["forwardprop.py"],
diff --git a/tensorflow/python/eager/backprop.py b/tensorflow/python/eager/backprop.py
index 0fdc0d7e53c..826d39c4777 100644
--- a/tensorflow/python/eager/backprop.py
+++ b/tensorflow/python/eager/backprop.py
@@ -26,6 +26,7 @@ import six
 
 from tensorflow.python import pywrap_tensorflow
 from tensorflow.python import _pywrap_utils
+from tensorflow.python.eager import backprop_util
 from tensorflow.python.eager import context
 from tensorflow.python.eager import execute
 from tensorflow.python.eager import imperative_grad
@@ -853,7 +854,7 @@ class GradientTape(object):
       if not (_pywrap_utils.IsTensor(t) or _pywrap_utils.IsVariable(t)):
         raise ValueError("Passed in object of type {}, not tf.Tensor".format(
             type(t)))
-      if not t.dtype.is_floating:
+      if not backprop_util.IsTrainable(t):
         logging.log_first_n(
             logging.WARN, "The dtype of the watched tensor must be "
             "floating (e.g. tf.float32), got %r", 5, t.dtype)
@@ -987,7 +988,7 @@ class GradientTape(object):
 
     flat_targets = []
     for t in nest.flatten(target):
-      if not t.dtype.is_floating:
+      if not backprop_util.IsTrainable(t):
         logging.vlog(
             logging.WARN, "The dtype of the target tensor must be "
             "floating (e.g. tf.float32) when calling GradientTape.gradient, "
@@ -1001,7 +1002,7 @@ class GradientTape(object):
     flat_sources_raw = flat_sources
     flat_sources = [_handle_or_self(x) for x in flat_sources]
     for t in flat_sources_raw:
-      if not t.dtype.is_floating:
+      if not backprop_util.IsTrainable(t):
         logging.vlog(
             logging.WARN, "The dtype of the source tensor must be "
             "floating (e.g. tf.float32) when calling GradientTape.gradient, "
diff --git a/tensorflow/python/eager/backprop_util.py b/tensorflow/python/eager/backprop_util.py
new file mode 100644
index 00000000000..ae026c0fbbb
--- /dev/null
+++ b/tensorflow/python/eager/backprop_util.py
@@ -0,0 +1,33 @@
+# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Shared utilities related to backprop."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from tensorflow.python.framework import dtypes
+from tensorflow.python.framework import tensor_util
+
+
+def IsTrainable(tensor_or_dtype):
+  if tensor_util.is_tensor(tensor_or_dtype):
+    dtype = tensor_or_dtype.dtype
+  else:
+    dtype = tensor_or_dtype
+  dtype = dtypes.as_dtype(dtype)
+  return dtype.base_dtype in (dtypes.float16, dtypes.float32, dtypes.float64,
+                              dtypes.complex64, dtypes.complex128,
+                              dtypes.resource, dtypes.variant)
diff --git a/tensorflow/python/eager/forwardprop.py b/tensorflow/python/eager/forwardprop.py
index a2bfc02daf0..145fada8124 100644
--- a/tensorflow/python/eager/forwardprop.py
+++ b/tensorflow/python/eager/forwardprop.py
@@ -20,13 +20,13 @@ from __future__ import print_function
 
 from tensorflow.python import pywrap_tensorflow
 from tensorflow.python.eager import backprop
+from tensorflow.python.eager import backprop_util
 from tensorflow.python.eager import def_function
 from tensorflow.python.eager import execute
 
 from tensorflow.python.framework import ops
 
 from tensorflow.python.ops import array_ops
-from tensorflow.python.ops import gradients_util
 from tensorflow.python.ops.unconnected_gradients import UnconnectedGradients
 from tensorflow.python.platform import tf_logging as logging
 from tensorflow.python.util import nest
@@ -61,7 +61,7 @@ def _forward_gradient(op_name, attr_tuple, inputs, outputs, tangents):
   trainable_indices = []
   nontrivial_tangents = []
   for input_index, tensor in enumerate(inputs):
-    if gradients_util.IsTrainable(tensor):
+    if backprop_util.IsTrainable(tensor):
       trainable_inputs.append(tensor)
       trainable_indices.append(input_index)
       nontrivial_tangents.append(tangents[input_index])
@@ -76,7 +76,7 @@ def _forward_gradient(op_name, attr_tuple, inputs, outputs, tangents):
     trainable_outputs = []
     nontrivial_output_indices = []
     for output_index, output in enumerate(outputs):
-      if gradients_util.IsTrainable(output):
+      if backprop_util.IsTrainable(output):
         forwardprop_aids.append(
             array_ops.ones_like(output, name="unused_forwardprop_aid"))
         trainable_outputs.append(output)
diff --git a/tensorflow/python/eager/function.py b/tensorflow/python/eager/function.py
index 6447bb55f7d..41c6032f30e 100644
--- a/tensorflow/python/eager/function.py
+++ b/tensorflow/python/eager/function.py
@@ -35,6 +35,7 @@ from tensorflow.core.framework import attr_value_pb2
 from tensorflow.core.framework import function_pb2
 from tensorflow.python import pywrap_tensorflow
 from tensorflow.python import _pywrap_utils
+from tensorflow.python.eager import backprop_util
 from tensorflow.python.eager import context
 from tensorflow.python.eager import execute
 from tensorflow.python.eager import tape
@@ -586,7 +587,7 @@ class _DelayedRewriteGradientFunctions(object):
     """
     trainable_outputs = [
         output for output in self._func_graph.outputs[:num_doutputs]
-        if gradients_util.IsTrainable(output)]
+        if backprop_util.IsTrainable(output)]
 
     signature = []
     for t in trainable_outputs:
@@ -668,7 +669,7 @@ class _DelayedRewriteGradientFunctions(object):
     # expects numeric inputs.
     cleaned_doutputs = []
     for doutput, placeholder in zip(doutputs, self._func_graph.outputs):
-      if gradients_util.IsTrainable(placeholder):
+      if backprop_util.IsTrainable(placeholder):
         if doutput is not None:
           cleaned_doutputs.append(doutput)
         else:
@@ -749,7 +750,7 @@ class _TapeGradientFunctions(object):
     handles_to_variables = self._func_graph.variable_captures
     trainable_outputs = []
     for output in outputs:
-      if gradients_util.IsTrainable(output):
+      if backprop_util.IsTrainable(output):
         # Swap in the Variable object for resource handles if we can so
         # sparse gradients work.
         output = handles_to_variables.get(ops.tensor_id(output), output)
@@ -858,7 +859,7 @@ class _TapeGradientFunctions(object):
     for output_index, output in enumerate(outputs):
       if trainable_recorded_outputs < backward_function_inputs:
         recorded_outputs.append(output)
-      if gradients_util.IsTrainable(output):
+      if backprop_util.IsTrainable(output):
         trainable_recorded_outputs += 1
       else:
         skip_positions.append(output_index)
diff --git a/tensorflow/python/ops/cond_v2.py b/tensorflow/python/ops/cond_v2.py
index b3eb9a5718c..3d099d52cbd 100644
--- a/tensorflow/python/ops/cond_v2.py
+++ b/tensorflow/python/ops/cond_v2.py
@@ -25,6 +25,7 @@ from __future__ import print_function
 
 import collections
 
+from tensorflow.python.eager import backprop_util
 from tensorflow.python.framework import dtypes
 from tensorflow.python.framework import func_graph as func_graph_module
 from tensorflow.python.framework import ops
@@ -344,7 +345,7 @@ def _grad_fn(func_graph, grads):
   ys = []
   grad_ys = []
   for y, grad_y in zip(func_graph.outputs, grads):
-    if not gradients_util.IsTrainable(y):
+    if not backprop_util.IsTrainable(y):
       continue
     ys.append(y)
     grad_ys.append(grad_y)
diff --git a/tensorflow/python/ops/gradients_util.py b/tensorflow/python/ops/gradients_util.py
index 2f8b15925d4..c89978bdfa4 100644
--- a/tensorflow/python/ops/gradients_util.py
+++ b/tensorflow/python/ops/gradients_util.py
@@ -25,6 +25,7 @@ from six.moves import xrange, zip  # pylint: disable=redefined-builtin
 
 from tensorflow.core.framework import attr_value_pb2
 from tensorflow.python.eager import backprop
+from tensorflow.python.eager import backprop_util
 from tensorflow.python.eager import context
 from tensorflow.python.framework import constant_op
 from tensorflow.python.framework import dtypes
@@ -226,19 +227,8 @@ def _DefaultGradYs(grad_ys,
   return new_grad_ys
 
 
-def IsTrainable(tensor_or_dtype):
-  if isinstance(tensor_or_dtype, ops.Tensor):
-    dtype = tensor_or_dtype.dtype
-  else:
-    dtype = tensor_or_dtype
-  dtype = dtypes.as_dtype(dtype)
-  return dtype.base_dtype in (dtypes.float16, dtypes.float32, dtypes.float64,
-                              dtypes.complex64, dtypes.complex128,
-                              dtypes.resource, dtypes.variant)
-
-
 def _IsBackpropagatable(tensor):
-  if IsTrainable(tensor):
+  if backprop_util.IsTrainable(tensor):
     return True
   dtype = dtypes.as_dtype(tensor.dtype)
   return dtype.base_dtype == dtypes.bfloat16
@@ -592,7 +582,7 @@ def _GradientsHelper(ys,
     if loop_state:
       loop_exits = loop_state.ProcessUnusedLoopExits(pending_count, to_ops_set)
       for y in loop_exits:
-        if IsTrainable(y):
+        if backprop_util.IsTrainable(y):
           _SetGrad(grads, y, loop_state.ZerosLikeForExit(y))
           queue.append(y.op)
 
@@ -658,7 +648,8 @@ def _GradientsHelper(ys,
           # therefore dC/doutput[i] is 0.
           for i, out_grad in enumerate(out_grads):
             if (not isinstance(out_grad, ops.Tensor) and not out_grad) and (
-                (not grad_fn and is_func_call) or IsTrainable(op.outputs[i])):
+                (not grad_fn and is_func_call)
+                or backprop_util.IsTrainable(op.outputs[i])):
               # Only trainable outputs or outputs for a function call that
               # will use SymbolicGradient get a zero gradient. Gradient
               # functions should ignore the gradient for other outputs.
@@ -765,7 +756,7 @@ def _UpdatePendingAndEnqueueReady(grads, op, queue, pending_count, loop_state,
             # For an unused exit, if it has trainable outputs, backprop
             # a zero gradient. Otherwise, just ignore it.
             for y in grad_state.unused_exits:
-              if IsTrainable(y):
+              if backprop_util.IsTrainable(y):
                 _SetGrad(grads, y, loop_state.ZerosLikeForExit(y))
               queue.append(y.op)
           else:
diff --git a/tensorflow/python/ops/while_v2.py b/tensorflow/python/ops/while_v2.py
index 73a767caf25..dfdf1ef83e9 100644
--- a/tensorflow/python/ops/while_v2.py
+++ b/tensorflow/python/ops/while_v2.py
@@ -24,6 +24,7 @@ from __future__ import division
 from __future__ import print_function
 
 from tensorflow.core.framework import attr_value_pb2
+from tensorflow.python.eager import backprop_util
 from tensorflow.python.framework import constant_op
 from tensorflow.python.framework import dtypes
 from tensorflow.python.framework import func_graph as func_graph_module
@@ -509,7 +510,7 @@ def _zeros_like(op_output):
 
 def _is_trainable(tensor):
   """Returns whether the given tensor is trainable."""
-  if not gradients_util.IsTrainable(tensor):
+  if not backprop_util.IsTrainable(tensor):
     return False
 
   # Special case: untrainable accumulator output. The gradients algorithm
@@ -520,7 +521,7 @@ def _is_trainable(tensor):
   if tensor.op.type == "TensorListPopBack" and tensor.value_index == 0:
     assert tensor.dtype == dtypes.variant
     element_type = tensor.op.get_attr("element_dtype")
-    return gradients_util.IsTrainable(element_type)
+    return backprop_util.IsTrainable(element_type)
 
   return True