Add a pfor converter for TensorListPopBack

PiperOrigin-RevId: 338299497
Change-Id: If0c00525e2628b8db42c46c11d32b1fbb3a96c7a
This commit is contained in:
Allen Lavoie 2020-10-21 11:06:05 -07:00 committed by TensorFlower Gardener
parent 1c7803fa56
commit dbc850494b
3 changed files with 89 additions and 0 deletions

View File

@ -43,6 +43,7 @@ from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import custom_gradient
from tensorflow.python.ops import embedding_ops
from tensorflow.python.ops import functional_ops
from tensorflow.python.ops import gradient_checker_v2
from tensorflow.python.ops import gradients
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import nn
@ -1713,6 +1714,35 @@ class JacobianTest(test.TestCase):
dy_xx_answer = [[[2., 0], [0, 2.]]] * 10
self.assertAllClose(dy_xx_answer, self.evaluate(dy_xx))
def test_nested_batch_jacobian_foldl(self):
def _grad(f):
def _grad_function(primal):
with backprop.GradientTape() as tape:
tape.watch(primal)
primal_out = f(primal)
return tape.batch_jacobian(primal_out, primal)
return _grad_function
def _func(x):
return array_ops.reshape(
functional_ops.foldl_v2(lambda a, b: math_ops.cos(a + b),
array_ops.transpose(x)),
[1, 1])
f = _func
x = constant_op.constant([[1., 2.]])
for _ in range(2):
theoretical, numerical = gradient_checker_v2.compute_gradient(f, [x])
self.assertAllClose(theoretical, numerical, rtol=1e-3)
f = _grad(f)
expected_flat = array_ops.reshape(numerical, [-1])
self.assertAllClose(expected_flat,
array_ops.reshape(f(x), [-1]),
rtol=1e-3)
self.assertAllClose(expected_flat,
array_ops.reshape(def_function.function(f)(x), [-1]),
rtol=1e-3)
@test_util.run_in_graph_and_eager_modes
def test_indexed_slices(self):
with backprop.GradientTape(persistent=True) as g:

View File

@ -992,6 +992,44 @@ class TensorListTest(PForTestCase):
self._test_loop_fn(loop_fn, 3)
def test_pop_back_no_shape(self):
def loop_fn(i):
handle = list_ops.tensor_list_reserve([2], 2, dtypes.int32)
handle = list_ops.tensor_list_push_back(handle, [1, 2])
handle = list_ops.tensor_list_push_back(handle, [i, 2])
handle, tensor = list_ops.tensor_list_pop_back(handle, dtypes.int32)
return tensor, list_ops.tensor_list_stack(handle, dtypes.int32)
self._test_loop_fn(loop_fn, 3)
def test_pop_back_no_shape_capture(self):
h = list_ops.tensor_list_reserve([2], 1, dtypes.int32)
h = list_ops.tensor_list_push_back(h, [1, 2])
def loop_fn(i):
handle, tensor = list_ops.tensor_list_pop_back(h, dtypes.int32)
handle = list_ops.tensor_list_push_back(handle, [1, i])
return tensor, list_ops.tensor_list_stack(handle, dtypes.int32)
self._test_loop_fn(loop_fn, 3)
def test_pop_back_with_shape(self):
@def_function.function
def loop_fn(i):
with backprop.GradientTape() as tape:
handle = list_ops.tensor_list_reserve(None, 1, dtypes.float32)
x = math_ops.cast(i, dtypes.float32)[None]
tape.watch(x)
handle = list_ops.tensor_list_push_back(handle, x)
stacked = list_ops.tensor_list_stack(handle, dtypes.float32)
list_grad = tape.gradient(stacked, x, x)
self.assertEqual("TensorListPopBack", list_grad.op.type)
return list_grad, stacked, list_grad.op.inputs[1]
self._test_loop_fn(loop_fn, 3)
def test_create_outside_and_scatter(self):
h = list_ops.tensor_list_reserve([2], 2, dtypes.int32)

View File

@ -3763,6 +3763,27 @@ def _convert_tensor_list_push_back(pfor_input):
return wrap(_tile_variant(handle, pfor_input), True)
@RegisterPFor("TensorListPopBack")
def _convert_tensor_array_push_back(pfor_input):
handle = pfor_input.stacked_input(0)
element_shape = pfor_input.unstacked_input(1)
handle = _untile_variant(handle)
if element_shape.shape.ndims == 0:
# Default / unspecified
vectorized_shape = -1
else:
# PopBack has an element shape set when it's the gradient of PushBack, only
# used when the list is uninitialized.
vectorized_shape = array_ops.concat(
[pfor_input.pfor.loop_len_vector, element_shape], axis=0)
output_handle, tensor = gen_list_ops.tensor_list_pop_back(
input_handle=handle, element_dtype=pfor_input.get_attr("element_dtype"),
element_shape=vectorized_shape)
return wrap(output_handle, True), wrap(tensor, True)
@RegisterPFor("TensorListConcatV2")
def _convert_tensor_list_concat_v2(pfor_input):
input_handle = pfor_input.stacked_input(0)