Fix device resolution for reductions inside tf.functions in eager mode.

PiperOrigin-RevId: 242907308
This commit is contained in:
Sourabh Bajaj 2019-04-10 11:29:24 -07:00 committed by TensorFlower Gardener
parent e2cc5c574b
commit 12b9e3ad86
2 changed files with 3 additions and 4 deletions

View File

@ -87,9 +87,8 @@ class _FakeOperation(object):
def current():
"""Return a string (not canonicalized) for the current device."""
# TODO(josh11b): Work out how this function interacts with ops.colocate_with.
ctx = context.context()
if ctx.executing_eagerly():
d = ctx.device_name
if ops.executing_eagerly_outside_functions():
d = context.context().device_name
else:
op = _FakeOperation()
ops.get_default_graph()._apply_device_functions(op) # pylint: disable=protected-access

View File

@ -477,7 +477,7 @@ class TPUExtended(distribute_lib.DistributionStrategyExtended):
host_canonical = device_util.canonicalize(self._host_device)
if dest_canonical != host_canonical:
with ops.device(devices[0]):
with ops.device(dest_canonical):
output = array_ops.identity(output)
return output