Give EagerTensor a fully qualified name so __module__ doesn't generate an error
This code still differs between py2 and py3 (__module__ returns "__builtin__" in py2, and the correct value in py3) - but its strictly better than before since earlier it would differ between py2 and py3 and generate an error in py3. We don't seem to correctly initialize the tp_dict in py2, so even when passing the correct, fully qualified name, we get back "__builtin__". Fixes #20701 PiperOrigin-RevId: 204762170
This commit is contained in:
parent
b70a39b4e6
commit
fd04b76337
@ -620,10 +620,6 @@ static PyType_Slot EagerTensor_Type_slots[] = {
|
||||
{Py_tp_init, reinterpret_cast<void*>(EagerTensor_init)},
|
||||
{0, nullptr},
|
||||
};
|
||||
|
||||
PyType_Spec EagerTensor_Type_spec = {"EagerTensor", sizeof(EagerTensor), 0,
|
||||
Py_TPFLAGS_DEFAULT | Py_TPFLAGS_HEAPTYPE,
|
||||
EagerTensor_Type_slots};
|
||||
#else
|
||||
// TODO(agarwal): support active_trace.
|
||||
static PyTypeObject _EagerTensorType = {
|
||||
@ -754,6 +750,34 @@ PyObject* TFE_Py_InitEagerTensor(PyObject* base_class) {
|
||||
#if PY_MAJOR_VERSION >= 3
|
||||
PyObject* bases = PyTuple_New(1);
|
||||
PyTuple_SET_ITEM(bases, 0, base_class);
|
||||
|
||||
tensorflow::Safe_PyObjectPtr base_class_module(
|
||||
PyObject_GetAttrString(base_class, "__module__"));
|
||||
const char* module = nullptr;
|
||||
if (PyErr_Occurred()) {
|
||||
PyErr_Clear();
|
||||
module = "__builtin__";
|
||||
} else {
|
||||
module = PyBytes_AsString(base_class_module.get());
|
||||
if (module == nullptr) {
|
||||
PyErr_Clear();
|
||||
module = PyUnicode_AsUTF8(base_class_module.get());
|
||||
if (module == nullptr) {
|
||||
PyErr_Clear();
|
||||
module = "__builtin__";
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// NOTE: The c_str from this string needs to outlast the function, hence is
|
||||
// static.
|
||||
static tensorflow::string fully_qualified_name =
|
||||
tensorflow::strings::StrCat(module, ".EagerTensor");
|
||||
|
||||
static PyType_Spec EagerTensor_Type_spec = {
|
||||
fully_qualified_name.c_str(), sizeof(EagerTensor), 0,
|
||||
Py_TPFLAGS_DEFAULT | Py_TPFLAGS_HEAPTYPE, EagerTensor_Type_slots};
|
||||
|
||||
EagerTensorType = reinterpret_cast<PyTypeObject*>(
|
||||
PyType_FromSpecWithBases(&EagerTensor_Type_spec, bases));
|
||||
if (PyErr_Occurred()) {
|
||||
|
@ -278,7 +278,7 @@ class TFETensorUtilTest(test_util.TensorFlowTestCase):
|
||||
|
||||
with self.assertRaisesRegexp(
|
||||
TypeError,
|
||||
r"tensors argument must be a list or a tuple. Got \"EagerTensor\""):
|
||||
r"tensors argument must be a list or a tuple. Got.*EagerTensor"):
|
||||
pywrap_tensorflow.TFE_Py_TensorShapeSlice(t1, -2)
|
||||
|
||||
def testNegativeSliceDim(self):
|
||||
|
Loading…
Reference in New Issue
Block a user