Allow TensorArray to be initialized with numpy array for shape.

PiperOrigin-RevId: 338453635
Change-Id: I981c692f8ea4efab0b2d0963b0c3ffbe6e68cd66
This commit is contained in:
Colin Carroll 2020-10-22 05:33:04 -07:00 committed by TensorFlower Gardener
parent 9a22465307
commit 10b5284c9d
3 changed files with 20 additions and 2 deletions

View File

@ -4355,6 +4355,7 @@ py_library(
":array_ops", ":array_ops",
":handle_data_util", ":handle_data_util",
":list_ops_gen", ":list_ops_gen",
"//third_party/py/numpy",
], ],
) )
@ -5455,6 +5456,7 @@ py_test(
":array_ops", ":array_ops",
":client", ":client",
":client_testlib", ":client_testlib",
"//third_party/py/numpy",
], ],
) )

View File

@ -19,6 +19,8 @@ from __future__ import absolute_import
from __future__ import division from __future__ import division
from __future__ import print_function from __future__ import print_function
import numpy as np
from tensorflow.core.framework import types_pb2 from tensorflow.core.framework import types_pb2
from tensorflow.python.framework import cpp_shape_inference_pb2 from tensorflow.python.framework import cpp_shape_inference_pb2
from tensorflow.python.framework import dtypes from tensorflow.python.framework import dtypes
@ -395,8 +397,8 @@ def _build_element_shape(shape):
# Shape is unknown. # Shape is unknown.
if shape is None: if shape is None:
return -1 return -1
# Shape is a scalar. # Shape is numpy array or a scalar.
if not shape: if isinstance(shape, (np.ndarray, np.generic)) or not shape:
return ops.convert_to_tensor(shape, dtype=dtypes.int32) return ops.convert_to_tensor(shape, dtype=dtypes.int32)
# Shape is a sequence of dimensions. Convert None dims to -1. # Shape is a sequence of dimensions. Convert None dims to -1.
def convert(val): def convert(val):

View File

@ -18,6 +18,8 @@ from __future__ import absolute_import
from __future__ import division from __future__ import division
from __future__ import print_function from __future__ import print_function
import numpy as np
from tensorflow.python.eager import def_function from tensorflow.python.eager import def_function
from tensorflow.python.framework import constant_op from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes from tensorflow.python.framework import dtypes
@ -72,6 +74,18 @@ class TensorArrayOpsTest(test.TestCase):
self.assertAllEqual(fn(['a', 'b', 'c'], ['c', 'd', 'e']), self.assertAllEqual(fn(['a', 'b', 'c'], ['c', 'd', 'e']),
[b'a', b'b', b'c', b'c', b'd', b'e']) [b'a', b'b', b'c', b'c', b'd', b'e'])
def test_init_numpy_shape(self):
@def_function.function
def fn():
values = tensor_array_ops.TensorArray(
np.float32,
size=1,
dynamic_size=False,
element_shape=np.array((2, 3)))
values = values.write(0, np.ones((2, 3)))
return values.concat()
self.assertAllEqual(fn(), [[1., 1., 1.], [1., 1., 1.]])
if __name__ == '__main__': if __name__ == '__main__':
test.main() test.main()