From 10ba148f7711c4724b41a09b09963f7f0f21fe7a Mon Sep 17 00:00:00 2001 From: Peter Hawkins Date: Mon, 11 Sep 2017 08:38:23 -0700 Subject: [PATCH] Switch control_flow_ops library to use Resource variants of Stack operators, instead of deprecated Ref variants. PiperOrigin-RevId: 168234822 --- .../kernel_tests/control_flow_ops_py_test.py | 9 +++++---- tensorflow/python/ops/control_flow_ops.py | 20 ++++++++++++++----- 2 files changed, 20 insertions(+), 9 deletions(-) diff --git a/tensorflow/python/kernel_tests/control_flow_ops_py_test.py b/tensorflow/python/kernel_tests/control_flow_ops_py_test.py index 6e81e1fdbd8..a21182beba3 100644 --- a/tensorflow/python/kernel_tests/control_flow_ops_py_test.py +++ b/tensorflow/python/kernel_tests/control_flow_ops_py_test.py @@ -1410,7 +1410,7 @@ class ControlFlowTest(test.TestCase): def testWhileStack_1(self): with self.test_session(): - s = gen_data_flow_ops._stack(dtypes.int32, stack_name="foo") + s = gen_data_flow_ops._stack_v2(-1, dtypes.int32, stack_name="foo") i = constant_op.constant(0) def c(i): @@ -1419,7 +1419,7 @@ class ControlFlowTest(test.TestCase): def b(i): ni = math_ops.add(i, 1) ni = control_flow_ops.with_dependencies( - [gen_data_flow_ops._stack_push(s, i)], ni) + [gen_data_flow_ops._stack_push_v2(s, i)], ni) return ni r = control_flow_ops.while_loop(c, b, [i], parallel_iterations=1) @@ -1431,7 +1431,7 @@ class ControlFlowTest(test.TestCase): def b1(i, x): ni = math_ops.subtract(i, 1) - nx = x + gen_data_flow_ops._stack_pop(s, dtypes.int32) + nx = x + gen_data_flow_ops._stack_pop_v2(s, dtypes.int32) return [ni, nx] _, rx = control_flow_ops.while_loop( @@ -2612,7 +2612,8 @@ class ControlFlowTest(test.TestCase): r = gradients_impl.gradients(r, x)[0] self.assertEqual(r.eval(), 524288.0) self.assertEqual( - len([op for op in x.graph.get_operations() if op.type == "Stack"]), 1) + len([op for op in x.graph.get_operations() if op.type == "StackV2"]), + 1) class TupleTest(test.TestCase): diff --git a/tensorflow/python/ops/control_flow_ops.py b/tensorflow/python/ops/control_flow_ops.py index 97b37ea0272..e8f64d58175 100644 --- a/tensorflow/python/ops/control_flow_ops.py +++ b/tensorflow/python/ops/control_flow_ops.py @@ -864,7 +864,8 @@ class GradLoopState(object): if curr_ctxt: curr_ctxt.Enter() with ops.colocate_with(value): # pylint: disable=protected-access - acc = gen_data_flow_ops._stack(value.dtype.base_dtype, name="f_acc") + acc = gen_data_flow_ops._stack_v2(-1, value.dtype.base_dtype, + name="f_acc") # pylint: enable=protected-access if curr_ctxt: curr_ctxt.Exit() @@ -877,8 +878,10 @@ class GradLoopState(object): if value_ctxt == self.forward_context: # value is not nested in the forward context. self.forward_context.Enter() - push = gen_data_flow_ops._stack_push( + # pylint: disable=protected-access + push = gen_data_flow_ops._stack_push_v2( enter_acc, value, swap_memory=swap_enabled) + # pylint: enable=protected-access self.forward_context.Exit() # Protect stack push and order it before forward_index. self.forward_index.op._add_control_input(push.op) @@ -891,14 +894,18 @@ class GradLoopState(object): # The special case for creating a zero tensor for a dead # branch of a switch. See ControlFlowState.ZerosLike(). value_ctxt.outer_context.Enter() - push = gen_data_flow_ops._stack_push( + # pylint: disable=protected-access + push = gen_data_flow_ops._stack_push_v2( enter_acc, value, swap_memory=swap_enabled) + # pylint: enable=protected-access value_ctxt.outer_context.Exit() push.op._set_control_flow_context(value_ctxt) else: value_ctxt.Enter() - push = gen_data_flow_ops._stack_push( + # pylint: disable=protected-access + push = gen_data_flow_ops._stack_push_v2( enter_acc, value, swap_memory=swap_enabled) + # pylint: enable=protected-access value_ctxt.Exit() # Protect stack push and order it before forward_sync. self.forward_sync._add_control_input(push.op) @@ -945,7 +952,10 @@ class GradLoopState(object): pred = cond_ctxt.pred branch = (1 - cond_ctxt.branch) if dead_branch else cond_ctxt.branch history_value = _SwitchRefOrTensor(history_value, pred)[branch] - pop = gen_data_flow_ops._stack_pop(history_value, value.dtype.base_dtype) + # pylint: disable=protected-access + pop = gen_data_flow_ops._stack_pop_v2(history_value, + value.dtype.base_dtype) + # pylint: enable=protected-access pop.set_shape(value.get_shape()) self.grad_context.Exit() parallel_iterations = self.grad_context.parallel_iterations