Re-write a long macro as a template class instead.
This will facilitate changes in the functions that were previously defined by macro calls. The re-write tries to be quite literal: no other refactors or optimizations were done in this change. PiperOrigin-RevId: 279088175 Change-Id: I2d6f24ad53eb70e330b3d3448a5616cb7f6b6cf7
This commit is contained in:
parent
700263d02a
commit
1746a9229a
@ -77,34 +77,6 @@ PyObject* ZeroDimArrayToScalar(PyObject* obj) {
|
||||
return obj;
|
||||
}
|
||||
|
||||
// Converts Python object `c` that should hold a Python string into a
|
||||
// C++ string in *out. Returns nullptr on success, or a message on error.
|
||||
// Defined below, but forward declared here for use in PyRepr.
|
||||
const char* ConvertOneString(PyObject* v, tstring* out);
|
||||
|
||||
tstring PyRepr(PyObject* obj) {
|
||||
if (obj == nullptr) {
|
||||
return "<null>";
|
||||
}
|
||||
Safe_PyObjectPtr repr_obj = make_safe(PyObject_Repr(obj));
|
||||
if (repr_obj) {
|
||||
tstring repr_str;
|
||||
if (ConvertOneString(repr_obj.get(), &repr_str) == nullptr) {
|
||||
return repr_str;
|
||||
}
|
||||
}
|
||||
return "<error computing repr()>";
|
||||
}
|
||||
|
||||
bool IsPyDimension(PyObject* obj) {
|
||||
const char* tp_name = obj->ob_type->tp_name;
|
||||
if (strcmp(tp_name, "Dimension") != 0) return false;
|
||||
bool ret = str_util::EndsWith(
|
||||
PyRepr(PyType(obj)),
|
||||
"tensorflow.python.framework.tensor_shape.Dimension'>");
|
||||
return ret;
|
||||
}
|
||||
|
||||
// Sets *elem to a NEW reference to an element in seq on success.
|
||||
// REQUIRES: PySequence_Check(seq) && PySequence_Length(seq) > 0.
|
||||
Status SampleElementFromSequence(PyObject* seq, PyObject** elem) {
|
||||
@ -140,6 +112,9 @@ Status SampleElementFromSequence(PyObject* seq, PyObject** elem) {
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
tstring PyRepr(PyObject* obj);
|
||||
bool IsPyDimension(PyObject* obj);
|
||||
|
||||
Status InferShapeAndType(PyObject* obj, TensorShape* shape, DataType* dtype) {
|
||||
std::vector<Safe_PyObjectPtr> refs_to_clean;
|
||||
while (true) {
|
||||
@ -223,160 +198,183 @@ const char ErrorFoundFloat[] =
|
||||
"Can't convert Python sequence with floating point values to integer "
|
||||
"Tensor.";
|
||||
|
||||
// Template for defining a function for recursively convering obj into
|
||||
// an array of TYPE using the conversion function CONVERT.
|
||||
// Defines a converter that recursively converts an object into
|
||||
// an array of type T using the conversion function defined by the
|
||||
// traits class in a ConvertScalar function.
|
||||
//
|
||||
// Note that these helper functions require shape.dims() >= 1.
|
||||
template <class T>
|
||||
struct ConverterTraits {
|
||||
static const tensorflow::DataType kTypeEnum;
|
||||
static const char* ConvertScalar(PyObject* v, T* out);
|
||||
};
|
||||
|
||||
#define DEFINE_HELPER(FUNCTION, TYPE, TYPE_ENUM, CONVERT) \
|
||||
const char* FUNCTION##Helper(PyObject* obj, const TensorShape& shape, \
|
||||
TYPE** buf) { \
|
||||
if (TF_PREDICT_FALSE(obj == nullptr)) { \
|
||||
return ErrorConverting; \
|
||||
} \
|
||||
if (shape.dims() > 1) { \
|
||||
/* Iterate over outer dim, and recursively convert each element. */ \
|
||||
const int64 s = shape.dim_size(0); \
|
||||
Safe_PyObjectPtr seq = make_safe(PySequence_Fast(obj, "")); \
|
||||
if (TF_PREDICT_FALSE(seq == nullptr)) return ErrorRectangular; \
|
||||
if (TF_PREDICT_FALSE(s != PySequence_Fast_GET_SIZE(seq.get()))) { \
|
||||
return ErrorRectangular; \
|
||||
} \
|
||||
TensorShape rest = shape; \
|
||||
rest.RemoveDim(0); \
|
||||
for (int64 i = 0; i < s; ++i) { \
|
||||
const char* error = FUNCTION##Helper( \
|
||||
PySequence_Fast_GET_ITEM(seq.get(), i), rest, buf); \
|
||||
if (TF_PREDICT_FALSE(error != nullptr)) return error; \
|
||||
} \
|
||||
} else { \
|
||||
Safe_PyObjectPtr seq = make_safe(PySequence_Fast(obj, "")); \
|
||||
if (TF_PREDICT_FALSE(seq == nullptr)) return ErrorRectangular; \
|
||||
const int64 s = shape.dim_size(0); \
|
||||
if (TF_PREDICT_FALSE(s != PySequence_Fast_GET_SIZE(seq.get()))) { \
|
||||
return ErrorRectangular; \
|
||||
} \
|
||||
PyObject** l = PySequence_Fast_ITEMS(seq.get()); \
|
||||
for (int64 i = 0; i < s; ++i) { \
|
||||
auto scalar = ZeroDimArrayToScalar(l[i]); \
|
||||
const char* error = CONVERT(scalar, *buf); \
|
||||
Py_DECREF(scalar); \
|
||||
if (TF_PREDICT_FALSE(error != nullptr)) return error; \
|
||||
++*buf; \
|
||||
} \
|
||||
} \
|
||||
return nullptr; \
|
||||
} \
|
||||
const char* FUNCTION(PyObject* obj, const TensorShape& shape, \
|
||||
Tensor* dest) { \
|
||||
/* TODO(josh11b): Allocator & attributes? */ \
|
||||
Tensor result(TYPE_ENUM, shape); \
|
||||
if (shape.dims() == 0) { /* Scalar case */ \
|
||||
TYPE value; \
|
||||
auto scalar = ZeroDimArrayToScalar(obj); \
|
||||
const char* error = CONVERT(scalar, &value); \
|
||||
Py_DECREF(scalar); \
|
||||
if (error != nullptr) return error; \
|
||||
result.scalar<TYPE>()() = value; \
|
||||
} else { \
|
||||
TYPE* buf = result.flat<TYPE>().data(); \
|
||||
const char* error = FUNCTION##Helper(obj, shape, &buf); \
|
||||
if (error != nullptr) return error; \
|
||||
} \
|
||||
*dest = result; \
|
||||
return nullptr; \
|
||||
template <class T>
|
||||
struct Converter {
|
||||
static const char* Helper(PyObject* obj, const TensorShape& shape, T** buf) {
|
||||
if (TF_PREDICT_FALSE(obj == nullptr)) {
|
||||
return ErrorConverting;
|
||||
}
|
||||
if (shape.dims() > 1) {
|
||||
/* Iterate over outer dim, and recursively convert each element. */
|
||||
const int64 s = shape.dim_size(0);
|
||||
Safe_PyObjectPtr seq = make_safe(PySequence_Fast(obj, ""));
|
||||
if (TF_PREDICT_FALSE(seq == nullptr)) return ErrorRectangular;
|
||||
if (TF_PREDICT_FALSE(s != PySequence_Fast_GET_SIZE(seq.get()))) {
|
||||
return ErrorRectangular;
|
||||
}
|
||||
TensorShape rest = shape;
|
||||
rest.RemoveDim(0);
|
||||
for (int64 i = 0; i < s; ++i) {
|
||||
const char* error =
|
||||
Helper(PySequence_Fast_GET_ITEM(seq.get(), i), rest, buf);
|
||||
if (TF_PREDICT_FALSE(error != nullptr)) return error;
|
||||
}
|
||||
} else {
|
||||
Safe_PyObjectPtr seq = make_safe(PySequence_Fast(obj, ""));
|
||||
if (TF_PREDICT_FALSE(seq == nullptr)) return ErrorRectangular;
|
||||
const int64 s = shape.dim_size(0);
|
||||
if (TF_PREDICT_FALSE(s != PySequence_Fast_GET_SIZE(seq.get()))) {
|
||||
return ErrorRectangular;
|
||||
}
|
||||
PyObject** l = PySequence_Fast_ITEMS(seq.get());
|
||||
for (int64 i = 0; i < s; ++i) {
|
||||
auto scalar = ZeroDimArrayToScalar(l[i]);
|
||||
const char* error = ConverterTraits<T>::ConvertScalar(scalar, *buf);
|
||||
Py_DECREF(scalar);
|
||||
if (TF_PREDICT_FALSE(error != nullptr)) return error;
|
||||
++*buf;
|
||||
}
|
||||
}
|
||||
return nullptr;
|
||||
}
|
||||
static const char* Convert(PyObject* obj, const TensorShape& shape,
|
||||
Tensor* dest) {
|
||||
/* TODO(josh11b): Allocator & attributes? */
|
||||
Tensor result(ConverterTraits<T>::kTypeEnum, shape);
|
||||
if (shape.dims() == 0) { /* Scalar case */
|
||||
T value;
|
||||
auto scalar = ZeroDimArrayToScalar(obj);
|
||||
const char* error = ConverterTraits<T>::ConvertScalar(scalar, &value);
|
||||
Py_DECREF(scalar);
|
||||
if (error != nullptr) return error;
|
||||
result.scalar<T>()() = value;
|
||||
} else {
|
||||
T* buf = result.flat<T>().data();
|
||||
const char* error = Helper(obj, shape, &buf);
|
||||
if (error != nullptr) return error;
|
||||
}
|
||||
*dest = result;
|
||||
return nullptr;
|
||||
}
|
||||
};
|
||||
|
||||
// Int support
|
||||
|
||||
const char* ConvertOneInt64(PyObject* v, int64* out) {
|
||||
template <>
|
||||
struct ConverterTraits<int64> {
|
||||
static const tensorflow::DataType kTypeEnum = DT_INT64;
|
||||
|
||||
static const char* ConvertScalar(PyObject* v, int64* out) {
|
||||
#if PY_MAJOR_VERSION < 3
|
||||
if (TF_PREDICT_TRUE(PyInt_Check(v))) {
|
||||
*out = PyInt_AS_LONG(v);
|
||||
return nullptr;
|
||||
}
|
||||
if (TF_PREDICT_TRUE(PyInt_Check(v))) {
|
||||
*out = PyInt_AS_LONG(v);
|
||||
return nullptr;
|
||||
}
|
||||
#endif
|
||||
if (TF_PREDICT_TRUE(PyLong_Check(v) || IsPyDimension(v))) {
|
||||
int overflow = 0;
|
||||
// Have to use LongLong for 64 bits, since long is 32 bits on Windows.
|
||||
*out = PyLong_AsLongLongAndOverflow(v, &overflow);
|
||||
if (TF_PREDICT_FALSE(overflow)) return ErrorOutOfRange;
|
||||
return nullptr;
|
||||
}
|
||||
if (PyIsInstance(v, &PyIntegerArrType_Type)) { // NumPy integers
|
||||
if (TF_PREDICT_TRUE(PyLong_Check(v) || IsPyDimension(v))) {
|
||||
int overflow = 0;
|
||||
// Have to use LongLong for 64 bits, since long is 32 bits on Windows.
|
||||
*out = PyLong_AsLongLongAndOverflow(v, &overflow);
|
||||
if (TF_PREDICT_FALSE(overflow)) return ErrorOutOfRange;
|
||||
return nullptr;
|
||||
}
|
||||
if (PyIsInstance(v, &PyIntegerArrType_Type)) { // NumPy integers
|
||||
#if PY_MAJOR_VERSION < 3
|
||||
Safe_PyObjectPtr as_int = make_safe(PyNumber_Int(v));
|
||||
Safe_PyObjectPtr as_int = make_safe(PyNumber_Int(v));
|
||||
#else
|
||||
Safe_PyObjectPtr as_int = make_safe(PyNumber_Long(v));
|
||||
Safe_PyObjectPtr as_int = make_safe(PyNumber_Long(v));
|
||||
#endif
|
||||
return ConvertOneInt64(as_int.get(), out);
|
||||
}
|
||||
if (IsPyFloat(v)) return ErrorFoundFloat;
|
||||
return ErrorMixedTypes;
|
||||
}
|
||||
|
||||
DEFINE_HELPER(ConvertInt64, int64, DT_INT64, ConvertOneInt64);
|
||||
|
||||
const char* ConvertOneUint64(PyObject* v, uint64* out) {
|
||||
#if PY_MAJOR_VERSION < 3
|
||||
if (TF_PREDICT_TRUE(PyInt_Check(v))) {
|
||||
*out = PyInt_AsUnsignedLongLongMask(v);
|
||||
return nullptr;
|
||||
}
|
||||
#endif
|
||||
if (TF_PREDICT_TRUE(PyLong_Check(v) || IsPyDimension(v))) {
|
||||
*out = PyLong_AsUnsignedLongLong(v);
|
||||
return nullptr;
|
||||
}
|
||||
if (PyIsInstance(v, &PyIntegerArrType_Type)) { // NumPy integers
|
||||
#if PY_MAJOR_VERSION < 3
|
||||
Safe_PyObjectPtr as_int = make_safe(PyNumber_Int(v));
|
||||
#else
|
||||
Safe_PyObjectPtr as_int = make_safe(PyNumber_Long(v));
|
||||
#endif
|
||||
return ConvertOneUint64(as_int.get(), out);
|
||||
}
|
||||
if (IsPyFloat(v)) return ErrorFoundFloat;
|
||||
return ErrorMixedTypes;
|
||||
}
|
||||
|
||||
DEFINE_HELPER(ConvertUint64, uint64, DT_UINT64, ConvertOneUint64);
|
||||
|
||||
const char* ConvertOneInt32(PyObject* v, int32* out) {
|
||||
int64 i;
|
||||
#if PY_MAJOR_VERSION < 3
|
||||
if (TF_PREDICT_TRUE(PyInt_Check(v))) {
|
||||
i = PyInt_AS_LONG(v);
|
||||
} else
|
||||
#endif
|
||||
if (PyLong_Check(v) || IsPyDimension(v)) {
|
||||
int overflow = 0;
|
||||
// Have to use LongLong for 64 bits, since long is 32 bits on Windows.
|
||||
i = PyLong_AsLongLongAndOverflow(v, &overflow);
|
||||
if (TF_PREDICT_FALSE(overflow)) return ErrorOutOfRange;
|
||||
} else if (PyIsInstance(v, &PyIntegerArrType_Type)) { // NumPy integers
|
||||
#if PY_MAJOR_VERSION < 3
|
||||
Safe_PyObjectPtr as_int = make_safe(PyNumber_Int(v));
|
||||
#else
|
||||
Safe_PyObjectPtr as_int = make_safe(PyNumber_Long(v));
|
||||
#endif
|
||||
return ConvertOneInt32(as_int.get(), out);
|
||||
} else if (IsPyFloat(v)) {
|
||||
return ErrorFoundFloat;
|
||||
} else {
|
||||
return ConvertScalar(as_int.get(), out);
|
||||
}
|
||||
if (IsPyFloat(v)) return ErrorFoundFloat;
|
||||
return ErrorMixedTypes;
|
||||
}
|
||||
*out = static_cast<uint32>(static_cast<uint64>(i));
|
||||
// Check for 32-bit overflow.
|
||||
if (TF_PREDICT_FALSE(i != *out)) return ErrorFoundInt64;
|
||||
return nullptr;
|
||||
}
|
||||
};
|
||||
|
||||
DEFINE_HELPER(ConvertInt32, int32, DT_INT32, ConvertOneInt32);
|
||||
typedef Converter<int64> Int64Converter;
|
||||
|
||||
template <>
|
||||
struct ConverterTraits<uint64> {
|
||||
static const tensorflow::DataType kTypeEnum = DT_UINT64;
|
||||
|
||||
static const char* ConvertScalar(PyObject* v, uint64* out) {
|
||||
#if PY_MAJOR_VERSION < 3
|
||||
if (TF_PREDICT_TRUE(PyInt_Check(v))) {
|
||||
*out = PyInt_AsUnsignedLongLongMask(v);
|
||||
return nullptr;
|
||||
}
|
||||
#endif
|
||||
if (TF_PREDICT_TRUE(PyLong_Check(v) || IsPyDimension(v))) {
|
||||
*out = PyLong_AsUnsignedLongLong(v);
|
||||
return nullptr;
|
||||
}
|
||||
if (PyIsInstance(v, &PyIntegerArrType_Type)) { // NumPy integers
|
||||
#if PY_MAJOR_VERSION < 3
|
||||
Safe_PyObjectPtr as_int = make_safe(PyNumber_Int(v));
|
||||
#else
|
||||
Safe_PyObjectPtr as_int = make_safe(PyNumber_Long(v));
|
||||
#endif
|
||||
return ConvertScalar(as_int.get(), out);
|
||||
}
|
||||
if (IsPyFloat(v)) return ErrorFoundFloat;
|
||||
return ErrorMixedTypes;
|
||||
}
|
||||
};
|
||||
|
||||
typedef Converter<uint64> UInt64Converter;
|
||||
|
||||
template <>
|
||||
struct ConverterTraits<int32> {
|
||||
static const tensorflow::DataType kTypeEnum = DT_INT32;
|
||||
|
||||
static const char* ConvertScalar(PyObject* v, int32* out) {
|
||||
int64 i;
|
||||
#if PY_MAJOR_VERSION < 3
|
||||
if (TF_PREDICT_TRUE(PyInt_Check(v))) {
|
||||
i = PyInt_AS_LONG(v);
|
||||
} else
|
||||
#endif
|
||||
if (PyLong_Check(v) || IsPyDimension(v)) {
|
||||
int overflow = 0;
|
||||
// Have to use LongLong for 64 bits, since long is 32 bits on Windows.
|
||||
i = PyLong_AsLongLongAndOverflow(v, &overflow);
|
||||
if (TF_PREDICT_FALSE(overflow)) return ErrorOutOfRange;
|
||||
} else if (PyIsInstance(v, &PyIntegerArrType_Type)) { // NumPy integers
|
||||
#if PY_MAJOR_VERSION < 3
|
||||
Safe_PyObjectPtr as_int = make_safe(PyNumber_Int(v));
|
||||
#else
|
||||
Safe_PyObjectPtr as_int = make_safe(PyNumber_Long(v));
|
||||
#endif
|
||||
return ConvertScalar(as_int.get(), out);
|
||||
} else if (IsPyFloat(v)) {
|
||||
return ErrorFoundFloat;
|
||||
} else {
|
||||
return ErrorMixedTypes;
|
||||
}
|
||||
*out = static_cast<uint32>(static_cast<uint64>(i));
|
||||
// Check for 32-bit overflow.
|
||||
if (TF_PREDICT_FALSE(i != *out)) return ErrorFoundInt64;
|
||||
return nullptr;
|
||||
}
|
||||
};
|
||||
|
||||
typedef Converter<int32> Int32Converter;
|
||||
|
||||
// Floating-point support
|
||||
|
||||
template <class T>
|
||||
const char* ConvertOneFloat(PyObject* v, T* out) {
|
||||
static const char* ConvertOneFloat(PyObject* v, T* out) {
|
||||
if (PyErr_Occurred()) {
|
||||
return nullptr;
|
||||
}
|
||||
@ -422,81 +420,143 @@ const char* ConvertOneFloat(PyObject* v, T* out) {
|
||||
return ErrorMixedTypes;
|
||||
}
|
||||
|
||||
DEFINE_HELPER(ConvertDouble, double, DT_DOUBLE, ConvertOneFloat<double>);
|
||||
DEFINE_HELPER(ConvertFloat, float, DT_FLOAT, ConvertOneFloat<float>);
|
||||
template <>
|
||||
struct ConverterTraits<float> {
|
||||
static const tensorflow::DataType kTypeEnum = DT_FLOAT;
|
||||
static const char* ConvertScalar(PyObject* v, float* out) {
|
||||
return ConvertOneFloat<float>(v, out);
|
||||
}
|
||||
};
|
||||
|
||||
const char* ConvertOneNumpyHalf(PyObject* v, Eigen::half* out) {
|
||||
// NOTE(nareshmodi): Is there a way to convert to C double without the
|
||||
// intermediate Python double? This will help with ConvertOneFloat as well.
|
||||
Safe_PyObjectPtr as_float = make_safe(PyNumber_Float(v));
|
||||
double v_double = PyFloat_AS_DOUBLE(as_float.get());
|
||||
*out = Eigen::half(v_double);
|
||||
template <>
|
||||
struct ConverterTraits<double> {
|
||||
static const tensorflow::DataType kTypeEnum = DT_DOUBLE;
|
||||
static const char* ConvertScalar(PyObject* v, double* out) {
|
||||
return ConvertOneFloat<double>(v, out);
|
||||
}
|
||||
};
|
||||
|
||||
return nullptr;
|
||||
}
|
||||
DEFINE_HELPER(ConvertNumpyHalf, Eigen::half, DT_HALF, ConvertOneNumpyHalf);
|
||||
typedef Converter<double> DoubleConverter;
|
||||
typedef Converter<float> FloatConverter;
|
||||
|
||||
template <>
|
||||
struct ConverterTraits<Eigen::half> {
|
||||
static const tensorflow::DataType kTypeEnum = DT_HALF;
|
||||
|
||||
static const char* ConvertScalar(PyObject* v, Eigen::half* out) {
|
||||
// NOTE(nareshmodi): Is there a way to convert to C double without the
|
||||
// intermediate Python double? This will help with ConvertOneFloat as well.
|
||||
Safe_PyObjectPtr as_float = make_safe(PyNumber_Float(v));
|
||||
double v_double = PyFloat_AS_DOUBLE(as_float.get());
|
||||
*out = Eigen::half(v_double);
|
||||
|
||||
return nullptr;
|
||||
}
|
||||
};
|
||||
|
||||
typedef Converter<Eigen::half> NumpyHalfConverter;
|
||||
|
||||
// String support
|
||||
|
||||
const char* ConvertOneString(PyObject* v, tstring* out) {
|
||||
if (PyBytes_Check(v)) {
|
||||
out->assign(PyBytes_AS_STRING(v), PyBytes_GET_SIZE(v));
|
||||
return nullptr;
|
||||
}
|
||||
if (PyUnicode_Check(v)) {
|
||||
template <>
|
||||
struct ConverterTraits<string> {
|
||||
static const tensorflow::DataType kTypeEnum = DT_STRING;
|
||||
|
||||
static const char* ConvertScalar(PyObject* v, tstring* out) {
|
||||
if (PyBytes_Check(v)) {
|
||||
out->assign(PyBytes_AS_STRING(v), PyBytes_GET_SIZE(v));
|
||||
return nullptr;
|
||||
}
|
||||
if (PyUnicode_Check(v)) {
|
||||
#if PY_MAJOR_VERSION >= 3
|
||||
Py_ssize_t size;
|
||||
const char* str = PyUnicode_AsUTF8AndSize(v, &size);
|
||||
if (str == nullptr) return ErrorConvertingUnicodeString;
|
||||
out->assign(str, size);
|
||||
return nullptr;
|
||||
Py_ssize_t size;
|
||||
const char* str = PyUnicode_AsUTF8AndSize(v, &size);
|
||||
if (str == nullptr) return ErrorConvertingUnicodeString;
|
||||
out->assign(str, size);
|
||||
return nullptr;
|
||||
#else
|
||||
PyObject* py_str = PyUnicode_AsUTF8String(v);
|
||||
if (py_str == nullptr) return ErrorConvertingUnicodeString;
|
||||
out->assign(PyBytes_AS_STRING(py_str), PyBytes_GET_SIZE(py_str));
|
||||
Py_DECREF(py_str);
|
||||
return nullptr;
|
||||
PyObject* py_str = PyUnicode_AsUTF8String(v);
|
||||
if (py_str == nullptr) return ErrorConvertingUnicodeString;
|
||||
out->assign(PyBytes_AS_STRING(py_str), PyBytes_GET_SIZE(py_str));
|
||||
Py_DECREF(py_str);
|
||||
return nullptr;
|
||||
#endif
|
||||
}
|
||||
return ErrorMixedTypes;
|
||||
}
|
||||
return ErrorMixedTypes;
|
||||
};
|
||||
|
||||
typedef Converter<string> StringConverter;
|
||||
|
||||
// Converts Python object `c` that should hold a Python string into a
|
||||
// C++ string in *out. Returns nullptr on success, or a message on error.
|
||||
// Defined below, but forward declared here for use in PyRepr.
|
||||
tstring PyRepr(PyObject* obj) {
|
||||
if (obj == nullptr) {
|
||||
return "<null>";
|
||||
}
|
||||
Safe_PyObjectPtr repr_obj = make_safe(PyObject_Repr(obj));
|
||||
if (repr_obj) {
|
||||
tstring repr_str;
|
||||
if (ConverterTraits<string>::ConvertScalar(repr_obj.get(), &repr_str) ==
|
||||
nullptr) {
|
||||
return repr_str;
|
||||
}
|
||||
}
|
||||
return "<error computing repr()>";
|
||||
}
|
||||
|
||||
DEFINE_HELPER(ConvertString, tstring, DT_STRING, ConvertOneString);
|
||||
bool IsPyDimension(PyObject* obj) {
|
||||
const char* tp_name = obj->ob_type->tp_name;
|
||||
if (strcmp(tp_name, "Dimension") != 0) return false;
|
||||
bool ret = str_util::EndsWith(
|
||||
PyRepr(PyType(obj)),
|
||||
"tensorflow.python.framework.tensor_shape.Dimension'>");
|
||||
return ret;
|
||||
}
|
||||
|
||||
// Complex support
|
||||
|
||||
const char* ConvertOneComplex(PyObject* v, complex128* out) {
|
||||
if (PyComplex_Check(v)) {
|
||||
*out = complex128(PyComplex_RealAsDouble(v), PyComplex_ImagAsDouble(v));
|
||||
return nullptr;
|
||||
} else if (PyIsInstance(v, &PyComplexFloatingArrType_Type)) { // NumPy
|
||||
auto as_complex = PyComplex_AsCComplex(v);
|
||||
*out = complex128(as_complex.real, as_complex.imag);
|
||||
return nullptr;
|
||||
template <>
|
||||
struct ConverterTraits<complex128> {
|
||||
static const tensorflow::DataType kTypeEnum = DT_COMPLEX128;
|
||||
static const char* ConvertScalar(PyObject* v, complex128* out) {
|
||||
if (PyComplex_Check(v)) {
|
||||
*out = complex128(PyComplex_RealAsDouble(v), PyComplex_ImagAsDouble(v));
|
||||
return nullptr;
|
||||
} else if (PyIsInstance(v, &PyComplexFloatingArrType_Type)) { // NumPy
|
||||
auto as_complex = PyComplex_AsCComplex(v);
|
||||
*out = complex128(as_complex.real, as_complex.imag);
|
||||
return nullptr;
|
||||
}
|
||||
return ErrorMixedTypes;
|
||||
}
|
||||
return ErrorMixedTypes;
|
||||
}
|
||||
};
|
||||
|
||||
DEFINE_HELPER(ConvertComplex, complex128, DT_COMPLEX128, ConvertOneComplex);
|
||||
typedef Converter<complex128> Complex128Converter;
|
||||
|
||||
// Bool support
|
||||
|
||||
const char* ConvertOneBool(PyObject* v, bool* out) {
|
||||
if (v == Py_True) {
|
||||
*out = true;
|
||||
} else if (v == Py_False) {
|
||||
*out = false;
|
||||
} else if (PyIsInstance(v, &PyBoolArrType_Type)) { // NumPy
|
||||
*out = PyObject_IsTrue(v);
|
||||
} else {
|
||||
return ErrorMixedTypes;
|
||||
template <>
|
||||
struct ConverterTraits<bool> {
|
||||
typedef bool Type;
|
||||
static const tensorflow::DataType kTypeEnum = DT_BOOL;
|
||||
|
||||
static const char* ConvertScalar(PyObject* v, bool* out) {
|
||||
if (v == Py_True) {
|
||||
*out = true;
|
||||
} else if (v == Py_False) {
|
||||
*out = false;
|
||||
} else if (PyIsInstance(v, &PyBoolArrType_Type)) { // NumPy
|
||||
*out = PyObject_IsTrue(v);
|
||||
} else {
|
||||
return ErrorMixedTypes;
|
||||
}
|
||||
return nullptr;
|
||||
}
|
||||
return nullptr;
|
||||
}
|
||||
};
|
||||
|
||||
DEFINE_HELPER(ConvertBool, bool, DT_BOOL, ConvertOneBool);
|
||||
|
||||
#undef DEFINE_HELPER
|
||||
typedef Converter<bool> BoolConverter;
|
||||
|
||||
} // namespace
|
||||
|
||||
@ -521,38 +581,46 @@ Status PySeqToTensor(PyObject* obj, DataType dtype, Tensor* ret) {
|
||||
// operation.
|
||||
switch (requested_dtype) {
|
||||
case DT_FLOAT:
|
||||
if (ConvertFloat(obj, shape, ret) == nullptr) return Status::OK();
|
||||
if (FloatConverter::Convert(obj, shape, ret) == nullptr)
|
||||
return Status::OK();
|
||||
break;
|
||||
|
||||
case DT_DOUBLE:
|
||||
if (ConvertDouble(obj, shape, ret) == nullptr) return Status::OK();
|
||||
if (DoubleConverter::Convert(obj, shape, ret) == nullptr)
|
||||
return Status::OK();
|
||||
break;
|
||||
|
||||
case DT_HALF:
|
||||
RETURN_STRING_AS_STATUS(ConvertNumpyHalf(obj, shape, ret));
|
||||
RETURN_STRING_AS_STATUS(NumpyHalfConverter::Convert(obj, shape, ret));
|
||||
|
||||
case DT_INT64:
|
||||
if (ConvertInt64(obj, shape, ret) == nullptr) return Status::OK();
|
||||
if (Int64Converter::Convert(obj, shape, ret) == nullptr)
|
||||
return Status::OK();
|
||||
break;
|
||||
|
||||
case DT_INT32:
|
||||
if (ConvertInt32(obj, shape, ret) == nullptr) return Status::OK();
|
||||
if (Int32Converter::Convert(obj, shape, ret) == nullptr)
|
||||
return Status::OK();
|
||||
break;
|
||||
|
||||
case DT_UINT64:
|
||||
if (ConvertUint64(obj, shape, ret) == nullptr) return Status::OK();
|
||||
if (UInt64Converter::Convert(obj, shape, ret) == nullptr)
|
||||
return Status::OK();
|
||||
break;
|
||||
|
||||
case DT_COMPLEX128:
|
||||
if (ConvertComplex(obj, shape, ret) == nullptr) return Status::OK();
|
||||
if (Complex128Converter::Convert(obj, shape, ret) == nullptr)
|
||||
return Status::OK();
|
||||
break;
|
||||
|
||||
case DT_STRING:
|
||||
if (ConvertString(obj, shape, ret) == nullptr) return Status::OK();
|
||||
if (StringConverter::Convert(obj, shape, ret) == nullptr)
|
||||
return Status::OK();
|
||||
break;
|
||||
|
||||
case DT_BOOL:
|
||||
if (ConvertBool(obj, shape, ret) == nullptr) return Status::OK();
|
||||
if (BoolConverter::Convert(obj, shape, ret) == nullptr)
|
||||
return Status::OK();
|
||||
break;
|
||||
|
||||
default:
|
||||
@ -564,49 +632,49 @@ Status PySeqToTensor(PyObject* obj, DataType dtype, Tensor* ret) {
|
||||
if (requested_dtype == DT_INVALID) {
|
||||
// TensorFlow uses float32s to represent floating point numbers
|
||||
// by default (for space and speed over using doubles).
|
||||
RETURN_STRING_AS_STATUS(ConvertFloat(obj, shape, ret));
|
||||
RETURN_STRING_AS_STATUS(FloatConverter::Convert(obj, shape, ret));
|
||||
} else {
|
||||
// We are going to do a cast to the user's requested dtype
|
||||
// after this. We use doubles for this intermediate result so
|
||||
// we don't lose precision that might be representable in the
|
||||
// final type.
|
||||
RETURN_STRING_AS_STATUS(ConvertDouble(obj, shape, ret));
|
||||
RETURN_STRING_AS_STATUS(DoubleConverter::Convert(obj, shape, ret));
|
||||
}
|
||||
|
||||
case DT_DOUBLE:
|
||||
RETURN_STRING_AS_STATUS(ConvertDouble(obj, shape, ret));
|
||||
RETURN_STRING_AS_STATUS(DoubleConverter::Convert(obj, shape, ret));
|
||||
|
||||
case DT_HALF:
|
||||
RETURN_STRING_AS_STATUS(ConvertNumpyHalf(obj, shape, ret));
|
||||
RETURN_STRING_AS_STATUS(NumpyHalfConverter::Convert(obj, shape, ret));
|
||||
|
||||
case DT_INT64:
|
||||
if (requested_dtype == DT_INVALID) {
|
||||
const char* error = ConvertInt32(obj, shape, ret);
|
||||
const char* error = Int32Converter::Convert(obj, shape, ret);
|
||||
if (error == ErrorFoundInt64) {
|
||||
error = ConvertInt64(obj, shape, ret);
|
||||
error = Int64Converter::Convert(obj, shape, ret);
|
||||
}
|
||||
if (error == ErrorFoundFloat) {
|
||||
error = ConvertFloat(obj, shape, ret);
|
||||
error = FloatConverter::Convert(obj, shape, ret);
|
||||
}
|
||||
// TODO(josh11b): May also want to fall back to using doubles if
|
||||
// error == ErrorOutOfRange?
|
||||
RETURN_STRING_AS_STATUS(error);
|
||||
} else {
|
||||
const char* error = ConvertInt64(obj, shape, ret);
|
||||
const char* error = Int64Converter::Convert(obj, shape, ret);
|
||||
if (error == ErrorFoundFloat) {
|
||||
error = ConvertDouble(obj, shape, ret);
|
||||
error = DoubleConverter::Convert(obj, shape, ret);
|
||||
}
|
||||
RETURN_STRING_AS_STATUS(error);
|
||||
}
|
||||
|
||||
case DT_STRING:
|
||||
RETURN_STRING_AS_STATUS(ConvertString(obj, shape, ret));
|
||||
RETURN_STRING_AS_STATUS(StringConverter::Convert(obj, shape, ret));
|
||||
|
||||
case DT_COMPLEX128:
|
||||
RETURN_STRING_AS_STATUS(ConvertComplex(obj, shape, ret));
|
||||
RETURN_STRING_AS_STATUS(Complex128Converter::Convert(obj, shape, ret));
|
||||
|
||||
case DT_BOOL:
|
||||
RETURN_STRING_AS_STATUS(ConvertBool(obj, shape, ret));
|
||||
RETURN_STRING_AS_STATUS(BoolConverter::Convert(obj, shape, ret));
|
||||
|
||||
case DT_INVALID: // Only occurs for empty tensors.
|
||||
*ret = Tensor(requested_dtype == DT_INVALID ? DT_FLOAT : requested_dtype,
|
||||
|
Loading…
x
Reference in New Issue
Block a user