Re-write a long macro as a template class instead.
This will facilitate changes in the functions that were previously defined by macro calls. The re-write tries to be quite literal: no other refactors or optimizations were done in this change. PiperOrigin-RevId: 279088175 Change-Id: I2d6f24ad53eb70e330b3d3448a5616cb7f6b6cf7
This commit is contained in:
parent
700263d02a
commit
1746a9229a
@ -77,34 +77,6 @@ PyObject* ZeroDimArrayToScalar(PyObject* obj) {
|
|||||||
return obj;
|
return obj;
|
||||||
}
|
}
|
||||||
|
|
||||||
// Converts Python object `c` that should hold a Python string into a
|
|
||||||
// C++ string in *out. Returns nullptr on success, or a message on error.
|
|
||||||
// Defined below, but forward declared here for use in PyRepr.
|
|
||||||
const char* ConvertOneString(PyObject* v, tstring* out);
|
|
||||||
|
|
||||||
tstring PyRepr(PyObject* obj) {
|
|
||||||
if (obj == nullptr) {
|
|
||||||
return "<null>";
|
|
||||||
}
|
|
||||||
Safe_PyObjectPtr repr_obj = make_safe(PyObject_Repr(obj));
|
|
||||||
if (repr_obj) {
|
|
||||||
tstring repr_str;
|
|
||||||
if (ConvertOneString(repr_obj.get(), &repr_str) == nullptr) {
|
|
||||||
return repr_str;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return "<error computing repr()>";
|
|
||||||
}
|
|
||||||
|
|
||||||
bool IsPyDimension(PyObject* obj) {
|
|
||||||
const char* tp_name = obj->ob_type->tp_name;
|
|
||||||
if (strcmp(tp_name, "Dimension") != 0) return false;
|
|
||||||
bool ret = str_util::EndsWith(
|
|
||||||
PyRepr(PyType(obj)),
|
|
||||||
"tensorflow.python.framework.tensor_shape.Dimension'>");
|
|
||||||
return ret;
|
|
||||||
}
|
|
||||||
|
|
||||||
// Sets *elem to a NEW reference to an element in seq on success.
|
// Sets *elem to a NEW reference to an element in seq on success.
|
||||||
// REQUIRES: PySequence_Check(seq) && PySequence_Length(seq) > 0.
|
// REQUIRES: PySequence_Check(seq) && PySequence_Length(seq) > 0.
|
||||||
Status SampleElementFromSequence(PyObject* seq, PyObject** elem) {
|
Status SampleElementFromSequence(PyObject* seq, PyObject** elem) {
|
||||||
@ -140,6 +112,9 @@ Status SampleElementFromSequence(PyObject* seq, PyObject** elem) {
|
|||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
tstring PyRepr(PyObject* obj);
|
||||||
|
bool IsPyDimension(PyObject* obj);
|
||||||
|
|
||||||
Status InferShapeAndType(PyObject* obj, TensorShape* shape, DataType* dtype) {
|
Status InferShapeAndType(PyObject* obj, TensorShape* shape, DataType* dtype) {
|
||||||
std::vector<Safe_PyObjectPtr> refs_to_clean;
|
std::vector<Safe_PyObjectPtr> refs_to_clean;
|
||||||
while (true) {
|
while (true) {
|
||||||
@ -223,72 +198,84 @@ const char ErrorFoundFloat[] =
|
|||||||
"Can't convert Python sequence with floating point values to integer "
|
"Can't convert Python sequence with floating point values to integer "
|
||||||
"Tensor.";
|
"Tensor.";
|
||||||
|
|
||||||
// Template for defining a function for recursively convering obj into
|
// Defines a converter that recursively converts an object into
|
||||||
// an array of TYPE using the conversion function CONVERT.
|
// an array of type T using the conversion function defined by the
|
||||||
|
// traits class in a ConvertScalar function.
|
||||||
|
//
|
||||||
// Note that these helper functions require shape.dims() >= 1.
|
// Note that these helper functions require shape.dims() >= 1.
|
||||||
|
template <class T>
|
||||||
|
struct ConverterTraits {
|
||||||
|
static const tensorflow::DataType kTypeEnum;
|
||||||
|
static const char* ConvertScalar(PyObject* v, T* out);
|
||||||
|
};
|
||||||
|
|
||||||
#define DEFINE_HELPER(FUNCTION, TYPE, TYPE_ENUM, CONVERT) \
|
template <class T>
|
||||||
const char* FUNCTION##Helper(PyObject* obj, const TensorShape& shape, \
|
struct Converter {
|
||||||
TYPE** buf) { \
|
static const char* Helper(PyObject* obj, const TensorShape& shape, T** buf) {
|
||||||
if (TF_PREDICT_FALSE(obj == nullptr)) { \
|
if (TF_PREDICT_FALSE(obj == nullptr)) {
|
||||||
return ErrorConverting; \
|
return ErrorConverting;
|
||||||
} \
|
|
||||||
if (shape.dims() > 1) { \
|
|
||||||
/* Iterate over outer dim, and recursively convert each element. */ \
|
|
||||||
const int64 s = shape.dim_size(0); \
|
|
||||||
Safe_PyObjectPtr seq = make_safe(PySequence_Fast(obj, "")); \
|
|
||||||
if (TF_PREDICT_FALSE(seq == nullptr)) return ErrorRectangular; \
|
|
||||||
if (TF_PREDICT_FALSE(s != PySequence_Fast_GET_SIZE(seq.get()))) { \
|
|
||||||
return ErrorRectangular; \
|
|
||||||
} \
|
|
||||||
TensorShape rest = shape; \
|
|
||||||
rest.RemoveDim(0); \
|
|
||||||
for (int64 i = 0; i < s; ++i) { \
|
|
||||||
const char* error = FUNCTION##Helper( \
|
|
||||||
PySequence_Fast_GET_ITEM(seq.get(), i), rest, buf); \
|
|
||||||
if (TF_PREDICT_FALSE(error != nullptr)) return error; \
|
|
||||||
} \
|
|
||||||
} else { \
|
|
||||||
Safe_PyObjectPtr seq = make_safe(PySequence_Fast(obj, "")); \
|
|
||||||
if (TF_PREDICT_FALSE(seq == nullptr)) return ErrorRectangular; \
|
|
||||||
const int64 s = shape.dim_size(0); \
|
|
||||||
if (TF_PREDICT_FALSE(s != PySequence_Fast_GET_SIZE(seq.get()))) { \
|
|
||||||
return ErrorRectangular; \
|
|
||||||
} \
|
|
||||||
PyObject** l = PySequence_Fast_ITEMS(seq.get()); \
|
|
||||||
for (int64 i = 0; i < s; ++i) { \
|
|
||||||
auto scalar = ZeroDimArrayToScalar(l[i]); \
|
|
||||||
const char* error = CONVERT(scalar, *buf); \
|
|
||||||
Py_DECREF(scalar); \
|
|
||||||
if (TF_PREDICT_FALSE(error != nullptr)) return error; \
|
|
||||||
++*buf; \
|
|
||||||
} \
|
|
||||||
} \
|
|
||||||
return nullptr; \
|
|
||||||
} \
|
|
||||||
const char* FUNCTION(PyObject* obj, const TensorShape& shape, \
|
|
||||||
Tensor* dest) { \
|
|
||||||
/* TODO(josh11b): Allocator & attributes? */ \
|
|
||||||
Tensor result(TYPE_ENUM, shape); \
|
|
||||||
if (shape.dims() == 0) { /* Scalar case */ \
|
|
||||||
TYPE value; \
|
|
||||||
auto scalar = ZeroDimArrayToScalar(obj); \
|
|
||||||
const char* error = CONVERT(scalar, &value); \
|
|
||||||
Py_DECREF(scalar); \
|
|
||||||
if (error != nullptr) return error; \
|
|
||||||
result.scalar<TYPE>()() = value; \
|
|
||||||
} else { \
|
|
||||||
TYPE* buf = result.flat<TYPE>().data(); \
|
|
||||||
const char* error = FUNCTION##Helper(obj, shape, &buf); \
|
|
||||||
if (error != nullptr) return error; \
|
|
||||||
} \
|
|
||||||
*dest = result; \
|
|
||||||
return nullptr; \
|
|
||||||
}
|
}
|
||||||
|
if (shape.dims() > 1) {
|
||||||
|
/* Iterate over outer dim, and recursively convert each element. */
|
||||||
|
const int64 s = shape.dim_size(0);
|
||||||
|
Safe_PyObjectPtr seq = make_safe(PySequence_Fast(obj, ""));
|
||||||
|
if (TF_PREDICT_FALSE(seq == nullptr)) return ErrorRectangular;
|
||||||
|
if (TF_PREDICT_FALSE(s != PySequence_Fast_GET_SIZE(seq.get()))) {
|
||||||
|
return ErrorRectangular;
|
||||||
|
}
|
||||||
|
TensorShape rest = shape;
|
||||||
|
rest.RemoveDim(0);
|
||||||
|
for (int64 i = 0; i < s; ++i) {
|
||||||
|
const char* error =
|
||||||
|
Helper(PySequence_Fast_GET_ITEM(seq.get(), i), rest, buf);
|
||||||
|
if (TF_PREDICT_FALSE(error != nullptr)) return error;
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
Safe_PyObjectPtr seq = make_safe(PySequence_Fast(obj, ""));
|
||||||
|
if (TF_PREDICT_FALSE(seq == nullptr)) return ErrorRectangular;
|
||||||
|
const int64 s = shape.dim_size(0);
|
||||||
|
if (TF_PREDICT_FALSE(s != PySequence_Fast_GET_SIZE(seq.get()))) {
|
||||||
|
return ErrorRectangular;
|
||||||
|
}
|
||||||
|
PyObject** l = PySequence_Fast_ITEMS(seq.get());
|
||||||
|
for (int64 i = 0; i < s; ++i) {
|
||||||
|
auto scalar = ZeroDimArrayToScalar(l[i]);
|
||||||
|
const char* error = ConverterTraits<T>::ConvertScalar(scalar, *buf);
|
||||||
|
Py_DECREF(scalar);
|
||||||
|
if (TF_PREDICT_FALSE(error != nullptr)) return error;
|
||||||
|
++*buf;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
static const char* Convert(PyObject* obj, const TensorShape& shape,
|
||||||
|
Tensor* dest) {
|
||||||
|
/* TODO(josh11b): Allocator & attributes? */
|
||||||
|
Tensor result(ConverterTraits<T>::kTypeEnum, shape);
|
||||||
|
if (shape.dims() == 0) { /* Scalar case */
|
||||||
|
T value;
|
||||||
|
auto scalar = ZeroDimArrayToScalar(obj);
|
||||||
|
const char* error = ConverterTraits<T>::ConvertScalar(scalar, &value);
|
||||||
|
Py_DECREF(scalar);
|
||||||
|
if (error != nullptr) return error;
|
||||||
|
result.scalar<T>()() = value;
|
||||||
|
} else {
|
||||||
|
T* buf = result.flat<T>().data();
|
||||||
|
const char* error = Helper(obj, shape, &buf);
|
||||||
|
if (error != nullptr) return error;
|
||||||
|
}
|
||||||
|
*dest = result;
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
// Int support
|
// Int support
|
||||||
|
|
||||||
const char* ConvertOneInt64(PyObject* v, int64* out) {
|
template <>
|
||||||
|
struct ConverterTraits<int64> {
|
||||||
|
static const tensorflow::DataType kTypeEnum = DT_INT64;
|
||||||
|
|
||||||
|
static const char* ConvertScalar(PyObject* v, int64* out) {
|
||||||
#if PY_MAJOR_VERSION < 3
|
#if PY_MAJOR_VERSION < 3
|
||||||
if (TF_PREDICT_TRUE(PyInt_Check(v))) {
|
if (TF_PREDICT_TRUE(PyInt_Check(v))) {
|
||||||
*out = PyInt_AS_LONG(v);
|
*out = PyInt_AS_LONG(v);
|
||||||
@ -308,15 +295,20 @@ const char* ConvertOneInt64(PyObject* v, int64* out) {
|
|||||||
#else
|
#else
|
||||||
Safe_PyObjectPtr as_int = make_safe(PyNumber_Long(v));
|
Safe_PyObjectPtr as_int = make_safe(PyNumber_Long(v));
|
||||||
#endif
|
#endif
|
||||||
return ConvertOneInt64(as_int.get(), out);
|
return ConvertScalar(as_int.get(), out);
|
||||||
}
|
}
|
||||||
if (IsPyFloat(v)) return ErrorFoundFloat;
|
if (IsPyFloat(v)) return ErrorFoundFloat;
|
||||||
return ErrorMixedTypes;
|
return ErrorMixedTypes;
|
||||||
}
|
}
|
||||||
|
};
|
||||||
|
|
||||||
DEFINE_HELPER(ConvertInt64, int64, DT_INT64, ConvertOneInt64);
|
typedef Converter<int64> Int64Converter;
|
||||||
|
|
||||||
const char* ConvertOneUint64(PyObject* v, uint64* out) {
|
template <>
|
||||||
|
struct ConverterTraits<uint64> {
|
||||||
|
static const tensorflow::DataType kTypeEnum = DT_UINT64;
|
||||||
|
|
||||||
|
static const char* ConvertScalar(PyObject* v, uint64* out) {
|
||||||
#if PY_MAJOR_VERSION < 3
|
#if PY_MAJOR_VERSION < 3
|
||||||
if (TF_PREDICT_TRUE(PyInt_Check(v))) {
|
if (TF_PREDICT_TRUE(PyInt_Check(v))) {
|
||||||
*out = PyInt_AsUnsignedLongLongMask(v);
|
*out = PyInt_AsUnsignedLongLongMask(v);
|
||||||
@ -333,15 +325,20 @@ const char* ConvertOneUint64(PyObject* v, uint64* out) {
|
|||||||
#else
|
#else
|
||||||
Safe_PyObjectPtr as_int = make_safe(PyNumber_Long(v));
|
Safe_PyObjectPtr as_int = make_safe(PyNumber_Long(v));
|
||||||
#endif
|
#endif
|
||||||
return ConvertOneUint64(as_int.get(), out);
|
return ConvertScalar(as_int.get(), out);
|
||||||
}
|
}
|
||||||
if (IsPyFloat(v)) return ErrorFoundFloat;
|
if (IsPyFloat(v)) return ErrorFoundFloat;
|
||||||
return ErrorMixedTypes;
|
return ErrorMixedTypes;
|
||||||
}
|
}
|
||||||
|
};
|
||||||
|
|
||||||
DEFINE_HELPER(ConvertUint64, uint64, DT_UINT64, ConvertOneUint64);
|
typedef Converter<uint64> UInt64Converter;
|
||||||
|
|
||||||
const char* ConvertOneInt32(PyObject* v, int32* out) {
|
template <>
|
||||||
|
struct ConverterTraits<int32> {
|
||||||
|
static const tensorflow::DataType kTypeEnum = DT_INT32;
|
||||||
|
|
||||||
|
static const char* ConvertScalar(PyObject* v, int32* out) {
|
||||||
int64 i;
|
int64 i;
|
||||||
#if PY_MAJOR_VERSION < 3
|
#if PY_MAJOR_VERSION < 3
|
||||||
if (TF_PREDICT_TRUE(PyInt_Check(v))) {
|
if (TF_PREDICT_TRUE(PyInt_Check(v))) {
|
||||||
@ -359,7 +356,7 @@ const char* ConvertOneInt32(PyObject* v, int32* out) {
|
|||||||
#else
|
#else
|
||||||
Safe_PyObjectPtr as_int = make_safe(PyNumber_Long(v));
|
Safe_PyObjectPtr as_int = make_safe(PyNumber_Long(v));
|
||||||
#endif
|
#endif
|
||||||
return ConvertOneInt32(as_int.get(), out);
|
return ConvertScalar(as_int.get(), out);
|
||||||
} else if (IsPyFloat(v)) {
|
} else if (IsPyFloat(v)) {
|
||||||
return ErrorFoundFloat;
|
return ErrorFoundFloat;
|
||||||
} else {
|
} else {
|
||||||
@ -369,14 +366,15 @@ const char* ConvertOneInt32(PyObject* v, int32* out) {
|
|||||||
// Check for 32-bit overflow.
|
// Check for 32-bit overflow.
|
||||||
if (TF_PREDICT_FALSE(i != *out)) return ErrorFoundInt64;
|
if (TF_PREDICT_FALSE(i != *out)) return ErrorFoundInt64;
|
||||||
return nullptr;
|
return nullptr;
|
||||||
}
|
}
|
||||||
|
};
|
||||||
|
|
||||||
DEFINE_HELPER(ConvertInt32, int32, DT_INT32, ConvertOneInt32);
|
typedef Converter<int32> Int32Converter;
|
||||||
|
|
||||||
// Floating-point support
|
// Floating-point support
|
||||||
|
|
||||||
template <class T>
|
template <class T>
|
||||||
const char* ConvertOneFloat(PyObject* v, T* out) {
|
static const char* ConvertOneFloat(PyObject* v, T* out) {
|
||||||
if (PyErr_Occurred()) {
|
if (PyErr_Occurred()) {
|
||||||
return nullptr;
|
return nullptr;
|
||||||
}
|
}
|
||||||
@ -422,10 +420,30 @@ const char* ConvertOneFloat(PyObject* v, T* out) {
|
|||||||
return ErrorMixedTypes;
|
return ErrorMixedTypes;
|
||||||
}
|
}
|
||||||
|
|
||||||
DEFINE_HELPER(ConvertDouble, double, DT_DOUBLE, ConvertOneFloat<double>);
|
template <>
|
||||||
DEFINE_HELPER(ConvertFloat, float, DT_FLOAT, ConvertOneFloat<float>);
|
struct ConverterTraits<float> {
|
||||||
|
static const tensorflow::DataType kTypeEnum = DT_FLOAT;
|
||||||
|
static const char* ConvertScalar(PyObject* v, float* out) {
|
||||||
|
return ConvertOneFloat<float>(v, out);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
const char* ConvertOneNumpyHalf(PyObject* v, Eigen::half* out) {
|
template <>
|
||||||
|
struct ConverterTraits<double> {
|
||||||
|
static const tensorflow::DataType kTypeEnum = DT_DOUBLE;
|
||||||
|
static const char* ConvertScalar(PyObject* v, double* out) {
|
||||||
|
return ConvertOneFloat<double>(v, out);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
typedef Converter<double> DoubleConverter;
|
||||||
|
typedef Converter<float> FloatConverter;
|
||||||
|
|
||||||
|
template <>
|
||||||
|
struct ConverterTraits<Eigen::half> {
|
||||||
|
static const tensorflow::DataType kTypeEnum = DT_HALF;
|
||||||
|
|
||||||
|
static const char* ConvertScalar(PyObject* v, Eigen::half* out) {
|
||||||
// NOTE(nareshmodi): Is there a way to convert to C double without the
|
// NOTE(nareshmodi): Is there a way to convert to C double without the
|
||||||
// intermediate Python double? This will help with ConvertOneFloat as well.
|
// intermediate Python double? This will help with ConvertOneFloat as well.
|
||||||
Safe_PyObjectPtr as_float = make_safe(PyNumber_Float(v));
|
Safe_PyObjectPtr as_float = make_safe(PyNumber_Float(v));
|
||||||
@ -433,12 +451,18 @@ const char* ConvertOneNumpyHalf(PyObject* v, Eigen::half* out) {
|
|||||||
*out = Eigen::half(v_double);
|
*out = Eigen::half(v_double);
|
||||||
|
|
||||||
return nullptr;
|
return nullptr;
|
||||||
}
|
}
|
||||||
DEFINE_HELPER(ConvertNumpyHalf, Eigen::half, DT_HALF, ConvertOneNumpyHalf);
|
};
|
||||||
|
|
||||||
|
typedef Converter<Eigen::half> NumpyHalfConverter;
|
||||||
|
|
||||||
// String support
|
// String support
|
||||||
|
|
||||||
const char* ConvertOneString(PyObject* v, tstring* out) {
|
template <>
|
||||||
|
struct ConverterTraits<string> {
|
||||||
|
static const tensorflow::DataType kTypeEnum = DT_STRING;
|
||||||
|
|
||||||
|
static const char* ConvertScalar(PyObject* v, tstring* out) {
|
||||||
if (PyBytes_Check(v)) {
|
if (PyBytes_Check(v)) {
|
||||||
out->assign(PyBytes_AS_STRING(v), PyBytes_GET_SIZE(v));
|
out->assign(PyBytes_AS_STRING(v), PyBytes_GET_SIZE(v));
|
||||||
return nullptr;
|
return nullptr;
|
||||||
@ -459,13 +483,44 @@ const char* ConvertOneString(PyObject* v, tstring* out) {
|
|||||||
#endif
|
#endif
|
||||||
}
|
}
|
||||||
return ErrorMixedTypes;
|
return ErrorMixedTypes;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
typedef Converter<string> StringConverter;
|
||||||
|
|
||||||
|
// Converts Python object `c` that should hold a Python string into a
|
||||||
|
// C++ string in *out. Returns nullptr on success, or a message on error.
|
||||||
|
// Defined below, but forward declared here for use in PyRepr.
|
||||||
|
tstring PyRepr(PyObject* obj) {
|
||||||
|
if (obj == nullptr) {
|
||||||
|
return "<null>";
|
||||||
|
}
|
||||||
|
Safe_PyObjectPtr repr_obj = make_safe(PyObject_Repr(obj));
|
||||||
|
if (repr_obj) {
|
||||||
|
tstring repr_str;
|
||||||
|
if (ConverterTraits<string>::ConvertScalar(repr_obj.get(), &repr_str) ==
|
||||||
|
nullptr) {
|
||||||
|
return repr_str;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return "<error computing repr()>";
|
||||||
}
|
}
|
||||||
|
|
||||||
DEFINE_HELPER(ConvertString, tstring, DT_STRING, ConvertOneString);
|
bool IsPyDimension(PyObject* obj) {
|
||||||
|
const char* tp_name = obj->ob_type->tp_name;
|
||||||
|
if (strcmp(tp_name, "Dimension") != 0) return false;
|
||||||
|
bool ret = str_util::EndsWith(
|
||||||
|
PyRepr(PyType(obj)),
|
||||||
|
"tensorflow.python.framework.tensor_shape.Dimension'>");
|
||||||
|
return ret;
|
||||||
|
}
|
||||||
|
|
||||||
// Complex support
|
// Complex support
|
||||||
|
|
||||||
const char* ConvertOneComplex(PyObject* v, complex128* out) {
|
template <>
|
||||||
|
struct ConverterTraits<complex128> {
|
||||||
|
static const tensorflow::DataType kTypeEnum = DT_COMPLEX128;
|
||||||
|
static const char* ConvertScalar(PyObject* v, complex128* out) {
|
||||||
if (PyComplex_Check(v)) {
|
if (PyComplex_Check(v)) {
|
||||||
*out = complex128(PyComplex_RealAsDouble(v), PyComplex_ImagAsDouble(v));
|
*out = complex128(PyComplex_RealAsDouble(v), PyComplex_ImagAsDouble(v));
|
||||||
return nullptr;
|
return nullptr;
|
||||||
@ -475,13 +530,19 @@ const char* ConvertOneComplex(PyObject* v, complex128* out) {
|
|||||||
return nullptr;
|
return nullptr;
|
||||||
}
|
}
|
||||||
return ErrorMixedTypes;
|
return ErrorMixedTypes;
|
||||||
}
|
}
|
||||||
|
};
|
||||||
|
|
||||||
DEFINE_HELPER(ConvertComplex, complex128, DT_COMPLEX128, ConvertOneComplex);
|
typedef Converter<complex128> Complex128Converter;
|
||||||
|
|
||||||
// Bool support
|
// Bool support
|
||||||
|
|
||||||
const char* ConvertOneBool(PyObject* v, bool* out) {
|
template <>
|
||||||
|
struct ConverterTraits<bool> {
|
||||||
|
typedef bool Type;
|
||||||
|
static const tensorflow::DataType kTypeEnum = DT_BOOL;
|
||||||
|
|
||||||
|
static const char* ConvertScalar(PyObject* v, bool* out) {
|
||||||
if (v == Py_True) {
|
if (v == Py_True) {
|
||||||
*out = true;
|
*out = true;
|
||||||
} else if (v == Py_False) {
|
} else if (v == Py_False) {
|
||||||
@ -492,11 +553,10 @@ const char* ConvertOneBool(PyObject* v, bool* out) {
|
|||||||
return ErrorMixedTypes;
|
return ErrorMixedTypes;
|
||||||
}
|
}
|
||||||
return nullptr;
|
return nullptr;
|
||||||
}
|
}
|
||||||
|
};
|
||||||
|
|
||||||
DEFINE_HELPER(ConvertBool, bool, DT_BOOL, ConvertOneBool);
|
typedef Converter<bool> BoolConverter;
|
||||||
|
|
||||||
#undef DEFINE_HELPER
|
|
||||||
|
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
@ -521,38 +581,46 @@ Status PySeqToTensor(PyObject* obj, DataType dtype, Tensor* ret) {
|
|||||||
// operation.
|
// operation.
|
||||||
switch (requested_dtype) {
|
switch (requested_dtype) {
|
||||||
case DT_FLOAT:
|
case DT_FLOAT:
|
||||||
if (ConvertFloat(obj, shape, ret) == nullptr) return Status::OK();
|
if (FloatConverter::Convert(obj, shape, ret) == nullptr)
|
||||||
|
return Status::OK();
|
||||||
break;
|
break;
|
||||||
|
|
||||||
case DT_DOUBLE:
|
case DT_DOUBLE:
|
||||||
if (ConvertDouble(obj, shape, ret) == nullptr) return Status::OK();
|
if (DoubleConverter::Convert(obj, shape, ret) == nullptr)
|
||||||
|
return Status::OK();
|
||||||
break;
|
break;
|
||||||
|
|
||||||
case DT_HALF:
|
case DT_HALF:
|
||||||
RETURN_STRING_AS_STATUS(ConvertNumpyHalf(obj, shape, ret));
|
RETURN_STRING_AS_STATUS(NumpyHalfConverter::Convert(obj, shape, ret));
|
||||||
|
|
||||||
case DT_INT64:
|
case DT_INT64:
|
||||||
if (ConvertInt64(obj, shape, ret) == nullptr) return Status::OK();
|
if (Int64Converter::Convert(obj, shape, ret) == nullptr)
|
||||||
|
return Status::OK();
|
||||||
break;
|
break;
|
||||||
|
|
||||||
case DT_INT32:
|
case DT_INT32:
|
||||||
if (ConvertInt32(obj, shape, ret) == nullptr) return Status::OK();
|
if (Int32Converter::Convert(obj, shape, ret) == nullptr)
|
||||||
|
return Status::OK();
|
||||||
break;
|
break;
|
||||||
|
|
||||||
case DT_UINT64:
|
case DT_UINT64:
|
||||||
if (ConvertUint64(obj, shape, ret) == nullptr) return Status::OK();
|
if (UInt64Converter::Convert(obj, shape, ret) == nullptr)
|
||||||
|
return Status::OK();
|
||||||
break;
|
break;
|
||||||
|
|
||||||
case DT_COMPLEX128:
|
case DT_COMPLEX128:
|
||||||
if (ConvertComplex(obj, shape, ret) == nullptr) return Status::OK();
|
if (Complex128Converter::Convert(obj, shape, ret) == nullptr)
|
||||||
|
return Status::OK();
|
||||||
break;
|
break;
|
||||||
|
|
||||||
case DT_STRING:
|
case DT_STRING:
|
||||||
if (ConvertString(obj, shape, ret) == nullptr) return Status::OK();
|
if (StringConverter::Convert(obj, shape, ret) == nullptr)
|
||||||
|
return Status::OK();
|
||||||
break;
|
break;
|
||||||
|
|
||||||
case DT_BOOL:
|
case DT_BOOL:
|
||||||
if (ConvertBool(obj, shape, ret) == nullptr) return Status::OK();
|
if (BoolConverter::Convert(obj, shape, ret) == nullptr)
|
||||||
|
return Status::OK();
|
||||||
break;
|
break;
|
||||||
|
|
||||||
default:
|
default:
|
||||||
@ -564,49 +632,49 @@ Status PySeqToTensor(PyObject* obj, DataType dtype, Tensor* ret) {
|
|||||||
if (requested_dtype == DT_INVALID) {
|
if (requested_dtype == DT_INVALID) {
|
||||||
// TensorFlow uses float32s to represent floating point numbers
|
// TensorFlow uses float32s to represent floating point numbers
|
||||||
// by default (for space and speed over using doubles).
|
// by default (for space and speed over using doubles).
|
||||||
RETURN_STRING_AS_STATUS(ConvertFloat(obj, shape, ret));
|
RETURN_STRING_AS_STATUS(FloatConverter::Convert(obj, shape, ret));
|
||||||
} else {
|
} else {
|
||||||
// We are going to do a cast to the user's requested dtype
|
// We are going to do a cast to the user's requested dtype
|
||||||
// after this. We use doubles for this intermediate result so
|
// after this. We use doubles for this intermediate result so
|
||||||
// we don't lose precision that might be representable in the
|
// we don't lose precision that might be representable in the
|
||||||
// final type.
|
// final type.
|
||||||
RETURN_STRING_AS_STATUS(ConvertDouble(obj, shape, ret));
|
RETURN_STRING_AS_STATUS(DoubleConverter::Convert(obj, shape, ret));
|
||||||
}
|
}
|
||||||
|
|
||||||
case DT_DOUBLE:
|
case DT_DOUBLE:
|
||||||
RETURN_STRING_AS_STATUS(ConvertDouble(obj, shape, ret));
|
RETURN_STRING_AS_STATUS(DoubleConverter::Convert(obj, shape, ret));
|
||||||
|
|
||||||
case DT_HALF:
|
case DT_HALF:
|
||||||
RETURN_STRING_AS_STATUS(ConvertNumpyHalf(obj, shape, ret));
|
RETURN_STRING_AS_STATUS(NumpyHalfConverter::Convert(obj, shape, ret));
|
||||||
|
|
||||||
case DT_INT64:
|
case DT_INT64:
|
||||||
if (requested_dtype == DT_INVALID) {
|
if (requested_dtype == DT_INVALID) {
|
||||||
const char* error = ConvertInt32(obj, shape, ret);
|
const char* error = Int32Converter::Convert(obj, shape, ret);
|
||||||
if (error == ErrorFoundInt64) {
|
if (error == ErrorFoundInt64) {
|
||||||
error = ConvertInt64(obj, shape, ret);
|
error = Int64Converter::Convert(obj, shape, ret);
|
||||||
}
|
}
|
||||||
if (error == ErrorFoundFloat) {
|
if (error == ErrorFoundFloat) {
|
||||||
error = ConvertFloat(obj, shape, ret);
|
error = FloatConverter::Convert(obj, shape, ret);
|
||||||
}
|
}
|
||||||
// TODO(josh11b): May also want to fall back to using doubles if
|
// TODO(josh11b): May also want to fall back to using doubles if
|
||||||
// error == ErrorOutOfRange?
|
// error == ErrorOutOfRange?
|
||||||
RETURN_STRING_AS_STATUS(error);
|
RETURN_STRING_AS_STATUS(error);
|
||||||
} else {
|
} else {
|
||||||
const char* error = ConvertInt64(obj, shape, ret);
|
const char* error = Int64Converter::Convert(obj, shape, ret);
|
||||||
if (error == ErrorFoundFloat) {
|
if (error == ErrorFoundFloat) {
|
||||||
error = ConvertDouble(obj, shape, ret);
|
error = DoubleConverter::Convert(obj, shape, ret);
|
||||||
}
|
}
|
||||||
RETURN_STRING_AS_STATUS(error);
|
RETURN_STRING_AS_STATUS(error);
|
||||||
}
|
}
|
||||||
|
|
||||||
case DT_STRING:
|
case DT_STRING:
|
||||||
RETURN_STRING_AS_STATUS(ConvertString(obj, shape, ret));
|
RETURN_STRING_AS_STATUS(StringConverter::Convert(obj, shape, ret));
|
||||||
|
|
||||||
case DT_COMPLEX128:
|
case DT_COMPLEX128:
|
||||||
RETURN_STRING_AS_STATUS(ConvertComplex(obj, shape, ret));
|
RETURN_STRING_AS_STATUS(Complex128Converter::Convert(obj, shape, ret));
|
||||||
|
|
||||||
case DT_BOOL:
|
case DT_BOOL:
|
||||||
RETURN_STRING_AS_STATUS(ConvertBool(obj, shape, ret));
|
RETURN_STRING_AS_STATUS(BoolConverter::Convert(obj, shape, ret));
|
||||||
|
|
||||||
case DT_INVALID: // Only occurs for empty tensors.
|
case DT_INVALID: // Only occurs for empty tensors.
|
||||||
*ret = Tensor(requested_dtype == DT_INVALID ? DT_FLOAT : requested_dtype,
|
*ret = Tensor(requested_dtype == DT_INVALID ? DT_FLOAT : requested_dtype,
|
||||||
|
Loading…
x
Reference in New Issue
Block a user