Add test for passing tensor list handle_data across function calls.
PiperOrigin-RevId: 244265425
This commit is contained in:
parent
5aa52f33d3
commit
1260d2cac3
@ -131,6 +131,7 @@ cuda_py_test(
|
||||
"//tensorflow/python:math_ops",
|
||||
"//tensorflow/python:tensor_shape",
|
||||
"//tensorflow/python/eager:context",
|
||||
"//tensorflow/python/eager:def_function",
|
||||
"//tensorflow/python:framework_for_generated_wrappers",
|
||||
"//tensorflow/python:framework_test_lib",
|
||||
"//tensorflow/python:client_testlib",
|
||||
|
@ -25,6 +25,7 @@ import numpy as np # pylint: disable=unused-import
|
||||
from tensorflow.python.client import session
|
||||
from tensorflow.python.eager import backprop
|
||||
from tensorflow.python.eager import context
|
||||
from tensorflow.python.eager import def_function
|
||||
from tensorflow.python.framework import constant_op
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import errors
|
||||
@ -1561,6 +1562,19 @@ class ListOpsTest(test_util.TensorFlowTestCase, parameterized.TestCase):
|
||||
grad = gradients_impl.gradients(t1, t)[0]
|
||||
self.assertAllEqual(self.evaluate(grad), [1., 1., 1.])
|
||||
|
||||
def testHandleDataAcrossFunctionCall(self):
|
||||
|
||||
@def_function.function
|
||||
def func():
|
||||
t = constant_op.constant([1., 2., 3.])
|
||||
l = list_ops.tensor_list_from_tensor(t, element_shape=[])
|
||||
return l
|
||||
|
||||
tensor_list = func()
|
||||
element = list_ops.tensor_list_get_item(
|
||||
tensor_list, 0, element_dtype=dtypes.float32)
|
||||
self.assertAllEqual(element.shape.as_list(), [])
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test.main()
|
||||
|
Loading…
Reference in New Issue
Block a user