Optimize version_utils.swap_class
method for better performance.
PiperOrigin-RevId: 350465713 Change-Id: Ibbacdccd974c7ecf597cb501248a228404996b01
This commit is contained in:
parent
d9e960b7be
commit
fd758534e2
@ -93,21 +93,32 @@ def should_use_v2():
|
||||
graph.name.startswith("wrapped_function")):
|
||||
return False
|
||||
return True
|
||||
else:
|
||||
return False
|
||||
|
||||
|
||||
def swap_class(cls, v2_cls, v1_cls, use_v2):
|
||||
"""Swaps in v2_cls or v1_cls depending on graph mode."""
|
||||
if cls == object:
|
||||
return cls
|
||||
|
||||
if cls in (v2_cls, v1_cls):
|
||||
if use_v2:
|
||||
return v2_cls
|
||||
return v1_cls
|
||||
return v2_cls if use_v2 else v1_cls
|
||||
|
||||
# Recursively search superclasses to swap in the right Keras class.
|
||||
cls.__bases__ = tuple(
|
||||
swap_class(base, v2_cls, v1_cls, use_v2) for base in cls.__bases__)
|
||||
new_bases = []
|
||||
for base in cls.__bases__:
|
||||
if ((use_v2 and issubclass(base, v1_cls)
|
||||
# `v1_cls` often extends `v2_cls`, so it may still call `swap_class`
|
||||
# even if it doesn't need to. That being said, it may be the safest
|
||||
# not to over optimize this logic for the sake of correctness,
|
||||
# especially if we swap v1 & v2 classes that don't extend each other,
|
||||
# or when the inheritance order is different.
|
||||
or (not use_v2 and issubclass(base, v2_cls)))):
|
||||
new_base = swap_class(base, v2_cls, v1_cls, use_v2)
|
||||
else:
|
||||
new_base = base
|
||||
new_bases.append(new_base)
|
||||
cls.__bases__ = tuple(new_bases)
|
||||
return cls
|
||||
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user