[XLA:Python] Fail gracefully when an unsupported numpy type is converted to/from XLA.

Add support for int8/int16/uint16.

PiperOrigin-RevId: 232502865
This commit is contained in:
Peter Hawkins 2019-02-05 09:40:53 -08:00 committed by TensorFlower Gardener
parent 0906bafa73
commit 32b84541e4
4 changed files with 72 additions and 28 deletions

View File

@ -453,16 +453,6 @@ tensorflow::ImportNumpy();
// Literal
%typemap(out) StatusOr<Literal> {
if ($1.ok()) {
Literal value = $1.ConsumeValueOrDie();
$result = numpy::PyObjectFromXlaLiteral(*value);
} else {
PyErr_SetString(PyExc_RuntimeError, $1.status().ToString().c_str());
SWIG_fail;
}
}
%typemap(in) const Literal& (StatusOr<Literal> literal_status) {
literal_status = numpy::XlaLiteralFromPyObject($input);
if (!literal_status.ok()) {
@ -472,16 +462,26 @@ tensorflow::ImportNumpy();
$1 = &literal_status.ValueOrDie();
}
%typemap(out) Literal {
$result = numpy::PyObjectFromXlaLiteral(*$1);
%typemap(out) Literal (StatusOr<numpy::Safe_PyObjectPtr> obj_status) {
obj_status = numpy::PyObjectFromXlaLiteral(*$1);
if (!obj_status.ok()) {
PyErr_SetString(PyExc_RuntimeError, obj_status.status().ToString().c_str());
SWIG_fail;
}
$result = obj_status.ValueOrDie().release();
}
%typemap(out) StatusOr<Literal> {
%typemap(out) StatusOr<Literal> (StatusOr<numpy::Safe_PyObjectPtr> obj_status) {
if (!$1.ok()) {
PyErr_SetString(PyExc_RuntimeError, $1.status().ToString().c_str());
SWIG_fail;
}
$result = numpy::PyObjectFromXlaLiteral($1.ValueOrDie());
obj_status = numpy::PyObjectFromXlaLiteral($1.ValueOrDie());
if (!obj_status.ok()) {
PyErr_SetString(PyExc_RuntimeError, obj_status.status().ToString().c_str());
SWIG_fail;
}
$result = obj_status.ValueOrDie().release();
}
%typemap(in) const std::vector<Literal>& (std::vector<Literal> temps) {

View File

@ -26,6 +26,10 @@ namespace swig {
namespace numpy {
Safe_PyObjectPtr make_safe(PyObject* object) {
return Safe_PyObjectPtr(object);
}
int PrimitiveTypeToNumpyType(PrimitiveType primitive_type) {
switch (primitive_type) {
case PRED:
@ -349,13 +353,17 @@ StatusOr<OpMetadata> OpMetadataFromPyObject(PyObject* o) {
return result;
}
PyObject* PyObjectFromXlaLiteral(const LiteralSlice& literal) {
StatusOr<Safe_PyObjectPtr> PyObjectFromXlaLiteral(const LiteralSlice& literal) {
if (literal.shape().IsTuple()) {
int num_elements = ShapeUtil::TupleElementCount(literal.shape());
PyObject* tuple = PyTuple_New(num_elements);
std::vector<Safe_PyObjectPtr> elems(num_elements);
for (int i = 0; i < num_elements; i++) {
PyTuple_SET_ITEM(tuple, i,
PyObjectFromXlaLiteral(LiteralSlice(literal, {i})));
TF_ASSIGN_OR_RETURN(elems[i],
PyObjectFromXlaLiteral(LiteralSlice(literal, {i})));
}
Safe_PyObjectPtr tuple = make_safe(PyTuple_New(num_elements));
for (int i = 0; i < num_elements; i++) {
PyTuple_SET_ITEM(tuple.get(), i, elems[i].release());
}
return tuple;
} else {
@ -365,10 +373,10 @@ PyObject* PyObjectFromXlaLiteral(const LiteralSlice& literal) {
dimensions[i] = ShapeUtil::GetDimension(literal.shape(), i);
}
int np_type = PrimitiveTypeToNumpyType(literal.shape().element_type());
PyObject* array =
PyArray_EMPTY(rank, dimensions.data(), np_type, /*fortran=*/0);
CopyLiteralToNumpyArray(np_type, literal,
reinterpret_cast<PyArrayObject*>(array));
Safe_PyObjectPtr array = make_safe(
PyArray_EMPTY(rank, dimensions.data(), np_type, /*fortran=*/0));
TF_RETURN_IF_ERROR(CopyLiteralToNumpyArray(
np_type, literal, reinterpret_cast<PyArrayObject*>(array.get())));
return array;
}
}
@ -408,6 +416,12 @@ Status CopyNumpyArrayToLiteral(int np_type, PyArrayObject* py_array,
case NPY_BOOL:
CopyNumpyArrayToLiteral<bool>(py_array, literal);
break;
case NPY_INT8:
CopyNumpyArrayToLiteral<int8>(py_array, literal);
break;
case NPY_INT16:
CopyNumpyArrayToLiteral<int16>(py_array, literal);
break;
case NPY_INT32:
CopyNumpyArrayToLiteral<int32>(py_array, literal);
break;
@ -417,6 +431,9 @@ Status CopyNumpyArrayToLiteral(int np_type, PyArrayObject* py_array,
case NPY_UINT8:
CopyNumpyArrayToLiteral<uint8>(py_array, literal);
break;
case NPY_UINT16:
CopyNumpyArrayToLiteral<uint16>(py_array, literal);
break;
case NPY_UINT32:
CopyNumpyArrayToLiteral<uint32>(py_array, literal);
break;
@ -445,12 +462,18 @@ Status CopyNumpyArrayToLiteral(int np_type, PyArrayObject* py_array,
return Status::OK();
}
void CopyLiteralToNumpyArray(int np_type, const LiteralSlice& literal,
PyArrayObject* py_array) {
Status CopyLiteralToNumpyArray(int np_type, const LiteralSlice& literal,
PyArrayObject* py_array) {
switch (np_type) {
case NPY_BOOL:
CopyLiteralToNumpyArray<bool>(literal, py_array);
break;
case NPY_INT8:
CopyLiteralToNumpyArray<int8>(literal, py_array);
break;
case NPY_INT16:
CopyLiteralToNumpyArray<int16>(literal, py_array);
break;
case NPY_INT32:
CopyLiteralToNumpyArray<int32>(literal, py_array);
break;
@ -460,6 +483,9 @@ void CopyLiteralToNumpyArray(int np_type, const LiteralSlice& literal,
case NPY_UINT8:
CopyLiteralToNumpyArray<uint8>(literal, py_array);
break;
case NPY_UINT16:
CopyLiteralToNumpyArray<uint16>(literal, py_array);
break;
case NPY_UINT32:
CopyLiteralToNumpyArray<uint32>(literal, py_array);
break;
@ -482,8 +508,10 @@ void CopyLiteralToNumpyArray(int np_type, const LiteralSlice& literal,
CopyLiteralToNumpyArray<complex128>(literal, py_array);
break;
default:
LOG(FATAL) << "No XLA literal container for Numpy type" << np_type;
return InvalidArgument(
"No XLA literal container for Numpy type number: %d", np_type);
}
return Status::OK();
}
PyObject* LongToPyIntOrPyLong(long x) { // NOLINT

View File

@ -36,6 +36,16 @@ namespace swig {
namespace numpy {
struct PyDecrefDeleter {
void operator()(PyObject* p) const { Py_DECREF(p); }
};
// Safe container for an owned PyObject. On destruction, the reference count of
// the contained object will be decremented.
using Safe_PyObjectPtr = std::unique_ptr<PyObject, PyDecrefDeleter>;
Safe_PyObjectPtr make_safe(PyObject* object);
// Maps XLA primitive types (PRED, S8, F32, ..., and TUPLE) to numpy
// dtypes (NPY_BOOL, NPY_INT8, NPY_FLOAT32, ..., and NPY_OBJECT), and
// vice versa.
@ -74,7 +84,7 @@ StatusOr<OpMetadata> OpMetadataFromPyObject(PyObject* o);
// array data.
//
// The return value is a new reference.
PyObject* PyObjectFromXlaLiteral(const LiteralSlice& literal);
StatusOr<Safe_PyObjectPtr> PyObjectFromXlaLiteral(const LiteralSlice& literal);
// Converts a Numpy ndarray or a nested Python tuple thereof to a
// corresponding XLA literal.
@ -90,8 +100,8 @@ StatusOr<Literal> XlaLiteralFromPyObject(PyObject* o);
Status CopyNumpyArrayToLiteral(int np_type, PyArrayObject* py_array,
Literal* literal);
void CopyLiteralToNumpyArray(int np_type, const LiteralSlice& literal,
PyArrayObject* py_array);
Status CopyLiteralToNumpyArray(int np_type, const LiteralSlice& literal,
PyArrayObject* py_array);
template <typename NativeT>
void CopyNumpyArrayToLiteral(PyArrayObject* py_array, Literal* literal) {

View File

@ -88,6 +88,12 @@ def NumpyArrayBool(*args, **kwargs):
class ComputationsWithConstantsTest(LocalComputationTest):
"""Tests focusing on Constant ops."""
def testConstantScalarSumS8(self):
c = self._NewComputation()
root = c.Add(c.Constant(np.int8(1)), c.Constant(np.int8(2)))
self.assertEqual(c.GetShape(root), c.GetReturnValueShape())
self._ExecuteAndCompareExact(c, expected=np.int8(3))
def testConstantScalarSumF32(self):
c = self._NewComputation()
root = c.Add(c.ConstantF32Scalar(1.11), c.ConstantF32Scalar(3.14))