Put a bunch of TPU classes in the tensorflow::tpu namespace.
PiperOrigin-RevId: 332925458 Change-Id: Ibdecb012419112f6a82e4c5bf88160b2c2f71033
This commit is contained in:
parent
64e13aa79c
commit
81282b0a40
@ -247,7 +247,7 @@ class TpuCompiler : public Compiler {
|
||||
~TpuCompiler() override { ExecutorApiFn()->TpuCompiler_FreeFn(compiler_); }
|
||||
|
||||
stream_executor::Platform::Id PlatformId() const override {
|
||||
return tensorflow::TpuPlatform::kId;
|
||||
return tensorflow::tpu::TpuPlatform::kId;
|
||||
}
|
||||
|
||||
StatusOr<std::unique_ptr<HloModule>> RunHloPasses(
|
||||
@ -267,7 +267,7 @@ class TpuCompiler : public Compiler {
|
||||
StatusHelper status;
|
||||
ExecutorApiFn()->TpuCompiler_RunHloPassesFn(
|
||||
compiler_, &hlo_module,
|
||||
static_cast<tensorflow::TpuExecutor*>(executor->implementation())
|
||||
static_cast<tensorflow::tpu::TpuExecutor*>(executor->implementation())
|
||||
->se_executor(),
|
||||
&allocator, &result, status.c_status);
|
||||
if (!status.ok()) {
|
||||
@ -305,7 +305,7 @@ class TpuCompiler : public Compiler {
|
||||
StatusHelper status;
|
||||
ExecutorApiFn()->TpuCompiler_RunBackendFn(
|
||||
compiler_, &hlo_module,
|
||||
static_cast<tensorflow::TpuExecutor*>(executor->implementation())
|
||||
static_cast<tensorflow::tpu::TpuExecutor*>(executor->implementation())
|
||||
->se_executor(),
|
||||
&allocator, &result, status.c_status);
|
||||
if (!status.ok()) {
|
||||
@ -345,7 +345,7 @@ class TpuCompiler : public Compiler {
|
||||
se_lists_storage.emplace_back(stream_exec[i].size());
|
||||
se_lists[i].exec = se_lists_storage.back().data();
|
||||
for (int j = 0; j < stream_exec[i].size(); ++j) {
|
||||
se_lists[i].exec[j] = static_cast<tensorflow::TpuExecutor*>(
|
||||
se_lists[i].exec[j] = static_cast<tensorflow::tpu::TpuExecutor*>(
|
||||
stream_exec[i][j]->implementation())
|
||||
->se_executor();
|
||||
}
|
||||
@ -415,9 +415,9 @@ class TpuCompiler : public Compiler {
|
||||
};
|
||||
|
||||
static bool InitModule() {
|
||||
xla::Compiler::RegisterCompilerFactory(tensorflow::TpuPlatform::kId, []() {
|
||||
return absl::make_unique<TpuCompiler>();
|
||||
});
|
||||
xla::Compiler::RegisterCompilerFactory(
|
||||
tensorflow::tpu::TpuPlatform::kId,
|
||||
[]() { return absl::make_unique<TpuCompiler>(); });
|
||||
return true;
|
||||
}
|
||||
|
||||
|
@ -21,6 +21,7 @@ limitations under the License.
|
||||
#include "tensorflow/stream_executor/tpu/tpu_platform.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace tpu {
|
||||
namespace {
|
||||
|
||||
class TpuSystemDeviceFactory : public DeviceFactory {
|
||||
@ -75,4 +76,5 @@ void RegisterTpuSystemDevice() {
|
||||
REGISTER_LOCAL_DEVICE_FACTORY(DEVICE_TPU_SYSTEM, TpuSystemDeviceFactory);
|
||||
}
|
||||
|
||||
} // namespace tpu
|
||||
} // namespace tensorflow
|
||||
|
@ -17,9 +17,11 @@ limitations under the License.
|
||||
#define TENSORFLOW_CORE_TPU_TPU_SYSTEM_DEVICE_H_
|
||||
|
||||
namespace tensorflow {
|
||||
namespace tpu {
|
||||
|
||||
void RegisterTpuSystemDevice();
|
||||
|
||||
} // namespace tpu
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // TENSORFLOW_CORE_TPU_TPU_SYSTEM_DEVICE_H_
|
||||
|
@ -72,8 +72,8 @@ static std::unique_ptr<xla::ComputationPlacer> CreateTpuComputationPlacer() {
|
||||
}
|
||||
|
||||
static bool InitModule() {
|
||||
xla::ComputationPlacer::RegisterComputationPlacer(
|
||||
tensorflow::TpuPlatform::kId, CreateTpuComputationPlacer);
|
||||
xla::ComputationPlacer::RegisterComputationPlacer(TpuPlatform::kId,
|
||||
CreateTpuComputationPlacer);
|
||||
return true;
|
||||
}
|
||||
static bool module_initialized = InitModule();
|
||||
|
@ -19,6 +19,9 @@ limitations under the License.
|
||||
#include "tensorflow/stream_executor/stream_executor_internal.h"
|
||||
#include "tensorflow/stream_executor/tpu/c_api_decl.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace tpu {
|
||||
|
||||
class TpuEvent : public ::stream_executor::internal::EventInterface {
|
||||
public:
|
||||
explicit TpuEvent(SE_Event* event) : event_(event) {}
|
||||
@ -30,4 +33,7 @@ class TpuEvent : public ::stream_executor::internal::EventInterface {
|
||||
SE_Event* event_;
|
||||
};
|
||||
|
||||
} // namespace tpu
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // TENSORFLOW_STREAM_EXECUTOR_TPU_TPU_EVENT_H_
|
||||
|
@ -73,7 +73,7 @@ Status TpuExecutable::LoadProgramAndEnqueueToStream(
|
||||
c_dev_assign.size = dev_assign_serialized.size;
|
||||
}
|
||||
|
||||
auto platform = tensorflow::down_cast<tensorflow::TpuPlatform*>(
|
||||
auto platform = tensorflow::down_cast<tensorflow::tpu::TpuPlatform*>(
|
||||
tensorflow::tpu::TpuPlatformInterface::GetRegisteredPlatform());
|
||||
auto stream = platform->stream_map()->at(
|
||||
run_options.run_options().stream()->implementation());
|
||||
|
@ -26,6 +26,7 @@ limitations under the License.
|
||||
using stream_executor::DeviceMemoryBase;
|
||||
|
||||
namespace tensorflow {
|
||||
namespace tpu {
|
||||
|
||||
namespace {
|
||||
using ::stream_executor::port::Status;
|
||||
@ -372,4 +373,5 @@ TpuExecutor::CreateDeviceDescription() const {
|
||||
return status.status();
|
||||
}
|
||||
|
||||
} // namespace tpu
|
||||
} // namespace tensorflow
|
||||
|
@ -37,6 +37,7 @@ limitations under the License.
|
||||
#include "tensorflow/stream_executor/tpu/tpu_stream.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace tpu {
|
||||
|
||||
class TpuExecutor : public tensorflow::tpu::TpuExecutorInterface {
|
||||
public:
|
||||
@ -237,6 +238,7 @@ class TpuExecutor : public tensorflow::tpu::TpuExecutorInterface {
|
||||
SE_StreamExecutor* executor_;
|
||||
};
|
||||
|
||||
} // namespace tpu
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // TENSORFLOW_STREAM_EXECUTOR_TPU_TPU_EXECUTOR_H_
|
||||
|
@ -22,6 +22,7 @@ limitations under the License.
|
||||
#include "tensorflow/stream_executor/tpu/tpu_executor.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace tpu {
|
||||
|
||||
PLATFORM_DEFINE_ID(TpuPlatform::kId);
|
||||
TpuPlatform* tpu_registered_platform = nullptr;
|
||||
@ -99,8 +100,7 @@ TpuPlatform::GetUncachedExecutor(
|
||||
return status.status();
|
||||
}
|
||||
return std::make_unique<stream_executor::StreamExecutor>(
|
||||
this, std::make_unique<tensorflow::TpuExecutor>(this, executor),
|
||||
config.ordinal);
|
||||
this, std::make_unique<TpuExecutor>(this, executor), config.ordinal);
|
||||
}
|
||||
|
||||
::stream_executor::Platform::Id TpuPlatform::id() const {
|
||||
@ -165,9 +165,9 @@ Status TpuPlatform::TpuMemoryLimit(int64* memory_limit) {
|
||||
bool RegisterTpuPlatform() {
|
||||
static bool tpu_platform_registered = false;
|
||||
if (!tpu_platform_registered) {
|
||||
tensorflow::tpu_registered_platform = new tensorflow::TpuPlatform();
|
||||
tpu_registered_platform = new TpuPlatform();
|
||||
std::unique_ptr<stream_executor::Platform> platform(
|
||||
tensorflow::tpu_registered_platform);
|
||||
tpu_registered_platform);
|
||||
SE_CHECK_OK(stream_executor::MultiPlatformManager::RegisterPlatform(
|
||||
std::move(platform)));
|
||||
tpu_platform_registered = true;
|
||||
@ -175,4 +175,5 @@ bool RegisterTpuPlatform() {
|
||||
return true;
|
||||
}
|
||||
|
||||
} // namespace tpu
|
||||
} // namespace tensorflow
|
||||
|
@ -28,6 +28,7 @@ limitations under the License.
|
||||
#include "tensorflow/stream_executor/tpu/tpu_platform_interface.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace tpu {
|
||||
|
||||
class TpuPlatform : public ::tensorflow::tpu::TpuPlatformInterface {
|
||||
public:
|
||||
@ -147,6 +148,7 @@ class TpuPlatform : public ::tensorflow::tpu::TpuPlatformInterface {
|
||||
|
||||
bool RegisterTpuPlatform();
|
||||
|
||||
} // namespace tpu
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // TENSORFLOW_STREAM_EXECUTOR_TPU_TPU_PLATFORM_H_
|
||||
|
@ -17,7 +17,8 @@ limitations under the License.
|
||||
#include "tensorflow/stream_executor/tpu/tpu_platform.h"
|
||||
|
||||
#if defined(PLATFORM_GOOGLE)
|
||||
REGISTER_MODULE_INITIALIZER(tpu_platform, tensorflow::RegisterTpuPlatform());
|
||||
REGISTER_MODULE_INITIALIZER(tpu_platform,
|
||||
tensorflow::tpu::RegisterTpuPlatform());
|
||||
|
||||
DECLARE_MODULE_INITIALIZER(multi_platform_manager);
|
||||
DECLARE_MODULE_INITIALIZER(multi_platform_manager_listener);
|
||||
|
@ -30,6 +30,7 @@ limitations under the License.
|
||||
#include "tensorflow/stream_executor/tpu/tpu_platform.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace tpu {
|
||||
|
||||
using Status = stream_executor::port::Status;
|
||||
template <typename T>
|
||||
@ -319,4 +320,5 @@ Status TpuTransferManager::LinearizeToBuffers(
|
||||
return status.status();
|
||||
}
|
||||
|
||||
} // namespace tpu
|
||||
} // namespace tensorflow
|
||||
|
@ -25,6 +25,7 @@ limitations under the License.
|
||||
#include "tensorflow/stream_executor/tpu/tpu_transfer_manager_interface.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace tpu {
|
||||
|
||||
class TpuTransferManager : public xla::TpuTransferManagerInterface {
|
||||
public:
|
||||
@ -92,6 +93,7 @@ class TpuTransferManager : public xla::TpuTransferManagerInterface {
|
||||
XLA_TransferManager* manager_;
|
||||
};
|
||||
|
||||
} // namespace tpu
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // TENSORFLOW_STREAM_EXECUTOR_TPU_TPU_TRANSFER_MANAGER_H_
|
||||
|
@ -20,6 +20,7 @@ limitations under the License.
|
||||
#include "tensorflow/stream_executor/tpu/tpu_transfer_manager.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace tpu {
|
||||
|
||||
static std::unique_ptr<xla::TransferManager> CreateTpuTransferManager() {
|
||||
return std::make_unique<TpuTransferManager>();
|
||||
@ -32,4 +33,5 @@ static bool InitModule() {
|
||||
}
|
||||
static bool module_initialized = InitModule();
|
||||
|
||||
} // namespace tpu
|
||||
} // namespace tensorflow
|
||||
|
Loading…
Reference in New Issue
Block a user