tf numpy: some changes to ndarray constructor logic.
PiperOrigin-RevId: 317765968 Change-Id: Iea4338ad18707ff36fc49b450d0defad5c13a6a2
This commit is contained in:
parent
1a3b7af373
commit
c692c45dae
@ -141,13 +141,12 @@ class ndarray(composite_tensor.CompositeTensor): # pylint: disable=invalid-name
|
|||||||
raise ValueError('Unexpected type for `buffer` {}. Must be an ndarray,'
|
raise ValueError('Unexpected type for `buffer` {}. Must be an ndarray,'
|
||||||
' Tensor or np.ndarray.'.format(type(buffer)))
|
' Tensor or np.ndarray.'.format(type(buffer)))
|
||||||
|
|
||||||
if shape is not None and tuple(shape) != buffer._shape_tuple(): # pylint: disable=protected-access
|
if shape is not None:
|
||||||
# TODO(srbs): NumPy allows this. Investigate if/how to support this.
|
buffer.set_shape(shape)
|
||||||
raise ValueError('shape arg must match buffer.shape.')
|
|
||||||
|
|
||||||
assert isinstance(buffer, ops.Tensor)
|
assert isinstance(buffer, ops.Tensor)
|
||||||
if dtype and dtype != buffer.dtype:
|
if dtype and dtype != buffer.dtype:
|
||||||
buffer = array_ops.bitcast(buffer, dtype)
|
buffer = math_ops.cast(buffer, dtype)
|
||||||
self._data = buffer
|
self._data = buffer
|
||||||
self._type_spec_internal = None
|
self._type_spec_internal = None
|
||||||
|
|
||||||
|
@ -22,6 +22,7 @@ import collections
|
|||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
|
from tensorflow.python.framework import constant_op
|
||||||
from tensorflow.python.framework import dtypes
|
from tensorflow.python.framework import dtypes
|
||||||
from tensorflow.python.framework import ops
|
from tensorflow.python.framework import ops
|
||||||
from tensorflow.python.ops import array_ops
|
from tensorflow.python.ops import array_ops
|
||||||
@ -51,6 +52,19 @@ class ArrayTest(test.TestCase):
|
|||||||
self.assertIs(a.dtype.type, np.bool_)
|
self.assertIs(a.dtype.type, np.bool_)
|
||||||
self.assertAllEqual([False, True], a)
|
self.assertAllEqual([False, True], a)
|
||||||
|
|
||||||
|
def testConstructor(self):
|
||||||
|
t = constant_op.constant([[1], [1]])
|
||||||
|
a = np_arrays.ndarray(shape=(2, 1), buffer=t)
|
||||||
|
self.assertAllEqual(t, a)
|
||||||
|
self.assertEqual(dtypes.float64, a.dtype)
|
||||||
|
|
||||||
|
a = np_arrays.ndarray(shape=(2, 1), dtype=dtypes.int32, buffer=t)
|
||||||
|
self.assertAllEqual(t, a)
|
||||||
|
self.assertEqual(dtypes.int32, a.dtype)
|
||||||
|
|
||||||
|
with self.assertRaises(ValueError): # bad shape
|
||||||
|
_ = np_arrays.ndarray((2, 2), buffer=t)
|
||||||
|
|
||||||
def testNeg(self):
|
def testNeg(self):
|
||||||
a = t2a(ops.convert_to_tensor(value=[1.0, 2.0]))
|
a = t2a(ops.convert_to_tensor(value=[1.0, 2.0]))
|
||||||
self.assertAllEqual([-1.0, -2.0], -a)
|
self.assertAllEqual([-1.0, -2.0], -a)
|
||||||
|
Loading…
Reference in New Issue
Block a user