From 12b9e3ad86e25085e648c43e3ac61ac5db7fce56 Mon Sep 17 00:00:00 2001 From: Sourabh Bajaj Date: Wed, 10 Apr 2019 11:29:24 -0700 Subject: [PATCH] Fix device resolution for reductions inside tf.functions in eager mode. PiperOrigin-RevId: 242907308 --- tensorflow/python/distribute/device_util.py | 5 ++--- tensorflow/python/distribute/tpu_strategy.py | 2 +- 2 files changed, 3 insertions(+), 4 deletions(-) diff --git a/tensorflow/python/distribute/device_util.py b/tensorflow/python/distribute/device_util.py index 3e5b44dbe55..c3d858690e0 100644 --- a/tensorflow/python/distribute/device_util.py +++ b/tensorflow/python/distribute/device_util.py @@ -87,9 +87,8 @@ class _FakeOperation(object): def current(): """Return a string (not canonicalized) for the current device.""" # TODO(josh11b): Work out how this function interacts with ops.colocate_with. - ctx = context.context() - if ctx.executing_eagerly(): - d = ctx.device_name + if ops.executing_eagerly_outside_functions(): + d = context.context().device_name else: op = _FakeOperation() ops.get_default_graph()._apply_device_functions(op) # pylint: disable=protected-access diff --git a/tensorflow/python/distribute/tpu_strategy.py b/tensorflow/python/distribute/tpu_strategy.py index a322839d039..6909b140507 100644 --- a/tensorflow/python/distribute/tpu_strategy.py +++ b/tensorflow/python/distribute/tpu_strategy.py @@ -477,7 +477,7 @@ class TPUExtended(distribute_lib.DistributionStrategyExtended): host_canonical = device_util.canonicalize(self._host_device) if dest_canonical != host_canonical: - with ops.device(devices[0]): + with ops.device(dest_canonical): output = array_ops.identity(output) return output