Allow TensorArray to be initialized with numpy array for shape.
PiperOrigin-RevId: 338453635 Change-Id: I981c692f8ea4efab0b2d0963b0c3ffbe6e68cd66
This commit is contained in:
parent
9a22465307
commit
10b5284c9d
@ -4355,6 +4355,7 @@ py_library(
|
|||||||
":array_ops",
|
":array_ops",
|
||||||
":handle_data_util",
|
":handle_data_util",
|
||||||
":list_ops_gen",
|
":list_ops_gen",
|
||||||
|
"//third_party/py/numpy",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -5455,6 +5456,7 @@ py_test(
|
|||||||
":array_ops",
|
":array_ops",
|
||||||
":client",
|
":client",
|
||||||
":client_testlib",
|
":client_testlib",
|
||||||
|
"//third_party/py/numpy",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -19,6 +19,8 @@ from __future__ import absolute_import
|
|||||||
from __future__ import division
|
from __future__ import division
|
||||||
from __future__ import print_function
|
from __future__ import print_function
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
from tensorflow.core.framework import types_pb2
|
from tensorflow.core.framework import types_pb2
|
||||||
from tensorflow.python.framework import cpp_shape_inference_pb2
|
from tensorflow.python.framework import cpp_shape_inference_pb2
|
||||||
from tensorflow.python.framework import dtypes
|
from tensorflow.python.framework import dtypes
|
||||||
@ -395,8 +397,8 @@ def _build_element_shape(shape):
|
|||||||
# Shape is unknown.
|
# Shape is unknown.
|
||||||
if shape is None:
|
if shape is None:
|
||||||
return -1
|
return -1
|
||||||
# Shape is a scalar.
|
# Shape is numpy array or a scalar.
|
||||||
if not shape:
|
if isinstance(shape, (np.ndarray, np.generic)) or not shape:
|
||||||
return ops.convert_to_tensor(shape, dtype=dtypes.int32)
|
return ops.convert_to_tensor(shape, dtype=dtypes.int32)
|
||||||
# Shape is a sequence of dimensions. Convert None dims to -1.
|
# Shape is a sequence of dimensions. Convert None dims to -1.
|
||||||
def convert(val):
|
def convert(val):
|
||||||
|
@ -18,6 +18,8 @@ from __future__ import absolute_import
|
|||||||
from __future__ import division
|
from __future__ import division
|
||||||
from __future__ import print_function
|
from __future__ import print_function
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
from tensorflow.python.eager import def_function
|
from tensorflow.python.eager import def_function
|
||||||
from tensorflow.python.framework import constant_op
|
from tensorflow.python.framework import constant_op
|
||||||
from tensorflow.python.framework import dtypes
|
from tensorflow.python.framework import dtypes
|
||||||
@ -72,6 +74,18 @@ class TensorArrayOpsTest(test.TestCase):
|
|||||||
self.assertAllEqual(fn(['a', 'b', 'c'], ['c', 'd', 'e']),
|
self.assertAllEqual(fn(['a', 'b', 'c'], ['c', 'd', 'e']),
|
||||||
[b'a', b'b', b'c', b'c', b'd', b'e'])
|
[b'a', b'b', b'c', b'c', b'd', b'e'])
|
||||||
|
|
||||||
|
def test_init_numpy_shape(self):
|
||||||
|
@def_function.function
|
||||||
|
def fn():
|
||||||
|
values = tensor_array_ops.TensorArray(
|
||||||
|
np.float32,
|
||||||
|
size=1,
|
||||||
|
dynamic_size=False,
|
||||||
|
element_shape=np.array((2, 3)))
|
||||||
|
values = values.write(0, np.ones((2, 3)))
|
||||||
|
return values.concat()
|
||||||
|
self.assertAllEqual(fn(), [[1., 1., 1.], [1., 1., 1.]])
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
test.main()
|
test.main()
|
||||||
|
Loading…
Reference in New Issue
Block a user