Extract UnconnectedGradients from gradients_impl to simplify dependencies

PiperOrigin-RevId: 218661408
This commit is contained in:
Tamara Norman 2018-10-25 03:48:34 -07:00 committed by TensorFlower Gardener
parent 5d2b7b34c8
commit 670eff0b0f
8 changed files with 68 additions and 41 deletions

View File

@ -2203,6 +2203,7 @@ py_library(
srcs_version = "PY2AND3",
deps = [
":gradients_impl",
":unconnected_gradients",
"//tensorflow/python/eager:function",
"//tensorflow/python/eager:tape",
],
@ -2240,6 +2241,7 @@ py_library(
":spectral_grad",
":tensor_array_ops",
":tensor_util",
":unconnected_gradients",
":util",
":variable_scope",
"//tensorflow/core:protos_all_py",
@ -2250,6 +2252,15 @@ py_library(
],
)
py_library(
name = "unconnected_gradients",
srcs = ["ops/unconnected_gradients.py"],
srcs_version = "PY2AND3",
deps = [
":util",
],
)
py_library(
name = "histogram_ops",
srcs = ["ops/histogram_ops.py"],

View File

@ -279,10 +279,10 @@ py_library(
"//tensorflow/python:dtypes",
"//tensorflow/python:errors",
"//tensorflow/python:framework_ops",
"//tensorflow/python:gradients_impl",
"//tensorflow/python:math_ops",
"//tensorflow/python:pywrap_tensorflow",
"//tensorflow/python:tensor_shape",
"//tensorflow/python:unconnected_gradients",
"//tensorflow/python:util",
"//tensorflow/python/eager:context",
"//tensorflow/python/eager:execute",
@ -375,8 +375,8 @@ py_library(
srcs = ["imperative_grad.py"],
srcs_version = "PY2AND3",
deps = [
"//tensorflow/python:gradients_impl",
"//tensorflow/python:pywrap_tensorflow",
"//tensorflow/python:unconnected_gradients",
"//tensorflow/python:util",
],
)

View File

@ -35,9 +35,9 @@ from tensorflow.python.framework import tensor_shape
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import gen_array_ops
from tensorflow.python.ops import gen_math_ops
from tensorflow.python.ops import gradients_impl
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import resource_variable_ops
from tensorflow.python.ops.unconnected_gradients import UnconnectedGradients
from tensorflow.python.platform import tf_logging as logging
from tensorflow.python.util import nest
from tensorflow.python.util import tf_contextlib
@ -855,7 +855,7 @@ class GradientTape(object):
target,
sources,
output_gradients=None,
unconnected_gradients=gradients_impl.UnconnectedGradients.NONE):
unconnected_gradients=UnconnectedGradients.NONE):
"""Computes the gradient using operations recorded in context of this tape.
Args:

View File

@ -481,8 +481,7 @@ class Function(object):
outputs,
self._func_graph.inputs,
grad_ys=gradients_wrt_outputs,
src_graph=self._func_graph,
unconnected_gradients=gradients_impl.UnconnectedGradients.NONE)
src_graph=self._func_graph)
backwards_graph_captures = list(backwards_graph.captures.keys())

View File

@ -21,7 +21,7 @@ from __future__ import print_function
import collections
from tensorflow.python import pywrap_tensorflow
from tensorflow.python.ops import gradients_impl
from tensorflow.python.ops.unconnected_gradients import UnconnectedGradients
from tensorflow.python.util import compat
VSpace = collections.namedtuple("VSpace", [
@ -34,7 +34,7 @@ def imperative_grad(
target,
sources,
output_gradients=None,
unconnected_gradients=gradients_impl.UnconnectedGradients.NONE):
unconnected_gradients=UnconnectedGradients.NONE):
"""Computes gradients from the imperatively defined tape on top of the stack.
Works by filtering the tape, computing how many downstream usages are of each
@ -59,8 +59,7 @@ def imperative_grad(
RuntimeError: if something goes wrong.
"""
try:
unconnected_gradients = gradients_impl.UnconnectedGradients(
unconnected_gradients)
unconnected_gradients = UnconnectedGradients(unconnected_gradients)
except ValueError:
raise ValueError(
"Unknown value for unconnected_gradients: %r" % unconnected_gradients)

View File

@ -25,5 +25,5 @@ from tensorflow.python.ops.custom_gradient import custom_gradient
from tensorflow.python.ops.gradients_impl import AggregationMethod
from tensorflow.python.ops.gradients_impl import gradients
from tensorflow.python.ops.gradients_impl import hessians
from tensorflow.python.ops.gradients_impl import UnconnectedGradients
from tensorflow.python.ops.unconnected_gradients import UnconnectedGradients
# pylint: enable=unused-import

View File

@ -20,7 +20,6 @@ from __future__ import print_function
import collections
import contextlib
import enum # pylint: disable=g-bad-import-order
import warnings
import numpy as np
@ -35,6 +34,7 @@ from tensorflow.python.framework import function as framework_function
from tensorflow.python.framework import ops
from tensorflow.python.framework import tensor_shape
from tensorflow.python.framework import tensor_util
from tensorflow.python.framework.func_graph import FuncGraph
from tensorflow.python.ops import array_grad # pylint: disable=unused-import
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import check_ops # pylint: disable=unused-import
@ -53,16 +53,11 @@ from tensorflow.python.ops import random_grad # pylint: disable=unused-import
from tensorflow.python.ops import resource_variable_ops
from tensorflow.python.ops import spectral_grad # pylint: disable=unused-import
from tensorflow.python.ops import tensor_array_ops
from tensorflow.python.ops.unconnected_gradients import UnconnectedGradients
from tensorflow.python.platform import tf_logging as logging
from tensorflow.python.util import compat
from tensorflow.python.util.lazy_loader import LazyLoader
from tensorflow.python.util.tf_export import tf_export
# This is to avoid a circular dependency
# backprop -> gradients_impl -> func_graph
funct_graph = LazyLoader(
"func_graph", globals(),
"tensorflow.python.framework.func_graph")
# This is to avoid a circular dependency (eager.function depends on
# gradients_impl). This is set in eager/function.py.
@ -455,12 +450,12 @@ def _RaiseNoGradWrtInitialLoopValError(op, from_ops, xs):
def _IsFunction(graph):
return (isinstance(graph, funct_graph.FuncGraph) or
return (isinstance(graph, FuncGraph) or
isinstance(graph, framework_function._FuncGraph)) # pylint: disable=protected-access
def _Captures(func_graph):
if isinstance(func_graph, funct_graph.FuncGraph):
if isinstance(func_graph, FuncGraph):
return func_graph.captures
else:
assert isinstance(func_graph, framework_function._FuncGraph) # pylint: disable=protected-access
@ -539,26 +534,6 @@ def _Consumers(t, func_graphs):
return consumers
@tf_export("UnconnectedGradients")
class UnconnectedGradients(enum.Enum):
"""Controls how gradient computation behaves when y does not depend on x.
The gradient of y with respect to x can be zero in two different ways: there
could be no differentiable path in the graph connecting x to y (and so we can
statically prove that the gradient is zero) or it could be that runtime values
of tensors in a particular execution lead to a gradient of zero (say, if a
relu unit happens to not be activated). To allow you to distinguish between
these two cases you can choose what value gets returned for the gradient when
there is no path in the graph from x to y:
* `NONE`: Indicates that [None] will be returned if there is no path from x
to y
* `ZERO`: Indicates that a zero tensor will be returned in the shape of x.
"""
NONE = "none"
ZERO = "zero"
@tf_export("gradients")
def gradients(ys,
xs,
@ -703,7 +678,7 @@ def _GradientsHelper(ys,
curr_graph = src_graph
while _IsFunction(curr_graph):
func_graphs.append(curr_graph)
if isinstance(curr_graph, funct_graph.FuncGraph):
if isinstance(curr_graph, FuncGraph):
curr_graph = curr_graph.outer_graph
else:
assert isinstance(curr_graph, framework_function._FuncGraph) # pylint: disable=protected-access

View File

@ -0,0 +1,43 @@
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Utilities for calculating gradients."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import enum
from tensorflow.python.util.tf_export import tf_export
@tf_export("UnconnectedGradients")
class UnconnectedGradients(enum.Enum):
"""Controls how gradient computation behaves when y does not depend on x.
The gradient of y with respect to x can be zero in two different ways: there
could be no differentiable path in the graph connecting x to y (and so we can
statically prove that the gradient is zero) or it could be that runtime values
of tensors in a particular execution lead to a gradient of zero (say, if a
relu unit happens to not be activated). To allow you to distinguish between
these two cases you can choose what value gets returned for the gradient when
there is no path in the graph from x to y:
* `NONE`: Indicates that [None] will be returned if there is no path from x
to y
* `ZERO`: Indicates that a zero tensor will be returned in the shape of x.
"""
NONE = "none"
ZERO = "zero"