Make multiple changes to enable TPU single host training e2e. The changes include:
- Create TPU_SYSTEM device (in both TF and TFRT device manager) statically in startup. This is a virtual device to enable initialize_tpu_system() API. - Create TPU devices (in both TF and TFRT device manager) in tpurt.system.configure_distributed_tpu kernels. This kernel is execute when user invoked initialize_tpu_system() API. - Set tf_devices attribute in graph lowering. - Update tpu_strategy to get the TPU devices lazily. PiperOrigin-RevId: 353980774 Change-Id: I6dbce77ebfeee41fba7953260110fb8c5a5051b1
This commit is contained in:
parent
5cac10d182
commit
b42b1ec06c
@ -585,7 +585,9 @@ std::vector<string> DeviceNameUtils::GetLocalNamesForDeviceMappings(
|
||||
return errors::Internal("Could not parse device name ", device_name);
|
||||
}
|
||||
device.type = "CPU";
|
||||
device.has_type = true;
|
||||
device.id = 0;
|
||||
device.has_id = true;
|
||||
*host_device_name = DeviceNameUtils::ParsedNameToString(device);
|
||||
return Status::OK();
|
||||
}
|
||||
|
@ -108,6 +108,7 @@ def initialize_tpu_system(cluster_resolver=None):
|
||||
# Clear out the eager context caches since the memory is invalid now.
|
||||
logging.info("Clearing out eager caches")
|
||||
context.context()._clear_caches() # pylint: disable=protected-access
|
||||
context.context()._initialize_logical_devices() # pylint: disable=protected-access
|
||||
context.context().clear_kernel_cache()
|
||||
|
||||
serialized_topology = output.numpy()
|
||||
|
Loading…
x
Reference in New Issue
Block a user