Switch control_flow_ops library to use Resource variants of Stack operators, instead of deprecated Ref variants.
PiperOrigin-RevId: 168234822
This commit is contained in:
parent
ca43fe82bb
commit
10ba148f77
@ -1410,7 +1410,7 @@ class ControlFlowTest(test.TestCase):
|
||||
|
||||
def testWhileStack_1(self):
|
||||
with self.test_session():
|
||||
s = gen_data_flow_ops._stack(dtypes.int32, stack_name="foo")
|
||||
s = gen_data_flow_ops._stack_v2(-1, dtypes.int32, stack_name="foo")
|
||||
i = constant_op.constant(0)
|
||||
|
||||
def c(i):
|
||||
@ -1419,7 +1419,7 @@ class ControlFlowTest(test.TestCase):
|
||||
def b(i):
|
||||
ni = math_ops.add(i, 1)
|
||||
ni = control_flow_ops.with_dependencies(
|
||||
[gen_data_flow_ops._stack_push(s, i)], ni)
|
||||
[gen_data_flow_ops._stack_push_v2(s, i)], ni)
|
||||
return ni
|
||||
|
||||
r = control_flow_ops.while_loop(c, b, [i], parallel_iterations=1)
|
||||
@ -1431,7 +1431,7 @@ class ControlFlowTest(test.TestCase):
|
||||
|
||||
def b1(i, x):
|
||||
ni = math_ops.subtract(i, 1)
|
||||
nx = x + gen_data_flow_ops._stack_pop(s, dtypes.int32)
|
||||
nx = x + gen_data_flow_ops._stack_pop_v2(s, dtypes.int32)
|
||||
return [ni, nx]
|
||||
|
||||
_, rx = control_flow_ops.while_loop(
|
||||
@ -2612,7 +2612,8 @@ class ControlFlowTest(test.TestCase):
|
||||
r = gradients_impl.gradients(r, x)[0]
|
||||
self.assertEqual(r.eval(), 524288.0)
|
||||
self.assertEqual(
|
||||
len([op for op in x.graph.get_operations() if op.type == "Stack"]), 1)
|
||||
len([op for op in x.graph.get_operations() if op.type == "StackV2"]),
|
||||
1)
|
||||
|
||||
|
||||
class TupleTest(test.TestCase):
|
||||
|
@ -864,7 +864,8 @@ class GradLoopState(object):
|
||||
if curr_ctxt: curr_ctxt.Enter()
|
||||
with ops.colocate_with(value):
|
||||
# pylint: disable=protected-access
|
||||
acc = gen_data_flow_ops._stack(value.dtype.base_dtype, name="f_acc")
|
||||
acc = gen_data_flow_ops._stack_v2(-1, value.dtype.base_dtype,
|
||||
name="f_acc")
|
||||
# pylint: enable=protected-access
|
||||
if curr_ctxt: curr_ctxt.Exit()
|
||||
|
||||
@ -877,8 +878,10 @@ class GradLoopState(object):
|
||||
if value_ctxt == self.forward_context:
|
||||
# value is not nested in the forward context.
|
||||
self.forward_context.Enter()
|
||||
push = gen_data_flow_ops._stack_push(
|
||||
# pylint: disable=protected-access
|
||||
push = gen_data_flow_ops._stack_push_v2(
|
||||
enter_acc, value, swap_memory=swap_enabled)
|
||||
# pylint: enable=protected-access
|
||||
self.forward_context.Exit()
|
||||
# Protect stack push and order it before forward_index.
|
||||
self.forward_index.op._add_control_input(push.op)
|
||||
@ -891,14 +894,18 @@ class GradLoopState(object):
|
||||
# The special case for creating a zero tensor for a dead
|
||||
# branch of a switch. See ControlFlowState.ZerosLike().
|
||||
value_ctxt.outer_context.Enter()
|
||||
push = gen_data_flow_ops._stack_push(
|
||||
# pylint: disable=protected-access
|
||||
push = gen_data_flow_ops._stack_push_v2(
|
||||
enter_acc, value, swap_memory=swap_enabled)
|
||||
# pylint: enable=protected-access
|
||||
value_ctxt.outer_context.Exit()
|
||||
push.op._set_control_flow_context(value_ctxt)
|
||||
else:
|
||||
value_ctxt.Enter()
|
||||
push = gen_data_flow_ops._stack_push(
|
||||
# pylint: disable=protected-access
|
||||
push = gen_data_flow_ops._stack_push_v2(
|
||||
enter_acc, value, swap_memory=swap_enabled)
|
||||
# pylint: enable=protected-access
|
||||
value_ctxt.Exit()
|
||||
# Protect stack push and order it before forward_sync.
|
||||
self.forward_sync._add_control_input(push.op)
|
||||
@ -945,7 +952,10 @@ class GradLoopState(object):
|
||||
pred = cond_ctxt.pred
|
||||
branch = (1 - cond_ctxt.branch) if dead_branch else cond_ctxt.branch
|
||||
history_value = _SwitchRefOrTensor(history_value, pred)[branch]
|
||||
pop = gen_data_flow_ops._stack_pop(history_value, value.dtype.base_dtype)
|
||||
# pylint: disable=protected-access
|
||||
pop = gen_data_flow_ops._stack_pop_v2(history_value,
|
||||
value.dtype.base_dtype)
|
||||
# pylint: enable=protected-access
|
||||
pop.set_shape(value.get_shape())
|
||||
self.grad_context.Exit()
|
||||
parallel_iterations = self.grad_context.parallel_iterations
|
||||
|
Loading…
Reference in New Issue
Block a user