[XLA:Python] Fail gracefully when an unsupported numpy type is converted to/from XLA.

Add support for int8/int16/uint16.

PiperOrigin-RevId: 232502865
This commit is contained in:
Peter Hawkins 2019-02-05 09:40:53 -08:00 committed by TensorFlower Gardener
parent 0906bafa73
commit 32b84541e4
4 changed files with 72 additions and 28 deletions

View File

@ -453,16 +453,6 @@ tensorflow::ImportNumpy();
// Literal // Literal
%typemap(out) StatusOr<Literal> {
if ($1.ok()) {
Literal value = $1.ConsumeValueOrDie();
$result = numpy::PyObjectFromXlaLiteral(*value);
} else {
PyErr_SetString(PyExc_RuntimeError, $1.status().ToString().c_str());
SWIG_fail;
}
}
%typemap(in) const Literal& (StatusOr<Literal> literal_status) { %typemap(in) const Literal& (StatusOr<Literal> literal_status) {
literal_status = numpy::XlaLiteralFromPyObject($input); literal_status = numpy::XlaLiteralFromPyObject($input);
if (!literal_status.ok()) { if (!literal_status.ok()) {
@ -472,16 +462,26 @@ tensorflow::ImportNumpy();
$1 = &literal_status.ValueOrDie(); $1 = &literal_status.ValueOrDie();
} }
%typemap(out) Literal { %typemap(out) Literal (StatusOr<numpy::Safe_PyObjectPtr> obj_status) {
$result = numpy::PyObjectFromXlaLiteral(*$1); obj_status = numpy::PyObjectFromXlaLiteral(*$1);
if (!obj_status.ok()) {
PyErr_SetString(PyExc_RuntimeError, obj_status.status().ToString().c_str());
SWIG_fail;
}
$result = obj_status.ValueOrDie().release();
} }
%typemap(out) StatusOr<Literal> { %typemap(out) StatusOr<Literal> (StatusOr<numpy::Safe_PyObjectPtr> obj_status) {
if (!$1.ok()) { if (!$1.ok()) {
PyErr_SetString(PyExc_RuntimeError, $1.status().ToString().c_str()); PyErr_SetString(PyExc_RuntimeError, $1.status().ToString().c_str());
SWIG_fail; SWIG_fail;
} }
$result = numpy::PyObjectFromXlaLiteral($1.ValueOrDie()); obj_status = numpy::PyObjectFromXlaLiteral($1.ValueOrDie());
if (!obj_status.ok()) {
PyErr_SetString(PyExc_RuntimeError, obj_status.status().ToString().c_str());
SWIG_fail;
}
$result = obj_status.ValueOrDie().release();
} }
%typemap(in) const std::vector<Literal>& (std::vector<Literal> temps) { %typemap(in) const std::vector<Literal>& (std::vector<Literal> temps) {

View File

@ -26,6 +26,10 @@ namespace swig {
namespace numpy { namespace numpy {
Safe_PyObjectPtr make_safe(PyObject* object) {
return Safe_PyObjectPtr(object);
}
int PrimitiveTypeToNumpyType(PrimitiveType primitive_type) { int PrimitiveTypeToNumpyType(PrimitiveType primitive_type) {
switch (primitive_type) { switch (primitive_type) {
case PRED: case PRED:
@ -349,14 +353,18 @@ StatusOr<OpMetadata> OpMetadataFromPyObject(PyObject* o) {
return result; return result;
} }
PyObject* PyObjectFromXlaLiteral(const LiteralSlice& literal) { StatusOr<Safe_PyObjectPtr> PyObjectFromXlaLiteral(const LiteralSlice& literal) {
if (literal.shape().IsTuple()) { if (literal.shape().IsTuple()) {
int num_elements = ShapeUtil::TupleElementCount(literal.shape()); int num_elements = ShapeUtil::TupleElementCount(literal.shape());
PyObject* tuple = PyTuple_New(num_elements); std::vector<Safe_PyObjectPtr> elems(num_elements);
for (int i = 0; i < num_elements; i++) { for (int i = 0; i < num_elements; i++) {
PyTuple_SET_ITEM(tuple, i, TF_ASSIGN_OR_RETURN(elems[i],
PyObjectFromXlaLiteral(LiteralSlice(literal, {i}))); PyObjectFromXlaLiteral(LiteralSlice(literal, {i})));
} }
Safe_PyObjectPtr tuple = make_safe(PyTuple_New(num_elements));
for (int i = 0; i < num_elements; i++) {
PyTuple_SET_ITEM(tuple.get(), i, elems[i].release());
}
return tuple; return tuple;
} else { } else {
int rank = literal.shape().rank(); int rank = literal.shape().rank();
@ -365,10 +373,10 @@ PyObject* PyObjectFromXlaLiteral(const LiteralSlice& literal) {
dimensions[i] = ShapeUtil::GetDimension(literal.shape(), i); dimensions[i] = ShapeUtil::GetDimension(literal.shape(), i);
} }
int np_type = PrimitiveTypeToNumpyType(literal.shape().element_type()); int np_type = PrimitiveTypeToNumpyType(literal.shape().element_type());
PyObject* array = Safe_PyObjectPtr array = make_safe(
PyArray_EMPTY(rank, dimensions.data(), np_type, /*fortran=*/0); PyArray_EMPTY(rank, dimensions.data(), np_type, /*fortran=*/0));
CopyLiteralToNumpyArray(np_type, literal, TF_RETURN_IF_ERROR(CopyLiteralToNumpyArray(
reinterpret_cast<PyArrayObject*>(array)); np_type, literal, reinterpret_cast<PyArrayObject*>(array.get())));
return array; return array;
} }
} }
@ -408,6 +416,12 @@ Status CopyNumpyArrayToLiteral(int np_type, PyArrayObject* py_array,
case NPY_BOOL: case NPY_BOOL:
CopyNumpyArrayToLiteral<bool>(py_array, literal); CopyNumpyArrayToLiteral<bool>(py_array, literal);
break; break;
case NPY_INT8:
CopyNumpyArrayToLiteral<int8>(py_array, literal);
break;
case NPY_INT16:
CopyNumpyArrayToLiteral<int16>(py_array, literal);
break;
case NPY_INT32: case NPY_INT32:
CopyNumpyArrayToLiteral<int32>(py_array, literal); CopyNumpyArrayToLiteral<int32>(py_array, literal);
break; break;
@ -417,6 +431,9 @@ Status CopyNumpyArrayToLiteral(int np_type, PyArrayObject* py_array,
case NPY_UINT8: case NPY_UINT8:
CopyNumpyArrayToLiteral<uint8>(py_array, literal); CopyNumpyArrayToLiteral<uint8>(py_array, literal);
break; break;
case NPY_UINT16:
CopyNumpyArrayToLiteral<uint16>(py_array, literal);
break;
case NPY_UINT32: case NPY_UINT32:
CopyNumpyArrayToLiteral<uint32>(py_array, literal); CopyNumpyArrayToLiteral<uint32>(py_array, literal);
break; break;
@ -445,12 +462,18 @@ Status CopyNumpyArrayToLiteral(int np_type, PyArrayObject* py_array,
return Status::OK(); return Status::OK();
} }
void CopyLiteralToNumpyArray(int np_type, const LiteralSlice& literal, Status CopyLiteralToNumpyArray(int np_type, const LiteralSlice& literal,
PyArrayObject* py_array) { PyArrayObject* py_array) {
switch (np_type) { switch (np_type) {
case NPY_BOOL: case NPY_BOOL:
CopyLiteralToNumpyArray<bool>(literal, py_array); CopyLiteralToNumpyArray<bool>(literal, py_array);
break; break;
case NPY_INT8:
CopyLiteralToNumpyArray<int8>(literal, py_array);
break;
case NPY_INT16:
CopyLiteralToNumpyArray<int16>(literal, py_array);
break;
case NPY_INT32: case NPY_INT32:
CopyLiteralToNumpyArray<int32>(literal, py_array); CopyLiteralToNumpyArray<int32>(literal, py_array);
break; break;
@ -460,6 +483,9 @@ void CopyLiteralToNumpyArray(int np_type, const LiteralSlice& literal,
case NPY_UINT8: case NPY_UINT8:
CopyLiteralToNumpyArray<uint8>(literal, py_array); CopyLiteralToNumpyArray<uint8>(literal, py_array);
break; break;
case NPY_UINT16:
CopyLiteralToNumpyArray<uint16>(literal, py_array);
break;
case NPY_UINT32: case NPY_UINT32:
CopyLiteralToNumpyArray<uint32>(literal, py_array); CopyLiteralToNumpyArray<uint32>(literal, py_array);
break; break;
@ -482,8 +508,10 @@ void CopyLiteralToNumpyArray(int np_type, const LiteralSlice& literal,
CopyLiteralToNumpyArray<complex128>(literal, py_array); CopyLiteralToNumpyArray<complex128>(literal, py_array);
break; break;
default: default:
LOG(FATAL) << "No XLA literal container for Numpy type" << np_type; return InvalidArgument(
"No XLA literal container for Numpy type number: %d", np_type);
} }
return Status::OK();
} }
PyObject* LongToPyIntOrPyLong(long x) { // NOLINT PyObject* LongToPyIntOrPyLong(long x) { // NOLINT

View File

@ -36,6 +36,16 @@ namespace swig {
namespace numpy { namespace numpy {
struct PyDecrefDeleter {
void operator()(PyObject* p) const { Py_DECREF(p); }
};
// Safe container for an owned PyObject. On destruction, the reference count of
// the contained object will be decremented.
using Safe_PyObjectPtr = std::unique_ptr<PyObject, PyDecrefDeleter>;
Safe_PyObjectPtr make_safe(PyObject* object);
// Maps XLA primitive types (PRED, S8, F32, ..., and TUPLE) to numpy // Maps XLA primitive types (PRED, S8, F32, ..., and TUPLE) to numpy
// dtypes (NPY_BOOL, NPY_INT8, NPY_FLOAT32, ..., and NPY_OBJECT), and // dtypes (NPY_BOOL, NPY_INT8, NPY_FLOAT32, ..., and NPY_OBJECT), and
// vice versa. // vice versa.
@ -74,7 +84,7 @@ StatusOr<OpMetadata> OpMetadataFromPyObject(PyObject* o);
// array data. // array data.
// //
// The return value is a new reference. // The return value is a new reference.
PyObject* PyObjectFromXlaLiteral(const LiteralSlice& literal); StatusOr<Safe_PyObjectPtr> PyObjectFromXlaLiteral(const LiteralSlice& literal);
// Converts a Numpy ndarray or a nested Python tuple thereof to a // Converts a Numpy ndarray or a nested Python tuple thereof to a
// corresponding XLA literal. // corresponding XLA literal.
@ -90,7 +100,7 @@ StatusOr<Literal> XlaLiteralFromPyObject(PyObject* o);
Status CopyNumpyArrayToLiteral(int np_type, PyArrayObject* py_array, Status CopyNumpyArrayToLiteral(int np_type, PyArrayObject* py_array,
Literal* literal); Literal* literal);
void CopyLiteralToNumpyArray(int np_type, const LiteralSlice& literal, Status CopyLiteralToNumpyArray(int np_type, const LiteralSlice& literal,
PyArrayObject* py_array); PyArrayObject* py_array);
template <typename NativeT> template <typename NativeT>

View File

@ -88,6 +88,12 @@ def NumpyArrayBool(*args, **kwargs):
class ComputationsWithConstantsTest(LocalComputationTest): class ComputationsWithConstantsTest(LocalComputationTest):
"""Tests focusing on Constant ops.""" """Tests focusing on Constant ops."""
def testConstantScalarSumS8(self):
c = self._NewComputation()
root = c.Add(c.Constant(np.int8(1)), c.Constant(np.int8(2)))
self.assertEqual(c.GetShape(root), c.GetReturnValueShape())
self._ExecuteAndCompareExact(c, expected=np.int8(3))
def testConstantScalarSumF32(self): def testConstantScalarSumF32(self):
c = self._NewComputation() c = self._NewComputation()
root = c.Add(c.ConstantF32Scalar(1.11), c.ConstantF32Scalar(3.14)) root = c.Add(c.ConstantF32Scalar(1.11), c.ConstantF32Scalar(3.14))