diff --git a/tensorflow/python/kernel_tests/BUILD b/tensorflow/python/kernel_tests/BUILD index 38c4735ffd0..f75aee95ce5 100644 --- a/tensorflow/python/kernel_tests/BUILD +++ b/tensorflow/python/kernel_tests/BUILD @@ -131,6 +131,7 @@ cuda_py_test( "//tensorflow/python:math_ops", "//tensorflow/python:tensor_shape", "//tensorflow/python/eager:context", + "//tensorflow/python/eager:def_function", "//tensorflow/python:framework_for_generated_wrappers", "//tensorflow/python:framework_test_lib", "//tensorflow/python:client_testlib", diff --git a/tensorflow/python/kernel_tests/list_ops_test.py b/tensorflow/python/kernel_tests/list_ops_test.py index f5fda3c9a46..f3bec9ddb72 100644 --- a/tensorflow/python/kernel_tests/list_ops_test.py +++ b/tensorflow/python/kernel_tests/list_ops_test.py @@ -25,6 +25,7 @@ import numpy as np # pylint: disable=unused-import from tensorflow.python.client import session from tensorflow.python.eager import backprop from tensorflow.python.eager import context +from tensorflow.python.eager import def_function from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes from tensorflow.python.framework import errors @@ -1561,6 +1562,19 @@ class ListOpsTest(test_util.TensorFlowTestCase, parameterized.TestCase): grad = gradients_impl.gradients(t1, t)[0] self.assertAllEqual(self.evaluate(grad), [1., 1., 1.]) + def testHandleDataAcrossFunctionCall(self): + + @def_function.function + def func(): + t = constant_op.constant([1., 2., 3.]) + l = list_ops.tensor_list_from_tensor(t, element_shape=[]) + return l + + tensor_list = func() + element = list_ops.tensor_list_get_item( + tensor_list, 0, element_dtype=dtypes.float32) + self.assertAllEqual(element.shape.as_list(), []) + if __name__ == "__main__": test.main()