Extract tpu_platform_id initialization to its own package.

Now, tpu_computation_placer can get the ID without dependingon tpu_executor_base out and creating a circular dependency.

PiperOrigin-RevId: 353308958
Change-Id: I9b23cc9caaebf7d55cc63b47da4b77a3d3fb94a5
This commit is contained in:
Frank Chen 2021-01-22 13:35:06 -08:00 committed by TensorFlower Gardener
parent 445cdeb9f0
commit de326ad87d
10 changed files with 82 additions and 6 deletions

View File

@ -2923,6 +2923,7 @@ cc_library(
"//tensorflow/stream_executor/tpu:status_helper",
"//tensorflow/stream_executor/tpu:tpu_executor_c_api_hdrs",
"//tensorflow/stream_executor/tpu:tpu_platform_hdr",
"//tensorflow/stream_executor/tpu:tpu_platform_id",
"//tensorflow/stream_executor/tpu:tpu_topology_external",
],
alwayslink = True, # Contains TPU computation placer registration

View File

@ -18,6 +18,7 @@ limitations under the License.
#include "tensorflow/core/tpu/tpu_api.h"
#include "tensorflow/stream_executor/tpu/status_helper.h"
#include "tensorflow/stream_executor/tpu/tpu_platform.h"
#include "tensorflow/stream_executor/tpu/tpu_platform_id.h"
namespace tensorflow {
namespace tpu {
@ -72,7 +73,7 @@ static std::unique_ptr<xla::ComputationPlacer> CreateTpuComputationPlacer() {
}
static bool InitModule() {
xla::ComputationPlacer::RegisterComputationPlacer(TpuPlatform::kId,
xla::ComputationPlacer::RegisterComputationPlacer(GetTpuPlatformId(),
CreateTpuComputationPlacer);
return true;
}

View File

@ -273,6 +273,7 @@ cc_library(
"//tensorflow/stream_executor/tpu:tpu_executable_interface",
"//tensorflow/stream_executor/tpu:tpu_executor",
"//tensorflow/stream_executor/tpu:tpu_executor_c_api_hdrs",
"//tensorflow/stream_executor/tpu:tpu_platform_id",
"@com_google_absl//absl/types:span",
],
alwayslink = True,

View File

@ -33,6 +33,7 @@ limitations under the License.
#include "tensorflow/stream_executor/tpu/tpu_executor.h"
#include "tensorflow/stream_executor/tpu/tpu_executor_c_api.h"
#include "tensorflow/stream_executor/tpu/tpu_platform.h"
#include "tensorflow/stream_executor/tpu/tpu_platform_id.h"
#include "tensorflow/stream_executor/tpu/tpu_stream.h"
namespace ApiConverter {
@ -211,7 +212,7 @@ class TpuCompiler : public Compiler {
~TpuCompiler() override { ExecutorApiFn()->TpuCompiler_FreeFn(compiler_); }
stream_executor::Platform::Id PlatformId() const override {
return tensorflow::tpu::TpuPlatform::kId;
return tensorflow::tpu::GetTpuPlatformId();
}
StatusOr<std::unique_ptr<HloModule>> RunHloPasses(
@ -371,7 +372,7 @@ class TpuCompiler : public Compiler {
static bool InitModule() {
xla::Compiler::RegisterCompilerFactory(
tensorflow::tpu::TpuPlatform::kId,
tensorflow::tpu::GetTpuPlatformId(),
[]() { return absl::make_unique<TpuCompiler>(); });
return true;
}

View File

@ -165,6 +165,13 @@ cc_library(
],
)
cc_library(
name = "tpu_platform_id",
srcs = ["tpu_platform_id.cc"],
hdrs = ["tpu_platform_id.h"],
deps = ["//tensorflow/stream_executor:stream_header"],
)
cc_library(
name = "tpu_executor_base",
srcs = [
@ -184,6 +191,7 @@ cc_library(
":status_helper",
":tpu_executor_c_api_hdrs",
":tpu_executor_interface",
":tpu_platform_id",
":tpu_platform_interface",
":tpu_stream_interface",
"//tensorflow/c:tf_status",
@ -236,6 +244,7 @@ cc_library(
visibility = ["//visibility:public"],
deps = [
":tpu_executor",
":tpu_platform_id",
":tpu_transfer_manager_base",
"//tensorflow/compiler/xla/service:transfer_manager",
],
@ -253,6 +262,7 @@ cc_library(
":status_helper",
":tpu_executor_base",
":tpu_executor_c_api_hdrs",
":tpu_platform_id",
":tpu_transfer_manager_interface",
"//tensorflow/compiler/xla:literal",
"//tensorflow/compiler/xla:shape_util",

View File

@ -20,11 +20,12 @@ limitations under the License.
#include "tensorflow/core/tpu/tpu_api.h"
#include "tensorflow/stream_executor/tpu/status_helper.h"
#include "tensorflow/stream_executor/tpu/tpu_executor.h"
#include "tensorflow/stream_executor/tpu/tpu_platform_id.h"
namespace tensorflow {
namespace tpu {
PLATFORM_DEFINE_ID(TpuPlatform::kId);
const ::stream_executor::Platform::Id TpuPlatform::kId = GetTpuPlatformId();
TpuPlatform* tpu_registered_platform = nullptr;
using Status = ::stream_executor::port::Status;

View File

@ -0,0 +1,30 @@
/* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/stream_executor/tpu/tpu_platform_id.h"
namespace tensorflow {
namespace tpu {
::stream_executor::Platform::Id GetTpuPlatformId() {
// We can't use the PLATFORM_DEFINE_ID macro because of potential
// initialization-order-fiasco errors.
static int plugin_id_value = 42;
const ::stream_executor::Platform::Id platform_id = &plugin_id_value;
return platform_id;
}
} // namespace tpu
} // namespace tensorflow

View File

@ -0,0 +1,29 @@
/* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_STREAM_EXECUTOR_TPU_TPU_PLATFORM_ID_H_
#define TENSORFLOW_STREAM_EXECUTOR_TPU_TPU_PLATFORM_ID_H_
#include "tensorflow/stream_executor/platform.h"
namespace tensorflow {
namespace tpu {
::stream_executor::Platform::Id GetTpuPlatformId();
} // namespace tpu
} // namespace tensorflow
#endif // TENSORFLOW_STREAM_EXECUTOR_TPU_TPU_PLATFORM_ID_H_

View File

@ -28,6 +28,7 @@ limitations under the License.
#include "tensorflow/stream_executor/tpu/tpu_executor.h"
#include "tensorflow/stream_executor/tpu/tpu_executor_c_api.h"
#include "tensorflow/stream_executor/tpu/tpu_platform.h"
#include "tensorflow/stream_executor/tpu/tpu_platform_id.h"
namespace tensorflow {
namespace tpu {
@ -45,7 +46,7 @@ TpuTransferManager::~TpuTransferManager() {
}
stream_executor::Platform::Id TpuTransferManager::PlatformId() const {
return TpuPlatform::kId;
return GetTpuPlatformId();
}
xla::Shape TpuTransferManager::HostShapeToDeviceShape(

View File

@ -17,6 +17,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/transfer_manager.h"
#include "tensorflow/stream_executor/tpu/tpu_platform.h"
#include "tensorflow/stream_executor/tpu/tpu_platform_id.h"
#include "tensorflow/stream_executor/tpu/tpu_transfer_manager.h"
namespace tensorflow {
@ -27,7 +28,7 @@ static std::unique_ptr<xla::TransferManager> CreateTpuTransferManager() {
}
static bool InitModule() {
xla::TransferManager::RegisterTransferManager(TpuPlatform::kId,
xla::TransferManager::RegisterTransferManager(GetTpuPlatformId(),
CreateTpuTransferManager);
return true;
}