From de7d814b50ebf74bda4fe560afe22d120428b1fd Mon Sep 17 00:00:00 2001 From: Akshay Modi Date: Fri, 22 Feb 2019 15:47:10 -0800 Subject: [PATCH] Support DT_UINT64 as a direct conversion type PiperOrigin-RevId: 235273319 --- tensorflow/python/eager/tensor_test.py | 8 ++++++ tensorflow/python/lib/core/py_seq_tensor.cc | 29 +++++++++++++++++++++ 2 files changed, 37 insertions(+) diff --git a/tensorflow/python/eager/tensor_test.py b/tensorflow/python/eager/tensor_test.py index 23fb983767b..4e6a12b897e 100644 --- a/tensorflow/python/eager/tensor_test.py +++ b/tensorflow/python/eager/tensor_test.py @@ -364,6 +364,14 @@ class TFETensorTest(test_util.TensorFlowTestCase): "Provided value.*Requested dtype.*"): _ = ops.convert_to_tensor(1., dtype=dtypes.int32) + def testEagerLargeConstant(self): + self.assertEqual( + constant_op.constant(dtypes.uint64.max, dtype=dtypes.uint64).numpy(), + dtypes.uint64.max) + self.assertEqual( + constant_op.constant(dtypes.uint32.max, dtype=dtypes.uint32).numpy(), + dtypes.uint32.max) + class TFETensorUtilTest(test_util.TensorFlowTestCase): diff --git a/tensorflow/python/lib/core/py_seq_tensor.cc b/tensorflow/python/lib/core/py_seq_tensor.cc index 77fbfd51bbb..726701eb8f5 100644 --- a/tensorflow/python/lib/core/py_seq_tensor.cc +++ b/tensorflow/python/lib/core/py_seq_tensor.cc @@ -316,6 +316,31 @@ const char* ConvertOneInt64(PyObject* v, int64* out) { DEFINE_HELPER(ConvertInt64, int64, DT_INT64, ConvertOneInt64); +const char* ConvertOneUint64(PyObject* v, uint64* out) { +#if PY_MAJOR_VERSION < 3 + if (TF_PREDICT_TRUE(PyInt_Check(v))) { + *out = PyInt_AsUnsignedLongMask(v); + return nullptr; + } +#endif + if (TF_PREDICT_TRUE(PyLong_Check(v) || IsPyDimension(v))) { + *out = PyLong_AsUnsignedLong(v); + return nullptr; + } + if (PyIsInstance(v, &PyIntegerArrType_Type)) { // NumPy integers +#if PY_MAJOR_VERSION < 3 + Safe_PyObjectPtr as_int = make_safe(PyNumber_Int(v)); +#else + Safe_PyObjectPtr as_int = make_safe(PyNumber_Long(v)); +#endif + return ConvertOneUint64(as_int.get(), out); + } + if (IsPyFloat(v)) return ErrorFoundFloat; + return ErrorMixedTypes; +} + +DEFINE_HELPER(ConvertUint64, uint64, DT_UINT64, ConvertOneUint64); + const char* ConvertOneInt32(PyObject* v, int32* out) { int64 i; #if PY_MAJOR_VERSION < 3 @@ -525,6 +550,10 @@ Status PySeqToTensor(PyObject* obj, PyObject* dtype, Tensor* ret) { if (ConvertInt32(obj, shape, ret) == nullptr) return Status::OK(); break; + case DT_UINT64: + if (ConvertUint64(obj, shape, ret) == nullptr) return Status::OK(); + break; + case DT_COMPLEX128: if (ConvertComplex(obj, shape, ret) == nullptr) return Status::OK(); break;