Extract tpu_platform_id initialization to its own package.
Now, tpu_computation_placer can get the ID without dependingon tpu_executor_base out and creating a circular dependency. PiperOrigin-RevId: 353308958 Change-Id: I9b23cc9caaebf7d55cc63b47da4b77a3d3fb94a5
This commit is contained in:
parent
445cdeb9f0
commit
de326ad87d
@ -2923,6 +2923,7 @@ cc_library(
|
|||||||
"//tensorflow/stream_executor/tpu:status_helper",
|
"//tensorflow/stream_executor/tpu:status_helper",
|
||||||
"//tensorflow/stream_executor/tpu:tpu_executor_c_api_hdrs",
|
"//tensorflow/stream_executor/tpu:tpu_executor_c_api_hdrs",
|
||||||
"//tensorflow/stream_executor/tpu:tpu_platform_hdr",
|
"//tensorflow/stream_executor/tpu:tpu_platform_hdr",
|
||||||
|
"//tensorflow/stream_executor/tpu:tpu_platform_id",
|
||||||
"//tensorflow/stream_executor/tpu:tpu_topology_external",
|
"//tensorflow/stream_executor/tpu:tpu_topology_external",
|
||||||
],
|
],
|
||||||
alwayslink = True, # Contains TPU computation placer registration
|
alwayslink = True, # Contains TPU computation placer registration
|
||||||
|
@ -18,6 +18,7 @@ limitations under the License.
|
|||||||
#include "tensorflow/core/tpu/tpu_api.h"
|
#include "tensorflow/core/tpu/tpu_api.h"
|
||||||
#include "tensorflow/stream_executor/tpu/status_helper.h"
|
#include "tensorflow/stream_executor/tpu/status_helper.h"
|
||||||
#include "tensorflow/stream_executor/tpu/tpu_platform.h"
|
#include "tensorflow/stream_executor/tpu/tpu_platform.h"
|
||||||
|
#include "tensorflow/stream_executor/tpu/tpu_platform_id.h"
|
||||||
|
|
||||||
namespace tensorflow {
|
namespace tensorflow {
|
||||||
namespace tpu {
|
namespace tpu {
|
||||||
@ -72,7 +73,7 @@ static std::unique_ptr<xla::ComputationPlacer> CreateTpuComputationPlacer() {
|
|||||||
}
|
}
|
||||||
|
|
||||||
static bool InitModule() {
|
static bool InitModule() {
|
||||||
xla::ComputationPlacer::RegisterComputationPlacer(TpuPlatform::kId,
|
xla::ComputationPlacer::RegisterComputationPlacer(GetTpuPlatformId(),
|
||||||
CreateTpuComputationPlacer);
|
CreateTpuComputationPlacer);
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
@ -273,6 +273,7 @@ cc_library(
|
|||||||
"//tensorflow/stream_executor/tpu:tpu_executable_interface",
|
"//tensorflow/stream_executor/tpu:tpu_executable_interface",
|
||||||
"//tensorflow/stream_executor/tpu:tpu_executor",
|
"//tensorflow/stream_executor/tpu:tpu_executor",
|
||||||
"//tensorflow/stream_executor/tpu:tpu_executor_c_api_hdrs",
|
"//tensorflow/stream_executor/tpu:tpu_executor_c_api_hdrs",
|
||||||
|
"//tensorflow/stream_executor/tpu:tpu_platform_id",
|
||||||
"@com_google_absl//absl/types:span",
|
"@com_google_absl//absl/types:span",
|
||||||
],
|
],
|
||||||
alwayslink = True,
|
alwayslink = True,
|
||||||
|
@ -33,6 +33,7 @@ limitations under the License.
|
|||||||
#include "tensorflow/stream_executor/tpu/tpu_executor.h"
|
#include "tensorflow/stream_executor/tpu/tpu_executor.h"
|
||||||
#include "tensorflow/stream_executor/tpu/tpu_executor_c_api.h"
|
#include "tensorflow/stream_executor/tpu/tpu_executor_c_api.h"
|
||||||
#include "tensorflow/stream_executor/tpu/tpu_platform.h"
|
#include "tensorflow/stream_executor/tpu/tpu_platform.h"
|
||||||
|
#include "tensorflow/stream_executor/tpu/tpu_platform_id.h"
|
||||||
#include "tensorflow/stream_executor/tpu/tpu_stream.h"
|
#include "tensorflow/stream_executor/tpu/tpu_stream.h"
|
||||||
|
|
||||||
namespace ApiConverter {
|
namespace ApiConverter {
|
||||||
@ -211,7 +212,7 @@ class TpuCompiler : public Compiler {
|
|||||||
~TpuCompiler() override { ExecutorApiFn()->TpuCompiler_FreeFn(compiler_); }
|
~TpuCompiler() override { ExecutorApiFn()->TpuCompiler_FreeFn(compiler_); }
|
||||||
|
|
||||||
stream_executor::Platform::Id PlatformId() const override {
|
stream_executor::Platform::Id PlatformId() const override {
|
||||||
return tensorflow::tpu::TpuPlatform::kId;
|
return tensorflow::tpu::GetTpuPlatformId();
|
||||||
}
|
}
|
||||||
|
|
||||||
StatusOr<std::unique_ptr<HloModule>> RunHloPasses(
|
StatusOr<std::unique_ptr<HloModule>> RunHloPasses(
|
||||||
@ -371,7 +372,7 @@ class TpuCompiler : public Compiler {
|
|||||||
|
|
||||||
static bool InitModule() {
|
static bool InitModule() {
|
||||||
xla::Compiler::RegisterCompilerFactory(
|
xla::Compiler::RegisterCompilerFactory(
|
||||||
tensorflow::tpu::TpuPlatform::kId,
|
tensorflow::tpu::GetTpuPlatformId(),
|
||||||
[]() { return absl::make_unique<TpuCompiler>(); });
|
[]() { return absl::make_unique<TpuCompiler>(); });
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
@ -165,6 +165,13 @@ cc_library(
|
|||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
cc_library(
|
||||||
|
name = "tpu_platform_id",
|
||||||
|
srcs = ["tpu_platform_id.cc"],
|
||||||
|
hdrs = ["tpu_platform_id.h"],
|
||||||
|
deps = ["//tensorflow/stream_executor:stream_header"],
|
||||||
|
)
|
||||||
|
|
||||||
cc_library(
|
cc_library(
|
||||||
name = "tpu_executor_base",
|
name = "tpu_executor_base",
|
||||||
srcs = [
|
srcs = [
|
||||||
@ -184,6 +191,7 @@ cc_library(
|
|||||||
":status_helper",
|
":status_helper",
|
||||||
":tpu_executor_c_api_hdrs",
|
":tpu_executor_c_api_hdrs",
|
||||||
":tpu_executor_interface",
|
":tpu_executor_interface",
|
||||||
|
":tpu_platform_id",
|
||||||
":tpu_platform_interface",
|
":tpu_platform_interface",
|
||||||
":tpu_stream_interface",
|
":tpu_stream_interface",
|
||||||
"//tensorflow/c:tf_status",
|
"//tensorflow/c:tf_status",
|
||||||
@ -236,6 +244,7 @@ cc_library(
|
|||||||
visibility = ["//visibility:public"],
|
visibility = ["//visibility:public"],
|
||||||
deps = [
|
deps = [
|
||||||
":tpu_executor",
|
":tpu_executor",
|
||||||
|
":tpu_platform_id",
|
||||||
":tpu_transfer_manager_base",
|
":tpu_transfer_manager_base",
|
||||||
"//tensorflow/compiler/xla/service:transfer_manager",
|
"//tensorflow/compiler/xla/service:transfer_manager",
|
||||||
],
|
],
|
||||||
@ -253,6 +262,7 @@ cc_library(
|
|||||||
":status_helper",
|
":status_helper",
|
||||||
":tpu_executor_base",
|
":tpu_executor_base",
|
||||||
":tpu_executor_c_api_hdrs",
|
":tpu_executor_c_api_hdrs",
|
||||||
|
":tpu_platform_id",
|
||||||
":tpu_transfer_manager_interface",
|
":tpu_transfer_manager_interface",
|
||||||
"//tensorflow/compiler/xla:literal",
|
"//tensorflow/compiler/xla:literal",
|
||||||
"//tensorflow/compiler/xla:shape_util",
|
"//tensorflow/compiler/xla:shape_util",
|
||||||
|
@ -20,11 +20,12 @@ limitations under the License.
|
|||||||
#include "tensorflow/core/tpu/tpu_api.h"
|
#include "tensorflow/core/tpu/tpu_api.h"
|
||||||
#include "tensorflow/stream_executor/tpu/status_helper.h"
|
#include "tensorflow/stream_executor/tpu/status_helper.h"
|
||||||
#include "tensorflow/stream_executor/tpu/tpu_executor.h"
|
#include "tensorflow/stream_executor/tpu/tpu_executor.h"
|
||||||
|
#include "tensorflow/stream_executor/tpu/tpu_platform_id.h"
|
||||||
|
|
||||||
namespace tensorflow {
|
namespace tensorflow {
|
||||||
namespace tpu {
|
namespace tpu {
|
||||||
|
|
||||||
PLATFORM_DEFINE_ID(TpuPlatform::kId);
|
const ::stream_executor::Platform::Id TpuPlatform::kId = GetTpuPlatformId();
|
||||||
TpuPlatform* tpu_registered_platform = nullptr;
|
TpuPlatform* tpu_registered_platform = nullptr;
|
||||||
|
|
||||||
using Status = ::stream_executor::port::Status;
|
using Status = ::stream_executor::port::Status;
|
||||||
|
30
tensorflow/stream_executor/tpu/tpu_platform_id.cc
Normal file
30
tensorflow/stream_executor/tpu/tpu_platform_id.cc
Normal file
@ -0,0 +1,30 @@
|
|||||||
|
/* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
#include "tensorflow/stream_executor/tpu/tpu_platform_id.h"
|
||||||
|
|
||||||
|
namespace tensorflow {
|
||||||
|
namespace tpu {
|
||||||
|
|
||||||
|
::stream_executor::Platform::Id GetTpuPlatformId() {
|
||||||
|
// We can't use the PLATFORM_DEFINE_ID macro because of potential
|
||||||
|
// initialization-order-fiasco errors.
|
||||||
|
static int plugin_id_value = 42;
|
||||||
|
const ::stream_executor::Platform::Id platform_id = &plugin_id_value;
|
||||||
|
return platform_id;
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace tpu
|
||||||
|
} // namespace tensorflow
|
29
tensorflow/stream_executor/tpu/tpu_platform_id.h
Normal file
29
tensorflow/stream_executor/tpu/tpu_platform_id.h
Normal file
@ -0,0 +1,29 @@
|
|||||||
|
/* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
#ifndef TENSORFLOW_STREAM_EXECUTOR_TPU_TPU_PLATFORM_ID_H_
|
||||||
|
#define TENSORFLOW_STREAM_EXECUTOR_TPU_TPU_PLATFORM_ID_H_
|
||||||
|
|
||||||
|
#include "tensorflow/stream_executor/platform.h"
|
||||||
|
|
||||||
|
namespace tensorflow {
|
||||||
|
namespace tpu {
|
||||||
|
|
||||||
|
::stream_executor::Platform::Id GetTpuPlatformId();
|
||||||
|
|
||||||
|
} // namespace tpu
|
||||||
|
} // namespace tensorflow
|
||||||
|
|
||||||
|
#endif // TENSORFLOW_STREAM_EXECUTOR_TPU_TPU_PLATFORM_ID_H_
|
@ -28,6 +28,7 @@ limitations under the License.
|
|||||||
#include "tensorflow/stream_executor/tpu/tpu_executor.h"
|
#include "tensorflow/stream_executor/tpu/tpu_executor.h"
|
||||||
#include "tensorflow/stream_executor/tpu/tpu_executor_c_api.h"
|
#include "tensorflow/stream_executor/tpu/tpu_executor_c_api.h"
|
||||||
#include "tensorflow/stream_executor/tpu/tpu_platform.h"
|
#include "tensorflow/stream_executor/tpu/tpu_platform.h"
|
||||||
|
#include "tensorflow/stream_executor/tpu/tpu_platform_id.h"
|
||||||
|
|
||||||
namespace tensorflow {
|
namespace tensorflow {
|
||||||
namespace tpu {
|
namespace tpu {
|
||||||
@ -45,7 +46,7 @@ TpuTransferManager::~TpuTransferManager() {
|
|||||||
}
|
}
|
||||||
|
|
||||||
stream_executor::Platform::Id TpuTransferManager::PlatformId() const {
|
stream_executor::Platform::Id TpuTransferManager::PlatformId() const {
|
||||||
return TpuPlatform::kId;
|
return GetTpuPlatformId();
|
||||||
}
|
}
|
||||||
|
|
||||||
xla::Shape TpuTransferManager::HostShapeToDeviceShape(
|
xla::Shape TpuTransferManager::HostShapeToDeviceShape(
|
||||||
|
@ -17,6 +17,7 @@ limitations under the License.
|
|||||||
|
|
||||||
#include "tensorflow/compiler/xla/service/transfer_manager.h"
|
#include "tensorflow/compiler/xla/service/transfer_manager.h"
|
||||||
#include "tensorflow/stream_executor/tpu/tpu_platform.h"
|
#include "tensorflow/stream_executor/tpu/tpu_platform.h"
|
||||||
|
#include "tensorflow/stream_executor/tpu/tpu_platform_id.h"
|
||||||
#include "tensorflow/stream_executor/tpu/tpu_transfer_manager.h"
|
#include "tensorflow/stream_executor/tpu/tpu_transfer_manager.h"
|
||||||
|
|
||||||
namespace tensorflow {
|
namespace tensorflow {
|
||||||
@ -27,7 +28,7 @@ static std::unique_ptr<xla::TransferManager> CreateTpuTransferManager() {
|
|||||||
}
|
}
|
||||||
|
|
||||||
static bool InitModule() {
|
static bool InitModule() {
|
||||||
xla::TransferManager::RegisterTransferManager(TpuPlatform::kId,
|
xla::TransferManager::RegisterTransferManager(GetTpuPlatformId(),
|
||||||
CreateTpuTransferManager);
|
CreateTpuTransferManager);
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
Loading…
x
Reference in New Issue
Block a user