Add test for element_shape arg of TensorListFromTensor.

PiperOrigin-RevId: 345251392
Change-Id: I84830aeb2044578e6530e92dd5336613793dbc82
This commit is contained in:
Saurabh Saxena 2020-12-02 09:45:11 -08:00 committed by TensorFlower Gardener
parent bea9ecb9aa
commit c7685c0ee7

View File

@ -1711,6 +1711,20 @@ class ListOpsTest(test_util.TensorFlowTestCase, parameterized.TestCase):
self.assertAllEqual(f(), -1)
def testElementShapeArgOfTensorListFromTensor(self):
@def_function.function
def f():
t = array_ops.ones([3, 3])
l = list_ops.tensor_list_from_tensor(t, element_shape=[-1])
l = list_ops.tensor_list_push_back(l, array_ops.ones([4]))
read_val = list_ops.tensor_list_get_item(
l, 3, element_dtype=dtypes.float32)
self.assertAllEqual(read_val.shape.as_list(), [None])
return read_val
self.assertAllEqual(f(), [1.0, 1.0, 1.0, 1.0])
if __name__ == "__main__":
test.main()