Enter the tpu worker device scope automatically if within TPUStrategy scope. This avoids the case that user always need to manually place their dataset and functions on tpu worker.
PiperOrigin-RevId: 261231056
This commit is contained in:
parent
8a142d90de
commit
41884bf3d3
@ -195,6 +195,12 @@ class TPUExtended(distribute_lib.StrategyExtendedV1):
|
|||||||
|
|
||||||
self._host_device = device_util.get_host_for_device(self._tpu_devices[0])
|
self._host_device = device_util.get_host_for_device(self._tpu_devices[0])
|
||||||
|
|
||||||
|
if ops.executing_eagerly_outside_functions():
|
||||||
|
# In remote eager mode, we need to place the tf.function on tpu worker
|
||||||
|
# cpu, setting the default device here so it can enter the host device
|
||||||
|
# scope automatically when inside strategy scope.
|
||||||
|
self._default_device = self._host_device
|
||||||
|
|
||||||
# Only create variables for the number of replicas we're running.
|
# Only create variables for the number of replicas we're running.
|
||||||
self._tpu_devices = self._tpu_devices[:self._num_replicas_in_sync]
|
self._tpu_devices = self._tpu_devices[:self._num_replicas_in_sync]
|
||||||
self._device_map = values.ReplicaDeviceMap(self._tpu_devices)
|
self._device_map = values.ReplicaDeviceMap(self._tpu_devices)
|
||||||
|
Loading…
Reference in New Issue
Block a user