Add test for element_shape arg of TensorListFromTensor.
PiperOrigin-RevId: 345251392 Change-Id: I84830aeb2044578e6530e92dd5336613793dbc82
This commit is contained in:
parent
bea9ecb9aa
commit
c7685c0ee7
@ -1711,6 +1711,20 @@ class ListOpsTest(test_util.TensorFlowTestCase, parameterized.TestCase):
|
||||
|
||||
self.assertAllEqual(f(), -1)
|
||||
|
||||
def testElementShapeArgOfTensorListFromTensor(self):
|
||||
|
||||
@def_function.function
|
||||
def f():
|
||||
t = array_ops.ones([3, 3])
|
||||
l = list_ops.tensor_list_from_tensor(t, element_shape=[-1])
|
||||
l = list_ops.tensor_list_push_back(l, array_ops.ones([4]))
|
||||
read_val = list_ops.tensor_list_get_item(
|
||||
l, 3, element_dtype=dtypes.float32)
|
||||
self.assertAllEqual(read_val.shape.as_list(), [None])
|
||||
return read_val
|
||||
|
||||
self.assertAllEqual(f(), [1.0, 1.0, 1.0, 1.0])
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test.main()
|
||||
|
Loading…
Reference in New Issue
Block a user