Fix for IndexedSlices gradient accumulation in while loop.
Change: 120136109
This commit is contained in:
parent
ec89b0c218
commit
3402f51ecd
@ -183,7 +183,10 @@ def _EnterGrad(op, grad):
|
||||
return grad
|
||||
if op.get_attr("is_constant"):
|
||||
# Add a gradient accumulator for each loop invariant.
|
||||
result = grad_ctxt.AddBackPropAccumulator(grad)
|
||||
if isinstance(grad, ops.IndexedSlices):
|
||||
result = grad_ctxt.AddBackPropIndexedSlicesAccumulator(grad)
|
||||
else:
|
||||
result = grad_ctxt.AddBackPropAccumulator(grad)
|
||||
else:
|
||||
result = exit(grad)
|
||||
grad_ctxt.ExitResult([result])
|
||||
|
@ -1078,10 +1078,13 @@ class CondContext(ControlFlowContext):
|
||||
def BuildCondBranch(self, fn):
|
||||
"""Add the subgraph defined by fn() to the graph."""
|
||||
r = fn()
|
||||
original_r = r
|
||||
result = []
|
||||
if r is not None:
|
||||
if not isinstance(r, list) and not isinstance(r, _basetuple):
|
||||
r = [r]
|
||||
original_r = [original_r]
|
||||
r = _convert_tensorarrays_to_flows(r)
|
||||
for v in r:
|
||||
real_v = v
|
||||
if isinstance(v, ops.Operation):
|
||||
@ -1100,7 +1103,7 @@ class CondContext(ControlFlowContext):
|
||||
if external_v is not None:
|
||||
real_v = external_v
|
||||
result.append(real_v)
|
||||
return result
|
||||
return original_r, result
|
||||
|
||||
|
||||
def cond(pred, fn1, fn2, name=None):
|
||||
@ -1154,14 +1157,14 @@ def cond(pred, fn1, fn2, name=None):
|
||||
# Build the graph for the true branch in a new context.
|
||||
context_t = CondContext(pred, pivot_1, branch=1)
|
||||
context_t.Enter()
|
||||
res_t = context_t.BuildCondBranch(fn1)
|
||||
orig_res, res_t = context_t.BuildCondBranch(fn1)
|
||||
context_t.ExitResult(res_t)
|
||||
context_t.Exit()
|
||||
|
||||
# Build the graph for the false branch in a new context.
|
||||
context_f = CondContext(pred, pivot_2, branch=0)
|
||||
context_f.Enter()
|
||||
res_f = context_f.BuildCondBranch(fn2)
|
||||
_, res_f = context_f.BuildCondBranch(fn2)
|
||||
context_f.ExitResult(res_f)
|
||||
context_f.Exit()
|
||||
|
||||
@ -1180,6 +1183,7 @@ def cond(pred, fn1, fn2, name=None):
|
||||
raise ValueError("Outputs of fn1 and fn2 must have the same type: "
|
||||
"%s, %s" % (val_x.dtype.name, val_y.dtype.name))
|
||||
merges = [merge([x[0], x[1]])[0] for x in zip(res_f, res_t)]
|
||||
merges = _convert_flows_to_tensorarrays(orig_res, merges)
|
||||
return merges[0] if len(merges) == 1 else merges
|
||||
|
||||
|
||||
@ -1322,9 +1326,10 @@ class WhileContext(ControlFlowContext):
|
||||
else:
|
||||
# Control edges must be in the same context.
|
||||
for x in op.control_inputs:
|
||||
|
||||
assert x._get_control_flow_context() == self, (
|
||||
"Control inputs must come from Operations in the same while "
|
||||
"loop context (not an outer context).")
|
||||
"loop context (not an outer context)." + str(x))
|
||||
for x in op.outputs:
|
||||
self._values.add(x.name)
|
||||
else:
|
||||
@ -1455,12 +1460,61 @@ class WhileContext(ControlFlowContext):
|
||||
|
||||
add_acc = math_ops.add(switch_acc[1], value)
|
||||
next_acc = _NextIteration(add_acc)
|
||||
merge_acc.op._update_input(1, next_acc)
|
||||
merge_acc.op._update_input(1, next_acc) # pylint: disable=protected-access
|
||||
|
||||
acc_result = exit(switch_acc[0], name="b_acc")
|
||||
self.ExitResult([acc_result])
|
||||
return acc_result
|
||||
|
||||
def AddBackPropIndexedSlicesAccumulator(self, value):
|
||||
"""This is used for accumulating gradients that are IndexedSlices.
|
||||
|
||||
This is essentially the equavalent of AddBackPropAccumulator but optimized
|
||||
for things like updating embeddings from within a while loop.
|
||||
|
||||
Args:
|
||||
value: The partial gradients represented as an IndexedSlices.
|
||||
|
||||
Returns:
|
||||
The accumulated IndexedSlices gradient of the loop invariant.
|
||||
"""
|
||||
values = value.values
|
||||
indices = value.indices
|
||||
|
||||
self.Exit()
|
||||
shape = tensor_shape.TensorShape([tensor_shape.Dimension(1)] +
|
||||
values.get_shape().dims[1:])
|
||||
if not shape.is_fully_defined():
|
||||
shape = None
|
||||
if self.outer_context: self.outer_context.Enter()
|
||||
values_acc = constant_op.constant(0, values.dtype, shape=shape,
|
||||
name="b_acc")
|
||||
if not shape:
|
||||
values_acc._shape = shape # pylint: disable=protected-access
|
||||
indices_acc = constant_op.constant([0], indices.dtype)
|
||||
if self.outer_context: self.outer_context.Exit()
|
||||
self.Enter()
|
||||
self.AddName(values_acc.name)
|
||||
self.AddName(indices_acc.name)
|
||||
enter_acc = [_Enter(x, self._name, is_constant=False,
|
||||
parallel_iterations=self._parallel_iterations,
|
||||
name="b_acc") for x in [indices_acc, values_acc]]
|
||||
merge_acc = [merge([x, x], name="b_acc")[0] for x in enter_acc]
|
||||
switch_acc = [switch(x, self._pivot) for x in merge_acc]
|
||||
|
||||
# The actual accumulation.
|
||||
acc_value = [array_ops.concat(0, [xa[1], xv])
|
||||
for xa, xv in zip(switch_acc, [indices, values])]
|
||||
|
||||
next_acc = [_NextIteration(x) for x in acc_value]
|
||||
for xm, xn in zip(merge_acc, next_acc):
|
||||
xm.op._update_input(1, xn) # pylint: disable=protected-access
|
||||
|
||||
acc_result = [exit(x[0], name="b_acc") for x in switch_acc]
|
||||
self.ExitResult(acc_result)
|
||||
return ops.IndexedSlices(values=acc_result[1], indices=acc_result[0],
|
||||
dense_shape=self.ExitResult(value.dense_shape))
|
||||
|
||||
def BuildLoop(self, pred, body, loop_vars):
|
||||
"""Add the loop termination condition and body to the graph."""
|
||||
|
||||
|
@ -22,8 +22,10 @@ from tensorflow.core.framework import graph_pb2
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.framework.test_util import TensorFlowTestCase
|
||||
from tensorflow.python.ops import control_flow_ops
|
||||
from tensorflow.python.ops import embedding_ops
|
||||
from tensorflow.python.ops import standard_ops as tf
|
||||
from tensorflow.python.platform import googletest
|
||||
from tensorflow.python.training import momentum
|
||||
|
||||
|
||||
class GroupTestCase(TensorFlowTestCase):
|
||||
@ -115,6 +117,59 @@ class SwitchTestCase(TensorFlowTestCase):
|
||||
self.assertAllEqual([1, 2, 3], switch_true.values.eval())
|
||||
self.assertAllEqual([0, 1], switch_true.indices.eval())
|
||||
|
||||
def testIndexedSlicesGradient(self):
|
||||
with ops.Graph().as_default():
|
||||
embedding_matrix = tf.get_variable(
|
||||
"embedding_matrix", [5, 5],
|
||||
initializer=tf.random_normal_initializer())
|
||||
def Cond(it, _):
|
||||
return it < 5
|
||||
def Body(it, cost):
|
||||
embedding = embedding_ops.embedding_lookup(embedding_matrix + 0.0, [0])
|
||||
cost += tf.reduce_sum(embedding)
|
||||
return it + 1, cost
|
||||
_, cost = control_flow_ops.While(
|
||||
Cond, Body, [tf.constant(0), tf.constant(0.0)])
|
||||
optimizer = momentum.MomentumOptimizer(0.1, 0.9)
|
||||
train_op = optimizer.minimize(cost)
|
||||
with self.test_session() as sess:
|
||||
sess.run(tf.initialize_all_variables())
|
||||
for _ in range(10):
|
||||
sess.run([train_op])
|
||||
|
||||
def testIndexedSlicesGradientInCondInWhileLoop(self):
|
||||
with ops.Graph().as_default():
|
||||
embedding_matrix = tf.get_variable(
|
||||
"embedding_matrix", [5, 5],
|
||||
initializer=tf.random_normal_initializer())
|
||||
|
||||
def Cond(it, _):
|
||||
return it < 5
|
||||
def Body(it, cost):
|
||||
embedding = embedding_ops.embedding_lookup(embedding_matrix, [0])
|
||||
cost = tf.cond(tf.equal(it, 3),
|
||||
lambda: tf.square(cost),
|
||||
lambda: cost + tf.reduce_sum(embedding))
|
||||
return it + 1, cost
|
||||
_, cost = control_flow_ops.While(
|
||||
Cond, Body, [tf.constant(0), tf.constant(0.0)])
|
||||
|
||||
dynamic_grads = tf.gradients(cost, [embedding_matrix])[0]
|
||||
dynamic_grads = tf.segment_sum(dynamic_grads.values,
|
||||
dynamic_grads.indices)
|
||||
|
||||
embedding = embedding_ops.embedding_lookup(embedding_matrix, [0])
|
||||
static = tf.square(
|
||||
tf.reduce_sum(embedding) +
|
||||
tf.reduce_sum(embedding) +
|
||||
tf.reduce_sum(embedding)) + tf.reduce_sum(embedding)
|
||||
static_grads = tf.gradients(static, [embedding_matrix])[0]
|
||||
static_grads = tf.segment_sum(static_grads.values, static_grads.indices)
|
||||
|
||||
with self.test_session() as sess:
|
||||
sess.run(tf.initialize_all_variables())
|
||||
self.assertAllEqual(*sess.run([static_grads, dynamic_grads]))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
googletest.main()
|
||||
|
Loading…
Reference in New Issue
Block a user