diff --git a/tensorflow/compiler/xla/python/local_computation_builder.i b/tensorflow/compiler/xla/python/local_computation_builder.i index 7b2f69d6ecf..df2ab0b539b 100644 --- a/tensorflow/compiler/xla/python/local_computation_builder.i +++ b/tensorflow/compiler/xla/python/local_computation_builder.i @@ -453,16 +453,6 @@ tensorflow::ImportNumpy(); // Literal -%typemap(out) StatusOr { - if ($1.ok()) { - Literal value = $1.ConsumeValueOrDie(); - $result = numpy::PyObjectFromXlaLiteral(*value); - } else { - PyErr_SetString(PyExc_RuntimeError, $1.status().ToString().c_str()); - SWIG_fail; - } -} - %typemap(in) const Literal& (StatusOr literal_status) { literal_status = numpy::XlaLiteralFromPyObject($input); if (!literal_status.ok()) { @@ -472,16 +462,26 @@ tensorflow::ImportNumpy(); $1 = &literal_status.ValueOrDie(); } -%typemap(out) Literal { - $result = numpy::PyObjectFromXlaLiteral(*$1); +%typemap(out) Literal (StatusOr obj_status) { + obj_status = numpy::PyObjectFromXlaLiteral(*$1); + if (!obj_status.ok()) { + PyErr_SetString(PyExc_RuntimeError, obj_status.status().ToString().c_str()); + SWIG_fail; + } + $result = obj_status.ValueOrDie().release(); } -%typemap(out) StatusOr { +%typemap(out) StatusOr (StatusOr obj_status) { if (!$1.ok()) { PyErr_SetString(PyExc_RuntimeError, $1.status().ToString().c_str()); SWIG_fail; } - $result = numpy::PyObjectFromXlaLiteral($1.ValueOrDie()); + obj_status = numpy::PyObjectFromXlaLiteral($1.ValueOrDie()); + if (!obj_status.ok()) { + PyErr_SetString(PyExc_RuntimeError, obj_status.status().ToString().c_str()); + SWIG_fail; + } + $result = obj_status.ValueOrDie().release(); } %typemap(in) const std::vector& (std::vector temps) { diff --git a/tensorflow/compiler/xla/python/numpy_bridge.cc b/tensorflow/compiler/xla/python/numpy_bridge.cc index 52c5c621f72..8e056f97255 100644 --- a/tensorflow/compiler/xla/python/numpy_bridge.cc +++ b/tensorflow/compiler/xla/python/numpy_bridge.cc @@ -26,6 +26,10 @@ namespace swig { namespace numpy { +Safe_PyObjectPtr make_safe(PyObject* object) { + return Safe_PyObjectPtr(object); +} + int PrimitiveTypeToNumpyType(PrimitiveType primitive_type) { switch (primitive_type) { case PRED: @@ -349,13 +353,17 @@ StatusOr OpMetadataFromPyObject(PyObject* o) { return result; } -PyObject* PyObjectFromXlaLiteral(const LiteralSlice& literal) { +StatusOr PyObjectFromXlaLiteral(const LiteralSlice& literal) { if (literal.shape().IsTuple()) { int num_elements = ShapeUtil::TupleElementCount(literal.shape()); - PyObject* tuple = PyTuple_New(num_elements); + std::vector elems(num_elements); for (int i = 0; i < num_elements; i++) { - PyTuple_SET_ITEM(tuple, i, - PyObjectFromXlaLiteral(LiteralSlice(literal, {i}))); + TF_ASSIGN_OR_RETURN(elems[i], + PyObjectFromXlaLiteral(LiteralSlice(literal, {i}))); + } + Safe_PyObjectPtr tuple = make_safe(PyTuple_New(num_elements)); + for (int i = 0; i < num_elements; i++) { + PyTuple_SET_ITEM(tuple.get(), i, elems[i].release()); } return tuple; } else { @@ -365,10 +373,10 @@ PyObject* PyObjectFromXlaLiteral(const LiteralSlice& literal) { dimensions[i] = ShapeUtil::GetDimension(literal.shape(), i); } int np_type = PrimitiveTypeToNumpyType(literal.shape().element_type()); - PyObject* array = - PyArray_EMPTY(rank, dimensions.data(), np_type, /*fortran=*/0); - CopyLiteralToNumpyArray(np_type, literal, - reinterpret_cast(array)); + Safe_PyObjectPtr array = make_safe( + PyArray_EMPTY(rank, dimensions.data(), np_type, /*fortran=*/0)); + TF_RETURN_IF_ERROR(CopyLiteralToNumpyArray( + np_type, literal, reinterpret_cast(array.get()))); return array; } } @@ -408,6 +416,12 @@ Status CopyNumpyArrayToLiteral(int np_type, PyArrayObject* py_array, case NPY_BOOL: CopyNumpyArrayToLiteral(py_array, literal); break; + case NPY_INT8: + CopyNumpyArrayToLiteral(py_array, literal); + break; + case NPY_INT16: + CopyNumpyArrayToLiteral(py_array, literal); + break; case NPY_INT32: CopyNumpyArrayToLiteral(py_array, literal); break; @@ -417,6 +431,9 @@ Status CopyNumpyArrayToLiteral(int np_type, PyArrayObject* py_array, case NPY_UINT8: CopyNumpyArrayToLiteral(py_array, literal); break; + case NPY_UINT16: + CopyNumpyArrayToLiteral(py_array, literal); + break; case NPY_UINT32: CopyNumpyArrayToLiteral(py_array, literal); break; @@ -445,12 +462,18 @@ Status CopyNumpyArrayToLiteral(int np_type, PyArrayObject* py_array, return Status::OK(); } -void CopyLiteralToNumpyArray(int np_type, const LiteralSlice& literal, - PyArrayObject* py_array) { +Status CopyLiteralToNumpyArray(int np_type, const LiteralSlice& literal, + PyArrayObject* py_array) { switch (np_type) { case NPY_BOOL: CopyLiteralToNumpyArray(literal, py_array); break; + case NPY_INT8: + CopyLiteralToNumpyArray(literal, py_array); + break; + case NPY_INT16: + CopyLiteralToNumpyArray(literal, py_array); + break; case NPY_INT32: CopyLiteralToNumpyArray(literal, py_array); break; @@ -460,6 +483,9 @@ void CopyLiteralToNumpyArray(int np_type, const LiteralSlice& literal, case NPY_UINT8: CopyLiteralToNumpyArray(literal, py_array); break; + case NPY_UINT16: + CopyLiteralToNumpyArray(literal, py_array); + break; case NPY_UINT32: CopyLiteralToNumpyArray(literal, py_array); break; @@ -482,8 +508,10 @@ void CopyLiteralToNumpyArray(int np_type, const LiteralSlice& literal, CopyLiteralToNumpyArray(literal, py_array); break; default: - LOG(FATAL) << "No XLA literal container for Numpy type" << np_type; + return InvalidArgument( + "No XLA literal container for Numpy type number: %d", np_type); } + return Status::OK(); } PyObject* LongToPyIntOrPyLong(long x) { // NOLINT diff --git a/tensorflow/compiler/xla/python/numpy_bridge.h b/tensorflow/compiler/xla/python/numpy_bridge.h index 40ff2d9ad21..737fc4b29c1 100644 --- a/tensorflow/compiler/xla/python/numpy_bridge.h +++ b/tensorflow/compiler/xla/python/numpy_bridge.h @@ -36,6 +36,16 @@ namespace swig { namespace numpy { +struct PyDecrefDeleter { + void operator()(PyObject* p) const { Py_DECREF(p); } +}; + +// Safe container for an owned PyObject. On destruction, the reference count of +// the contained object will be decremented. +using Safe_PyObjectPtr = std::unique_ptr; + +Safe_PyObjectPtr make_safe(PyObject* object); + // Maps XLA primitive types (PRED, S8, F32, ..., and TUPLE) to numpy // dtypes (NPY_BOOL, NPY_INT8, NPY_FLOAT32, ..., and NPY_OBJECT), and // vice versa. @@ -74,7 +84,7 @@ StatusOr OpMetadataFromPyObject(PyObject* o); // array data. // // The return value is a new reference. -PyObject* PyObjectFromXlaLiteral(const LiteralSlice& literal); +StatusOr PyObjectFromXlaLiteral(const LiteralSlice& literal); // Converts a Numpy ndarray or a nested Python tuple thereof to a // corresponding XLA literal. @@ -90,8 +100,8 @@ StatusOr XlaLiteralFromPyObject(PyObject* o); Status CopyNumpyArrayToLiteral(int np_type, PyArrayObject* py_array, Literal* literal); -void CopyLiteralToNumpyArray(int np_type, const LiteralSlice& literal, - PyArrayObject* py_array); +Status CopyLiteralToNumpyArray(int np_type, const LiteralSlice& literal, + PyArrayObject* py_array); template void CopyNumpyArrayToLiteral(PyArrayObject* py_array, Literal* literal) { diff --git a/tensorflow/compiler/xla/python/xla_client_test.py b/tensorflow/compiler/xla/python/xla_client_test.py index c80e7924645..aa38c06cf90 100644 --- a/tensorflow/compiler/xla/python/xla_client_test.py +++ b/tensorflow/compiler/xla/python/xla_client_test.py @@ -88,6 +88,12 @@ def NumpyArrayBool(*args, **kwargs): class ComputationsWithConstantsTest(LocalComputationTest): """Tests focusing on Constant ops.""" + def testConstantScalarSumS8(self): + c = self._NewComputation() + root = c.Add(c.Constant(np.int8(1)), c.Constant(np.int8(2))) + self.assertEqual(c.GetShape(root), c.GetReturnValueShape()) + self._ExecuteAndCompareExact(c, expected=np.int8(3)) + def testConstantScalarSumF32(self): c = self._NewComputation() root = c.Add(c.ConstantF32Scalar(1.11), c.ConstantF32Scalar(3.14))