diff --git a/tensorflow/python/distribute/tpu_strategy.py b/tensorflow/python/distribute/tpu_strategy.py index 2d301b51e41..5d1a189ecba 100644 --- a/tensorflow/python/distribute/tpu_strategy.py +++ b/tensorflow/python/distribute/tpu_strategy.py @@ -195,6 +195,12 @@ class TPUExtended(distribute_lib.StrategyExtendedV1): self._host_device = device_util.get_host_for_device(self._tpu_devices[0]) + if ops.executing_eagerly_outside_functions(): + # In remote eager mode, we need to place the tf.function on tpu worker + # cpu, setting the default device here so it can enter the host device + # scope automatically when inside strategy scope. + self._default_device = self._host_device + # Only create variables for the number of replicas we're running. self._tpu_devices = self._tpu_devices[:self._num_replicas_in_sync] self._device_map = values.ReplicaDeviceMap(self._tpu_devices)