[XLA:Python] Fail gracefully when an unsupported numpy type is converted to/from XLA.
Add support for int8/int16/uint16. PiperOrigin-RevId: 232502865
This commit is contained in:
parent
0906bafa73
commit
32b84541e4
@ -453,16 +453,6 @@ tensorflow::ImportNumpy();
|
||||
|
||||
// Literal
|
||||
|
||||
%typemap(out) StatusOr<Literal> {
|
||||
if ($1.ok()) {
|
||||
Literal value = $1.ConsumeValueOrDie();
|
||||
$result = numpy::PyObjectFromXlaLiteral(*value);
|
||||
} else {
|
||||
PyErr_SetString(PyExc_RuntimeError, $1.status().ToString().c_str());
|
||||
SWIG_fail;
|
||||
}
|
||||
}
|
||||
|
||||
%typemap(in) const Literal& (StatusOr<Literal> literal_status) {
|
||||
literal_status = numpy::XlaLiteralFromPyObject($input);
|
||||
if (!literal_status.ok()) {
|
||||
@ -472,16 +462,26 @@ tensorflow::ImportNumpy();
|
||||
$1 = &literal_status.ValueOrDie();
|
||||
}
|
||||
|
||||
%typemap(out) Literal {
|
||||
$result = numpy::PyObjectFromXlaLiteral(*$1);
|
||||
%typemap(out) Literal (StatusOr<numpy::Safe_PyObjectPtr> obj_status) {
|
||||
obj_status = numpy::PyObjectFromXlaLiteral(*$1);
|
||||
if (!obj_status.ok()) {
|
||||
PyErr_SetString(PyExc_RuntimeError, obj_status.status().ToString().c_str());
|
||||
SWIG_fail;
|
||||
}
|
||||
$result = obj_status.ValueOrDie().release();
|
||||
}
|
||||
|
||||
%typemap(out) StatusOr<Literal> {
|
||||
%typemap(out) StatusOr<Literal> (StatusOr<numpy::Safe_PyObjectPtr> obj_status) {
|
||||
if (!$1.ok()) {
|
||||
PyErr_SetString(PyExc_RuntimeError, $1.status().ToString().c_str());
|
||||
SWIG_fail;
|
||||
}
|
||||
$result = numpy::PyObjectFromXlaLiteral($1.ValueOrDie());
|
||||
obj_status = numpy::PyObjectFromXlaLiteral($1.ValueOrDie());
|
||||
if (!obj_status.ok()) {
|
||||
PyErr_SetString(PyExc_RuntimeError, obj_status.status().ToString().c_str());
|
||||
SWIG_fail;
|
||||
}
|
||||
$result = obj_status.ValueOrDie().release();
|
||||
}
|
||||
|
||||
%typemap(in) const std::vector<Literal>& (std::vector<Literal> temps) {
|
||||
|
@ -26,6 +26,10 @@ namespace swig {
|
||||
|
||||
namespace numpy {
|
||||
|
||||
Safe_PyObjectPtr make_safe(PyObject* object) {
|
||||
return Safe_PyObjectPtr(object);
|
||||
}
|
||||
|
||||
int PrimitiveTypeToNumpyType(PrimitiveType primitive_type) {
|
||||
switch (primitive_type) {
|
||||
case PRED:
|
||||
@ -349,13 +353,17 @@ StatusOr<OpMetadata> OpMetadataFromPyObject(PyObject* o) {
|
||||
return result;
|
||||
}
|
||||
|
||||
PyObject* PyObjectFromXlaLiteral(const LiteralSlice& literal) {
|
||||
StatusOr<Safe_PyObjectPtr> PyObjectFromXlaLiteral(const LiteralSlice& literal) {
|
||||
if (literal.shape().IsTuple()) {
|
||||
int num_elements = ShapeUtil::TupleElementCount(literal.shape());
|
||||
PyObject* tuple = PyTuple_New(num_elements);
|
||||
std::vector<Safe_PyObjectPtr> elems(num_elements);
|
||||
for (int i = 0; i < num_elements; i++) {
|
||||
PyTuple_SET_ITEM(tuple, i,
|
||||
PyObjectFromXlaLiteral(LiteralSlice(literal, {i})));
|
||||
TF_ASSIGN_OR_RETURN(elems[i],
|
||||
PyObjectFromXlaLiteral(LiteralSlice(literal, {i})));
|
||||
}
|
||||
Safe_PyObjectPtr tuple = make_safe(PyTuple_New(num_elements));
|
||||
for (int i = 0; i < num_elements; i++) {
|
||||
PyTuple_SET_ITEM(tuple.get(), i, elems[i].release());
|
||||
}
|
||||
return tuple;
|
||||
} else {
|
||||
@ -365,10 +373,10 @@ PyObject* PyObjectFromXlaLiteral(const LiteralSlice& literal) {
|
||||
dimensions[i] = ShapeUtil::GetDimension(literal.shape(), i);
|
||||
}
|
||||
int np_type = PrimitiveTypeToNumpyType(literal.shape().element_type());
|
||||
PyObject* array =
|
||||
PyArray_EMPTY(rank, dimensions.data(), np_type, /*fortran=*/0);
|
||||
CopyLiteralToNumpyArray(np_type, literal,
|
||||
reinterpret_cast<PyArrayObject*>(array));
|
||||
Safe_PyObjectPtr array = make_safe(
|
||||
PyArray_EMPTY(rank, dimensions.data(), np_type, /*fortran=*/0));
|
||||
TF_RETURN_IF_ERROR(CopyLiteralToNumpyArray(
|
||||
np_type, literal, reinterpret_cast<PyArrayObject*>(array.get())));
|
||||
return array;
|
||||
}
|
||||
}
|
||||
@ -408,6 +416,12 @@ Status CopyNumpyArrayToLiteral(int np_type, PyArrayObject* py_array,
|
||||
case NPY_BOOL:
|
||||
CopyNumpyArrayToLiteral<bool>(py_array, literal);
|
||||
break;
|
||||
case NPY_INT8:
|
||||
CopyNumpyArrayToLiteral<int8>(py_array, literal);
|
||||
break;
|
||||
case NPY_INT16:
|
||||
CopyNumpyArrayToLiteral<int16>(py_array, literal);
|
||||
break;
|
||||
case NPY_INT32:
|
||||
CopyNumpyArrayToLiteral<int32>(py_array, literal);
|
||||
break;
|
||||
@ -417,6 +431,9 @@ Status CopyNumpyArrayToLiteral(int np_type, PyArrayObject* py_array,
|
||||
case NPY_UINT8:
|
||||
CopyNumpyArrayToLiteral<uint8>(py_array, literal);
|
||||
break;
|
||||
case NPY_UINT16:
|
||||
CopyNumpyArrayToLiteral<uint16>(py_array, literal);
|
||||
break;
|
||||
case NPY_UINT32:
|
||||
CopyNumpyArrayToLiteral<uint32>(py_array, literal);
|
||||
break;
|
||||
@ -445,12 +462,18 @@ Status CopyNumpyArrayToLiteral(int np_type, PyArrayObject* py_array,
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
void CopyLiteralToNumpyArray(int np_type, const LiteralSlice& literal,
|
||||
PyArrayObject* py_array) {
|
||||
Status CopyLiteralToNumpyArray(int np_type, const LiteralSlice& literal,
|
||||
PyArrayObject* py_array) {
|
||||
switch (np_type) {
|
||||
case NPY_BOOL:
|
||||
CopyLiteralToNumpyArray<bool>(literal, py_array);
|
||||
break;
|
||||
case NPY_INT8:
|
||||
CopyLiteralToNumpyArray<int8>(literal, py_array);
|
||||
break;
|
||||
case NPY_INT16:
|
||||
CopyLiteralToNumpyArray<int16>(literal, py_array);
|
||||
break;
|
||||
case NPY_INT32:
|
||||
CopyLiteralToNumpyArray<int32>(literal, py_array);
|
||||
break;
|
||||
@ -460,6 +483,9 @@ void CopyLiteralToNumpyArray(int np_type, const LiteralSlice& literal,
|
||||
case NPY_UINT8:
|
||||
CopyLiteralToNumpyArray<uint8>(literal, py_array);
|
||||
break;
|
||||
case NPY_UINT16:
|
||||
CopyLiteralToNumpyArray<uint16>(literal, py_array);
|
||||
break;
|
||||
case NPY_UINT32:
|
||||
CopyLiteralToNumpyArray<uint32>(literal, py_array);
|
||||
break;
|
||||
@ -482,8 +508,10 @@ void CopyLiteralToNumpyArray(int np_type, const LiteralSlice& literal,
|
||||
CopyLiteralToNumpyArray<complex128>(literal, py_array);
|
||||
break;
|
||||
default:
|
||||
LOG(FATAL) << "No XLA literal container for Numpy type" << np_type;
|
||||
return InvalidArgument(
|
||||
"No XLA literal container for Numpy type number: %d", np_type);
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
PyObject* LongToPyIntOrPyLong(long x) { // NOLINT
|
||||
|
@ -36,6 +36,16 @@ namespace swig {
|
||||
|
||||
namespace numpy {
|
||||
|
||||
struct PyDecrefDeleter {
|
||||
void operator()(PyObject* p) const { Py_DECREF(p); }
|
||||
};
|
||||
|
||||
// Safe container for an owned PyObject. On destruction, the reference count of
|
||||
// the contained object will be decremented.
|
||||
using Safe_PyObjectPtr = std::unique_ptr<PyObject, PyDecrefDeleter>;
|
||||
|
||||
Safe_PyObjectPtr make_safe(PyObject* object);
|
||||
|
||||
// Maps XLA primitive types (PRED, S8, F32, ..., and TUPLE) to numpy
|
||||
// dtypes (NPY_BOOL, NPY_INT8, NPY_FLOAT32, ..., and NPY_OBJECT), and
|
||||
// vice versa.
|
||||
@ -74,7 +84,7 @@ StatusOr<OpMetadata> OpMetadataFromPyObject(PyObject* o);
|
||||
// array data.
|
||||
//
|
||||
// The return value is a new reference.
|
||||
PyObject* PyObjectFromXlaLiteral(const LiteralSlice& literal);
|
||||
StatusOr<Safe_PyObjectPtr> PyObjectFromXlaLiteral(const LiteralSlice& literal);
|
||||
|
||||
// Converts a Numpy ndarray or a nested Python tuple thereof to a
|
||||
// corresponding XLA literal.
|
||||
@ -90,8 +100,8 @@ StatusOr<Literal> XlaLiteralFromPyObject(PyObject* o);
|
||||
Status CopyNumpyArrayToLiteral(int np_type, PyArrayObject* py_array,
|
||||
Literal* literal);
|
||||
|
||||
void CopyLiteralToNumpyArray(int np_type, const LiteralSlice& literal,
|
||||
PyArrayObject* py_array);
|
||||
Status CopyLiteralToNumpyArray(int np_type, const LiteralSlice& literal,
|
||||
PyArrayObject* py_array);
|
||||
|
||||
template <typename NativeT>
|
||||
void CopyNumpyArrayToLiteral(PyArrayObject* py_array, Literal* literal) {
|
||||
|
@ -88,6 +88,12 @@ def NumpyArrayBool(*args, **kwargs):
|
||||
class ComputationsWithConstantsTest(LocalComputationTest):
|
||||
"""Tests focusing on Constant ops."""
|
||||
|
||||
def testConstantScalarSumS8(self):
|
||||
c = self._NewComputation()
|
||||
root = c.Add(c.Constant(np.int8(1)), c.Constant(np.int8(2)))
|
||||
self.assertEqual(c.GetShape(root), c.GetReturnValueShape())
|
||||
self._ExecuteAndCompareExact(c, expected=np.int8(3))
|
||||
|
||||
def testConstantScalarSumF32(self):
|
||||
c = self._NewComputation()
|
||||
root = c.Add(c.ConstantF32Scalar(1.11), c.ConstantF32Scalar(3.14))
|
||||
|
Loading…
Reference in New Issue
Block a user