Add a pfor converter for TensorListPopBack
PiperOrigin-RevId: 338299497 Change-Id: If0c00525e2628b8db42c46c11d32b1fbb3a96c7a
This commit is contained in:
parent
1c7803fa56
commit
dbc850494b
@ -43,6 +43,7 @@ from tensorflow.python.ops import control_flow_ops
|
|||||||
from tensorflow.python.ops import custom_gradient
|
from tensorflow.python.ops import custom_gradient
|
||||||
from tensorflow.python.ops import embedding_ops
|
from tensorflow.python.ops import embedding_ops
|
||||||
from tensorflow.python.ops import functional_ops
|
from tensorflow.python.ops import functional_ops
|
||||||
|
from tensorflow.python.ops import gradient_checker_v2
|
||||||
from tensorflow.python.ops import gradients
|
from tensorflow.python.ops import gradients
|
||||||
from tensorflow.python.ops import math_ops
|
from tensorflow.python.ops import math_ops
|
||||||
from tensorflow.python.ops import nn
|
from tensorflow.python.ops import nn
|
||||||
@ -1713,6 +1714,35 @@ class JacobianTest(test.TestCase):
|
|||||||
dy_xx_answer = [[[2., 0], [0, 2.]]] * 10
|
dy_xx_answer = [[[2., 0], [0, 2.]]] * 10
|
||||||
self.assertAllClose(dy_xx_answer, self.evaluate(dy_xx))
|
self.assertAllClose(dy_xx_answer, self.evaluate(dy_xx))
|
||||||
|
|
||||||
|
def test_nested_batch_jacobian_foldl(self):
|
||||||
|
def _grad(f):
|
||||||
|
def _grad_function(primal):
|
||||||
|
with backprop.GradientTape() as tape:
|
||||||
|
tape.watch(primal)
|
||||||
|
primal_out = f(primal)
|
||||||
|
return tape.batch_jacobian(primal_out, primal)
|
||||||
|
return _grad_function
|
||||||
|
|
||||||
|
def _func(x):
|
||||||
|
return array_ops.reshape(
|
||||||
|
functional_ops.foldl_v2(lambda a, b: math_ops.cos(a + b),
|
||||||
|
array_ops.transpose(x)),
|
||||||
|
[1, 1])
|
||||||
|
|
||||||
|
f = _func
|
||||||
|
x = constant_op.constant([[1., 2.]])
|
||||||
|
for _ in range(2):
|
||||||
|
theoretical, numerical = gradient_checker_v2.compute_gradient(f, [x])
|
||||||
|
self.assertAllClose(theoretical, numerical, rtol=1e-3)
|
||||||
|
f = _grad(f)
|
||||||
|
expected_flat = array_ops.reshape(numerical, [-1])
|
||||||
|
self.assertAllClose(expected_flat,
|
||||||
|
array_ops.reshape(f(x), [-1]),
|
||||||
|
rtol=1e-3)
|
||||||
|
self.assertAllClose(expected_flat,
|
||||||
|
array_ops.reshape(def_function.function(f)(x), [-1]),
|
||||||
|
rtol=1e-3)
|
||||||
|
|
||||||
@test_util.run_in_graph_and_eager_modes
|
@test_util.run_in_graph_and_eager_modes
|
||||||
def test_indexed_slices(self):
|
def test_indexed_slices(self):
|
||||||
with backprop.GradientTape(persistent=True) as g:
|
with backprop.GradientTape(persistent=True) as g:
|
||||||
|
@ -992,6 +992,44 @@ class TensorListTest(PForTestCase):
|
|||||||
|
|
||||||
self._test_loop_fn(loop_fn, 3)
|
self._test_loop_fn(loop_fn, 3)
|
||||||
|
|
||||||
|
def test_pop_back_no_shape(self):
|
||||||
|
|
||||||
|
def loop_fn(i):
|
||||||
|
handle = list_ops.tensor_list_reserve([2], 2, dtypes.int32)
|
||||||
|
handle = list_ops.tensor_list_push_back(handle, [1, 2])
|
||||||
|
handle = list_ops.tensor_list_push_back(handle, [i, 2])
|
||||||
|
handle, tensor = list_ops.tensor_list_pop_back(handle, dtypes.int32)
|
||||||
|
return tensor, list_ops.tensor_list_stack(handle, dtypes.int32)
|
||||||
|
|
||||||
|
self._test_loop_fn(loop_fn, 3)
|
||||||
|
|
||||||
|
def test_pop_back_no_shape_capture(self):
|
||||||
|
h = list_ops.tensor_list_reserve([2], 1, dtypes.int32)
|
||||||
|
h = list_ops.tensor_list_push_back(h, [1, 2])
|
||||||
|
|
||||||
|
def loop_fn(i):
|
||||||
|
handle, tensor = list_ops.tensor_list_pop_back(h, dtypes.int32)
|
||||||
|
handle = list_ops.tensor_list_push_back(handle, [1, i])
|
||||||
|
return tensor, list_ops.tensor_list_stack(handle, dtypes.int32)
|
||||||
|
|
||||||
|
self._test_loop_fn(loop_fn, 3)
|
||||||
|
|
||||||
|
def test_pop_back_with_shape(self):
|
||||||
|
|
||||||
|
@def_function.function
|
||||||
|
def loop_fn(i):
|
||||||
|
with backprop.GradientTape() as tape:
|
||||||
|
handle = list_ops.tensor_list_reserve(None, 1, dtypes.float32)
|
||||||
|
x = math_ops.cast(i, dtypes.float32)[None]
|
||||||
|
tape.watch(x)
|
||||||
|
handle = list_ops.tensor_list_push_back(handle, x)
|
||||||
|
stacked = list_ops.tensor_list_stack(handle, dtypes.float32)
|
||||||
|
list_grad = tape.gradient(stacked, x, x)
|
||||||
|
self.assertEqual("TensorListPopBack", list_grad.op.type)
|
||||||
|
return list_grad, stacked, list_grad.op.inputs[1]
|
||||||
|
|
||||||
|
self._test_loop_fn(loop_fn, 3)
|
||||||
|
|
||||||
def test_create_outside_and_scatter(self):
|
def test_create_outside_and_scatter(self):
|
||||||
h = list_ops.tensor_list_reserve([2], 2, dtypes.int32)
|
h = list_ops.tensor_list_reserve([2], 2, dtypes.int32)
|
||||||
|
|
||||||
|
@ -3763,6 +3763,27 @@ def _convert_tensor_list_push_back(pfor_input):
|
|||||||
return wrap(_tile_variant(handle, pfor_input), True)
|
return wrap(_tile_variant(handle, pfor_input), True)
|
||||||
|
|
||||||
|
|
||||||
|
@RegisterPFor("TensorListPopBack")
|
||||||
|
def _convert_tensor_array_push_back(pfor_input):
|
||||||
|
handle = pfor_input.stacked_input(0)
|
||||||
|
element_shape = pfor_input.unstacked_input(1)
|
||||||
|
handle = _untile_variant(handle)
|
||||||
|
|
||||||
|
if element_shape.shape.ndims == 0:
|
||||||
|
# Default / unspecified
|
||||||
|
vectorized_shape = -1
|
||||||
|
else:
|
||||||
|
# PopBack has an element shape set when it's the gradient of PushBack, only
|
||||||
|
# used when the list is uninitialized.
|
||||||
|
vectorized_shape = array_ops.concat(
|
||||||
|
[pfor_input.pfor.loop_len_vector, element_shape], axis=0)
|
||||||
|
|
||||||
|
output_handle, tensor = gen_list_ops.tensor_list_pop_back(
|
||||||
|
input_handle=handle, element_dtype=pfor_input.get_attr("element_dtype"),
|
||||||
|
element_shape=vectorized_shape)
|
||||||
|
return wrap(output_handle, True), wrap(tensor, True)
|
||||||
|
|
||||||
|
|
||||||
@RegisterPFor("TensorListConcatV2")
|
@RegisterPFor("TensorListConcatV2")
|
||||||
def _convert_tensor_list_concat_v2(pfor_input):
|
def _convert_tensor_list_concat_v2(pfor_input):
|
||||||
input_handle = pfor_input.stacked_input(0)
|
input_handle = pfor_input.stacked_input(0)
|
||||||
|
Loading…
Reference in New Issue
Block a user