[XLA:Python] Fail gracefully when an unsupported numpy type is converted to/from XLA.
Add support for int8/int16/uint16. PiperOrigin-RevId: 232502865
This commit is contained in:
parent
0906bafa73
commit
32b84541e4
@ -453,16 +453,6 @@ tensorflow::ImportNumpy();
|
|||||||
|
|
||||||
// Literal
|
// Literal
|
||||||
|
|
||||||
%typemap(out) StatusOr<Literal> {
|
|
||||||
if ($1.ok()) {
|
|
||||||
Literal value = $1.ConsumeValueOrDie();
|
|
||||||
$result = numpy::PyObjectFromXlaLiteral(*value);
|
|
||||||
} else {
|
|
||||||
PyErr_SetString(PyExc_RuntimeError, $1.status().ToString().c_str());
|
|
||||||
SWIG_fail;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
%typemap(in) const Literal& (StatusOr<Literal> literal_status) {
|
%typemap(in) const Literal& (StatusOr<Literal> literal_status) {
|
||||||
literal_status = numpy::XlaLiteralFromPyObject($input);
|
literal_status = numpy::XlaLiteralFromPyObject($input);
|
||||||
if (!literal_status.ok()) {
|
if (!literal_status.ok()) {
|
||||||
@ -472,16 +462,26 @@ tensorflow::ImportNumpy();
|
|||||||
$1 = &literal_status.ValueOrDie();
|
$1 = &literal_status.ValueOrDie();
|
||||||
}
|
}
|
||||||
|
|
||||||
%typemap(out) Literal {
|
%typemap(out) Literal (StatusOr<numpy::Safe_PyObjectPtr> obj_status) {
|
||||||
$result = numpy::PyObjectFromXlaLiteral(*$1);
|
obj_status = numpy::PyObjectFromXlaLiteral(*$1);
|
||||||
|
if (!obj_status.ok()) {
|
||||||
|
PyErr_SetString(PyExc_RuntimeError, obj_status.status().ToString().c_str());
|
||||||
|
SWIG_fail;
|
||||||
|
}
|
||||||
|
$result = obj_status.ValueOrDie().release();
|
||||||
}
|
}
|
||||||
|
|
||||||
%typemap(out) StatusOr<Literal> {
|
%typemap(out) StatusOr<Literal> (StatusOr<numpy::Safe_PyObjectPtr> obj_status) {
|
||||||
if (!$1.ok()) {
|
if (!$1.ok()) {
|
||||||
PyErr_SetString(PyExc_RuntimeError, $1.status().ToString().c_str());
|
PyErr_SetString(PyExc_RuntimeError, $1.status().ToString().c_str());
|
||||||
SWIG_fail;
|
SWIG_fail;
|
||||||
}
|
}
|
||||||
$result = numpy::PyObjectFromXlaLiteral($1.ValueOrDie());
|
obj_status = numpy::PyObjectFromXlaLiteral($1.ValueOrDie());
|
||||||
|
if (!obj_status.ok()) {
|
||||||
|
PyErr_SetString(PyExc_RuntimeError, obj_status.status().ToString().c_str());
|
||||||
|
SWIG_fail;
|
||||||
|
}
|
||||||
|
$result = obj_status.ValueOrDie().release();
|
||||||
}
|
}
|
||||||
|
|
||||||
%typemap(in) const std::vector<Literal>& (std::vector<Literal> temps) {
|
%typemap(in) const std::vector<Literal>& (std::vector<Literal> temps) {
|
||||||
|
@ -26,6 +26,10 @@ namespace swig {
|
|||||||
|
|
||||||
namespace numpy {
|
namespace numpy {
|
||||||
|
|
||||||
|
Safe_PyObjectPtr make_safe(PyObject* object) {
|
||||||
|
return Safe_PyObjectPtr(object);
|
||||||
|
}
|
||||||
|
|
||||||
int PrimitiveTypeToNumpyType(PrimitiveType primitive_type) {
|
int PrimitiveTypeToNumpyType(PrimitiveType primitive_type) {
|
||||||
switch (primitive_type) {
|
switch (primitive_type) {
|
||||||
case PRED:
|
case PRED:
|
||||||
@ -349,13 +353,17 @@ StatusOr<OpMetadata> OpMetadataFromPyObject(PyObject* o) {
|
|||||||
return result;
|
return result;
|
||||||
}
|
}
|
||||||
|
|
||||||
PyObject* PyObjectFromXlaLiteral(const LiteralSlice& literal) {
|
StatusOr<Safe_PyObjectPtr> PyObjectFromXlaLiteral(const LiteralSlice& literal) {
|
||||||
if (literal.shape().IsTuple()) {
|
if (literal.shape().IsTuple()) {
|
||||||
int num_elements = ShapeUtil::TupleElementCount(literal.shape());
|
int num_elements = ShapeUtil::TupleElementCount(literal.shape());
|
||||||
PyObject* tuple = PyTuple_New(num_elements);
|
std::vector<Safe_PyObjectPtr> elems(num_elements);
|
||||||
for (int i = 0; i < num_elements; i++) {
|
for (int i = 0; i < num_elements; i++) {
|
||||||
PyTuple_SET_ITEM(tuple, i,
|
TF_ASSIGN_OR_RETURN(elems[i],
|
||||||
PyObjectFromXlaLiteral(LiteralSlice(literal, {i})));
|
PyObjectFromXlaLiteral(LiteralSlice(literal, {i})));
|
||||||
|
}
|
||||||
|
Safe_PyObjectPtr tuple = make_safe(PyTuple_New(num_elements));
|
||||||
|
for (int i = 0; i < num_elements; i++) {
|
||||||
|
PyTuple_SET_ITEM(tuple.get(), i, elems[i].release());
|
||||||
}
|
}
|
||||||
return tuple;
|
return tuple;
|
||||||
} else {
|
} else {
|
||||||
@ -365,10 +373,10 @@ PyObject* PyObjectFromXlaLiteral(const LiteralSlice& literal) {
|
|||||||
dimensions[i] = ShapeUtil::GetDimension(literal.shape(), i);
|
dimensions[i] = ShapeUtil::GetDimension(literal.shape(), i);
|
||||||
}
|
}
|
||||||
int np_type = PrimitiveTypeToNumpyType(literal.shape().element_type());
|
int np_type = PrimitiveTypeToNumpyType(literal.shape().element_type());
|
||||||
PyObject* array =
|
Safe_PyObjectPtr array = make_safe(
|
||||||
PyArray_EMPTY(rank, dimensions.data(), np_type, /*fortran=*/0);
|
PyArray_EMPTY(rank, dimensions.data(), np_type, /*fortran=*/0));
|
||||||
CopyLiteralToNumpyArray(np_type, literal,
|
TF_RETURN_IF_ERROR(CopyLiteralToNumpyArray(
|
||||||
reinterpret_cast<PyArrayObject*>(array));
|
np_type, literal, reinterpret_cast<PyArrayObject*>(array.get())));
|
||||||
return array;
|
return array;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -408,6 +416,12 @@ Status CopyNumpyArrayToLiteral(int np_type, PyArrayObject* py_array,
|
|||||||
case NPY_BOOL:
|
case NPY_BOOL:
|
||||||
CopyNumpyArrayToLiteral<bool>(py_array, literal);
|
CopyNumpyArrayToLiteral<bool>(py_array, literal);
|
||||||
break;
|
break;
|
||||||
|
case NPY_INT8:
|
||||||
|
CopyNumpyArrayToLiteral<int8>(py_array, literal);
|
||||||
|
break;
|
||||||
|
case NPY_INT16:
|
||||||
|
CopyNumpyArrayToLiteral<int16>(py_array, literal);
|
||||||
|
break;
|
||||||
case NPY_INT32:
|
case NPY_INT32:
|
||||||
CopyNumpyArrayToLiteral<int32>(py_array, literal);
|
CopyNumpyArrayToLiteral<int32>(py_array, literal);
|
||||||
break;
|
break;
|
||||||
@ -417,6 +431,9 @@ Status CopyNumpyArrayToLiteral(int np_type, PyArrayObject* py_array,
|
|||||||
case NPY_UINT8:
|
case NPY_UINT8:
|
||||||
CopyNumpyArrayToLiteral<uint8>(py_array, literal);
|
CopyNumpyArrayToLiteral<uint8>(py_array, literal);
|
||||||
break;
|
break;
|
||||||
|
case NPY_UINT16:
|
||||||
|
CopyNumpyArrayToLiteral<uint16>(py_array, literal);
|
||||||
|
break;
|
||||||
case NPY_UINT32:
|
case NPY_UINT32:
|
||||||
CopyNumpyArrayToLiteral<uint32>(py_array, literal);
|
CopyNumpyArrayToLiteral<uint32>(py_array, literal);
|
||||||
break;
|
break;
|
||||||
@ -445,12 +462,18 @@ Status CopyNumpyArrayToLiteral(int np_type, PyArrayObject* py_array,
|
|||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
|
||||||
void CopyLiteralToNumpyArray(int np_type, const LiteralSlice& literal,
|
Status CopyLiteralToNumpyArray(int np_type, const LiteralSlice& literal,
|
||||||
PyArrayObject* py_array) {
|
PyArrayObject* py_array) {
|
||||||
switch (np_type) {
|
switch (np_type) {
|
||||||
case NPY_BOOL:
|
case NPY_BOOL:
|
||||||
CopyLiteralToNumpyArray<bool>(literal, py_array);
|
CopyLiteralToNumpyArray<bool>(literal, py_array);
|
||||||
break;
|
break;
|
||||||
|
case NPY_INT8:
|
||||||
|
CopyLiteralToNumpyArray<int8>(literal, py_array);
|
||||||
|
break;
|
||||||
|
case NPY_INT16:
|
||||||
|
CopyLiteralToNumpyArray<int16>(literal, py_array);
|
||||||
|
break;
|
||||||
case NPY_INT32:
|
case NPY_INT32:
|
||||||
CopyLiteralToNumpyArray<int32>(literal, py_array);
|
CopyLiteralToNumpyArray<int32>(literal, py_array);
|
||||||
break;
|
break;
|
||||||
@ -460,6 +483,9 @@ void CopyLiteralToNumpyArray(int np_type, const LiteralSlice& literal,
|
|||||||
case NPY_UINT8:
|
case NPY_UINT8:
|
||||||
CopyLiteralToNumpyArray<uint8>(literal, py_array);
|
CopyLiteralToNumpyArray<uint8>(literal, py_array);
|
||||||
break;
|
break;
|
||||||
|
case NPY_UINT16:
|
||||||
|
CopyLiteralToNumpyArray<uint16>(literal, py_array);
|
||||||
|
break;
|
||||||
case NPY_UINT32:
|
case NPY_UINT32:
|
||||||
CopyLiteralToNumpyArray<uint32>(literal, py_array);
|
CopyLiteralToNumpyArray<uint32>(literal, py_array);
|
||||||
break;
|
break;
|
||||||
@ -482,8 +508,10 @@ void CopyLiteralToNumpyArray(int np_type, const LiteralSlice& literal,
|
|||||||
CopyLiteralToNumpyArray<complex128>(literal, py_array);
|
CopyLiteralToNumpyArray<complex128>(literal, py_array);
|
||||||
break;
|
break;
|
||||||
default:
|
default:
|
||||||
LOG(FATAL) << "No XLA literal container for Numpy type" << np_type;
|
return InvalidArgument(
|
||||||
|
"No XLA literal container for Numpy type number: %d", np_type);
|
||||||
}
|
}
|
||||||
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
|
||||||
PyObject* LongToPyIntOrPyLong(long x) { // NOLINT
|
PyObject* LongToPyIntOrPyLong(long x) { // NOLINT
|
||||||
|
@ -36,6 +36,16 @@ namespace swig {
|
|||||||
|
|
||||||
namespace numpy {
|
namespace numpy {
|
||||||
|
|
||||||
|
struct PyDecrefDeleter {
|
||||||
|
void operator()(PyObject* p) const { Py_DECREF(p); }
|
||||||
|
};
|
||||||
|
|
||||||
|
// Safe container for an owned PyObject. On destruction, the reference count of
|
||||||
|
// the contained object will be decremented.
|
||||||
|
using Safe_PyObjectPtr = std::unique_ptr<PyObject, PyDecrefDeleter>;
|
||||||
|
|
||||||
|
Safe_PyObjectPtr make_safe(PyObject* object);
|
||||||
|
|
||||||
// Maps XLA primitive types (PRED, S8, F32, ..., and TUPLE) to numpy
|
// Maps XLA primitive types (PRED, S8, F32, ..., and TUPLE) to numpy
|
||||||
// dtypes (NPY_BOOL, NPY_INT8, NPY_FLOAT32, ..., and NPY_OBJECT), and
|
// dtypes (NPY_BOOL, NPY_INT8, NPY_FLOAT32, ..., and NPY_OBJECT), and
|
||||||
// vice versa.
|
// vice versa.
|
||||||
@ -74,7 +84,7 @@ StatusOr<OpMetadata> OpMetadataFromPyObject(PyObject* o);
|
|||||||
// array data.
|
// array data.
|
||||||
//
|
//
|
||||||
// The return value is a new reference.
|
// The return value is a new reference.
|
||||||
PyObject* PyObjectFromXlaLiteral(const LiteralSlice& literal);
|
StatusOr<Safe_PyObjectPtr> PyObjectFromXlaLiteral(const LiteralSlice& literal);
|
||||||
|
|
||||||
// Converts a Numpy ndarray or a nested Python tuple thereof to a
|
// Converts a Numpy ndarray or a nested Python tuple thereof to a
|
||||||
// corresponding XLA literal.
|
// corresponding XLA literal.
|
||||||
@ -90,8 +100,8 @@ StatusOr<Literal> XlaLiteralFromPyObject(PyObject* o);
|
|||||||
Status CopyNumpyArrayToLiteral(int np_type, PyArrayObject* py_array,
|
Status CopyNumpyArrayToLiteral(int np_type, PyArrayObject* py_array,
|
||||||
Literal* literal);
|
Literal* literal);
|
||||||
|
|
||||||
void CopyLiteralToNumpyArray(int np_type, const LiteralSlice& literal,
|
Status CopyLiteralToNumpyArray(int np_type, const LiteralSlice& literal,
|
||||||
PyArrayObject* py_array);
|
PyArrayObject* py_array);
|
||||||
|
|
||||||
template <typename NativeT>
|
template <typename NativeT>
|
||||||
void CopyNumpyArrayToLiteral(PyArrayObject* py_array, Literal* literal) {
|
void CopyNumpyArrayToLiteral(PyArrayObject* py_array, Literal* literal) {
|
||||||
|
@ -88,6 +88,12 @@ def NumpyArrayBool(*args, **kwargs):
|
|||||||
class ComputationsWithConstantsTest(LocalComputationTest):
|
class ComputationsWithConstantsTest(LocalComputationTest):
|
||||||
"""Tests focusing on Constant ops."""
|
"""Tests focusing on Constant ops."""
|
||||||
|
|
||||||
|
def testConstantScalarSumS8(self):
|
||||||
|
c = self._NewComputation()
|
||||||
|
root = c.Add(c.Constant(np.int8(1)), c.Constant(np.int8(2)))
|
||||||
|
self.assertEqual(c.GetShape(root), c.GetReturnValueShape())
|
||||||
|
self._ExecuteAndCompareExact(c, expected=np.int8(3))
|
||||||
|
|
||||||
def testConstantScalarSumF32(self):
|
def testConstantScalarSumF32(self):
|
||||||
c = self._NewComputation()
|
c = self._NewComputation()
|
||||||
root = c.Add(c.ConstantF32Scalar(1.11), c.ConstantF32Scalar(3.14))
|
root = c.Add(c.ConstantF32Scalar(1.11), c.ConstantF32Scalar(3.14))
|
||||||
|
Loading…
Reference in New Issue
Block a user