Give EagerTensor a fully qualified name so __module__ doesn't generate an error
This code still differs between py2 and py3 (__module__ returns "__builtin__" in py2, and the correct value in py3) - but its strictly better than before since earlier it would differ between py2 and py3 and generate an error in py3. We don't seem to correctly initialize the tp_dict in py2, so even when passing the correct, fully qualified name, we get back "__builtin__". Fixes #20701 PiperOrigin-RevId: 204762170
This commit is contained in:
parent
b70a39b4e6
commit
fd04b76337
@ -620,10 +620,6 @@ static PyType_Slot EagerTensor_Type_slots[] = {
|
|||||||
{Py_tp_init, reinterpret_cast<void*>(EagerTensor_init)},
|
{Py_tp_init, reinterpret_cast<void*>(EagerTensor_init)},
|
||||||
{0, nullptr},
|
{0, nullptr},
|
||||||
};
|
};
|
||||||
|
|
||||||
PyType_Spec EagerTensor_Type_spec = {"EagerTensor", sizeof(EagerTensor), 0,
|
|
||||||
Py_TPFLAGS_DEFAULT | Py_TPFLAGS_HEAPTYPE,
|
|
||||||
EagerTensor_Type_slots};
|
|
||||||
#else
|
#else
|
||||||
// TODO(agarwal): support active_trace.
|
// TODO(agarwal): support active_trace.
|
||||||
static PyTypeObject _EagerTensorType = {
|
static PyTypeObject _EagerTensorType = {
|
||||||
@ -754,6 +750,34 @@ PyObject* TFE_Py_InitEagerTensor(PyObject* base_class) {
|
|||||||
#if PY_MAJOR_VERSION >= 3
|
#if PY_MAJOR_VERSION >= 3
|
||||||
PyObject* bases = PyTuple_New(1);
|
PyObject* bases = PyTuple_New(1);
|
||||||
PyTuple_SET_ITEM(bases, 0, base_class);
|
PyTuple_SET_ITEM(bases, 0, base_class);
|
||||||
|
|
||||||
|
tensorflow::Safe_PyObjectPtr base_class_module(
|
||||||
|
PyObject_GetAttrString(base_class, "__module__"));
|
||||||
|
const char* module = nullptr;
|
||||||
|
if (PyErr_Occurred()) {
|
||||||
|
PyErr_Clear();
|
||||||
|
module = "__builtin__";
|
||||||
|
} else {
|
||||||
|
module = PyBytes_AsString(base_class_module.get());
|
||||||
|
if (module == nullptr) {
|
||||||
|
PyErr_Clear();
|
||||||
|
module = PyUnicode_AsUTF8(base_class_module.get());
|
||||||
|
if (module == nullptr) {
|
||||||
|
PyErr_Clear();
|
||||||
|
module = "__builtin__";
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// NOTE: The c_str from this string needs to outlast the function, hence is
|
||||||
|
// static.
|
||||||
|
static tensorflow::string fully_qualified_name =
|
||||||
|
tensorflow::strings::StrCat(module, ".EagerTensor");
|
||||||
|
|
||||||
|
static PyType_Spec EagerTensor_Type_spec = {
|
||||||
|
fully_qualified_name.c_str(), sizeof(EagerTensor), 0,
|
||||||
|
Py_TPFLAGS_DEFAULT | Py_TPFLAGS_HEAPTYPE, EagerTensor_Type_slots};
|
||||||
|
|
||||||
EagerTensorType = reinterpret_cast<PyTypeObject*>(
|
EagerTensorType = reinterpret_cast<PyTypeObject*>(
|
||||||
PyType_FromSpecWithBases(&EagerTensor_Type_spec, bases));
|
PyType_FromSpecWithBases(&EagerTensor_Type_spec, bases));
|
||||||
if (PyErr_Occurred()) {
|
if (PyErr_Occurred()) {
|
||||||
|
@ -278,7 +278,7 @@ class TFETensorUtilTest(test_util.TensorFlowTestCase):
|
|||||||
|
|
||||||
with self.assertRaisesRegexp(
|
with self.assertRaisesRegexp(
|
||||||
TypeError,
|
TypeError,
|
||||||
r"tensors argument must be a list or a tuple. Got \"EagerTensor\""):
|
r"tensors argument must be a list or a tuple. Got.*EagerTensor"):
|
||||||
pywrap_tensorflow.TFE_Py_TensorShapeSlice(t1, -2)
|
pywrap_tensorflow.TFE_Py_TensorShapeSlice(t1, -2)
|
||||||
|
|
||||||
def testNegativeSliceDim(self):
|
def testNegativeSliceDim(self):
|
||||||
|
Loading…
Reference in New Issue
Block a user