Add PyMemberDef for __dict__ on eager tensors.
This fixes dir() calls on instances of eager tensors so that it correctly
accesses the __dict__ of EagerTensorType.
Earlier it would fail due to an infinite "loop" in subtype_dict: 7e610bcdf1/Objects/typeobject.c (L2145)
get_builtin_base_with_dict will return the same type (though I'm not sure this is reasonable behavior given its name).
The __dict__ getter for the type is subtype_dict creating an infinite tail recursion.
PiperOrigin-RevId: 212020695
This commit is contained in:
parent
ca92311cbd
commit
d644e729ca
@ -27,6 +27,8 @@ limitations under the License.
|
|||||||
#include "tensorflow/core/lib/strings/strcat.h"
|
#include "tensorflow/core/lib/strings/strcat.h"
|
||||||
#include "tensorflow/python/lib/core/ndarray_tensor.h"
|
#include "tensorflow/python/lib/core/ndarray_tensor.h"
|
||||||
|
|
||||||
|
#include "structmember.h" // NOLINT // For PyMemberDef
|
||||||
|
|
||||||
// forward declare
|
// forward declare
|
||||||
struct EagerTensor;
|
struct EagerTensor;
|
||||||
|
|
||||||
@ -643,6 +645,15 @@ static PyGetSetDef EagerTensor_getseters[] = {
|
|||||||
{nullptr} /* Sentinel */
|
{nullptr} /* Sentinel */
|
||||||
};
|
};
|
||||||
|
|
||||||
|
#if PY_MAJOR_VERSION < 3
|
||||||
|
// Only used for Python2 since Python3 seems to set the __dict__ correctly.
|
||||||
|
static PyMemberDef EagerTensor_members[] = {
|
||||||
|
{const_cast<char*>("__dict__"), T_OBJECT, offsetof(EagerTensor, dict),
|
||||||
|
READONLY},
|
||||||
|
{nullptr},
|
||||||
|
};
|
||||||
|
#endif
|
||||||
|
|
||||||
static PyMethodDef EagerTensor_methods[] = {
|
static PyMethodDef EagerTensor_methods[] = {
|
||||||
{"_numpy", (PyCFunction)EagerTensor_numpy, METH_NOARGS,
|
{"_numpy", (PyCFunction)EagerTensor_numpy, METH_NOARGS,
|
||||||
PyDoc_STR("_numpy")},
|
PyDoc_STR("_numpy")},
|
||||||
@ -717,7 +728,7 @@ static PyTypeObject _EagerTensorType = {
|
|||||||
nullptr, /* tp_iter */
|
nullptr, /* tp_iter */
|
||||||
nullptr, /* tp_iternext */
|
nullptr, /* tp_iternext */
|
||||||
EagerTensor_methods, /* tp_methods */
|
EagerTensor_methods, /* tp_methods */
|
||||||
nullptr, /* tp_members */
|
EagerTensor_members, /* tp_members */
|
||||||
EagerTensor_getseters, /* tp_getset */
|
EagerTensor_getseters, /* tp_getset */
|
||||||
nullptr, /* tp_base */
|
nullptr, /* tp_base */
|
||||||
nullptr, /* tp_dict */
|
nullptr, /* tp_dict */
|
||||||
@ -853,7 +864,7 @@ PyObject* TFE_Py_InitEagerTensor(PyObject* base_class) {
|
|||||||
}
|
}
|
||||||
EagerTensorType->tp_dictoffset = offsetof(EagerTensor, dict);
|
EagerTensorType->tp_dictoffset = offsetof(EagerTensor, dict);
|
||||||
#else
|
#else
|
||||||
_EagerTensorType.tp_base = reinterpret_cast<PyTypeObject*>(base_class);
|
_EagerTensorType.tp_base = base_class_type;
|
||||||
|
|
||||||
if (PyType_Ready(&_EagerTensorType) < 0) {
|
if (PyType_Ready(&_EagerTensorType) < 0) {
|
||||||
if (PyErr_Occurred()) return nullptr;
|
if (PyErr_Occurred()) return nullptr;
|
||||||
|
@ -31,6 +31,7 @@ from tensorflow.python.framework import constant_op
|
|||||||
from tensorflow.python.framework import dtypes
|
from tensorflow.python.framework import dtypes
|
||||||
from tensorflow.python.framework import ops
|
from tensorflow.python.framework import ops
|
||||||
from tensorflow.python.framework import test_util
|
from tensorflow.python.framework import test_util
|
||||||
|
from tensorflow.python.ops import array_ops
|
||||||
|
|
||||||
|
|
||||||
def _create_tensor(value, device=None, dtype=None):
|
def _create_tensor(value, device=None, dtype=None):
|
||||||
@ -333,6 +334,19 @@ class TFETensorUtilTest(test_util.TensorFlowTestCase):
|
|||||||
"but tensor at index 2 has rank 0"):
|
"but tensor at index 2 has rank 0"):
|
||||||
pywrap_tensorflow.TFE_Py_TensorShapeSlice([t2, t1, t3], 0)
|
pywrap_tensorflow.TFE_Py_TensorShapeSlice([t2, t1, t3], 0)
|
||||||
|
|
||||||
|
@test_util.assert_no_new_pyobjects_executing_eagerly
|
||||||
|
def testTensorDir(self):
|
||||||
|
t = array_ops.zeros(1)
|
||||||
|
t.test_attr = "Test"
|
||||||
|
|
||||||
|
instance_dir = dir(t)
|
||||||
|
type_dir = dir(ops.EagerTensor)
|
||||||
|
|
||||||
|
# Monkey patched attributes should show up in dir(t)
|
||||||
|
self.assertIn("test_attr", instance_dir)
|
||||||
|
instance_dir.remove("test_attr")
|
||||||
|
self.assertEqual(instance_dir, type_dir)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
test.main()
|
test.main()
|
||||||
|
@ -465,29 +465,31 @@ def assert_no_new_pyobjects_executing_eagerly(f):
|
|||||||
f(self, **kwargs)
|
f(self, **kwargs)
|
||||||
gc.collect()
|
gc.collect()
|
||||||
previous_count = len(gc.get_objects())
|
previous_count = len(gc.get_objects())
|
||||||
collection_sizes_before = {
|
if ops.has_default_graph():
|
||||||
collection: len(ops.get_collection(collection))
|
collection_sizes_before = {
|
||||||
for collection in ops.get_default_graph().collections
|
collection: len(ops.get_collection(collection))
|
||||||
}
|
for collection in ops.get_default_graph().collections
|
||||||
|
}
|
||||||
for _ in range(3):
|
for _ in range(3):
|
||||||
f(self, **kwargs)
|
f(self, **kwargs)
|
||||||
# Note that gc.get_objects misses anything that isn't subject to garbage
|
# Note that gc.get_objects misses anything that isn't subject to garbage
|
||||||
# collection (C types). Collections are a common source of leaks, so we
|
# collection (C types). Collections are a common source of leaks, so we
|
||||||
# test for collection sizes explicitly.
|
# test for collection sizes explicitly.
|
||||||
for collection_key in ops.get_default_graph().collections:
|
if ops.has_default_graph():
|
||||||
collection = ops.get_collection(collection_key)
|
for collection_key in ops.get_default_graph().collections:
|
||||||
size_before = collection_sizes_before.get(collection_key, 0)
|
collection = ops.get_collection(collection_key)
|
||||||
if len(collection) > size_before:
|
size_before = collection_sizes_before.get(collection_key, 0)
|
||||||
raise AssertionError(
|
if len(collection) > size_before:
|
||||||
("Collection %s increased in size from "
|
raise AssertionError(
|
||||||
"%d to %d (current items %s).") % (collection_key, size_before,
|
("Collection %s increased in size from "
|
||||||
len(collection), collection))
|
"%d to %d (current items %s).") %
|
||||||
# Make sure our collection checks don't show up as leaked memory by
|
(collection_key, size_before, len(collection), collection))
|
||||||
# removing references to temporary variables.
|
# Make sure our collection checks don't show up as leaked memory by
|
||||||
del collection
|
# removing references to temporary variables.
|
||||||
del collection_key
|
del collection
|
||||||
del size_before
|
del collection_key
|
||||||
del collection_sizes_before
|
del size_before
|
||||||
|
del collection_sizes_before
|
||||||
gc.collect()
|
gc.collect()
|
||||||
# There should be no new Python objects hanging around.
|
# There should be no new Python objects hanging around.
|
||||||
new_count = len(gc.get_objects())
|
new_count = len(gc.get_objects())
|
||||||
|
Loading…
Reference in New Issue
Block a user