Make while_v2_indexed_slices_writers compatible with tensor equality.
PiperOrigin-RevId: 299619903 Change-Id: Ia829ce5942d08fb6c55d83d5180f0a49e80bfdc3
This commit is contained in:
parent
404ea6206c
commit
72a787e2b1
@ -31,6 +31,7 @@ import numpy as np
|
||||
from six.moves import xrange # pylint: disable=redefined-builtin
|
||||
|
||||
from tensorflow.core.protobuf import config_pb2
|
||||
from tensorflow.python import tf2
|
||||
from tensorflow.python.client import device_lib
|
||||
from tensorflow.python.client import session
|
||||
from tensorflow.python.eager import context
|
||||
@ -150,6 +151,14 @@ def filter_test_messages(s):
|
||||
return [l[len(prefix):] for l in s.split("\n") if l.startswith(prefix)]
|
||||
|
||||
|
||||
def tf_function_in_tf2(f):
|
||||
if tf2.enabled():
|
||||
# In TF1 do not wrap with tf.function so that we can test the v1 control
|
||||
# flow code path.
|
||||
return def_function.function(f)
|
||||
return f
|
||||
|
||||
|
||||
@test_util.with_control_flow_v2
|
||||
class ControlFlowTest(test.TestCase, parameterized.TestCase):
|
||||
|
||||
@ -3207,31 +3216,37 @@ class ControlFlowTest(test.TestCase, parameterized.TestCase):
|
||||
self.assertAllEqual(gradient_checker_v2._to_numpy(var2_grad_val),
|
||||
[3., 0., 0.])
|
||||
|
||||
@test_util.run_deprecated_v1
|
||||
def testWhileGrad_Gather(self):
|
||||
# NOTE(skyewm): this test is interesting because the gather gradient
|
||||
# function returns an IndexedSlices.
|
||||
x = constant_op.constant([1., 1., 1., 1., 1.])
|
||||
y = control_flow_ops.while_loop(
|
||||
lambda i, _: i < 3,
|
||||
lambda i, x: (i + 1, x + array_ops.gather(x, [0])),
|
||||
[0, x[:1]])[1]
|
||||
z = y * 3.0
|
||||
grad = gradients_impl.gradients(z, x)[0]
|
||||
@tf_function_in_tf2
|
||||
def fn():
|
||||
x = constant_op.constant([1., 1., 1., 1., 1.])
|
||||
y = control_flow_ops.while_loop(
|
||||
lambda i, _: i < 3,
|
||||
lambda i, x: (i + 1, x + array_ops.gather(x, [0])),
|
||||
[0, x[:1]])[1]
|
||||
z = y * 3.0
|
||||
grad = gradients_impl.gradients(z, x)[0]
|
||||
return y, grad
|
||||
y, grad = fn()
|
||||
self.assertEqual(self.evaluate(y), 8.)
|
||||
self.assertAllEqual(self.evaluate(grad), [24., 0., 0., 0., 0.])
|
||||
|
||||
@test_util.run_deprecated_v1
|
||||
def testWhileGrad_GatherNoFanOut(self):
|
||||
# NOTE(skyewm): this test is interesting because the gather gradient
|
||||
# function returns an IndexedSlices.
|
||||
x = constant_op.constant([1., 1., 1., 1., 1.])
|
||||
y = control_flow_ops.while_loop(
|
||||
lambda i, _: i < 3,
|
||||
lambda i, x: (i + 1, array_ops.gather(x, [0])),
|
||||
[0, x[:1]])[1]
|
||||
z = y * 3.0
|
||||
grad = gradients_impl.gradients(z, x)[0]
|
||||
@tf_function_in_tf2
|
||||
def fn():
|
||||
x = constant_op.constant([1., 1., 1., 1., 1.])
|
||||
y = control_flow_ops.while_loop(
|
||||
lambda i, _: i < 3,
|
||||
lambda i, x: (i + 1, array_ops.gather(x, [0])),
|
||||
[0, x[:1]])[1]
|
||||
z = y * 3.0
|
||||
grad = gradients_impl.gradients(z, x)[0]
|
||||
return y, grad
|
||||
y, grad = fn()
|
||||
self.assertEqual(self.evaluate(y), 1.)
|
||||
self.assertAllEqual(self.evaluate(grad), [3., 0., 0., 0., 0.])
|
||||
|
||||
|
@ -81,6 +81,14 @@ def rewrite_grad_indexed_slices(grads, body_grad_graph, loop_vars,
|
||||
return loop_vars
|
||||
|
||||
|
||||
def _get_tensor_index_in_iterable(iterable, t):
|
||||
"""Returns index of first occurence of `t`, raises ValueError if not found."""
|
||||
for i, elem in enumerate(iterable):
|
||||
if t is elem:
|
||||
return i
|
||||
raise ValueError("%s is not in iterable" % str(t))
|
||||
|
||||
|
||||
def _rewrite_output_as_tensor(body_grad_graph, grad_output_slices):
|
||||
"""Rewrites grad_output_slices to be a Tensor output.
|
||||
|
||||
@ -91,7 +99,8 @@ def _rewrite_output_as_tensor(body_grad_graph, grad_output_slices):
|
||||
with body_grad_graph.as_default():
|
||||
new_output = ops.convert_to_tensor_v2(grad_output_slices)
|
||||
|
||||
idx = body_grad_graph.structured_outputs.index(grad_output_slices)
|
||||
idx = _get_tensor_index_in_iterable(body_grad_graph.structured_outputs,
|
||||
grad_output_slices)
|
||||
body_grad_graph.structured_outputs[idx] = new_output
|
||||
body_grad_graph.outputs = func_graph.flatten(
|
||||
body_grad_graph.structured_outputs)
|
||||
@ -259,11 +268,14 @@ def _update_indexed_slices_param(graph, loop_vars, init_slices, input_slices,
|
||||
Returns:
|
||||
New loop_vars to pass to graph.
|
||||
"""
|
||||
structured_idx = graph.structured_outputs.index(old_output_slices)
|
||||
structured_idx = _get_tensor_index_in_iterable(graph.structured_outputs,
|
||||
old_output_slices)
|
||||
# We assume that the component tensors of old_output_slices appear
|
||||
# sequentially in graph.outputs. We use the first of these tensors
|
||||
# as the reference index.
|
||||
flat_idx = graph.outputs.index(func_graph.flatten(old_output_slices)[0])
|
||||
flat_idx = _get_tensor_index_in_iterable(
|
||||
graph.outputs,
|
||||
func_graph.flatten(old_output_slices)[0])
|
||||
|
||||
graph.structured_outputs[structured_idx] = output_slices
|
||||
graph.outputs = func_graph.flatten(
|
||||
|
Loading…
Reference in New Issue
Block a user