diff --git a/tensorflow/python/framework/test_util.py b/tensorflow/python/framework/test_util.py index 2967bb3de84..a46bb7c9bda 100644 --- a/tensorflow/python/framework/test_util.py +++ b/tensorflow/python/framework/test_util.py @@ -736,7 +736,7 @@ def assert_no_new_tensors(f): return isinstance(obj, (ops.Tensor, variables.Variable, tensor_shape.Dimension, tensor_shape.TensorShape)) - except ReferenceError: + except (ReferenceError, AttributeError): # If the object no longer exists, we don't care about it. return False diff --git a/tensorflow/python/util/tf_should_use.py b/tensorflow/python/util/tf_should_use.py index 1671b078fa3..41c3220f5ca 100644 --- a/tensorflow/python/util/tf_should_use.py +++ b/tensorflow/python/util/tf_should_use.py @@ -21,15 +21,12 @@ import copy import sys import textwrap import traceback - -import six # pylint: disable=unused-import - +import types from tensorflow.python.eager import context from tensorflow.python.framework import ops from tensorflow.python.platform import tf_logging from tensorflow.python.util import tf_decorator -# pylint: enable=g-bad-import-order,g-import-not-at-top class _TFShouldUseHelper(object): @@ -154,7 +151,18 @@ def _get_wrapper(x, tf_should_use_helper): tx = copy.deepcopy(type_x) # Prefer using __orig_bases__, which preserve generic type arguments. bases = getattr(tx, '__orig_bases__', tx.__bases__) - copy_tx = type(tx.__name__, bases, dict(tx.__dict__)) + + # Use types.new_class when available, which is preferred over plain type in + # some distributions. + if sys.version_info >= (3, 5): + def set_body(ns): + ns.update(tx.__dict__) + return ns + + copy_tx = types.new_class(tx.__name__, bases, exec_body=set_body) + else: + copy_tx = type(tx.__name__, bases, dict(tx.__dict__)) + copy_tx.__init__ = _new__init__ copy_tx.__getattribute__ = _new__getattribute__ copy_tx.mark_used = _new_mark_used