Add PyMemberDef for __dict__ on eager tensors.
This fixes dir() calls on instances of eager tensors so that it correctly
accesses the __dict__ of EagerTensorType.
Earlier it would fail due to an infinite "loop" in subtype_dict: 7e610bcdf1/Objects/typeobject.c (L2145)
get_builtin_base_with_dict will return the same type (though I'm not sure this is reasonable behavior given its name).
The __dict__ getter for the type is subtype_dict creating an infinite tail recursion.
PiperOrigin-RevId: 212020695
This commit is contained in:
parent
ca92311cbd
commit
d644e729ca
@ -27,6 +27,8 @@ limitations under the License.
|
||||
#include "tensorflow/core/lib/strings/strcat.h"
|
||||
#include "tensorflow/python/lib/core/ndarray_tensor.h"
|
||||
|
||||
#include "structmember.h" // NOLINT // For PyMemberDef
|
||||
|
||||
// forward declare
|
||||
struct EagerTensor;
|
||||
|
||||
@ -643,6 +645,15 @@ static PyGetSetDef EagerTensor_getseters[] = {
|
||||
{nullptr} /* Sentinel */
|
||||
};
|
||||
|
||||
#if PY_MAJOR_VERSION < 3
|
||||
// Only used for Python2 since Python3 seems to set the __dict__ correctly.
|
||||
static PyMemberDef EagerTensor_members[] = {
|
||||
{const_cast<char*>("__dict__"), T_OBJECT, offsetof(EagerTensor, dict),
|
||||
READONLY},
|
||||
{nullptr},
|
||||
};
|
||||
#endif
|
||||
|
||||
static PyMethodDef EagerTensor_methods[] = {
|
||||
{"_numpy", (PyCFunction)EagerTensor_numpy, METH_NOARGS,
|
||||
PyDoc_STR("_numpy")},
|
||||
@ -717,7 +728,7 @@ static PyTypeObject _EagerTensorType = {
|
||||
nullptr, /* tp_iter */
|
||||
nullptr, /* tp_iternext */
|
||||
EagerTensor_methods, /* tp_methods */
|
||||
nullptr, /* tp_members */
|
||||
EagerTensor_members, /* tp_members */
|
||||
EagerTensor_getseters, /* tp_getset */
|
||||
nullptr, /* tp_base */
|
||||
nullptr, /* tp_dict */
|
||||
@ -853,7 +864,7 @@ PyObject* TFE_Py_InitEagerTensor(PyObject* base_class) {
|
||||
}
|
||||
EagerTensorType->tp_dictoffset = offsetof(EagerTensor, dict);
|
||||
#else
|
||||
_EagerTensorType.tp_base = reinterpret_cast<PyTypeObject*>(base_class);
|
||||
_EagerTensorType.tp_base = base_class_type;
|
||||
|
||||
if (PyType_Ready(&_EagerTensorType) < 0) {
|
||||
if (PyErr_Occurred()) return nullptr;
|
||||
|
@ -31,6 +31,7 @@ from tensorflow.python.framework import constant_op
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.framework import test_util
|
||||
from tensorflow.python.ops import array_ops
|
||||
|
||||
|
||||
def _create_tensor(value, device=None, dtype=None):
|
||||
@ -333,6 +334,19 @@ class TFETensorUtilTest(test_util.TensorFlowTestCase):
|
||||
"but tensor at index 2 has rank 0"):
|
||||
pywrap_tensorflow.TFE_Py_TensorShapeSlice([t2, t1, t3], 0)
|
||||
|
||||
@test_util.assert_no_new_pyobjects_executing_eagerly
|
||||
def testTensorDir(self):
|
||||
t = array_ops.zeros(1)
|
||||
t.test_attr = "Test"
|
||||
|
||||
instance_dir = dir(t)
|
||||
type_dir = dir(ops.EagerTensor)
|
||||
|
||||
# Monkey patched attributes should show up in dir(t)
|
||||
self.assertIn("test_attr", instance_dir)
|
||||
instance_dir.remove("test_attr")
|
||||
self.assertEqual(instance_dir, type_dir)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test.main()
|
||||
|
@ -465,29 +465,31 @@ def assert_no_new_pyobjects_executing_eagerly(f):
|
||||
f(self, **kwargs)
|
||||
gc.collect()
|
||||
previous_count = len(gc.get_objects())
|
||||
collection_sizes_before = {
|
||||
collection: len(ops.get_collection(collection))
|
||||
for collection in ops.get_default_graph().collections
|
||||
}
|
||||
if ops.has_default_graph():
|
||||
collection_sizes_before = {
|
||||
collection: len(ops.get_collection(collection))
|
||||
for collection in ops.get_default_graph().collections
|
||||
}
|
||||
for _ in range(3):
|
||||
f(self, **kwargs)
|
||||
# Note that gc.get_objects misses anything that isn't subject to garbage
|
||||
# collection (C types). Collections are a common source of leaks, so we
|
||||
# test for collection sizes explicitly.
|
||||
for collection_key in ops.get_default_graph().collections:
|
||||
collection = ops.get_collection(collection_key)
|
||||
size_before = collection_sizes_before.get(collection_key, 0)
|
||||
if len(collection) > size_before:
|
||||
raise AssertionError(
|
||||
("Collection %s increased in size from "
|
||||
"%d to %d (current items %s).") % (collection_key, size_before,
|
||||
len(collection), collection))
|
||||
# Make sure our collection checks don't show up as leaked memory by
|
||||
# removing references to temporary variables.
|
||||
del collection
|
||||
del collection_key
|
||||
del size_before
|
||||
del collection_sizes_before
|
||||
if ops.has_default_graph():
|
||||
for collection_key in ops.get_default_graph().collections:
|
||||
collection = ops.get_collection(collection_key)
|
||||
size_before = collection_sizes_before.get(collection_key, 0)
|
||||
if len(collection) > size_before:
|
||||
raise AssertionError(
|
||||
("Collection %s increased in size from "
|
||||
"%d to %d (current items %s).") %
|
||||
(collection_key, size_before, len(collection), collection))
|
||||
# Make sure our collection checks don't show up as leaked memory by
|
||||
# removing references to temporary variables.
|
||||
del collection
|
||||
del collection_key
|
||||
del size_before
|
||||
del collection_sizes_before
|
||||
gc.collect()
|
||||
# There should be no new Python objects hanging around.
|
||||
new_count = len(gc.get_objects())
|
||||
|
Loading…
Reference in New Issue
Block a user