From f9a682e70fbd2f6b19a40c2d923da2e5afe25586 Mon Sep 17 00:00:00 2001 From: Akshay Modi Date: Wed, 1 May 2019 14:46:15 -0700 Subject: [PATCH] Support DT_UINT64 as a direct conversion type PiperOrigin-RevId: 246202834 --- tensorflow/python/eager/tensor_test.py | 7 +++++ tensorflow/python/lib/core/py_seq_tensor.cc | 29 +++++++++++++++++++++ 2 files changed, 36 insertions(+) diff --git a/tensorflow/python/eager/tensor_test.py b/tensorflow/python/eager/tensor_test.py index f8e3f640124..238f0f9eb1c 100644 --- a/tensorflow/python/eager/tensor_test.py +++ b/tensorflow/python/eager/tensor_test.py @@ -368,6 +368,13 @@ class TFETensorTest(test_util.TensorFlowTestCase): "Provided value.*Requested dtype.*"): _ = ops.convert_to_tensor(1., dtype=dtypes.int32) + def testEagerLargeConstant(self): + for t in [dtypes.uint64, dtypes.uint32, dtypes.int32, dtypes.int64]: + self.assertEqual( + constant_op.constant(t.max, dtype=t).numpy(), t.max) + self.assertEqual( + constant_op.constant(t.min, dtype=t).numpy(), t.min) + class TFETensorUtilTest(test_util.TensorFlowTestCase): diff --git a/tensorflow/python/lib/core/py_seq_tensor.cc b/tensorflow/python/lib/core/py_seq_tensor.cc index 354949c31af..8f66a8a7364 100644 --- a/tensorflow/python/lib/core/py_seq_tensor.cc +++ b/tensorflow/python/lib/core/py_seq_tensor.cc @@ -316,6 +316,31 @@ const char* ConvertOneInt64(PyObject* v, int64* out) { DEFINE_HELPER(ConvertInt64, int64, DT_INT64, ConvertOneInt64); +const char* ConvertOneUint64(PyObject* v, uint64* out) { +#if PY_MAJOR_VERSION < 3 + if (TF_PREDICT_TRUE(PyInt_Check(v))) { + *out = PyInt_AsUnsignedLongLongMask(v); + return nullptr; + } +#endif + if (TF_PREDICT_TRUE(PyLong_Check(v) || IsPyDimension(v))) { + *out = PyLong_AsUnsignedLongLong(v); + return nullptr; + } + if (PyIsInstance(v, &PyIntegerArrType_Type)) { // NumPy integers +#if PY_MAJOR_VERSION < 3 + Safe_PyObjectPtr as_int = make_safe(PyNumber_Int(v)); +#else + Safe_PyObjectPtr as_int = make_safe(PyNumber_Long(v)); +#endif + return ConvertOneUint64(as_int.get(), out); + } + if (IsPyFloat(v)) return ErrorFoundFloat; + return ErrorMixedTypes; +} + +DEFINE_HELPER(ConvertUint64, uint64, DT_UINT64, ConvertOneUint64); + const char* ConvertOneInt32(PyObject* v, int32* out) { int64 i; #if PY_MAJOR_VERSION < 3 @@ -522,6 +547,10 @@ Status PySeqToTensor(PyObject* obj, DataType dtype, Tensor* ret) { if (ConvertInt32(obj, shape, ret) == nullptr) return Status::OK(); break; + case DT_UINT64: + if (ConvertUint64(obj, shape, ret) == nullptr) return Status::OK(); + break; + case DT_COMPLEX128: if (ConvertComplex(obj, shape, ret) == nullptr) return Status::OK(); break;