Make while_v2_indexed_slices_writers compatible with tensor equality.
PiperOrigin-RevId: 299619903 Change-Id: Ia829ce5942d08fb6c55d83d5180f0a49e80bfdc3
This commit is contained in:
parent
404ea6206c
commit
72a787e2b1
@ -31,6 +31,7 @@ import numpy as np
|
|||||||
from six.moves import xrange # pylint: disable=redefined-builtin
|
from six.moves import xrange # pylint: disable=redefined-builtin
|
||||||
|
|
||||||
from tensorflow.core.protobuf import config_pb2
|
from tensorflow.core.protobuf import config_pb2
|
||||||
|
from tensorflow.python import tf2
|
||||||
from tensorflow.python.client import device_lib
|
from tensorflow.python.client import device_lib
|
||||||
from tensorflow.python.client import session
|
from tensorflow.python.client import session
|
||||||
from tensorflow.python.eager import context
|
from tensorflow.python.eager import context
|
||||||
@ -150,6 +151,14 @@ def filter_test_messages(s):
|
|||||||
return [l[len(prefix):] for l in s.split("\n") if l.startswith(prefix)]
|
return [l[len(prefix):] for l in s.split("\n") if l.startswith(prefix)]
|
||||||
|
|
||||||
|
|
||||||
|
def tf_function_in_tf2(f):
|
||||||
|
if tf2.enabled():
|
||||||
|
# In TF1 do not wrap with tf.function so that we can test the v1 control
|
||||||
|
# flow code path.
|
||||||
|
return def_function.function(f)
|
||||||
|
return f
|
||||||
|
|
||||||
|
|
||||||
@test_util.with_control_flow_v2
|
@test_util.with_control_flow_v2
|
||||||
class ControlFlowTest(test.TestCase, parameterized.TestCase):
|
class ControlFlowTest(test.TestCase, parameterized.TestCase):
|
||||||
|
|
||||||
@ -3207,31 +3216,37 @@ class ControlFlowTest(test.TestCase, parameterized.TestCase):
|
|||||||
self.assertAllEqual(gradient_checker_v2._to_numpy(var2_grad_val),
|
self.assertAllEqual(gradient_checker_v2._to_numpy(var2_grad_val),
|
||||||
[3., 0., 0.])
|
[3., 0., 0.])
|
||||||
|
|
||||||
@test_util.run_deprecated_v1
|
|
||||||
def testWhileGrad_Gather(self):
|
def testWhileGrad_Gather(self):
|
||||||
# NOTE(skyewm): this test is interesting because the gather gradient
|
# NOTE(skyewm): this test is interesting because the gather gradient
|
||||||
# function returns an IndexedSlices.
|
# function returns an IndexedSlices.
|
||||||
x = constant_op.constant([1., 1., 1., 1., 1.])
|
@tf_function_in_tf2
|
||||||
y = control_flow_ops.while_loop(
|
def fn():
|
||||||
lambda i, _: i < 3,
|
x = constant_op.constant([1., 1., 1., 1., 1.])
|
||||||
lambda i, x: (i + 1, x + array_ops.gather(x, [0])),
|
y = control_flow_ops.while_loop(
|
||||||
[0, x[:1]])[1]
|
lambda i, _: i < 3,
|
||||||
z = y * 3.0
|
lambda i, x: (i + 1, x + array_ops.gather(x, [0])),
|
||||||
grad = gradients_impl.gradients(z, x)[0]
|
[0, x[:1]])[1]
|
||||||
|
z = y * 3.0
|
||||||
|
grad = gradients_impl.gradients(z, x)[0]
|
||||||
|
return y, grad
|
||||||
|
y, grad = fn()
|
||||||
self.assertEqual(self.evaluate(y), 8.)
|
self.assertEqual(self.evaluate(y), 8.)
|
||||||
self.assertAllEqual(self.evaluate(grad), [24., 0., 0., 0., 0.])
|
self.assertAllEqual(self.evaluate(grad), [24., 0., 0., 0., 0.])
|
||||||
|
|
||||||
@test_util.run_deprecated_v1
|
|
||||||
def testWhileGrad_GatherNoFanOut(self):
|
def testWhileGrad_GatherNoFanOut(self):
|
||||||
# NOTE(skyewm): this test is interesting because the gather gradient
|
# NOTE(skyewm): this test is interesting because the gather gradient
|
||||||
# function returns an IndexedSlices.
|
# function returns an IndexedSlices.
|
||||||
x = constant_op.constant([1., 1., 1., 1., 1.])
|
@tf_function_in_tf2
|
||||||
y = control_flow_ops.while_loop(
|
def fn():
|
||||||
lambda i, _: i < 3,
|
x = constant_op.constant([1., 1., 1., 1., 1.])
|
||||||
lambda i, x: (i + 1, array_ops.gather(x, [0])),
|
y = control_flow_ops.while_loop(
|
||||||
[0, x[:1]])[1]
|
lambda i, _: i < 3,
|
||||||
z = y * 3.0
|
lambda i, x: (i + 1, array_ops.gather(x, [0])),
|
||||||
grad = gradients_impl.gradients(z, x)[0]
|
[0, x[:1]])[1]
|
||||||
|
z = y * 3.0
|
||||||
|
grad = gradients_impl.gradients(z, x)[0]
|
||||||
|
return y, grad
|
||||||
|
y, grad = fn()
|
||||||
self.assertEqual(self.evaluate(y), 1.)
|
self.assertEqual(self.evaluate(y), 1.)
|
||||||
self.assertAllEqual(self.evaluate(grad), [3., 0., 0., 0., 0.])
|
self.assertAllEqual(self.evaluate(grad), [3., 0., 0., 0., 0.])
|
||||||
|
|
||||||
|
@ -81,6 +81,14 @@ def rewrite_grad_indexed_slices(grads, body_grad_graph, loop_vars,
|
|||||||
return loop_vars
|
return loop_vars
|
||||||
|
|
||||||
|
|
||||||
|
def _get_tensor_index_in_iterable(iterable, t):
|
||||||
|
"""Returns index of first occurence of `t`, raises ValueError if not found."""
|
||||||
|
for i, elem in enumerate(iterable):
|
||||||
|
if t is elem:
|
||||||
|
return i
|
||||||
|
raise ValueError("%s is not in iterable" % str(t))
|
||||||
|
|
||||||
|
|
||||||
def _rewrite_output_as_tensor(body_grad_graph, grad_output_slices):
|
def _rewrite_output_as_tensor(body_grad_graph, grad_output_slices):
|
||||||
"""Rewrites grad_output_slices to be a Tensor output.
|
"""Rewrites grad_output_slices to be a Tensor output.
|
||||||
|
|
||||||
@ -91,7 +99,8 @@ def _rewrite_output_as_tensor(body_grad_graph, grad_output_slices):
|
|||||||
with body_grad_graph.as_default():
|
with body_grad_graph.as_default():
|
||||||
new_output = ops.convert_to_tensor_v2(grad_output_slices)
|
new_output = ops.convert_to_tensor_v2(grad_output_slices)
|
||||||
|
|
||||||
idx = body_grad_graph.structured_outputs.index(grad_output_slices)
|
idx = _get_tensor_index_in_iterable(body_grad_graph.structured_outputs,
|
||||||
|
grad_output_slices)
|
||||||
body_grad_graph.structured_outputs[idx] = new_output
|
body_grad_graph.structured_outputs[idx] = new_output
|
||||||
body_grad_graph.outputs = func_graph.flatten(
|
body_grad_graph.outputs = func_graph.flatten(
|
||||||
body_grad_graph.structured_outputs)
|
body_grad_graph.structured_outputs)
|
||||||
@ -259,11 +268,14 @@ def _update_indexed_slices_param(graph, loop_vars, init_slices, input_slices,
|
|||||||
Returns:
|
Returns:
|
||||||
New loop_vars to pass to graph.
|
New loop_vars to pass to graph.
|
||||||
"""
|
"""
|
||||||
structured_idx = graph.structured_outputs.index(old_output_slices)
|
structured_idx = _get_tensor_index_in_iterable(graph.structured_outputs,
|
||||||
|
old_output_slices)
|
||||||
# We assume that the component tensors of old_output_slices appear
|
# We assume that the component tensors of old_output_slices appear
|
||||||
# sequentially in graph.outputs. We use the first of these tensors
|
# sequentially in graph.outputs. We use the first of these tensors
|
||||||
# as the reference index.
|
# as the reference index.
|
||||||
flat_idx = graph.outputs.index(func_graph.flatten(old_output_slices)[0])
|
flat_idx = _get_tensor_index_in_iterable(
|
||||||
|
graph.outputs,
|
||||||
|
func_graph.flatten(old_output_slices)[0])
|
||||||
|
|
||||||
graph.structured_outputs[structured_idx] = output_slices
|
graph.structured_outputs[structured_idx] = output_slices
|
||||||
graph.outputs = func_graph.flatten(
|
graph.outputs = func_graph.flatten(
|
||||||
|
Loading…
x
Reference in New Issue
Block a user