Give EagerTensor a fully qualified name so __module__ doesn't generate an error

This code still differs between py2 and py3 (__module__ returns
"__builtin__" in py2, and the correct value in py3) - but its strictly better
than before since earlier it would differ between py2 and py3 and generate an
error in py3. We don't seem to correctly initialize the tp_dict in py2, so even
when passing the correct, fully qualified name, we get back "__builtin__".

Fixes #20701

PiperOrigin-RevId: 204762170
This commit is contained in:
Akshay Modi 2018-07-16 10:15:40 -07:00 committed by TensorFlower Gardener
parent b70a39b4e6
commit fd04b76337
2 changed files with 29 additions and 5 deletions

View File

@ -620,10 +620,6 @@ static PyType_Slot EagerTensor_Type_slots[] = {
{Py_tp_init, reinterpret_cast<void*>(EagerTensor_init)}, {Py_tp_init, reinterpret_cast<void*>(EagerTensor_init)},
{0, nullptr}, {0, nullptr},
}; };
PyType_Spec EagerTensor_Type_spec = {"EagerTensor", sizeof(EagerTensor), 0,
Py_TPFLAGS_DEFAULT | Py_TPFLAGS_HEAPTYPE,
EagerTensor_Type_slots};
#else #else
// TODO(agarwal): support active_trace. // TODO(agarwal): support active_trace.
static PyTypeObject _EagerTensorType = { static PyTypeObject _EagerTensorType = {
@ -754,6 +750,34 @@ PyObject* TFE_Py_InitEagerTensor(PyObject* base_class) {
#if PY_MAJOR_VERSION >= 3 #if PY_MAJOR_VERSION >= 3
PyObject* bases = PyTuple_New(1); PyObject* bases = PyTuple_New(1);
PyTuple_SET_ITEM(bases, 0, base_class); PyTuple_SET_ITEM(bases, 0, base_class);
tensorflow::Safe_PyObjectPtr base_class_module(
PyObject_GetAttrString(base_class, "__module__"));
const char* module = nullptr;
if (PyErr_Occurred()) {
PyErr_Clear();
module = "__builtin__";
} else {
module = PyBytes_AsString(base_class_module.get());
if (module == nullptr) {
PyErr_Clear();
module = PyUnicode_AsUTF8(base_class_module.get());
if (module == nullptr) {
PyErr_Clear();
module = "__builtin__";
}
}
}
// NOTE: The c_str from this string needs to outlast the function, hence is
// static.
static tensorflow::string fully_qualified_name =
tensorflow::strings::StrCat(module, ".EagerTensor");
static PyType_Spec EagerTensor_Type_spec = {
fully_qualified_name.c_str(), sizeof(EagerTensor), 0,
Py_TPFLAGS_DEFAULT | Py_TPFLAGS_HEAPTYPE, EagerTensor_Type_slots};
EagerTensorType = reinterpret_cast<PyTypeObject*>( EagerTensorType = reinterpret_cast<PyTypeObject*>(
PyType_FromSpecWithBases(&EagerTensor_Type_spec, bases)); PyType_FromSpecWithBases(&EagerTensor_Type_spec, bases));
if (PyErr_Occurred()) { if (PyErr_Occurred()) {

View File

@ -278,7 +278,7 @@ class TFETensorUtilTest(test_util.TensorFlowTestCase):
with self.assertRaisesRegexp( with self.assertRaisesRegexp(
TypeError, TypeError,
r"tensors argument must be a list or a tuple. Got \"EagerTensor\""): r"tensors argument must be a list or a tuple. Got.*EagerTensor"):
pywrap_tensorflow.TFE_Py_TensorShapeSlice(t1, -2) pywrap_tensorflow.TFE_Py_TensorShapeSlice(t1, -2)
def testNegativeSliceDim(self): def testNegativeSliceDim(self):