[XLA:Python] Remove some Python 2 compatibility code.
PiperOrigin-RevId: 308274667 Change-Id: I60aedce4c49906fad05a6494044b8ee415ec3e21
This commit is contained in:
parent
3184d08434
commit
76cb507e11
@ -46,52 +46,15 @@ Safe_PyObjectPtr make_safe(PyObject* object) {
|
||||
return Safe_PyObjectPtr(object);
|
||||
}
|
||||
|
||||
// Workarounds for Python 2 vs 3 API differences.
|
||||
#if PY_MAJOR_VERSION < 3
|
||||
|
||||
PyObject* MakePyString(const string& s) {
|
||||
return PyString_FromString(s.c_str());
|
||||
}
|
||||
|
||||
typedef long HashType; // NOLINT
|
||||
|
||||
bool TfPyInt_Check(PyObject* object) { return PyInt_Check(object); }
|
||||
|
||||
PyObject* TfPyInt_FromLong(long x) { // NOLINT
|
||||
return PyInt_FromLong(x);
|
||||
}
|
||||
|
||||
long TfPyInt_AsLong(PyObject* x) { // NOLINT
|
||||
return PyInt_AsLong(x);
|
||||
}
|
||||
|
||||
#else // PY_MAJOR_VERSION < 3
|
||||
|
||||
PyObject* MakePyString(const string& s) {
|
||||
return PyUnicode_FromString(s.c_str());
|
||||
}
|
||||
|
||||
bool TfPyInt_Check(PyObject* object) {
|
||||
bool PyLong_CheckNoOverflow(PyObject* object) {
|
||||
if (!PyLong_Check(object)) {
|
||||
return 0;
|
||||
return false;
|
||||
}
|
||||
int overflow = 0;
|
||||
PyLong_AsLongAndOverflow(object, &overflow);
|
||||
return (overflow == 0);
|
||||
}
|
||||
|
||||
PyObject* TfPyInt_FromLong(long x) { // NOLINT
|
||||
return PyLong_FromLong(x);
|
||||
}
|
||||
|
||||
long TfPyInt_AsLong(PyObject* x) { // NOLINT
|
||||
return PyLong_AsLong(x);
|
||||
}
|
||||
|
||||
typedef Py_hash_t HashType;
|
||||
|
||||
#endif // PY_MAJOR_VERSION < 3
|
||||
|
||||
// Registered numpy type ID. Global variable populated by the registration code.
|
||||
// Protected by the GIL.
|
||||
int npy_bfloat16 = -1;
|
||||
@ -143,8 +106,8 @@ bool CastToBfloat16(PyObject* arg, bfloat16* output) {
|
||||
*output = bfloat16(d);
|
||||
return true;
|
||||
}
|
||||
if (TfPyInt_Check(arg)) {
|
||||
long l = TfPyInt_AsLong(arg); // NOLINT
|
||||
if (PyLong_CheckNoOverflow(arg)) {
|
||||
long l = PyLong_AsLong(arg); // NOLINT
|
||||
if (PyErr_Occurred()) {
|
||||
return false;
|
||||
}
|
||||
@ -205,7 +168,7 @@ PyObject* PyBfloat16_Float(PyObject* self) {
|
||||
PyObject* PyBfloat16_Int(PyObject* self) {
|
||||
bfloat16 x = PyBfloat16_Bfloat16(self);
|
||||
long y = static_cast<long>(x); // NOLINT
|
||||
return TfPyInt_FromLong(y);
|
||||
return PyLong_FromLong(y);
|
||||
}
|
||||
|
||||
// Negates a PyBfloat16.
|
||||
@ -243,11 +206,7 @@ PyObject* PyBfloat16_TrueDivide(PyObject* a, PyObject* b) {
|
||||
if (SafeCastToBfloat16(a, &x) && SafeCastToBfloat16(b, &y)) {
|
||||
return PyBfloat16_FromBfloat16(x / y).release();
|
||||
}
|
||||
#if PY_MAJOR_VERSION < 3
|
||||
return PyArray_Type.tp_as_number->nb_divide(a, b);
|
||||
#else
|
||||
return PyArray_Type.tp_as_number->nb_true_divide(a, b);
|
||||
#endif
|
||||
}
|
||||
|
||||
// Python number methods for PyBfloat16 objects.
|
||||
@ -255,9 +214,6 @@ PyNumberMethods PyBfloat16_AsNumber = {
|
||||
PyBfloat16_Add, // nb_add
|
||||
PyBfloat16_Subtract, // nb_subtract
|
||||
PyBfloat16_Multiply, // nb_multiply
|
||||
#if PY_MAJOR_VERSION < 3
|
||||
PyBfloat16_TrueDivide, // nb_divide
|
||||
#endif
|
||||
nullptr, // nb_remainder
|
||||
nullptr, // nb_divmod
|
||||
nullptr, // nb_power
|
||||
@ -271,27 +227,13 @@ PyNumberMethods PyBfloat16_AsNumber = {
|
||||
nullptr, // nb_and
|
||||
nullptr, // nb_xor
|
||||
nullptr, // nb_or
|
||||
#if PY_MAJOR_VERSION < 3
|
||||
nullptr, // nb_coerce
|
||||
#endif
|
||||
PyBfloat16_Int, // nb_int
|
||||
#if PY_MAJOR_VERSION < 3
|
||||
PyBfloat16_Int, // nb_long
|
||||
#else
|
||||
nullptr, // reserved
|
||||
#endif
|
||||
PyBfloat16_Float, // nb_float
|
||||
#if PY_MAJOR_VERSION < 3
|
||||
nullptr, // nb_oct
|
||||
nullptr, // nb_hex
|
||||
#endif
|
||||
|
||||
nullptr, // nb_inplace_add
|
||||
nullptr, // nb_inplace_subtract
|
||||
nullptr, // nb_inplace_multiply
|
||||
#if PY_MAJOR_VERSION < 3
|
||||
nullptr, // nb_inplace_divide
|
||||
#endif
|
||||
nullptr, // nb_inplace_remainder
|
||||
nullptr, // nb_inplace_power
|
||||
nullptr, // nb_inplace_lshift
|
||||
@ -376,31 +318,27 @@ PyObject* PyBfloat16_RichCompare(PyObject* a, PyObject* b, int op) {
|
||||
// Implementation of repr() for PyBfloat16.
|
||||
PyObject* PyBfloat16_Repr(PyObject* self) {
|
||||
bfloat16 x = reinterpret_cast<PyBfloat16*>(self)->value;
|
||||
string v = absl::StrCat(static_cast<float>(x));
|
||||
return MakePyString(v);
|
||||
std::string v = absl::StrCat(static_cast<float>(x));
|
||||
return PyUnicode_FromString(v.c_str());
|
||||
}
|
||||
|
||||
// Implementation of str() for PyBfloat16.
|
||||
PyObject* PyBfloat16_Str(PyObject* self) {
|
||||
bfloat16 x = reinterpret_cast<PyBfloat16*>(self)->value;
|
||||
string v = absl::StrCat(static_cast<float>(x));
|
||||
return MakePyString(v);
|
||||
std::string v = absl::StrCat(static_cast<float>(x));
|
||||
return PyUnicode_FromString(v.c_str());
|
||||
}
|
||||
|
||||
// Hash function for PyBfloat16. We use the identity function, which is a weak
|
||||
// hash function.
|
||||
HashType PyBfloat16_Hash(PyObject* self) {
|
||||
Py_hash_t PyBfloat16_Hash(PyObject* self) {
|
||||
bfloat16 x = reinterpret_cast<PyBfloat16*>(self)->value;
|
||||
return x.value;
|
||||
}
|
||||
|
||||
// Python type for PyBfloat16 objects.
|
||||
PyTypeObject PyBfloat16_Type = {
|
||||
#if PY_MAJOR_VERSION < 3
|
||||
PyObject_HEAD_INIT(nullptr) 0, // ob_size
|
||||
#else
|
||||
PyVarObject_HEAD_INIT(nullptr, 0)
|
||||
#endif
|
||||
"bfloat16", // tp_name
|
||||
sizeof(PyBfloat16), // tp_basicsize
|
||||
0, // tp_itemsize
|
||||
@ -420,11 +358,7 @@ PyTypeObject PyBfloat16_Type = {
|
||||
nullptr, // tp_setattro
|
||||
nullptr, // tp_as_buffer
|
||||
// tp_flags
|
||||
#if PY_MAJOR_VERSION < 3
|
||||
Py_TPFLAGS_DEFAULT | Py_TPFLAGS_BASETYPE | Py_TPFLAGS_CHECKTYPES,
|
||||
#else
|
||||
Py_TPFLAGS_DEFAULT | Py_TPFLAGS_BASETYPE,
|
||||
#endif
|
||||
"bfloat16 floating-point values", // tp_doc
|
||||
nullptr, // tp_traverse
|
||||
nullptr, // tp_clear
|
||||
@ -1287,7 +1221,7 @@ bool Initialize() {
|
||||
import_array1(false);
|
||||
import_umath1(false);
|
||||
|
||||
Safe_PyObjectPtr numpy_str = make_safe(MakePyString("numpy"));
|
||||
Safe_PyObjectPtr numpy_str = make_safe(PyUnicode_FromString("numpy"));
|
||||
if (!numpy_str) {
|
||||
return false;
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user