TFE: Improves the interfaces of tape.watch_variable() and implicit_grad().

tape.watch_variable() replaces tape.watch() and now is called on ResourceVariable objects instead of their underlying handles.

implicit_grad() now returns a list of (gradient, variable) pairs to be consistent with tf.Optimizer's interface.

PiperOrigin-RevId: 168232055
This commit is contained in:
Ali Yahya 2017-09-11 08:08:06 -07:00 committed by TensorFlower Gardener
parent b72862dfc5
commit ca43fe82bb
5 changed files with 51 additions and 44 deletions

View File

@ -18,6 +18,7 @@ from __future__ import absolute_import
from __future__ import division from __future__ import division
from __future__ import print_function from __future__ import print_function
import collections
import threading import threading
from autograd import container_types from autograd import container_types
@ -171,22 +172,22 @@ execute.record_gradient = _record_gradient
def _aggregate_grads(gradients): def _aggregate_grads(gradients):
"""Aggregate gradients of the same tensor.""" """Aggregate gradients of the same tensor."""
grad_lists = dict() grad_lists = collections.OrderedDict()
for t, g in gradients: for g, v in gradients:
if g is None: if g is None:
continue continue
if id(t) not in grad_lists: if id(v) not in grad_lists:
grad_lists[id(t)] = [(t, g)] grad_lists[id(v)] = [(g, v)]
else: else:
grad_lists[id(t)].append((t, g)) grad_lists[id(v)].append((g, v))
ret = [] ret = []
for t, g_list in six.iteritems(grad_lists): for _, g_list in six.iteritems(grad_lists):
if len(g_list) == 1: if len(g_list) == 1:
ret.append(g_list[0]) ret.append(g_list[0])
else: else:
# TODO(xpan): Aggregate IndexedSlices. # TODO(xpan): Aggregate IndexedSlices.
ret.append((g_list[0][0], math_ops.add_n(list(zip(*g_list))[1]))) ret.append((math_ops.add_n(list(zip(*g_list))[0]), g_list[1][1]))
return ret return ret
@ -200,11 +201,27 @@ def implicit_val_and_grad(f):
This function is useful when the exact set of variables to differentiate with This function is useful when the exact set of variables to differentiate with
is not known ahead of time. is not known ahead of time.
Example:
```python
def train(model, inputs, labels, optimizer):
def forward_fn():
prediction = model(inputs)
return loss_fn(labels, prediction)
loss, grads_and_vars = implicit_val_and_grad(forward_fn)()
optimizer.apply_gradients(grads_and_vars)
return loss
```
Args: Args:
f: The function to be differentiated. f: The function to be differentiated.
Returns: Returns:
A function which, when called, returns the value and gradients. A function which, when called, returns a tuple pair.
Its first element is the value to which the function evaluates.
Its second element is list of (gradient, variable) pairs.
""" """
def grad_fn(*args, **kwds): def grad_fn(*args, **kwds):
@ -242,7 +259,7 @@ def implicit_grad(f):
f: The function to be differentiated. f: The function to be differentiated.
Returns: Returns:
A function which, when called, returns the gradients. A function which, when called, returns a list of (gradient, variable) pairs.
""" """
def grad_fn(*args, **kwds): def grad_fn(*args, **kwds):

View File

@ -90,8 +90,9 @@ class BackpropTest(test.TestCase):
c = math_ops.add(x.value(), b) c = math_ops.add(x.value(), b)
return math_ops.add(c, tensor.Tensor(3.0)) return math_ops.add(c, tensor.Tensor(3.0))
grad = backprop.implicit_grad(fn)()[0][1] grads_and_vars = backprop.implicit_grad(fn)()
self.assertEqual(grad.numpy(), 1.0) self.assertEqual(grads_and_vars[0][0].numpy(), 1.0)
self.assertEqual(id(grads_and_vars[0][1]), id(x))
def testImplicitGradOverEmbeddingLookup(self): def testImplicitGradOverEmbeddingLookup(self):
batch_size = 8 batch_size = 8
@ -105,11 +106,11 @@ class BackpropTest(test.TestCase):
initial_value=random_init, dtype=dtypes.float32, name='embedding') initial_value=random_init, dtype=dtypes.float32, name='embedding')
def f(): def f():
tape.watch(embedding.handle) tape.watch_variable(embedding)
embedded_x = embedding_ops.embedding_lookup(embedding, x) embedded_x = embedding_ops.embedding_lookup(embedding, x)
return tensor.Tensor(1.0, dtypes.float32) - embedded_x return tensor.Tensor(1.0, dtypes.float32) - embedded_x
grad = backprop.implicit_grad(f)()[0][1] grad = backprop.implicit_grad(f)()[0][0]
opt = training.GradientDescentOptimizer(lrn_rate) opt = training.GradientDescentOptimizer(lrn_rate)
with context.graph_mode(), self.test_session(): with context.graph_mode(), self.test_session():
@ -207,11 +208,11 @@ class BackpropTest(test.TestCase):
def f(): def f():
with context.device('gpu:0'): with context.device('gpu:0'):
tape.watch(v.handle) tape.watch_variable(v)
return v.read_value() return v.read_value()
self.assertEqual( self.assertEqual(
backprop.implicit_grad(f)()[0][1].as_cpu_tensor().numpy(), 1.0) backprop.implicit_grad(f)()[0][0].as_cpu_tensor().numpy(), 1.0)
def testCPU(self): def testCPU(self):
@ -318,7 +319,7 @@ class BackpropTest(test.TestCase):
b = array_ops.stack([a, a], axis=0) b = array_ops.stack([a, a], axis=0)
return math_ops.reduce_mean(b) return math_ops.reduce_mean(b)
grad = backprop.implicit_grad(fn)()[0][1] grad = backprop.implicit_grad(fn)()[0][0]
self.assertAllEqual([1.0], grad.numpy()) self.assertAllEqual([1.0], grad.numpy())

View File

@ -59,10 +59,10 @@ class FunctionTest(test.TestCase):
@function.defun @function.defun
def step(): def step():
def inner(): def inner():
tape.watch(v.handle) tape.watch_variable(v)
return v * v return v * v
return backprop.implicit_grad(inner)()[0][1] return backprop.implicit_grad(inner)()[0][0]
self.assertAllEqual(step().numpy(), 2.0) self.assertAllEqual(step().numpy(), 2.0)
@ -113,17 +113,17 @@ class FunctionTest(test.TestCase):
g(tensor.Tensor(1.0)) g(tensor.Tensor(1.0))
def testGradientTensorConversionWithDefun(self): def testGradientTensorConversionWithDefun(self):
three = tensor.Tensor(3.0) three = resource_variable_ops.ResourceVariable(3.0)
@function.defun @function.defun
def f(x): def f(x):
return math_ops.add(x, three) return math_ops.add(x, three)
def g(x): def g(x):
tape.watch(three) tape.watch_variable(three)
return f(x) return f(x)
g = backprop.implicit_grad(g)(tensor.Tensor(1.0))[0][1] g = backprop.implicit_grad(g)(tensor.Tensor(1.0))[0][0]
self.assertEqual(g.numpy(), 1.0) self.assertEqual(g.numpy(), 1.0)
def testGradient(self): def testGradient(self):

View File

@ -34,6 +34,7 @@ class ImplicitTape(object):
def __init__(self): def __init__(self):
self.tensors = {} self.tensors = {}
self.variables = {}
self.gradients = [] self.gradients = []
def __eq__(self, other): def __eq__(self, other):
@ -49,21 +50,22 @@ def _watch_with_tape_internal(_, tensor):
return tensor return tensor
def _watch_with_tape(tape, tensor): def _watch_with_tape(tape, resource_variable):
"""Wraps a watched Tensor and keeps track of it in the implicit tape.""" """Wraps a watched Tensor and keeps track of it in the implicit tape."""
tensor = resource_variable.handle
w = _watch_with_tape_internal(tape, tensor) w = _watch_with_tape_internal(tape, tensor)
if ag_core.isnode(tape): if ag_core.isnode(tape):
tape.value.variables[ops.tensor_id(tensor)] = resource_variable
tape.value.tensors[ops.tensor_id(tensor)] = w tape.value.tensors[ops.tensor_id(tensor)] = w
return w
def _watch_with_tape_vjp(g, ans, vs, gvs, tape, tensor): def _watch_with_tape_vjp(g, ans, vs, gvs, tape, tensor):
"""Gradient for _watch_with_tape_internal.""" """Gradient for _watch_with_tape_internal."""
del ans, gvs, tape del ans, gvs
def mut_add(implicit_tape): def mut_add(implicit_tape):
t = ag_core.getval(tensor) resource_variable = tape.value.variables[ops.tensor_id(tensor)]
implicit_tape.gradients.append((t, g)) implicit_tape.gradients.append((g, resource_variable))
return implicit_tape return implicit_tape
return ag_core.SparseObject(vs, mut_add) return ag_core.SparseObject(vs, mut_add)
@ -137,27 +139,14 @@ def push_new_tape():
ag_core.active_progenitors.add(progenitor) ag_core.active_progenitors.add(progenitor)
def watch(tensor):
"""Marks this tensor to be watched by all tapes in the stack.
Args:
tensor: tensor to be watched.
Returns:
The tensor, potentially wrapped by all tapes in the stack.
"""
for t in _tape_stack.stack:
tensor = _watch_with_tape(t, tensor)
return tensor
def watch_variable(resource_variable): def watch_variable(resource_variable):
"""Marks this ResourceVariable to be watched by all tapes in the stack. """Marks this ResourceVariable to be watched by all tapes in the stack.
Args: Args:
resource_variable: A ResourceVariable to be watched. resource_variable: A ResourceVariable to be watched.
""" """
watch(resource_variable.handle) # py-lint: disable=protected-access for t in _tape_stack.stack:
_watch_with_tape(t, resource_variable)
def pop_tape(): def pop_tape():
@ -175,7 +164,7 @@ def any_tape_has(tensor):
def should_record(tensors): def should_record(tensors):
"""Returns true if any tape in the stach watches any of these tensors.""" """Returns true if any tape in the stack watches any of these tensors."""
return any(ag_core.isnode(x) for x in tensors) return any(ag_core.isnode(x) for x in tensors)

View File

@ -505,7 +505,7 @@ class ResourceVariable(variables.Variable):
def _read_variable_op(self): def _read_variable_op(self):
if hasattr(self, "_trainable") and self._trainable: if hasattr(self, "_trainable") and self._trainable:
tape.watch(self._handle) tape.watch_variable(self)
return read_variable_op(self._handle, dtype=self._dtype) return read_variable_op(self._handle, dtype=self._dtype)
else: else:
return gen_resource_variable_ops.read_variable_op(self._handle, return gen_resource_variable_ops.read_variable_op(self._handle,
@ -540,7 +540,7 @@ class ResourceVariable(variables.Variable):
"""Reads the value of this variable sparsely, using `gather`.""" """Reads the value of this variable sparsely, using `gather`."""
with ops.name_scope("Gather" if name is None else name) as name: with ops.name_scope("Gather" if name is None else name) as name:
if self._trainable: if self._trainable:
tape.watch(self._handle) tape.watch_variable(self)
value = resource_gather( value = resource_gather(
self._handle, indices, dtype=self._dtype, name=name) self._handle, indices, dtype=self._dtype, name=name)
return array_ops.identity(value) return array_ops.identity(value)