From d644e729caa4071cc2571cf679acac4392117848 Mon Sep 17 00:00:00 2001 From: Akshay Modi Date: Fri, 7 Sep 2018 12:48:22 -0700 Subject: [PATCH] Add PyMemberDef for __dict__ on eager tensors. This fixes dir() calls on instances of eager tensors so that it correctly accesses the __dict__ of EagerTensorType. Earlier it would fail due to an infinite "loop" in subtype_dict: https://github.com/python/cpython/blob/7e610bcdf128f61b925654e4fa80fbac83537d0e/Objects/typeobject.c#L2145 get_builtin_base_with_dict will return the same type (though I'm not sure this is reasonable behavior given its name). The __dict__ getter for the type is subtype_dict creating an infinite tail recursion. PiperOrigin-RevId: 212020695 --- tensorflow/python/eager/pywrap_tensor.cc | 15 ++++++++-- tensorflow/python/eager/tensor_test.py | 14 +++++++++ tensorflow/python/framework/test_util.py | 38 +++++++++++++----------- 3 files changed, 47 insertions(+), 20 deletions(-) diff --git a/tensorflow/python/eager/pywrap_tensor.cc b/tensorflow/python/eager/pywrap_tensor.cc index 432dcbc2e21..f34ce6af79c 100644 --- a/tensorflow/python/eager/pywrap_tensor.cc +++ b/tensorflow/python/eager/pywrap_tensor.cc @@ -27,6 +27,8 @@ limitations under the License. #include "tensorflow/core/lib/strings/strcat.h" #include "tensorflow/python/lib/core/ndarray_tensor.h" +#include "structmember.h" // NOLINT // For PyMemberDef + // forward declare struct EagerTensor; @@ -643,6 +645,15 @@ static PyGetSetDef EagerTensor_getseters[] = { {nullptr} /* Sentinel */ }; +#if PY_MAJOR_VERSION < 3 +// Only used for Python2 since Python3 seems to set the __dict__ correctly. +static PyMemberDef EagerTensor_members[] = { + {const_cast("__dict__"), T_OBJECT, offsetof(EagerTensor, dict), + READONLY}, + {nullptr}, +}; +#endif + static PyMethodDef EagerTensor_methods[] = { {"_numpy", (PyCFunction)EagerTensor_numpy, METH_NOARGS, PyDoc_STR("_numpy")}, @@ -717,7 +728,7 @@ static PyTypeObject _EagerTensorType = { nullptr, /* tp_iter */ nullptr, /* tp_iternext */ EagerTensor_methods, /* tp_methods */ - nullptr, /* tp_members */ + EagerTensor_members, /* tp_members */ EagerTensor_getseters, /* tp_getset */ nullptr, /* tp_base */ nullptr, /* tp_dict */ @@ -853,7 +864,7 @@ PyObject* TFE_Py_InitEagerTensor(PyObject* base_class) { } EagerTensorType->tp_dictoffset = offsetof(EagerTensor, dict); #else - _EagerTensorType.tp_base = reinterpret_cast(base_class); + _EagerTensorType.tp_base = base_class_type; if (PyType_Ready(&_EagerTensorType) < 0) { if (PyErr_Occurred()) return nullptr; diff --git a/tensorflow/python/eager/tensor_test.py b/tensorflow/python/eager/tensor_test.py index 32742a9b968..344a9b25bdd 100644 --- a/tensorflow/python/eager/tensor_test.py +++ b/tensorflow/python/eager/tensor_test.py @@ -31,6 +31,7 @@ from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops from tensorflow.python.framework import test_util +from tensorflow.python.ops import array_ops def _create_tensor(value, device=None, dtype=None): @@ -333,6 +334,19 @@ class TFETensorUtilTest(test_util.TensorFlowTestCase): "but tensor at index 2 has rank 0"): pywrap_tensorflow.TFE_Py_TensorShapeSlice([t2, t1, t3], 0) + @test_util.assert_no_new_pyobjects_executing_eagerly + def testTensorDir(self): + t = array_ops.zeros(1) + t.test_attr = "Test" + + instance_dir = dir(t) + type_dir = dir(ops.EagerTensor) + + # Monkey patched attributes should show up in dir(t) + self.assertIn("test_attr", instance_dir) + instance_dir.remove("test_attr") + self.assertEqual(instance_dir, type_dir) + if __name__ == "__main__": test.main() diff --git a/tensorflow/python/framework/test_util.py b/tensorflow/python/framework/test_util.py index 0925598e339..4bece9e25e8 100644 --- a/tensorflow/python/framework/test_util.py +++ b/tensorflow/python/framework/test_util.py @@ -465,29 +465,31 @@ def assert_no_new_pyobjects_executing_eagerly(f): f(self, **kwargs) gc.collect() previous_count = len(gc.get_objects()) - collection_sizes_before = { - collection: len(ops.get_collection(collection)) - for collection in ops.get_default_graph().collections - } + if ops.has_default_graph(): + collection_sizes_before = { + collection: len(ops.get_collection(collection)) + for collection in ops.get_default_graph().collections + } for _ in range(3): f(self, **kwargs) # Note that gc.get_objects misses anything that isn't subject to garbage # collection (C types). Collections are a common source of leaks, so we # test for collection sizes explicitly. - for collection_key in ops.get_default_graph().collections: - collection = ops.get_collection(collection_key) - size_before = collection_sizes_before.get(collection_key, 0) - if len(collection) > size_before: - raise AssertionError( - ("Collection %s increased in size from " - "%d to %d (current items %s).") % (collection_key, size_before, - len(collection), collection)) - # Make sure our collection checks don't show up as leaked memory by - # removing references to temporary variables. - del collection - del collection_key - del size_before - del collection_sizes_before + if ops.has_default_graph(): + for collection_key in ops.get_default_graph().collections: + collection = ops.get_collection(collection_key) + size_before = collection_sizes_before.get(collection_key, 0) + if len(collection) > size_before: + raise AssertionError( + ("Collection %s increased in size from " + "%d to %d (current items %s).") % + (collection_key, size_before, len(collection), collection)) + # Make sure our collection checks don't show up as leaked memory by + # removing references to temporary variables. + del collection + del collection_key + del size_before + del collection_sizes_before gc.collect() # There should be no new Python objects hanging around. new_count = len(gc.get_objects())