Resurrects autograd-free eager gradients.

PiperOrigin-RevId: 168448557
This commit is contained in:
Alexandre Passos 2017-09-12 14:37:06 -07:00 committed by TensorFlower Gardener
parent 8f37f30027
commit 655f26fc70
17 changed files with 510 additions and 837 deletions

View File

@ -637,6 +637,7 @@ py_library(
"//tensorflow/core:protos_all_py", "//tensorflow/core:protos_all_py",
"//tensorflow/python/eager:context", "//tensorflow/python/eager:context",
"//tensorflow/python/eager:core", "//tensorflow/python/eager:core",
"//tensorflow/python/eager:tape",
"@six_archive//:six", "@six_archive//:six",
], ],
) )
@ -1800,7 +1801,6 @@ py_library(
"//tensorflow/python/eager:context", "//tensorflow/python/eager:context",
"//tensorflow/python/eager:custom_gradient", "//tensorflow/python/eager:custom_gradient",
"//tensorflow/python/eager:tape", "//tensorflow/python/eager:tape",
"//tensorflow/python/eager:tensor_node",
], ],
) )

View File

@ -82,7 +82,6 @@ py_library(
visibility = ["//tensorflow:internal"], visibility = ["//tensorflow:internal"],
deps = [ deps = [
"//tensorflow/python:dtypes", "//tensorflow/python:dtypes",
"//tensorflow/python:framework_ops",
"//tensorflow/python:util", "//tensorflow/python:util",
], ],
) )
@ -307,25 +306,6 @@ py_library(
], ],
) )
py_library(
name = "tensor_node",
srcs = ["tensor_node.py"],
srcs_version = "PY2AND3",
visibility = ["//tensorflow:internal"],
deps = [
":context",
":custom_gradient",
":tape",
":tensor",
"//tensorflow/python:array_ops",
"//tensorflow/python:common_shapes",
"//tensorflow/python:dtypes",
"//tensorflow/python:framework_ops",
"//tensorflow/python:math_ops",
"//tensorflow/python:tensor_shape",
],
)
py_library( py_library(
name = "backprop", name = "backprop",
srcs = ["backprop.py"], srcs = ["backprop.py"],
@ -344,7 +324,6 @@ py_library(
"//tensorflow/python/eager:execute", "//tensorflow/python/eager:execute",
"//tensorflow/python/eager:tape", "//tensorflow/python/eager:tape",
"//tensorflow/python/eager:tensor", "//tensorflow/python/eager:tensor",
"//tensorflow/python/eager:tensor_node",
"@six_archive//:six", "@six_archive//:six",
], ],
) )
@ -386,18 +365,6 @@ py_test(
], ],
) )
py_test(
name = "tensor_node_test",
srcs = ["tensor_node_test.py"],
srcs_version = "PY2AND3",
deps = [
":tensor",
":tensor_node",
":test",
"//tensorflow/python:framework_test_lib",
],
)
py_test( py_test(
name = "ops_test", name = "ops_test",
srcs = ["ops_test.py"], srcs = ["ops_test.py"],

View File

@ -21,30 +21,265 @@ from __future__ import print_function
import collections import collections
import threading import threading
from autograd import container_types
from autograd import convenience_wrappers
from autograd import core as ag_core
import six import six
from tensorflow.python import pywrap_tensorflow from tensorflow.python import pywrap_tensorflow
from tensorflow.python.eager import context from tensorflow.python.eager import context
from tensorflow.python.eager import execute from tensorflow.python.eager import execute
from tensorflow.python.eager import tape from tensorflow.python.eager import tape
from tensorflow.python.eager import tensor from tensorflow.python.framework import constant_op
# Imports TensorNode to enable autograd tracing of TF ops. We don't need to use
# any symbols here but import the file just to get the right registrations to
# happen.
from tensorflow.python.eager import tensor_node # pylint: disable=unused-import
from tensorflow.python.framework import dtypes from tensorflow.python.framework import dtypes
from tensorflow.python.framework import errors from tensorflow.python.framework import errors
from tensorflow.python.framework import ops as tf_ops from tensorflow.python.framework import ops
from tensorflow.python.framework import tensor_shape from tensorflow.python.framework import tensor_shape
from tensorflow.python.ops import array_ops from tensorflow.python.ops import array_ops
from tensorflow.python.ops import math_ops from tensorflow.python.ops import math_ops
from tensorflow.python.util import tf_contextlib
from tensorflow.python.util import tf_inspect from tensorflow.python.util import tf_inspect
# Terminology:
#
# - op: a possibly composite operation, which has an entry in the tape
# - target: dy in dx/dy
# - source: dx in dx/dy
# - tensor: one of the many inputs or outputs of an operation
#
# Below here we do the gradient algorithm. It works as follows:
#
# First we filter the tape to just the subset of operations we want to
# differentiate. In the process of doing so we count how many times each Tensor
# is used as an input to an op (so we know when we're done computing gradients
# for that Tensor). We also count, for each tape entry, how many of its output
# Tensors need gradients to be computed (Tensors which are not used do not need
# any gradients to be computed).
#
# Finally, we start a backprop stack with a set of tape entries for which we
# have all gradients available. This set usually is a subset of the set of
# targets (not all since targets which have outputs in the tape will not have
# gradients available initially).
#
# Then we repeatedly pop an entry from the stack, run its backprop, and update
# the gradients of its inputs. Once we have computed all gradients for a single
# input we can mark this input as done, and this can trigger adding an entry to
# the stack if all outputs of that entry are now done.
#
# When the stack is empty we have gradients for all tensors we're interested in.
def _prepare_backprop(target, tensor_to_op, op_to_entry, id_sources):
"""Filters the tape to only include relevant entries and counts tensor usages.
Args:
target: the target to optimize.
tensor_to_op: Map from tensor id to key in op_to_entry that produced it.
op_to_entry: Map from op id to a tape.TapeEntry object
id_sources: the ids of the sources wrt the gradient is being taken.
Returns:
usage counts (how many entries downstream from a tensor use it)
op_to_entry_map: entry map (a filtered tape, with only the relevant
entries),
missing: map from tensor id to how many downstream gradients still need
to be computed before this tensor's gradient can be computed.
"""
if isinstance(target, (ops.Tensor)):
tensor_stack = [ops.tensor_id(target)]
else:
tensor_stack = list([ops.tensor_id(x) for x in target])
tensor_usage_counts = {}
o_to_e = {} # Copy of just the bits we need from op_to_entry
while tensor_stack:
t = tensor_stack.pop()
op = tensor_to_op[t]
# op is None if the tensor is a source (i.e. was watched directly)
if op is None or op in o_to_e:
continue
op_trace = op_to_entry[op]
o_to_e[op] = op_trace
for i in op_trace.inputs:
it = ops.tensor_id(i)
if it in tensor_usage_counts:
tensor_usage_counts[it] += 1
else:
tensor_usage_counts[it] = 1
if it not in id_sources and it in tensor_to_op:
tensor_stack.append(it)
op_missing_tensor_counts = collections.defaultdict(int)
for t in tensor_usage_counts:
if t in tensor_to_op and tensor_to_op[t] is not None:
op_missing_tensor_counts[tensor_to_op[t]] += 1
return tensor_usage_counts, o_to_e, op_missing_tensor_counts
def _initialize_backprop_stack(op_to_entry, op_missing_tensor):
"""Returns the set of tape entries which are available for backprop."""
ready_ops = []
for op in op_to_entry:
if op not in op_missing_tensor:
ready_ops.append(op)
return ready_ops
def _initial_gradients(target, output_gradients, tensor_usage_counts):
"""Computes the initial gradients for each Tensor."""
# Initialize the backprop stack
gradients = collections.defaultdict(list)
if isinstance(target, ops.Tensor):
if output_gradients is not None:
output_gradient = output_gradients
else:
output_gradient = array_ops.ones_like(target)
gradients[ops.tensor_id(target)].append(output_gradient)
else:
for i, t in enumerate(target):
if ops.tensor_id(t) in tensor_usage_counts:
# Can't provide a gradient of something we're trying to differentiate
assert output_gradients is None or output_gradients[i] is None
else:
if output_gradients is None or output_gradients[i] is None:
out_grad = ops.ones_like(t)
else:
out_grad = output_gradients[i]
gradients[ops.tensor_id(t)].append(out_grad)
return gradients
@tf_contextlib.contextmanager
def _no_op():
yield
def _aggregate_grads(gradients):
"""Aggregate gradients from multiple sources.
Args:
gradients: A list of 'Tensor' or 'IndexedSlices' gradients.
Returns:
If 'gradients' only has 'Tensor', returns an aggregated 'Tensor'.
Otherwise returns an aggregated 'IndexedSlices'.
"""
assert gradients, "No gradients to aggregate"
if len(gradients) == 1:
return gradients[0]
if all([isinstance(g, ops.Tensor) for g in gradients]):
return math_ops.add_n(gradients)
else:
assert all([isinstance(g, (ops.Tensor, ops.IndexedSlices))
for g in gradients])
indexed_slices_list = []
for grad in gradients:
# TODO(xpan): Support nested IndexedSlices and core IndexedSlices
if isinstance(grad, ops.Tensor):
indexed_slices = ops.IndexedSlices(
grad,
constant_op.constant(range(grad.shape[0])),
constant_op.constant(grad.shape.as_list()))
indexed_slices_list.append(indexed_slices)
else:
indexed_slices_list.append(grad)
# Dense shapes from all gradients should be the same.
dense_shape = indexed_slices_list[0].dense_shape
# For simplicity now, always cast to int64.
indices = array_ops.concat([math_ops.cast(x.indices, dtypes.int64)
for x in indexed_slices_list], 0)
values = array_ops.concat([x.values for x in indexed_slices_list], 0)
return ops.IndexedSlices(values, indices, dense_shape)
def imperative_grad(
target,
sources,
output_gradients=None):
"""Computes gradients from the imperatively defined tape on top of the stack.
Works by filtering the tape, computing how many downstream usages are of each
tensor and entry, and repeatedly applying backward functions until we have
gradients for all sources.
Args:
target: either a Tensor or list of Tensors to be differentiated.
sources: list of Tensors for which we want gradients
output_gradients: if not None, a list of gradient provided for each Target,
or None if we are to use the target's computed downstream gradient.
Returns:
the gradient wrt each of the sources.
Raises:
RuntimeError: if something goes wrong.
ValueError: if there is no sequence of differentiable operations connecting
a source and any target Tensor. This can happen either if the target is
not computed based on the source, if the tracing was set up incorrectly,
or if only non-differentiable functions of the source were used in the
computation of target.
"""
if not tape._tape_stack.stack: # pylint: disable=protected-access
raise RuntimeError("Computing a gradient with no tape present")
bp_tape = tape.pop_tape()
tensor_to_op, op_to_entry, output_to_shape_dtype = bp_tape.export()
# This overwrites the op_to_entry variable, which will release all memory used
# to keep traces that are irrelevant to the gradient computation we're doing
# here.
id_sources = [ops.tensor_id(t) for t in sources]
tensor_usage_counts, op_to_entry, op_missing_tensor = _prepare_backprop(
target, tensor_to_op, op_to_entry, id_sources)
ready_ops = _initialize_backprop_stack(op_to_entry, op_missing_tensor)
gradients = _initial_gradients(target, output_gradients,
tensor_usage_counts)
# Now exhaust the backprop stack
while ready_ops:
op = ready_ops.pop()
op_trace = op_to_entry.pop(op)
out_gradients = [gradients.pop(t, None) for t in op_trace.output_ids]
for i in range(len(out_gradients)):
if out_gradients[i] is None:
# TODO(apassos) this should be in the right device
out_gradients[i] = array_ops.zeros(
*output_to_shape_dtype[op_trace.output_ids[i]])
else:
out_gradients[i] = _aggregate_grads(out_gradients[i])
in_gradients = op_trace.backward_function(
*(out_gradients + op_trace.side_outputs))
in_gradients = ([in_gradients]
if isinstance(in_gradients, (ops.Tensor,
ops.IndexedSlices,
type(None)))
else in_gradients)
for i, t in enumerate(op_trace.inputs):
if in_gradients[i] is not None:
gradients[ops.tensor_id(t)].append(in_gradients[i])
if tensor_usage_counts.get(ops.tensor_id(t), 0) > 0:
tensor_usage_counts[ops.tensor_id(t)] -= 1
if ops.tensor_id(t) in tensor_to_op and tensor_usage_counts[
ops.tensor_id(t)] == 0 and ops.tensor_id(t) not in id_sources:
in_op = tensor_to_op[ops.tensor_id(t)]
if in_op is None:
continue
if op_missing_tensor.get(in_op, 0) > 0:
op_missing_tensor[in_op] -= 1
if op_missing_tensor.get(in_op, 0) == 0:
ready_ops.append(in_op)
result = []
for i, s in enumerate(sources):
g = gradients.get(ops.tensor_id(s), None)
if g is None:
# TODO(apassos): figure out a way to summarize why sources and targets are
# not connected.
raise ValueError("There is no sequence of operations connecting source "
"tensor %s (%s) to any of the target Tensors. This is "
"commonly caused by the tape not recording all "
"operations in the forward pass or if by mistake a "
"source was only used in non-differentiable operations."
% (i, s))
result.append(_aggregate_grads(g))
return result
def op_attr_type(op_type, attr_name): def op_attr_type(op_type, attr_name):
with errors.raise_exception_on_not_ok_status() as status: with errors.raise_exception_on_not_ok_status() as status:
h = context.context()._handle # pylint: disable=protected-access h = context.context()._handle # pylint: disable=protected-access
@ -82,26 +317,23 @@ class _MockOp(object):
raise KeyError(attr) raise KeyError(attr)
def _magic_gradient_function(op_name, attr_tuple, num_inputs, num_outputs, def _magic_gradient_function(op_name, attr_tuple, num_inputs,
*tensors): inputs, outputs, out_grads):
"""Calls the gradient function of the op. """Calls the gradient function of the op.
Args: Args:
op_name: the name of the op to be differentiated. op_name: the name of the op to be differentiated.
attr_tuple: the attrs, as a tuple. attr_tuple: the attrs, as a tuple.
num_inputs: the number of inputs to the op. num_inputs: the number of inputs to the op.
num_outputs: the number of outputs of the op. inputs: inputs to the original operation.
*tensors: a list of tensors, composed of, in order, the inputs, the outputs, outputs: outputs to the original operation.
and the gradients with respect to the outputs. out_grads: gradients of the operation wrt its outputs.
Returns: Returns:
The gradients with respect to the inputs of the function, as a list. The gradients with respect to the inputs of the function, as a list.
""" """
inputs = tensors[:num_inputs]
outputs = tensors[num_inputs:num_inputs + num_outputs]
out_grads = tensors[num_inputs + num_outputs:]
mock_op = _MockOp(attr_tuple, inputs, outputs, op_name) mock_op = _MockOp(attr_tuple, inputs, outputs, op_name)
grad_fn = tf_ops._gradient_registry.lookup(op_name) # pylint: disable=protected-access grad_fn = ops._gradient_registry.lookup(op_name) # pylint: disable=protected-access
if grad_fn is None: if grad_fn is None:
return [None] * num_inputs return [None] * num_inputs
out_grads = [ out_grads = [
@ -136,31 +368,23 @@ def _record_gradient(op_name, inputs, attrs, results, name):
Raises: Raises:
An exception on error. An exception on error.
""" """
if not any(ag_core.isnode(x) for x in inputs):
return results
num_outputs = len(results) num_outputs = len(results)
if num_outputs == 0: if num_outputs == 0:
return results return results
if attrs is not None: if attrs is not None:
attrs = tuple(tuple(x) if isinstance(x, list) else x for x in attrs) attrs = attrs
# It is imperative we make a copy of results here as otherwise we create a
# dependency cycle in the captured function and this can delay garbage
# collecting of the tensors arbitrarily.
results_size = len(results) if isinstance(results, (list, tuple)) else 1
def grad_fn(*orig_outputs): def grad_fn(*orig_outputs):
"""Generated gradient function.""" """Generated gradient function."""
tensors = inputs + list(orig_outputs)
tensors = container_types.make_sequence(tape.EagerList, *tensors)
result = _magic_gradient_function(op_name, attrs, len(inputs), result = _magic_gradient_function(op_name, attrs, len(inputs),
num_outputs, *(tensors)) inputs, results, orig_outputs)
if _tracing: if _tracing:
print("Gradient for", (name if name else op_name), "inputs", inputs, print("Gradient for", (name if name else op_name), "inputs", inputs,
"output_grads", orig_outputs[results_size:], "gradients", result) "output_grads", orig_outputs, "gradients", result)
return result return result
results = tape.record_operation(results, inputs, [], grad_fn) inputs = [ops.convert_to_tensor(x) for x in inputs]
tape.record_operation(results, inputs, [], grad_fn)
if _tracing: if _tracing:
print("Computed op", (name if name else op_name), "inputs", inputs, print("Computed op", (name if name else op_name), "inputs", inputs,
"outputs", results) "outputs", results)
@ -170,27 +394,6 @@ def _record_gradient(op_name, inputs, attrs, results, name):
execute.record_gradient = _record_gradient execute.record_gradient = _record_gradient
def _aggregate_grads(gradients):
"""Aggregate gradients of the same tensor."""
grad_lists = collections.OrderedDict()
for g, v in gradients:
if g is None:
continue
if id(v) not in grad_lists:
grad_lists[id(v)] = [(g, v)]
else:
grad_lists[id(v)].append((g, v))
ret = []
for _, g_list in six.iteritems(grad_lists):
if len(g_list) == 1:
ret.append(g_list[0])
else:
# TODO(xpan): Aggregate IndexedSlices.
ret.append((math_ops.add_n(list(zip(*g_list))[0]), g_list[1][1]))
return ret
def implicit_val_and_grad(f): def implicit_val_and_grad(f):
"""Returns a function which differentiates f with respect to variables. """Returns a function which differentiates f with respect to variables.
@ -224,23 +427,14 @@ def implicit_val_and_grad(f):
Its second element is list of (gradient, variable) pairs. Its second element is list of (gradient, variable) pairs.
""" """
def grad_fn(*args, **kwds): def grad_fn(*args):
"""Computes the gradient of the wrapped function.""" """Computes the gradient of the wrapped function."""
tape.push_new_tape() tape.push_new_tape()
end_node = f(*args) end_node = f(*args)
start_node = tape.pop_tape() variables = tape.top_tape_watched_variables()
ag_core.active_progenitors.remove(start_node) sources = [x.handle for x in variables]
if not ag_core.isnode(end_node): grad = imperative_grad(end_node, sources)
raise ValueError( return end_node, list(zip(grad, variables))
"Target not part of a computation being traced. %s." % end_node)
if start_node not in end_node.progenitors:
raise ValueError("Target not derived from source. %s %s." %
(end_node.progenitors, repr(start_node)))
output_gradients = kwds.get("output_gradients", None)
if output_gradients is None:
output_gradients = array_ops.ones_like(end_node.value)
grad = ag_core.backward_pass(output_gradients, end_node, start_node)
return end_node.value, _aggregate_grads(grad.gradients)
return grad_fn return grad_fn
@ -295,24 +489,25 @@ def gradients_function(f, params=None):
differentiates with respect to all parameters. differentiates with respect to all parameters.
Returns: Returns:
function which, when called, returns the gradient of f with function which, when called, returns the value of f and the gradient
respect to all of `params`. of f with respect to all of `params`. The function takes an extra optional
keyword argument "dy". Setting it allows computation of vector jacobian
products for vectors other than the vector of ones.
Raises: Raises:
ValueError: if the params are not all strings or all integers. ValueError: if the params are not all strings or all integers.
""" """
parameter_positions = _get_arg_spec(f, params)
def decorated(*args, **kwargs): def decorated(*args, **kwds):
tensors = convenience_wrappers.multigrad(f, parameter_positions)(*args, """Computes the gradient of the decorated function."""
**kwargs)
return [t.tensor() if isinstance(t, tensor.LazyZero) _, grad = val_and_grad_function(f, params)(*args, **kwds)
else t for t in tensors] return grad
return decorated return decorated
def val_and_grad_function(f, params=None): def val_and_grad_function(f, params):
"""Returns a function that computes f and is derivative w.r.t. params. """Returns a function that computes f and is derivative w.r.t. params.
Args: Args:
@ -321,11 +516,30 @@ def val_and_grad_function(f, params=None):
parameters with respect to which we'll differentiate. Passing None parameters with respect to which we'll differentiate. Passing None
differentiates with respect to all parameters. differentiates with respect to all parameters.
Returns: Returns: function which, when called, returns the value of f and the gradient
function which, when called, returns the value of f and the of f with respect to all of `params`. The function takes an extra optional
gradient of f with respect to all of `params`. keyword argument "dy". Setting it allows computation of vector jacobian
products for vectors other than the vector of ones.
Raises: Raises:
ValueError: if the params are not all strings or all integers. ValueError: if the params are not all strings or all integers.
""" """
return convenience_wrappers.value_and_multigrad(f, _get_arg_spec(f, params)) parameter_positions = _get_arg_spec(f, params)
def decorated(*args, **kwds):
"""Computes the value and gradient of the decorated function."""
dy = kwds.pop("dy", None)
assert not kwds, "The gradient function can't take keyword arguments."
tape.push_new_tape()
sources = []
args = list(args)
for i in parameter_positions:
sources.append(args[i])
tape.watch(args[i])
result = f(*args)
return result, imperative_grad(
result,
sources,
output_gradients=dy)
return decorated

View File

@ -23,10 +23,10 @@ from tensorflow.python.eager import backprop
from tensorflow.python.eager import context from tensorflow.python.eager import context
from tensorflow.python.eager import tape from tensorflow.python.eager import tape
from tensorflow.python.eager import tensor from tensorflow.python.eager import tensor
from tensorflow.python.eager import tensor_node
from tensorflow.python.eager import test from tensorflow.python.eager import test
from tensorflow.python.framework import constant_op from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.framework import tensor_shape from tensorflow.python.framework import tensor_shape
from tensorflow.python.ops import array_ops from tensorflow.python.ops import array_ops
from tensorflow.python.ops import embedding_ops from tensorflow.python.ops import embedding_ops
@ -58,6 +58,7 @@ class BackpropTest(test.TestCase):
var_np = np.random.rand(4, 2).astype(np.float32) var_np = np.random.rand(4, 2).astype(np.float32)
var = tensor.Tensor(var_np) var = tensor.Tensor(var_np)
grad = backprop.gradients_function(fn, [0])(var)[0] grad = backprop.gradients_function(fn, [0])(var)[0]
grad = ops.convert_to_tensor(grad).numpy()
with context.graph_mode(), self.test_session(): with context.graph_mode(), self.test_session():
tf_var = array_ops.constant(var_np, dtypes.float32) tf_var = array_ops.constant(var_np, dtypes.float32)
@ -74,11 +75,7 @@ class BackpropTest(test.TestCase):
tf_dense_grad = math_ops.unsorted_segment_sum( tf_dense_grad = math_ops.unsorted_segment_sum(
tf_grad.values, tf_grad.indices, tf_grad.dense_shape[0]) tf_grad.values, tf_grad.indices, tf_grad.dense_shape[0])
self.assertAllClose(grad.numpy(), tf_dense_grad.eval()) self.assertAllClose(grad, tf_dense_grad.eval())
def testTensoVspaceNoneMutAdd(self):
t = tensor.Tensor(1.0)
self.assertEqual(tensor_node.TensorVSpace(t).mut_add(t, None).numpy(), 1.0)
def testImplicitGradWithResourceVariable(self): def testImplicitGradWithResourceVariable(self):
x = resource_variable_ops.ResourceVariable( x = resource_variable_ops.ResourceVariable(
@ -94,6 +91,16 @@ class BackpropTest(test.TestCase):
self.assertEqual(grads_and_vars[0][0].numpy(), 1.0) self.assertEqual(grads_and_vars[0][0].numpy(), 1.0)
self.assertEqual(id(grads_and_vars[0][1]), id(x)) self.assertEqual(id(grads_and_vars[0][1]), id(x))
def testDy(self):
def f(x):
return x
grad_fn = backprop.gradients_function(f)
self.assertAllEqual(grad_fn(constant_op.constant(1.0),
dy=constant_op.constant(2.0))[0].numpy(),
2.0)
def testImplicitGradOverEmbeddingLookup(self): def testImplicitGradOverEmbeddingLookup(self):
batch_size = 8 batch_size = 8
embedding_size = 512 embedding_size = 512

View File

@ -18,22 +18,11 @@ from __future__ import absolute_import
from __future__ import division from __future__ import division
from __future__ import print_function from __future__ import print_function
from autograd import core as ag_core
from tensorflow.python.eager import tape from tensorflow.python.eager import tape
from tensorflow.python.eager import tensor as _tensor
from tensorflow.python.framework import ops as tf_ops from tensorflow.python.framework import ops as tf_ops
from tensorflow.python.util import nest from tensorflow.python.util import nest
def _watch_value_from_tape(tensor):
for t in tape._tape_stack.stack: # pylint: disable=protected-access
w = t.value.tensors.get(tf_ops.tensor_id(tensor), None)
if w is not None:
return w
return tensor
def custom_gradient(f): def custom_gradient(f):
"""Decorator to define a function with a custom gradient. """Decorator to define a function with a custom gradient.
@ -52,27 +41,23 @@ def custom_gradient(f):
def decorated(*args, **kwargs): def decorated(*args, **kwargs):
"""Decorated function with custom gradient.""" """Decorated function with custom gradient."""
input_tensors = [_watch_value_from_tape(x) for x in args input_tensors = [x for x in args
if isinstance(x, (_tensor.Tensor, tf_ops.Tensor)) if isinstance(x, tf_ops.Tensor)]
or ag_core.isnode(x)]
result, grad_fn = f(*args, **kwargs) result, grad_fn = f(*args, **kwargs)
result_size = len(result) if isinstance(result, (list, tuple)) else 1
# TODO(apassos): naive uses of custom_gradient will not get the correct # TODO(apassos): naive uses of custom_gradient will not get the correct
# second derivative this way if they capture any output tensors. Change the # second derivative this way if they capture any output tensors. Change the
# signature of custom_gradient. # signature of custom_gradient.
def actual_grad_fn(*outputs): def actual_grad_fn(*outputs):
outputs = outputs[result_size:]
return grad_fn(*outputs) return grad_fn(*outputs)
flat_result = nest.flatten(result) flat_result = nest.flatten(result)
flat_result = [ag_core.getval(x) for x in flat_result] tape.record_operation(
flat_result = tape.record_operation(
flat_result, flat_result,
input_tensors, input_tensors,
[], [],
actual_grad_fn) actual_grad_fn)
flat_result = list(flat_result) flat_result = list(flat_result)
return nest.pack_sequence_as(structure=result, flat_sequence=flat_result) return result
return decorated return decorated

View File

@ -18,7 +18,6 @@ from __future__ import absolute_import
from __future__ import division from __future__ import division
from __future__ import print_function from __future__ import print_function
from autograd import core as ag_core
import six import six
from google.protobuf import text_format from google.protobuf import text_format
@ -57,7 +56,6 @@ def execute(op_name, num_outputs, inputs, attrs=None, name=None):
""" """
ctx = context.get_default_context() ctx = context.get_default_context()
# TODO(apassos) move this to convert_to_tensor # TODO(apassos) move this to convert_to_tensor
inputs = [ag_core.getval(x) for x in inputs]
# pylint: disable=protected-access # pylint: disable=protected-access
input_handles = [c._handle for c in inputs] input_handles = [c._handle for c in inputs]
device_name = ctx.device_name device_name = ctx.device_name
@ -184,7 +182,7 @@ def args_to_matching_eager(l, default_dtype=None):
# Is some input already a Tensor with a dtype? # Is some input already a Tensor with a dtype?
dtype = None dtype = None
for t in l: for t in l:
if isinstance(ag_core.getval(t), tensor.Tensor): if isinstance(t, tensor.Tensor):
dtype = t.dtype dtype = t.dtype
break break
@ -203,7 +201,7 @@ def args_to_matching_eager(l, default_dtype=None):
def convert_to_mixed_eager_tensors(values): def convert_to_mixed_eager_tensors(values):
v = [t if isinstance(ag_core.getval(t), tensor.Tensor) else tensor.Tensor(t) v = [t if isinstance(t, tensor.Tensor) else tensor.Tensor(t)
for t in values] for t in values]
types = [t.dtype for t in v] types = [t.dtype for t in v]
return types, v return types, v
@ -228,7 +226,7 @@ def args_to_mixed_eager_tensors(lists):
dtype = None dtype = None
# If any list has a Tensor, use that dtype # If any list has a Tensor, use that dtype
for l in lists: for l in lists:
if isinstance(ag_core.getval(l[i]), tensor.Tensor): if isinstance(l[i], tensor.Tensor):
dtype = l[i].dtype dtype = l[i].dtype
break break
if dtype is None: if dtype is None:

View File

@ -23,7 +23,6 @@ import collections
import contextlib import contextlib
import threading import threading
from autograd import core as ag_core
import numpy as np import numpy as np
from tensorflow.python.eager import context from tensorflow.python.eager import context
@ -88,6 +87,7 @@ def _convert_to_graph_tensor(value, dtype=None, name=None, as_ref=False):
tensor_map[ops.tensor_id(value)] = (value, captured_value) tensor_map[ops.tensor_id(value)] = (value, captured_value)
else: else:
captured_value = captured_value[1] captured_value = captured_value[1]
tape.record_operation([captured_value], [value], [], lambda x: x)
return captured_value return captured_value
@ -193,11 +193,8 @@ class _GraphModeFunction(object):
self._num_outputs = len(fdef.signature.output_arg) self._num_outputs = len(fdef.signature.output_arg)
self._ops = operations self._ops = operations
self._func_outputs = func_outputs self._func_outputs = func_outputs
if (isinstance(func_outputs, (ops.Tensor, type(None))) or self._returns = [func_outputs] if isinstance(
ag_core.isnode(func_outputs)): func_outputs, (ops.Tensor, type(None))) else list(func_outputs)
self._returns = [func_outputs]
else:
self._returns = list(func_outputs)
self._returns_to_fedf_outputs = func_outputs_to_fdef_outputs self._returns_to_fedf_outputs = func_outputs_to_fdef_outputs
self._output_shapes = output_shapes self._output_shapes = output_shapes
@ -208,7 +205,7 @@ class _GraphModeFunction(object):
c = _CapturingContext() c = _CapturingContext()
with c: with c:
filtered_outputs = [ filtered_outputs = [
ag_core.getval(x) for x in self._returns if x is not None x for x in self._returns if x is not None
] ]
self._out_grad_placeholders = [ self._out_grad_placeholders = [
graph_placeholder(x.dtype, x.shape) for x in filtered_outputs graph_placeholder(x.dtype, x.shape) for x in filtered_outputs
@ -242,16 +239,19 @@ class _GraphModeFunction(object):
if context.in_graph_mode(): if context.in_graph_mode():
g = ops.get_default_graph() g = ops.get_default_graph()
g._add_function(self._forward_fdef) # pylint: disable=protected-access g._add_function(self._forward_fdef) # pylint: disable=protected-access
unwrapped_args = [ag_core.getval(x) for x in all_args] def make_tensor(x):
if isinstance(x, ops.Tensor):
return x
return ops.convert_to_tensor(x)
op = g.create_op( op = g.create_op(
signature.name, [ops.convert_to_tensor(x) for x in unwrapped_args], signature.name, [make_tensor(x) for x in all_args],
[dtypes.DType(x.type) for x in signature.output_arg], [dtypes.DType(x.type) for x in signature.output_arg],
op_def=signature, op_def=signature,
name="FunctionCall", name="FunctionCall",
compute_shapes=False) compute_shapes=False)
outputs = op.outputs outputs = op.outputs
outputs = [outputs] if isinstance( outputs = [outputs] if isinstance(
outputs, (tensor.Tensor, ops.Tensor, type(None))) else list(outputs) outputs, (ops.Tensor, type(None))) else list(outputs)
for i, s in enumerate(self._output_shapes): for i, s in enumerate(self._output_shapes):
outputs[i].set_shape(s) outputs[i].set_shape(s)
else: else:
@ -261,25 +261,12 @@ class _GraphModeFunction(object):
inputs=all_args) inputs=all_args)
real_outputs = outputs[:len(self._returns)] real_outputs = outputs[:len(self._returns)]
side_outputs = outputs[len(self._returns):] side_outputs = outputs[len(self._returns):]
watched_extra_inputs = []
for t in self._extra_inputs:
tid = ops.tensor_id(t)
for t in tape._tape_stack.stack: # pylint: disable=protected-access
w = t.value.tensors.get(tid, None)
if w is not None:
watched_extra_inputs.append(w)
break
else: # Note: for-else here done on purpose
watched_extra_inputs.append(t)
def backward_function_wrapper(*outputs): tape.record_operation(
outputs = outputs[len(real_outputs):]
return self._backward_function(*outputs)
real_outputs = tape.record_operation(
real_outputs, real_outputs,
(args + watched_extra_inputs), (args + self._extra_inputs),
side_outputs, side_outputs,
backward_function_wrapper) self._backward_function)
return self._build_call_outputs(self._returns, real_outputs) return self._build_call_outputs(self._returns, real_outputs)
@ -288,10 +275,10 @@ class _GraphModeFunction(object):
tensor_inputs = [ tensor_inputs = [
x for x in nest.flatten(args) x for x in nest.flatten(args)
if isinstance(x, (tensor.Tensor, ops.Tensor, if isinstance(x, (tensor.Tensor, ops.Tensor,
tensor.LazyZero)) or ag_core.isnode(x) tensor.LazyZero))
] ]
if tape.should_record(tensor_inputs) or any( if tape.should_record(tensor_inputs) or tape.should_record(
tape.any_tape_has(t) for t in self._extra_inputs): self._extra_inputs):
if not self._has_backprop: if not self._has_backprop:
self._compute_backprop() self._compute_backprop()
return self._backprop_call(tensor_inputs) return self._backprop_call(tensor_inputs)
@ -334,12 +321,12 @@ class _GraphModeFunction(object):
""" """
if self._func_outputs is None: if self._func_outputs is None:
return None return None
if isinstance(ag_core.getval(self._func_outputs), ops.Tensor): if isinstance(self._func_outputs, ops.Tensor):
return result[0] return result[0]
outputs = [] outputs = []
for o in func_outputs: for o in func_outputs:
vo = ag_core.getval(o) vo = o
if isinstance(vo, ops.Tensor): if isinstance(vo, ops.Tensor):
outputs.append(result[self._returns_to_fedf_outputs[id(vo)]]) outputs.append(result[self._returns_to_fedf_outputs[id(vo)]])
elif type(vo) in (tuple, list): elif type(vo) in (tuple, list):
@ -354,7 +341,6 @@ def _get_defun_inputs(args):
"""Maps the inputs args to graph inputs.""" """Maps the inputs args to graph inputs."""
ret = [] ret = []
for a in args: for a in args:
a = ag_core.getval(a)
if isinstance(a, (tensor.LazyZero, ops.Tensor, tensor.Tensor)): if isinstance(a, (tensor.LazyZero, ops.Tensor, tensor.Tensor)):
ret.append(graph_placeholder(a.dtype, a.shape)) ret.append(graph_placeholder(a.dtype, a.shape))
elif type(a) in (tuple, list): elif type(a) in (tuple, list):
@ -395,7 +381,7 @@ def _defun_internal(name, func, args, kwds):
] ]
all_inputs = flat_inputs + list(extra_placeholders) all_inputs = flat_inputs + list(extra_placeholders)
func_def_outputs = [ag_core.getval(x) for x in outputs_list if x is not None] func_def_outputs = [x for x in outputs_list if x is not None]
inference_function_def = graph_to_function_def.graph_to_function_def( inference_function_def = graph_to_function_def.graph_to_function_def(
tmp_graph, tmp_graph.get_operations(), all_inputs, func_def_outputs) tmp_graph, tmp_graph.get_operations(), all_inputs, func_def_outputs)
# Register any other functions defined in the graph # Register any other functions defined in the graph
@ -421,7 +407,6 @@ _ZeroDtype = collections.namedtuple("_ZeroDtype", ["dtype", "shape"])
def _cache_key(x): def _cache_key(x):
"""Cache key for tfe functions.""" """Cache key for tfe functions."""
x = ag_core.getval(x)
if isinstance(x, tensor.Tensor): if isinstance(x, tensor.Tensor):
return _TensorDtype(x.dtype, x._shape_tuple()) # pylint: disable=protected-access return _TensorDtype(x.dtype, x._shape_tuple()) # pylint: disable=protected-access
if isinstance(x, tensor.LazyZero): if isinstance(x, tensor.LazyZero):

View File

@ -18,96 +18,114 @@ from __future__ import absolute_import
from __future__ import division from __future__ import division
from __future__ import print_function from __future__ import print_function
import collections
import threading import threading
from autograd import container_types
from autograd import core as ag_core
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.util import nest
from tensorflow.python.util import tf_contextlib from tensorflow.python.util import tf_contextlib
class ImplicitTape(object): def tid(tensor):
"""Global object which can watch tensors and wrap them with autograd.""" return tensor._id # pylint: disable=protected-access
class TapeEntry(
collections.namedtuple("TapeEntry", [
"output_ids", "inputs", "side_outputs", "backward_function"
])):
"""Entry in the gradient tape.
Represents the execution of one op or function, with instructions for doing
its backward pass and useful information for it.
Args:
output_ids: tensor_id(t) for each output tensor T
inputs: input tensors
side_outputs: optional tensors which need to be provided to the backward
function.
backward_function: function to be called with the downstream gradients and
side outputs as arguments which computes the backward pass.
"""
def _tensor_shape(t):
return t._shape_tuple() # pylint: disable=protected-access
class Tape(object):
"""Represents a gradient propagation trace."""
def __init__(self): def __init__(self):
self.tensors = {} # _tensor_tape maps from tensor IDs to their operation IDs
self.variables = {} self._tensor_tape = {}
self.gradients = [] # maps output tensor IDs to their shapes and dtypes
self._shape_dtype = {}
# maps from operation ID to TapeEntry
self._op_tape = {}
# next operation ID
self._next_op_id = 0
# List of directly watched tensors
self._watched = []
# Set of directly watched variables
self._watched_variables = set()
def __eq__(self, other): def should_record(self, tensors):
return self is other """Returns true if any tensor should be recorded.
def __hash__(self): Args:
return id(self) tensors: some tensors.
Returns:
True if any of the tensors is in the tape.
"""
return any(x._id in self._tensor_tape for x in tensors) # pylint: disable=protected-access
@ag_core.primitive def watch(self, tensor):
def _watch_with_tape_internal(_, tensor): """Adds a tensor to the tape."""
"""Primitive to wrap a tensor around an ImplicitTape progenitor.""" if tid(tensor) not in self._tensor_tape:
return tensor self._tensor_tape[tid(tensor)] = None
self._watched.append(tensor)
def watch_variable(self, v):
self._watched_variables.add(v)
self.watch(v.handle)
def _watch_with_tape(tape, resource_variable): def record_operation(self, output_tensors, input_tensors, side_outputs,
"""Wraps a watched Tensor and keeps track of it in the implicit tape.""" backward_function):
tensor = resource_variable.handle """Records an operation in the tape."""
w = _watch_with_tape_internal(tape, tensor) if not self.should_record(input_tensors):
if ag_core.isnode(tape): return output_tensors
tape.value.variables[ops.tensor_id(tensor)] = resource_variable for t in output_tensors:
tape.value.tensors[ops.tensor_id(tensor)] = w self._tensor_tape[tid(t)] = self._next_op_id
self._shape_dtype[tid(t)] = (_tensor_shape(t), t.dtype)
self._op_tape[self._next_op_id] = TapeEntry(
[tid(t) for t in output_tensors],
input_tensors,
side_outputs,
backward_function)
self._next_op_id += 1
def _watch_with_tape_vjp(g, ans, vs, gvs, tape, tensor): def delete_trace(self, tensor):
"""Gradient for _watch_with_tape_internal.""" """Deletes any trace we have for this tensor."""
del ans, gvs if tid(tensor) in self._tensor_tape:
op = self._tensor_tape[tid(tensor)]
del self._tensor_tape[tid(tensor)]
if op in self._op_tape:
if not any(
x in self._tensor_tape for x in self._op_tape[op].output_ids):
del self._op_tape[op]
def mut_add(implicit_tape): def export(self):
resource_variable = tape.value.variables[ops.tensor_id(tensor)] """Exports the internal state of this tape.
implicit_tape.gradients.append((g, resource_variable))
return implicit_tape
return ag_core.SparseObject(vs, mut_add) Returns:
tensor_tape: a map from tensor_id(tensor) to <identifier for op>
_watch_with_tape_internal.defvjp(_watch_with_tape_vjp, argnum=0) responsible for generating that tensor.
_watch_with_tape_internal.defvjp( op_tape: a map from <identifier for op> to TapeEntry for that op.
lambda g, ans, vs, gvs, tape, tensor: g, output_to_shape_dtype: a map from tensor_id(tensor) to its shape and
argnum=1) dtype, for tensors which are outputs
"""
return self._tensor_tape, self._op_tape, self._shape_dtype
class ImplicitTapeVSpace(ag_core.VSpace):
"""VSpace needed to have ImplicitTape be a valid progenitor."""
def zeros(self):
return ImplicitTape()
class ImplicitTapeNode(ag_core.Node):
"""Node to wrap ImplicitTape in."""
def __eq__(self, other):
return self is other
def __hash__(self):
return id(self)
ag_core.register_node(ImplicitTapeNode, ImplicitTape)
ag_core.register_vspace(ImplicitTapeVSpace, ImplicitTape)
# TODO(apassos) try to not do this.
class NoneVSpace(ag_core.VSpace):
"""VSpace for python None."""
def __init__(self, _):
self.size = 0
def zeros(self):
return 0
ag_core.register_vspace(NoneVSpace, type(None))
class _TapeStack(threading.local): class _TapeStack(threading.local):
@ -134,19 +152,33 @@ _tape_stack = _TapeStack()
def push_new_tape(): def push_new_tape():
"""Pushes a new tape onto the tape stack.""" """Pushes a new tape onto the tape stack."""
progenitor = ag_core.new_progenitor(ImplicitTape()) _tape_stack.stack.append(Tape())
_tape_stack.stack.append(progenitor)
ag_core.active_progenitors.add(progenitor)
def watch_variable(resource_variable): def watch(tensor):
"""Marks this ResourceVariable to be watched by all tapes in the stack. """Marks this tensor to be watched by all tapes in the stack.
Args: Args:
resource_variable: A ResourceVariable to be watched. tensor: tensor to be watched.
Returns:
The tensor, potentially wrapped by all tapes in the stack.
""" """
for t in _tape_stack.stack: for t in _tape_stack.stack:
_watch_with_tape(t, resource_variable) t.watch(tensor)
def watch_variable(variable):
"""Marks this variable to be watched by all tapes in the stack.
Args:
variable: variable to be watched.
Returns:
The tensor, potentially wrapped by all tapes in the stack.
"""
for t in _tape_stack.stack:
t.watch_variable(variable)
def pop_tape(): def pop_tape():
@ -156,85 +188,34 @@ def pop_tape():
return None return None
def any_tape_has(tensor):
for t in _tape_stack.stack:
if ops.tensor_id(tensor) in t.value.tensors:
return True
return False
def should_record(tensors): def should_record(tensors):
"""Returns true if any tape in the stack watches any of these tensors.""" """Returns true if any tape in the stach watches any of these tensors."""
return any(ag_core.isnode(x) for x in tensors) if not _tape_stack.stack:
return False
return any(x.should_record(tensors) for x in _tape_stack.stack)
class _EagerSequenceNode(container_types.SequenceNode): def record_operation(output_tensors, input_tensors, side_outputs,
"""Eager version of SequenceNode, to live in EagerSequenceVSpace.""" backward_function):
pass """Records the operation on all tapes in the stack."""
for t in _tape_stack.stack:
t.record_operation(output_tensors,
input_tensors,
side_outputs,
backward_function)
class _EagerSequenceVSpace(container_types.SequenceVSpace): def delete_trace(tensor):
"""Changes equality on SequenceVSpace to conform to tfe requirements.""" """Deletes traces for this Tensor from all tapes in the stack."""
for t in _tape_stack.stack:
def __init__(self, value): t.delete_trace(tensor)
self.shape = [ag_core.vspace(x) for x in value]
self.size = sum(s.size for s in self.shape)
self.sequence_type = type(value)
def __eq__(self, other):
if type(self) != type(other): # pylint: disable=unidiomatic-typecheck
return False
if len(self.shape) != len(other.shape):
# TODO(apassos) function gradients sometimes return gradients for side
# inputs which breaks this assertion. Understand how to fix it.
return True
for ss, os in zip(self.shape, other.shape):
if ss != os:
if isinstance(ss, NoneVSpace) or isinstance(os, NoneVSpace):
continue
if ss.dtype == dtypes.resource or os.dtype == dtypes.resource:
continue
return False
return True
class EagerList(list): def top_tape_watched_tensors():
"""Type used to bypass SequenceVSpace. t = _tape_stack.stack[-1]
return t._watched # pylint: disable=protected-access
SequenceVSpace has a very strict equality check which does not match
tensorflow semantics.
"""
def __init__(self, value):
super(EagerList, self).__init__(value)
for v in value:
assert not ag_core.isnode(v)
ag_core.register_vspace(_EagerSequenceVSpace, EagerList)
ag_core.register_node(_EagerSequenceNode, EagerList)
@ag_core.primitive def top_tape_watched_variables():
def _record_operation(output_tensors, input_tensors, side_outputs, t = _tape_stack.stack[-1]
backward_function): return t._watched_variables # pylint: disable=protected-access
del input_tensors, side_outputs, backward_function
return EagerList(output_tensors)
def record_operation(o, i, s, b):
"""Primitive to trigger autograd tracing on outputs from inputs."""
inputs = container_types.make_sequence(EagerList, *i)
return _record_operation(o, inputs, s, b)
def _record_operation_vjp(g, ans, vs, gvs, output_tensors, input_tensors,
side_outputs, backward_function):
"""Gradient for _record_operation."""
del vs, gvs, input_tensors, output_tensors
backward_args = tuple(g) + tuple(side_outputs)
backward_args = container_types.make_sequence(
EagerList, *(tuple(ans) + backward_args))
tensors = nest.flatten(backward_function(*backward_args))
return container_types.make_sequence(EagerList, *tensors)
_record_operation.defvjp(_record_operation_vjp, argnum=1)

View File

@ -1,301 +0,0 @@
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""TensorNode for autograd tracing of computations with Tensors."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from autograd import core as ag_core
from tensorflow.python.eager import context
from tensorflow.python.eager import custom_gradient
from tensorflow.python.eager import tensor
from tensorflow.python.framework import common_shapes
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.framework import tensor_shape
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import math_ops
@ag_core.primitive
def _tensor_numpy(t):
return t.numpy()
@ag_core.primitive
def _as_gpu_tensor(t, index=0):
return t.as_gpu_tensor(gpu_index=index)
_as_gpu_tensor.defvjp(
lambda g, ans, vs, gvs, t, index: g.as_cpu_tensor(), argnum=0)
@custom_gradient.custom_gradient
def _tensor_copy(t, ctx=None, device_name=None):
def grad(dresult):
return dresult._copy(device_name=t.device) # pylint: disable=protected-access
return t.value._copy(ctx=ctx, device_name=device_name), grad # pylint: disable=protected-access
@ag_core.primitive
def _as_cpu_tensor(t):
return t.as_cpu_tensor()
_as_cpu_tensor.defvjp(lambda g, ans, vs, gvs, t: g.as_gpu_tensor(), argnum=0)
# TODO(apassos,ashankar): The operator overrides here need to be kept in sync
# with the overrides for ops.Tensor and ops.EagerTensor.
#
# Note that we cannot use self.value.__op__() because that would result
# in an ops.EagerTensor instead of a TensorNode being returned.
#
# We need to figure out a way to ensure that the two are in sync.
class TensorNode(ag_core.Node):
"""A TensorFlow Tensor."""
__slots__ = []
def __getitem__(self, idx):
return array_ops._SliceHelper(self, idx) # pylint: disable=protected-access
shape = property(lambda self: self.value.shape)
dtype = property(lambda self: self.value.dtype)
device = property(lambda self: self.value.device)
def get_shape(self):
return self.shape
def numpy(self):
return _tensor_numpy(self)
def _shape_tuple(self):
return self.value._shape_tuple # pylint: disable=protected-access
def as_cpu_tensor(self):
return _as_cpu_tensor(self)
def as_gpu_tensor(self, gpu_index=0):
return _as_gpu_tensor(self, gpu_index)
def _copy(self, ctx=None, device_name=None):
return _tensor_copy(self, ctx, device_name)
def __neg__(self):
return math_ops.negative(self)
def __abs__(self):
return math_ops.abs(self) # pylint: disable=protected-access
def __invert__(self):
# ops.Tensor used math_ops.logical_not as of August 2017.
# Now that bitwise_ops.invert exists, it might make sense
# for both ops.Tensor and TensorNode to use that if the
# type is compatible.
return math_ops.logical_not(self)
def __hash__(self):
return id(self)
def __add__(self, other):
if isinstance(self.value, tensor.LazyZero):
return other
if isinstance(other, tensor.LazyZero):
return self
return math_ops.add(self, other)
def __radd__(self, other):
if isinstance(self.value, tensor.LazyZero):
return other
if isinstance(ag_core.getval(other), tensor.LazyZero):
return self
return math_ops.add(other, self)
def __sub__(self, other):
return math_ops.subtract(self, other)
def __rsub__(self, other):
return math_ops.subtract(other, self)
def __mul__(self, other):
return math_ops.multiply(self, other)
def __rmul__(self, other):
return math_ops.multiply(other, self)
def __mod__(self, other):
return math_ops.floormod(self, other)
def __rmod__(self, other):
return math_ops.floormod(other, self)
def __pow__(self, other):
return math_ops.pow(self, other)
def __rpow__(self, other):
return math_ops.pow(other, self)
def __div__(self, other):
return math_ops._div_python2(self, other) # pylint: disable=protected-access
def __rdiv__(self, other):
return math_ops._div_python2(other, self) # pylint: disable=protected-access
def __truediv__(self, other):
return math_ops._truediv_python3(self, other) # pylint: disable=protected-access
def __rtruediv__(self, other):
return math_ops._truediv_python3(other, self) # pylint: disable=protected-access
def __floordiv__(self, other):
return math_ops.floordiv(self, other)
def __rfloordiv__(self, other):
return math_ops.floordiv(other, self)
def __eq__(self, other):
# math_ops.equal raises an error if shapes are not compatible, so check that
# explicitly first.
if common_shapes.is_broadcast_compatible(
self.shape, ops.convert_to_tensor(other).shape):
return math_ops.equal(self, other)
return False
def __gt__(self, other):
return math_ops.greater(self, other)
def __ge__(self, other):
return math_ops.greater_equal(self, other)
def __lt__(self, other):
return math_ops.less(self, other)
def __le__(self, other):
return math_ops.less_equal(self, other)
ag_core.register_node(TensorNode, tensor.Tensor)
ag_core.register_node(TensorNode, ops.Tensor)
def _zeros(shape, dtype):
with context.device("cpu:0"):
shape = tensor.Tensor(shape, dtype=dtypes.int32)
return array_ops.fill(shape, tensor.Tensor(0, dtype=dtype))
def _ones(shape, dtype):
return array_ops.fill(
tensor.Tensor(shape, dtype=dtypes.int32), tensor.Tensor(1, dtype=dtype))
def _lazy_zero_tensor(zero):
return _zeros(zero.shape, zero.dtype)
tensor.LazyZero.tensor = _lazy_zero_tensor
def _lazy_zero_to_tensor(lazy_zero, dtype=None, name=None, as_ref=False):
del as_ref, name, dtype
return _zeros(lazy_zero.shape, lazy_zero.dtype)
ops.register_tensor_conversion_function(tensor.LazyZero, _lazy_zero_to_tensor)
def _indexed_slices_to_tensor(value):
"""Converts an IndexedSlices object `value` to a Tensor.
Args:
value: An ops.IndexedSlices object.
Returns:
A dense Tensor representing the values in the given IndexedSlices.
Raises:
ValueError: If the IndexedSlices does not have the same dtype.
"""
if value.dense_shape is None:
raise ValueError(
"Tensor conversion requested for IndexedSlices without dense_shape: %s"
% str(value))
return math_ops.unsorted_segment_sum(value.values, value.indices,
value.dense_shape[0])
class TensorVSpace(ag_core.VSpace):
"""VSpace for tf/tfe Tensors in autograd."""
def __init__(self, value):
self.size = 1
if isinstance(value, ops.IndexedSlices):
self.shape = tensor_shape.TensorShape(value.dense_shape.numpy())
self.dtype = value.values.dtype
else:
self.shape = value._shape_tuple() # pylint: disable=protected-access
self.dtype = value.dtype
# TODO(apassos) put gradients on the same device as ops.
def __eq__(self, other):
# TODO(apassos) consider revisiting this if not performance sensitive.
return True
def __ne__(self, other):
return not self.__eq__(other)
def zeros(self):
return tensor.LazyZero(self.shape, self.dtype)
def ones(self):
return _ones(self.shape, self.dtype)
def standard_basis(self):
raise NotImplementedError
def flatten(self, value):
return array_ops.reshape(value, tensor.Tensor(-1))
def unflatten(self, value):
return array_ops.reshape(value, tensor.Tensor(self.shape))
def mut_add(self, x, y):
"""Add wrapper safe for IndexedSlices and LazyZero."""
if isinstance(ag_core.getval(x), tensor.LazyZero):
return y
if isinstance(ag_core.getval(y), tensor.LazyZero):
return x
if isinstance(x, ops.IndexedSlices):
x = _indexed_slices_to_tensor(x)
if isinstance(y, ops.IndexedSlices):
y = _indexed_slices_to_tensor(y)
if x is None:
return y
if y is None:
return x
return math_ops.add(x, y)
ag_core.register_vspace(TensorVSpace, tensor.Tensor)
ag_core.register_vspace(TensorVSpace, ops.Tensor)
ag_core.register_vspace(TensorVSpace, ops.IndexedSlices)
ag_core.register_vspace(TensorVSpace, tensor.LazyZero)
ag_core.register_node(TensorNode, tensor.LazyZero)

View File

@ -1,84 +0,0 @@
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from tensorflow.python.eager import tensor
from tensorflow.python.eager import tensor_node
from tensorflow.python.eager import test
from tensorflow.python.framework import test_util
def public_or_operator(n):
if not n.startswith("_"):
return True
return n.startswith("__") and n.endswith("__")
class TensorNodeTest(test_util.TensorFlowTestCase):
# TensorNode must implement autograd core's Node interface, which
# it does via inheritance. It also needs to be duck-typeable as
# a tensorflow.python.framework.ops.EagerTensor.
#
# This is a "test" to help ensure interface compatibility.
def testCanBeATensor(self):
# TODO(ashankar,apassos): This list of "exceptions" - list of
# Tensor methods not implemented by TensorNode needs to be
# trimmed.
exceptions = set([
"OVERLOADABLE_OPERATORS",
"__and__",
"__del__",
"__dict__",
"__iter__",
"__len__",
"__matmul__",
"__or__",
"__rand__",
"__rmatmul__",
"__ror__",
"__rxor__",
"__weakref__",
"__xor__",
# BEGIN: Methods of Tensor that EagerTensor raises exceptions on.
# But perhaps TensorNode should defer to "self.value.<method>" for
# them?
"consumers",
"eval",
"graph",
"name",
"op",
"set_shape",
"value_index",
# END: Methods of Tensor that EagerTensor raises exceptions on.
])
tensor_dir = dir(tensor.Tensor)
tensor_dir = filter(public_or_operator, tensor_dir)
tensor_dir = set(tensor_dir).difference(exceptions)
tensor_node_dir = set(dir(tensor_node.TensorNode))
missing = tensor_dir.difference(tensor_node_dir)
self.assertEqual(
0,
len(missing),
msg="Methods/properties missing in TensorNode: {}".format(missing))
if __name__ == "__main__":
test.main()

View File

@ -41,7 +41,6 @@ from __future__ import absolute_import
from __future__ import division from __future__ import division
from __future__ import print_function from __future__ import print_function
from autograd import core as ag_core
import numpy as np import numpy as np
from tensorflow.core.framework import attr_value_pb2 from tensorflow.core.framework import attr_value_pb2
@ -76,7 +75,7 @@ def _eager_fill(dims, value):
def convert_to_eager_tensor(t, dtype=None): def convert_to_eager_tensor(t, dtype=None):
"""Converts the given `value` to an `EagerTensor`.""" """Converts the given `value` to an `EagerTensor`."""
if isinstance(ag_core.getval(t), ops.EagerTensor): if isinstance(t, ops.EagerTensor):
if dtype is not None and t.dtype != dtype: if dtype is not None and t.dtype != dtype:
raise TypeError("Expected tensor with type %r not %r" % (dtype, t.dtype)) raise TypeError("Expected tensor with type %r not %r" % (dtype, t.dtype))
return t return t

View File

@ -19,7 +19,6 @@ from __future__ import absolute_import
from __future__ import division from __future__ import division
from __future__ import print_function from __future__ import print_function
from autograd import core as ag_core
import six import six
from tensorflow.core.framework import attr_value_pb2 from tensorflow.core.framework import attr_value_pb2
@ -503,7 +502,6 @@ class OpDefLibrary(object):
default_dtype = default_type_attr_map[input_arg.type_attr] default_dtype = default_type_attr_map[input_arg.type_attr]
try: try:
values = ag_core.getval(values)
values = ops.internal_convert_to_tensor( values = ops.internal_convert_to_tensor(
values, values,
name=input_arg.name, name=input_arg.name,
@ -784,7 +782,6 @@ class OpDefLibrary(object):
if arg.is_ref] if arg.is_ref]
with _MaybeColocateWith(must_colocate_inputs): with _MaybeColocateWith(must_colocate_inputs):
# Add Op to graph # Add Op to graph
inputs = [ag_core.getval(x) for x in inputs]
op = g.create_op(op_type_name, inputs, output_types, name=scope, op = g.create_op(op_type_name, inputs, output_types, name=scope,
input_types=input_types, attrs=attr_protos, input_types=input_types, attrs=attr_protos,
op_def=op_def) op_def=op_def)

View File

@ -25,7 +25,6 @@ import re
import sys import sys
import threading import threading
from autograd import core as ag_core
import numpy as np import numpy as np
import six import six
@ -38,6 +37,7 @@ from tensorflow.core.framework import versions_pb2
from tensorflow.python import pywrap_tensorflow as c_api from tensorflow.python import pywrap_tensorflow as c_api
from tensorflow.python.eager import context from tensorflow.python.eager import context
from tensorflow.python.eager import core from tensorflow.python.eager import core
from tensorflow.python.eager import tape
from tensorflow.python.framework import c_api_util from tensorflow.python.framework import c_api_util
from tensorflow.python.framework import device as pydev from tensorflow.python.framework import device as pydev
from tensorflow.python.framework import dtypes from tensorflow.python.framework import dtypes
@ -70,10 +70,9 @@ from tensorflow.python.util import tf_contextlib
_USE_C_API = False _USE_C_API = False
def tensor_id(t): def tensor_id(tensor):
"""Returns a unique identifier for this Tensor.""" """Returns a unique identifier for this Tensor."""
t = ag_core.getval(t) return tensor._id # pylint: disable=protected-access
return t._id # pylint: disable=protected-access
def _in_gpu_device(): def _in_gpu_device():
@ -703,6 +702,7 @@ class EagerTensor(Tensor):
def __del__(self): def __del__(self):
try: try:
tape.delete_trace(self)
if c_api is not None and c_api.TFE_DeleteTensorHandle is not None: if c_api is not None and c_api.TFE_DeleteTensorHandle is not None:
c_api.TFE_DeleteTensorHandle(self._handle) c_api.TFE_DeleteTensorHandle(self._handle)
if core.active_trace() is not None: if core.active_trace() is not None:
@ -727,7 +727,7 @@ class EagerTensor(Tensor):
self.dtype.name) self.dtype.name)
def __repr__(self): def __repr__(self):
return "<tf.Tensor: id=%s, shape=%s, dtype=%s, numpy=%s)>" % ( return "<tf.Tensor: id=%s, shape=%s, dtype=%s, numpy=%s>" % (
self._id, self.shape, self.dtype.name, self._numpy_text(is_repr=True)) self._id, self.shape, self.dtype.name, self._numpy_text(is_repr=True))
@staticmethod @staticmethod
@ -770,6 +770,16 @@ class EagerTensor(Tensor):
tensor_id(new_tensor), tensor_id(new_tensor),
new_tensor.device, new_tensor.device,
new_tensor.shape.num_elements()) new_tensor.shape.num_elements())
# Record the copy on tape and define backprop copy as well.
if not context.in_graph_mode():
self_device = self.device
def grad_fun(dresult):
with errors.raise_exception_on_not_ok_status() as status:
grad_h = c_api.TFE_TensorHandleCopyToDevice(
dresult._handle, ctx._handle, self_device, status)
return _tensor_from_handle(grad_h)
tape.record_operation([new_tensor], [self], [], grad_fun)
return new_tensor return new_tensor
# pylint: enable=protected-access # pylint: enable=protected-access
@ -1033,26 +1043,21 @@ def internal_convert_to_tensor(value,
RuntimeError: If a registered conversion function returns an invalid value. RuntimeError: If a registered conversion function returns an invalid value.
""" """
# Note we check the type of the object unwrapped from an autograd node, if
# tracing gradients, to ensure the same behavior happens with and without
# tracing.
unwrapped = ag_core.getval(value)
if context.in_eager_mode(): if context.in_eager_mode():
# Fast path for EagerTensors that don't need any conversion. # Fast path for EagerTensors that don't need any conversion.
if isinstance(unwrapped, EagerTensor): if isinstance(value, EagerTensor):
# Note that we don't check that value's dtype matches the dtype # Note that we don't check that value's dtype matches the dtype
# argument. We exepct that the C runtime will do that checking # argument. We exepct that the C runtime will do that checking
# when we execute the kernel. # when we execute the kernel.
return value return value
values = nest.flatten(value) values = nest.flatten(value)
if (len(values) > 1 and if (len(values) > 1 and
any(isinstance(ag_core.getval(v), EagerTensor) for v in values)): any(isinstance(v, EagerTensor) for v in values)):
raise TypeError("Cannot convert to a eager tensor.") raise TypeError("Cannot convert to a eager tensor.")
if dtype is not None: if dtype is not None:
dtype = dtypes.as_dtype(dtype) dtype = dtypes.as_dtype(dtype)
unwrapped_type = type(unwrapped) unwrapped_type = type(value)
conversion_func_list = _tensor_conversion_func_cache.get(unwrapped_type, None) conversion_func_list = _tensor_conversion_func_cache.get(unwrapped_type, None)
if conversion_func_list is None: if conversion_func_list is None:
with _tensor_conversion_func_lock: with _tensor_conversion_func_lock:
@ -1060,7 +1065,7 @@ def internal_convert_to_tensor(value,
for _, funcs_at_priority in sorted( for _, funcs_at_priority in sorted(
_tensor_conversion_func_registry.items()): _tensor_conversion_func_registry.items()):
for base_type, conversion_func in funcs_at_priority: for base_type, conversion_func in funcs_at_priority:
if isinstance(unwrapped, base_type): if isinstance(value, base_type):
conversion_func_list.append((base_type, conversion_func)) conversion_func_list.append((base_type, conversion_func))
_tensor_conversion_func_cache[unwrapped_type] = conversion_func_list _tensor_conversion_func_cache[unwrapped_type] = conversion_func_list
@ -1090,7 +1095,7 @@ def internal_convert_to_tensor(value,
if ret is NotImplemented: if ret is NotImplemented:
continue continue
if not isinstance(ag_core.getval(ret), Tensor): if not isinstance(ret, Tensor):
raise RuntimeError( raise RuntimeError(
"%sConversion function %r for type %s returned non-Tensor: %r" % "%sConversion function %r for type %s returned non-Tensor: %r" %
(_error_prefix(name), conversion_func, base_type, ret)) (_error_prefix(name), conversion_func, base_type, ret))

View File

@ -18,7 +18,6 @@ from __future__ import absolute_import
from __future__ import division from __future__ import division
from __future__ import print_function from __future__ import print_function
from autograd import core as ag_core
import numpy as np import numpy as np
import six import six
@ -607,7 +606,7 @@ def ShapeEquals(tensor_proto, shape):
def _ConstantValue(tensor, partial): def _ConstantValue(tensor, partial):
# TODO(touts): Support Variables? # TODO(touts): Support Variables?
if not isinstance(ag_core.getval(tensor), ops.Tensor): if not isinstance(tensor, ops.Tensor):
raise TypeError("tensor is not a Tensor") raise TypeError("tensor is not a Tensor")
if tensor.op.type == "Const": if tensor.op.type == "Const":
return MakeNdarray(tensor.op.get_attr("value")) return MakeNdarray(tensor.op.get_attr("value"))
@ -737,7 +736,7 @@ def constant_value(tensor, partial=False): # pylint: disable=invalid-name
Raises: Raises:
TypeError: if tensor is not an ops.Tensor. TypeError: if tensor is not an ops.Tensor.
""" """
if isinstance(ag_core.getval(tensor), ops.EagerTensor): if isinstance(tensor, ops.EagerTensor):
return tensor.numpy() return tensor.numpy()
ret = _ConstantValue(tensor, partial) ret = _ConstantValue(tensor, partial)
if ret is not None: if ret is not None:

View File

@ -291,10 +291,11 @@ def shape_internal(input, name=None, optimize=True, out_type=dtypes.int32):
sparse_tensor.SparseTensorValue)): sparse_tensor.SparseTensorValue)):
return gen_math_ops.cast(input.dense_shape, out_type) return gen_math_ops.cast(input.dense_shape, out_type)
else: else:
input_tensor = ops.convert_to_tensor(input) if context.in_graph_mode():
input_shape = input_tensor.get_shape() input_tensor = ops.convert_to_tensor(input)
if optimize and input_shape.is_fully_defined(): input_shape = input_tensor.get_shape()
return constant(input_shape.as_list(), out_type, name=name) if optimize and input_shape.is_fully_defined():
return constant(input_shape.as_list(), out_type, name=name)
return gen_array_ops.shape(input, name=name, out_type=out_type) return gen_array_ops.shape(input, name=name, out_type=out_type)
@ -1428,7 +1429,9 @@ def zeros(shape, dtype=dtypes.float32, name=None):
zero = "" zero = ""
else: else:
zero = 0 zero = 0
if context.in_eager_mode(): # Checking for boolean dtype to prevent attempting to run fill on the GPU
# which does not have a boolean kernel registered.
if context.in_eager_mode() and dtype != dtypes.bool:
return fill(shape, constant(zero, dtype=dtype), name=name) return fill(shape, constant(zero, dtype=dtype), name=name)
try: try:
shape = tensor_shape.as_shape(shape) shape = tensor_shape.as_shape(shape)

View File

@ -144,7 +144,6 @@ from __future__ import absolute_import
from __future__ import division from __future__ import division
from __future__ import print_function from __future__ import print_function
from autograd import core as ag_core
import numpy as np import numpy as np
from six.moves import xrange # pylint: disable=redefined-builtin from six.moves import xrange # pylint: disable=redefined-builtin
@ -1113,12 +1112,11 @@ floormod = gen_math_ops._floor_mod
def _mul_dispatch(x, y, name=None): def _mul_dispatch(x, y, name=None):
"""Dispatches cwise mul for "Dense*Dense" and "Dense*Sparse".""" """Dispatches cwise mul for "Dense*Dense" and "Dense*Sparse"."""
is_tensor_y = isinstance(ag_core.getval(y), ops.Tensor) is_tensor_y = isinstance(y, ops.Tensor)
if is_tensor_y: if is_tensor_y:
return gen_math_ops._mul(x, y, name=name) return gen_math_ops._mul(x, y, name=name)
else: else:
assert isinstance(ag_core.getval(y), assert isinstance(y, sparse_tensor.SparseTensor) # Case: Dense * Sparse.
sparse_tensor.SparseTensor) # Case: Dense * Sparse.
new_vals = gen_sparse_ops.sparse_dense_cwise_mul(y.indices, y.values, new_vals = gen_sparse_ops.sparse_dense_cwise_mul(y.indices, y.values,
y.dense_shape, x, name) y.dense_shape, x, name)
return sparse_tensor.SparseTensor(y.indices, new_vals, y.dense_shape) return sparse_tensor.SparseTensor(y.indices, new_vals, y.dense_shape)

View File

@ -19,14 +19,10 @@ from __future__ import absolute_import
from __future__ import division from __future__ import division
from __future__ import print_function from __future__ import print_function
from autograd import core as ag_core
from tensorflow.core.framework import attr_value_pb2 from tensorflow.core.framework import attr_value_pb2
from tensorflow.core.framework import variable_pb2 from tensorflow.core.framework import variable_pb2
from tensorflow.python.eager import context from tensorflow.python.eager import context
from tensorflow.python.eager import custom_gradient
from tensorflow.python.eager import tape from tensorflow.python.eager import tape
from tensorflow.python.eager import tensor_node
from tensorflow.python.framework import dtypes from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops from tensorflow.python.framework import ops
from tensorflow.python.framework import tensor_shape from tensorflow.python.framework import tensor_shape
@ -506,10 +502,8 @@ class ResourceVariable(variables.Variable):
def _read_variable_op(self): def _read_variable_op(self):
if hasattr(self, "_trainable") and self._trainable: if hasattr(self, "_trainable") and self._trainable:
tape.watch_variable(self) tape.watch_variable(self)
return read_variable_op(self._handle, dtype=self._dtype) return gen_resource_variable_ops.read_variable_op(self._handle,
else: self._dtype)
return gen_resource_variable_ops.read_variable_op(self._handle,
self._dtype)
def read_value(self): def read_value(self):
"""Constructs an op which reads the value of this variable. """Constructs an op which reads the value of this variable.
@ -541,7 +535,7 @@ class ResourceVariable(variables.Variable):
with ops.name_scope("Gather" if name is None else name) as name: with ops.name_scope("Gather" if name is None else name) as name:
if self._trainable: if self._trainable:
tape.watch_variable(self) tape.watch_variable(self)
value = resource_gather( value = gen_resource_variable_ops.resource_gather(
self._handle, indices, dtype=self._dtype, name=name) self._handle, indices, dtype=self._dtype, name=name)
return array_ops.identity(value) return array_ops.identity(value)
@ -614,13 +608,7 @@ class ResourceVariable(variables.Variable):
def _run_op(a, *args): def _run_op(a, *args):
# pylint: disable=protected-access # pylint: disable=protected-access
value = a._AsTensor() value = a._AsTensor()
if ag_core.isnode(value): return getattr(ops.Tensor, operator)(value, *args)
# This avoids autograd trying to wrap a ResourceVariable.
value = ops.convert_to_tensor(value)
args = [ops.convert_to_tensor(x) for x in args]
return getattr(tensor_node.TensorNode, operator)(value, *args)
else:
return getattr(ops.Tensor, operator)(value, *args)
# Propagate __doc__ to wrapper # Propagate __doc__ to wrapper
try: try:
@ -693,33 +681,6 @@ class ResourceVariable(variables.Variable):
return self.value() return self.value()
@custom_gradient.custom_gradient
def read_variable_op(handle, dtype):
"""Reads the value of a variable.
The tensor returned by this operation is immutable.
The value returned by this operation is guaranteed to be influenced by all the
writes on which this operation depends directly or indirectly, and to not be
influenced by any of the writes which depend directly or indirectly on this
operation.
Args:
handle: A `Tensor` of type `resource`.
handle to the resource in which to store the variable.
dtype: A `tf.DType`. the dtype of the value.
Returns:
A `Tensor` of type `dtype`.
"""
result = gen_resource_variable_ops.read_variable_op(handle, dtype)
def grad(dresult):
return dresult
return result, grad
def _dense_var_to_tensor(var, dtype=None, name=None, as_ref=False): def _dense_var_to_tensor(var, dtype=None, name=None, as_ref=False):
return var._dense_var_to_tensor(dtype=dtype, name=name, as_ref=as_ref) # pylint: disable=protected-access return var._dense_var_to_tensor(dtype=dtype, name=name, as_ref=as_ref) # pylint: disable=protected-access
@ -744,51 +705,6 @@ def _ReadGrad(_, grad):
return grad return grad
# TODO(apassos) do not use custom_gradient here by making other entry points
# than custom_gradient also aware of how to deal with variables implicitly
# watched in the tape (i.e. the call to _watch_value in custom_gradient)
@custom_gradient.custom_gradient
def resource_gather(resource, indices, dtype, validate_indices=True, name=None):
"""Gather slices from the variable pointed to by `resource`.
`indices` must be an integer tensor of any dimension (usually 0-D or 1-D).
Produces an output tensor with shape `indices.shape + params.shape[1:]` where:
```python
# Scalar indices
output[:, ..., :] = params[indices, :, ... :]
# Vector indices
output[i, :, ..., :] = params[indices[i], :, ... :]
# Higher rank indices
output[i, ..., j, :, ... :] = params[indices[i, ..., j], :, ..., :]
```
Args:
resource: A `Tensor` of type `resource`.
handle to the resource in which to store the variable.
indices: a integer `Tensor` containing the indices to be gathered.
dtype: A `tf.DType`. the dtype of the value.
validate_indices: optional `bool`. If false will not validate that the
indices fit in the variable.
name: The optional name for the operation to be added.
Returns:
A `Tensor` of type `dtype`.
"""
result = gen_resource_variable_ops.resource_gather(
resource, indices, dtype, validate_indices=validate_indices, name=name)
def grad(dresult):
return ops.IndexedSlices(
dresult,
indices,
dense_shape=gen_resource_variable_ops.variable_shape(resource))
return result, grad
@ops.RegisterGradient("ResourceGather") @ops.RegisterGradient("ResourceGather")
def _GatherGrad(op, grad): def _GatherGrad(op, grad):
"""Gradient for gather op.""" """Gradient for gather op."""
@ -797,7 +713,11 @@ def _GatherGrad(op, grad):
# TODO(apassos): more robust way of getting the shape. # TODO(apassos): more robust way of getting the shape.
# TODO(apassos): implement this for EAGER mode. # TODO(apassos): implement this for EAGER mode.
if context.in_eager_mode(): if context.in_eager_mode():
raise NotImplementedError("_GatherGrad not implemented for EAGER mode") dense_shape = gen_resource_variable_ops.variable_shape(op.inputs[0])
return (ops.IndexedSlices(grad,
op.inputs[1],
dense_shape=dense_shape),
None)
handle = op.inputs[0] handle = op.inputs[0]
while handle.op.type != "VarHandleOp": while handle.op.type != "VarHandleOp":
handle = handle.op.inputs[0] handle = handle.op.inputs[0]