Add PyMemberDef for __dict__ on eager tensors.

This fixes dir() calls on instances of eager tensors so that it correctly
accesses the __dict__ of EagerTensorType.

Earlier it would fail due to an infinite "loop" in subtype_dict: 7e610bcdf1/Objects/typeobject.c (L2145)

get_builtin_base_with_dict will return the same type (though I'm not sure this is reasonable behavior given its name).
The __dict__ getter for the type is subtype_dict creating an infinite tail recursion.

PiperOrigin-RevId: 212020695
This commit is contained in:
Akshay Modi 2018-09-07 12:48:22 -07:00 committed by TensorFlower Gardener
parent ca92311cbd
commit d644e729ca
3 changed files with 47 additions and 20 deletions

View File

@ -27,6 +27,8 @@ limitations under the License.
#include "tensorflow/core/lib/strings/strcat.h"
#include "tensorflow/python/lib/core/ndarray_tensor.h"
#include "structmember.h" // NOLINT // For PyMemberDef
// forward declare
struct EagerTensor;
@ -643,6 +645,15 @@ static PyGetSetDef EagerTensor_getseters[] = {
{nullptr} /* Sentinel */
};
#if PY_MAJOR_VERSION < 3
// Only used for Python2 since Python3 seems to set the __dict__ correctly.
static PyMemberDef EagerTensor_members[] = {
{const_cast<char*>("__dict__"), T_OBJECT, offsetof(EagerTensor, dict),
READONLY},
{nullptr},
};
#endif
static PyMethodDef EagerTensor_methods[] = {
{"_numpy", (PyCFunction)EagerTensor_numpy, METH_NOARGS,
PyDoc_STR("_numpy")},
@ -717,7 +728,7 @@ static PyTypeObject _EagerTensorType = {
nullptr, /* tp_iter */
nullptr, /* tp_iternext */
EagerTensor_methods, /* tp_methods */
nullptr, /* tp_members */
EagerTensor_members, /* tp_members */
EagerTensor_getseters, /* tp_getset */
nullptr, /* tp_base */
nullptr, /* tp_dict */
@ -853,7 +864,7 @@ PyObject* TFE_Py_InitEagerTensor(PyObject* base_class) {
}
EagerTensorType->tp_dictoffset = offsetof(EagerTensor, dict);
#else
_EagerTensorType.tp_base = reinterpret_cast<PyTypeObject*>(base_class);
_EagerTensorType.tp_base = base_class_type;
if (PyType_Ready(&_EagerTensorType) < 0) {
if (PyErr_Occurred()) return nullptr;

View File

@ -31,6 +31,7 @@ from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.framework import test_util
from tensorflow.python.ops import array_ops
def _create_tensor(value, device=None, dtype=None):
@ -333,6 +334,19 @@ class TFETensorUtilTest(test_util.TensorFlowTestCase):
"but tensor at index 2 has rank 0"):
pywrap_tensorflow.TFE_Py_TensorShapeSlice([t2, t1, t3], 0)
@test_util.assert_no_new_pyobjects_executing_eagerly
def testTensorDir(self):
t = array_ops.zeros(1)
t.test_attr = "Test"
instance_dir = dir(t)
type_dir = dir(ops.EagerTensor)
# Monkey patched attributes should show up in dir(t)
self.assertIn("test_attr", instance_dir)
instance_dir.remove("test_attr")
self.assertEqual(instance_dir, type_dir)
if __name__ == "__main__":
test.main()

View File

@ -465,29 +465,31 @@ def assert_no_new_pyobjects_executing_eagerly(f):
f(self, **kwargs)
gc.collect()
previous_count = len(gc.get_objects())
collection_sizes_before = {
collection: len(ops.get_collection(collection))
for collection in ops.get_default_graph().collections
}
if ops.has_default_graph():
collection_sizes_before = {
collection: len(ops.get_collection(collection))
for collection in ops.get_default_graph().collections
}
for _ in range(3):
f(self, **kwargs)
# Note that gc.get_objects misses anything that isn't subject to garbage
# collection (C types). Collections are a common source of leaks, so we
# test for collection sizes explicitly.
for collection_key in ops.get_default_graph().collections:
collection = ops.get_collection(collection_key)
size_before = collection_sizes_before.get(collection_key, 0)
if len(collection) > size_before:
raise AssertionError(
("Collection %s increased in size from "
"%d to %d (current items %s).") % (collection_key, size_before,
len(collection), collection))
# Make sure our collection checks don't show up as leaked memory by
# removing references to temporary variables.
del collection
del collection_key
del size_before
del collection_sizes_before
if ops.has_default_graph():
for collection_key in ops.get_default_graph().collections:
collection = ops.get_collection(collection_key)
size_before = collection_sizes_before.get(collection_key, 0)
if len(collection) > size_before:
raise AssertionError(
("Collection %s increased in size from "
"%d to %d (current items %s).") %
(collection_key, size_before, len(collection), collection))
# Make sure our collection checks don't show up as leaked memory by
# removing references to temporary variables.
del collection
del collection_key
del size_before
del collection_sizes_before
gc.collect()
# There should be no new Python objects hanging around.
new_count = len(gc.get_objects())