Put a bunch of TPU classes in the tensorflow::tpu namespace.

PiperOrigin-RevId: 332925458
Change-Id: Ibdecb012419112f6a82e4c5bf88160b2c2f71033
This commit is contained in:
Skye Wanderman-Milne 2020-09-21 13:48:44 -07:00 committed by TensorFlower Gardener
parent 64e13aa79c
commit 81282b0a40
14 changed files with 39 additions and 15 deletions

View File

@ -247,7 +247,7 @@ class TpuCompiler : public Compiler {
~TpuCompiler() override { ExecutorApiFn()->TpuCompiler_FreeFn(compiler_); }
stream_executor::Platform::Id PlatformId() const override {
return tensorflow::TpuPlatform::kId;
return tensorflow::tpu::TpuPlatform::kId;
}
StatusOr<std::unique_ptr<HloModule>> RunHloPasses(
@ -267,7 +267,7 @@ class TpuCompiler : public Compiler {
StatusHelper status;
ExecutorApiFn()->TpuCompiler_RunHloPassesFn(
compiler_, &hlo_module,
static_cast<tensorflow::TpuExecutor*>(executor->implementation())
static_cast<tensorflow::tpu::TpuExecutor*>(executor->implementation())
->se_executor(),
&allocator, &result, status.c_status);
if (!status.ok()) {
@ -305,7 +305,7 @@ class TpuCompiler : public Compiler {
StatusHelper status;
ExecutorApiFn()->TpuCompiler_RunBackendFn(
compiler_, &hlo_module,
static_cast<tensorflow::TpuExecutor*>(executor->implementation())
static_cast<tensorflow::tpu::TpuExecutor*>(executor->implementation())
->se_executor(),
&allocator, &result, status.c_status);
if (!status.ok()) {
@ -345,7 +345,7 @@ class TpuCompiler : public Compiler {
se_lists_storage.emplace_back(stream_exec[i].size());
se_lists[i].exec = se_lists_storage.back().data();
for (int j = 0; j < stream_exec[i].size(); ++j) {
se_lists[i].exec[j] = static_cast<tensorflow::TpuExecutor*>(
se_lists[i].exec[j] = static_cast<tensorflow::tpu::TpuExecutor*>(
stream_exec[i][j]->implementation())
->se_executor();
}
@ -415,9 +415,9 @@ class TpuCompiler : public Compiler {
};
static bool InitModule() {
xla::Compiler::RegisterCompilerFactory(tensorflow::TpuPlatform::kId, []() {
return absl::make_unique<TpuCompiler>();
});
xla::Compiler::RegisterCompilerFactory(
tensorflow::tpu::TpuPlatform::kId,
[]() { return absl::make_unique<TpuCompiler>(); });
return true;
}

View File

@ -21,6 +21,7 @@ limitations under the License.
#include "tensorflow/stream_executor/tpu/tpu_platform.h"
namespace tensorflow {
namespace tpu {
namespace {
class TpuSystemDeviceFactory : public DeviceFactory {
@ -75,4 +76,5 @@ void RegisterTpuSystemDevice() {
REGISTER_LOCAL_DEVICE_FACTORY(DEVICE_TPU_SYSTEM, TpuSystemDeviceFactory);
}
} // namespace tpu
} // namespace tensorflow

View File

@ -17,9 +17,11 @@ limitations under the License.
#define TENSORFLOW_CORE_TPU_TPU_SYSTEM_DEVICE_H_
namespace tensorflow {
namespace tpu {
void RegisterTpuSystemDevice();
} // namespace tpu
} // namespace tensorflow
#endif // TENSORFLOW_CORE_TPU_TPU_SYSTEM_DEVICE_H_

View File

@ -72,8 +72,8 @@ static std::unique_ptr<xla::ComputationPlacer> CreateTpuComputationPlacer() {
}
static bool InitModule() {
xla::ComputationPlacer::RegisterComputationPlacer(
tensorflow::TpuPlatform::kId, CreateTpuComputationPlacer);
xla::ComputationPlacer::RegisterComputationPlacer(TpuPlatform::kId,
CreateTpuComputationPlacer);
return true;
}
static bool module_initialized = InitModule();

View File

@ -19,6 +19,9 @@ limitations under the License.
#include "tensorflow/stream_executor/stream_executor_internal.h"
#include "tensorflow/stream_executor/tpu/c_api_decl.h"
namespace tensorflow {
namespace tpu {
class TpuEvent : public ::stream_executor::internal::EventInterface {
public:
explicit TpuEvent(SE_Event* event) : event_(event) {}
@ -30,4 +33,7 @@ class TpuEvent : public ::stream_executor::internal::EventInterface {
SE_Event* event_;
};
} // namespace tpu
} // namespace tensorflow
#endif // TENSORFLOW_STREAM_EXECUTOR_TPU_TPU_EVENT_H_

View File

@ -73,7 +73,7 @@ Status TpuExecutable::LoadProgramAndEnqueueToStream(
c_dev_assign.size = dev_assign_serialized.size;
}
auto platform = tensorflow::down_cast<tensorflow::TpuPlatform*>(
auto platform = tensorflow::down_cast<tensorflow::tpu::TpuPlatform*>(
tensorflow::tpu::TpuPlatformInterface::GetRegisteredPlatform());
auto stream = platform->stream_map()->at(
run_options.run_options().stream()->implementation());

View File

@ -26,6 +26,7 @@ limitations under the License.
using stream_executor::DeviceMemoryBase;
namespace tensorflow {
namespace tpu {
namespace {
using ::stream_executor::port::Status;
@ -372,4 +373,5 @@ TpuExecutor::CreateDeviceDescription() const {
return status.status();
}
} // namespace tpu
} // namespace tensorflow

View File

@ -37,6 +37,7 @@ limitations under the License.
#include "tensorflow/stream_executor/tpu/tpu_stream.h"
namespace tensorflow {
namespace tpu {
class TpuExecutor : public tensorflow::tpu::TpuExecutorInterface {
public:
@ -237,6 +238,7 @@ class TpuExecutor : public tensorflow::tpu::TpuExecutorInterface {
SE_StreamExecutor* executor_;
};
} // namespace tpu
} // namespace tensorflow
#endif // TENSORFLOW_STREAM_EXECUTOR_TPU_TPU_EXECUTOR_H_

View File

@ -22,6 +22,7 @@ limitations under the License.
#include "tensorflow/stream_executor/tpu/tpu_executor.h"
namespace tensorflow {
namespace tpu {
PLATFORM_DEFINE_ID(TpuPlatform::kId);
TpuPlatform* tpu_registered_platform = nullptr;
@ -99,8 +100,7 @@ TpuPlatform::GetUncachedExecutor(
return status.status();
}
return std::make_unique<stream_executor::StreamExecutor>(
this, std::make_unique<tensorflow::TpuExecutor>(this, executor),
config.ordinal);
this, std::make_unique<TpuExecutor>(this, executor), config.ordinal);
}
::stream_executor::Platform::Id TpuPlatform::id() const {
@ -165,9 +165,9 @@ Status TpuPlatform::TpuMemoryLimit(int64* memory_limit) {
bool RegisterTpuPlatform() {
static bool tpu_platform_registered = false;
if (!tpu_platform_registered) {
tensorflow::tpu_registered_platform = new tensorflow::TpuPlatform();
tpu_registered_platform = new TpuPlatform();
std::unique_ptr<stream_executor::Platform> platform(
tensorflow::tpu_registered_platform);
tpu_registered_platform);
SE_CHECK_OK(stream_executor::MultiPlatformManager::RegisterPlatform(
std::move(platform)));
tpu_platform_registered = true;
@ -175,4 +175,5 @@ bool RegisterTpuPlatform() {
return true;
}
} // namespace tpu
} // namespace tensorflow

View File

@ -28,6 +28,7 @@ limitations under the License.
#include "tensorflow/stream_executor/tpu/tpu_platform_interface.h"
namespace tensorflow {
namespace tpu {
class TpuPlatform : public ::tensorflow::tpu::TpuPlatformInterface {
public:
@ -147,6 +148,7 @@ class TpuPlatform : public ::tensorflow::tpu::TpuPlatformInterface {
bool RegisterTpuPlatform();
} // namespace tpu
} // namespace tensorflow
#endif // TENSORFLOW_STREAM_EXECUTOR_TPU_TPU_PLATFORM_H_

View File

@ -17,7 +17,8 @@ limitations under the License.
#include "tensorflow/stream_executor/tpu/tpu_platform.h"
#if defined(PLATFORM_GOOGLE)
REGISTER_MODULE_INITIALIZER(tpu_platform, tensorflow::RegisterTpuPlatform());
REGISTER_MODULE_INITIALIZER(tpu_platform,
tensorflow::tpu::RegisterTpuPlatform());
DECLARE_MODULE_INITIALIZER(multi_platform_manager);
DECLARE_MODULE_INITIALIZER(multi_platform_manager_listener);

View File

@ -30,6 +30,7 @@ limitations under the License.
#include "tensorflow/stream_executor/tpu/tpu_platform.h"
namespace tensorflow {
namespace tpu {
using Status = stream_executor::port::Status;
template <typename T>
@ -319,4 +320,5 @@ Status TpuTransferManager::LinearizeToBuffers(
return status.status();
}
} // namespace tpu
} // namespace tensorflow

View File

@ -25,6 +25,7 @@ limitations under the License.
#include "tensorflow/stream_executor/tpu/tpu_transfer_manager_interface.h"
namespace tensorflow {
namespace tpu {
class TpuTransferManager : public xla::TpuTransferManagerInterface {
public:
@ -92,6 +93,7 @@ class TpuTransferManager : public xla::TpuTransferManagerInterface {
XLA_TransferManager* manager_;
};
} // namespace tpu
} // namespace tensorflow
#endif // TENSORFLOW_STREAM_EXECUTOR_TPU_TPU_TRANSFER_MANAGER_H_

View File

@ -20,6 +20,7 @@ limitations under the License.
#include "tensorflow/stream_executor/tpu/tpu_transfer_manager.h"
namespace tensorflow {
namespace tpu {
static std::unique_ptr<xla::TransferManager> CreateTpuTransferManager() {
return std::make_unique<TpuTransferManager>();
@ -32,4 +33,5 @@ static bool InitModule() {
}
static bool module_initialized = InitModule();
} // namespace tpu
} // namespace tensorflow