Add test for passing tensor list handle_data across function calls.

PiperOrigin-RevId: 244265425
This commit is contained in:
Saurabh Saxena 2019-04-18 15:01:09 -07:00 committed by TensorFlower Gardener
parent 5aa52f33d3
commit 1260d2cac3
2 changed files with 15 additions and 0 deletions

View File

@ -131,6 +131,7 @@ cuda_py_test(
"//tensorflow/python:math_ops",
"//tensorflow/python:tensor_shape",
"//tensorflow/python/eager:context",
"//tensorflow/python/eager:def_function",
"//tensorflow/python:framework_for_generated_wrappers",
"//tensorflow/python:framework_test_lib",
"//tensorflow/python:client_testlib",

View File

@ -25,6 +25,7 @@ import numpy as np # pylint: disable=unused-import
from tensorflow.python.client import session
from tensorflow.python.eager import backprop
from tensorflow.python.eager import context
from tensorflow.python.eager import def_function
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import errors
@ -1561,6 +1562,19 @@ class ListOpsTest(test_util.TensorFlowTestCase, parameterized.TestCase):
grad = gradients_impl.gradients(t1, t)[0]
self.assertAllEqual(self.evaluate(grad), [1., 1., 1.])
def testHandleDataAcrossFunctionCall(self):
@def_function.function
def func():
t = constant_op.constant([1., 2., 3.])
l = list_ops.tensor_list_from_tensor(t, element_shape=[])
return l
tensor_list = func()
element = list_ops.tensor_list_get_item(
tensor_list, 0, element_dtype=dtypes.float32)
self.assertAllEqual(element.shape.as_list(), [])
if __name__ == "__main__":
test.main()