Optimize version_utils.swap_class method for better performance.

PiperOrigin-RevId: 350465713
Change-Id: Ibbacdccd974c7ecf597cb501248a228404996b01
This commit is contained in:
A. Unique TensorFlower 2021-01-06 18:20:50 -08:00 committed by TensorFlower Gardener
parent d9e960b7be
commit fd758534e2

View File

@ -93,21 +93,32 @@ def should_use_v2():
graph.name.startswith("wrapped_function")):
return False
return True
else:
return False
def swap_class(cls, v2_cls, v1_cls, use_v2):
"""Swaps in v2_cls or v1_cls depending on graph mode."""
if cls == object:
return cls
if cls in (v2_cls, v1_cls):
if use_v2:
return v2_cls
return v1_cls
return v2_cls if use_v2 else v1_cls
# Recursively search superclasses to swap in the right Keras class.
cls.__bases__ = tuple(
swap_class(base, v2_cls, v1_cls, use_v2) for base in cls.__bases__)
new_bases = []
for base in cls.__bases__:
if ((use_v2 and issubclass(base, v1_cls)
# `v1_cls` often extends `v2_cls`, so it may still call `swap_class`
# even if it doesn't need to. That being said, it may be the safest
# not to over optimize this logic for the sake of correctness,
# especially if we swap v1 & v2 classes that don't extend each other,
# or when the inheritance order is different.
or (not use_v2 and issubclass(base, v2_cls)))):
new_base = swap_class(base, v2_cls, v1_cls, use_v2)
else:
new_base = base
new_bases.append(new_base)
cls.__bases__ = tuple(new_bases)
return cls