Make multiple changes to enable TPU single host training e2e. The changes include:

- Create TPU_SYSTEM device (in both TF and TFRT device manager)
statically in startup. This is a virtual device to enable initialize_tpu_system() API.
  - Create TPU devices (in both TF and TFRT device manager) in tpurt.system.configure_distributed_tpu kernels. This kernel is execute when user invoked initialize_tpu_system() API.
  - Set tf_devices attribute in graph lowering.
  - Update tpu_strategy to get the TPU devices lazily.

PiperOrigin-RevId: 353980774
Change-Id: I6dbce77ebfeee41fba7953260110fb8c5a5051b1
This commit is contained in:
Xiao Yu 2021-01-26 16:49:13 -08:00 committed by TensorFlower Gardener
parent 5cac10d182
commit b42b1ec06c
2 changed files with 3 additions and 0 deletions

View File

@ -585,7 +585,9 @@ std::vector<string> DeviceNameUtils::GetLocalNamesForDeviceMappings(
return errors::Internal("Could not parse device name ", device_name);
}
device.type = "CPU";
device.has_type = true;
device.id = 0;
device.has_id = true;
*host_device_name = DeviceNameUtils::ParsedNameToString(device);
return Status::OK();
}

View File

@ -108,6 +108,7 @@ def initialize_tpu_system(cluster_resolver=None):
# Clear out the eager context caches since the memory is invalid now.
logging.info("Clearing out eager caches")
context.context()._clear_caches() # pylint: disable=protected-access
context.context()._initialize_logical_devices() # pylint: disable=protected-access
context.context().clear_kernel_cache()
serialized_topology = output.numpy()