Resurrects autograd-free eager gradients.
PiperOrigin-RevId: 168448557
This commit is contained in:
parent
8f37f30027
commit
655f26fc70
@ -637,6 +637,7 @@ py_library(
|
|||||||
"//tensorflow/core:protos_all_py",
|
"//tensorflow/core:protos_all_py",
|
||||||
"//tensorflow/python/eager:context",
|
"//tensorflow/python/eager:context",
|
||||||
"//tensorflow/python/eager:core",
|
"//tensorflow/python/eager:core",
|
||||||
|
"//tensorflow/python/eager:tape",
|
||||||
"@six_archive//:six",
|
"@six_archive//:six",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
@ -1800,7 +1801,6 @@ py_library(
|
|||||||
"//tensorflow/python/eager:context",
|
"//tensorflow/python/eager:context",
|
||||||
"//tensorflow/python/eager:custom_gradient",
|
"//tensorflow/python/eager:custom_gradient",
|
||||||
"//tensorflow/python/eager:tape",
|
"//tensorflow/python/eager:tape",
|
||||||
"//tensorflow/python/eager:tensor_node",
|
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -82,7 +82,6 @@ py_library(
|
|||||||
visibility = ["//tensorflow:internal"],
|
visibility = ["//tensorflow:internal"],
|
||||||
deps = [
|
deps = [
|
||||||
"//tensorflow/python:dtypes",
|
"//tensorflow/python:dtypes",
|
||||||
"//tensorflow/python:framework_ops",
|
|
||||||
"//tensorflow/python:util",
|
"//tensorflow/python:util",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
@ -307,25 +306,6 @@ py_library(
|
|||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
py_library(
|
|
||||||
name = "tensor_node",
|
|
||||||
srcs = ["tensor_node.py"],
|
|
||||||
srcs_version = "PY2AND3",
|
|
||||||
visibility = ["//tensorflow:internal"],
|
|
||||||
deps = [
|
|
||||||
":context",
|
|
||||||
":custom_gradient",
|
|
||||||
":tape",
|
|
||||||
":tensor",
|
|
||||||
"//tensorflow/python:array_ops",
|
|
||||||
"//tensorflow/python:common_shapes",
|
|
||||||
"//tensorflow/python:dtypes",
|
|
||||||
"//tensorflow/python:framework_ops",
|
|
||||||
"//tensorflow/python:math_ops",
|
|
||||||
"//tensorflow/python:tensor_shape",
|
|
||||||
],
|
|
||||||
)
|
|
||||||
|
|
||||||
py_library(
|
py_library(
|
||||||
name = "backprop",
|
name = "backprop",
|
||||||
srcs = ["backprop.py"],
|
srcs = ["backprop.py"],
|
||||||
@ -344,7 +324,6 @@ py_library(
|
|||||||
"//tensorflow/python/eager:execute",
|
"//tensorflow/python/eager:execute",
|
||||||
"//tensorflow/python/eager:tape",
|
"//tensorflow/python/eager:tape",
|
||||||
"//tensorflow/python/eager:tensor",
|
"//tensorflow/python/eager:tensor",
|
||||||
"//tensorflow/python/eager:tensor_node",
|
|
||||||
"@six_archive//:six",
|
"@six_archive//:six",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
@ -386,18 +365,6 @@ py_test(
|
|||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
py_test(
|
|
||||||
name = "tensor_node_test",
|
|
||||||
srcs = ["tensor_node_test.py"],
|
|
||||||
srcs_version = "PY2AND3",
|
|
||||||
deps = [
|
|
||||||
":tensor",
|
|
||||||
":tensor_node",
|
|
||||||
":test",
|
|
||||||
"//tensorflow/python:framework_test_lib",
|
|
||||||
],
|
|
||||||
)
|
|
||||||
|
|
||||||
py_test(
|
py_test(
|
||||||
name = "ops_test",
|
name = "ops_test",
|
||||||
srcs = ["ops_test.py"],
|
srcs = ["ops_test.py"],
|
||||||
|
@ -21,30 +21,265 @@ from __future__ import print_function
|
|||||||
import collections
|
import collections
|
||||||
import threading
|
import threading
|
||||||
|
|
||||||
from autograd import container_types
|
|
||||||
from autograd import convenience_wrappers
|
|
||||||
from autograd import core as ag_core
|
|
||||||
|
|
||||||
import six
|
import six
|
||||||
|
|
||||||
from tensorflow.python import pywrap_tensorflow
|
from tensorflow.python import pywrap_tensorflow
|
||||||
from tensorflow.python.eager import context
|
from tensorflow.python.eager import context
|
||||||
from tensorflow.python.eager import execute
|
from tensorflow.python.eager import execute
|
||||||
from tensorflow.python.eager import tape
|
from tensorflow.python.eager import tape
|
||||||
from tensorflow.python.eager import tensor
|
from tensorflow.python.framework import constant_op
|
||||||
# Imports TensorNode to enable autograd tracing of TF ops. We don't need to use
|
|
||||||
# any symbols here but import the file just to get the right registrations to
|
|
||||||
# happen.
|
|
||||||
from tensorflow.python.eager import tensor_node # pylint: disable=unused-import
|
|
||||||
from tensorflow.python.framework import dtypes
|
from tensorflow.python.framework import dtypes
|
||||||
from tensorflow.python.framework import errors
|
from tensorflow.python.framework import errors
|
||||||
from tensorflow.python.framework import ops as tf_ops
|
from tensorflow.python.framework import ops
|
||||||
from tensorflow.python.framework import tensor_shape
|
from tensorflow.python.framework import tensor_shape
|
||||||
from tensorflow.python.ops import array_ops
|
from tensorflow.python.ops import array_ops
|
||||||
from tensorflow.python.ops import math_ops
|
from tensorflow.python.ops import math_ops
|
||||||
|
from tensorflow.python.util import tf_contextlib
|
||||||
from tensorflow.python.util import tf_inspect
|
from tensorflow.python.util import tf_inspect
|
||||||
|
|
||||||
|
|
||||||
|
# Terminology:
|
||||||
|
#
|
||||||
|
# - op: a possibly composite operation, which has an entry in the tape
|
||||||
|
# - target: dy in dx/dy
|
||||||
|
# - source: dx in dx/dy
|
||||||
|
# - tensor: one of the many inputs or outputs of an operation
|
||||||
|
#
|
||||||
|
# Below here we do the gradient algorithm. It works as follows:
|
||||||
|
#
|
||||||
|
# First we filter the tape to just the subset of operations we want to
|
||||||
|
# differentiate. In the process of doing so we count how many times each Tensor
|
||||||
|
# is used as an input to an op (so we know when we're done computing gradients
|
||||||
|
# for that Tensor). We also count, for each tape entry, how many of its output
|
||||||
|
# Tensors need gradients to be computed (Tensors which are not used do not need
|
||||||
|
# any gradients to be computed).
|
||||||
|
#
|
||||||
|
# Finally, we start a backprop stack with a set of tape entries for which we
|
||||||
|
# have all gradients available. This set usually is a subset of the set of
|
||||||
|
# targets (not all since targets which have outputs in the tape will not have
|
||||||
|
# gradients available initially).
|
||||||
|
#
|
||||||
|
# Then we repeatedly pop an entry from the stack, run its backprop, and update
|
||||||
|
# the gradients of its inputs. Once we have computed all gradients for a single
|
||||||
|
# input we can mark this input as done, and this can trigger adding an entry to
|
||||||
|
# the stack if all outputs of that entry are now done.
|
||||||
|
#
|
||||||
|
# When the stack is empty we have gradients for all tensors we're interested in.
|
||||||
|
|
||||||
|
|
||||||
|
def _prepare_backprop(target, tensor_to_op, op_to_entry, id_sources):
|
||||||
|
"""Filters the tape to only include relevant entries and counts tensor usages.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
target: the target to optimize.
|
||||||
|
tensor_to_op: Map from tensor id to key in op_to_entry that produced it.
|
||||||
|
op_to_entry: Map from op id to a tape.TapeEntry object
|
||||||
|
id_sources: the ids of the sources wrt the gradient is being taken.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
usage counts (how many entries downstream from a tensor use it)
|
||||||
|
op_to_entry_map: entry map (a filtered tape, with only the relevant
|
||||||
|
entries),
|
||||||
|
missing: map from tensor id to how many downstream gradients still need
|
||||||
|
to be computed before this tensor's gradient can be computed.
|
||||||
|
"""
|
||||||
|
if isinstance(target, (ops.Tensor)):
|
||||||
|
tensor_stack = [ops.tensor_id(target)]
|
||||||
|
else:
|
||||||
|
tensor_stack = list([ops.tensor_id(x) for x in target])
|
||||||
|
tensor_usage_counts = {}
|
||||||
|
o_to_e = {} # Copy of just the bits we need from op_to_entry
|
||||||
|
while tensor_stack:
|
||||||
|
t = tensor_stack.pop()
|
||||||
|
op = tensor_to_op[t]
|
||||||
|
# op is None if the tensor is a source (i.e. was watched directly)
|
||||||
|
if op is None or op in o_to_e:
|
||||||
|
continue
|
||||||
|
op_trace = op_to_entry[op]
|
||||||
|
o_to_e[op] = op_trace
|
||||||
|
for i in op_trace.inputs:
|
||||||
|
it = ops.tensor_id(i)
|
||||||
|
if it in tensor_usage_counts:
|
||||||
|
tensor_usage_counts[it] += 1
|
||||||
|
else:
|
||||||
|
tensor_usage_counts[it] = 1
|
||||||
|
if it not in id_sources and it in tensor_to_op:
|
||||||
|
tensor_stack.append(it)
|
||||||
|
op_missing_tensor_counts = collections.defaultdict(int)
|
||||||
|
for t in tensor_usage_counts:
|
||||||
|
if t in tensor_to_op and tensor_to_op[t] is not None:
|
||||||
|
op_missing_tensor_counts[tensor_to_op[t]] += 1
|
||||||
|
return tensor_usage_counts, o_to_e, op_missing_tensor_counts
|
||||||
|
|
||||||
|
|
||||||
|
def _initialize_backprop_stack(op_to_entry, op_missing_tensor):
|
||||||
|
"""Returns the set of tape entries which are available for backprop."""
|
||||||
|
ready_ops = []
|
||||||
|
for op in op_to_entry:
|
||||||
|
if op not in op_missing_tensor:
|
||||||
|
ready_ops.append(op)
|
||||||
|
return ready_ops
|
||||||
|
|
||||||
|
|
||||||
|
def _initial_gradients(target, output_gradients, tensor_usage_counts):
|
||||||
|
"""Computes the initial gradients for each Tensor."""
|
||||||
|
# Initialize the backprop stack
|
||||||
|
gradients = collections.defaultdict(list)
|
||||||
|
if isinstance(target, ops.Tensor):
|
||||||
|
if output_gradients is not None:
|
||||||
|
output_gradient = output_gradients
|
||||||
|
else:
|
||||||
|
output_gradient = array_ops.ones_like(target)
|
||||||
|
gradients[ops.tensor_id(target)].append(output_gradient)
|
||||||
|
else:
|
||||||
|
for i, t in enumerate(target):
|
||||||
|
if ops.tensor_id(t) in tensor_usage_counts:
|
||||||
|
# Can't provide a gradient of something we're trying to differentiate
|
||||||
|
assert output_gradients is None or output_gradients[i] is None
|
||||||
|
else:
|
||||||
|
if output_gradients is None or output_gradients[i] is None:
|
||||||
|
out_grad = ops.ones_like(t)
|
||||||
|
else:
|
||||||
|
out_grad = output_gradients[i]
|
||||||
|
gradients[ops.tensor_id(t)].append(out_grad)
|
||||||
|
return gradients
|
||||||
|
|
||||||
|
|
||||||
|
@tf_contextlib.contextmanager
|
||||||
|
def _no_op():
|
||||||
|
yield
|
||||||
|
|
||||||
|
|
||||||
|
def _aggregate_grads(gradients):
|
||||||
|
"""Aggregate gradients from multiple sources.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
gradients: A list of 'Tensor' or 'IndexedSlices' gradients.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
If 'gradients' only has 'Tensor', returns an aggregated 'Tensor'.
|
||||||
|
Otherwise returns an aggregated 'IndexedSlices'.
|
||||||
|
"""
|
||||||
|
assert gradients, "No gradients to aggregate"
|
||||||
|
|
||||||
|
if len(gradients) == 1:
|
||||||
|
return gradients[0]
|
||||||
|
if all([isinstance(g, ops.Tensor) for g in gradients]):
|
||||||
|
return math_ops.add_n(gradients)
|
||||||
|
else:
|
||||||
|
assert all([isinstance(g, (ops.Tensor, ops.IndexedSlices))
|
||||||
|
for g in gradients])
|
||||||
|
indexed_slices_list = []
|
||||||
|
for grad in gradients:
|
||||||
|
# TODO(xpan): Support nested IndexedSlices and core IndexedSlices
|
||||||
|
if isinstance(grad, ops.Tensor):
|
||||||
|
indexed_slices = ops.IndexedSlices(
|
||||||
|
grad,
|
||||||
|
constant_op.constant(range(grad.shape[0])),
|
||||||
|
constant_op.constant(grad.shape.as_list()))
|
||||||
|
indexed_slices_list.append(indexed_slices)
|
||||||
|
else:
|
||||||
|
indexed_slices_list.append(grad)
|
||||||
|
|
||||||
|
# Dense shapes from all gradients should be the same.
|
||||||
|
dense_shape = indexed_slices_list[0].dense_shape
|
||||||
|
# For simplicity now, always cast to int64.
|
||||||
|
indices = array_ops.concat([math_ops.cast(x.indices, dtypes.int64)
|
||||||
|
for x in indexed_slices_list], 0)
|
||||||
|
values = array_ops.concat([x.values for x in indexed_slices_list], 0)
|
||||||
|
return ops.IndexedSlices(values, indices, dense_shape)
|
||||||
|
|
||||||
|
|
||||||
|
def imperative_grad(
|
||||||
|
target,
|
||||||
|
sources,
|
||||||
|
output_gradients=None):
|
||||||
|
"""Computes gradients from the imperatively defined tape on top of the stack.
|
||||||
|
|
||||||
|
Works by filtering the tape, computing how many downstream usages are of each
|
||||||
|
tensor and entry, and repeatedly applying backward functions until we have
|
||||||
|
gradients for all sources.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
target: either a Tensor or list of Tensors to be differentiated.
|
||||||
|
sources: list of Tensors for which we want gradients
|
||||||
|
output_gradients: if not None, a list of gradient provided for each Target,
|
||||||
|
or None if we are to use the target's computed downstream gradient.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
the gradient wrt each of the sources.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
RuntimeError: if something goes wrong.
|
||||||
|
ValueError: if there is no sequence of differentiable operations connecting
|
||||||
|
a source and any target Tensor. This can happen either if the target is
|
||||||
|
not computed based on the source, if the tracing was set up incorrectly,
|
||||||
|
or if only non-differentiable functions of the source were used in the
|
||||||
|
computation of target.
|
||||||
|
"""
|
||||||
|
if not tape._tape_stack.stack: # pylint: disable=protected-access
|
||||||
|
raise RuntimeError("Computing a gradient with no tape present")
|
||||||
|
bp_tape = tape.pop_tape()
|
||||||
|
tensor_to_op, op_to_entry, output_to_shape_dtype = bp_tape.export()
|
||||||
|
# This overwrites the op_to_entry variable, which will release all memory used
|
||||||
|
# to keep traces that are irrelevant to the gradient computation we're doing
|
||||||
|
# here.
|
||||||
|
id_sources = [ops.tensor_id(t) for t in sources]
|
||||||
|
tensor_usage_counts, op_to_entry, op_missing_tensor = _prepare_backprop(
|
||||||
|
target, tensor_to_op, op_to_entry, id_sources)
|
||||||
|
ready_ops = _initialize_backprop_stack(op_to_entry, op_missing_tensor)
|
||||||
|
gradients = _initial_gradients(target, output_gradients,
|
||||||
|
tensor_usage_counts)
|
||||||
|
# Now exhaust the backprop stack
|
||||||
|
while ready_ops:
|
||||||
|
op = ready_ops.pop()
|
||||||
|
op_trace = op_to_entry.pop(op)
|
||||||
|
out_gradients = [gradients.pop(t, None) for t in op_trace.output_ids]
|
||||||
|
for i in range(len(out_gradients)):
|
||||||
|
if out_gradients[i] is None:
|
||||||
|
# TODO(apassos) this should be in the right device
|
||||||
|
out_gradients[i] = array_ops.zeros(
|
||||||
|
*output_to_shape_dtype[op_trace.output_ids[i]])
|
||||||
|
else:
|
||||||
|
out_gradients[i] = _aggregate_grads(out_gradients[i])
|
||||||
|
|
||||||
|
in_gradients = op_trace.backward_function(
|
||||||
|
*(out_gradients + op_trace.side_outputs))
|
||||||
|
in_gradients = ([in_gradients]
|
||||||
|
if isinstance(in_gradients, (ops.Tensor,
|
||||||
|
ops.IndexedSlices,
|
||||||
|
type(None)))
|
||||||
|
else in_gradients)
|
||||||
|
for i, t in enumerate(op_trace.inputs):
|
||||||
|
if in_gradients[i] is not None:
|
||||||
|
gradients[ops.tensor_id(t)].append(in_gradients[i])
|
||||||
|
if tensor_usage_counts.get(ops.tensor_id(t), 0) > 0:
|
||||||
|
tensor_usage_counts[ops.tensor_id(t)] -= 1
|
||||||
|
if ops.tensor_id(t) in tensor_to_op and tensor_usage_counts[
|
||||||
|
ops.tensor_id(t)] == 0 and ops.tensor_id(t) not in id_sources:
|
||||||
|
in_op = tensor_to_op[ops.tensor_id(t)]
|
||||||
|
if in_op is None:
|
||||||
|
continue
|
||||||
|
if op_missing_tensor.get(in_op, 0) > 0:
|
||||||
|
op_missing_tensor[in_op] -= 1
|
||||||
|
if op_missing_tensor.get(in_op, 0) == 0:
|
||||||
|
ready_ops.append(in_op)
|
||||||
|
result = []
|
||||||
|
for i, s in enumerate(sources):
|
||||||
|
g = gradients.get(ops.tensor_id(s), None)
|
||||||
|
if g is None:
|
||||||
|
# TODO(apassos): figure out a way to summarize why sources and targets are
|
||||||
|
# not connected.
|
||||||
|
raise ValueError("There is no sequence of operations connecting source "
|
||||||
|
"tensor %s (%s) to any of the target Tensors. This is "
|
||||||
|
"commonly caused by the tape not recording all "
|
||||||
|
"operations in the forward pass or if by mistake a "
|
||||||
|
"source was only used in non-differentiable operations."
|
||||||
|
% (i, s))
|
||||||
|
result.append(_aggregate_grads(g))
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
def op_attr_type(op_type, attr_name):
|
def op_attr_type(op_type, attr_name):
|
||||||
with errors.raise_exception_on_not_ok_status() as status:
|
with errors.raise_exception_on_not_ok_status() as status:
|
||||||
h = context.context()._handle # pylint: disable=protected-access
|
h = context.context()._handle # pylint: disable=protected-access
|
||||||
@ -82,26 +317,23 @@ class _MockOp(object):
|
|||||||
raise KeyError(attr)
|
raise KeyError(attr)
|
||||||
|
|
||||||
|
|
||||||
def _magic_gradient_function(op_name, attr_tuple, num_inputs, num_outputs,
|
def _magic_gradient_function(op_name, attr_tuple, num_inputs,
|
||||||
*tensors):
|
inputs, outputs, out_grads):
|
||||||
"""Calls the gradient function of the op.
|
"""Calls the gradient function of the op.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
op_name: the name of the op to be differentiated.
|
op_name: the name of the op to be differentiated.
|
||||||
attr_tuple: the attrs, as a tuple.
|
attr_tuple: the attrs, as a tuple.
|
||||||
num_inputs: the number of inputs to the op.
|
num_inputs: the number of inputs to the op.
|
||||||
num_outputs: the number of outputs of the op.
|
inputs: inputs to the original operation.
|
||||||
*tensors: a list of tensors, composed of, in order, the inputs, the outputs,
|
outputs: outputs to the original operation.
|
||||||
and the gradients with respect to the outputs.
|
out_grads: gradients of the operation wrt its outputs.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
The gradients with respect to the inputs of the function, as a list.
|
The gradients with respect to the inputs of the function, as a list.
|
||||||
"""
|
"""
|
||||||
inputs = tensors[:num_inputs]
|
|
||||||
outputs = tensors[num_inputs:num_inputs + num_outputs]
|
|
||||||
out_grads = tensors[num_inputs + num_outputs:]
|
|
||||||
mock_op = _MockOp(attr_tuple, inputs, outputs, op_name)
|
mock_op = _MockOp(attr_tuple, inputs, outputs, op_name)
|
||||||
grad_fn = tf_ops._gradient_registry.lookup(op_name) # pylint: disable=protected-access
|
grad_fn = ops._gradient_registry.lookup(op_name) # pylint: disable=protected-access
|
||||||
if grad_fn is None:
|
if grad_fn is None:
|
||||||
return [None] * num_inputs
|
return [None] * num_inputs
|
||||||
out_grads = [
|
out_grads = [
|
||||||
@ -136,31 +368,23 @@ def _record_gradient(op_name, inputs, attrs, results, name):
|
|||||||
Raises:
|
Raises:
|
||||||
An exception on error.
|
An exception on error.
|
||||||
"""
|
"""
|
||||||
if not any(ag_core.isnode(x) for x in inputs):
|
|
||||||
return results
|
|
||||||
num_outputs = len(results)
|
num_outputs = len(results)
|
||||||
if num_outputs == 0:
|
if num_outputs == 0:
|
||||||
return results
|
return results
|
||||||
if attrs is not None:
|
if attrs is not None:
|
||||||
attrs = tuple(tuple(x) if isinstance(x, list) else x for x in attrs)
|
attrs = attrs
|
||||||
|
|
||||||
# It is imperative we make a copy of results here as otherwise we create a
|
|
||||||
# dependency cycle in the captured function and this can delay garbage
|
|
||||||
# collecting of the tensors arbitrarily.
|
|
||||||
results_size = len(results) if isinstance(results, (list, tuple)) else 1
|
|
||||||
|
|
||||||
def grad_fn(*orig_outputs):
|
def grad_fn(*orig_outputs):
|
||||||
"""Generated gradient function."""
|
"""Generated gradient function."""
|
||||||
tensors = inputs + list(orig_outputs)
|
|
||||||
tensors = container_types.make_sequence(tape.EagerList, *tensors)
|
|
||||||
result = _magic_gradient_function(op_name, attrs, len(inputs),
|
result = _magic_gradient_function(op_name, attrs, len(inputs),
|
||||||
num_outputs, *(tensors))
|
inputs, results, orig_outputs)
|
||||||
if _tracing:
|
if _tracing:
|
||||||
print("Gradient for", (name if name else op_name), "inputs", inputs,
|
print("Gradient for", (name if name else op_name), "inputs", inputs,
|
||||||
"output_grads", orig_outputs[results_size:], "gradients", result)
|
"output_grads", orig_outputs, "gradients", result)
|
||||||
return result
|
return result
|
||||||
|
|
||||||
results = tape.record_operation(results, inputs, [], grad_fn)
|
inputs = [ops.convert_to_tensor(x) for x in inputs]
|
||||||
|
tape.record_operation(results, inputs, [], grad_fn)
|
||||||
if _tracing:
|
if _tracing:
|
||||||
print("Computed op", (name if name else op_name), "inputs", inputs,
|
print("Computed op", (name if name else op_name), "inputs", inputs,
|
||||||
"outputs", results)
|
"outputs", results)
|
||||||
@ -170,27 +394,6 @@ def _record_gradient(op_name, inputs, attrs, results, name):
|
|||||||
execute.record_gradient = _record_gradient
|
execute.record_gradient = _record_gradient
|
||||||
|
|
||||||
|
|
||||||
def _aggregate_grads(gradients):
|
|
||||||
"""Aggregate gradients of the same tensor."""
|
|
||||||
grad_lists = collections.OrderedDict()
|
|
||||||
for g, v in gradients:
|
|
||||||
if g is None:
|
|
||||||
continue
|
|
||||||
if id(v) not in grad_lists:
|
|
||||||
grad_lists[id(v)] = [(g, v)]
|
|
||||||
else:
|
|
||||||
grad_lists[id(v)].append((g, v))
|
|
||||||
|
|
||||||
ret = []
|
|
||||||
for _, g_list in six.iteritems(grad_lists):
|
|
||||||
if len(g_list) == 1:
|
|
||||||
ret.append(g_list[0])
|
|
||||||
else:
|
|
||||||
# TODO(xpan): Aggregate IndexedSlices.
|
|
||||||
ret.append((math_ops.add_n(list(zip(*g_list))[0]), g_list[1][1]))
|
|
||||||
return ret
|
|
||||||
|
|
||||||
|
|
||||||
def implicit_val_and_grad(f):
|
def implicit_val_and_grad(f):
|
||||||
"""Returns a function which differentiates f with respect to variables.
|
"""Returns a function which differentiates f with respect to variables.
|
||||||
|
|
||||||
@ -224,23 +427,14 @@ def implicit_val_and_grad(f):
|
|||||||
Its second element is list of (gradient, variable) pairs.
|
Its second element is list of (gradient, variable) pairs.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def grad_fn(*args, **kwds):
|
def grad_fn(*args):
|
||||||
"""Computes the gradient of the wrapped function."""
|
"""Computes the gradient of the wrapped function."""
|
||||||
tape.push_new_tape()
|
tape.push_new_tape()
|
||||||
end_node = f(*args)
|
end_node = f(*args)
|
||||||
start_node = tape.pop_tape()
|
variables = tape.top_tape_watched_variables()
|
||||||
ag_core.active_progenitors.remove(start_node)
|
sources = [x.handle for x in variables]
|
||||||
if not ag_core.isnode(end_node):
|
grad = imperative_grad(end_node, sources)
|
||||||
raise ValueError(
|
return end_node, list(zip(grad, variables))
|
||||||
"Target not part of a computation being traced. %s." % end_node)
|
|
||||||
if start_node not in end_node.progenitors:
|
|
||||||
raise ValueError("Target not derived from source. %s %s." %
|
|
||||||
(end_node.progenitors, repr(start_node)))
|
|
||||||
output_gradients = kwds.get("output_gradients", None)
|
|
||||||
if output_gradients is None:
|
|
||||||
output_gradients = array_ops.ones_like(end_node.value)
|
|
||||||
grad = ag_core.backward_pass(output_gradients, end_node, start_node)
|
|
||||||
return end_node.value, _aggregate_grads(grad.gradients)
|
|
||||||
|
|
||||||
return grad_fn
|
return grad_fn
|
||||||
|
|
||||||
@ -295,24 +489,25 @@ def gradients_function(f, params=None):
|
|||||||
differentiates with respect to all parameters.
|
differentiates with respect to all parameters.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
function which, when called, returns the gradient of f with
|
function which, when called, returns the value of f and the gradient
|
||||||
respect to all of `params`.
|
of f with respect to all of `params`. The function takes an extra optional
|
||||||
|
keyword argument "dy". Setting it allows computation of vector jacobian
|
||||||
|
products for vectors other than the vector of ones.
|
||||||
|
|
||||||
Raises:
|
Raises:
|
||||||
ValueError: if the params are not all strings or all integers.
|
ValueError: if the params are not all strings or all integers.
|
||||||
"""
|
"""
|
||||||
parameter_positions = _get_arg_spec(f, params)
|
|
||||||
|
|
||||||
def decorated(*args, **kwargs):
|
def decorated(*args, **kwds):
|
||||||
tensors = convenience_wrappers.multigrad(f, parameter_positions)(*args,
|
"""Computes the gradient of the decorated function."""
|
||||||
**kwargs)
|
|
||||||
return [t.tensor() if isinstance(t, tensor.LazyZero)
|
_, grad = val_and_grad_function(f, params)(*args, **kwds)
|
||||||
else t for t in tensors]
|
return grad
|
||||||
|
|
||||||
return decorated
|
return decorated
|
||||||
|
|
||||||
|
|
||||||
def val_and_grad_function(f, params=None):
|
def val_and_grad_function(f, params):
|
||||||
"""Returns a function that computes f and is derivative w.r.t. params.
|
"""Returns a function that computes f and is derivative w.r.t. params.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@ -321,11 +516,30 @@ def val_and_grad_function(f, params=None):
|
|||||||
parameters with respect to which we'll differentiate. Passing None
|
parameters with respect to which we'll differentiate. Passing None
|
||||||
differentiates with respect to all parameters.
|
differentiates with respect to all parameters.
|
||||||
|
|
||||||
Returns:
|
Returns: function which, when called, returns the value of f and the gradient
|
||||||
function which, when called, returns the value of f and the
|
of f with respect to all of `params`. The function takes an extra optional
|
||||||
gradient of f with respect to all of `params`.
|
keyword argument "dy". Setting it allows computation of vector jacobian
|
||||||
|
products for vectors other than the vector of ones.
|
||||||
|
|
||||||
Raises:
|
Raises:
|
||||||
ValueError: if the params are not all strings or all integers.
|
ValueError: if the params are not all strings or all integers.
|
||||||
"""
|
"""
|
||||||
return convenience_wrappers.value_and_multigrad(f, _get_arg_spec(f, params))
|
parameter_positions = _get_arg_spec(f, params)
|
||||||
|
|
||||||
|
def decorated(*args, **kwds):
|
||||||
|
"""Computes the value and gradient of the decorated function."""
|
||||||
|
dy = kwds.pop("dy", None)
|
||||||
|
assert not kwds, "The gradient function can't take keyword arguments."
|
||||||
|
tape.push_new_tape()
|
||||||
|
sources = []
|
||||||
|
args = list(args)
|
||||||
|
for i in parameter_positions:
|
||||||
|
sources.append(args[i])
|
||||||
|
tape.watch(args[i])
|
||||||
|
result = f(*args)
|
||||||
|
return result, imperative_grad(
|
||||||
|
result,
|
||||||
|
sources,
|
||||||
|
output_gradients=dy)
|
||||||
|
|
||||||
|
return decorated
|
||||||
|
@ -23,10 +23,10 @@ from tensorflow.python.eager import backprop
|
|||||||
from tensorflow.python.eager import context
|
from tensorflow.python.eager import context
|
||||||
from tensorflow.python.eager import tape
|
from tensorflow.python.eager import tape
|
||||||
from tensorflow.python.eager import tensor
|
from tensorflow.python.eager import tensor
|
||||||
from tensorflow.python.eager import tensor_node
|
|
||||||
from tensorflow.python.eager import test
|
from tensorflow.python.eager import test
|
||||||
from tensorflow.python.framework import constant_op
|
from tensorflow.python.framework import constant_op
|
||||||
from tensorflow.python.framework import dtypes
|
from tensorflow.python.framework import dtypes
|
||||||
|
from tensorflow.python.framework import ops
|
||||||
from tensorflow.python.framework import tensor_shape
|
from tensorflow.python.framework import tensor_shape
|
||||||
from tensorflow.python.ops import array_ops
|
from tensorflow.python.ops import array_ops
|
||||||
from tensorflow.python.ops import embedding_ops
|
from tensorflow.python.ops import embedding_ops
|
||||||
@ -58,6 +58,7 @@ class BackpropTest(test.TestCase):
|
|||||||
var_np = np.random.rand(4, 2).astype(np.float32)
|
var_np = np.random.rand(4, 2).astype(np.float32)
|
||||||
var = tensor.Tensor(var_np)
|
var = tensor.Tensor(var_np)
|
||||||
grad = backprop.gradients_function(fn, [0])(var)[0]
|
grad = backprop.gradients_function(fn, [0])(var)[0]
|
||||||
|
grad = ops.convert_to_tensor(grad).numpy()
|
||||||
|
|
||||||
with context.graph_mode(), self.test_session():
|
with context.graph_mode(), self.test_session():
|
||||||
tf_var = array_ops.constant(var_np, dtypes.float32)
|
tf_var = array_ops.constant(var_np, dtypes.float32)
|
||||||
@ -74,11 +75,7 @@ class BackpropTest(test.TestCase):
|
|||||||
tf_dense_grad = math_ops.unsorted_segment_sum(
|
tf_dense_grad = math_ops.unsorted_segment_sum(
|
||||||
tf_grad.values, tf_grad.indices, tf_grad.dense_shape[0])
|
tf_grad.values, tf_grad.indices, tf_grad.dense_shape[0])
|
||||||
|
|
||||||
self.assertAllClose(grad.numpy(), tf_dense_grad.eval())
|
self.assertAllClose(grad, tf_dense_grad.eval())
|
||||||
|
|
||||||
def testTensoVspaceNoneMutAdd(self):
|
|
||||||
t = tensor.Tensor(1.0)
|
|
||||||
self.assertEqual(tensor_node.TensorVSpace(t).mut_add(t, None).numpy(), 1.0)
|
|
||||||
|
|
||||||
def testImplicitGradWithResourceVariable(self):
|
def testImplicitGradWithResourceVariable(self):
|
||||||
x = resource_variable_ops.ResourceVariable(
|
x = resource_variable_ops.ResourceVariable(
|
||||||
@ -94,6 +91,16 @@ class BackpropTest(test.TestCase):
|
|||||||
self.assertEqual(grads_and_vars[0][0].numpy(), 1.0)
|
self.assertEqual(grads_and_vars[0][0].numpy(), 1.0)
|
||||||
self.assertEqual(id(grads_and_vars[0][1]), id(x))
|
self.assertEqual(id(grads_and_vars[0][1]), id(x))
|
||||||
|
|
||||||
|
def testDy(self):
|
||||||
|
|
||||||
|
def f(x):
|
||||||
|
return x
|
||||||
|
|
||||||
|
grad_fn = backprop.gradients_function(f)
|
||||||
|
self.assertAllEqual(grad_fn(constant_op.constant(1.0),
|
||||||
|
dy=constant_op.constant(2.0))[0].numpy(),
|
||||||
|
2.0)
|
||||||
|
|
||||||
def testImplicitGradOverEmbeddingLookup(self):
|
def testImplicitGradOverEmbeddingLookup(self):
|
||||||
batch_size = 8
|
batch_size = 8
|
||||||
embedding_size = 512
|
embedding_size = 512
|
||||||
|
@ -18,22 +18,11 @@ from __future__ import absolute_import
|
|||||||
from __future__ import division
|
from __future__ import division
|
||||||
from __future__ import print_function
|
from __future__ import print_function
|
||||||
|
|
||||||
from autograd import core as ag_core
|
|
||||||
|
|
||||||
from tensorflow.python.eager import tape
|
from tensorflow.python.eager import tape
|
||||||
from tensorflow.python.eager import tensor as _tensor
|
|
||||||
from tensorflow.python.framework import ops as tf_ops
|
from tensorflow.python.framework import ops as tf_ops
|
||||||
from tensorflow.python.util import nest
|
from tensorflow.python.util import nest
|
||||||
|
|
||||||
|
|
||||||
def _watch_value_from_tape(tensor):
|
|
||||||
for t in tape._tape_stack.stack: # pylint: disable=protected-access
|
|
||||||
w = t.value.tensors.get(tf_ops.tensor_id(tensor), None)
|
|
||||||
if w is not None:
|
|
||||||
return w
|
|
||||||
return tensor
|
|
||||||
|
|
||||||
|
|
||||||
def custom_gradient(f):
|
def custom_gradient(f):
|
||||||
"""Decorator to define a function with a custom gradient.
|
"""Decorator to define a function with a custom gradient.
|
||||||
|
|
||||||
@ -52,27 +41,23 @@ def custom_gradient(f):
|
|||||||
|
|
||||||
def decorated(*args, **kwargs):
|
def decorated(*args, **kwargs):
|
||||||
"""Decorated function with custom gradient."""
|
"""Decorated function with custom gradient."""
|
||||||
input_tensors = [_watch_value_from_tape(x) for x in args
|
input_tensors = [x for x in args
|
||||||
if isinstance(x, (_tensor.Tensor, tf_ops.Tensor))
|
if isinstance(x, tf_ops.Tensor)]
|
||||||
or ag_core.isnode(x)]
|
|
||||||
result, grad_fn = f(*args, **kwargs)
|
result, grad_fn = f(*args, **kwargs)
|
||||||
result_size = len(result) if isinstance(result, (list, tuple)) else 1
|
|
||||||
|
|
||||||
# TODO(apassos): naive uses of custom_gradient will not get the correct
|
# TODO(apassos): naive uses of custom_gradient will not get the correct
|
||||||
# second derivative this way if they capture any output tensors. Change the
|
# second derivative this way if they capture any output tensors. Change the
|
||||||
# signature of custom_gradient.
|
# signature of custom_gradient.
|
||||||
def actual_grad_fn(*outputs):
|
def actual_grad_fn(*outputs):
|
||||||
outputs = outputs[result_size:]
|
|
||||||
return grad_fn(*outputs)
|
return grad_fn(*outputs)
|
||||||
|
|
||||||
flat_result = nest.flatten(result)
|
flat_result = nest.flatten(result)
|
||||||
flat_result = [ag_core.getval(x) for x in flat_result]
|
tape.record_operation(
|
||||||
flat_result = tape.record_operation(
|
|
||||||
flat_result,
|
flat_result,
|
||||||
input_tensors,
|
input_tensors,
|
||||||
[],
|
[],
|
||||||
actual_grad_fn)
|
actual_grad_fn)
|
||||||
flat_result = list(flat_result)
|
flat_result = list(flat_result)
|
||||||
return nest.pack_sequence_as(structure=result, flat_sequence=flat_result)
|
return result
|
||||||
|
|
||||||
return decorated
|
return decorated
|
||||||
|
@ -18,7 +18,6 @@ from __future__ import absolute_import
|
|||||||
from __future__ import division
|
from __future__ import division
|
||||||
from __future__ import print_function
|
from __future__ import print_function
|
||||||
|
|
||||||
from autograd import core as ag_core
|
|
||||||
import six
|
import six
|
||||||
|
|
||||||
from google.protobuf import text_format
|
from google.protobuf import text_format
|
||||||
@ -57,7 +56,6 @@ def execute(op_name, num_outputs, inputs, attrs=None, name=None):
|
|||||||
"""
|
"""
|
||||||
ctx = context.get_default_context()
|
ctx = context.get_default_context()
|
||||||
# TODO(apassos) move this to convert_to_tensor
|
# TODO(apassos) move this to convert_to_tensor
|
||||||
inputs = [ag_core.getval(x) for x in inputs]
|
|
||||||
# pylint: disable=protected-access
|
# pylint: disable=protected-access
|
||||||
input_handles = [c._handle for c in inputs]
|
input_handles = [c._handle for c in inputs]
|
||||||
device_name = ctx.device_name
|
device_name = ctx.device_name
|
||||||
@ -184,7 +182,7 @@ def args_to_matching_eager(l, default_dtype=None):
|
|||||||
# Is some input already a Tensor with a dtype?
|
# Is some input already a Tensor with a dtype?
|
||||||
dtype = None
|
dtype = None
|
||||||
for t in l:
|
for t in l:
|
||||||
if isinstance(ag_core.getval(t), tensor.Tensor):
|
if isinstance(t, tensor.Tensor):
|
||||||
dtype = t.dtype
|
dtype = t.dtype
|
||||||
break
|
break
|
||||||
|
|
||||||
@ -203,7 +201,7 @@ def args_to_matching_eager(l, default_dtype=None):
|
|||||||
|
|
||||||
|
|
||||||
def convert_to_mixed_eager_tensors(values):
|
def convert_to_mixed_eager_tensors(values):
|
||||||
v = [t if isinstance(ag_core.getval(t), tensor.Tensor) else tensor.Tensor(t)
|
v = [t if isinstance(t, tensor.Tensor) else tensor.Tensor(t)
|
||||||
for t in values]
|
for t in values]
|
||||||
types = [t.dtype for t in v]
|
types = [t.dtype for t in v]
|
||||||
return types, v
|
return types, v
|
||||||
@ -228,7 +226,7 @@ def args_to_mixed_eager_tensors(lists):
|
|||||||
dtype = None
|
dtype = None
|
||||||
# If any list has a Tensor, use that dtype
|
# If any list has a Tensor, use that dtype
|
||||||
for l in lists:
|
for l in lists:
|
||||||
if isinstance(ag_core.getval(l[i]), tensor.Tensor):
|
if isinstance(l[i], tensor.Tensor):
|
||||||
dtype = l[i].dtype
|
dtype = l[i].dtype
|
||||||
break
|
break
|
||||||
if dtype is None:
|
if dtype is None:
|
||||||
|
@ -23,7 +23,6 @@ import collections
|
|||||||
import contextlib
|
import contextlib
|
||||||
import threading
|
import threading
|
||||||
|
|
||||||
from autograd import core as ag_core
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
from tensorflow.python.eager import context
|
from tensorflow.python.eager import context
|
||||||
@ -88,6 +87,7 @@ def _convert_to_graph_tensor(value, dtype=None, name=None, as_ref=False):
|
|||||||
tensor_map[ops.tensor_id(value)] = (value, captured_value)
|
tensor_map[ops.tensor_id(value)] = (value, captured_value)
|
||||||
else:
|
else:
|
||||||
captured_value = captured_value[1]
|
captured_value = captured_value[1]
|
||||||
|
tape.record_operation([captured_value], [value], [], lambda x: x)
|
||||||
return captured_value
|
return captured_value
|
||||||
|
|
||||||
|
|
||||||
@ -193,11 +193,8 @@ class _GraphModeFunction(object):
|
|||||||
self._num_outputs = len(fdef.signature.output_arg)
|
self._num_outputs = len(fdef.signature.output_arg)
|
||||||
self._ops = operations
|
self._ops = operations
|
||||||
self._func_outputs = func_outputs
|
self._func_outputs = func_outputs
|
||||||
if (isinstance(func_outputs, (ops.Tensor, type(None))) or
|
self._returns = [func_outputs] if isinstance(
|
||||||
ag_core.isnode(func_outputs)):
|
func_outputs, (ops.Tensor, type(None))) else list(func_outputs)
|
||||||
self._returns = [func_outputs]
|
|
||||||
else:
|
|
||||||
self._returns = list(func_outputs)
|
|
||||||
self._returns_to_fedf_outputs = func_outputs_to_fdef_outputs
|
self._returns_to_fedf_outputs = func_outputs_to_fdef_outputs
|
||||||
self._output_shapes = output_shapes
|
self._output_shapes = output_shapes
|
||||||
|
|
||||||
@ -208,7 +205,7 @@ class _GraphModeFunction(object):
|
|||||||
c = _CapturingContext()
|
c = _CapturingContext()
|
||||||
with c:
|
with c:
|
||||||
filtered_outputs = [
|
filtered_outputs = [
|
||||||
ag_core.getval(x) for x in self._returns if x is not None
|
x for x in self._returns if x is not None
|
||||||
]
|
]
|
||||||
self._out_grad_placeholders = [
|
self._out_grad_placeholders = [
|
||||||
graph_placeholder(x.dtype, x.shape) for x in filtered_outputs
|
graph_placeholder(x.dtype, x.shape) for x in filtered_outputs
|
||||||
@ -242,16 +239,19 @@ class _GraphModeFunction(object):
|
|||||||
if context.in_graph_mode():
|
if context.in_graph_mode():
|
||||||
g = ops.get_default_graph()
|
g = ops.get_default_graph()
|
||||||
g._add_function(self._forward_fdef) # pylint: disable=protected-access
|
g._add_function(self._forward_fdef) # pylint: disable=protected-access
|
||||||
unwrapped_args = [ag_core.getval(x) for x in all_args]
|
def make_tensor(x):
|
||||||
|
if isinstance(x, ops.Tensor):
|
||||||
|
return x
|
||||||
|
return ops.convert_to_tensor(x)
|
||||||
op = g.create_op(
|
op = g.create_op(
|
||||||
signature.name, [ops.convert_to_tensor(x) for x in unwrapped_args],
|
signature.name, [make_tensor(x) for x in all_args],
|
||||||
[dtypes.DType(x.type) for x in signature.output_arg],
|
[dtypes.DType(x.type) for x in signature.output_arg],
|
||||||
op_def=signature,
|
op_def=signature,
|
||||||
name="FunctionCall",
|
name="FunctionCall",
|
||||||
compute_shapes=False)
|
compute_shapes=False)
|
||||||
outputs = op.outputs
|
outputs = op.outputs
|
||||||
outputs = [outputs] if isinstance(
|
outputs = [outputs] if isinstance(
|
||||||
outputs, (tensor.Tensor, ops.Tensor, type(None))) else list(outputs)
|
outputs, (ops.Tensor, type(None))) else list(outputs)
|
||||||
for i, s in enumerate(self._output_shapes):
|
for i, s in enumerate(self._output_shapes):
|
||||||
outputs[i].set_shape(s)
|
outputs[i].set_shape(s)
|
||||||
else:
|
else:
|
||||||
@ -261,25 +261,12 @@ class _GraphModeFunction(object):
|
|||||||
inputs=all_args)
|
inputs=all_args)
|
||||||
real_outputs = outputs[:len(self._returns)]
|
real_outputs = outputs[:len(self._returns)]
|
||||||
side_outputs = outputs[len(self._returns):]
|
side_outputs = outputs[len(self._returns):]
|
||||||
watched_extra_inputs = []
|
|
||||||
for t in self._extra_inputs:
|
|
||||||
tid = ops.tensor_id(t)
|
|
||||||
for t in tape._tape_stack.stack: # pylint: disable=protected-access
|
|
||||||
w = t.value.tensors.get(tid, None)
|
|
||||||
if w is not None:
|
|
||||||
watched_extra_inputs.append(w)
|
|
||||||
break
|
|
||||||
else: # Note: for-else here done on purpose
|
|
||||||
watched_extra_inputs.append(t)
|
|
||||||
|
|
||||||
def backward_function_wrapper(*outputs):
|
tape.record_operation(
|
||||||
outputs = outputs[len(real_outputs):]
|
|
||||||
return self._backward_function(*outputs)
|
|
||||||
real_outputs = tape.record_operation(
|
|
||||||
real_outputs,
|
real_outputs,
|
||||||
(args + watched_extra_inputs),
|
(args + self._extra_inputs),
|
||||||
side_outputs,
|
side_outputs,
|
||||||
backward_function_wrapper)
|
self._backward_function)
|
||||||
|
|
||||||
return self._build_call_outputs(self._returns, real_outputs)
|
return self._build_call_outputs(self._returns, real_outputs)
|
||||||
|
|
||||||
@ -288,10 +275,10 @@ class _GraphModeFunction(object):
|
|||||||
tensor_inputs = [
|
tensor_inputs = [
|
||||||
x for x in nest.flatten(args)
|
x for x in nest.flatten(args)
|
||||||
if isinstance(x, (tensor.Tensor, ops.Tensor,
|
if isinstance(x, (tensor.Tensor, ops.Tensor,
|
||||||
tensor.LazyZero)) or ag_core.isnode(x)
|
tensor.LazyZero))
|
||||||
]
|
]
|
||||||
if tape.should_record(tensor_inputs) or any(
|
if tape.should_record(tensor_inputs) or tape.should_record(
|
||||||
tape.any_tape_has(t) for t in self._extra_inputs):
|
self._extra_inputs):
|
||||||
if not self._has_backprop:
|
if not self._has_backprop:
|
||||||
self._compute_backprop()
|
self._compute_backprop()
|
||||||
return self._backprop_call(tensor_inputs)
|
return self._backprop_call(tensor_inputs)
|
||||||
@ -334,12 +321,12 @@ class _GraphModeFunction(object):
|
|||||||
"""
|
"""
|
||||||
if self._func_outputs is None:
|
if self._func_outputs is None:
|
||||||
return None
|
return None
|
||||||
if isinstance(ag_core.getval(self._func_outputs), ops.Tensor):
|
if isinstance(self._func_outputs, ops.Tensor):
|
||||||
return result[0]
|
return result[0]
|
||||||
|
|
||||||
outputs = []
|
outputs = []
|
||||||
for o in func_outputs:
|
for o in func_outputs:
|
||||||
vo = ag_core.getval(o)
|
vo = o
|
||||||
if isinstance(vo, ops.Tensor):
|
if isinstance(vo, ops.Tensor):
|
||||||
outputs.append(result[self._returns_to_fedf_outputs[id(vo)]])
|
outputs.append(result[self._returns_to_fedf_outputs[id(vo)]])
|
||||||
elif type(vo) in (tuple, list):
|
elif type(vo) in (tuple, list):
|
||||||
@ -354,7 +341,6 @@ def _get_defun_inputs(args):
|
|||||||
"""Maps the inputs args to graph inputs."""
|
"""Maps the inputs args to graph inputs."""
|
||||||
ret = []
|
ret = []
|
||||||
for a in args:
|
for a in args:
|
||||||
a = ag_core.getval(a)
|
|
||||||
if isinstance(a, (tensor.LazyZero, ops.Tensor, tensor.Tensor)):
|
if isinstance(a, (tensor.LazyZero, ops.Tensor, tensor.Tensor)):
|
||||||
ret.append(graph_placeholder(a.dtype, a.shape))
|
ret.append(graph_placeholder(a.dtype, a.shape))
|
||||||
elif type(a) in (tuple, list):
|
elif type(a) in (tuple, list):
|
||||||
@ -395,7 +381,7 @@ def _defun_internal(name, func, args, kwds):
|
|||||||
]
|
]
|
||||||
all_inputs = flat_inputs + list(extra_placeholders)
|
all_inputs = flat_inputs + list(extra_placeholders)
|
||||||
|
|
||||||
func_def_outputs = [ag_core.getval(x) for x in outputs_list if x is not None]
|
func_def_outputs = [x for x in outputs_list if x is not None]
|
||||||
inference_function_def = graph_to_function_def.graph_to_function_def(
|
inference_function_def = graph_to_function_def.graph_to_function_def(
|
||||||
tmp_graph, tmp_graph.get_operations(), all_inputs, func_def_outputs)
|
tmp_graph, tmp_graph.get_operations(), all_inputs, func_def_outputs)
|
||||||
# Register any other functions defined in the graph
|
# Register any other functions defined in the graph
|
||||||
@ -421,7 +407,6 @@ _ZeroDtype = collections.namedtuple("_ZeroDtype", ["dtype", "shape"])
|
|||||||
|
|
||||||
def _cache_key(x):
|
def _cache_key(x):
|
||||||
"""Cache key for tfe functions."""
|
"""Cache key for tfe functions."""
|
||||||
x = ag_core.getval(x)
|
|
||||||
if isinstance(x, tensor.Tensor):
|
if isinstance(x, tensor.Tensor):
|
||||||
return _TensorDtype(x.dtype, x._shape_tuple()) # pylint: disable=protected-access
|
return _TensorDtype(x.dtype, x._shape_tuple()) # pylint: disable=protected-access
|
||||||
if isinstance(x, tensor.LazyZero):
|
if isinstance(x, tensor.LazyZero):
|
||||||
|
@ -18,96 +18,114 @@ from __future__ import absolute_import
|
|||||||
from __future__ import division
|
from __future__ import division
|
||||||
from __future__ import print_function
|
from __future__ import print_function
|
||||||
|
|
||||||
|
import collections
|
||||||
import threading
|
import threading
|
||||||
|
|
||||||
from autograd import container_types
|
|
||||||
from autograd import core as ag_core
|
|
||||||
|
|
||||||
from tensorflow.python.framework import dtypes
|
|
||||||
from tensorflow.python.framework import ops
|
|
||||||
from tensorflow.python.util import nest
|
|
||||||
from tensorflow.python.util import tf_contextlib
|
from tensorflow.python.util import tf_contextlib
|
||||||
|
|
||||||
|
|
||||||
class ImplicitTape(object):
|
def tid(tensor):
|
||||||
"""Global object which can watch tensors and wrap them with autograd."""
|
return tensor._id # pylint: disable=protected-access
|
||||||
|
|
||||||
|
|
||||||
|
class TapeEntry(
|
||||||
|
collections.namedtuple("TapeEntry", [
|
||||||
|
"output_ids", "inputs", "side_outputs", "backward_function"
|
||||||
|
])):
|
||||||
|
"""Entry in the gradient tape.
|
||||||
|
|
||||||
|
Represents the execution of one op or function, with instructions for doing
|
||||||
|
its backward pass and useful information for it.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
output_ids: tensor_id(t) for each output tensor T
|
||||||
|
inputs: input tensors
|
||||||
|
side_outputs: optional tensors which need to be provided to the backward
|
||||||
|
function.
|
||||||
|
backward_function: function to be called with the downstream gradients and
|
||||||
|
side outputs as arguments which computes the backward pass.
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
def _tensor_shape(t):
|
||||||
|
return t._shape_tuple() # pylint: disable=protected-access
|
||||||
|
|
||||||
|
|
||||||
|
class Tape(object):
|
||||||
|
"""Represents a gradient propagation trace."""
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
self.tensors = {}
|
# _tensor_tape maps from tensor IDs to their operation IDs
|
||||||
self.variables = {}
|
self._tensor_tape = {}
|
||||||
self.gradients = []
|
# maps output tensor IDs to their shapes and dtypes
|
||||||
|
self._shape_dtype = {}
|
||||||
|
# maps from operation ID to TapeEntry
|
||||||
|
self._op_tape = {}
|
||||||
|
# next operation ID
|
||||||
|
self._next_op_id = 0
|
||||||
|
# List of directly watched tensors
|
||||||
|
self._watched = []
|
||||||
|
# Set of directly watched variables
|
||||||
|
self._watched_variables = set()
|
||||||
|
|
||||||
def __eq__(self, other):
|
def should_record(self, tensors):
|
||||||
return self is other
|
"""Returns true if any tensor should be recorded.
|
||||||
|
|
||||||
def __hash__(self):
|
Args:
|
||||||
return id(self)
|
tensors: some tensors.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
True if any of the tensors is in the tape.
|
||||||
|
"""
|
||||||
|
return any(x._id in self._tensor_tape for x in tensors) # pylint: disable=protected-access
|
||||||
|
|
||||||
@ag_core.primitive
|
def watch(self, tensor):
|
||||||
def _watch_with_tape_internal(_, tensor):
|
"""Adds a tensor to the tape."""
|
||||||
"""Primitive to wrap a tensor around an ImplicitTape progenitor."""
|
if tid(tensor) not in self._tensor_tape:
|
||||||
return tensor
|
self._tensor_tape[tid(tensor)] = None
|
||||||
|
self._watched.append(tensor)
|
||||||
|
|
||||||
|
def watch_variable(self, v):
|
||||||
|
self._watched_variables.add(v)
|
||||||
|
self.watch(v.handle)
|
||||||
|
|
||||||
def _watch_with_tape(tape, resource_variable):
|
def record_operation(self, output_tensors, input_tensors, side_outputs,
|
||||||
"""Wraps a watched Tensor and keeps track of it in the implicit tape."""
|
backward_function):
|
||||||
tensor = resource_variable.handle
|
"""Records an operation in the tape."""
|
||||||
w = _watch_with_tape_internal(tape, tensor)
|
if not self.should_record(input_tensors):
|
||||||
if ag_core.isnode(tape):
|
return output_tensors
|
||||||
tape.value.variables[ops.tensor_id(tensor)] = resource_variable
|
for t in output_tensors:
|
||||||
tape.value.tensors[ops.tensor_id(tensor)] = w
|
self._tensor_tape[tid(t)] = self._next_op_id
|
||||||
|
self._shape_dtype[tid(t)] = (_tensor_shape(t), t.dtype)
|
||||||
|
|
||||||
|
self._op_tape[self._next_op_id] = TapeEntry(
|
||||||
|
[tid(t) for t in output_tensors],
|
||||||
|
input_tensors,
|
||||||
|
side_outputs,
|
||||||
|
backward_function)
|
||||||
|
self._next_op_id += 1
|
||||||
|
|
||||||
def _watch_with_tape_vjp(g, ans, vs, gvs, tape, tensor):
|
def delete_trace(self, tensor):
|
||||||
"""Gradient for _watch_with_tape_internal."""
|
"""Deletes any trace we have for this tensor."""
|
||||||
del ans, gvs
|
if tid(tensor) in self._tensor_tape:
|
||||||
|
op = self._tensor_tape[tid(tensor)]
|
||||||
|
del self._tensor_tape[tid(tensor)]
|
||||||
|
if op in self._op_tape:
|
||||||
|
if not any(
|
||||||
|
x in self._tensor_tape for x in self._op_tape[op].output_ids):
|
||||||
|
del self._op_tape[op]
|
||||||
|
|
||||||
def mut_add(implicit_tape):
|
def export(self):
|
||||||
resource_variable = tape.value.variables[ops.tensor_id(tensor)]
|
"""Exports the internal state of this tape.
|
||||||
implicit_tape.gradients.append((g, resource_variable))
|
|
||||||
return implicit_tape
|
|
||||||
|
|
||||||
return ag_core.SparseObject(vs, mut_add)
|
Returns:
|
||||||
|
tensor_tape: a map from tensor_id(tensor) to <identifier for op>
|
||||||
_watch_with_tape_internal.defvjp(_watch_with_tape_vjp, argnum=0)
|
responsible for generating that tensor.
|
||||||
_watch_with_tape_internal.defvjp(
|
op_tape: a map from <identifier for op> to TapeEntry for that op.
|
||||||
lambda g, ans, vs, gvs, tape, tensor: g,
|
output_to_shape_dtype: a map from tensor_id(tensor) to its shape and
|
||||||
argnum=1)
|
dtype, for tensors which are outputs
|
||||||
|
"""
|
||||||
|
return self._tensor_tape, self._op_tape, self._shape_dtype
|
||||||
class ImplicitTapeVSpace(ag_core.VSpace):
|
|
||||||
"""VSpace needed to have ImplicitTape be a valid progenitor."""
|
|
||||||
|
|
||||||
def zeros(self):
|
|
||||||
return ImplicitTape()
|
|
||||||
|
|
||||||
|
|
||||||
class ImplicitTapeNode(ag_core.Node):
|
|
||||||
"""Node to wrap ImplicitTape in."""
|
|
||||||
|
|
||||||
def __eq__(self, other):
|
|
||||||
return self is other
|
|
||||||
|
|
||||||
def __hash__(self):
|
|
||||||
return id(self)
|
|
||||||
|
|
||||||
ag_core.register_node(ImplicitTapeNode, ImplicitTape)
|
|
||||||
ag_core.register_vspace(ImplicitTapeVSpace, ImplicitTape)
|
|
||||||
|
|
||||||
|
|
||||||
# TODO(apassos) try to not do this.
|
|
||||||
class NoneVSpace(ag_core.VSpace):
|
|
||||||
"""VSpace for python None."""
|
|
||||||
|
|
||||||
def __init__(self, _):
|
|
||||||
self.size = 0
|
|
||||||
|
|
||||||
def zeros(self):
|
|
||||||
return 0
|
|
||||||
|
|
||||||
|
|
||||||
ag_core.register_vspace(NoneVSpace, type(None))
|
|
||||||
|
|
||||||
|
|
||||||
class _TapeStack(threading.local):
|
class _TapeStack(threading.local):
|
||||||
@ -134,19 +152,33 @@ _tape_stack = _TapeStack()
|
|||||||
|
|
||||||
def push_new_tape():
|
def push_new_tape():
|
||||||
"""Pushes a new tape onto the tape stack."""
|
"""Pushes a new tape onto the tape stack."""
|
||||||
progenitor = ag_core.new_progenitor(ImplicitTape())
|
_tape_stack.stack.append(Tape())
|
||||||
_tape_stack.stack.append(progenitor)
|
|
||||||
ag_core.active_progenitors.add(progenitor)
|
|
||||||
|
|
||||||
|
|
||||||
def watch_variable(resource_variable):
|
def watch(tensor):
|
||||||
"""Marks this ResourceVariable to be watched by all tapes in the stack.
|
"""Marks this tensor to be watched by all tapes in the stack.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
resource_variable: A ResourceVariable to be watched.
|
tensor: tensor to be watched.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The tensor, potentially wrapped by all tapes in the stack.
|
||||||
"""
|
"""
|
||||||
for t in _tape_stack.stack:
|
for t in _tape_stack.stack:
|
||||||
_watch_with_tape(t, resource_variable)
|
t.watch(tensor)
|
||||||
|
|
||||||
|
|
||||||
|
def watch_variable(variable):
|
||||||
|
"""Marks this variable to be watched by all tapes in the stack.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
variable: variable to be watched.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The tensor, potentially wrapped by all tapes in the stack.
|
||||||
|
"""
|
||||||
|
for t in _tape_stack.stack:
|
||||||
|
t.watch_variable(variable)
|
||||||
|
|
||||||
|
|
||||||
def pop_tape():
|
def pop_tape():
|
||||||
@ -156,85 +188,34 @@ def pop_tape():
|
|||||||
return None
|
return None
|
||||||
|
|
||||||
|
|
||||||
def any_tape_has(tensor):
|
|
||||||
for t in _tape_stack.stack:
|
|
||||||
if ops.tensor_id(tensor) in t.value.tensors:
|
|
||||||
return True
|
|
||||||
return False
|
|
||||||
|
|
||||||
|
|
||||||
def should_record(tensors):
|
def should_record(tensors):
|
||||||
"""Returns true if any tape in the stack watches any of these tensors."""
|
"""Returns true if any tape in the stach watches any of these tensors."""
|
||||||
return any(ag_core.isnode(x) for x in tensors)
|
if not _tape_stack.stack:
|
||||||
|
return False
|
||||||
|
return any(x.should_record(tensors) for x in _tape_stack.stack)
|
||||||
|
|
||||||
|
|
||||||
class _EagerSequenceNode(container_types.SequenceNode):
|
def record_operation(output_tensors, input_tensors, side_outputs,
|
||||||
"""Eager version of SequenceNode, to live in EagerSequenceVSpace."""
|
backward_function):
|
||||||
pass
|
"""Records the operation on all tapes in the stack."""
|
||||||
|
for t in _tape_stack.stack:
|
||||||
|
t.record_operation(output_tensors,
|
||||||
|
input_tensors,
|
||||||
|
side_outputs,
|
||||||
|
backward_function)
|
||||||
|
|
||||||
|
|
||||||
class _EagerSequenceVSpace(container_types.SequenceVSpace):
|
def delete_trace(tensor):
|
||||||
"""Changes equality on SequenceVSpace to conform to tfe requirements."""
|
"""Deletes traces for this Tensor from all tapes in the stack."""
|
||||||
|
for t in _tape_stack.stack:
|
||||||
def __init__(self, value):
|
t.delete_trace(tensor)
|
||||||
self.shape = [ag_core.vspace(x) for x in value]
|
|
||||||
self.size = sum(s.size for s in self.shape)
|
|
||||||
self.sequence_type = type(value)
|
|
||||||
|
|
||||||
def __eq__(self, other):
|
|
||||||
if type(self) != type(other): # pylint: disable=unidiomatic-typecheck
|
|
||||||
return False
|
|
||||||
if len(self.shape) != len(other.shape):
|
|
||||||
# TODO(apassos) function gradients sometimes return gradients for side
|
|
||||||
# inputs which breaks this assertion. Understand how to fix it.
|
|
||||||
return True
|
|
||||||
for ss, os in zip(self.shape, other.shape):
|
|
||||||
if ss != os:
|
|
||||||
if isinstance(ss, NoneVSpace) or isinstance(os, NoneVSpace):
|
|
||||||
continue
|
|
||||||
if ss.dtype == dtypes.resource or os.dtype == dtypes.resource:
|
|
||||||
continue
|
|
||||||
return False
|
|
||||||
return True
|
|
||||||
|
|
||||||
|
|
||||||
class EagerList(list):
|
def top_tape_watched_tensors():
|
||||||
"""Type used to bypass SequenceVSpace.
|
t = _tape_stack.stack[-1]
|
||||||
|
return t._watched # pylint: disable=protected-access
|
||||||
SequenceVSpace has a very strict equality check which does not match
|
|
||||||
tensorflow semantics.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self, value):
|
|
||||||
super(EagerList, self).__init__(value)
|
|
||||||
for v in value:
|
|
||||||
assert not ag_core.isnode(v)
|
|
||||||
|
|
||||||
ag_core.register_vspace(_EagerSequenceVSpace, EagerList)
|
|
||||||
ag_core.register_node(_EagerSequenceNode, EagerList)
|
|
||||||
|
|
||||||
|
|
||||||
@ag_core.primitive
|
def top_tape_watched_variables():
|
||||||
def _record_operation(output_tensors, input_tensors, side_outputs,
|
t = _tape_stack.stack[-1]
|
||||||
backward_function):
|
return t._watched_variables # pylint: disable=protected-access
|
||||||
del input_tensors, side_outputs, backward_function
|
|
||||||
return EagerList(output_tensors)
|
|
||||||
|
|
||||||
|
|
||||||
def record_operation(o, i, s, b):
|
|
||||||
"""Primitive to trigger autograd tracing on outputs from inputs."""
|
|
||||||
inputs = container_types.make_sequence(EagerList, *i)
|
|
||||||
return _record_operation(o, inputs, s, b)
|
|
||||||
|
|
||||||
|
|
||||||
def _record_operation_vjp(g, ans, vs, gvs, output_tensors, input_tensors,
|
|
||||||
side_outputs, backward_function):
|
|
||||||
"""Gradient for _record_operation."""
|
|
||||||
del vs, gvs, input_tensors, output_tensors
|
|
||||||
backward_args = tuple(g) + tuple(side_outputs)
|
|
||||||
backward_args = container_types.make_sequence(
|
|
||||||
EagerList, *(tuple(ans) + backward_args))
|
|
||||||
tensors = nest.flatten(backward_function(*backward_args))
|
|
||||||
return container_types.make_sequence(EagerList, *tensors)
|
|
||||||
|
|
||||||
_record_operation.defvjp(_record_operation_vjp, argnum=1)
|
|
||||||
|
@ -1,301 +0,0 @@
|
|||||||
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
|
||||||
#
|
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
# you may not use this file except in compliance with the License.
|
|
||||||
# You may obtain a copy of the License at
|
|
||||||
#
|
|
||||||
# http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
#
|
|
||||||
# Unless required by applicable law or agreed to in writing, software
|
|
||||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
# See the License for the specific language governing permissions and
|
|
||||||
# limitations under the License.
|
|
||||||
# ==============================================================================
|
|
||||||
"""TensorNode for autograd tracing of computations with Tensors."""
|
|
||||||
|
|
||||||
from __future__ import absolute_import
|
|
||||||
from __future__ import division
|
|
||||||
from __future__ import print_function
|
|
||||||
|
|
||||||
from autograd import core as ag_core
|
|
||||||
|
|
||||||
from tensorflow.python.eager import context
|
|
||||||
from tensorflow.python.eager import custom_gradient
|
|
||||||
from tensorflow.python.eager import tensor
|
|
||||||
from tensorflow.python.framework import common_shapes
|
|
||||||
from tensorflow.python.framework import dtypes
|
|
||||||
from tensorflow.python.framework import ops
|
|
||||||
from tensorflow.python.framework import tensor_shape
|
|
||||||
from tensorflow.python.ops import array_ops
|
|
||||||
from tensorflow.python.ops import math_ops
|
|
||||||
|
|
||||||
|
|
||||||
@ag_core.primitive
|
|
||||||
def _tensor_numpy(t):
|
|
||||||
return t.numpy()
|
|
||||||
|
|
||||||
|
|
||||||
@ag_core.primitive
|
|
||||||
def _as_gpu_tensor(t, index=0):
|
|
||||||
return t.as_gpu_tensor(gpu_index=index)
|
|
||||||
|
|
||||||
|
|
||||||
_as_gpu_tensor.defvjp(
|
|
||||||
lambda g, ans, vs, gvs, t, index: g.as_cpu_tensor(), argnum=0)
|
|
||||||
|
|
||||||
|
|
||||||
@custom_gradient.custom_gradient
|
|
||||||
def _tensor_copy(t, ctx=None, device_name=None):
|
|
||||||
|
|
||||||
def grad(dresult):
|
|
||||||
return dresult._copy(device_name=t.device) # pylint: disable=protected-access
|
|
||||||
|
|
||||||
return t.value._copy(ctx=ctx, device_name=device_name), grad # pylint: disable=protected-access
|
|
||||||
|
|
||||||
|
|
||||||
@ag_core.primitive
|
|
||||||
def _as_cpu_tensor(t):
|
|
||||||
return t.as_cpu_tensor()
|
|
||||||
|
|
||||||
|
|
||||||
_as_cpu_tensor.defvjp(lambda g, ans, vs, gvs, t: g.as_gpu_tensor(), argnum=0)
|
|
||||||
|
|
||||||
|
|
||||||
# TODO(apassos,ashankar): The operator overrides here need to be kept in sync
|
|
||||||
# with the overrides for ops.Tensor and ops.EagerTensor.
|
|
||||||
#
|
|
||||||
# Note that we cannot use self.value.__op__() because that would result
|
|
||||||
# in an ops.EagerTensor instead of a TensorNode being returned.
|
|
||||||
#
|
|
||||||
# We need to figure out a way to ensure that the two are in sync.
|
|
||||||
class TensorNode(ag_core.Node):
|
|
||||||
"""A TensorFlow Tensor."""
|
|
||||||
|
|
||||||
__slots__ = []
|
|
||||||
|
|
||||||
def __getitem__(self, idx):
|
|
||||||
return array_ops._SliceHelper(self, idx) # pylint: disable=protected-access
|
|
||||||
|
|
||||||
shape = property(lambda self: self.value.shape)
|
|
||||||
dtype = property(lambda self: self.value.dtype)
|
|
||||||
device = property(lambda self: self.value.device)
|
|
||||||
|
|
||||||
def get_shape(self):
|
|
||||||
return self.shape
|
|
||||||
|
|
||||||
def numpy(self):
|
|
||||||
return _tensor_numpy(self)
|
|
||||||
|
|
||||||
def _shape_tuple(self):
|
|
||||||
return self.value._shape_tuple # pylint: disable=protected-access
|
|
||||||
|
|
||||||
def as_cpu_tensor(self):
|
|
||||||
return _as_cpu_tensor(self)
|
|
||||||
|
|
||||||
def as_gpu_tensor(self, gpu_index=0):
|
|
||||||
return _as_gpu_tensor(self, gpu_index)
|
|
||||||
|
|
||||||
def _copy(self, ctx=None, device_name=None):
|
|
||||||
return _tensor_copy(self, ctx, device_name)
|
|
||||||
|
|
||||||
def __neg__(self):
|
|
||||||
return math_ops.negative(self)
|
|
||||||
|
|
||||||
def __abs__(self):
|
|
||||||
return math_ops.abs(self) # pylint: disable=protected-access
|
|
||||||
|
|
||||||
def __invert__(self):
|
|
||||||
# ops.Tensor used math_ops.logical_not as of August 2017.
|
|
||||||
# Now that bitwise_ops.invert exists, it might make sense
|
|
||||||
# for both ops.Tensor and TensorNode to use that if the
|
|
||||||
# type is compatible.
|
|
||||||
return math_ops.logical_not(self)
|
|
||||||
|
|
||||||
def __hash__(self):
|
|
||||||
return id(self)
|
|
||||||
|
|
||||||
def __add__(self, other):
|
|
||||||
if isinstance(self.value, tensor.LazyZero):
|
|
||||||
return other
|
|
||||||
if isinstance(other, tensor.LazyZero):
|
|
||||||
return self
|
|
||||||
return math_ops.add(self, other)
|
|
||||||
|
|
||||||
def __radd__(self, other):
|
|
||||||
if isinstance(self.value, tensor.LazyZero):
|
|
||||||
return other
|
|
||||||
if isinstance(ag_core.getval(other), tensor.LazyZero):
|
|
||||||
return self
|
|
||||||
return math_ops.add(other, self)
|
|
||||||
|
|
||||||
def __sub__(self, other):
|
|
||||||
return math_ops.subtract(self, other)
|
|
||||||
|
|
||||||
def __rsub__(self, other):
|
|
||||||
return math_ops.subtract(other, self)
|
|
||||||
|
|
||||||
def __mul__(self, other):
|
|
||||||
return math_ops.multiply(self, other)
|
|
||||||
|
|
||||||
def __rmul__(self, other):
|
|
||||||
return math_ops.multiply(other, self)
|
|
||||||
|
|
||||||
def __mod__(self, other):
|
|
||||||
return math_ops.floormod(self, other)
|
|
||||||
|
|
||||||
def __rmod__(self, other):
|
|
||||||
return math_ops.floormod(other, self)
|
|
||||||
|
|
||||||
def __pow__(self, other):
|
|
||||||
return math_ops.pow(self, other)
|
|
||||||
|
|
||||||
def __rpow__(self, other):
|
|
||||||
return math_ops.pow(other, self)
|
|
||||||
|
|
||||||
def __div__(self, other):
|
|
||||||
return math_ops._div_python2(self, other) # pylint: disable=protected-access
|
|
||||||
|
|
||||||
def __rdiv__(self, other):
|
|
||||||
return math_ops._div_python2(other, self) # pylint: disable=protected-access
|
|
||||||
|
|
||||||
def __truediv__(self, other):
|
|
||||||
return math_ops._truediv_python3(self, other) # pylint: disable=protected-access
|
|
||||||
|
|
||||||
def __rtruediv__(self, other):
|
|
||||||
return math_ops._truediv_python3(other, self) # pylint: disable=protected-access
|
|
||||||
|
|
||||||
def __floordiv__(self, other):
|
|
||||||
return math_ops.floordiv(self, other)
|
|
||||||
|
|
||||||
def __rfloordiv__(self, other):
|
|
||||||
return math_ops.floordiv(other, self)
|
|
||||||
|
|
||||||
def __eq__(self, other):
|
|
||||||
# math_ops.equal raises an error if shapes are not compatible, so check that
|
|
||||||
# explicitly first.
|
|
||||||
if common_shapes.is_broadcast_compatible(
|
|
||||||
self.shape, ops.convert_to_tensor(other).shape):
|
|
||||||
return math_ops.equal(self, other)
|
|
||||||
return False
|
|
||||||
|
|
||||||
def __gt__(self, other):
|
|
||||||
return math_ops.greater(self, other)
|
|
||||||
|
|
||||||
def __ge__(self, other):
|
|
||||||
return math_ops.greater_equal(self, other)
|
|
||||||
|
|
||||||
def __lt__(self, other):
|
|
||||||
return math_ops.less(self, other)
|
|
||||||
|
|
||||||
def __le__(self, other):
|
|
||||||
return math_ops.less_equal(self, other)
|
|
||||||
|
|
||||||
|
|
||||||
ag_core.register_node(TensorNode, tensor.Tensor)
|
|
||||||
ag_core.register_node(TensorNode, ops.Tensor)
|
|
||||||
|
|
||||||
|
|
||||||
def _zeros(shape, dtype):
|
|
||||||
with context.device("cpu:0"):
|
|
||||||
shape = tensor.Tensor(shape, dtype=dtypes.int32)
|
|
||||||
return array_ops.fill(shape, tensor.Tensor(0, dtype=dtype))
|
|
||||||
|
|
||||||
|
|
||||||
def _ones(shape, dtype):
|
|
||||||
return array_ops.fill(
|
|
||||||
tensor.Tensor(shape, dtype=dtypes.int32), tensor.Tensor(1, dtype=dtype))
|
|
||||||
|
|
||||||
|
|
||||||
def _lazy_zero_tensor(zero):
|
|
||||||
return _zeros(zero.shape, zero.dtype)
|
|
||||||
|
|
||||||
|
|
||||||
tensor.LazyZero.tensor = _lazy_zero_tensor
|
|
||||||
|
|
||||||
|
|
||||||
def _lazy_zero_to_tensor(lazy_zero, dtype=None, name=None, as_ref=False):
|
|
||||||
del as_ref, name, dtype
|
|
||||||
return _zeros(lazy_zero.shape, lazy_zero.dtype)
|
|
||||||
|
|
||||||
|
|
||||||
ops.register_tensor_conversion_function(tensor.LazyZero, _lazy_zero_to_tensor)
|
|
||||||
|
|
||||||
|
|
||||||
def _indexed_slices_to_tensor(value):
|
|
||||||
"""Converts an IndexedSlices object `value` to a Tensor.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
value: An ops.IndexedSlices object.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
A dense Tensor representing the values in the given IndexedSlices.
|
|
||||||
|
|
||||||
Raises:
|
|
||||||
ValueError: If the IndexedSlices does not have the same dtype.
|
|
||||||
"""
|
|
||||||
if value.dense_shape is None:
|
|
||||||
raise ValueError(
|
|
||||||
"Tensor conversion requested for IndexedSlices without dense_shape: %s"
|
|
||||||
% str(value))
|
|
||||||
return math_ops.unsorted_segment_sum(value.values, value.indices,
|
|
||||||
value.dense_shape[0])
|
|
||||||
|
|
||||||
|
|
||||||
class TensorVSpace(ag_core.VSpace):
|
|
||||||
"""VSpace for tf/tfe Tensors in autograd."""
|
|
||||||
|
|
||||||
def __init__(self, value):
|
|
||||||
self.size = 1
|
|
||||||
if isinstance(value, ops.IndexedSlices):
|
|
||||||
self.shape = tensor_shape.TensorShape(value.dense_shape.numpy())
|
|
||||||
self.dtype = value.values.dtype
|
|
||||||
else:
|
|
||||||
self.shape = value._shape_tuple() # pylint: disable=protected-access
|
|
||||||
self.dtype = value.dtype
|
|
||||||
# TODO(apassos) put gradients on the same device as ops.
|
|
||||||
|
|
||||||
def __eq__(self, other):
|
|
||||||
# TODO(apassos) consider revisiting this if not performance sensitive.
|
|
||||||
return True
|
|
||||||
|
|
||||||
def __ne__(self, other):
|
|
||||||
return not self.__eq__(other)
|
|
||||||
|
|
||||||
def zeros(self):
|
|
||||||
return tensor.LazyZero(self.shape, self.dtype)
|
|
||||||
|
|
||||||
def ones(self):
|
|
||||||
return _ones(self.shape, self.dtype)
|
|
||||||
|
|
||||||
def standard_basis(self):
|
|
||||||
raise NotImplementedError
|
|
||||||
|
|
||||||
def flatten(self, value):
|
|
||||||
return array_ops.reshape(value, tensor.Tensor(-1))
|
|
||||||
|
|
||||||
def unflatten(self, value):
|
|
||||||
return array_ops.reshape(value, tensor.Tensor(self.shape))
|
|
||||||
|
|
||||||
def mut_add(self, x, y):
|
|
||||||
"""Add wrapper safe for IndexedSlices and LazyZero."""
|
|
||||||
if isinstance(ag_core.getval(x), tensor.LazyZero):
|
|
||||||
return y
|
|
||||||
if isinstance(ag_core.getval(y), tensor.LazyZero):
|
|
||||||
return x
|
|
||||||
if isinstance(x, ops.IndexedSlices):
|
|
||||||
x = _indexed_slices_to_tensor(x)
|
|
||||||
if isinstance(y, ops.IndexedSlices):
|
|
||||||
y = _indexed_slices_to_tensor(y)
|
|
||||||
if x is None:
|
|
||||||
return y
|
|
||||||
if y is None:
|
|
||||||
return x
|
|
||||||
return math_ops.add(x, y)
|
|
||||||
|
|
||||||
|
|
||||||
ag_core.register_vspace(TensorVSpace, tensor.Tensor)
|
|
||||||
ag_core.register_vspace(TensorVSpace, ops.Tensor)
|
|
||||||
ag_core.register_vspace(TensorVSpace, ops.IndexedSlices)
|
|
||||||
ag_core.register_vspace(TensorVSpace, tensor.LazyZero)
|
|
||||||
ag_core.register_node(TensorNode, tensor.LazyZero)
|
|
@ -1,84 +0,0 @@
|
|||||||
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
|
||||||
#
|
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
# you may not use this file except in compliance with the License.
|
|
||||||
# You may obtain a copy of the License at
|
|
||||||
#
|
|
||||||
# http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
#
|
|
||||||
# Unless required by applicable law or agreed to in writing, software
|
|
||||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
# See the License for the specific language governing permissions and
|
|
||||||
# limitations under the License.
|
|
||||||
# ==============================================================================
|
|
||||||
from __future__ import absolute_import
|
|
||||||
from __future__ import division
|
|
||||||
from __future__ import print_function
|
|
||||||
|
|
||||||
from tensorflow.python.eager import tensor
|
|
||||||
from tensorflow.python.eager import tensor_node
|
|
||||||
from tensorflow.python.eager import test
|
|
||||||
from tensorflow.python.framework import test_util
|
|
||||||
|
|
||||||
|
|
||||||
def public_or_operator(n):
|
|
||||||
if not n.startswith("_"):
|
|
||||||
return True
|
|
||||||
return n.startswith("__") and n.endswith("__")
|
|
||||||
|
|
||||||
|
|
||||||
class TensorNodeTest(test_util.TensorFlowTestCase):
|
|
||||||
|
|
||||||
# TensorNode must implement autograd core's Node interface, which
|
|
||||||
# it does via inheritance. It also needs to be duck-typeable as
|
|
||||||
# a tensorflow.python.framework.ops.EagerTensor.
|
|
||||||
#
|
|
||||||
# This is a "test" to help ensure interface compatibility.
|
|
||||||
def testCanBeATensor(self):
|
|
||||||
# TODO(ashankar,apassos): This list of "exceptions" - list of
|
|
||||||
# Tensor methods not implemented by TensorNode needs to be
|
|
||||||
# trimmed.
|
|
||||||
exceptions = set([
|
|
||||||
"OVERLOADABLE_OPERATORS",
|
|
||||||
"__and__",
|
|
||||||
"__del__",
|
|
||||||
"__dict__",
|
|
||||||
"__iter__",
|
|
||||||
"__len__",
|
|
||||||
"__matmul__",
|
|
||||||
"__or__",
|
|
||||||
"__rand__",
|
|
||||||
"__rmatmul__",
|
|
||||||
"__ror__",
|
|
||||||
"__rxor__",
|
|
||||||
"__weakref__",
|
|
||||||
"__xor__",
|
|
||||||
# BEGIN: Methods of Tensor that EagerTensor raises exceptions on.
|
|
||||||
# But perhaps TensorNode should defer to "self.value.<method>" for
|
|
||||||
# them?
|
|
||||||
"consumers",
|
|
||||||
"eval",
|
|
||||||
"graph",
|
|
||||||
"name",
|
|
||||||
"op",
|
|
||||||
"set_shape",
|
|
||||||
"value_index",
|
|
||||||
# END: Methods of Tensor that EagerTensor raises exceptions on.
|
|
||||||
])
|
|
||||||
|
|
||||||
tensor_dir = dir(tensor.Tensor)
|
|
||||||
tensor_dir = filter(public_or_operator, tensor_dir)
|
|
||||||
tensor_dir = set(tensor_dir).difference(exceptions)
|
|
||||||
|
|
||||||
tensor_node_dir = set(dir(tensor_node.TensorNode))
|
|
||||||
|
|
||||||
missing = tensor_dir.difference(tensor_node_dir)
|
|
||||||
self.assertEqual(
|
|
||||||
0,
|
|
||||||
len(missing),
|
|
||||||
msg="Methods/properties missing in TensorNode: {}".format(missing))
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
test.main()
|
|
@ -41,7 +41,6 @@ from __future__ import absolute_import
|
|||||||
from __future__ import division
|
from __future__ import division
|
||||||
from __future__ import print_function
|
from __future__ import print_function
|
||||||
|
|
||||||
from autograd import core as ag_core
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
from tensorflow.core.framework import attr_value_pb2
|
from tensorflow.core.framework import attr_value_pb2
|
||||||
@ -76,7 +75,7 @@ def _eager_fill(dims, value):
|
|||||||
|
|
||||||
def convert_to_eager_tensor(t, dtype=None):
|
def convert_to_eager_tensor(t, dtype=None):
|
||||||
"""Converts the given `value` to an `EagerTensor`."""
|
"""Converts the given `value` to an `EagerTensor`."""
|
||||||
if isinstance(ag_core.getval(t), ops.EagerTensor):
|
if isinstance(t, ops.EagerTensor):
|
||||||
if dtype is not None and t.dtype != dtype:
|
if dtype is not None and t.dtype != dtype:
|
||||||
raise TypeError("Expected tensor with type %r not %r" % (dtype, t.dtype))
|
raise TypeError("Expected tensor with type %r not %r" % (dtype, t.dtype))
|
||||||
return t
|
return t
|
||||||
|
@ -19,7 +19,6 @@ from __future__ import absolute_import
|
|||||||
from __future__ import division
|
from __future__ import division
|
||||||
from __future__ import print_function
|
from __future__ import print_function
|
||||||
|
|
||||||
from autograd import core as ag_core
|
|
||||||
import six
|
import six
|
||||||
|
|
||||||
from tensorflow.core.framework import attr_value_pb2
|
from tensorflow.core.framework import attr_value_pb2
|
||||||
@ -503,7 +502,6 @@ class OpDefLibrary(object):
|
|||||||
default_dtype = default_type_attr_map[input_arg.type_attr]
|
default_dtype = default_type_attr_map[input_arg.type_attr]
|
||||||
|
|
||||||
try:
|
try:
|
||||||
values = ag_core.getval(values)
|
|
||||||
values = ops.internal_convert_to_tensor(
|
values = ops.internal_convert_to_tensor(
|
||||||
values,
|
values,
|
||||||
name=input_arg.name,
|
name=input_arg.name,
|
||||||
@ -784,7 +782,6 @@ class OpDefLibrary(object):
|
|||||||
if arg.is_ref]
|
if arg.is_ref]
|
||||||
with _MaybeColocateWith(must_colocate_inputs):
|
with _MaybeColocateWith(must_colocate_inputs):
|
||||||
# Add Op to graph
|
# Add Op to graph
|
||||||
inputs = [ag_core.getval(x) for x in inputs]
|
|
||||||
op = g.create_op(op_type_name, inputs, output_types, name=scope,
|
op = g.create_op(op_type_name, inputs, output_types, name=scope,
|
||||||
input_types=input_types, attrs=attr_protos,
|
input_types=input_types, attrs=attr_protos,
|
||||||
op_def=op_def)
|
op_def=op_def)
|
||||||
|
@ -25,7 +25,6 @@ import re
|
|||||||
import sys
|
import sys
|
||||||
import threading
|
import threading
|
||||||
|
|
||||||
from autograd import core as ag_core
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
import six
|
import six
|
||||||
@ -38,6 +37,7 @@ from tensorflow.core.framework import versions_pb2
|
|||||||
from tensorflow.python import pywrap_tensorflow as c_api
|
from tensorflow.python import pywrap_tensorflow as c_api
|
||||||
from tensorflow.python.eager import context
|
from tensorflow.python.eager import context
|
||||||
from tensorflow.python.eager import core
|
from tensorflow.python.eager import core
|
||||||
|
from tensorflow.python.eager import tape
|
||||||
from tensorflow.python.framework import c_api_util
|
from tensorflow.python.framework import c_api_util
|
||||||
from tensorflow.python.framework import device as pydev
|
from tensorflow.python.framework import device as pydev
|
||||||
from tensorflow.python.framework import dtypes
|
from tensorflow.python.framework import dtypes
|
||||||
@ -70,10 +70,9 @@ from tensorflow.python.util import tf_contextlib
|
|||||||
_USE_C_API = False
|
_USE_C_API = False
|
||||||
|
|
||||||
|
|
||||||
def tensor_id(t):
|
def tensor_id(tensor):
|
||||||
"""Returns a unique identifier for this Tensor."""
|
"""Returns a unique identifier for this Tensor."""
|
||||||
t = ag_core.getval(t)
|
return tensor._id # pylint: disable=protected-access
|
||||||
return t._id # pylint: disable=protected-access
|
|
||||||
|
|
||||||
|
|
||||||
def _in_gpu_device():
|
def _in_gpu_device():
|
||||||
@ -703,6 +702,7 @@ class EagerTensor(Tensor):
|
|||||||
|
|
||||||
def __del__(self):
|
def __del__(self):
|
||||||
try:
|
try:
|
||||||
|
tape.delete_trace(self)
|
||||||
if c_api is not None and c_api.TFE_DeleteTensorHandle is not None:
|
if c_api is not None and c_api.TFE_DeleteTensorHandle is not None:
|
||||||
c_api.TFE_DeleteTensorHandle(self._handle)
|
c_api.TFE_DeleteTensorHandle(self._handle)
|
||||||
if core.active_trace() is not None:
|
if core.active_trace() is not None:
|
||||||
@ -727,7 +727,7 @@ class EagerTensor(Tensor):
|
|||||||
self.dtype.name)
|
self.dtype.name)
|
||||||
|
|
||||||
def __repr__(self):
|
def __repr__(self):
|
||||||
return "<tf.Tensor: id=%s, shape=%s, dtype=%s, numpy=%s)>" % (
|
return "<tf.Tensor: id=%s, shape=%s, dtype=%s, numpy=%s>" % (
|
||||||
self._id, self.shape, self.dtype.name, self._numpy_text(is_repr=True))
|
self._id, self.shape, self.dtype.name, self._numpy_text(is_repr=True))
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
@ -770,6 +770,16 @@ class EagerTensor(Tensor):
|
|||||||
tensor_id(new_tensor),
|
tensor_id(new_tensor),
|
||||||
new_tensor.device,
|
new_tensor.device,
|
||||||
new_tensor.shape.num_elements())
|
new_tensor.shape.num_elements())
|
||||||
|
|
||||||
|
# Record the copy on tape and define backprop copy as well.
|
||||||
|
if not context.in_graph_mode():
|
||||||
|
self_device = self.device
|
||||||
|
def grad_fun(dresult):
|
||||||
|
with errors.raise_exception_on_not_ok_status() as status:
|
||||||
|
grad_h = c_api.TFE_TensorHandleCopyToDevice(
|
||||||
|
dresult._handle, ctx._handle, self_device, status)
|
||||||
|
return _tensor_from_handle(grad_h)
|
||||||
|
tape.record_operation([new_tensor], [self], [], grad_fun)
|
||||||
return new_tensor
|
return new_tensor
|
||||||
# pylint: enable=protected-access
|
# pylint: enable=protected-access
|
||||||
|
|
||||||
@ -1033,26 +1043,21 @@ def internal_convert_to_tensor(value,
|
|||||||
RuntimeError: If a registered conversion function returns an invalid value.
|
RuntimeError: If a registered conversion function returns an invalid value.
|
||||||
|
|
||||||
"""
|
"""
|
||||||
# Note we check the type of the object unwrapped from an autograd node, if
|
|
||||||
# tracing gradients, to ensure the same behavior happens with and without
|
|
||||||
# tracing.
|
|
||||||
unwrapped = ag_core.getval(value)
|
|
||||||
|
|
||||||
if context.in_eager_mode():
|
if context.in_eager_mode():
|
||||||
# Fast path for EagerTensors that don't need any conversion.
|
# Fast path for EagerTensors that don't need any conversion.
|
||||||
if isinstance(unwrapped, EagerTensor):
|
if isinstance(value, EagerTensor):
|
||||||
# Note that we don't check that value's dtype matches the dtype
|
# Note that we don't check that value's dtype matches the dtype
|
||||||
# argument. We exepct that the C runtime will do that checking
|
# argument. We exepct that the C runtime will do that checking
|
||||||
# when we execute the kernel.
|
# when we execute the kernel.
|
||||||
return value
|
return value
|
||||||
values = nest.flatten(value)
|
values = nest.flatten(value)
|
||||||
if (len(values) > 1 and
|
if (len(values) > 1 and
|
||||||
any(isinstance(ag_core.getval(v), EagerTensor) for v in values)):
|
any(isinstance(v, EagerTensor) for v in values)):
|
||||||
raise TypeError("Cannot convert to a eager tensor.")
|
raise TypeError("Cannot convert to a eager tensor.")
|
||||||
|
|
||||||
if dtype is not None:
|
if dtype is not None:
|
||||||
dtype = dtypes.as_dtype(dtype)
|
dtype = dtypes.as_dtype(dtype)
|
||||||
unwrapped_type = type(unwrapped)
|
unwrapped_type = type(value)
|
||||||
conversion_func_list = _tensor_conversion_func_cache.get(unwrapped_type, None)
|
conversion_func_list = _tensor_conversion_func_cache.get(unwrapped_type, None)
|
||||||
if conversion_func_list is None:
|
if conversion_func_list is None:
|
||||||
with _tensor_conversion_func_lock:
|
with _tensor_conversion_func_lock:
|
||||||
@ -1060,7 +1065,7 @@ def internal_convert_to_tensor(value,
|
|||||||
for _, funcs_at_priority in sorted(
|
for _, funcs_at_priority in sorted(
|
||||||
_tensor_conversion_func_registry.items()):
|
_tensor_conversion_func_registry.items()):
|
||||||
for base_type, conversion_func in funcs_at_priority:
|
for base_type, conversion_func in funcs_at_priority:
|
||||||
if isinstance(unwrapped, base_type):
|
if isinstance(value, base_type):
|
||||||
conversion_func_list.append((base_type, conversion_func))
|
conversion_func_list.append((base_type, conversion_func))
|
||||||
_tensor_conversion_func_cache[unwrapped_type] = conversion_func_list
|
_tensor_conversion_func_cache[unwrapped_type] = conversion_func_list
|
||||||
|
|
||||||
@ -1090,7 +1095,7 @@ def internal_convert_to_tensor(value,
|
|||||||
if ret is NotImplemented:
|
if ret is NotImplemented:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
if not isinstance(ag_core.getval(ret), Tensor):
|
if not isinstance(ret, Tensor):
|
||||||
raise RuntimeError(
|
raise RuntimeError(
|
||||||
"%sConversion function %r for type %s returned non-Tensor: %r" %
|
"%sConversion function %r for type %s returned non-Tensor: %r" %
|
||||||
(_error_prefix(name), conversion_func, base_type, ret))
|
(_error_prefix(name), conversion_func, base_type, ret))
|
||||||
|
@ -18,7 +18,6 @@ from __future__ import absolute_import
|
|||||||
from __future__ import division
|
from __future__ import division
|
||||||
from __future__ import print_function
|
from __future__ import print_function
|
||||||
|
|
||||||
from autograd import core as ag_core
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import six
|
import six
|
||||||
|
|
||||||
@ -607,7 +606,7 @@ def ShapeEquals(tensor_proto, shape):
|
|||||||
|
|
||||||
def _ConstantValue(tensor, partial):
|
def _ConstantValue(tensor, partial):
|
||||||
# TODO(touts): Support Variables?
|
# TODO(touts): Support Variables?
|
||||||
if not isinstance(ag_core.getval(tensor), ops.Tensor):
|
if not isinstance(tensor, ops.Tensor):
|
||||||
raise TypeError("tensor is not a Tensor")
|
raise TypeError("tensor is not a Tensor")
|
||||||
if tensor.op.type == "Const":
|
if tensor.op.type == "Const":
|
||||||
return MakeNdarray(tensor.op.get_attr("value"))
|
return MakeNdarray(tensor.op.get_attr("value"))
|
||||||
@ -737,7 +736,7 @@ def constant_value(tensor, partial=False): # pylint: disable=invalid-name
|
|||||||
Raises:
|
Raises:
|
||||||
TypeError: if tensor is not an ops.Tensor.
|
TypeError: if tensor is not an ops.Tensor.
|
||||||
"""
|
"""
|
||||||
if isinstance(ag_core.getval(tensor), ops.EagerTensor):
|
if isinstance(tensor, ops.EagerTensor):
|
||||||
return tensor.numpy()
|
return tensor.numpy()
|
||||||
ret = _ConstantValue(tensor, partial)
|
ret = _ConstantValue(tensor, partial)
|
||||||
if ret is not None:
|
if ret is not None:
|
||||||
|
@ -291,10 +291,11 @@ def shape_internal(input, name=None, optimize=True, out_type=dtypes.int32):
|
|||||||
sparse_tensor.SparseTensorValue)):
|
sparse_tensor.SparseTensorValue)):
|
||||||
return gen_math_ops.cast(input.dense_shape, out_type)
|
return gen_math_ops.cast(input.dense_shape, out_type)
|
||||||
else:
|
else:
|
||||||
input_tensor = ops.convert_to_tensor(input)
|
if context.in_graph_mode():
|
||||||
input_shape = input_tensor.get_shape()
|
input_tensor = ops.convert_to_tensor(input)
|
||||||
if optimize and input_shape.is_fully_defined():
|
input_shape = input_tensor.get_shape()
|
||||||
return constant(input_shape.as_list(), out_type, name=name)
|
if optimize and input_shape.is_fully_defined():
|
||||||
|
return constant(input_shape.as_list(), out_type, name=name)
|
||||||
return gen_array_ops.shape(input, name=name, out_type=out_type)
|
return gen_array_ops.shape(input, name=name, out_type=out_type)
|
||||||
|
|
||||||
|
|
||||||
@ -1428,7 +1429,9 @@ def zeros(shape, dtype=dtypes.float32, name=None):
|
|||||||
zero = ""
|
zero = ""
|
||||||
else:
|
else:
|
||||||
zero = 0
|
zero = 0
|
||||||
if context.in_eager_mode():
|
# Checking for boolean dtype to prevent attempting to run fill on the GPU
|
||||||
|
# which does not have a boolean kernel registered.
|
||||||
|
if context.in_eager_mode() and dtype != dtypes.bool:
|
||||||
return fill(shape, constant(zero, dtype=dtype), name=name)
|
return fill(shape, constant(zero, dtype=dtype), name=name)
|
||||||
try:
|
try:
|
||||||
shape = tensor_shape.as_shape(shape)
|
shape = tensor_shape.as_shape(shape)
|
||||||
|
@ -144,7 +144,6 @@ from __future__ import absolute_import
|
|||||||
from __future__ import division
|
from __future__ import division
|
||||||
from __future__ import print_function
|
from __future__ import print_function
|
||||||
|
|
||||||
from autograd import core as ag_core
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from six.moves import xrange # pylint: disable=redefined-builtin
|
from six.moves import xrange # pylint: disable=redefined-builtin
|
||||||
|
|
||||||
@ -1113,12 +1112,11 @@ floormod = gen_math_ops._floor_mod
|
|||||||
|
|
||||||
def _mul_dispatch(x, y, name=None):
|
def _mul_dispatch(x, y, name=None):
|
||||||
"""Dispatches cwise mul for "Dense*Dense" and "Dense*Sparse"."""
|
"""Dispatches cwise mul for "Dense*Dense" and "Dense*Sparse"."""
|
||||||
is_tensor_y = isinstance(ag_core.getval(y), ops.Tensor)
|
is_tensor_y = isinstance(y, ops.Tensor)
|
||||||
if is_tensor_y:
|
if is_tensor_y:
|
||||||
return gen_math_ops._mul(x, y, name=name)
|
return gen_math_ops._mul(x, y, name=name)
|
||||||
else:
|
else:
|
||||||
assert isinstance(ag_core.getval(y),
|
assert isinstance(y, sparse_tensor.SparseTensor) # Case: Dense * Sparse.
|
||||||
sparse_tensor.SparseTensor) # Case: Dense * Sparse.
|
|
||||||
new_vals = gen_sparse_ops.sparse_dense_cwise_mul(y.indices, y.values,
|
new_vals = gen_sparse_ops.sparse_dense_cwise_mul(y.indices, y.values,
|
||||||
y.dense_shape, x, name)
|
y.dense_shape, x, name)
|
||||||
return sparse_tensor.SparseTensor(y.indices, new_vals, y.dense_shape)
|
return sparse_tensor.SparseTensor(y.indices, new_vals, y.dense_shape)
|
||||||
|
@ -19,14 +19,10 @@ from __future__ import absolute_import
|
|||||||
from __future__ import division
|
from __future__ import division
|
||||||
from __future__ import print_function
|
from __future__ import print_function
|
||||||
|
|
||||||
from autograd import core as ag_core
|
|
||||||
|
|
||||||
from tensorflow.core.framework import attr_value_pb2
|
from tensorflow.core.framework import attr_value_pb2
|
||||||
from tensorflow.core.framework import variable_pb2
|
from tensorflow.core.framework import variable_pb2
|
||||||
from tensorflow.python.eager import context
|
from tensorflow.python.eager import context
|
||||||
from tensorflow.python.eager import custom_gradient
|
|
||||||
from tensorflow.python.eager import tape
|
from tensorflow.python.eager import tape
|
||||||
from tensorflow.python.eager import tensor_node
|
|
||||||
from tensorflow.python.framework import dtypes
|
from tensorflow.python.framework import dtypes
|
||||||
from tensorflow.python.framework import ops
|
from tensorflow.python.framework import ops
|
||||||
from tensorflow.python.framework import tensor_shape
|
from tensorflow.python.framework import tensor_shape
|
||||||
@ -506,10 +502,8 @@ class ResourceVariable(variables.Variable):
|
|||||||
def _read_variable_op(self):
|
def _read_variable_op(self):
|
||||||
if hasattr(self, "_trainable") and self._trainable:
|
if hasattr(self, "_trainable") and self._trainable:
|
||||||
tape.watch_variable(self)
|
tape.watch_variable(self)
|
||||||
return read_variable_op(self._handle, dtype=self._dtype)
|
return gen_resource_variable_ops.read_variable_op(self._handle,
|
||||||
else:
|
self._dtype)
|
||||||
return gen_resource_variable_ops.read_variable_op(self._handle,
|
|
||||||
self._dtype)
|
|
||||||
|
|
||||||
def read_value(self):
|
def read_value(self):
|
||||||
"""Constructs an op which reads the value of this variable.
|
"""Constructs an op which reads the value of this variable.
|
||||||
@ -541,7 +535,7 @@ class ResourceVariable(variables.Variable):
|
|||||||
with ops.name_scope("Gather" if name is None else name) as name:
|
with ops.name_scope("Gather" if name is None else name) as name:
|
||||||
if self._trainable:
|
if self._trainable:
|
||||||
tape.watch_variable(self)
|
tape.watch_variable(self)
|
||||||
value = resource_gather(
|
value = gen_resource_variable_ops.resource_gather(
|
||||||
self._handle, indices, dtype=self._dtype, name=name)
|
self._handle, indices, dtype=self._dtype, name=name)
|
||||||
return array_ops.identity(value)
|
return array_ops.identity(value)
|
||||||
|
|
||||||
@ -614,13 +608,7 @@ class ResourceVariable(variables.Variable):
|
|||||||
def _run_op(a, *args):
|
def _run_op(a, *args):
|
||||||
# pylint: disable=protected-access
|
# pylint: disable=protected-access
|
||||||
value = a._AsTensor()
|
value = a._AsTensor()
|
||||||
if ag_core.isnode(value):
|
return getattr(ops.Tensor, operator)(value, *args)
|
||||||
# This avoids autograd trying to wrap a ResourceVariable.
|
|
||||||
value = ops.convert_to_tensor(value)
|
|
||||||
args = [ops.convert_to_tensor(x) for x in args]
|
|
||||||
return getattr(tensor_node.TensorNode, operator)(value, *args)
|
|
||||||
else:
|
|
||||||
return getattr(ops.Tensor, operator)(value, *args)
|
|
||||||
|
|
||||||
# Propagate __doc__ to wrapper
|
# Propagate __doc__ to wrapper
|
||||||
try:
|
try:
|
||||||
@ -693,33 +681,6 @@ class ResourceVariable(variables.Variable):
|
|||||||
return self.value()
|
return self.value()
|
||||||
|
|
||||||
|
|
||||||
@custom_gradient.custom_gradient
|
|
||||||
def read_variable_op(handle, dtype):
|
|
||||||
"""Reads the value of a variable.
|
|
||||||
|
|
||||||
The tensor returned by this operation is immutable.
|
|
||||||
|
|
||||||
The value returned by this operation is guaranteed to be influenced by all the
|
|
||||||
writes on which this operation depends directly or indirectly, and to not be
|
|
||||||
influenced by any of the writes which depend directly or indirectly on this
|
|
||||||
operation.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
handle: A `Tensor` of type `resource`.
|
|
||||||
handle to the resource in which to store the variable.
|
|
||||||
dtype: A `tf.DType`. the dtype of the value.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
A `Tensor` of type `dtype`.
|
|
||||||
"""
|
|
||||||
result = gen_resource_variable_ops.read_variable_op(handle, dtype)
|
|
||||||
|
|
||||||
def grad(dresult):
|
|
||||||
return dresult
|
|
||||||
|
|
||||||
return result, grad
|
|
||||||
|
|
||||||
|
|
||||||
def _dense_var_to_tensor(var, dtype=None, name=None, as_ref=False):
|
def _dense_var_to_tensor(var, dtype=None, name=None, as_ref=False):
|
||||||
return var._dense_var_to_tensor(dtype=dtype, name=name, as_ref=as_ref) # pylint: disable=protected-access
|
return var._dense_var_to_tensor(dtype=dtype, name=name, as_ref=as_ref) # pylint: disable=protected-access
|
||||||
|
|
||||||
@ -744,51 +705,6 @@ def _ReadGrad(_, grad):
|
|||||||
return grad
|
return grad
|
||||||
|
|
||||||
|
|
||||||
# TODO(apassos) do not use custom_gradient here by making other entry points
|
|
||||||
# than custom_gradient also aware of how to deal with variables implicitly
|
|
||||||
# watched in the tape (i.e. the call to _watch_value in custom_gradient)
|
|
||||||
@custom_gradient.custom_gradient
|
|
||||||
def resource_gather(resource, indices, dtype, validate_indices=True, name=None):
|
|
||||||
"""Gather slices from the variable pointed to by `resource`.
|
|
||||||
|
|
||||||
`indices` must be an integer tensor of any dimension (usually 0-D or 1-D).
|
|
||||||
Produces an output tensor with shape `indices.shape + params.shape[1:]` where:
|
|
||||||
|
|
||||||
```python
|
|
||||||
# Scalar indices
|
|
||||||
output[:, ..., :] = params[indices, :, ... :]
|
|
||||||
|
|
||||||
# Vector indices
|
|
||||||
output[i, :, ..., :] = params[indices[i], :, ... :]
|
|
||||||
|
|
||||||
# Higher rank indices
|
|
||||||
output[i, ..., j, :, ... :] = params[indices[i, ..., j], :, ..., :]
|
|
||||||
```
|
|
||||||
|
|
||||||
Args:
|
|
||||||
resource: A `Tensor` of type `resource`.
|
|
||||||
handle to the resource in which to store the variable.
|
|
||||||
indices: a integer `Tensor` containing the indices to be gathered.
|
|
||||||
dtype: A `tf.DType`. the dtype of the value.
|
|
||||||
validate_indices: optional `bool`. If false will not validate that the
|
|
||||||
indices fit in the variable.
|
|
||||||
name: The optional name for the operation to be added.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
A `Tensor` of type `dtype`.
|
|
||||||
"""
|
|
||||||
result = gen_resource_variable_ops.resource_gather(
|
|
||||||
resource, indices, dtype, validate_indices=validate_indices, name=name)
|
|
||||||
|
|
||||||
def grad(dresult):
|
|
||||||
return ops.IndexedSlices(
|
|
||||||
dresult,
|
|
||||||
indices,
|
|
||||||
dense_shape=gen_resource_variable_ops.variable_shape(resource))
|
|
||||||
|
|
||||||
return result, grad
|
|
||||||
|
|
||||||
|
|
||||||
@ops.RegisterGradient("ResourceGather")
|
@ops.RegisterGradient("ResourceGather")
|
||||||
def _GatherGrad(op, grad):
|
def _GatherGrad(op, grad):
|
||||||
"""Gradient for gather op."""
|
"""Gradient for gather op."""
|
||||||
@ -797,7 +713,11 @@ def _GatherGrad(op, grad):
|
|||||||
# TODO(apassos): more robust way of getting the shape.
|
# TODO(apassos): more robust way of getting the shape.
|
||||||
# TODO(apassos): implement this for EAGER mode.
|
# TODO(apassos): implement this for EAGER mode.
|
||||||
if context.in_eager_mode():
|
if context.in_eager_mode():
|
||||||
raise NotImplementedError("_GatherGrad not implemented for EAGER mode")
|
dense_shape = gen_resource_variable_ops.variable_shape(op.inputs[0])
|
||||||
|
return (ops.IndexedSlices(grad,
|
||||||
|
op.inputs[1],
|
||||||
|
dense_shape=dense_shape),
|
||||||
|
None)
|
||||||
handle = op.inputs[0]
|
handle = op.inputs[0]
|
||||||
while handle.op.type != "VarHandleOp":
|
while handle.op.type != "VarHandleOp":
|
||||||
handle = handle.op.inputs[0]
|
handle = handle.op.inputs[0]
|
||||||
|
Loading…
Reference in New Issue
Block a user