TFE: Improves the interfaces of tape.watch_variable() and implicit_grad().
tape.watch_variable() replaces tape.watch() and now is called on ResourceVariable objects instead of their underlying handles. implicit_grad() now returns a list of (gradient, variable) pairs to be consistent with tf.Optimizer's interface. PiperOrigin-RevId: 168232055
This commit is contained in:
parent
b72862dfc5
commit
ca43fe82bb
@ -18,6 +18,7 @@ from __future__ import absolute_import
|
|||||||
from __future__ import division
|
from __future__ import division
|
||||||
from __future__ import print_function
|
from __future__ import print_function
|
||||||
|
|
||||||
|
import collections
|
||||||
import threading
|
import threading
|
||||||
|
|
||||||
from autograd import container_types
|
from autograd import container_types
|
||||||
@ -171,22 +172,22 @@ execute.record_gradient = _record_gradient
|
|||||||
|
|
||||||
def _aggregate_grads(gradients):
|
def _aggregate_grads(gradients):
|
||||||
"""Aggregate gradients of the same tensor."""
|
"""Aggregate gradients of the same tensor."""
|
||||||
grad_lists = dict()
|
grad_lists = collections.OrderedDict()
|
||||||
for t, g in gradients:
|
for g, v in gradients:
|
||||||
if g is None:
|
if g is None:
|
||||||
continue
|
continue
|
||||||
if id(t) not in grad_lists:
|
if id(v) not in grad_lists:
|
||||||
grad_lists[id(t)] = [(t, g)]
|
grad_lists[id(v)] = [(g, v)]
|
||||||
else:
|
else:
|
||||||
grad_lists[id(t)].append((t, g))
|
grad_lists[id(v)].append((g, v))
|
||||||
|
|
||||||
ret = []
|
ret = []
|
||||||
for t, g_list in six.iteritems(grad_lists):
|
for _, g_list in six.iteritems(grad_lists):
|
||||||
if len(g_list) == 1:
|
if len(g_list) == 1:
|
||||||
ret.append(g_list[0])
|
ret.append(g_list[0])
|
||||||
else:
|
else:
|
||||||
# TODO(xpan): Aggregate IndexedSlices.
|
# TODO(xpan): Aggregate IndexedSlices.
|
||||||
ret.append((g_list[0][0], math_ops.add_n(list(zip(*g_list))[1])))
|
ret.append((math_ops.add_n(list(zip(*g_list))[0]), g_list[1][1]))
|
||||||
return ret
|
return ret
|
||||||
|
|
||||||
|
|
||||||
@ -200,11 +201,27 @@ def implicit_val_and_grad(f):
|
|||||||
This function is useful when the exact set of variables to differentiate with
|
This function is useful when the exact set of variables to differentiate with
|
||||||
is not known ahead of time.
|
is not known ahead of time.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
|
||||||
|
```python
|
||||||
|
def train(model, inputs, labels, optimizer):
|
||||||
|
|
||||||
|
def forward_fn():
|
||||||
|
prediction = model(inputs)
|
||||||
|
return loss_fn(labels, prediction)
|
||||||
|
|
||||||
|
loss, grads_and_vars = implicit_val_and_grad(forward_fn)()
|
||||||
|
optimizer.apply_gradients(grads_and_vars)
|
||||||
|
return loss
|
||||||
|
```
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
f: The function to be differentiated.
|
f: The function to be differentiated.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
A function which, when called, returns the value and gradients.
|
A function which, when called, returns a tuple pair.
|
||||||
|
Its first element is the value to which the function evaluates.
|
||||||
|
Its second element is list of (gradient, variable) pairs.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def grad_fn(*args, **kwds):
|
def grad_fn(*args, **kwds):
|
||||||
@ -242,7 +259,7 @@ def implicit_grad(f):
|
|||||||
f: The function to be differentiated.
|
f: The function to be differentiated.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
A function which, when called, returns the gradients.
|
A function which, when called, returns a list of (gradient, variable) pairs.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def grad_fn(*args, **kwds):
|
def grad_fn(*args, **kwds):
|
||||||
|
@ -90,8 +90,9 @@ class BackpropTest(test.TestCase):
|
|||||||
c = math_ops.add(x.value(), b)
|
c = math_ops.add(x.value(), b)
|
||||||
return math_ops.add(c, tensor.Tensor(3.0))
|
return math_ops.add(c, tensor.Tensor(3.0))
|
||||||
|
|
||||||
grad = backprop.implicit_grad(fn)()[0][1]
|
grads_and_vars = backprop.implicit_grad(fn)()
|
||||||
self.assertEqual(grad.numpy(), 1.0)
|
self.assertEqual(grads_and_vars[0][0].numpy(), 1.0)
|
||||||
|
self.assertEqual(id(grads_and_vars[0][1]), id(x))
|
||||||
|
|
||||||
def testImplicitGradOverEmbeddingLookup(self):
|
def testImplicitGradOverEmbeddingLookup(self):
|
||||||
batch_size = 8
|
batch_size = 8
|
||||||
@ -105,11 +106,11 @@ class BackpropTest(test.TestCase):
|
|||||||
initial_value=random_init, dtype=dtypes.float32, name='embedding')
|
initial_value=random_init, dtype=dtypes.float32, name='embedding')
|
||||||
|
|
||||||
def f():
|
def f():
|
||||||
tape.watch(embedding.handle)
|
tape.watch_variable(embedding)
|
||||||
embedded_x = embedding_ops.embedding_lookup(embedding, x)
|
embedded_x = embedding_ops.embedding_lookup(embedding, x)
|
||||||
return tensor.Tensor(1.0, dtypes.float32) - embedded_x
|
return tensor.Tensor(1.0, dtypes.float32) - embedded_x
|
||||||
|
|
||||||
grad = backprop.implicit_grad(f)()[0][1]
|
grad = backprop.implicit_grad(f)()[0][0]
|
||||||
opt = training.GradientDescentOptimizer(lrn_rate)
|
opt = training.GradientDescentOptimizer(lrn_rate)
|
||||||
|
|
||||||
with context.graph_mode(), self.test_session():
|
with context.graph_mode(), self.test_session():
|
||||||
@ -207,11 +208,11 @@ class BackpropTest(test.TestCase):
|
|||||||
|
|
||||||
def f():
|
def f():
|
||||||
with context.device('gpu:0'):
|
with context.device('gpu:0'):
|
||||||
tape.watch(v.handle)
|
tape.watch_variable(v)
|
||||||
return v.read_value()
|
return v.read_value()
|
||||||
|
|
||||||
self.assertEqual(
|
self.assertEqual(
|
||||||
backprop.implicit_grad(f)()[0][1].as_cpu_tensor().numpy(), 1.0)
|
backprop.implicit_grad(f)()[0][0].as_cpu_tensor().numpy(), 1.0)
|
||||||
|
|
||||||
def testCPU(self):
|
def testCPU(self):
|
||||||
|
|
||||||
@ -318,7 +319,7 @@ class BackpropTest(test.TestCase):
|
|||||||
b = array_ops.stack([a, a], axis=0)
|
b = array_ops.stack([a, a], axis=0)
|
||||||
return math_ops.reduce_mean(b)
|
return math_ops.reduce_mean(b)
|
||||||
|
|
||||||
grad = backprop.implicit_grad(fn)()[0][1]
|
grad = backprop.implicit_grad(fn)()[0][0]
|
||||||
self.assertAllEqual([1.0], grad.numpy())
|
self.assertAllEqual([1.0], grad.numpy())
|
||||||
|
|
||||||
|
|
||||||
|
@ -59,10 +59,10 @@ class FunctionTest(test.TestCase):
|
|||||||
@function.defun
|
@function.defun
|
||||||
def step():
|
def step():
|
||||||
def inner():
|
def inner():
|
||||||
tape.watch(v.handle)
|
tape.watch_variable(v)
|
||||||
return v * v
|
return v * v
|
||||||
|
|
||||||
return backprop.implicit_grad(inner)()[0][1]
|
return backprop.implicit_grad(inner)()[0][0]
|
||||||
|
|
||||||
self.assertAllEqual(step().numpy(), 2.0)
|
self.assertAllEqual(step().numpy(), 2.0)
|
||||||
|
|
||||||
@ -113,17 +113,17 @@ class FunctionTest(test.TestCase):
|
|||||||
g(tensor.Tensor(1.0))
|
g(tensor.Tensor(1.0))
|
||||||
|
|
||||||
def testGradientTensorConversionWithDefun(self):
|
def testGradientTensorConversionWithDefun(self):
|
||||||
three = tensor.Tensor(3.0)
|
three = resource_variable_ops.ResourceVariable(3.0)
|
||||||
|
|
||||||
@function.defun
|
@function.defun
|
||||||
def f(x):
|
def f(x):
|
||||||
return math_ops.add(x, three)
|
return math_ops.add(x, three)
|
||||||
|
|
||||||
def g(x):
|
def g(x):
|
||||||
tape.watch(three)
|
tape.watch_variable(three)
|
||||||
return f(x)
|
return f(x)
|
||||||
|
|
||||||
g = backprop.implicit_grad(g)(tensor.Tensor(1.0))[0][1]
|
g = backprop.implicit_grad(g)(tensor.Tensor(1.0))[0][0]
|
||||||
self.assertEqual(g.numpy(), 1.0)
|
self.assertEqual(g.numpy(), 1.0)
|
||||||
|
|
||||||
def testGradient(self):
|
def testGradient(self):
|
||||||
|
@ -34,6 +34,7 @@ class ImplicitTape(object):
|
|||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
self.tensors = {}
|
self.tensors = {}
|
||||||
|
self.variables = {}
|
||||||
self.gradients = []
|
self.gradients = []
|
||||||
|
|
||||||
def __eq__(self, other):
|
def __eq__(self, other):
|
||||||
@ -49,21 +50,22 @@ def _watch_with_tape_internal(_, tensor):
|
|||||||
return tensor
|
return tensor
|
||||||
|
|
||||||
|
|
||||||
def _watch_with_tape(tape, tensor):
|
def _watch_with_tape(tape, resource_variable):
|
||||||
"""Wraps a watched Tensor and keeps track of it in the implicit tape."""
|
"""Wraps a watched Tensor and keeps track of it in the implicit tape."""
|
||||||
|
tensor = resource_variable.handle
|
||||||
w = _watch_with_tape_internal(tape, tensor)
|
w = _watch_with_tape_internal(tape, tensor)
|
||||||
if ag_core.isnode(tape):
|
if ag_core.isnode(tape):
|
||||||
|
tape.value.variables[ops.tensor_id(tensor)] = resource_variable
|
||||||
tape.value.tensors[ops.tensor_id(tensor)] = w
|
tape.value.tensors[ops.tensor_id(tensor)] = w
|
||||||
return w
|
|
||||||
|
|
||||||
|
|
||||||
def _watch_with_tape_vjp(g, ans, vs, gvs, tape, tensor):
|
def _watch_with_tape_vjp(g, ans, vs, gvs, tape, tensor):
|
||||||
"""Gradient for _watch_with_tape_internal."""
|
"""Gradient for _watch_with_tape_internal."""
|
||||||
del ans, gvs, tape
|
del ans, gvs
|
||||||
|
|
||||||
def mut_add(implicit_tape):
|
def mut_add(implicit_tape):
|
||||||
t = ag_core.getval(tensor)
|
resource_variable = tape.value.variables[ops.tensor_id(tensor)]
|
||||||
implicit_tape.gradients.append((t, g))
|
implicit_tape.gradients.append((g, resource_variable))
|
||||||
return implicit_tape
|
return implicit_tape
|
||||||
|
|
||||||
return ag_core.SparseObject(vs, mut_add)
|
return ag_core.SparseObject(vs, mut_add)
|
||||||
@ -137,27 +139,14 @@ def push_new_tape():
|
|||||||
ag_core.active_progenitors.add(progenitor)
|
ag_core.active_progenitors.add(progenitor)
|
||||||
|
|
||||||
|
|
||||||
def watch(tensor):
|
|
||||||
"""Marks this tensor to be watched by all tapes in the stack.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
tensor: tensor to be watched.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
The tensor, potentially wrapped by all tapes in the stack.
|
|
||||||
"""
|
|
||||||
for t in _tape_stack.stack:
|
|
||||||
tensor = _watch_with_tape(t, tensor)
|
|
||||||
return tensor
|
|
||||||
|
|
||||||
|
|
||||||
def watch_variable(resource_variable):
|
def watch_variable(resource_variable):
|
||||||
"""Marks this ResourceVariable to be watched by all tapes in the stack.
|
"""Marks this ResourceVariable to be watched by all tapes in the stack.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
resource_variable: A ResourceVariable to be watched.
|
resource_variable: A ResourceVariable to be watched.
|
||||||
"""
|
"""
|
||||||
watch(resource_variable.handle) # py-lint: disable=protected-access
|
for t in _tape_stack.stack:
|
||||||
|
_watch_with_tape(t, resource_variable)
|
||||||
|
|
||||||
|
|
||||||
def pop_tape():
|
def pop_tape():
|
||||||
@ -175,7 +164,7 @@ def any_tape_has(tensor):
|
|||||||
|
|
||||||
|
|
||||||
def should_record(tensors):
|
def should_record(tensors):
|
||||||
"""Returns true if any tape in the stach watches any of these tensors."""
|
"""Returns true if any tape in the stack watches any of these tensors."""
|
||||||
return any(ag_core.isnode(x) for x in tensors)
|
return any(ag_core.isnode(x) for x in tensors)
|
||||||
|
|
||||||
|
|
||||||
|
@ -505,7 +505,7 @@ class ResourceVariable(variables.Variable):
|
|||||||
|
|
||||||
def _read_variable_op(self):
|
def _read_variable_op(self):
|
||||||
if hasattr(self, "_trainable") and self._trainable:
|
if hasattr(self, "_trainable") and self._trainable:
|
||||||
tape.watch(self._handle)
|
tape.watch_variable(self)
|
||||||
return read_variable_op(self._handle, dtype=self._dtype)
|
return read_variable_op(self._handle, dtype=self._dtype)
|
||||||
else:
|
else:
|
||||||
return gen_resource_variable_ops.read_variable_op(self._handle,
|
return gen_resource_variable_ops.read_variable_op(self._handle,
|
||||||
@ -540,7 +540,7 @@ class ResourceVariable(variables.Variable):
|
|||||||
"""Reads the value of this variable sparsely, using `gather`."""
|
"""Reads the value of this variable sparsely, using `gather`."""
|
||||||
with ops.name_scope("Gather" if name is None else name) as name:
|
with ops.name_scope("Gather" if name is None else name) as name:
|
||||||
if self._trainable:
|
if self._trainable:
|
||||||
tape.watch(self._handle)
|
tape.watch_variable(self)
|
||||||
value = resource_gather(
|
value = resource_gather(
|
||||||
self._handle, indices, dtype=self._dtype, name=name)
|
self._handle, indices, dtype=self._dtype, name=name)
|
||||||
return array_ops.identity(value)
|
return array_ops.identity(value)
|
||||||
|
Loading…
Reference in New Issue
Block a user