Always execute the default branch in xla case for execute_fn_for_device

PiperOrigin-RevId: 330594659
Change-Id: I0c753afdb7eb71194c68d412def38c9fc1c4ae58
This commit is contained in:
Yanhua Sun 2020-09-08 15:03:27 -07:00 committed by TensorFlower Gardener
parent 3cdc97310f
commit 60bee51060

View File

@ -3647,7 +3647,11 @@ def execute_fn_for_device(device_branch_fns, default_fn, name="execute_fn"):
The tensors returned by the callable identified by device type during
execution, or those returned by 'default_fn' if no key matches.
"""
# Always execute the default fn for XLA to avoid complicated graph by case op.
# see more discussions in b/167276293.
is_in_xla = util.GraphOrParentsInXlaContext(ops.get_default_graph())
if is_in_xla:
return default_fn()
device_branch_fns_upper = {k.upper(): v for k, v in device_branch_fns.items()}
branch_fns = list(device_branch_fns_upper.values())
devices = list(device_branch_fns_upper.keys())