For tf.cond, this required that we don't create a default zeros output grad when the output grad for all branch functions is None. E.g. since LookupTable ops are marked non-differentiable the output gradient wrt the LookupTable resource tensor is always None. Right now we try to convert that to a zeros tensor which is not supported. Also added support for tf.cond v2 to have branch functions with no outputs. This is necessary now that we may have grad If ops with no outputs. In tf.while_loop, since a captured LookupTable resource is a loop output as well, due to the requirement for matching input and output signatures, gradients_util tries to create a default gradient for the LookupTable which is not supported. So in gradients_util we now check whether the resource is a differentiable resource before building the default grad. Hopefully we can avoid this once we have explicit captures in While. PiperOrigin-RevId: 277099963 Change-Id: Ib1e87fe42213bd10294d63c6ed4e77859489f1ce
85 lines
3.0 KiB
Python
85 lines
3.0 KiB
Python
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
# ==============================================================================
|
|
"""Utilities for computing default gradients."""
|
|
from __future__ import absolute_import
|
|
from __future__ import division
|
|
from __future__ import print_function
|
|
|
|
from tensorflow.python.framework import dtypes
|
|
from tensorflow.python.framework import tensor_shape
|
|
from tensorflow.python.ops import array_ops
|
|
from tensorflow.python.ops import resource_variable_ops
|
|
|
|
|
|
def get_zeros_dtype(t):
|
|
"""Return the dtype for the default gradient for a Tensor."""
|
|
if t.dtype == dtypes.resource:
|
|
handle_data = resource_variable_ops.get_eager_safe_handle_data(t)
|
|
if (handle_data is None or not handle_data.is_set or
|
|
len(handle_data.shape_and_type) != 1):
|
|
raise ValueError("Internal error: Tried to take gradients (or similar) "
|
|
"of a variable without handle data:\n%s" % str(t))
|
|
return handle_data.shape_and_type[0].dtype
|
|
return t.dtype
|
|
|
|
|
|
def shape_and_dtype(t):
|
|
"""Return the shape and dtype for the default gradient for a Tensor."""
|
|
if t.dtype == dtypes.resource:
|
|
handle_data = resource_variable_ops.get_eager_safe_handle_data(t)
|
|
if (handle_data is None or not handle_data.is_set or
|
|
len(handle_data.shape_and_type) != 1):
|
|
raise ValueError("Internal error: Tried to take gradients (or similar) "
|
|
"of a variable without handle data:\n%s" % str(t))
|
|
shape_and_type = handle_data.shape_and_type[0]
|
|
return (tensor_shape.TensorShape(shape_and_type.shape),
|
|
dtypes.as_dtype(shape_and_type.dtype))
|
|
return t.shape, t.dtype
|
|
|
|
|
|
def zeros_like(t):
|
|
"""Like array_ops.zeros_like, but respects resource handles."""
|
|
if t.dtype == dtypes.resource:
|
|
return array_ops.zeros(*shape_and_dtype(t))
|
|
else:
|
|
return array_ops.zeros_like(t)
|
|
|
|
|
|
def ones_like(t):
|
|
"""Like array_ops.ones_like, but respects resource handles."""
|
|
if t.dtype == dtypes.resource:
|
|
return array_ops.ones(*shape_and_dtype(t))
|
|
else:
|
|
return array_ops.ones_like(t)
|
|
|
|
|
|
def supports_default_grad(t):
|
|
"""Whether tensor `t` supports creating a default gradient.
|
|
|
|
This function assumes that `t` is of a trainable type.
|
|
|
|
Args:
|
|
t: Tensor
|
|
|
|
Returns:
|
|
Bool
|
|
"""
|
|
if t.dtype == dtypes.resource:
|
|
handle_data = resource_variable_ops.get_eager_safe_handle_data(t)
|
|
if (handle_data is None or not handle_data.is_set or
|
|
len(handle_data.shape_and_type) != 1):
|
|
return False
|
|
return True
|