TFE: Improves the interfaces of tape.watch_variable() and implicit_grad().

tape.watch_variable() replaces tape.watch() and now is called on ResourceVariable objects instead of their underlying handles.

implicit_grad() now returns a list of (gradient, variable) pairs to be consistent with tf.Optimizer's interface.

PiperOrigin-RevId: 168232055
This commit is contained in:
Ali Yahya 2017-09-11 08:08:06 -07:00 committed by TensorFlower Gardener
parent b72862dfc5
commit ca43fe82bb
5 changed files with 51 additions and 44 deletions

View File

@ -18,6 +18,7 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import collections
import threading
from autograd import container_types
@ -171,22 +172,22 @@ execute.record_gradient = _record_gradient
def _aggregate_grads(gradients):
"""Aggregate gradients of the same tensor."""
grad_lists = dict()
for t, g in gradients:
grad_lists = collections.OrderedDict()
for g, v in gradients:
if g is None:
continue
if id(t) not in grad_lists:
grad_lists[id(t)] = [(t, g)]
if id(v) not in grad_lists:
grad_lists[id(v)] = [(g, v)]
else:
grad_lists[id(t)].append((t, g))
grad_lists[id(v)].append((g, v))
ret = []
for t, g_list in six.iteritems(grad_lists):
for _, g_list in six.iteritems(grad_lists):
if len(g_list) == 1:
ret.append(g_list[0])
else:
# TODO(xpan): Aggregate IndexedSlices.
ret.append((g_list[0][0], math_ops.add_n(list(zip(*g_list))[1])))
ret.append((math_ops.add_n(list(zip(*g_list))[0]), g_list[1][1]))
return ret
@ -200,11 +201,27 @@ def implicit_val_and_grad(f):
This function is useful when the exact set of variables to differentiate with
is not known ahead of time.
Example:
```python
def train(model, inputs, labels, optimizer):
def forward_fn():
prediction = model(inputs)
return loss_fn(labels, prediction)
loss, grads_and_vars = implicit_val_and_grad(forward_fn)()
optimizer.apply_gradients(grads_and_vars)
return loss
```
Args:
f: The function to be differentiated.
Returns:
A function which, when called, returns the value and gradients.
A function which, when called, returns a tuple pair.
Its first element is the value to which the function evaluates.
Its second element is list of (gradient, variable) pairs.
"""
def grad_fn(*args, **kwds):
@ -242,7 +259,7 @@ def implicit_grad(f):
f: The function to be differentiated.
Returns:
A function which, when called, returns the gradients.
A function which, when called, returns a list of (gradient, variable) pairs.
"""
def grad_fn(*args, **kwds):

View File

@ -90,8 +90,9 @@ class BackpropTest(test.TestCase):
c = math_ops.add(x.value(), b)
return math_ops.add(c, tensor.Tensor(3.0))
grad = backprop.implicit_grad(fn)()[0][1]
self.assertEqual(grad.numpy(), 1.0)
grads_and_vars = backprop.implicit_grad(fn)()
self.assertEqual(grads_and_vars[0][0].numpy(), 1.0)
self.assertEqual(id(grads_and_vars[0][1]), id(x))
def testImplicitGradOverEmbeddingLookup(self):
batch_size = 8
@ -105,11 +106,11 @@ class BackpropTest(test.TestCase):
initial_value=random_init, dtype=dtypes.float32, name='embedding')
def f():
tape.watch(embedding.handle)
tape.watch_variable(embedding)
embedded_x = embedding_ops.embedding_lookup(embedding, x)
return tensor.Tensor(1.0, dtypes.float32) - embedded_x
grad = backprop.implicit_grad(f)()[0][1]
grad = backprop.implicit_grad(f)()[0][0]
opt = training.GradientDescentOptimizer(lrn_rate)
with context.graph_mode(), self.test_session():
@ -207,11 +208,11 @@ class BackpropTest(test.TestCase):
def f():
with context.device('gpu:0'):
tape.watch(v.handle)
tape.watch_variable(v)
return v.read_value()
self.assertEqual(
backprop.implicit_grad(f)()[0][1].as_cpu_tensor().numpy(), 1.0)
backprop.implicit_grad(f)()[0][0].as_cpu_tensor().numpy(), 1.0)
def testCPU(self):
@ -318,7 +319,7 @@ class BackpropTest(test.TestCase):
b = array_ops.stack([a, a], axis=0)
return math_ops.reduce_mean(b)
grad = backprop.implicit_grad(fn)()[0][1]
grad = backprop.implicit_grad(fn)()[0][0]
self.assertAllEqual([1.0], grad.numpy())

View File

@ -59,10 +59,10 @@ class FunctionTest(test.TestCase):
@function.defun
def step():
def inner():
tape.watch(v.handle)
tape.watch_variable(v)
return v * v
return backprop.implicit_grad(inner)()[0][1]
return backprop.implicit_grad(inner)()[0][0]
self.assertAllEqual(step().numpy(), 2.0)
@ -113,17 +113,17 @@ class FunctionTest(test.TestCase):
g(tensor.Tensor(1.0))
def testGradientTensorConversionWithDefun(self):
three = tensor.Tensor(3.0)
three = resource_variable_ops.ResourceVariable(3.0)
@function.defun
def f(x):
return math_ops.add(x, three)
def g(x):
tape.watch(three)
tape.watch_variable(three)
return f(x)
g = backprop.implicit_grad(g)(tensor.Tensor(1.0))[0][1]
g = backprop.implicit_grad(g)(tensor.Tensor(1.0))[0][0]
self.assertEqual(g.numpy(), 1.0)
def testGradient(self):

View File

@ -34,6 +34,7 @@ class ImplicitTape(object):
def __init__(self):
self.tensors = {}
self.variables = {}
self.gradients = []
def __eq__(self, other):
@ -49,21 +50,22 @@ def _watch_with_tape_internal(_, tensor):
return tensor
def _watch_with_tape(tape, tensor):
def _watch_with_tape(tape, resource_variable):
"""Wraps a watched Tensor and keeps track of it in the implicit tape."""
tensor = resource_variable.handle
w = _watch_with_tape_internal(tape, tensor)
if ag_core.isnode(tape):
tape.value.variables[ops.tensor_id(tensor)] = resource_variable
tape.value.tensors[ops.tensor_id(tensor)] = w
return w
def _watch_with_tape_vjp(g, ans, vs, gvs, tape, tensor):
"""Gradient for _watch_with_tape_internal."""
del ans, gvs, tape
del ans, gvs
def mut_add(implicit_tape):
t = ag_core.getval(tensor)
implicit_tape.gradients.append((t, g))
resource_variable = tape.value.variables[ops.tensor_id(tensor)]
implicit_tape.gradients.append((g, resource_variable))
return implicit_tape
return ag_core.SparseObject(vs, mut_add)
@ -137,27 +139,14 @@ def push_new_tape():
ag_core.active_progenitors.add(progenitor)
def watch(tensor):
"""Marks this tensor to be watched by all tapes in the stack.
Args:
tensor: tensor to be watched.
Returns:
The tensor, potentially wrapped by all tapes in the stack.
"""
for t in _tape_stack.stack:
tensor = _watch_with_tape(t, tensor)
return tensor
def watch_variable(resource_variable):
"""Marks this ResourceVariable to be watched by all tapes in the stack.
Args:
resource_variable: A ResourceVariable to be watched.
"""
watch(resource_variable.handle) # py-lint: disable=protected-access
for t in _tape_stack.stack:
_watch_with_tape(t, resource_variable)
def pop_tape():
@ -175,7 +164,7 @@ def any_tape_has(tensor):
def should_record(tensors):
"""Returns true if any tape in the stach watches any of these tensors."""
"""Returns true if any tape in the stack watches any of these tensors."""
return any(ag_core.isnode(x) for x in tensors)

View File

@ -505,7 +505,7 @@ class ResourceVariable(variables.Variable):
def _read_variable_op(self):
if hasattr(self, "_trainable") and self._trainable:
tape.watch(self._handle)
tape.watch_variable(self)
return read_variable_op(self._handle, dtype=self._dtype)
else:
return gen_resource_variable_ops.read_variable_op(self._handle,
@ -540,7 +540,7 @@ class ResourceVariable(variables.Variable):
"""Reads the value of this variable sparsely, using `gather`."""
with ops.name_scope("Gather" if name is None else name) as name:
if self._trainable:
tape.watch(self._handle)
tape.watch_variable(self)
value = resource_gather(
self._handle, indices, dtype=self._dtype, name=name)
return array_ops.identity(value)