From b42b1ec06c279004656d2c3e07eea32a2f3743ee Mon Sep 17 00:00:00 2001 From: Xiao Yu Date: Tue, 26 Jan 2021 16:49:13 -0800 Subject: [PATCH] Make multiple changes to enable TPU single host training e2e. The changes include: - Create TPU_SYSTEM device (in both TF and TFRT device manager) statically in startup. This is a virtual device to enable initialize_tpu_system() API. - Create TPU devices (in both TF and TFRT device manager) in tpurt.system.configure_distributed_tpu kernels. This kernel is execute when user invoked initialize_tpu_system() API. - Set tf_devices attribute in graph lowering. - Update tpu_strategy to get the TPU devices lazily. PiperOrigin-RevId: 353980774 Change-Id: I6dbce77ebfeee41fba7953260110fb8c5a5051b1 --- tensorflow/core/util/device_name_utils.cc | 2 ++ tensorflow/python/tpu/tpu_strategy_util.py | 1 + 2 files changed, 3 insertions(+) diff --git a/tensorflow/core/util/device_name_utils.cc b/tensorflow/core/util/device_name_utils.cc index 14dab634416..7a2b0e600d2 100644 --- a/tensorflow/core/util/device_name_utils.cc +++ b/tensorflow/core/util/device_name_utils.cc @@ -585,7 +585,9 @@ std::vector DeviceNameUtils::GetLocalNamesForDeviceMappings( return errors::Internal("Could not parse device name ", device_name); } device.type = "CPU"; + device.has_type = true; device.id = 0; + device.has_id = true; *host_device_name = DeviceNameUtils::ParsedNameToString(device); return Status::OK(); } diff --git a/tensorflow/python/tpu/tpu_strategy_util.py b/tensorflow/python/tpu/tpu_strategy_util.py index 4aac4471a59..281485db8b1 100644 --- a/tensorflow/python/tpu/tpu_strategy_util.py +++ b/tensorflow/python/tpu/tpu_strategy_util.py @@ -108,6 +108,7 @@ def initialize_tpu_system(cluster_resolver=None): # Clear out the eager context caches since the memory is invalid now. logging.info("Clearing out eager caches") context.context()._clear_caches() # pylint: disable=protected-access + context.context()._initialize_logical_devices() # pylint: disable=protected-access context.context().clear_kernel_cache() serialized_topology = output.numpy()