From 72a787e2b1e6f1d284b3692235bf1d8a30ef5cff Mon Sep 17 00:00:00 2001 From: Saurabh Saxena Date: Sat, 7 Mar 2020 19:09:20 -0800 Subject: [PATCH] Make while_v2_indexed_slices_writers compatible with tensor equality. PiperOrigin-RevId: 299619903 Change-Id: Ia829ce5942d08fb6c55d83d5180f0a49e80bfdc3 --- .../kernel_tests/control_flow_ops_py_test.py | 47 ++++++++++++------- .../ops/while_v2_indexed_slices_rewriter.py | 18 +++++-- 2 files changed, 46 insertions(+), 19 deletions(-) diff --git a/tensorflow/python/kernel_tests/control_flow_ops_py_test.py b/tensorflow/python/kernel_tests/control_flow_ops_py_test.py index f65cd64c93d..ec9d97c4bcc 100644 --- a/tensorflow/python/kernel_tests/control_flow_ops_py_test.py +++ b/tensorflow/python/kernel_tests/control_flow_ops_py_test.py @@ -31,6 +31,7 @@ import numpy as np from six.moves import xrange # pylint: disable=redefined-builtin from tensorflow.core.protobuf import config_pb2 +from tensorflow.python import tf2 from tensorflow.python.client import device_lib from tensorflow.python.client import session from tensorflow.python.eager import context @@ -150,6 +151,14 @@ def filter_test_messages(s): return [l[len(prefix):] for l in s.split("\n") if l.startswith(prefix)] +def tf_function_in_tf2(f): + if tf2.enabled(): + # In TF1 do not wrap with tf.function so that we can test the v1 control + # flow code path. + return def_function.function(f) + return f + + @test_util.with_control_flow_v2 class ControlFlowTest(test.TestCase, parameterized.TestCase): @@ -3207,31 +3216,37 @@ class ControlFlowTest(test.TestCase, parameterized.TestCase): self.assertAllEqual(gradient_checker_v2._to_numpy(var2_grad_val), [3., 0., 0.]) - @test_util.run_deprecated_v1 def testWhileGrad_Gather(self): # NOTE(skyewm): this test is interesting because the gather gradient # function returns an IndexedSlices. - x = constant_op.constant([1., 1., 1., 1., 1.]) - y = control_flow_ops.while_loop( - lambda i, _: i < 3, - lambda i, x: (i + 1, x + array_ops.gather(x, [0])), - [0, x[:1]])[1] - z = y * 3.0 - grad = gradients_impl.gradients(z, x)[0] + @tf_function_in_tf2 + def fn(): + x = constant_op.constant([1., 1., 1., 1., 1.]) + y = control_flow_ops.while_loop( + lambda i, _: i < 3, + lambda i, x: (i + 1, x + array_ops.gather(x, [0])), + [0, x[:1]])[1] + z = y * 3.0 + grad = gradients_impl.gradients(z, x)[0] + return y, grad + y, grad = fn() self.assertEqual(self.evaluate(y), 8.) self.assertAllEqual(self.evaluate(grad), [24., 0., 0., 0., 0.]) - @test_util.run_deprecated_v1 def testWhileGrad_GatherNoFanOut(self): # NOTE(skyewm): this test is interesting because the gather gradient # function returns an IndexedSlices. - x = constant_op.constant([1., 1., 1., 1., 1.]) - y = control_flow_ops.while_loop( - lambda i, _: i < 3, - lambda i, x: (i + 1, array_ops.gather(x, [0])), - [0, x[:1]])[1] - z = y * 3.0 - grad = gradients_impl.gradients(z, x)[0] + @tf_function_in_tf2 + def fn(): + x = constant_op.constant([1., 1., 1., 1., 1.]) + y = control_flow_ops.while_loop( + lambda i, _: i < 3, + lambda i, x: (i + 1, array_ops.gather(x, [0])), + [0, x[:1]])[1] + z = y * 3.0 + grad = gradients_impl.gradients(z, x)[0] + return y, grad + y, grad = fn() self.assertEqual(self.evaluate(y), 1.) self.assertAllEqual(self.evaluate(grad), [3., 0., 0., 0., 0.]) diff --git a/tensorflow/python/ops/while_v2_indexed_slices_rewriter.py b/tensorflow/python/ops/while_v2_indexed_slices_rewriter.py index 9637ee174d7..70e63133bdb 100644 --- a/tensorflow/python/ops/while_v2_indexed_slices_rewriter.py +++ b/tensorflow/python/ops/while_v2_indexed_slices_rewriter.py @@ -81,6 +81,14 @@ def rewrite_grad_indexed_slices(grads, body_grad_graph, loop_vars, return loop_vars +def _get_tensor_index_in_iterable(iterable, t): + """Returns index of first occurence of `t`, raises ValueError if not found.""" + for i, elem in enumerate(iterable): + if t is elem: + return i + raise ValueError("%s is not in iterable" % str(t)) + + def _rewrite_output_as_tensor(body_grad_graph, grad_output_slices): """Rewrites grad_output_slices to be a Tensor output. @@ -91,7 +99,8 @@ def _rewrite_output_as_tensor(body_grad_graph, grad_output_slices): with body_grad_graph.as_default(): new_output = ops.convert_to_tensor_v2(grad_output_slices) - idx = body_grad_graph.structured_outputs.index(grad_output_slices) + idx = _get_tensor_index_in_iterable(body_grad_graph.structured_outputs, + grad_output_slices) body_grad_graph.structured_outputs[idx] = new_output body_grad_graph.outputs = func_graph.flatten( body_grad_graph.structured_outputs) @@ -259,11 +268,14 @@ def _update_indexed_slices_param(graph, loop_vars, init_slices, input_slices, Returns: New loop_vars to pass to graph. """ - structured_idx = graph.structured_outputs.index(old_output_slices) + structured_idx = _get_tensor_index_in_iterable(graph.structured_outputs, + old_output_slices) # We assume that the component tensors of old_output_slices appear # sequentially in graph.outputs. We use the first of these tensors # as the reference index. - flat_idx = graph.outputs.index(func_graph.flatten(old_output_slices)[0]) + flat_idx = _get_tensor_index_in_iterable( + graph.outputs, + func_graph.flatten(old_output_slices)[0]) graph.structured_outputs[structured_idx] = output_slices graph.outputs = func_graph.flatten(