Add an argument to control memory swap for while loop gradient.
Change: 115637967
This commit is contained in:
parent
5c5b29adda
commit
5d01083d62
@ -530,6 +530,7 @@ class GradLoopState(object):
|
||||
outer_grad_ctxt.Enter()
|
||||
self._grad_context = WhileContext(forward_ctxt.parallel_iterations,
|
||||
forward_ctxt.back_prop,
|
||||
forward_ctxt.swap_memory,
|
||||
forward_ctxt.name)
|
||||
real_cnt = outer_grad_state.AddBackPropAccumulatedValue(history_cnt, cnt)
|
||||
self._grad_index = self._grad_context.AddBackPropCounter(real_cnt)
|
||||
@ -538,6 +539,7 @@ class GradLoopState(object):
|
||||
if outer_forward_ctxt: outer_forward_ctxt.Enter()
|
||||
self._grad_context = WhileContext(forward_ctxt.parallel_iterations,
|
||||
forward_ctxt.back_prop,
|
||||
forward_ctxt.swap_memory,
|
||||
forward_ctxt.name)
|
||||
self._grad_index = self._grad_context.AddBackPropCounter(cnt)
|
||||
if outer_forward_ctxt: outer_forward_ctxt.Exit()
|
||||
@ -642,13 +644,15 @@ class GradLoopState(object):
|
||||
enter_acc = self.forward_context.AddValue(acc)
|
||||
|
||||
# Add the stack_push op in the context of value.op.
|
||||
swap_enabled = self.forward_context.swap_memory
|
||||
value_ctxt = value.op._get_control_flow_context()
|
||||
if _IsLoopExit(value.op):
|
||||
value_ctxt = value_ctxt.outer_context
|
||||
if value_ctxt == self.forward_context:
|
||||
# value is not nested in the forward context.
|
||||
self.forward_context.Enter()
|
||||
push = gen_data_flow_ops._stack_push(enter_acc, value)
|
||||
push = gen_data_flow_ops._stack_push(enter_acc, value,
|
||||
swap_memory=swap_enabled)
|
||||
self.forward_context.Exit()
|
||||
# Protect stack push and order it before forward_index.
|
||||
self.forward_index.op._add_control_input(push.op)
|
||||
@ -659,12 +663,14 @@ class GradLoopState(object):
|
||||
# The special case for creating a zero tensor for a dead
|
||||
# branch of a switch. See ControlFlowState.ZerosLike().
|
||||
value_ctxt.outer_context.Enter()
|
||||
push = gen_data_flow_ops._stack_push(enter_acc, value)
|
||||
push = gen_data_flow_ops._stack_push(enter_acc, value,
|
||||
swap_memory=swap_enabled)
|
||||
value_ctxt.outer_context.Exit()
|
||||
push.op._set_control_flow_context(value_ctxt)
|
||||
else:
|
||||
value_ctxt.Enter()
|
||||
push = gen_data_flow_ops._stack_push(enter_acc, value)
|
||||
push = gen_data_flow_ops._stack_push(enter_acc, value,
|
||||
swap_memory=swap_enabled)
|
||||
value_ctxt.Exit()
|
||||
# Protect stack push and order it before forward_sync.
|
||||
self.forward_sync._add_control_input(push.op)
|
||||
@ -1242,11 +1248,12 @@ def cond(pred, fn1, fn2, name=None):
|
||||
class WhileContext(ControlFlowContext):
|
||||
"""The context for the loop construct."""
|
||||
|
||||
def __init__(self, parallel_iterations, back_prop, name):
|
||||
def __init__(self, parallel_iterations, back_prop, swap_memory, name):
|
||||
ControlFlowContext.__init__(self)
|
||||
self._name = ops.get_default_graph().unique_name(name)
|
||||
self._parallel_iterations = parallel_iterations
|
||||
self._back_prop = back_prop
|
||||
self._swap_memory = swap_memory
|
||||
# We use this node to control constants created by the pred lambda.
|
||||
self._pivot_for_pred = None
|
||||
# We use this node to control constants created by the body lambda.
|
||||
@ -1271,6 +1278,11 @@ class WhileContext(ControlFlowContext):
|
||||
"""True iff backprop is enabled for this While loop."""
|
||||
return self._back_prop
|
||||
|
||||
@property
|
||||
def swap_memory(self):
|
||||
"""True iff GPU-CPU memory swap is enabled for this While loop."""
|
||||
return self._swap_memory
|
||||
|
||||
@property
|
||||
def pivot(self):
|
||||
"""The boolean tensor representing the loop termination condition."""
|
||||
@ -1540,7 +1552,7 @@ class WhileContext(ControlFlowContext):
|
||||
|
||||
|
||||
def While(cond, body, loop_vars, parallel_iterations=10, back_prop=True,
|
||||
name=None):
|
||||
swap_memory=False, name=None):
|
||||
"""Repeat `body` while the condition `cond` is true.
|
||||
|
||||
`cond` is a function taking a list of tensors and returning a boolean scalar
|
||||
@ -1560,6 +1572,7 @@ def While(cond, body, loop_vars, parallel_iterations=10, back_prop=True,
|
||||
loop_vars: The list of variable input tensors.
|
||||
parallel_iterations: The number of iterations allowed to run in parallel.
|
||||
back_prop: Whether backprop is enabled for this while loop.
|
||||
swap_memory: Whether GPU-CPU memory swap is enabled for this loop.
|
||||
name: Optional name prefix for the returned tensors.
|
||||
|
||||
Returns:
|
||||
@ -1585,7 +1598,7 @@ def While(cond, body, loop_vars, parallel_iterations=10, back_prop=True,
|
||||
if not callable(body):
|
||||
raise TypeError("body must be callable.")
|
||||
|
||||
context = WhileContext(parallel_iterations, back_prop, name)
|
||||
context = WhileContext(parallel_iterations, back_prop, swap_memory, name)
|
||||
context.Enter()
|
||||
result = context.BuildLoop(cond, body, loop_vars)
|
||||
context.Exit()
|
||||
|
Loading…
Reference in New Issue
Block a user