STT-tensorflow/tensorflow/python/ops/default_gradient.py
Saurabh Saxena 6d7211299d Support taking gradients of tf.cond and tf.while_loop using LookupTable.
For tf.cond, this required that we don't create a default zeros output grad when the output grad for all branch functions is None. E.g. since LookupTable ops are marked non-differentiable the output gradient wrt the LookupTable resource tensor is always None. Right now we try to convert that to a zeros tensor which is not supported.
Also added support for tf.cond v2 to have branch functions with no outputs. This is necessary now that we may have grad If ops with no outputs.

In tf.while_loop, since a captured LookupTable resource is a loop output as well, due to the requirement for matching input and output signatures, gradients_util tries to create a default gradient for the LookupTable which is not supported. So in gradients_util we now check whether the resource is a differentiable resource before building the default grad. Hopefully we can avoid this once we have explicit captures in While.

PiperOrigin-RevId: 277099963
Change-Id: Ib1e87fe42213bd10294d63c6ed4e77859489f1ce
2019-10-28 11:06:28 -07:00

85 lines
3.0 KiB
Python

# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Utilities for computing default gradients."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import tensor_shape
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import resource_variable_ops
def get_zeros_dtype(t):
"""Return the dtype for the default gradient for a Tensor."""
if t.dtype == dtypes.resource:
handle_data = resource_variable_ops.get_eager_safe_handle_data(t)
if (handle_data is None or not handle_data.is_set or
len(handle_data.shape_and_type) != 1):
raise ValueError("Internal error: Tried to take gradients (or similar) "
"of a variable without handle data:\n%s" % str(t))
return handle_data.shape_and_type[0].dtype
return t.dtype
def shape_and_dtype(t):
"""Return the shape and dtype for the default gradient for a Tensor."""
if t.dtype == dtypes.resource:
handle_data = resource_variable_ops.get_eager_safe_handle_data(t)
if (handle_data is None or not handle_data.is_set or
len(handle_data.shape_and_type) != 1):
raise ValueError("Internal error: Tried to take gradients (or similar) "
"of a variable without handle data:\n%s" % str(t))
shape_and_type = handle_data.shape_and_type[0]
return (tensor_shape.TensorShape(shape_and_type.shape),
dtypes.as_dtype(shape_and_type.dtype))
return t.shape, t.dtype
def zeros_like(t):
"""Like array_ops.zeros_like, but respects resource handles."""
if t.dtype == dtypes.resource:
return array_ops.zeros(*shape_and_dtype(t))
else:
return array_ops.zeros_like(t)
def ones_like(t):
"""Like array_ops.ones_like, but respects resource handles."""
if t.dtype == dtypes.resource:
return array_ops.ones(*shape_and_dtype(t))
else:
return array_ops.ones_like(t)
def supports_default_grad(t):
"""Whether tensor `t` supports creating a default gradient.
This function assumes that `t` is of a trainable type.
Args:
t: Tensor
Returns:
Bool
"""
if t.dtype == dtypes.resource:
handle_data = resource_variable_ops.get_eager_safe_handle_data(t)
if (handle_data is None or not handle_data.is_set or
len(handle_data.shape_and_type) != 1):
return False
return True