Enter the tpu worker device scope automatically if within TPUStrategy scope. This avoids the case that user always need to manually place their dataset and functions on tpu worker.

PiperOrigin-RevId: 261231056
This commit is contained in:
Ruoxin Sang 2019-08-01 17:22:31 -07:00 committed by TensorFlower Gardener
parent 8a142d90de
commit 41884bf3d3

View File

@ -195,6 +195,12 @@ class TPUExtended(distribute_lib.StrategyExtendedV1):
self._host_device = device_util.get_host_for_device(self._tpu_devices[0]) self._host_device = device_util.get_host_for_device(self._tpu_devices[0])
if ops.executing_eagerly_outside_functions():
# In remote eager mode, we need to place the tf.function on tpu worker
# cpu, setting the default device here so it can enter the host device
# scope automatically when inside strategy scope.
self._default_device = self._host_device
# Only create variables for the number of replicas we're running. # Only create variables for the number of replicas we're running.
self._tpu_devices = self._tpu_devices[:self._num_replicas_in_sync] self._tpu_devices = self._tpu_devices[:self._num_replicas_in_sync]
self._device_map = values.ReplicaDeviceMap(self._tpu_devices) self._device_map = values.ReplicaDeviceMap(self._tpu_devices)