Support DT_UINT64 as a direct conversion type
PiperOrigin-RevId: 235273319
This commit is contained in:
parent
4af8712749
commit
de7d814b50
@ -364,6 +364,14 @@ class TFETensorTest(test_util.TensorFlowTestCase):
|
||||
"Provided value.*Requested dtype.*"):
|
||||
_ = ops.convert_to_tensor(1., dtype=dtypes.int32)
|
||||
|
||||
def testEagerLargeConstant(self):
|
||||
self.assertEqual(
|
||||
constant_op.constant(dtypes.uint64.max, dtype=dtypes.uint64).numpy(),
|
||||
dtypes.uint64.max)
|
||||
self.assertEqual(
|
||||
constant_op.constant(dtypes.uint32.max, dtype=dtypes.uint32).numpy(),
|
||||
dtypes.uint32.max)
|
||||
|
||||
|
||||
class TFETensorUtilTest(test_util.TensorFlowTestCase):
|
||||
|
||||
|
@ -316,6 +316,31 @@ const char* ConvertOneInt64(PyObject* v, int64* out) {
|
||||
|
||||
DEFINE_HELPER(ConvertInt64, int64, DT_INT64, ConvertOneInt64);
|
||||
|
||||
const char* ConvertOneUint64(PyObject* v, uint64* out) {
|
||||
#if PY_MAJOR_VERSION < 3
|
||||
if (TF_PREDICT_TRUE(PyInt_Check(v))) {
|
||||
*out = PyInt_AsUnsignedLongMask(v);
|
||||
return nullptr;
|
||||
}
|
||||
#endif
|
||||
if (TF_PREDICT_TRUE(PyLong_Check(v) || IsPyDimension(v))) {
|
||||
*out = PyLong_AsUnsignedLong(v);
|
||||
return nullptr;
|
||||
}
|
||||
if (PyIsInstance(v, &PyIntegerArrType_Type)) { // NumPy integers
|
||||
#if PY_MAJOR_VERSION < 3
|
||||
Safe_PyObjectPtr as_int = make_safe(PyNumber_Int(v));
|
||||
#else
|
||||
Safe_PyObjectPtr as_int = make_safe(PyNumber_Long(v));
|
||||
#endif
|
||||
return ConvertOneUint64(as_int.get(), out);
|
||||
}
|
||||
if (IsPyFloat(v)) return ErrorFoundFloat;
|
||||
return ErrorMixedTypes;
|
||||
}
|
||||
|
||||
DEFINE_HELPER(ConvertUint64, uint64, DT_UINT64, ConvertOneUint64);
|
||||
|
||||
const char* ConvertOneInt32(PyObject* v, int32* out) {
|
||||
int64 i;
|
||||
#if PY_MAJOR_VERSION < 3
|
||||
@ -525,6 +550,10 @@ Status PySeqToTensor(PyObject* obj, PyObject* dtype, Tensor* ret) {
|
||||
if (ConvertInt32(obj, shape, ret) == nullptr) return Status::OK();
|
||||
break;
|
||||
|
||||
case DT_UINT64:
|
||||
if (ConvertUint64(obj, shape, ret) == nullptr) return Status::OK();
|
||||
break;
|
||||
|
||||
case DT_COMPLEX128:
|
||||
if (ConvertComplex(obj, shape, ret) == nullptr) return Status::OK();
|
||||
break;
|
||||
|
Loading…
Reference in New Issue
Block a user