Allows tf.convert_to_tensor to work with zero-dimensional numpy arrays
PiperOrigin-RevId: 232717709
This commit is contained in:
parent
8b633af6ff
commit
324c339486
@ -339,6 +339,24 @@ class TFETensorTest(test_util.TensorFlowTestCase):
|
||||
def testConvertToTensorAllowsOverflow(self):
|
||||
_ = ops.convert_to_tensor(123456789, dtype=dtypes.uint8)
|
||||
|
||||
@test_util.assert_no_new_pyobjects_executing_eagerly
|
||||
@test_util.run_in_graph_and_eager_modes
|
||||
def testConvertToTensorNumpyZeroDim(self):
|
||||
for np_type, dtype in [(np.int32, dtypes.int32),
|
||||
(np.half, dtypes.half),
|
||||
(np.float32, dtypes.float32)]:
|
||||
x = ops.convert_to_tensor([np.array(65, dtype=np_type),
|
||||
np.array(16, dtype=np_type)])
|
||||
self.assertEqual(x.dtype, dtype)
|
||||
self.assertAllEqual(x, [65, 16])
|
||||
|
||||
@test_util.assert_no_new_pyobjects_executing_eagerly
|
||||
@test_util.run_in_graph_and_eager_modes
|
||||
def testConvertToTensorNumpyScalar(self):
|
||||
x = ops.convert_to_tensor([np.asscalar(np.array(321, dtype=np.int)),
|
||||
np.asscalar(np.array(16, dtype=np.int))])
|
||||
self.assertAllEqual(x, [321, 16])
|
||||
|
||||
def testEagerTensorError(self):
|
||||
with self.assertRaisesRegexp(
|
||||
TypeError,
|
||||
@ -347,7 +365,6 @@ class TFETensorTest(test_util.TensorFlowTestCase):
|
||||
_ = ops.convert_to_tensor(1., dtype=dtypes.int32)
|
||||
|
||||
|
||||
|
||||
class TFETensorUtilTest(test_util.TensorFlowTestCase):
|
||||
|
||||
def testListOfThree(self):
|
||||
|
@ -64,6 +64,19 @@ bool IsPyFloat(PyObject* obj) {
|
||||
PyIsInstance(obj, &PyFloatingArrType_Type); // NumPy float types
|
||||
}
|
||||
|
||||
// If the input is a zero dimensional PyArray return it converted to a scalar.
|
||||
// Otherwise return the input and increment its reference count.
|
||||
// Users must Py_DECREF the output of this method.
|
||||
PyObject* ZeroDimArrayToScalar(PyObject* obj) {
|
||||
if (PyArray_IsZeroDim(obj) && !PyArray_IsScalar(obj, Generic)) {
|
||||
auto pyarray_obj = reinterpret_cast<PyArrayObject*>(obj);
|
||||
obj = PyArray_ToScalar(PyArray_DATA(pyarray_obj), pyarray_obj);
|
||||
} else {
|
||||
Py_INCREF(obj);
|
||||
}
|
||||
return obj;
|
||||
}
|
||||
|
||||
// Converts Python object `c` that should hold a Python string into a
|
||||
// C++ string in *out. Returns nullptr on success, or a message on error.
|
||||
// Defined below, but forward declared here for use in PyRepr.
|
||||
@ -130,6 +143,10 @@ Status SampleElementFromSequence(PyObject* seq, PyObject** elem) {
|
||||
Status InferShapeAndType(PyObject* obj, TensorShape* shape, DataType* dtype) {
|
||||
std::vector<Safe_PyObjectPtr> refs_to_clean;
|
||||
while (true) {
|
||||
// Convert any zero dimensional numpy arrays to scalars first of all.
|
||||
// We also have to make sure a reference to the safe_obj is kept.
|
||||
obj = ZeroDimArrayToScalar(obj);
|
||||
refs_to_clean.push_back(make_safe(obj));
|
||||
// We test strings first, in case a string is considered a sequence.
|
||||
if (IsPyString(obj)) {
|
||||
*dtype = DT_STRING;
|
||||
@ -240,7 +257,9 @@ const char ErrorFoundFloat[] =
|
||||
} \
|
||||
PyObject** l = PySequence_Fast_ITEMS(seq.get()); \
|
||||
for (int64 i = 0; i < s; ++i) { \
|
||||
const char* error = CONVERT(l[i], *buf); \
|
||||
auto scalar = ZeroDimArrayToScalar(l[i]); \
|
||||
const char* error = CONVERT(scalar, *buf); \
|
||||
Py_DECREF(scalar); \
|
||||
if (TF_PREDICT_FALSE(error != nullptr)) return error; \
|
||||
++*buf; \
|
||||
} \
|
||||
@ -253,7 +272,9 @@ const char ErrorFoundFloat[] =
|
||||
Tensor result(TYPE_ENUM, shape); \
|
||||
if (shape.dims() == 0) { /* Scalar case */ \
|
||||
TYPE value; \
|
||||
const char* error = CONVERT(obj, &value); \
|
||||
auto scalar = ZeroDimArrayToScalar(obj); \
|
||||
const char* error = CONVERT(scalar, &value); \
|
||||
Py_DECREF(scalar); \
|
||||
if (error != nullptr) return error; \
|
||||
result.scalar<TYPE>()() = value; \
|
||||
} else { \
|
||||
|
Loading…
Reference in New Issue
Block a user