Allow TensorArray to be initialized with numpy array for shape.
PiperOrigin-RevId: 338453635 Change-Id: I981c692f8ea4efab0b2d0963b0c3ffbe6e68cd66
This commit is contained in:
parent
9a22465307
commit
10b5284c9d
tensorflow/python
@ -4355,6 +4355,7 @@ py_library(
|
||||
":array_ops",
|
||||
":handle_data_util",
|
||||
":list_ops_gen",
|
||||
"//third_party/py/numpy",
|
||||
],
|
||||
)
|
||||
|
||||
@ -5455,6 +5456,7 @@ py_test(
|
||||
":array_ops",
|
||||
":client",
|
||||
":client_testlib",
|
||||
"//third_party/py/numpy",
|
||||
],
|
||||
)
|
||||
|
||||
|
@ -19,6 +19,8 @@ from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import numpy as np
|
||||
|
||||
from tensorflow.core.framework import types_pb2
|
||||
from tensorflow.python.framework import cpp_shape_inference_pb2
|
||||
from tensorflow.python.framework import dtypes
|
||||
@ -395,8 +397,8 @@ def _build_element_shape(shape):
|
||||
# Shape is unknown.
|
||||
if shape is None:
|
||||
return -1
|
||||
# Shape is a scalar.
|
||||
if not shape:
|
||||
# Shape is numpy array or a scalar.
|
||||
if isinstance(shape, (np.ndarray, np.generic)) or not shape:
|
||||
return ops.convert_to_tensor(shape, dtype=dtypes.int32)
|
||||
# Shape is a sequence of dimensions. Convert None dims to -1.
|
||||
def convert(val):
|
||||
|
@ -18,6 +18,8 @@ from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import numpy as np
|
||||
|
||||
from tensorflow.python.eager import def_function
|
||||
from tensorflow.python.framework import constant_op
|
||||
from tensorflow.python.framework import dtypes
|
||||
@ -72,6 +74,18 @@ class TensorArrayOpsTest(test.TestCase):
|
||||
self.assertAllEqual(fn(['a', 'b', 'c'], ['c', 'd', 'e']),
|
||||
[b'a', b'b', b'c', b'c', b'd', b'e'])
|
||||
|
||||
def test_init_numpy_shape(self):
|
||||
@def_function.function
|
||||
def fn():
|
||||
values = tensor_array_ops.TensorArray(
|
||||
np.float32,
|
||||
size=1,
|
||||
dynamic_size=False,
|
||||
element_shape=np.array((2, 3)))
|
||||
values = values.write(0, np.ones((2, 3)))
|
||||
return values.concat()
|
||||
self.assertAllEqual(fn(), [[1., 1., 1.], [1., 1., 1.]])
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test.main()
|
||||
|
Loading…
Reference in New Issue
Block a user