Allow tf.Variable objects to be passed as gradients in tf.train.Optimizer

Also supports other convertible-to-tensor arguments (such as Python lists, NumPy arrays, etc.) as gradients.

* Fixes #3994.
This commit is contained in:
Jingtian Peng 2016-09-08 07:25:26 +08:00 committed by Derek Murray
parent f71cc62282
commit 7e7dff529f
2 changed files with 43 additions and 5 deletions

View File

@ -282,26 +282,37 @@ class Optimizer(object):
# This is a default implementation of apply_gradients() that can be shared
# by most optimizers. It relies on the subclass implementing the following
# methods: _create_slots(), _prepare(), _apply_dense(), and _apply_sparse().
grads_and_vars = tuple(grads_and_vars) # Make sure repeat iteration works
converted_grads_and_vars = []
for g, v in grads_and_vars:
if g is not None:
try:
# Convert the grad to Tensor or IndexedSlices if necessary
g = ops.convert_to_tensor_or_indexed_slices(g)
except TypeError:
raise TypeError(
"Gradient must be convertible to a Tensor or IndexedSlices, or None: %s" %g)
if not isinstance(g, (ops.Tensor, ops.IndexedSlices, type(None))):
raise TypeError(
"Gradient must be a Tensor, IndexedSlices, or None: %s" % g)
if not isinstance(v, variables.Variable):
raise TypeError(
"Variable must be a tf.Variable: %s" % v)
if g is not None:
self._assert_valid_dtypes([g, v])
var_list = [v for g, v in grads_and_vars if g is not None]
converted_grads_and_vars.append((g,v))
converted_grads_and_vars = tuple(converted_grads_and_vars)
var_list = [v for g, v in converted_grads_and_vars if g is not None]
if not var_list:
raise ValueError("No gradients provided for any variable: %s" %
(grads_and_vars,))
(converted_grads_and_vars,))
with ops.control_dependencies(None):
self._create_slots(var_list)
update_ops = []
with ops.name_scope(name, self._name) as name:
self._prepare()
for grad, var in grads_and_vars:
for grad, var in converted_grads_and_vars:
if grad is None:
continue
# We colocate all ops created in _apply_dense or _apply_sparse

View File

@ -113,6 +113,33 @@ class OptimizerTest(tf.test.TestCase):
# var1 has no gradient
sgd_op.minimize(cost, global_step, [var1])
def testGradientsAsVariables(self):
for dtype in [tf.half, tf.float32, tf.float64]:
with self.test_session() as sess:
var0 = tf.Variable([1.0, 2.0], dtype=dtype)
var1 = tf.Variable([3.0, 4.0], dtype=dtype)
cost = 5 * var0 + 3 * var1
global_step = tf.Variable(tf.zeros([], tf.int64), name='global_step')
sgd_op = tf.train.GradientDescentOptimizer(3.0)
grads_and_vars = sgd_op.compute_gradients(cost, [var0, var1])
# Convert gradients to tf.Variables
converted_grads = [tf.Variable(tf.zeros([2], dtype)) for i in grads_and_vars]
convert_ops = [tf.assign(converted_grads[i], gv[0]) for i,gv in enumerate(grads_and_vars)]
converted_grads_and_vars = list(zip(converted_grads, [var0, var1]))
opt_op = sgd_op.apply_gradients(converted_grads_and_vars, global_step)
tf.initialize_all_variables().run()
# Run convert_ops to achieve the gradietns converting
sess.run(convert_ops)
# Fetch params to validate initial values
self.assertAllClose([1.0, 2.0], var0.eval())
self.assertAllClose([3.0, 4.0], var1.eval())
# Run 1 step of sgd through optimizer
opt_op.run()
# Validate updated params
self.assertAllClose([-14., -13.], var0.eval())
self.assertAllClose([-6., -5.], var1.eval())
if __name__ == '__main__':
tf.test.main()