Fix device resolution for reductions inside tf.functions in eager mode.
PiperOrigin-RevId: 242907308
This commit is contained in:
parent
e2cc5c574b
commit
12b9e3ad86
@ -87,9 +87,8 @@ class _FakeOperation(object):
|
||||
def current():
|
||||
"""Return a string (not canonicalized) for the current device."""
|
||||
# TODO(josh11b): Work out how this function interacts with ops.colocate_with.
|
||||
ctx = context.context()
|
||||
if ctx.executing_eagerly():
|
||||
d = ctx.device_name
|
||||
if ops.executing_eagerly_outside_functions():
|
||||
d = context.context().device_name
|
||||
else:
|
||||
op = _FakeOperation()
|
||||
ops.get_default_graph()._apply_device_functions(op) # pylint: disable=protected-access
|
||||
|
@ -477,7 +477,7 @@ class TPUExtended(distribute_lib.DistributionStrategyExtended):
|
||||
host_canonical = device_util.canonicalize(self._host_device)
|
||||
|
||||
if dest_canonical != host_canonical:
|
||||
with ops.device(devices[0]):
|
||||
with ops.device(dest_canonical):
|
||||
output = array_ops.identity(output)
|
||||
|
||||
return output
|
||||
|
Loading…
x
Reference in New Issue
Block a user