From fd04b76337f9f65c5a2ce35a4f7336a25511435b Mon Sep 17 00:00:00 2001 From: Akshay Modi Date: Mon, 16 Jul 2018 10:15:40 -0700 Subject: [PATCH] Give EagerTensor a fully qualified name so __module__ doesn't generate an error This code still differs between py2 and py3 (__module__ returns "__builtin__" in py2, and the correct value in py3) - but its strictly better than before since earlier it would differ between py2 and py3 and generate an error in py3. We don't seem to correctly initialize the tp_dict in py2, so even when passing the correct, fully qualified name, we get back "__builtin__". Fixes #20701 PiperOrigin-RevId: 204762170 --- tensorflow/python/eager/pywrap_tensor.cc | 32 +++++++++++++++++++++--- tensorflow/python/eager/tensor_test.py | 2 +- 2 files changed, 29 insertions(+), 5 deletions(-) diff --git a/tensorflow/python/eager/pywrap_tensor.cc b/tensorflow/python/eager/pywrap_tensor.cc index ea604647fae..cefd5b12061 100644 --- a/tensorflow/python/eager/pywrap_tensor.cc +++ b/tensorflow/python/eager/pywrap_tensor.cc @@ -620,10 +620,6 @@ static PyType_Slot EagerTensor_Type_slots[] = { {Py_tp_init, reinterpret_cast(EagerTensor_init)}, {0, nullptr}, }; - -PyType_Spec EagerTensor_Type_spec = {"EagerTensor", sizeof(EagerTensor), 0, - Py_TPFLAGS_DEFAULT | Py_TPFLAGS_HEAPTYPE, - EagerTensor_Type_slots}; #else // TODO(agarwal): support active_trace. static PyTypeObject _EagerTensorType = { @@ -754,6 +750,34 @@ PyObject* TFE_Py_InitEagerTensor(PyObject* base_class) { #if PY_MAJOR_VERSION >= 3 PyObject* bases = PyTuple_New(1); PyTuple_SET_ITEM(bases, 0, base_class); + + tensorflow::Safe_PyObjectPtr base_class_module( + PyObject_GetAttrString(base_class, "__module__")); + const char* module = nullptr; + if (PyErr_Occurred()) { + PyErr_Clear(); + module = "__builtin__"; + } else { + module = PyBytes_AsString(base_class_module.get()); + if (module == nullptr) { + PyErr_Clear(); + module = PyUnicode_AsUTF8(base_class_module.get()); + if (module == nullptr) { + PyErr_Clear(); + module = "__builtin__"; + } + } + } + + // NOTE: The c_str from this string needs to outlast the function, hence is + // static. + static tensorflow::string fully_qualified_name = + tensorflow::strings::StrCat(module, ".EagerTensor"); + + static PyType_Spec EagerTensor_Type_spec = { + fully_qualified_name.c_str(), sizeof(EagerTensor), 0, + Py_TPFLAGS_DEFAULT | Py_TPFLAGS_HEAPTYPE, EagerTensor_Type_slots}; + EagerTensorType = reinterpret_cast( PyType_FromSpecWithBases(&EagerTensor_Type_spec, bases)); if (PyErr_Occurred()) { diff --git a/tensorflow/python/eager/tensor_test.py b/tensorflow/python/eager/tensor_test.py index 626a4eb1eee..871136e2c89 100644 --- a/tensorflow/python/eager/tensor_test.py +++ b/tensorflow/python/eager/tensor_test.py @@ -278,7 +278,7 @@ class TFETensorUtilTest(test_util.TensorFlowTestCase): with self.assertRaisesRegexp( TypeError, - r"tensors argument must be a list or a tuple. Got \"EagerTensor\""): + r"tensors argument must be a list or a tuple. Got.*EagerTensor"): pywrap_tensorflow.TFE_Py_TensorShapeSlice(t1, -2) def testNegativeSliceDim(self):