diff --git a/tensorflow/python/ops/control_flow_grad.py b/tensorflow/python/ops/control_flow_grad.py index 0187771b0d2..292e49593de 100644 --- a/tensorflow/python/ops/control_flow_grad.py +++ b/tensorflow/python/ops/control_flow_grad.py @@ -183,7 +183,10 @@ def _EnterGrad(op, grad): return grad if op.get_attr("is_constant"): # Add a gradient accumulator for each loop invariant. - result = grad_ctxt.AddBackPropAccumulator(grad) + if isinstance(grad, ops.IndexedSlices): + result = grad_ctxt.AddBackPropIndexedSlicesAccumulator(grad) + else: + result = grad_ctxt.AddBackPropAccumulator(grad) else: result = exit(grad) grad_ctxt.ExitResult([result]) diff --git a/tensorflow/python/ops/control_flow_ops.py b/tensorflow/python/ops/control_flow_ops.py index 025386aa163..4153fb96567 100644 --- a/tensorflow/python/ops/control_flow_ops.py +++ b/tensorflow/python/ops/control_flow_ops.py @@ -1078,10 +1078,13 @@ class CondContext(ControlFlowContext): def BuildCondBranch(self, fn): """Add the subgraph defined by fn() to the graph.""" r = fn() + original_r = r result = [] if r is not None: if not isinstance(r, list) and not isinstance(r, _basetuple): r = [r] + original_r = [original_r] + r = _convert_tensorarrays_to_flows(r) for v in r: real_v = v if isinstance(v, ops.Operation): @@ -1100,7 +1103,7 @@ class CondContext(ControlFlowContext): if external_v is not None: real_v = external_v result.append(real_v) - return result + return original_r, result def cond(pred, fn1, fn2, name=None): @@ -1154,14 +1157,14 @@ def cond(pred, fn1, fn2, name=None): # Build the graph for the true branch in a new context. context_t = CondContext(pred, pivot_1, branch=1) context_t.Enter() - res_t = context_t.BuildCondBranch(fn1) + orig_res, res_t = context_t.BuildCondBranch(fn1) context_t.ExitResult(res_t) context_t.Exit() # Build the graph for the false branch in a new context. context_f = CondContext(pred, pivot_2, branch=0) context_f.Enter() - res_f = context_f.BuildCondBranch(fn2) + _, res_f = context_f.BuildCondBranch(fn2) context_f.ExitResult(res_f) context_f.Exit() @@ -1180,6 +1183,7 @@ def cond(pred, fn1, fn2, name=None): raise ValueError("Outputs of fn1 and fn2 must have the same type: " "%s, %s" % (val_x.dtype.name, val_y.dtype.name)) merges = [merge([x[0], x[1]])[0] for x in zip(res_f, res_t)] + merges = _convert_flows_to_tensorarrays(orig_res, merges) return merges[0] if len(merges) == 1 else merges @@ -1322,9 +1326,10 @@ class WhileContext(ControlFlowContext): else: # Control edges must be in the same context. for x in op.control_inputs: + assert x._get_control_flow_context() == self, ( "Control inputs must come from Operations in the same while " - "loop context (not an outer context).") + "loop context (not an outer context)." + str(x)) for x in op.outputs: self._values.add(x.name) else: @@ -1455,12 +1460,61 @@ class WhileContext(ControlFlowContext): add_acc = math_ops.add(switch_acc[1], value) next_acc = _NextIteration(add_acc) - merge_acc.op._update_input(1, next_acc) + merge_acc.op._update_input(1, next_acc) # pylint: disable=protected-access acc_result = exit(switch_acc[0], name="b_acc") self.ExitResult([acc_result]) return acc_result + def AddBackPropIndexedSlicesAccumulator(self, value): + """This is used for accumulating gradients that are IndexedSlices. + + This is essentially the equavalent of AddBackPropAccumulator but optimized + for things like updating embeddings from within a while loop. + + Args: + value: The partial gradients represented as an IndexedSlices. + + Returns: + The accumulated IndexedSlices gradient of the loop invariant. + """ + values = value.values + indices = value.indices + + self.Exit() + shape = tensor_shape.TensorShape([tensor_shape.Dimension(1)] + + values.get_shape().dims[1:]) + if not shape.is_fully_defined(): + shape = None + if self.outer_context: self.outer_context.Enter() + values_acc = constant_op.constant(0, values.dtype, shape=shape, + name="b_acc") + if not shape: + values_acc._shape = shape # pylint: disable=protected-access + indices_acc = constant_op.constant([0], indices.dtype) + if self.outer_context: self.outer_context.Exit() + self.Enter() + self.AddName(values_acc.name) + self.AddName(indices_acc.name) + enter_acc = [_Enter(x, self._name, is_constant=False, + parallel_iterations=self._parallel_iterations, + name="b_acc") for x in [indices_acc, values_acc]] + merge_acc = [merge([x, x], name="b_acc")[0] for x in enter_acc] + switch_acc = [switch(x, self._pivot) for x in merge_acc] + + # The actual accumulation. + acc_value = [array_ops.concat(0, [xa[1], xv]) + for xa, xv in zip(switch_acc, [indices, values])] + + next_acc = [_NextIteration(x) for x in acc_value] + for xm, xn in zip(merge_acc, next_acc): + xm.op._update_input(1, xn) # pylint: disable=protected-access + + acc_result = [exit(x[0], name="b_acc") for x in switch_acc] + self.ExitResult(acc_result) + return ops.IndexedSlices(values=acc_result[1], indices=acc_result[0], + dense_shape=self.ExitResult(value.dense_shape)) + def BuildLoop(self, pred, body, loop_vars): """Add the loop termination condition and body to the graph.""" diff --git a/tensorflow/python/ops/control_flow_ops_test.py b/tensorflow/python/ops/control_flow_ops_test.py index c537ed1b734..399de012108 100644 --- a/tensorflow/python/ops/control_flow_ops_test.py +++ b/tensorflow/python/ops/control_flow_ops_test.py @@ -22,8 +22,10 @@ from tensorflow.core.framework import graph_pb2 from tensorflow.python.framework import ops from tensorflow.python.framework.test_util import TensorFlowTestCase from tensorflow.python.ops import control_flow_ops +from tensorflow.python.ops import embedding_ops from tensorflow.python.ops import standard_ops as tf from tensorflow.python.platform import googletest +from tensorflow.python.training import momentum class GroupTestCase(TensorFlowTestCase): @@ -115,6 +117,59 @@ class SwitchTestCase(TensorFlowTestCase): self.assertAllEqual([1, 2, 3], switch_true.values.eval()) self.assertAllEqual([0, 1], switch_true.indices.eval()) + def testIndexedSlicesGradient(self): + with ops.Graph().as_default(): + embedding_matrix = tf.get_variable( + "embedding_matrix", [5, 5], + initializer=tf.random_normal_initializer()) + def Cond(it, _): + return it < 5 + def Body(it, cost): + embedding = embedding_ops.embedding_lookup(embedding_matrix + 0.0, [0]) + cost += tf.reduce_sum(embedding) + return it + 1, cost + _, cost = control_flow_ops.While( + Cond, Body, [tf.constant(0), tf.constant(0.0)]) + optimizer = momentum.MomentumOptimizer(0.1, 0.9) + train_op = optimizer.minimize(cost) + with self.test_session() as sess: + sess.run(tf.initialize_all_variables()) + for _ in range(10): + sess.run([train_op]) + + def testIndexedSlicesGradientInCondInWhileLoop(self): + with ops.Graph().as_default(): + embedding_matrix = tf.get_variable( + "embedding_matrix", [5, 5], + initializer=tf.random_normal_initializer()) + + def Cond(it, _): + return it < 5 + def Body(it, cost): + embedding = embedding_ops.embedding_lookup(embedding_matrix, [0]) + cost = tf.cond(tf.equal(it, 3), + lambda: tf.square(cost), + lambda: cost + tf.reduce_sum(embedding)) + return it + 1, cost + _, cost = control_flow_ops.While( + Cond, Body, [tf.constant(0), tf.constant(0.0)]) + + dynamic_grads = tf.gradients(cost, [embedding_matrix])[0] + dynamic_grads = tf.segment_sum(dynamic_grads.values, + dynamic_grads.indices) + + embedding = embedding_ops.embedding_lookup(embedding_matrix, [0]) + static = tf.square( + tf.reduce_sum(embedding) + + tf.reduce_sum(embedding) + + tf.reduce_sum(embedding)) + tf.reduce_sum(embedding) + static_grads = tf.gradients(static, [embedding_matrix])[0] + static_grads = tf.segment_sum(static_grads.values, static_grads.indices) + + with self.test_session() as sess: + sess.run(tf.initialize_all_variables()) + self.assertAllEqual(*sess.run([static_grads, dynamic_grads])) + if __name__ == "__main__": googletest.main()