diff --git a/tensorflow/compiler/xla/service/BUILD b/tensorflow/compiler/xla/service/BUILD index 61ce56ced2b..aef2045642b 100644 --- a/tensorflow/compiler/xla/service/BUILD +++ b/tensorflow/compiler/xla/service/BUILD @@ -2923,6 +2923,7 @@ cc_library( "//tensorflow/stream_executor/tpu:status_helper", "//tensorflow/stream_executor/tpu:tpu_executor_c_api_hdrs", "//tensorflow/stream_executor/tpu:tpu_platform_hdr", + "//tensorflow/stream_executor/tpu:tpu_platform_id", "//tensorflow/stream_executor/tpu:tpu_topology_external", ], alwayslink = True, # Contains TPU computation placer registration diff --git a/tensorflow/compiler/xla/service/tpu_computation_placer.cc b/tensorflow/compiler/xla/service/tpu_computation_placer.cc index a60f1ae1f68..52d11dfba50 100644 --- a/tensorflow/compiler/xla/service/tpu_computation_placer.cc +++ b/tensorflow/compiler/xla/service/tpu_computation_placer.cc @@ -18,6 +18,7 @@ limitations under the License. #include "tensorflow/core/tpu/tpu_api.h" #include "tensorflow/stream_executor/tpu/status_helper.h" #include "tensorflow/stream_executor/tpu/tpu_platform.h" +#include "tensorflow/stream_executor/tpu/tpu_platform_id.h" namespace tensorflow { namespace tpu { @@ -72,7 +73,7 @@ static std::unique_ptr CreateTpuComputationPlacer() { } static bool InitModule() { - xla::ComputationPlacer::RegisterComputationPlacer(TpuPlatform::kId, + xla::ComputationPlacer::RegisterComputationPlacer(GetTpuPlatformId(), CreateTpuComputationPlacer); return true; } diff --git a/tensorflow/core/tpu/BUILD b/tensorflow/core/tpu/BUILD index e8995d7c3b5..3a4114793cf 100644 --- a/tensorflow/core/tpu/BUILD +++ b/tensorflow/core/tpu/BUILD @@ -273,6 +273,7 @@ cc_library( "//tensorflow/stream_executor/tpu:tpu_executable_interface", "//tensorflow/stream_executor/tpu:tpu_executor", "//tensorflow/stream_executor/tpu:tpu_executor_c_api_hdrs", + "//tensorflow/stream_executor/tpu:tpu_platform_id", "@com_google_absl//absl/types:span", ], alwayslink = True, diff --git a/tensorflow/core/tpu/tpu_on_demand_compiler.cc b/tensorflow/core/tpu/tpu_on_demand_compiler.cc index 50983ef9c9d..75165609191 100644 --- a/tensorflow/core/tpu/tpu_on_demand_compiler.cc +++ b/tensorflow/core/tpu/tpu_on_demand_compiler.cc @@ -33,6 +33,7 @@ limitations under the License. #include "tensorflow/stream_executor/tpu/tpu_executor.h" #include "tensorflow/stream_executor/tpu/tpu_executor_c_api.h" #include "tensorflow/stream_executor/tpu/tpu_platform.h" +#include "tensorflow/stream_executor/tpu/tpu_platform_id.h" #include "tensorflow/stream_executor/tpu/tpu_stream.h" namespace ApiConverter { @@ -211,7 +212,7 @@ class TpuCompiler : public Compiler { ~TpuCompiler() override { ExecutorApiFn()->TpuCompiler_FreeFn(compiler_); } stream_executor::Platform::Id PlatformId() const override { - return tensorflow::tpu::TpuPlatform::kId; + return tensorflow::tpu::GetTpuPlatformId(); } StatusOr> RunHloPasses( @@ -371,7 +372,7 @@ class TpuCompiler : public Compiler { static bool InitModule() { xla::Compiler::RegisterCompilerFactory( - tensorflow::tpu::TpuPlatform::kId, + tensorflow::tpu::GetTpuPlatformId(), []() { return absl::make_unique(); }); return true; } diff --git a/tensorflow/stream_executor/tpu/BUILD b/tensorflow/stream_executor/tpu/BUILD index 20847f2e3a2..412c2297db6 100644 --- a/tensorflow/stream_executor/tpu/BUILD +++ b/tensorflow/stream_executor/tpu/BUILD @@ -165,6 +165,13 @@ cc_library( ], ) +cc_library( + name = "tpu_platform_id", + srcs = ["tpu_platform_id.cc"], + hdrs = ["tpu_platform_id.h"], + deps = ["//tensorflow/stream_executor:stream_header"], +) + cc_library( name = "tpu_executor_base", srcs = [ @@ -184,6 +191,7 @@ cc_library( ":status_helper", ":tpu_executor_c_api_hdrs", ":tpu_executor_interface", + ":tpu_platform_id", ":tpu_platform_interface", ":tpu_stream_interface", "//tensorflow/c:tf_status", @@ -236,6 +244,7 @@ cc_library( visibility = ["//visibility:public"], deps = [ ":tpu_executor", + ":tpu_platform_id", ":tpu_transfer_manager_base", "//tensorflow/compiler/xla/service:transfer_manager", ], @@ -253,6 +262,7 @@ cc_library( ":status_helper", ":tpu_executor_base", ":tpu_executor_c_api_hdrs", + ":tpu_platform_id", ":tpu_transfer_manager_interface", "//tensorflow/compiler/xla:literal", "//tensorflow/compiler/xla:shape_util", diff --git a/tensorflow/stream_executor/tpu/tpu_platform.cc b/tensorflow/stream_executor/tpu/tpu_platform.cc index b77b86e0751..d5aab2586a3 100644 --- a/tensorflow/stream_executor/tpu/tpu_platform.cc +++ b/tensorflow/stream_executor/tpu/tpu_platform.cc @@ -20,11 +20,12 @@ limitations under the License. #include "tensorflow/core/tpu/tpu_api.h" #include "tensorflow/stream_executor/tpu/status_helper.h" #include "tensorflow/stream_executor/tpu/tpu_executor.h" +#include "tensorflow/stream_executor/tpu/tpu_platform_id.h" namespace tensorflow { namespace tpu { -PLATFORM_DEFINE_ID(TpuPlatform::kId); +const ::stream_executor::Platform::Id TpuPlatform::kId = GetTpuPlatformId(); TpuPlatform* tpu_registered_platform = nullptr; using Status = ::stream_executor::port::Status; diff --git a/tensorflow/stream_executor/tpu/tpu_platform_id.cc b/tensorflow/stream_executor/tpu/tpu_platform_id.cc new file mode 100644 index 00000000000..ad15bdad2fc --- /dev/null +++ b/tensorflow/stream_executor/tpu/tpu_platform_id.cc @@ -0,0 +1,30 @@ +/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/stream_executor/tpu/tpu_platform_id.h" + +namespace tensorflow { +namespace tpu { + +::stream_executor::Platform::Id GetTpuPlatformId() { + // We can't use the PLATFORM_DEFINE_ID macro because of potential + // initialization-order-fiasco errors. + static int plugin_id_value = 42; + const ::stream_executor::Platform::Id platform_id = &plugin_id_value; + return platform_id; +} + +} // namespace tpu +} // namespace tensorflow diff --git a/tensorflow/stream_executor/tpu/tpu_platform_id.h b/tensorflow/stream_executor/tpu/tpu_platform_id.h new file mode 100644 index 00000000000..f6a4af3e34d --- /dev/null +++ b/tensorflow/stream_executor/tpu/tpu_platform_id.h @@ -0,0 +1,29 @@ +/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_STREAM_EXECUTOR_TPU_TPU_PLATFORM_ID_H_ +#define TENSORFLOW_STREAM_EXECUTOR_TPU_TPU_PLATFORM_ID_H_ + +#include "tensorflow/stream_executor/platform.h" + +namespace tensorflow { +namespace tpu { + +::stream_executor::Platform::Id GetTpuPlatformId(); + +} // namespace tpu +} // namespace tensorflow + +#endif // TENSORFLOW_STREAM_EXECUTOR_TPU_TPU_PLATFORM_ID_H_ diff --git a/tensorflow/stream_executor/tpu/tpu_transfer_manager.cc b/tensorflow/stream_executor/tpu/tpu_transfer_manager.cc index 2900e6b08e5..1790c24a9f6 100644 --- a/tensorflow/stream_executor/tpu/tpu_transfer_manager.cc +++ b/tensorflow/stream_executor/tpu/tpu_transfer_manager.cc @@ -28,6 +28,7 @@ limitations under the License. #include "tensorflow/stream_executor/tpu/tpu_executor.h" #include "tensorflow/stream_executor/tpu/tpu_executor_c_api.h" #include "tensorflow/stream_executor/tpu/tpu_platform.h" +#include "tensorflow/stream_executor/tpu/tpu_platform_id.h" namespace tensorflow { namespace tpu { @@ -45,7 +46,7 @@ TpuTransferManager::~TpuTransferManager() { } stream_executor::Platform::Id TpuTransferManager::PlatformId() const { - return TpuPlatform::kId; + return GetTpuPlatformId(); } xla::Shape TpuTransferManager::HostShapeToDeviceShape( diff --git a/tensorflow/stream_executor/tpu/tpu_transfer_manager_registration.cc b/tensorflow/stream_executor/tpu/tpu_transfer_manager_registration.cc index f4af8882e18..ff4cb326ade 100644 --- a/tensorflow/stream_executor/tpu/tpu_transfer_manager_registration.cc +++ b/tensorflow/stream_executor/tpu/tpu_transfer_manager_registration.cc @@ -17,6 +17,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/transfer_manager.h" #include "tensorflow/stream_executor/tpu/tpu_platform.h" +#include "tensorflow/stream_executor/tpu/tpu_platform_id.h" #include "tensorflow/stream_executor/tpu/tpu_transfer_manager.h" namespace tensorflow { @@ -27,7 +28,7 @@ static std::unique_ptr CreateTpuTransferManager() { } static bool InitModule() { - xla::TransferManager::RegisterTransferManager(TpuPlatform::kId, + xla::TransferManager::RegisterTransferManager(GetTpuPlatformId(), CreateTpuTransferManager); return true; }