TFE: Improves the interfaces of tape.watch_variable() and implicit_grad().
tape.watch_variable() replaces tape.watch() and now is called on ResourceVariable objects instead of their underlying handles. implicit_grad() now returns a list of (gradient, variable) pairs to be consistent with tf.Optimizer's interface. PiperOrigin-RevId: 168232055
This commit is contained in:
parent
b72862dfc5
commit
ca43fe82bb
tensorflow/python
@ -18,6 +18,7 @@ from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import collections
|
||||
import threading
|
||||
|
||||
from autograd import container_types
|
||||
@ -171,22 +172,22 @@ execute.record_gradient = _record_gradient
|
||||
|
||||
def _aggregate_grads(gradients):
|
||||
"""Aggregate gradients of the same tensor."""
|
||||
grad_lists = dict()
|
||||
for t, g in gradients:
|
||||
grad_lists = collections.OrderedDict()
|
||||
for g, v in gradients:
|
||||
if g is None:
|
||||
continue
|
||||
if id(t) not in grad_lists:
|
||||
grad_lists[id(t)] = [(t, g)]
|
||||
if id(v) not in grad_lists:
|
||||
grad_lists[id(v)] = [(g, v)]
|
||||
else:
|
||||
grad_lists[id(t)].append((t, g))
|
||||
grad_lists[id(v)].append((g, v))
|
||||
|
||||
ret = []
|
||||
for t, g_list in six.iteritems(grad_lists):
|
||||
for _, g_list in six.iteritems(grad_lists):
|
||||
if len(g_list) == 1:
|
||||
ret.append(g_list[0])
|
||||
else:
|
||||
# TODO(xpan): Aggregate IndexedSlices.
|
||||
ret.append((g_list[0][0], math_ops.add_n(list(zip(*g_list))[1])))
|
||||
ret.append((math_ops.add_n(list(zip(*g_list))[0]), g_list[1][1]))
|
||||
return ret
|
||||
|
||||
|
||||
@ -200,11 +201,27 @@ def implicit_val_and_grad(f):
|
||||
This function is useful when the exact set of variables to differentiate with
|
||||
is not known ahead of time.
|
||||
|
||||
Example:
|
||||
|
||||
```python
|
||||
def train(model, inputs, labels, optimizer):
|
||||
|
||||
def forward_fn():
|
||||
prediction = model(inputs)
|
||||
return loss_fn(labels, prediction)
|
||||
|
||||
loss, grads_and_vars = implicit_val_and_grad(forward_fn)()
|
||||
optimizer.apply_gradients(grads_and_vars)
|
||||
return loss
|
||||
```
|
||||
|
||||
Args:
|
||||
f: The function to be differentiated.
|
||||
|
||||
Returns:
|
||||
A function which, when called, returns the value and gradients.
|
||||
A function which, when called, returns a tuple pair.
|
||||
Its first element is the value to which the function evaluates.
|
||||
Its second element is list of (gradient, variable) pairs.
|
||||
"""
|
||||
|
||||
def grad_fn(*args, **kwds):
|
||||
@ -242,7 +259,7 @@ def implicit_grad(f):
|
||||
f: The function to be differentiated.
|
||||
|
||||
Returns:
|
||||
A function which, when called, returns the gradients.
|
||||
A function which, when called, returns a list of (gradient, variable) pairs.
|
||||
"""
|
||||
|
||||
def grad_fn(*args, **kwds):
|
||||
|
@ -90,8 +90,9 @@ class BackpropTest(test.TestCase):
|
||||
c = math_ops.add(x.value(), b)
|
||||
return math_ops.add(c, tensor.Tensor(3.0))
|
||||
|
||||
grad = backprop.implicit_grad(fn)()[0][1]
|
||||
self.assertEqual(grad.numpy(), 1.0)
|
||||
grads_and_vars = backprop.implicit_grad(fn)()
|
||||
self.assertEqual(grads_and_vars[0][0].numpy(), 1.0)
|
||||
self.assertEqual(id(grads_and_vars[0][1]), id(x))
|
||||
|
||||
def testImplicitGradOverEmbeddingLookup(self):
|
||||
batch_size = 8
|
||||
@ -105,11 +106,11 @@ class BackpropTest(test.TestCase):
|
||||
initial_value=random_init, dtype=dtypes.float32, name='embedding')
|
||||
|
||||
def f():
|
||||
tape.watch(embedding.handle)
|
||||
tape.watch_variable(embedding)
|
||||
embedded_x = embedding_ops.embedding_lookup(embedding, x)
|
||||
return tensor.Tensor(1.0, dtypes.float32) - embedded_x
|
||||
|
||||
grad = backprop.implicit_grad(f)()[0][1]
|
||||
grad = backprop.implicit_grad(f)()[0][0]
|
||||
opt = training.GradientDescentOptimizer(lrn_rate)
|
||||
|
||||
with context.graph_mode(), self.test_session():
|
||||
@ -207,11 +208,11 @@ class BackpropTest(test.TestCase):
|
||||
|
||||
def f():
|
||||
with context.device('gpu:0'):
|
||||
tape.watch(v.handle)
|
||||
tape.watch_variable(v)
|
||||
return v.read_value()
|
||||
|
||||
self.assertEqual(
|
||||
backprop.implicit_grad(f)()[0][1].as_cpu_tensor().numpy(), 1.0)
|
||||
backprop.implicit_grad(f)()[0][0].as_cpu_tensor().numpy(), 1.0)
|
||||
|
||||
def testCPU(self):
|
||||
|
||||
@ -318,7 +319,7 @@ class BackpropTest(test.TestCase):
|
||||
b = array_ops.stack([a, a], axis=0)
|
||||
return math_ops.reduce_mean(b)
|
||||
|
||||
grad = backprop.implicit_grad(fn)()[0][1]
|
||||
grad = backprop.implicit_grad(fn)()[0][0]
|
||||
self.assertAllEqual([1.0], grad.numpy())
|
||||
|
||||
|
||||
|
@ -59,10 +59,10 @@ class FunctionTest(test.TestCase):
|
||||
@function.defun
|
||||
def step():
|
||||
def inner():
|
||||
tape.watch(v.handle)
|
||||
tape.watch_variable(v)
|
||||
return v * v
|
||||
|
||||
return backprop.implicit_grad(inner)()[0][1]
|
||||
return backprop.implicit_grad(inner)()[0][0]
|
||||
|
||||
self.assertAllEqual(step().numpy(), 2.0)
|
||||
|
||||
@ -113,17 +113,17 @@ class FunctionTest(test.TestCase):
|
||||
g(tensor.Tensor(1.0))
|
||||
|
||||
def testGradientTensorConversionWithDefun(self):
|
||||
three = tensor.Tensor(3.0)
|
||||
three = resource_variable_ops.ResourceVariable(3.0)
|
||||
|
||||
@function.defun
|
||||
def f(x):
|
||||
return math_ops.add(x, three)
|
||||
|
||||
def g(x):
|
||||
tape.watch(three)
|
||||
tape.watch_variable(three)
|
||||
return f(x)
|
||||
|
||||
g = backprop.implicit_grad(g)(tensor.Tensor(1.0))[0][1]
|
||||
g = backprop.implicit_grad(g)(tensor.Tensor(1.0))[0][0]
|
||||
self.assertEqual(g.numpy(), 1.0)
|
||||
|
||||
def testGradient(self):
|
||||
|
@ -34,6 +34,7 @@ class ImplicitTape(object):
|
||||
|
||||
def __init__(self):
|
||||
self.tensors = {}
|
||||
self.variables = {}
|
||||
self.gradients = []
|
||||
|
||||
def __eq__(self, other):
|
||||
@ -49,21 +50,22 @@ def _watch_with_tape_internal(_, tensor):
|
||||
return tensor
|
||||
|
||||
|
||||
def _watch_with_tape(tape, tensor):
|
||||
def _watch_with_tape(tape, resource_variable):
|
||||
"""Wraps a watched Tensor and keeps track of it in the implicit tape."""
|
||||
tensor = resource_variable.handle
|
||||
w = _watch_with_tape_internal(tape, tensor)
|
||||
if ag_core.isnode(tape):
|
||||
tape.value.variables[ops.tensor_id(tensor)] = resource_variable
|
||||
tape.value.tensors[ops.tensor_id(tensor)] = w
|
||||
return w
|
||||
|
||||
|
||||
def _watch_with_tape_vjp(g, ans, vs, gvs, tape, tensor):
|
||||
"""Gradient for _watch_with_tape_internal."""
|
||||
del ans, gvs, tape
|
||||
del ans, gvs
|
||||
|
||||
def mut_add(implicit_tape):
|
||||
t = ag_core.getval(tensor)
|
||||
implicit_tape.gradients.append((t, g))
|
||||
resource_variable = tape.value.variables[ops.tensor_id(tensor)]
|
||||
implicit_tape.gradients.append((g, resource_variable))
|
||||
return implicit_tape
|
||||
|
||||
return ag_core.SparseObject(vs, mut_add)
|
||||
@ -137,27 +139,14 @@ def push_new_tape():
|
||||
ag_core.active_progenitors.add(progenitor)
|
||||
|
||||
|
||||
def watch(tensor):
|
||||
"""Marks this tensor to be watched by all tapes in the stack.
|
||||
|
||||
Args:
|
||||
tensor: tensor to be watched.
|
||||
|
||||
Returns:
|
||||
The tensor, potentially wrapped by all tapes in the stack.
|
||||
"""
|
||||
for t in _tape_stack.stack:
|
||||
tensor = _watch_with_tape(t, tensor)
|
||||
return tensor
|
||||
|
||||
|
||||
def watch_variable(resource_variable):
|
||||
"""Marks this ResourceVariable to be watched by all tapes in the stack.
|
||||
|
||||
Args:
|
||||
resource_variable: A ResourceVariable to be watched.
|
||||
"""
|
||||
watch(resource_variable.handle) # py-lint: disable=protected-access
|
||||
for t in _tape_stack.stack:
|
||||
_watch_with_tape(t, resource_variable)
|
||||
|
||||
|
||||
def pop_tape():
|
||||
@ -175,7 +164,7 @@ def any_tape_has(tensor):
|
||||
|
||||
|
||||
def should_record(tensors):
|
||||
"""Returns true if any tape in the stach watches any of these tensors."""
|
||||
"""Returns true if any tape in the stack watches any of these tensors."""
|
||||
return any(ag_core.isnode(x) for x in tensors)
|
||||
|
||||
|
||||
|
@ -505,7 +505,7 @@ class ResourceVariable(variables.Variable):
|
||||
|
||||
def _read_variable_op(self):
|
||||
if hasattr(self, "_trainable") and self._trainable:
|
||||
tape.watch(self._handle)
|
||||
tape.watch_variable(self)
|
||||
return read_variable_op(self._handle, dtype=self._dtype)
|
||||
else:
|
||||
return gen_resource_variable_ops.read_variable_op(self._handle,
|
||||
@ -540,7 +540,7 @@ class ResourceVariable(variables.Variable):
|
||||
"""Reads the value of this variable sparsely, using `gather`."""
|
||||
with ops.name_scope("Gather" if name is None else name) as name:
|
||||
if self._trainable:
|
||||
tape.watch(self._handle)
|
||||
tape.watch_variable(self)
|
||||
value = resource_gather(
|
||||
self._handle, indices, dtype=self._dtype, name=name)
|
||||
return array_ops.identity(value)
|
||||
|
Loading…
Reference in New Issue
Block a user