Make while_v2_indexed_slices_writers compatible with tensor equality.

PiperOrigin-RevId: 299619903
Change-Id: Ia829ce5942d08fb6c55d83d5180f0a49e80bfdc3
This commit is contained in:
Saurabh Saxena 2020-03-07 19:09:20 -08:00 committed by TensorFlower Gardener
parent 404ea6206c
commit 72a787e2b1
2 changed files with 46 additions and 19 deletions

View File

@ -31,6 +31,7 @@ import numpy as np
from six.moves import xrange # pylint: disable=redefined-builtin
from tensorflow.core.protobuf import config_pb2
from tensorflow.python import tf2
from tensorflow.python.client import device_lib
from tensorflow.python.client import session
from tensorflow.python.eager import context
@ -150,6 +151,14 @@ def filter_test_messages(s):
return [l[len(prefix):] for l in s.split("\n") if l.startswith(prefix)]
def tf_function_in_tf2(f):
if tf2.enabled():
# In TF1 do not wrap with tf.function so that we can test the v1 control
# flow code path.
return def_function.function(f)
return f
@test_util.with_control_flow_v2
class ControlFlowTest(test.TestCase, parameterized.TestCase):
@ -3207,31 +3216,37 @@ class ControlFlowTest(test.TestCase, parameterized.TestCase):
self.assertAllEqual(gradient_checker_v2._to_numpy(var2_grad_val),
[3., 0., 0.])
@test_util.run_deprecated_v1
def testWhileGrad_Gather(self):
# NOTE(skyewm): this test is interesting because the gather gradient
# function returns an IndexedSlices.
x = constant_op.constant([1., 1., 1., 1., 1.])
y = control_flow_ops.while_loop(
lambda i, _: i < 3,
lambda i, x: (i + 1, x + array_ops.gather(x, [0])),
[0, x[:1]])[1]
z = y * 3.0
grad = gradients_impl.gradients(z, x)[0]
@tf_function_in_tf2
def fn():
x = constant_op.constant([1., 1., 1., 1., 1.])
y = control_flow_ops.while_loop(
lambda i, _: i < 3,
lambda i, x: (i + 1, x + array_ops.gather(x, [0])),
[0, x[:1]])[1]
z = y * 3.0
grad = gradients_impl.gradients(z, x)[0]
return y, grad
y, grad = fn()
self.assertEqual(self.evaluate(y), 8.)
self.assertAllEqual(self.evaluate(grad), [24., 0., 0., 0., 0.])
@test_util.run_deprecated_v1
def testWhileGrad_GatherNoFanOut(self):
# NOTE(skyewm): this test is interesting because the gather gradient
# function returns an IndexedSlices.
x = constant_op.constant([1., 1., 1., 1., 1.])
y = control_flow_ops.while_loop(
lambda i, _: i < 3,
lambda i, x: (i + 1, array_ops.gather(x, [0])),
[0, x[:1]])[1]
z = y * 3.0
grad = gradients_impl.gradients(z, x)[0]
@tf_function_in_tf2
def fn():
x = constant_op.constant([1., 1., 1., 1., 1.])
y = control_flow_ops.while_loop(
lambda i, _: i < 3,
lambda i, x: (i + 1, array_ops.gather(x, [0])),
[0, x[:1]])[1]
z = y * 3.0
grad = gradients_impl.gradients(z, x)[0]
return y, grad
y, grad = fn()
self.assertEqual(self.evaluate(y), 1.)
self.assertAllEqual(self.evaluate(grad), [3., 0., 0., 0., 0.])

View File

@ -81,6 +81,14 @@ def rewrite_grad_indexed_slices(grads, body_grad_graph, loop_vars,
return loop_vars
def _get_tensor_index_in_iterable(iterable, t):
"""Returns index of first occurence of `t`, raises ValueError if not found."""
for i, elem in enumerate(iterable):
if t is elem:
return i
raise ValueError("%s is not in iterable" % str(t))
def _rewrite_output_as_tensor(body_grad_graph, grad_output_slices):
"""Rewrites grad_output_slices to be a Tensor output.
@ -91,7 +99,8 @@ def _rewrite_output_as_tensor(body_grad_graph, grad_output_slices):
with body_grad_graph.as_default():
new_output = ops.convert_to_tensor_v2(grad_output_slices)
idx = body_grad_graph.structured_outputs.index(grad_output_slices)
idx = _get_tensor_index_in_iterable(body_grad_graph.structured_outputs,
grad_output_slices)
body_grad_graph.structured_outputs[idx] = new_output
body_grad_graph.outputs = func_graph.flatten(
body_grad_graph.structured_outputs)
@ -259,11 +268,14 @@ def _update_indexed_slices_param(graph, loop_vars, init_slices, input_slices,
Returns:
New loop_vars to pass to graph.
"""
structured_idx = graph.structured_outputs.index(old_output_slices)
structured_idx = _get_tensor_index_in_iterable(graph.structured_outputs,
old_output_slices)
# We assume that the component tensors of old_output_slices appear
# sequentially in graph.outputs. We use the first of these tensors
# as the reference index.
flat_idx = graph.outputs.index(func_graph.flatten(old_output_slices)[0])
flat_idx = _get_tensor_index_in_iterable(
graph.outputs,
func_graph.flatten(old_output_slices)[0])
graph.structured_outputs[structured_idx] = output_slices
graph.outputs = func_graph.flatten(