Always execute the default branch in xla case for execute_fn_for_device
PiperOrigin-RevId: 330594659 Change-Id: I0c753afdb7eb71194c68d412def38c9fc1c4ae58
This commit is contained in:
parent
3cdc97310f
commit
60bee51060
@ -3647,7 +3647,11 @@ def execute_fn_for_device(device_branch_fns, default_fn, name="execute_fn"):
|
||||
The tensors returned by the callable identified by device type during
|
||||
execution, or those returned by 'default_fn' if no key matches.
|
||||
"""
|
||||
|
||||
# Always execute the default fn for XLA to avoid complicated graph by case op.
|
||||
# see more discussions in b/167276293.
|
||||
is_in_xla = util.GraphOrParentsInXlaContext(ops.get_default_graph())
|
||||
if is_in_xla:
|
||||
return default_fn()
|
||||
device_branch_fns_upper = {k.upper(): v for k, v in device_branch_fns.items()}
|
||||
branch_fns = list(device_branch_fns_upper.values())
|
||||
devices = list(device_branch_fns_upper.keys())
|
||||
|
Loading…
Reference in New Issue
Block a user