Add test for passing tensor list handle_data across function calls.
PiperOrigin-RevId: 244265425
This commit is contained in:
parent
5aa52f33d3
commit
1260d2cac3
@ -131,6 +131,7 @@ cuda_py_test(
|
|||||||
"//tensorflow/python:math_ops",
|
"//tensorflow/python:math_ops",
|
||||||
"//tensorflow/python:tensor_shape",
|
"//tensorflow/python:tensor_shape",
|
||||||
"//tensorflow/python/eager:context",
|
"//tensorflow/python/eager:context",
|
||||||
|
"//tensorflow/python/eager:def_function",
|
||||||
"//tensorflow/python:framework_for_generated_wrappers",
|
"//tensorflow/python:framework_for_generated_wrappers",
|
||||||
"//tensorflow/python:framework_test_lib",
|
"//tensorflow/python:framework_test_lib",
|
||||||
"//tensorflow/python:client_testlib",
|
"//tensorflow/python:client_testlib",
|
||||||
|
@ -25,6 +25,7 @@ import numpy as np # pylint: disable=unused-import
|
|||||||
from tensorflow.python.client import session
|
from tensorflow.python.client import session
|
||||||
from tensorflow.python.eager import backprop
|
from tensorflow.python.eager import backprop
|
||||||
from tensorflow.python.eager import context
|
from tensorflow.python.eager import context
|
||||||
|
from tensorflow.python.eager import def_function
|
||||||
from tensorflow.python.framework import constant_op
|
from tensorflow.python.framework import constant_op
|
||||||
from tensorflow.python.framework import dtypes
|
from tensorflow.python.framework import dtypes
|
||||||
from tensorflow.python.framework import errors
|
from tensorflow.python.framework import errors
|
||||||
@ -1561,6 +1562,19 @@ class ListOpsTest(test_util.TensorFlowTestCase, parameterized.TestCase):
|
|||||||
grad = gradients_impl.gradients(t1, t)[0]
|
grad = gradients_impl.gradients(t1, t)[0]
|
||||||
self.assertAllEqual(self.evaluate(grad), [1., 1., 1.])
|
self.assertAllEqual(self.evaluate(grad), [1., 1., 1.])
|
||||||
|
|
||||||
|
def testHandleDataAcrossFunctionCall(self):
|
||||||
|
|
||||||
|
@def_function.function
|
||||||
|
def func():
|
||||||
|
t = constant_op.constant([1., 2., 3.])
|
||||||
|
l = list_ops.tensor_list_from_tensor(t, element_shape=[])
|
||||||
|
return l
|
||||||
|
|
||||||
|
tensor_list = func()
|
||||||
|
element = list_ops.tensor_list_get_item(
|
||||||
|
tensor_list, 0, element_dtype=dtypes.float32)
|
||||||
|
self.assertAllEqual(element.shape.as_list(), [])
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
test.main()
|
test.main()
|
||||||
|
Loading…
x
Reference in New Issue
Block a user