tf numpy: some changes to ndarray constructor logic.

PiperOrigin-RevId: 317765968
Change-Id: Iea4338ad18707ff36fc49b450d0defad5c13a6a2
This commit is contained in:
A. Unique TensorFlower 2020-06-22 17:10:15 -07:00 committed by TensorFlower Gardener
parent 1a3b7af373
commit c692c45dae
2 changed files with 17 additions and 4 deletions

View File

@ -141,13 +141,12 @@ class ndarray(composite_tensor.CompositeTensor): # pylint: disable=invalid-name
raise ValueError('Unexpected type for `buffer` {}. Must be an ndarray,'
' Tensor or np.ndarray.'.format(type(buffer)))
if shape is not None and tuple(shape) != buffer._shape_tuple(): # pylint: disable=protected-access
# TODO(srbs): NumPy allows this. Investigate if/how to support this.
raise ValueError('shape arg must match buffer.shape.')
if shape is not None:
buffer.set_shape(shape)
assert isinstance(buffer, ops.Tensor)
if dtype and dtype != buffer.dtype:
buffer = array_ops.bitcast(buffer, dtype)
buffer = math_ops.cast(buffer, dtype)
self._data = buffer
self._type_spec_internal = None

View File

@ -22,6 +22,7 @@ import collections
import numpy as np
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.ops import array_ops
@ -51,6 +52,19 @@ class ArrayTest(test.TestCase):
self.assertIs(a.dtype.type, np.bool_)
self.assertAllEqual([False, True], a)
def testConstructor(self):
t = constant_op.constant([[1], [1]])
a = np_arrays.ndarray(shape=(2, 1), buffer=t)
self.assertAllEqual(t, a)
self.assertEqual(dtypes.float64, a.dtype)
a = np_arrays.ndarray(shape=(2, 1), dtype=dtypes.int32, buffer=t)
self.assertAllEqual(t, a)
self.assertEqual(dtypes.int32, a.dtype)
with self.assertRaises(ValueError): # bad shape
_ = np_arrays.ndarray((2, 2), buffer=t)
def testNeg(self):
a = t2a(ops.convert_to_tensor(value=[1.0, 2.0]))
self.assertAllEqual([-1.0, -2.0], -a)