diff --git a/tensorflow/python/BUILD b/tensorflow/python/BUILD index 5753f4d0178..092435a40f1 100644 --- a/tensorflow/python/BUILD +++ b/tensorflow/python/BUILD @@ -637,6 +637,7 @@ py_library( "//tensorflow/core:protos_all_py", "//tensorflow/python/eager:context", "//tensorflow/python/eager:core", + "//tensorflow/python/eager:tape", "@six_archive//:six", ], ) @@ -1800,7 +1801,6 @@ py_library( "//tensorflow/python/eager:context", "//tensorflow/python/eager:custom_gradient", "//tensorflow/python/eager:tape", - "//tensorflow/python/eager:tensor_node", ], ) diff --git a/tensorflow/python/eager/BUILD b/tensorflow/python/eager/BUILD index c848ee96bc7..39f33a7ecc2 100644 --- a/tensorflow/python/eager/BUILD +++ b/tensorflow/python/eager/BUILD @@ -82,7 +82,6 @@ py_library( visibility = ["//tensorflow:internal"], deps = [ "//tensorflow/python:dtypes", - "//tensorflow/python:framework_ops", "//tensorflow/python:util", ], ) @@ -307,25 +306,6 @@ py_library( ], ) -py_library( - name = "tensor_node", - srcs = ["tensor_node.py"], - srcs_version = "PY2AND3", - visibility = ["//tensorflow:internal"], - deps = [ - ":context", - ":custom_gradient", - ":tape", - ":tensor", - "//tensorflow/python:array_ops", - "//tensorflow/python:common_shapes", - "//tensorflow/python:dtypes", - "//tensorflow/python:framework_ops", - "//tensorflow/python:math_ops", - "//tensorflow/python:tensor_shape", - ], -) - py_library( name = "backprop", srcs = ["backprop.py"], @@ -344,7 +324,6 @@ py_library( "//tensorflow/python/eager:execute", "//tensorflow/python/eager:tape", "//tensorflow/python/eager:tensor", - "//tensorflow/python/eager:tensor_node", "@six_archive//:six", ], ) @@ -386,18 +365,6 @@ py_test( ], ) -py_test( - name = "tensor_node_test", - srcs = ["tensor_node_test.py"], - srcs_version = "PY2AND3", - deps = [ - ":tensor", - ":tensor_node", - ":test", - "//tensorflow/python:framework_test_lib", - ], -) - py_test( name = "ops_test", srcs = ["ops_test.py"], diff --git a/tensorflow/python/eager/backprop.py b/tensorflow/python/eager/backprop.py index f6b9cec0bd1..b99474e2ebc 100644 --- a/tensorflow/python/eager/backprop.py +++ b/tensorflow/python/eager/backprop.py @@ -21,30 +21,265 @@ from __future__ import print_function import collections import threading -from autograd import container_types -from autograd import convenience_wrappers -from autograd import core as ag_core - import six from tensorflow.python import pywrap_tensorflow from tensorflow.python.eager import context from tensorflow.python.eager import execute from tensorflow.python.eager import tape -from tensorflow.python.eager import tensor -# Imports TensorNode to enable autograd tracing of TF ops. We don't need to use -# any symbols here but import the file just to get the right registrations to -# happen. -from tensorflow.python.eager import tensor_node # pylint: disable=unused-import +from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes from tensorflow.python.framework import errors -from tensorflow.python.framework import ops as tf_ops +from tensorflow.python.framework import ops from tensorflow.python.framework import tensor_shape from tensorflow.python.ops import array_ops from tensorflow.python.ops import math_ops +from tensorflow.python.util import tf_contextlib from tensorflow.python.util import tf_inspect +# Terminology: +# +# - op: a possibly composite operation, which has an entry in the tape +# - target: dy in dx/dy +# - source: dx in dx/dy +# - tensor: one of the many inputs or outputs of an operation +# +# Below here we do the gradient algorithm. It works as follows: +# +# First we filter the tape to just the subset of operations we want to +# differentiate. In the process of doing so we count how many times each Tensor +# is used as an input to an op (so we know when we're done computing gradients +# for that Tensor). We also count, for each tape entry, how many of its output +# Tensors need gradients to be computed (Tensors which are not used do not need +# any gradients to be computed). +# +# Finally, we start a backprop stack with a set of tape entries for which we +# have all gradients available. This set usually is a subset of the set of +# targets (not all since targets which have outputs in the tape will not have +# gradients available initially). +# +# Then we repeatedly pop an entry from the stack, run its backprop, and update +# the gradients of its inputs. Once we have computed all gradients for a single +# input we can mark this input as done, and this can trigger adding an entry to +# the stack if all outputs of that entry are now done. +# +# When the stack is empty we have gradients for all tensors we're interested in. + + +def _prepare_backprop(target, tensor_to_op, op_to_entry, id_sources): + """Filters the tape to only include relevant entries and counts tensor usages. + + Args: + target: the target to optimize. + tensor_to_op: Map from tensor id to key in op_to_entry that produced it. + op_to_entry: Map from op id to a tape.TapeEntry object + id_sources: the ids of the sources wrt the gradient is being taken. + + Returns: + usage counts (how many entries downstream from a tensor use it) + op_to_entry_map: entry map (a filtered tape, with only the relevant + entries), + missing: map from tensor id to how many downstream gradients still need + to be computed before this tensor's gradient can be computed. + """ + if isinstance(target, (ops.Tensor)): + tensor_stack = [ops.tensor_id(target)] + else: + tensor_stack = list([ops.tensor_id(x) for x in target]) + tensor_usage_counts = {} + o_to_e = {} # Copy of just the bits we need from op_to_entry + while tensor_stack: + t = tensor_stack.pop() + op = tensor_to_op[t] + # op is None if the tensor is a source (i.e. was watched directly) + if op is None or op in o_to_e: + continue + op_trace = op_to_entry[op] + o_to_e[op] = op_trace + for i in op_trace.inputs: + it = ops.tensor_id(i) + if it in tensor_usage_counts: + tensor_usage_counts[it] += 1 + else: + tensor_usage_counts[it] = 1 + if it not in id_sources and it in tensor_to_op: + tensor_stack.append(it) + op_missing_tensor_counts = collections.defaultdict(int) + for t in tensor_usage_counts: + if t in tensor_to_op and tensor_to_op[t] is not None: + op_missing_tensor_counts[tensor_to_op[t]] += 1 + return tensor_usage_counts, o_to_e, op_missing_tensor_counts + + +def _initialize_backprop_stack(op_to_entry, op_missing_tensor): + """Returns the set of tape entries which are available for backprop.""" + ready_ops = [] + for op in op_to_entry: + if op not in op_missing_tensor: + ready_ops.append(op) + return ready_ops + + +def _initial_gradients(target, output_gradients, tensor_usage_counts): + """Computes the initial gradients for each Tensor.""" + # Initialize the backprop stack + gradients = collections.defaultdict(list) + if isinstance(target, ops.Tensor): + if output_gradients is not None: + output_gradient = output_gradients + else: + output_gradient = array_ops.ones_like(target) + gradients[ops.tensor_id(target)].append(output_gradient) + else: + for i, t in enumerate(target): + if ops.tensor_id(t) in tensor_usage_counts: + # Can't provide a gradient of something we're trying to differentiate + assert output_gradients is None or output_gradients[i] is None + else: + if output_gradients is None or output_gradients[i] is None: + out_grad = ops.ones_like(t) + else: + out_grad = output_gradients[i] + gradients[ops.tensor_id(t)].append(out_grad) + return gradients + + +@tf_contextlib.contextmanager +def _no_op(): + yield + + +def _aggregate_grads(gradients): + """Aggregate gradients from multiple sources. + + Args: + gradients: A list of 'Tensor' or 'IndexedSlices' gradients. + + Returns: + If 'gradients' only has 'Tensor', returns an aggregated 'Tensor'. + Otherwise returns an aggregated 'IndexedSlices'. + """ + assert gradients, "No gradients to aggregate" + + if len(gradients) == 1: + return gradients[0] + if all([isinstance(g, ops.Tensor) for g in gradients]): + return math_ops.add_n(gradients) + else: + assert all([isinstance(g, (ops.Tensor, ops.IndexedSlices)) + for g in gradients]) + indexed_slices_list = [] + for grad in gradients: + # TODO(xpan): Support nested IndexedSlices and core IndexedSlices + if isinstance(grad, ops.Tensor): + indexed_slices = ops.IndexedSlices( + grad, + constant_op.constant(range(grad.shape[0])), + constant_op.constant(grad.shape.as_list())) + indexed_slices_list.append(indexed_slices) + else: + indexed_slices_list.append(grad) + + # Dense shapes from all gradients should be the same. + dense_shape = indexed_slices_list[0].dense_shape + # For simplicity now, always cast to int64. + indices = array_ops.concat([math_ops.cast(x.indices, dtypes.int64) + for x in indexed_slices_list], 0) + values = array_ops.concat([x.values for x in indexed_slices_list], 0) + return ops.IndexedSlices(values, indices, dense_shape) + + +def imperative_grad( + target, + sources, + output_gradients=None): + """Computes gradients from the imperatively defined tape on top of the stack. + + Works by filtering the tape, computing how many downstream usages are of each + tensor and entry, and repeatedly applying backward functions until we have + gradients for all sources. + + Args: + target: either a Tensor or list of Tensors to be differentiated. + sources: list of Tensors for which we want gradients + output_gradients: if not None, a list of gradient provided for each Target, + or None if we are to use the target's computed downstream gradient. + + Returns: + the gradient wrt each of the sources. + + Raises: + RuntimeError: if something goes wrong. + ValueError: if there is no sequence of differentiable operations connecting + a source and any target Tensor. This can happen either if the target is + not computed based on the source, if the tracing was set up incorrectly, + or if only non-differentiable functions of the source were used in the + computation of target. + """ + if not tape._tape_stack.stack: # pylint: disable=protected-access + raise RuntimeError("Computing a gradient with no tape present") + bp_tape = tape.pop_tape() + tensor_to_op, op_to_entry, output_to_shape_dtype = bp_tape.export() + # This overwrites the op_to_entry variable, which will release all memory used + # to keep traces that are irrelevant to the gradient computation we're doing + # here. + id_sources = [ops.tensor_id(t) for t in sources] + tensor_usage_counts, op_to_entry, op_missing_tensor = _prepare_backprop( + target, tensor_to_op, op_to_entry, id_sources) + ready_ops = _initialize_backprop_stack(op_to_entry, op_missing_tensor) + gradients = _initial_gradients(target, output_gradients, + tensor_usage_counts) + # Now exhaust the backprop stack + while ready_ops: + op = ready_ops.pop() + op_trace = op_to_entry.pop(op) + out_gradients = [gradients.pop(t, None) for t in op_trace.output_ids] + for i in range(len(out_gradients)): + if out_gradients[i] is None: + # TODO(apassos) this should be in the right device + out_gradients[i] = array_ops.zeros( + *output_to_shape_dtype[op_trace.output_ids[i]]) + else: + out_gradients[i] = _aggregate_grads(out_gradients[i]) + + in_gradients = op_trace.backward_function( + *(out_gradients + op_trace.side_outputs)) + in_gradients = ([in_gradients] + if isinstance(in_gradients, (ops.Tensor, + ops.IndexedSlices, + type(None))) + else in_gradients) + for i, t in enumerate(op_trace.inputs): + if in_gradients[i] is not None: + gradients[ops.tensor_id(t)].append(in_gradients[i]) + if tensor_usage_counts.get(ops.tensor_id(t), 0) > 0: + tensor_usage_counts[ops.tensor_id(t)] -= 1 + if ops.tensor_id(t) in tensor_to_op and tensor_usage_counts[ + ops.tensor_id(t)] == 0 and ops.tensor_id(t) not in id_sources: + in_op = tensor_to_op[ops.tensor_id(t)] + if in_op is None: + continue + if op_missing_tensor.get(in_op, 0) > 0: + op_missing_tensor[in_op] -= 1 + if op_missing_tensor.get(in_op, 0) == 0: + ready_ops.append(in_op) + result = [] + for i, s in enumerate(sources): + g = gradients.get(ops.tensor_id(s), None) + if g is None: + # TODO(apassos): figure out a way to summarize why sources and targets are + # not connected. + raise ValueError("There is no sequence of operations connecting source " + "tensor %s (%s) to any of the target Tensors. This is " + "commonly caused by the tape not recording all " + "operations in the forward pass or if by mistake a " + "source was only used in non-differentiable operations." + % (i, s)) + result.append(_aggregate_grads(g)) + return result + + def op_attr_type(op_type, attr_name): with errors.raise_exception_on_not_ok_status() as status: h = context.context()._handle # pylint: disable=protected-access @@ -82,26 +317,23 @@ class _MockOp(object): raise KeyError(attr) -def _magic_gradient_function(op_name, attr_tuple, num_inputs, num_outputs, - *tensors): +def _magic_gradient_function(op_name, attr_tuple, num_inputs, + inputs, outputs, out_grads): """Calls the gradient function of the op. Args: op_name: the name of the op to be differentiated. attr_tuple: the attrs, as a tuple. num_inputs: the number of inputs to the op. - num_outputs: the number of outputs of the op. - *tensors: a list of tensors, composed of, in order, the inputs, the outputs, - and the gradients with respect to the outputs. + inputs: inputs to the original operation. + outputs: outputs to the original operation. + out_grads: gradients of the operation wrt its outputs. Returns: The gradients with respect to the inputs of the function, as a list. """ - inputs = tensors[:num_inputs] - outputs = tensors[num_inputs:num_inputs + num_outputs] - out_grads = tensors[num_inputs + num_outputs:] mock_op = _MockOp(attr_tuple, inputs, outputs, op_name) - grad_fn = tf_ops._gradient_registry.lookup(op_name) # pylint: disable=protected-access + grad_fn = ops._gradient_registry.lookup(op_name) # pylint: disable=protected-access if grad_fn is None: return [None] * num_inputs out_grads = [ @@ -136,31 +368,23 @@ def _record_gradient(op_name, inputs, attrs, results, name): Raises: An exception on error. """ - if not any(ag_core.isnode(x) for x in inputs): - return results num_outputs = len(results) if num_outputs == 0: return results if attrs is not None: - attrs = tuple(tuple(x) if isinstance(x, list) else x for x in attrs) - - # It is imperative we make a copy of results here as otherwise we create a - # dependency cycle in the captured function and this can delay garbage - # collecting of the tensors arbitrarily. - results_size = len(results) if isinstance(results, (list, tuple)) else 1 + attrs = attrs def grad_fn(*orig_outputs): """Generated gradient function.""" - tensors = inputs + list(orig_outputs) - tensors = container_types.make_sequence(tape.EagerList, *tensors) result = _magic_gradient_function(op_name, attrs, len(inputs), - num_outputs, *(tensors)) + inputs, results, orig_outputs) if _tracing: print("Gradient for", (name if name else op_name), "inputs", inputs, - "output_grads", orig_outputs[results_size:], "gradients", result) + "output_grads", orig_outputs, "gradients", result) return result - results = tape.record_operation(results, inputs, [], grad_fn) + inputs = [ops.convert_to_tensor(x) for x in inputs] + tape.record_operation(results, inputs, [], grad_fn) if _tracing: print("Computed op", (name if name else op_name), "inputs", inputs, "outputs", results) @@ -170,27 +394,6 @@ def _record_gradient(op_name, inputs, attrs, results, name): execute.record_gradient = _record_gradient -def _aggregate_grads(gradients): - """Aggregate gradients of the same tensor.""" - grad_lists = collections.OrderedDict() - for g, v in gradients: - if g is None: - continue - if id(v) not in grad_lists: - grad_lists[id(v)] = [(g, v)] - else: - grad_lists[id(v)].append((g, v)) - - ret = [] - for _, g_list in six.iteritems(grad_lists): - if len(g_list) == 1: - ret.append(g_list[0]) - else: - # TODO(xpan): Aggregate IndexedSlices. - ret.append((math_ops.add_n(list(zip(*g_list))[0]), g_list[1][1])) - return ret - - def implicit_val_and_grad(f): """Returns a function which differentiates f with respect to variables. @@ -224,23 +427,14 @@ def implicit_val_and_grad(f): Its second element is list of (gradient, variable) pairs. """ - def grad_fn(*args, **kwds): + def grad_fn(*args): """Computes the gradient of the wrapped function.""" tape.push_new_tape() end_node = f(*args) - start_node = tape.pop_tape() - ag_core.active_progenitors.remove(start_node) - if not ag_core.isnode(end_node): - raise ValueError( - "Target not part of a computation being traced. %s." % end_node) - if start_node not in end_node.progenitors: - raise ValueError("Target not derived from source. %s %s." % - (end_node.progenitors, repr(start_node))) - output_gradients = kwds.get("output_gradients", None) - if output_gradients is None: - output_gradients = array_ops.ones_like(end_node.value) - grad = ag_core.backward_pass(output_gradients, end_node, start_node) - return end_node.value, _aggregate_grads(grad.gradients) + variables = tape.top_tape_watched_variables() + sources = [x.handle for x in variables] + grad = imperative_grad(end_node, sources) + return end_node, list(zip(grad, variables)) return grad_fn @@ -295,24 +489,25 @@ def gradients_function(f, params=None): differentiates with respect to all parameters. Returns: - function which, when called, returns the gradient of f with - respect to all of `params`. + function which, when called, returns the value of f and the gradient + of f with respect to all of `params`. The function takes an extra optional + keyword argument "dy". Setting it allows computation of vector jacobian + products for vectors other than the vector of ones. Raises: ValueError: if the params are not all strings or all integers. """ - parameter_positions = _get_arg_spec(f, params) - def decorated(*args, **kwargs): - tensors = convenience_wrappers.multigrad(f, parameter_positions)(*args, - **kwargs) - return [t.tensor() if isinstance(t, tensor.LazyZero) - else t for t in tensors] + def decorated(*args, **kwds): + """Computes the gradient of the decorated function.""" + + _, grad = val_and_grad_function(f, params)(*args, **kwds) + return grad return decorated -def val_and_grad_function(f, params=None): +def val_and_grad_function(f, params): """Returns a function that computes f and is derivative w.r.t. params. Args: @@ -321,11 +516,30 @@ def val_and_grad_function(f, params=None): parameters with respect to which we'll differentiate. Passing None differentiates with respect to all parameters. - Returns: - function which, when called, returns the value of f and the - gradient of f with respect to all of `params`. + Returns: function which, when called, returns the value of f and the gradient + of f with respect to all of `params`. The function takes an extra optional + keyword argument "dy". Setting it allows computation of vector jacobian + products for vectors other than the vector of ones. Raises: ValueError: if the params are not all strings or all integers. """ - return convenience_wrappers.value_and_multigrad(f, _get_arg_spec(f, params)) + parameter_positions = _get_arg_spec(f, params) + + def decorated(*args, **kwds): + """Computes the value and gradient of the decorated function.""" + dy = kwds.pop("dy", None) + assert not kwds, "The gradient function can't take keyword arguments." + tape.push_new_tape() + sources = [] + args = list(args) + for i in parameter_positions: + sources.append(args[i]) + tape.watch(args[i]) + result = f(*args) + return result, imperative_grad( + result, + sources, + output_gradients=dy) + + return decorated diff --git a/tensorflow/python/eager/backprop_test.py b/tensorflow/python/eager/backprop_test.py index a78a21f0186..908d21073f9 100644 --- a/tensorflow/python/eager/backprop_test.py +++ b/tensorflow/python/eager/backprop_test.py @@ -23,10 +23,10 @@ from tensorflow.python.eager import backprop from tensorflow.python.eager import context from tensorflow.python.eager import tape from tensorflow.python.eager import tensor -from tensorflow.python.eager import tensor_node from tensorflow.python.eager import test from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes +from tensorflow.python.framework import ops from tensorflow.python.framework import tensor_shape from tensorflow.python.ops import array_ops from tensorflow.python.ops import embedding_ops @@ -58,6 +58,7 @@ class BackpropTest(test.TestCase): var_np = np.random.rand(4, 2).astype(np.float32) var = tensor.Tensor(var_np) grad = backprop.gradients_function(fn, [0])(var)[0] + grad = ops.convert_to_tensor(grad).numpy() with context.graph_mode(), self.test_session(): tf_var = array_ops.constant(var_np, dtypes.float32) @@ -74,11 +75,7 @@ class BackpropTest(test.TestCase): tf_dense_grad = math_ops.unsorted_segment_sum( tf_grad.values, tf_grad.indices, tf_grad.dense_shape[0]) - self.assertAllClose(grad.numpy(), tf_dense_grad.eval()) - - def testTensoVspaceNoneMutAdd(self): - t = tensor.Tensor(1.0) - self.assertEqual(tensor_node.TensorVSpace(t).mut_add(t, None).numpy(), 1.0) + self.assertAllClose(grad, tf_dense_grad.eval()) def testImplicitGradWithResourceVariable(self): x = resource_variable_ops.ResourceVariable( @@ -94,6 +91,16 @@ class BackpropTest(test.TestCase): self.assertEqual(grads_and_vars[0][0].numpy(), 1.0) self.assertEqual(id(grads_and_vars[0][1]), id(x)) + def testDy(self): + + def f(x): + return x + + grad_fn = backprop.gradients_function(f) + self.assertAllEqual(grad_fn(constant_op.constant(1.0), + dy=constant_op.constant(2.0))[0].numpy(), + 2.0) + def testImplicitGradOverEmbeddingLookup(self): batch_size = 8 embedding_size = 512 diff --git a/tensorflow/python/eager/custom_gradient.py b/tensorflow/python/eager/custom_gradient.py index 6f9372ea2cc..52f3f077086 100644 --- a/tensorflow/python/eager/custom_gradient.py +++ b/tensorflow/python/eager/custom_gradient.py @@ -18,22 +18,11 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -from autograd import core as ag_core - from tensorflow.python.eager import tape -from tensorflow.python.eager import tensor as _tensor from tensorflow.python.framework import ops as tf_ops from tensorflow.python.util import nest -def _watch_value_from_tape(tensor): - for t in tape._tape_stack.stack: # pylint: disable=protected-access - w = t.value.tensors.get(tf_ops.tensor_id(tensor), None) - if w is not None: - return w - return tensor - - def custom_gradient(f): """Decorator to define a function with a custom gradient. @@ -52,27 +41,23 @@ def custom_gradient(f): def decorated(*args, **kwargs): """Decorated function with custom gradient.""" - input_tensors = [_watch_value_from_tape(x) for x in args - if isinstance(x, (_tensor.Tensor, tf_ops.Tensor)) - or ag_core.isnode(x)] + input_tensors = [x for x in args + if isinstance(x, tf_ops.Tensor)] result, grad_fn = f(*args, **kwargs) - result_size = len(result) if isinstance(result, (list, tuple)) else 1 # TODO(apassos): naive uses of custom_gradient will not get the correct # second derivative this way if they capture any output tensors. Change the # signature of custom_gradient. def actual_grad_fn(*outputs): - outputs = outputs[result_size:] return grad_fn(*outputs) flat_result = nest.flatten(result) - flat_result = [ag_core.getval(x) for x in flat_result] - flat_result = tape.record_operation( + tape.record_operation( flat_result, input_tensors, [], actual_grad_fn) flat_result = list(flat_result) - return nest.pack_sequence_as(structure=result, flat_sequence=flat_result) + return result return decorated diff --git a/tensorflow/python/eager/execute.py b/tensorflow/python/eager/execute.py index 5d8141c3672..993d1b53584 100644 --- a/tensorflow/python/eager/execute.py +++ b/tensorflow/python/eager/execute.py @@ -18,7 +18,6 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -from autograd import core as ag_core import six from google.protobuf import text_format @@ -57,7 +56,6 @@ def execute(op_name, num_outputs, inputs, attrs=None, name=None): """ ctx = context.get_default_context() # TODO(apassos) move this to convert_to_tensor - inputs = [ag_core.getval(x) for x in inputs] # pylint: disable=protected-access input_handles = [c._handle for c in inputs] device_name = ctx.device_name @@ -184,7 +182,7 @@ def args_to_matching_eager(l, default_dtype=None): # Is some input already a Tensor with a dtype? dtype = None for t in l: - if isinstance(ag_core.getval(t), tensor.Tensor): + if isinstance(t, tensor.Tensor): dtype = t.dtype break @@ -203,7 +201,7 @@ def args_to_matching_eager(l, default_dtype=None): def convert_to_mixed_eager_tensors(values): - v = [t if isinstance(ag_core.getval(t), tensor.Tensor) else tensor.Tensor(t) + v = [t if isinstance(t, tensor.Tensor) else tensor.Tensor(t) for t in values] types = [t.dtype for t in v] return types, v @@ -228,7 +226,7 @@ def args_to_mixed_eager_tensors(lists): dtype = None # If any list has a Tensor, use that dtype for l in lists: - if isinstance(ag_core.getval(l[i]), tensor.Tensor): + if isinstance(l[i], tensor.Tensor): dtype = l[i].dtype break if dtype is None: diff --git a/tensorflow/python/eager/function.py b/tensorflow/python/eager/function.py index ffb7924bffe..026add23456 100644 --- a/tensorflow/python/eager/function.py +++ b/tensorflow/python/eager/function.py @@ -23,7 +23,6 @@ import collections import contextlib import threading -from autograd import core as ag_core import numpy as np from tensorflow.python.eager import context @@ -88,6 +87,7 @@ def _convert_to_graph_tensor(value, dtype=None, name=None, as_ref=False): tensor_map[ops.tensor_id(value)] = (value, captured_value) else: captured_value = captured_value[1] + tape.record_operation([captured_value], [value], [], lambda x: x) return captured_value @@ -193,11 +193,8 @@ class _GraphModeFunction(object): self._num_outputs = len(fdef.signature.output_arg) self._ops = operations self._func_outputs = func_outputs - if (isinstance(func_outputs, (ops.Tensor, type(None))) or - ag_core.isnode(func_outputs)): - self._returns = [func_outputs] - else: - self._returns = list(func_outputs) + self._returns = [func_outputs] if isinstance( + func_outputs, (ops.Tensor, type(None))) else list(func_outputs) self._returns_to_fedf_outputs = func_outputs_to_fdef_outputs self._output_shapes = output_shapes @@ -208,7 +205,7 @@ class _GraphModeFunction(object): c = _CapturingContext() with c: filtered_outputs = [ - ag_core.getval(x) for x in self._returns if x is not None + x for x in self._returns if x is not None ] self._out_grad_placeholders = [ graph_placeholder(x.dtype, x.shape) for x in filtered_outputs @@ -242,16 +239,19 @@ class _GraphModeFunction(object): if context.in_graph_mode(): g = ops.get_default_graph() g._add_function(self._forward_fdef) # pylint: disable=protected-access - unwrapped_args = [ag_core.getval(x) for x in all_args] + def make_tensor(x): + if isinstance(x, ops.Tensor): + return x + return ops.convert_to_tensor(x) op = g.create_op( - signature.name, [ops.convert_to_tensor(x) for x in unwrapped_args], + signature.name, [make_tensor(x) for x in all_args], [dtypes.DType(x.type) for x in signature.output_arg], op_def=signature, name="FunctionCall", compute_shapes=False) outputs = op.outputs outputs = [outputs] if isinstance( - outputs, (tensor.Tensor, ops.Tensor, type(None))) else list(outputs) + outputs, (ops.Tensor, type(None))) else list(outputs) for i, s in enumerate(self._output_shapes): outputs[i].set_shape(s) else: @@ -261,25 +261,12 @@ class _GraphModeFunction(object): inputs=all_args) real_outputs = outputs[:len(self._returns)] side_outputs = outputs[len(self._returns):] - watched_extra_inputs = [] - for t in self._extra_inputs: - tid = ops.tensor_id(t) - for t in tape._tape_stack.stack: # pylint: disable=protected-access - w = t.value.tensors.get(tid, None) - if w is not None: - watched_extra_inputs.append(w) - break - else: # Note: for-else here done on purpose - watched_extra_inputs.append(t) - def backward_function_wrapper(*outputs): - outputs = outputs[len(real_outputs):] - return self._backward_function(*outputs) - real_outputs = tape.record_operation( + tape.record_operation( real_outputs, - (args + watched_extra_inputs), + (args + self._extra_inputs), side_outputs, - backward_function_wrapper) + self._backward_function) return self._build_call_outputs(self._returns, real_outputs) @@ -288,10 +275,10 @@ class _GraphModeFunction(object): tensor_inputs = [ x for x in nest.flatten(args) if isinstance(x, (tensor.Tensor, ops.Tensor, - tensor.LazyZero)) or ag_core.isnode(x) + tensor.LazyZero)) ] - if tape.should_record(tensor_inputs) or any( - tape.any_tape_has(t) for t in self._extra_inputs): + if tape.should_record(tensor_inputs) or tape.should_record( + self._extra_inputs): if not self._has_backprop: self._compute_backprop() return self._backprop_call(tensor_inputs) @@ -334,12 +321,12 @@ class _GraphModeFunction(object): """ if self._func_outputs is None: return None - if isinstance(ag_core.getval(self._func_outputs), ops.Tensor): + if isinstance(self._func_outputs, ops.Tensor): return result[0] outputs = [] for o in func_outputs: - vo = ag_core.getval(o) + vo = o if isinstance(vo, ops.Tensor): outputs.append(result[self._returns_to_fedf_outputs[id(vo)]]) elif type(vo) in (tuple, list): @@ -354,7 +341,6 @@ def _get_defun_inputs(args): """Maps the inputs args to graph inputs.""" ret = [] for a in args: - a = ag_core.getval(a) if isinstance(a, (tensor.LazyZero, ops.Tensor, tensor.Tensor)): ret.append(graph_placeholder(a.dtype, a.shape)) elif type(a) in (tuple, list): @@ -395,7 +381,7 @@ def _defun_internal(name, func, args, kwds): ] all_inputs = flat_inputs + list(extra_placeholders) - func_def_outputs = [ag_core.getval(x) for x in outputs_list if x is not None] + func_def_outputs = [x for x in outputs_list if x is not None] inference_function_def = graph_to_function_def.graph_to_function_def( tmp_graph, tmp_graph.get_operations(), all_inputs, func_def_outputs) # Register any other functions defined in the graph @@ -421,7 +407,6 @@ _ZeroDtype = collections.namedtuple("_ZeroDtype", ["dtype", "shape"]) def _cache_key(x): """Cache key for tfe functions.""" - x = ag_core.getval(x) if isinstance(x, tensor.Tensor): return _TensorDtype(x.dtype, x._shape_tuple()) # pylint: disable=protected-access if isinstance(x, tensor.LazyZero): diff --git a/tensorflow/python/eager/tape.py b/tensorflow/python/eager/tape.py index ea2e3b114bd..e33c52a1b2c 100644 --- a/tensorflow/python/eager/tape.py +++ b/tensorflow/python/eager/tape.py @@ -18,96 +18,114 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +import collections import threading -from autograd import container_types -from autograd import core as ag_core - -from tensorflow.python.framework import dtypes -from tensorflow.python.framework import ops -from tensorflow.python.util import nest from tensorflow.python.util import tf_contextlib -class ImplicitTape(object): - """Global object which can watch tensors and wrap them with autograd.""" +def tid(tensor): + return tensor._id # pylint: disable=protected-access + + +class TapeEntry( + collections.namedtuple("TapeEntry", [ + "output_ids", "inputs", "side_outputs", "backward_function" + ])): + """Entry in the gradient tape. + + Represents the execution of one op or function, with instructions for doing + its backward pass and useful information for it. + + Args: + output_ids: tensor_id(t) for each output tensor T + inputs: input tensors + side_outputs: optional tensors which need to be provided to the backward + function. + backward_function: function to be called with the downstream gradients and + side outputs as arguments which computes the backward pass. + """ + + +def _tensor_shape(t): + return t._shape_tuple() # pylint: disable=protected-access + + +class Tape(object): + """Represents a gradient propagation trace.""" def __init__(self): - self.tensors = {} - self.variables = {} - self.gradients = [] + # _tensor_tape maps from tensor IDs to their operation IDs + self._tensor_tape = {} + # maps output tensor IDs to their shapes and dtypes + self._shape_dtype = {} + # maps from operation ID to TapeEntry + self._op_tape = {} + # next operation ID + self._next_op_id = 0 + # List of directly watched tensors + self._watched = [] + # Set of directly watched variables + self._watched_variables = set() - def __eq__(self, other): - return self is other + def should_record(self, tensors): + """Returns true if any tensor should be recorded. - def __hash__(self): - return id(self) + Args: + tensors: some tensors. + Returns: + True if any of the tensors is in the tape. + """ + return any(x._id in self._tensor_tape for x in tensors) # pylint: disable=protected-access -@ag_core.primitive -def _watch_with_tape_internal(_, tensor): - """Primitive to wrap a tensor around an ImplicitTape progenitor.""" - return tensor + def watch(self, tensor): + """Adds a tensor to the tape.""" + if tid(tensor) not in self._tensor_tape: + self._tensor_tape[tid(tensor)] = None + self._watched.append(tensor) + def watch_variable(self, v): + self._watched_variables.add(v) + self.watch(v.handle) -def _watch_with_tape(tape, resource_variable): - """Wraps a watched Tensor and keeps track of it in the implicit tape.""" - tensor = resource_variable.handle - w = _watch_with_tape_internal(tape, tensor) - if ag_core.isnode(tape): - tape.value.variables[ops.tensor_id(tensor)] = resource_variable - tape.value.tensors[ops.tensor_id(tensor)] = w + def record_operation(self, output_tensors, input_tensors, side_outputs, + backward_function): + """Records an operation in the tape.""" + if not self.should_record(input_tensors): + return output_tensors + for t in output_tensors: + self._tensor_tape[tid(t)] = self._next_op_id + self._shape_dtype[tid(t)] = (_tensor_shape(t), t.dtype) + self._op_tape[self._next_op_id] = TapeEntry( + [tid(t) for t in output_tensors], + input_tensors, + side_outputs, + backward_function) + self._next_op_id += 1 -def _watch_with_tape_vjp(g, ans, vs, gvs, tape, tensor): - """Gradient for _watch_with_tape_internal.""" - del ans, gvs + def delete_trace(self, tensor): + """Deletes any trace we have for this tensor.""" + if tid(tensor) in self._tensor_tape: + op = self._tensor_tape[tid(tensor)] + del self._tensor_tape[tid(tensor)] + if op in self._op_tape: + if not any( + x in self._tensor_tape for x in self._op_tape[op].output_ids): + del self._op_tape[op] - def mut_add(implicit_tape): - resource_variable = tape.value.variables[ops.tensor_id(tensor)] - implicit_tape.gradients.append((g, resource_variable)) - return implicit_tape + def export(self): + """Exports the internal state of this tape. - return ag_core.SparseObject(vs, mut_add) - -_watch_with_tape_internal.defvjp(_watch_with_tape_vjp, argnum=0) -_watch_with_tape_internal.defvjp( - lambda g, ans, vs, gvs, tape, tensor: g, - argnum=1) - - -class ImplicitTapeVSpace(ag_core.VSpace): - """VSpace needed to have ImplicitTape be a valid progenitor.""" - - def zeros(self): - return ImplicitTape() - - -class ImplicitTapeNode(ag_core.Node): - """Node to wrap ImplicitTape in.""" - - def __eq__(self, other): - return self is other - - def __hash__(self): - return id(self) - -ag_core.register_node(ImplicitTapeNode, ImplicitTape) -ag_core.register_vspace(ImplicitTapeVSpace, ImplicitTape) - - -# TODO(apassos) try to not do this. -class NoneVSpace(ag_core.VSpace): - """VSpace for python None.""" - - def __init__(self, _): - self.size = 0 - - def zeros(self): - return 0 - - -ag_core.register_vspace(NoneVSpace, type(None)) + Returns: + tensor_tape: a map from tensor_id(tensor) to + responsible for generating that tensor. + op_tape: a map from to TapeEntry for that op. + output_to_shape_dtype: a map from tensor_id(tensor) to its shape and + dtype, for tensors which are outputs + """ + return self._tensor_tape, self._op_tape, self._shape_dtype class _TapeStack(threading.local): @@ -134,19 +152,33 @@ _tape_stack = _TapeStack() def push_new_tape(): """Pushes a new tape onto the tape stack.""" - progenitor = ag_core.new_progenitor(ImplicitTape()) - _tape_stack.stack.append(progenitor) - ag_core.active_progenitors.add(progenitor) + _tape_stack.stack.append(Tape()) -def watch_variable(resource_variable): - """Marks this ResourceVariable to be watched by all tapes in the stack. +def watch(tensor): + """Marks this tensor to be watched by all tapes in the stack. Args: - resource_variable: A ResourceVariable to be watched. + tensor: tensor to be watched. + + Returns: + The tensor, potentially wrapped by all tapes in the stack. """ for t in _tape_stack.stack: - _watch_with_tape(t, resource_variable) + t.watch(tensor) + + +def watch_variable(variable): + """Marks this variable to be watched by all tapes in the stack. + + Args: + variable: variable to be watched. + + Returns: + The tensor, potentially wrapped by all tapes in the stack. + """ + for t in _tape_stack.stack: + t.watch_variable(variable) def pop_tape(): @@ -156,85 +188,34 @@ def pop_tape(): return None -def any_tape_has(tensor): - for t in _tape_stack.stack: - if ops.tensor_id(tensor) in t.value.tensors: - return True - return False - - def should_record(tensors): - """Returns true if any tape in the stack watches any of these tensors.""" - return any(ag_core.isnode(x) for x in tensors) + """Returns true if any tape in the stach watches any of these tensors.""" + if not _tape_stack.stack: + return False + return any(x.should_record(tensors) for x in _tape_stack.stack) -class _EagerSequenceNode(container_types.SequenceNode): - """Eager version of SequenceNode, to live in EagerSequenceVSpace.""" - pass +def record_operation(output_tensors, input_tensors, side_outputs, + backward_function): + """Records the operation on all tapes in the stack.""" + for t in _tape_stack.stack: + t.record_operation(output_tensors, + input_tensors, + side_outputs, + backward_function) -class _EagerSequenceVSpace(container_types.SequenceVSpace): - """Changes equality on SequenceVSpace to conform to tfe requirements.""" - - def __init__(self, value): - self.shape = [ag_core.vspace(x) for x in value] - self.size = sum(s.size for s in self.shape) - self.sequence_type = type(value) - - def __eq__(self, other): - if type(self) != type(other): # pylint: disable=unidiomatic-typecheck - return False - if len(self.shape) != len(other.shape): - # TODO(apassos) function gradients sometimes return gradients for side - # inputs which breaks this assertion. Understand how to fix it. - return True - for ss, os in zip(self.shape, other.shape): - if ss != os: - if isinstance(ss, NoneVSpace) or isinstance(os, NoneVSpace): - continue - if ss.dtype == dtypes.resource or os.dtype == dtypes.resource: - continue - return False - return True +def delete_trace(tensor): + """Deletes traces for this Tensor from all tapes in the stack.""" + for t in _tape_stack.stack: + t.delete_trace(tensor) -class EagerList(list): - """Type used to bypass SequenceVSpace. - - SequenceVSpace has a very strict equality check which does not match - tensorflow semantics. - """ - - def __init__(self, value): - super(EagerList, self).__init__(value) - for v in value: - assert not ag_core.isnode(v) - -ag_core.register_vspace(_EagerSequenceVSpace, EagerList) -ag_core.register_node(_EagerSequenceNode, EagerList) +def top_tape_watched_tensors(): + t = _tape_stack.stack[-1] + return t._watched # pylint: disable=protected-access -@ag_core.primitive -def _record_operation(output_tensors, input_tensors, side_outputs, - backward_function): - del input_tensors, side_outputs, backward_function - return EagerList(output_tensors) - - -def record_operation(o, i, s, b): - """Primitive to trigger autograd tracing on outputs from inputs.""" - inputs = container_types.make_sequence(EagerList, *i) - return _record_operation(o, inputs, s, b) - - -def _record_operation_vjp(g, ans, vs, gvs, output_tensors, input_tensors, - side_outputs, backward_function): - """Gradient for _record_operation.""" - del vs, gvs, input_tensors, output_tensors - backward_args = tuple(g) + tuple(side_outputs) - backward_args = container_types.make_sequence( - EagerList, *(tuple(ans) + backward_args)) - tensors = nest.flatten(backward_function(*backward_args)) - return container_types.make_sequence(EagerList, *tensors) - -_record_operation.defvjp(_record_operation_vjp, argnum=1) +def top_tape_watched_variables(): + t = _tape_stack.stack[-1] + return t._watched_variables # pylint: disable=protected-access diff --git a/tensorflow/python/eager/tensor_node.py b/tensorflow/python/eager/tensor_node.py deleted file mode 100644 index ea7dfcbe4c8..00000000000 --- a/tensorflow/python/eager/tensor_node.py +++ /dev/null @@ -1,301 +0,0 @@ -# Copyright 2017 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -"""TensorNode for autograd tracing of computations with Tensors.""" - -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -from autograd import core as ag_core - -from tensorflow.python.eager import context -from tensorflow.python.eager import custom_gradient -from tensorflow.python.eager import tensor -from tensorflow.python.framework import common_shapes -from tensorflow.python.framework import dtypes -from tensorflow.python.framework import ops -from tensorflow.python.framework import tensor_shape -from tensorflow.python.ops import array_ops -from tensorflow.python.ops import math_ops - - -@ag_core.primitive -def _tensor_numpy(t): - return t.numpy() - - -@ag_core.primitive -def _as_gpu_tensor(t, index=0): - return t.as_gpu_tensor(gpu_index=index) - - -_as_gpu_tensor.defvjp( - lambda g, ans, vs, gvs, t, index: g.as_cpu_tensor(), argnum=0) - - -@custom_gradient.custom_gradient -def _tensor_copy(t, ctx=None, device_name=None): - - def grad(dresult): - return dresult._copy(device_name=t.device) # pylint: disable=protected-access - - return t.value._copy(ctx=ctx, device_name=device_name), grad # pylint: disable=protected-access - - -@ag_core.primitive -def _as_cpu_tensor(t): - return t.as_cpu_tensor() - - -_as_cpu_tensor.defvjp(lambda g, ans, vs, gvs, t: g.as_gpu_tensor(), argnum=0) - - -# TODO(apassos,ashankar): The operator overrides here need to be kept in sync -# with the overrides for ops.Tensor and ops.EagerTensor. -# -# Note that we cannot use self.value.__op__() because that would result -# in an ops.EagerTensor instead of a TensorNode being returned. -# -# We need to figure out a way to ensure that the two are in sync. -class TensorNode(ag_core.Node): - """A TensorFlow Tensor.""" - - __slots__ = [] - - def __getitem__(self, idx): - return array_ops._SliceHelper(self, idx) # pylint: disable=protected-access - - shape = property(lambda self: self.value.shape) - dtype = property(lambda self: self.value.dtype) - device = property(lambda self: self.value.device) - - def get_shape(self): - return self.shape - - def numpy(self): - return _tensor_numpy(self) - - def _shape_tuple(self): - return self.value._shape_tuple # pylint: disable=protected-access - - def as_cpu_tensor(self): - return _as_cpu_tensor(self) - - def as_gpu_tensor(self, gpu_index=0): - return _as_gpu_tensor(self, gpu_index) - - def _copy(self, ctx=None, device_name=None): - return _tensor_copy(self, ctx, device_name) - - def __neg__(self): - return math_ops.negative(self) - - def __abs__(self): - return math_ops.abs(self) # pylint: disable=protected-access - - def __invert__(self): - # ops.Tensor used math_ops.logical_not as of August 2017. - # Now that bitwise_ops.invert exists, it might make sense - # for both ops.Tensor and TensorNode to use that if the - # type is compatible. - return math_ops.logical_not(self) - - def __hash__(self): - return id(self) - - def __add__(self, other): - if isinstance(self.value, tensor.LazyZero): - return other - if isinstance(other, tensor.LazyZero): - return self - return math_ops.add(self, other) - - def __radd__(self, other): - if isinstance(self.value, tensor.LazyZero): - return other - if isinstance(ag_core.getval(other), tensor.LazyZero): - return self - return math_ops.add(other, self) - - def __sub__(self, other): - return math_ops.subtract(self, other) - - def __rsub__(self, other): - return math_ops.subtract(other, self) - - def __mul__(self, other): - return math_ops.multiply(self, other) - - def __rmul__(self, other): - return math_ops.multiply(other, self) - - def __mod__(self, other): - return math_ops.floormod(self, other) - - def __rmod__(self, other): - return math_ops.floormod(other, self) - - def __pow__(self, other): - return math_ops.pow(self, other) - - def __rpow__(self, other): - return math_ops.pow(other, self) - - def __div__(self, other): - return math_ops._div_python2(self, other) # pylint: disable=protected-access - - def __rdiv__(self, other): - return math_ops._div_python2(other, self) # pylint: disable=protected-access - - def __truediv__(self, other): - return math_ops._truediv_python3(self, other) # pylint: disable=protected-access - - def __rtruediv__(self, other): - return math_ops._truediv_python3(other, self) # pylint: disable=protected-access - - def __floordiv__(self, other): - return math_ops.floordiv(self, other) - - def __rfloordiv__(self, other): - return math_ops.floordiv(other, self) - - def __eq__(self, other): - # math_ops.equal raises an error if shapes are not compatible, so check that - # explicitly first. - if common_shapes.is_broadcast_compatible( - self.shape, ops.convert_to_tensor(other).shape): - return math_ops.equal(self, other) - return False - - def __gt__(self, other): - return math_ops.greater(self, other) - - def __ge__(self, other): - return math_ops.greater_equal(self, other) - - def __lt__(self, other): - return math_ops.less(self, other) - - def __le__(self, other): - return math_ops.less_equal(self, other) - - -ag_core.register_node(TensorNode, tensor.Tensor) -ag_core.register_node(TensorNode, ops.Tensor) - - -def _zeros(shape, dtype): - with context.device("cpu:0"): - shape = tensor.Tensor(shape, dtype=dtypes.int32) - return array_ops.fill(shape, tensor.Tensor(0, dtype=dtype)) - - -def _ones(shape, dtype): - return array_ops.fill( - tensor.Tensor(shape, dtype=dtypes.int32), tensor.Tensor(1, dtype=dtype)) - - -def _lazy_zero_tensor(zero): - return _zeros(zero.shape, zero.dtype) - - -tensor.LazyZero.tensor = _lazy_zero_tensor - - -def _lazy_zero_to_tensor(lazy_zero, dtype=None, name=None, as_ref=False): - del as_ref, name, dtype - return _zeros(lazy_zero.shape, lazy_zero.dtype) - - -ops.register_tensor_conversion_function(tensor.LazyZero, _lazy_zero_to_tensor) - - -def _indexed_slices_to_tensor(value): - """Converts an IndexedSlices object `value` to a Tensor. - - Args: - value: An ops.IndexedSlices object. - - Returns: - A dense Tensor representing the values in the given IndexedSlices. - - Raises: - ValueError: If the IndexedSlices does not have the same dtype. - """ - if value.dense_shape is None: - raise ValueError( - "Tensor conversion requested for IndexedSlices without dense_shape: %s" - % str(value)) - return math_ops.unsorted_segment_sum(value.values, value.indices, - value.dense_shape[0]) - - -class TensorVSpace(ag_core.VSpace): - """VSpace for tf/tfe Tensors in autograd.""" - - def __init__(self, value): - self.size = 1 - if isinstance(value, ops.IndexedSlices): - self.shape = tensor_shape.TensorShape(value.dense_shape.numpy()) - self.dtype = value.values.dtype - else: - self.shape = value._shape_tuple() # pylint: disable=protected-access - self.dtype = value.dtype - # TODO(apassos) put gradients on the same device as ops. - - def __eq__(self, other): - # TODO(apassos) consider revisiting this if not performance sensitive. - return True - - def __ne__(self, other): - return not self.__eq__(other) - - def zeros(self): - return tensor.LazyZero(self.shape, self.dtype) - - def ones(self): - return _ones(self.shape, self.dtype) - - def standard_basis(self): - raise NotImplementedError - - def flatten(self, value): - return array_ops.reshape(value, tensor.Tensor(-1)) - - def unflatten(self, value): - return array_ops.reshape(value, tensor.Tensor(self.shape)) - - def mut_add(self, x, y): - """Add wrapper safe for IndexedSlices and LazyZero.""" - if isinstance(ag_core.getval(x), tensor.LazyZero): - return y - if isinstance(ag_core.getval(y), tensor.LazyZero): - return x - if isinstance(x, ops.IndexedSlices): - x = _indexed_slices_to_tensor(x) - if isinstance(y, ops.IndexedSlices): - y = _indexed_slices_to_tensor(y) - if x is None: - return y - if y is None: - return x - return math_ops.add(x, y) - - -ag_core.register_vspace(TensorVSpace, tensor.Tensor) -ag_core.register_vspace(TensorVSpace, ops.Tensor) -ag_core.register_vspace(TensorVSpace, ops.IndexedSlices) -ag_core.register_vspace(TensorVSpace, tensor.LazyZero) -ag_core.register_node(TensorNode, tensor.LazyZero) diff --git a/tensorflow/python/eager/tensor_node_test.py b/tensorflow/python/eager/tensor_node_test.py deleted file mode 100644 index 558c7b706d1..00000000000 --- a/tensorflow/python/eager/tensor_node_test.py +++ /dev/null @@ -1,84 +0,0 @@ -# Copyright 2017 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -from tensorflow.python.eager import tensor -from tensorflow.python.eager import tensor_node -from tensorflow.python.eager import test -from tensorflow.python.framework import test_util - - -def public_or_operator(n): - if not n.startswith("_"): - return True - return n.startswith("__") and n.endswith("__") - - -class TensorNodeTest(test_util.TensorFlowTestCase): - - # TensorNode must implement autograd core's Node interface, which - # it does via inheritance. It also needs to be duck-typeable as - # a tensorflow.python.framework.ops.EagerTensor. - # - # This is a "test" to help ensure interface compatibility. - def testCanBeATensor(self): - # TODO(ashankar,apassos): This list of "exceptions" - list of - # Tensor methods not implemented by TensorNode needs to be - # trimmed. - exceptions = set([ - "OVERLOADABLE_OPERATORS", - "__and__", - "__del__", - "__dict__", - "__iter__", - "__len__", - "__matmul__", - "__or__", - "__rand__", - "__rmatmul__", - "__ror__", - "__rxor__", - "__weakref__", - "__xor__", - # BEGIN: Methods of Tensor that EagerTensor raises exceptions on. - # But perhaps TensorNode should defer to "self.value." for - # them? - "consumers", - "eval", - "graph", - "name", - "op", - "set_shape", - "value_index", - # END: Methods of Tensor that EagerTensor raises exceptions on. - ]) - - tensor_dir = dir(tensor.Tensor) - tensor_dir = filter(public_or_operator, tensor_dir) - tensor_dir = set(tensor_dir).difference(exceptions) - - tensor_node_dir = set(dir(tensor_node.TensorNode)) - - missing = tensor_dir.difference(tensor_node_dir) - self.assertEqual( - 0, - len(missing), - msg="Methods/properties missing in TensorNode: {}".format(missing)) - - -if __name__ == "__main__": - test.main() diff --git a/tensorflow/python/framework/constant_op.py b/tensorflow/python/framework/constant_op.py index a859645950d..a961c85f783 100644 --- a/tensorflow/python/framework/constant_op.py +++ b/tensorflow/python/framework/constant_op.py @@ -41,7 +41,6 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -from autograd import core as ag_core import numpy as np from tensorflow.core.framework import attr_value_pb2 @@ -76,7 +75,7 @@ def _eager_fill(dims, value): def convert_to_eager_tensor(t, dtype=None): """Converts the given `value` to an `EagerTensor`.""" - if isinstance(ag_core.getval(t), ops.EagerTensor): + if isinstance(t, ops.EagerTensor): if dtype is not None and t.dtype != dtype: raise TypeError("Expected tensor with type %r not %r" % (dtype, t.dtype)) return t diff --git a/tensorflow/python/framework/op_def_library.py b/tensorflow/python/framework/op_def_library.py index 76424ef579b..e85bba11cd1 100644 --- a/tensorflow/python/framework/op_def_library.py +++ b/tensorflow/python/framework/op_def_library.py @@ -19,7 +19,6 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -from autograd import core as ag_core import six from tensorflow.core.framework import attr_value_pb2 @@ -503,7 +502,6 @@ class OpDefLibrary(object): default_dtype = default_type_attr_map[input_arg.type_attr] try: - values = ag_core.getval(values) values = ops.internal_convert_to_tensor( values, name=input_arg.name, @@ -784,7 +782,6 @@ class OpDefLibrary(object): if arg.is_ref] with _MaybeColocateWith(must_colocate_inputs): # Add Op to graph - inputs = [ag_core.getval(x) for x in inputs] op = g.create_op(op_type_name, inputs, output_types, name=scope, input_types=input_types, attrs=attr_protos, op_def=op_def) diff --git a/tensorflow/python/framework/ops.py b/tensorflow/python/framework/ops.py index b197e96886e..4973643fa8c 100644 --- a/tensorflow/python/framework/ops.py +++ b/tensorflow/python/framework/ops.py @@ -25,7 +25,6 @@ import re import sys import threading -from autograd import core as ag_core import numpy as np import six @@ -38,6 +37,7 @@ from tensorflow.core.framework import versions_pb2 from tensorflow.python import pywrap_tensorflow as c_api from tensorflow.python.eager import context from tensorflow.python.eager import core +from tensorflow.python.eager import tape from tensorflow.python.framework import c_api_util from tensorflow.python.framework import device as pydev from tensorflow.python.framework import dtypes @@ -70,10 +70,9 @@ from tensorflow.python.util import tf_contextlib _USE_C_API = False -def tensor_id(t): +def tensor_id(tensor): """Returns a unique identifier for this Tensor.""" - t = ag_core.getval(t) - return t._id # pylint: disable=protected-access + return tensor._id # pylint: disable=protected-access def _in_gpu_device(): @@ -703,6 +702,7 @@ class EagerTensor(Tensor): def __del__(self): try: + tape.delete_trace(self) if c_api is not None and c_api.TFE_DeleteTensorHandle is not None: c_api.TFE_DeleteTensorHandle(self._handle) if core.active_trace() is not None: @@ -727,7 +727,7 @@ class EagerTensor(Tensor): self.dtype.name) def __repr__(self): - return "" % ( + return "" % ( self._id, self.shape, self.dtype.name, self._numpy_text(is_repr=True)) @staticmethod @@ -770,6 +770,16 @@ class EagerTensor(Tensor): tensor_id(new_tensor), new_tensor.device, new_tensor.shape.num_elements()) + + # Record the copy on tape and define backprop copy as well. + if not context.in_graph_mode(): + self_device = self.device + def grad_fun(dresult): + with errors.raise_exception_on_not_ok_status() as status: + grad_h = c_api.TFE_TensorHandleCopyToDevice( + dresult._handle, ctx._handle, self_device, status) + return _tensor_from_handle(grad_h) + tape.record_operation([new_tensor], [self], [], grad_fun) return new_tensor # pylint: enable=protected-access @@ -1033,26 +1043,21 @@ def internal_convert_to_tensor(value, RuntimeError: If a registered conversion function returns an invalid value. """ - # Note we check the type of the object unwrapped from an autograd node, if - # tracing gradients, to ensure the same behavior happens with and without - # tracing. - unwrapped = ag_core.getval(value) - if context.in_eager_mode(): # Fast path for EagerTensors that don't need any conversion. - if isinstance(unwrapped, EagerTensor): + if isinstance(value, EagerTensor): # Note that we don't check that value's dtype matches the dtype # argument. We exepct that the C runtime will do that checking # when we execute the kernel. return value values = nest.flatten(value) if (len(values) > 1 and - any(isinstance(ag_core.getval(v), EagerTensor) for v in values)): + any(isinstance(v, EagerTensor) for v in values)): raise TypeError("Cannot convert to a eager tensor.") if dtype is not None: dtype = dtypes.as_dtype(dtype) - unwrapped_type = type(unwrapped) + unwrapped_type = type(value) conversion_func_list = _tensor_conversion_func_cache.get(unwrapped_type, None) if conversion_func_list is None: with _tensor_conversion_func_lock: @@ -1060,7 +1065,7 @@ def internal_convert_to_tensor(value, for _, funcs_at_priority in sorted( _tensor_conversion_func_registry.items()): for base_type, conversion_func in funcs_at_priority: - if isinstance(unwrapped, base_type): + if isinstance(value, base_type): conversion_func_list.append((base_type, conversion_func)) _tensor_conversion_func_cache[unwrapped_type] = conversion_func_list @@ -1090,7 +1095,7 @@ def internal_convert_to_tensor(value, if ret is NotImplemented: continue - if not isinstance(ag_core.getval(ret), Tensor): + if not isinstance(ret, Tensor): raise RuntimeError( "%sConversion function %r for type %s returned non-Tensor: %r" % (_error_prefix(name), conversion_func, base_type, ret)) diff --git a/tensorflow/python/framework/tensor_util.py b/tensorflow/python/framework/tensor_util.py index 745428e530c..c8bdb35e804 100644 --- a/tensorflow/python/framework/tensor_util.py +++ b/tensorflow/python/framework/tensor_util.py @@ -18,7 +18,6 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -from autograd import core as ag_core import numpy as np import six @@ -607,7 +606,7 @@ def ShapeEquals(tensor_proto, shape): def _ConstantValue(tensor, partial): # TODO(touts): Support Variables? - if not isinstance(ag_core.getval(tensor), ops.Tensor): + if not isinstance(tensor, ops.Tensor): raise TypeError("tensor is not a Tensor") if tensor.op.type == "Const": return MakeNdarray(tensor.op.get_attr("value")) @@ -737,7 +736,7 @@ def constant_value(tensor, partial=False): # pylint: disable=invalid-name Raises: TypeError: if tensor is not an ops.Tensor. """ - if isinstance(ag_core.getval(tensor), ops.EagerTensor): + if isinstance(tensor, ops.EagerTensor): return tensor.numpy() ret = _ConstantValue(tensor, partial) if ret is not None: diff --git a/tensorflow/python/ops/array_ops.py b/tensorflow/python/ops/array_ops.py index 274eda4f643..9bf178cc49e 100644 --- a/tensorflow/python/ops/array_ops.py +++ b/tensorflow/python/ops/array_ops.py @@ -291,10 +291,11 @@ def shape_internal(input, name=None, optimize=True, out_type=dtypes.int32): sparse_tensor.SparseTensorValue)): return gen_math_ops.cast(input.dense_shape, out_type) else: - input_tensor = ops.convert_to_tensor(input) - input_shape = input_tensor.get_shape() - if optimize and input_shape.is_fully_defined(): - return constant(input_shape.as_list(), out_type, name=name) + if context.in_graph_mode(): + input_tensor = ops.convert_to_tensor(input) + input_shape = input_tensor.get_shape() + if optimize and input_shape.is_fully_defined(): + return constant(input_shape.as_list(), out_type, name=name) return gen_array_ops.shape(input, name=name, out_type=out_type) @@ -1428,7 +1429,9 @@ def zeros(shape, dtype=dtypes.float32, name=None): zero = "" else: zero = 0 - if context.in_eager_mode(): + # Checking for boolean dtype to prevent attempting to run fill on the GPU + # which does not have a boolean kernel registered. + if context.in_eager_mode() and dtype != dtypes.bool: return fill(shape, constant(zero, dtype=dtype), name=name) try: shape = tensor_shape.as_shape(shape) diff --git a/tensorflow/python/ops/math_ops.py b/tensorflow/python/ops/math_ops.py index 3cd82d60417..59b1238c6f0 100644 --- a/tensorflow/python/ops/math_ops.py +++ b/tensorflow/python/ops/math_ops.py @@ -144,7 +144,6 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -from autograd import core as ag_core import numpy as np from six.moves import xrange # pylint: disable=redefined-builtin @@ -1113,12 +1112,11 @@ floormod = gen_math_ops._floor_mod def _mul_dispatch(x, y, name=None): """Dispatches cwise mul for "Dense*Dense" and "Dense*Sparse".""" - is_tensor_y = isinstance(ag_core.getval(y), ops.Tensor) + is_tensor_y = isinstance(y, ops.Tensor) if is_tensor_y: return gen_math_ops._mul(x, y, name=name) else: - assert isinstance(ag_core.getval(y), - sparse_tensor.SparseTensor) # Case: Dense * Sparse. + assert isinstance(y, sparse_tensor.SparseTensor) # Case: Dense * Sparse. new_vals = gen_sparse_ops.sparse_dense_cwise_mul(y.indices, y.values, y.dense_shape, x, name) return sparse_tensor.SparseTensor(y.indices, new_vals, y.dense_shape) diff --git a/tensorflow/python/ops/resource_variable_ops.py b/tensorflow/python/ops/resource_variable_ops.py index b628fb26d14..2848e978981 100644 --- a/tensorflow/python/ops/resource_variable_ops.py +++ b/tensorflow/python/ops/resource_variable_ops.py @@ -19,14 +19,10 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -from autograd import core as ag_core - from tensorflow.core.framework import attr_value_pb2 from tensorflow.core.framework import variable_pb2 from tensorflow.python.eager import context -from tensorflow.python.eager import custom_gradient from tensorflow.python.eager import tape -from tensorflow.python.eager import tensor_node from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops from tensorflow.python.framework import tensor_shape @@ -506,10 +502,8 @@ class ResourceVariable(variables.Variable): def _read_variable_op(self): if hasattr(self, "_trainable") and self._trainable: tape.watch_variable(self) - return read_variable_op(self._handle, dtype=self._dtype) - else: - return gen_resource_variable_ops.read_variable_op(self._handle, - self._dtype) + return gen_resource_variable_ops.read_variable_op(self._handle, + self._dtype) def read_value(self): """Constructs an op which reads the value of this variable. @@ -541,7 +535,7 @@ class ResourceVariable(variables.Variable): with ops.name_scope("Gather" if name is None else name) as name: if self._trainable: tape.watch_variable(self) - value = resource_gather( + value = gen_resource_variable_ops.resource_gather( self._handle, indices, dtype=self._dtype, name=name) return array_ops.identity(value) @@ -614,13 +608,7 @@ class ResourceVariable(variables.Variable): def _run_op(a, *args): # pylint: disable=protected-access value = a._AsTensor() - if ag_core.isnode(value): - # This avoids autograd trying to wrap a ResourceVariable. - value = ops.convert_to_tensor(value) - args = [ops.convert_to_tensor(x) for x in args] - return getattr(tensor_node.TensorNode, operator)(value, *args) - else: - return getattr(ops.Tensor, operator)(value, *args) + return getattr(ops.Tensor, operator)(value, *args) # Propagate __doc__ to wrapper try: @@ -693,33 +681,6 @@ class ResourceVariable(variables.Variable): return self.value() -@custom_gradient.custom_gradient -def read_variable_op(handle, dtype): - """Reads the value of a variable. - - The tensor returned by this operation is immutable. - - The value returned by this operation is guaranteed to be influenced by all the - writes on which this operation depends directly or indirectly, and to not be - influenced by any of the writes which depend directly or indirectly on this - operation. - - Args: - handle: A `Tensor` of type `resource`. - handle to the resource in which to store the variable. - dtype: A `tf.DType`. the dtype of the value. - - Returns: - A `Tensor` of type `dtype`. - """ - result = gen_resource_variable_ops.read_variable_op(handle, dtype) - - def grad(dresult): - return dresult - - return result, grad - - def _dense_var_to_tensor(var, dtype=None, name=None, as_ref=False): return var._dense_var_to_tensor(dtype=dtype, name=name, as_ref=as_ref) # pylint: disable=protected-access @@ -744,51 +705,6 @@ def _ReadGrad(_, grad): return grad -# TODO(apassos) do not use custom_gradient here by making other entry points -# than custom_gradient also aware of how to deal with variables implicitly -# watched in the tape (i.e. the call to _watch_value in custom_gradient) -@custom_gradient.custom_gradient -def resource_gather(resource, indices, dtype, validate_indices=True, name=None): - """Gather slices from the variable pointed to by `resource`. - - `indices` must be an integer tensor of any dimension (usually 0-D or 1-D). - Produces an output tensor with shape `indices.shape + params.shape[1:]` where: - - ```python - # Scalar indices - output[:, ..., :] = params[indices, :, ... :] - - # Vector indices - output[i, :, ..., :] = params[indices[i], :, ... :] - - # Higher rank indices - output[i, ..., j, :, ... :] = params[indices[i, ..., j], :, ..., :] - ``` - - Args: - resource: A `Tensor` of type `resource`. - handle to the resource in which to store the variable. - indices: a integer `Tensor` containing the indices to be gathered. - dtype: A `tf.DType`. the dtype of the value. - validate_indices: optional `bool`. If false will not validate that the - indices fit in the variable. - name: The optional name for the operation to be added. - - Returns: - A `Tensor` of type `dtype`. - """ - result = gen_resource_variable_ops.resource_gather( - resource, indices, dtype, validate_indices=validate_indices, name=name) - - def grad(dresult): - return ops.IndexedSlices( - dresult, - indices, - dense_shape=gen_resource_variable_ops.variable_shape(resource)) - - return result, grad - - @ops.RegisterGradient("ResourceGather") def _GatherGrad(op, grad): """Gradient for gather op.""" @@ -797,7 +713,11 @@ def _GatherGrad(op, grad): # TODO(apassos): more robust way of getting the shape. # TODO(apassos): implement this for EAGER mode. if context.in_eager_mode(): - raise NotImplementedError("_GatherGrad not implemented for EAGER mode") + dense_shape = gen_resource_variable_ops.variable_shape(op.inputs[0]) + return (ops.IndexedSlices(grad, + op.inputs[1], + dense_shape=dense_shape), + None) handle = op.inputs[0] while handle.op.type != "VarHandleOp": handle = handle.op.inputs[0]