Switch control_flow_ops library to use Resource variants of Stack operators, instead of deprecated Ref variants.

PiperOrigin-RevId: 168234822
This commit is contained in:
Peter Hawkins 2017-09-11 08:38:23 -07:00 committed by TensorFlower Gardener
parent ca43fe82bb
commit 10ba148f77
2 changed files with 20 additions and 9 deletions

View File

@ -1410,7 +1410,7 @@ class ControlFlowTest(test.TestCase):
def testWhileStack_1(self):
with self.test_session():
s = gen_data_flow_ops._stack(dtypes.int32, stack_name="foo")
s = gen_data_flow_ops._stack_v2(-1, dtypes.int32, stack_name="foo")
i = constant_op.constant(0)
def c(i):
@ -1419,7 +1419,7 @@ class ControlFlowTest(test.TestCase):
def b(i):
ni = math_ops.add(i, 1)
ni = control_flow_ops.with_dependencies(
[gen_data_flow_ops._stack_push(s, i)], ni)
[gen_data_flow_ops._stack_push_v2(s, i)], ni)
return ni
r = control_flow_ops.while_loop(c, b, [i], parallel_iterations=1)
@ -1431,7 +1431,7 @@ class ControlFlowTest(test.TestCase):
def b1(i, x):
ni = math_ops.subtract(i, 1)
nx = x + gen_data_flow_ops._stack_pop(s, dtypes.int32)
nx = x + gen_data_flow_ops._stack_pop_v2(s, dtypes.int32)
return [ni, nx]
_, rx = control_flow_ops.while_loop(
@ -2612,7 +2612,8 @@ class ControlFlowTest(test.TestCase):
r = gradients_impl.gradients(r, x)[0]
self.assertEqual(r.eval(), 524288.0)
self.assertEqual(
len([op for op in x.graph.get_operations() if op.type == "Stack"]), 1)
len([op for op in x.graph.get_operations() if op.type == "StackV2"]),
1)
class TupleTest(test.TestCase):

View File

@ -864,7 +864,8 @@ class GradLoopState(object):
if curr_ctxt: curr_ctxt.Enter()
with ops.colocate_with(value):
# pylint: disable=protected-access
acc = gen_data_flow_ops._stack(value.dtype.base_dtype, name="f_acc")
acc = gen_data_flow_ops._stack_v2(-1, value.dtype.base_dtype,
name="f_acc")
# pylint: enable=protected-access
if curr_ctxt: curr_ctxt.Exit()
@ -877,8 +878,10 @@ class GradLoopState(object):
if value_ctxt == self.forward_context:
# value is not nested in the forward context.
self.forward_context.Enter()
push = gen_data_flow_ops._stack_push(
# pylint: disable=protected-access
push = gen_data_flow_ops._stack_push_v2(
enter_acc, value, swap_memory=swap_enabled)
# pylint: enable=protected-access
self.forward_context.Exit()
# Protect stack push and order it before forward_index.
self.forward_index.op._add_control_input(push.op)
@ -891,14 +894,18 @@ class GradLoopState(object):
# The special case for creating a zero tensor for a dead
# branch of a switch. See ControlFlowState.ZerosLike().
value_ctxt.outer_context.Enter()
push = gen_data_flow_ops._stack_push(
# pylint: disable=protected-access
push = gen_data_flow_ops._stack_push_v2(
enter_acc, value, swap_memory=swap_enabled)
# pylint: enable=protected-access
value_ctxt.outer_context.Exit()
push.op._set_control_flow_context(value_ctxt)
else:
value_ctxt.Enter()
push = gen_data_flow_ops._stack_push(
# pylint: disable=protected-access
push = gen_data_flow_ops._stack_push_v2(
enter_acc, value, swap_memory=swap_enabled)
# pylint: enable=protected-access
value_ctxt.Exit()
# Protect stack push and order it before forward_sync.
self.forward_sync._add_control_input(push.op)
@ -945,7 +952,10 @@ class GradLoopState(object):
pred = cond_ctxt.pred
branch = (1 - cond_ctxt.branch) if dead_branch else cond_ctxt.branch
history_value = _SwitchRefOrTensor(history_value, pred)[branch]
pop = gen_data_flow_ops._stack_pop(history_value, value.dtype.base_dtype)
# pylint: disable=protected-access
pop = gen_data_flow_ops._stack_pop_v2(history_value,
value.dtype.base_dtype)
# pylint: enable=protected-access
pop.set_shape(value.get_shape())
self.grad_context.Exit()
parallel_iterations = self.grad_context.parallel_iterations