Add a pfor converter for TensorListPopBack
PiperOrigin-RevId: 338299497 Change-Id: If0c00525e2628b8db42c46c11d32b1fbb3a96c7a
This commit is contained in:
parent
1c7803fa56
commit
dbc850494b
@ -43,6 +43,7 @@ from tensorflow.python.ops import control_flow_ops
|
||||
from tensorflow.python.ops import custom_gradient
|
||||
from tensorflow.python.ops import embedding_ops
|
||||
from tensorflow.python.ops import functional_ops
|
||||
from tensorflow.python.ops import gradient_checker_v2
|
||||
from tensorflow.python.ops import gradients
|
||||
from tensorflow.python.ops import math_ops
|
||||
from tensorflow.python.ops import nn
|
||||
@ -1713,6 +1714,35 @@ class JacobianTest(test.TestCase):
|
||||
dy_xx_answer = [[[2., 0], [0, 2.]]] * 10
|
||||
self.assertAllClose(dy_xx_answer, self.evaluate(dy_xx))
|
||||
|
||||
def test_nested_batch_jacobian_foldl(self):
|
||||
def _grad(f):
|
||||
def _grad_function(primal):
|
||||
with backprop.GradientTape() as tape:
|
||||
tape.watch(primal)
|
||||
primal_out = f(primal)
|
||||
return tape.batch_jacobian(primal_out, primal)
|
||||
return _grad_function
|
||||
|
||||
def _func(x):
|
||||
return array_ops.reshape(
|
||||
functional_ops.foldl_v2(lambda a, b: math_ops.cos(a + b),
|
||||
array_ops.transpose(x)),
|
||||
[1, 1])
|
||||
|
||||
f = _func
|
||||
x = constant_op.constant([[1., 2.]])
|
||||
for _ in range(2):
|
||||
theoretical, numerical = gradient_checker_v2.compute_gradient(f, [x])
|
||||
self.assertAllClose(theoretical, numerical, rtol=1e-3)
|
||||
f = _grad(f)
|
||||
expected_flat = array_ops.reshape(numerical, [-1])
|
||||
self.assertAllClose(expected_flat,
|
||||
array_ops.reshape(f(x), [-1]),
|
||||
rtol=1e-3)
|
||||
self.assertAllClose(expected_flat,
|
||||
array_ops.reshape(def_function.function(f)(x), [-1]),
|
||||
rtol=1e-3)
|
||||
|
||||
@test_util.run_in_graph_and_eager_modes
|
||||
def test_indexed_slices(self):
|
||||
with backprop.GradientTape(persistent=True) as g:
|
||||
|
@ -992,6 +992,44 @@ class TensorListTest(PForTestCase):
|
||||
|
||||
self._test_loop_fn(loop_fn, 3)
|
||||
|
||||
def test_pop_back_no_shape(self):
|
||||
|
||||
def loop_fn(i):
|
||||
handle = list_ops.tensor_list_reserve([2], 2, dtypes.int32)
|
||||
handle = list_ops.tensor_list_push_back(handle, [1, 2])
|
||||
handle = list_ops.tensor_list_push_back(handle, [i, 2])
|
||||
handle, tensor = list_ops.tensor_list_pop_back(handle, dtypes.int32)
|
||||
return tensor, list_ops.tensor_list_stack(handle, dtypes.int32)
|
||||
|
||||
self._test_loop_fn(loop_fn, 3)
|
||||
|
||||
def test_pop_back_no_shape_capture(self):
|
||||
h = list_ops.tensor_list_reserve([2], 1, dtypes.int32)
|
||||
h = list_ops.tensor_list_push_back(h, [1, 2])
|
||||
|
||||
def loop_fn(i):
|
||||
handle, tensor = list_ops.tensor_list_pop_back(h, dtypes.int32)
|
||||
handle = list_ops.tensor_list_push_back(handle, [1, i])
|
||||
return tensor, list_ops.tensor_list_stack(handle, dtypes.int32)
|
||||
|
||||
self._test_loop_fn(loop_fn, 3)
|
||||
|
||||
def test_pop_back_with_shape(self):
|
||||
|
||||
@def_function.function
|
||||
def loop_fn(i):
|
||||
with backprop.GradientTape() as tape:
|
||||
handle = list_ops.tensor_list_reserve(None, 1, dtypes.float32)
|
||||
x = math_ops.cast(i, dtypes.float32)[None]
|
||||
tape.watch(x)
|
||||
handle = list_ops.tensor_list_push_back(handle, x)
|
||||
stacked = list_ops.tensor_list_stack(handle, dtypes.float32)
|
||||
list_grad = tape.gradient(stacked, x, x)
|
||||
self.assertEqual("TensorListPopBack", list_grad.op.type)
|
||||
return list_grad, stacked, list_grad.op.inputs[1]
|
||||
|
||||
self._test_loop_fn(loop_fn, 3)
|
||||
|
||||
def test_create_outside_and_scatter(self):
|
||||
h = list_ops.tensor_list_reserve([2], 2, dtypes.int32)
|
||||
|
||||
|
@ -3763,6 +3763,27 @@ def _convert_tensor_list_push_back(pfor_input):
|
||||
return wrap(_tile_variant(handle, pfor_input), True)
|
||||
|
||||
|
||||
@RegisterPFor("TensorListPopBack")
|
||||
def _convert_tensor_array_push_back(pfor_input):
|
||||
handle = pfor_input.stacked_input(0)
|
||||
element_shape = pfor_input.unstacked_input(1)
|
||||
handle = _untile_variant(handle)
|
||||
|
||||
if element_shape.shape.ndims == 0:
|
||||
# Default / unspecified
|
||||
vectorized_shape = -1
|
||||
else:
|
||||
# PopBack has an element shape set when it's the gradient of PushBack, only
|
||||
# used when the list is uninitialized.
|
||||
vectorized_shape = array_ops.concat(
|
||||
[pfor_input.pfor.loop_len_vector, element_shape], axis=0)
|
||||
|
||||
output_handle, tensor = gen_list_ops.tensor_list_pop_back(
|
||||
input_handle=handle, element_dtype=pfor_input.get_attr("element_dtype"),
|
||||
element_shape=vectorized_shape)
|
||||
return wrap(output_handle, True), wrap(tensor, True)
|
||||
|
||||
|
||||
@RegisterPFor("TensorListConcatV2")
|
||||
def _convert_tensor_list_concat_v2(pfor_input):
|
||||
input_handle = pfor_input.stacked_input(0)
|
||||
|
Loading…
Reference in New Issue
Block a user