diff --git a/tensorflow/python/util/tf_should_use.py b/tensorflow/python/util/tf_should_use.py index 9ba4b7520e5..1671b078fa3 100644 --- a/tensorflow/python/util/tf_should_use.py +++ b/tensorflow/python/util/tf_should_use.py @@ -152,7 +152,9 @@ def _get_wrapper(x, tf_should_use_helper): return memoized(x, tf_should_use_helper) tx = copy.deepcopy(type_x) - copy_tx = type(tx.__name__, tx.__bases__, dict(tx.__dict__)) + # Prefer using __orig_bases__, which preserve generic type arguments. + bases = getattr(tx, '__orig_bases__', tx.__bases__) + copy_tx = type(tx.__name__, bases, dict(tx.__dict__)) copy_tx.__init__ = _new__init__ copy_tx.__getattribute__ = _new__getattribute__ copy_tx.mark_used = _new_mark_used