From d97234dc564b2db7543b7f69517e5beefb89dc0e Mon Sep 17 00:00:00 2001 From: Dan Moldovan Date: Tue, 16 Jun 2020 12:34:27 -0700 Subject: [PATCH] More compatibility fixes for typing.Generic: * types.new_class is required in some distributions * avoid calling `isinstance` on some function objects in python 3.6 * account for some strange zombie pointer issue on windows Required for #40132. PiperOrigin-RevId: 316735720 Change-Id: I1b08ef5f18c77c9343d587562f50632336b684d5 --- tensorflow/python/framework/test_util.py | 2 +- tensorflow/python/util/tf_should_use.py | 18 +++++++++++++----- 2 files changed, 14 insertions(+), 6 deletions(-) diff --git a/tensorflow/python/framework/test_util.py b/tensorflow/python/framework/test_util.py index 2967bb3de84..a46bb7c9bda 100644 --- a/tensorflow/python/framework/test_util.py +++ b/tensorflow/python/framework/test_util.py @@ -736,7 +736,7 @@ def assert_no_new_tensors(f): return isinstance(obj, (ops.Tensor, variables.Variable, tensor_shape.Dimension, tensor_shape.TensorShape)) - except ReferenceError: + except (ReferenceError, AttributeError): # If the object no longer exists, we don't care about it. return False diff --git a/tensorflow/python/util/tf_should_use.py b/tensorflow/python/util/tf_should_use.py index 1671b078fa3..41c3220f5ca 100644 --- a/tensorflow/python/util/tf_should_use.py +++ b/tensorflow/python/util/tf_should_use.py @@ -21,15 +21,12 @@ import copy import sys import textwrap import traceback - -import six # pylint: disable=unused-import - +import types from tensorflow.python.eager import context from tensorflow.python.framework import ops from tensorflow.python.platform import tf_logging from tensorflow.python.util import tf_decorator -# pylint: enable=g-bad-import-order,g-import-not-at-top class _TFShouldUseHelper(object): @@ -154,7 +151,18 @@ def _get_wrapper(x, tf_should_use_helper): tx = copy.deepcopy(type_x) # Prefer using __orig_bases__, which preserve generic type arguments. bases = getattr(tx, '__orig_bases__', tx.__bases__) - copy_tx = type(tx.__name__, bases, dict(tx.__dict__)) + + # Use types.new_class when available, which is preferred over plain type in + # some distributions. + if sys.version_info >= (3, 5): + def set_body(ns): + ns.update(tx.__dict__) + return ns + + copy_tx = types.new_class(tx.__name__, bases, exec_body=set_body) + else: + copy_tx = type(tx.__name__, bases, dict(tx.__dict__)) + copy_tx.__init__ = _new__init__ copy_tx.__getattribute__ = _new__getattribute__ copy_tx.mark_used = _new_mark_used