Add an argument to control memory swap for while loop gradient.

Change: 115637967
This commit is contained in:
Yuan Yu 2016-02-25 21:19:18 -08:00 committed by TensorFlower Gardener
parent 5c5b29adda
commit 5d01083d62

View File

@ -530,6 +530,7 @@ class GradLoopState(object):
outer_grad_ctxt.Enter()
self._grad_context = WhileContext(forward_ctxt.parallel_iterations,
forward_ctxt.back_prop,
forward_ctxt.swap_memory,
forward_ctxt.name)
real_cnt = outer_grad_state.AddBackPropAccumulatedValue(history_cnt, cnt)
self._grad_index = self._grad_context.AddBackPropCounter(real_cnt)
@ -538,6 +539,7 @@ class GradLoopState(object):
if outer_forward_ctxt: outer_forward_ctxt.Enter()
self._grad_context = WhileContext(forward_ctxt.parallel_iterations,
forward_ctxt.back_prop,
forward_ctxt.swap_memory,
forward_ctxt.name)
self._grad_index = self._grad_context.AddBackPropCounter(cnt)
if outer_forward_ctxt: outer_forward_ctxt.Exit()
@ -642,13 +644,15 @@ class GradLoopState(object):
enter_acc = self.forward_context.AddValue(acc)
# Add the stack_push op in the context of value.op.
swap_enabled = self.forward_context.swap_memory
value_ctxt = value.op._get_control_flow_context()
if _IsLoopExit(value.op):
value_ctxt = value_ctxt.outer_context
if value_ctxt == self.forward_context:
# value is not nested in the forward context.
self.forward_context.Enter()
push = gen_data_flow_ops._stack_push(enter_acc, value)
push = gen_data_flow_ops._stack_push(enter_acc, value,
swap_memory=swap_enabled)
self.forward_context.Exit()
# Protect stack push and order it before forward_index.
self.forward_index.op._add_control_input(push.op)
@ -659,12 +663,14 @@ class GradLoopState(object):
# The special case for creating a zero tensor for a dead
# branch of a switch. See ControlFlowState.ZerosLike().
value_ctxt.outer_context.Enter()
push = gen_data_flow_ops._stack_push(enter_acc, value)
push = gen_data_flow_ops._stack_push(enter_acc, value,
swap_memory=swap_enabled)
value_ctxt.outer_context.Exit()
push.op._set_control_flow_context(value_ctxt)
else:
value_ctxt.Enter()
push = gen_data_flow_ops._stack_push(enter_acc, value)
push = gen_data_flow_ops._stack_push(enter_acc, value,
swap_memory=swap_enabled)
value_ctxt.Exit()
# Protect stack push and order it before forward_sync.
self.forward_sync._add_control_input(push.op)
@ -1242,11 +1248,12 @@ def cond(pred, fn1, fn2, name=None):
class WhileContext(ControlFlowContext):
"""The context for the loop construct."""
def __init__(self, parallel_iterations, back_prop, name):
def __init__(self, parallel_iterations, back_prop, swap_memory, name):
ControlFlowContext.__init__(self)
self._name = ops.get_default_graph().unique_name(name)
self._parallel_iterations = parallel_iterations
self._back_prop = back_prop
self._swap_memory = swap_memory
# We use this node to control constants created by the pred lambda.
self._pivot_for_pred = None
# We use this node to control constants created by the body lambda.
@ -1271,6 +1278,11 @@ class WhileContext(ControlFlowContext):
"""True iff backprop is enabled for this While loop."""
return self._back_prop
@property
def swap_memory(self):
"""True iff GPU-CPU memory swap is enabled for this While loop."""
return self._swap_memory
@property
def pivot(self):
"""The boolean tensor representing the loop termination condition."""
@ -1540,7 +1552,7 @@ class WhileContext(ControlFlowContext):
def While(cond, body, loop_vars, parallel_iterations=10, back_prop=True,
name=None):
swap_memory=False, name=None):
"""Repeat `body` while the condition `cond` is true.
`cond` is a function taking a list of tensors and returning a boolean scalar
@ -1560,6 +1572,7 @@ def While(cond, body, loop_vars, parallel_iterations=10, back_prop=True,
loop_vars: The list of variable input tensors.
parallel_iterations: The number of iterations allowed to run in parallel.
back_prop: Whether backprop is enabled for this while loop.
swap_memory: Whether GPU-CPU memory swap is enabled for this loop.
name: Optional name prefix for the returned tensors.
Returns:
@ -1585,7 +1598,7 @@ def While(cond, body, loop_vars, parallel_iterations=10, back_prop=True,
if not callable(body):
raise TypeError("body must be callable.")
context = WhileContext(parallel_iterations, back_prop, name)
context = WhileContext(parallel_iterations, back_prop, swap_memory, name)
context.Enter()
result = context.BuildLoop(cond, body, loop_vars)
context.Exit()