From 7e7dff529fd35edea443580d58e95f5de39b5356 Mon Sep 17 00:00:00 2001
From: Jingtian Peng <pjt73651@gmail.com>
Date: Thu, 8 Sep 2016 07:25:26 +0800
Subject: [PATCH] Allow tf.Variable objects to be passed as gradients in
 tf.train.Optimizer

Also supports other convertible-to-tensor arguments (such as Python lists, NumPy arrays, etc.) as gradients.

* Fixes #3994.
---
 tensorflow/python/training/optimizer.py      | 21 +++++++++++----
 tensorflow/python/training/optimizer_test.py | 27 ++++++++++++++++++++
 2 files changed, 43 insertions(+), 5 deletions(-)

diff --git a/tensorflow/python/training/optimizer.py b/tensorflow/python/training/optimizer.py
index deeec2f6e3e..3990c042e5a 100644
--- a/tensorflow/python/training/optimizer.py
+++ b/tensorflow/python/training/optimizer.py
@@ -282,26 +282,37 @@ class Optimizer(object):
     # This is a default implementation of apply_gradients() that can be shared
     # by most optimizers.  It relies on the subclass implementing the following
     # methods: _create_slots(), _prepare(), _apply_dense(), and _apply_sparse().
+
     grads_and_vars = tuple(grads_and_vars)  # Make sure repeat iteration works
+    converted_grads_and_vars = []
     for g, v in grads_and_vars:
+      if g is not None:
+        try:
+          # Convert the grad to Tensor or IndexedSlices if necessary
+          g = ops.convert_to_tensor_or_indexed_slices(g)
+        except TypeError:
+          raise TypeError(
+              "Gradient must be convertible to a Tensor or IndexedSlices, or None: %s" %g)  
       if not isinstance(g, (ops.Tensor, ops.IndexedSlices, type(None))):
         raise TypeError(
             "Gradient must be a Tensor, IndexedSlices, or None: %s" % g)
       if not isinstance(v, variables.Variable):
         raise TypeError(
             "Variable must be a tf.Variable: %s" % v)
-      if g is not None:
-        self._assert_valid_dtypes([g, v])
-    var_list = [v for g, v in grads_and_vars if g is not None]
+
+      converted_grads_and_vars.append((g,v))
+    
+    converted_grads_and_vars = tuple(converted_grads_and_vars)
+    var_list = [v for g, v in converted_grads_and_vars if g is not None]
     if not var_list:
       raise ValueError("No gradients provided for any variable: %s" %
-                       (grads_and_vars,))
+                       (converted_grads_and_vars,))
     with ops.control_dependencies(None):
       self._create_slots(var_list)
     update_ops = []
     with ops.name_scope(name, self._name) as name:
       self._prepare()
-      for grad, var in grads_and_vars:
+      for grad, var in converted_grads_and_vars:
         if grad is None:
           continue
         # We colocate all ops created in _apply_dense or _apply_sparse
diff --git a/tensorflow/python/training/optimizer_test.py b/tensorflow/python/training/optimizer_test.py
index 13e8cb9b258..ab4eecf7be2 100644
--- a/tensorflow/python/training/optimizer_test.py
+++ b/tensorflow/python/training/optimizer_test.py
@@ -113,6 +113,33 @@ class OptimizerTest(tf.test.TestCase):
           # var1 has no gradient
           sgd_op.minimize(cost, global_step, [var1])
 
+  def testGradientsAsVariables(self):
+    for dtype in [tf.half, tf.float32, tf.float64]:
+      with self.test_session() as sess:
+        var0 = tf.Variable([1.0, 2.0], dtype=dtype)
+        var1 = tf.Variable([3.0, 4.0], dtype=dtype)
+        cost = 5 * var0 + 3 * var1
+        global_step = tf.Variable(tf.zeros([], tf.int64), name='global_step')
+        sgd_op = tf.train.GradientDescentOptimizer(3.0)
+        grads_and_vars = sgd_op.compute_gradients(cost, [var0, var1])
+        # Convert gradients to tf.Variables
+        converted_grads = [tf.Variable(tf.zeros([2], dtype)) for i in grads_and_vars]
+        convert_ops = [tf.assign(converted_grads[i], gv[0]) for i,gv in enumerate(grads_and_vars)]
+        
+        converted_grads_and_vars = list(zip(converted_grads, [var0, var1]))
+        opt_op = sgd_op.apply_gradients(converted_grads_and_vars, global_step)
+
+        tf.initialize_all_variables().run()
+        # Run convert_ops to achieve the gradietns converting
+        sess.run(convert_ops)
+        # Fetch params to validate initial values
+        self.assertAllClose([1.0, 2.0], var0.eval())
+        self.assertAllClose([3.0, 4.0], var1.eval())
+        # Run 1 step of sgd through optimizer
+        opt_op.run()
+        # Validate updated params
+        self.assertAllClose([-14., -13.], var0.eval())
+        self.assertAllClose([-6., -5.], var1.eval()) 
 
 if __name__ == '__main__':
   tf.test.main()