From c692c45daef9e4a1d5937f705fe8a3977b04fddd Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Mon, 22 Jun 2020 17:10:15 -0700 Subject: [PATCH] tf numpy: some changes to ndarray constructor logic. PiperOrigin-RevId: 317765968 Change-Id: Iea4338ad18707ff36fc49b450d0defad5c13a6a2 --- tensorflow/python/ops/numpy_ops/np_arrays.py | 7 +++---- tensorflow/python/ops/numpy_ops/np_arrays_test.py | 14 ++++++++++++++ 2 files changed, 17 insertions(+), 4 deletions(-) diff --git a/tensorflow/python/ops/numpy_ops/np_arrays.py b/tensorflow/python/ops/numpy_ops/np_arrays.py index eca84421d1b..77157544e8f 100644 --- a/tensorflow/python/ops/numpy_ops/np_arrays.py +++ b/tensorflow/python/ops/numpy_ops/np_arrays.py @@ -141,13 +141,12 @@ class ndarray(composite_tensor.CompositeTensor): # pylint: disable=invalid-name raise ValueError('Unexpected type for `buffer` {}. Must be an ndarray,' ' Tensor or np.ndarray.'.format(type(buffer))) - if shape is not None and tuple(shape) != buffer._shape_tuple(): # pylint: disable=protected-access - # TODO(srbs): NumPy allows this. Investigate if/how to support this. - raise ValueError('shape arg must match buffer.shape.') + if shape is not None: + buffer.set_shape(shape) assert isinstance(buffer, ops.Tensor) if dtype and dtype != buffer.dtype: - buffer = array_ops.bitcast(buffer, dtype) + buffer = math_ops.cast(buffer, dtype) self._data = buffer self._type_spec_internal = None diff --git a/tensorflow/python/ops/numpy_ops/np_arrays_test.py b/tensorflow/python/ops/numpy_ops/np_arrays_test.py index 412addc0ad7..ab407d2bfcf 100644 --- a/tensorflow/python/ops/numpy_ops/np_arrays_test.py +++ b/tensorflow/python/ops/numpy_ops/np_arrays_test.py @@ -22,6 +22,7 @@ import collections import numpy as np +from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops from tensorflow.python.ops import array_ops @@ -51,6 +52,19 @@ class ArrayTest(test.TestCase): self.assertIs(a.dtype.type, np.bool_) self.assertAllEqual([False, True], a) + def testConstructor(self): + t = constant_op.constant([[1], [1]]) + a = np_arrays.ndarray(shape=(2, 1), buffer=t) + self.assertAllEqual(t, a) + self.assertEqual(dtypes.float64, a.dtype) + + a = np_arrays.ndarray(shape=(2, 1), dtype=dtypes.int32, buffer=t) + self.assertAllEqual(t, a) + self.assertEqual(dtypes.int32, a.dtype) + + with self.assertRaises(ValueError): # bad shape + _ = np_arrays.ndarray((2, 2), buffer=t) + def testNeg(self): a = t2a(ops.convert_to_tensor(value=[1.0, 2.0])) self.assertAllEqual([-1.0, -2.0], -a)