Allow tf.Variable objects to be passed as gradients in tf.train.Optimizer
Also supports other convertible-to-tensor arguments (such as Python lists, NumPy arrays, etc.) as gradients. * Fixes #3994.
This commit is contained in:
parent
f71cc62282
commit
7e7dff529f
@ -282,26 +282,37 @@ class Optimizer(object):
|
|||||||
# This is a default implementation of apply_gradients() that can be shared
|
# This is a default implementation of apply_gradients() that can be shared
|
||||||
# by most optimizers. It relies on the subclass implementing the following
|
# by most optimizers. It relies on the subclass implementing the following
|
||||||
# methods: _create_slots(), _prepare(), _apply_dense(), and _apply_sparse().
|
# methods: _create_slots(), _prepare(), _apply_dense(), and _apply_sparse().
|
||||||
|
|
||||||
grads_and_vars = tuple(grads_and_vars) # Make sure repeat iteration works
|
grads_and_vars = tuple(grads_and_vars) # Make sure repeat iteration works
|
||||||
|
converted_grads_and_vars = []
|
||||||
for g, v in grads_and_vars:
|
for g, v in grads_and_vars:
|
||||||
|
if g is not None:
|
||||||
|
try:
|
||||||
|
# Convert the grad to Tensor or IndexedSlices if necessary
|
||||||
|
g = ops.convert_to_tensor_or_indexed_slices(g)
|
||||||
|
except TypeError:
|
||||||
|
raise TypeError(
|
||||||
|
"Gradient must be convertible to a Tensor or IndexedSlices, or None: %s" %g)
|
||||||
if not isinstance(g, (ops.Tensor, ops.IndexedSlices, type(None))):
|
if not isinstance(g, (ops.Tensor, ops.IndexedSlices, type(None))):
|
||||||
raise TypeError(
|
raise TypeError(
|
||||||
"Gradient must be a Tensor, IndexedSlices, or None: %s" % g)
|
"Gradient must be a Tensor, IndexedSlices, or None: %s" % g)
|
||||||
if not isinstance(v, variables.Variable):
|
if not isinstance(v, variables.Variable):
|
||||||
raise TypeError(
|
raise TypeError(
|
||||||
"Variable must be a tf.Variable: %s" % v)
|
"Variable must be a tf.Variable: %s" % v)
|
||||||
if g is not None:
|
|
||||||
self._assert_valid_dtypes([g, v])
|
converted_grads_and_vars.append((g,v))
|
||||||
var_list = [v for g, v in grads_and_vars if g is not None]
|
|
||||||
|
converted_grads_and_vars = tuple(converted_grads_and_vars)
|
||||||
|
var_list = [v for g, v in converted_grads_and_vars if g is not None]
|
||||||
if not var_list:
|
if not var_list:
|
||||||
raise ValueError("No gradients provided for any variable: %s" %
|
raise ValueError("No gradients provided for any variable: %s" %
|
||||||
(grads_and_vars,))
|
(converted_grads_and_vars,))
|
||||||
with ops.control_dependencies(None):
|
with ops.control_dependencies(None):
|
||||||
self._create_slots(var_list)
|
self._create_slots(var_list)
|
||||||
update_ops = []
|
update_ops = []
|
||||||
with ops.name_scope(name, self._name) as name:
|
with ops.name_scope(name, self._name) as name:
|
||||||
self._prepare()
|
self._prepare()
|
||||||
for grad, var in grads_and_vars:
|
for grad, var in converted_grads_and_vars:
|
||||||
if grad is None:
|
if grad is None:
|
||||||
continue
|
continue
|
||||||
# We colocate all ops created in _apply_dense or _apply_sparse
|
# We colocate all ops created in _apply_dense or _apply_sparse
|
||||||
|
@ -113,6 +113,33 @@ class OptimizerTest(tf.test.TestCase):
|
|||||||
# var1 has no gradient
|
# var1 has no gradient
|
||||||
sgd_op.minimize(cost, global_step, [var1])
|
sgd_op.minimize(cost, global_step, [var1])
|
||||||
|
|
||||||
|
def testGradientsAsVariables(self):
|
||||||
|
for dtype in [tf.half, tf.float32, tf.float64]:
|
||||||
|
with self.test_session() as sess:
|
||||||
|
var0 = tf.Variable([1.0, 2.0], dtype=dtype)
|
||||||
|
var1 = tf.Variable([3.0, 4.0], dtype=dtype)
|
||||||
|
cost = 5 * var0 + 3 * var1
|
||||||
|
global_step = tf.Variable(tf.zeros([], tf.int64), name='global_step')
|
||||||
|
sgd_op = tf.train.GradientDescentOptimizer(3.0)
|
||||||
|
grads_and_vars = sgd_op.compute_gradients(cost, [var0, var1])
|
||||||
|
# Convert gradients to tf.Variables
|
||||||
|
converted_grads = [tf.Variable(tf.zeros([2], dtype)) for i in grads_and_vars]
|
||||||
|
convert_ops = [tf.assign(converted_grads[i], gv[0]) for i,gv in enumerate(grads_and_vars)]
|
||||||
|
|
||||||
|
converted_grads_and_vars = list(zip(converted_grads, [var0, var1]))
|
||||||
|
opt_op = sgd_op.apply_gradients(converted_grads_and_vars, global_step)
|
||||||
|
|
||||||
|
tf.initialize_all_variables().run()
|
||||||
|
# Run convert_ops to achieve the gradietns converting
|
||||||
|
sess.run(convert_ops)
|
||||||
|
# Fetch params to validate initial values
|
||||||
|
self.assertAllClose([1.0, 2.0], var0.eval())
|
||||||
|
self.assertAllClose([3.0, 4.0], var1.eval())
|
||||||
|
# Run 1 step of sgd through optimizer
|
||||||
|
opt_op.run()
|
||||||
|
# Validate updated params
|
||||||
|
self.assertAllClose([-14., -13.], var0.eval())
|
||||||
|
self.assertAllClose([-6., -5.], var1.eval())
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
tf.test.main()
|
tf.test.main()
|
||||||
|
Loading…
Reference in New Issue
Block a user