diff --git a/tensorflow/compiler/xla/executable_run_options.cc b/tensorflow/compiler/xla/executable_run_options.cc index 1700c977189..99b8f0558e6 100644 --- a/tensorflow/compiler/xla/executable_run_options.cc +++ b/tensorflow/compiler/xla/executable_run_options.cc @@ -36,12 +36,12 @@ DeviceMemoryAllocator* ExecutableRunOptions::allocator() const { } ExecutableRunOptions& ExecutableRunOptions::set_stream( - perftools::gputools::Stream* stream) { + stream_executor::Stream* stream) { stream_ = stream; return *this; } -perftools::gputools::Stream* ExecutableRunOptions::stream() const { +stream_executor::Stream* ExecutableRunOptions::stream() const { return stream_; } diff --git a/tensorflow/compiler/xla/executable_run_options.h b/tensorflow/compiler/xla/executable_run_options.h index 2c1d9ffff10..1a095a82cca 100644 --- a/tensorflow/compiler/xla/executable_run_options.h +++ b/tensorflow/compiler/xla/executable_run_options.h @@ -16,29 +16,31 @@ limitations under the License. #ifndef TENSORFLOW_COMPILER_XLA_EXECUTABLE_RUN_OPTIONS_H_ #define TENSORFLOW_COMPILER_XLA_EXECUTABLE_RUN_OPTIONS_H_ -// Intentionally forward declared so that ExecutableRunOptions can be linked +// These classes are forward declared so that ExecutableRunOptions can be linked // into an XLA-compiled binary without having to link all of the pointed-to // objects (e.g., for an ahead-of-time compiled CPU binary, the gpu tools don't // need to be linked). -namespace perftools { -namespace gputools { +namespace stream_executor { class Stream; class Platform; -} -} +} // namespace stream_executor namespace tensorflow { namespace thread { class ThreadPool; -} -} +} // namespace thread +} // namespace tensorflow namespace Eigen { struct ThreadPoolDevice; -} +} // namespace Eigen namespace xla { +// TODO(b/77980417): Once the perftools::gputools -> stream_executor migration +// is complete, add "using namespace se = stream_executor" here and +// s/stream_executor/se::/ to match our idiom elsewhere. + class DeviceMemoryAllocator; class DeviceAssignment; class ExecutionProfile; @@ -61,8 +63,8 @@ class ExecutableRunOptions { // If set, this is the stream to run the computation on. The platform of the // stream must match the platform the executable was built for. A value of // nullptr indicates the option has not been set. - ExecutableRunOptions& set_stream(perftools::gputools::Stream* stream); - perftools::gputools::Stream* stream() const; + ExecutableRunOptions& set_stream(stream_executor::Stream* stream); + stream_executor::Stream* stream() const; // Sets the thread pool on which to run parallel CPU backend // computations. Does not take ownership. @@ -91,7 +93,7 @@ class ExecutableRunOptions { DeviceMemoryAllocator* allocator_ = nullptr; int device_ordinal_ = -1; DeviceAssignment* device_assignment_ = nullptr; - perftools::gputools::Stream* stream_ = nullptr; + stream_executor::Stream* stream_ = nullptr; tensorflow::thread::ThreadPool* inter_op_thread_pool_ = nullptr; const Eigen::ThreadPoolDevice* intra_op_thread_pool_ = nullptr; ExecutionProfile* execution_profile_ = nullptr; diff --git a/tensorflow/compiler/xla/types.h b/tensorflow/compiler/xla/types.h index 9fa4297523b..20f3f1b957c 100644 --- a/tensorflow/compiler/xla/types.h +++ b/tensorflow/compiler/xla/types.h @@ -46,4 +46,12 @@ using ::Eigen::half; } // namespace xla +// Alias namespace ::stream_executor as ::xla::se. +namespace stream_executor {} +namespace xla { +// TODO(b/77980417): Uncomment this once all namespace aliases named 'se' are +// removed in ::xla. +// namespace se = ::stream_executor; +} // namespace xla + #endif // TENSORFLOW_COMPILER_XLA_TYPES_H_ diff --git a/tensorflow/core/common_runtime/gpu/gpu_event_mgr.h b/tensorflow/core/common_runtime/gpu/gpu_event_mgr.h index d23898e1f26..fd5f50ca4ea 100644 --- a/tensorflow/core/common_runtime/gpu/gpu_event_mgr.h +++ b/tensorflow/core/common_runtime/gpu/gpu_event_mgr.h @@ -29,13 +29,11 @@ limitations under the License. #include "tensorflow/core/platform/thread_annotations.h" #include "tensorflow/core/platform/types.h" -namespace perftools { -namespace gputools { +namespace stream_executor { class Event; class Stream; class StreamExecutor; -} // namespace gputools -} // namespace perftools +} // namespace stream_executor namespace tensorflow { diff --git a/tensorflow/core/common_runtime/gpu/gpu_id_utils.h b/tensorflow/core/common_runtime/gpu/gpu_id_utils.h index 2e90687fe88..5c503d12616 100644 --- a/tensorflow/core/common_runtime/gpu/gpu_id_utils.h +++ b/tensorflow/core/common_runtime/gpu/gpu_id_utils.h @@ -23,7 +23,10 @@ limitations under the License. #include "tensorflow/core/platform/stream_executor.h" namespace tensorflow { -namespace gpu = ::perftools::gputools; + +// TODO(b/77980417): Remove this and use the regular tensorflow::se alias once +// that's available. +namespace gpu = ::stream_executor; // Utility methods for translation between Tensorflow GPU ids and CUDA GPU ids. class GpuIdUtil { diff --git a/tensorflow/core/common_runtime/gpu/gpu_init.h b/tensorflow/core/common_runtime/gpu/gpu_init.h index 927d05d5ba9..bfd7a77f833 100644 --- a/tensorflow/core/common_runtime/gpu/gpu_init.h +++ b/tensorflow/core/common_runtime/gpu/gpu_init.h @@ -18,11 +18,9 @@ limitations under the License. #include "tensorflow/core/lib/core/status.h" -namespace perftools { -namespace gputools { +namespace stream_executor { class Platform; -} // namespace gputools -} // namespace perftools +} // namespace stream_executor namespace tensorflow { @@ -34,7 +32,7 @@ Status ValidateGPUMachineManager(); // initializing the GPUs on the machine if needed the first time it is // called. Must only be called when there is a valid GPU environment // in the process (e.g., ValidateGPUMachineManager() returns OK). -perftools::gputools::Platform* GPUMachineManager(); +stream_executor::Platform* GPUMachineManager(); } // namespace tensorflow diff --git a/tensorflow/core/common_runtime/gpu/gpu_util.cc b/tensorflow/core/common_runtime/gpu/gpu_util.cc index a0f5877d62f..5214ceaae57 100644 --- a/tensorflow/core/common_runtime/gpu/gpu_util.cc +++ b/tensorflow/core/common_runtime/gpu/gpu_util.cc @@ -60,7 +60,9 @@ using perftools::gputools::Stream; namespace tensorflow { -namespace gpu = ::perftools::gputools; +// TODO(b/77980417): Remove this and use the regular tensorflow::se alias once +// that's available. +namespace gpu = ::stream_executor; Status PrepareCopy(Device* device, const DeviceContext* ctx, const Tensor& src, const Tensor* dst, diff --git a/tensorflow/core/common_runtime/gpu/gpu_util.h b/tensorflow/core/common_runtime/gpu/gpu_util.h index d99a0b1f611..337dc89895c 100644 --- a/tensorflow/core/common_runtime/gpu/gpu_util.h +++ b/tensorflow/core/common_runtime/gpu/gpu_util.h @@ -27,7 +27,9 @@ namespace tensorflow { class RecvTensorResponse; class TensorProto; -namespace gpu = ::perftools::gputools; +// TODO(b/77980417): Remove this and use the regular tensorflow::se alias once +// that's available. +namespace gpu = ::stream_executor; class GPUUtil { public: diff --git a/tensorflow/core/common_runtime/gpu_device_context.h b/tensorflow/core/common_runtime/gpu_device_context.h index 8b1430f0219..38a18cd0877 100644 --- a/tensorflow/core/common_runtime/gpu_device_context.h +++ b/tensorflow/core/common_runtime/gpu_device_context.h @@ -19,23 +19,22 @@ limitations under the License. #include "tensorflow/core/common_runtime/device.h" #include "tensorflow/core/framework/device_base.h" -namespace perftools { -namespace gputools { +namespace stream_executor { class Stream; -} // namespace gputools -} // namespace perftools +} // namespace stream_executor namespace tensorflow { -namespace gpu = ::perftools::gputools; +// TODO(b/77980417): Replace stream_executor:: with se:: once our namespace +// migration is complete and the alias is available. class GPUDeviceContext : public DeviceContext { public: // Does not take ownership of streams. - GPUDeviceContext(int stream_id, gpu::Stream* stream, - gpu::Stream* host_to_device_stream, - gpu::Stream* device_to_host_stream, - gpu::Stream* device_to_device_stream) + GPUDeviceContext(int stream_id, stream_executor::Stream* stream, + stream_executor::Stream* host_to_device_stream, + stream_executor::Stream* device_to_host_stream, + stream_executor::Stream* device_to_device_stream) : stream_id_(stream_id), stream_(stream), host_to_device_stream_(host_to_device_stream), @@ -44,10 +43,14 @@ class GPUDeviceContext : public DeviceContext { ~GPUDeviceContext() override {} - gpu::Stream* stream() const override { return stream_; } - gpu::Stream* host_to_device_stream() const { return host_to_device_stream_; } - gpu::Stream* device_to_host_stream() const { return device_to_host_stream_; } - gpu::Stream* device_to_device_stream() const { + stream_executor::Stream* stream() const override { return stream_; } + stream_executor::Stream* host_to_device_stream() const { + return host_to_device_stream_; + } + stream_executor::Stream* device_to_host_stream() const { + return device_to_host_stream_; + } + stream_executor::Stream* device_to_device_stream() const { return device_to_device_stream_; } int stream_id() const { return stream_id_; } @@ -67,13 +70,13 @@ class GPUDeviceContext : public DeviceContext { int stream_id_; // The default primary stream to use for this context. // All the memory belongs to this stream. - gpu::Stream* stream_; + stream_executor::Stream* stream_; // The stream to use for copy data from host into GPU. - gpu::Stream* host_to_device_stream_; + stream_executor::Stream* host_to_device_stream_; // The stream to use for copy data from GPU to host. - gpu::Stream* device_to_host_stream_; + stream_executor::Stream* device_to_host_stream_; // The stream to use for copy data between GPU. - gpu::Stream* device_to_device_stream_; + stream_executor::Stream* device_to_device_stream_; }; } // namespace tensorflow diff --git a/tensorflow/core/framework/device_base.h b/tensorflow/core/framework/device_base.h index 8473b228d3d..223b74857d0 100644 --- a/tensorflow/core/framework/device_base.h +++ b/tensorflow/core/framework/device_base.h @@ -34,11 +34,9 @@ struct SyclDevice; #endif } // end namespace Eigen -namespace perftools { -namespace gputools { +namespace stream_executor { class Stream; -} // namespace gputools -} // namespace perftools +} // namespace stream_executor namespace tensorflow { @@ -69,9 +67,10 @@ class PerOpGpuDevice { class DeviceContext : public core::RefCounted { public: ~DeviceContext() override {} - virtual perftools::gputools::Stream* stream() const { return nullptr; } - virtual void MaintainLifetimeOnStream( - const Tensor* t, perftools::gputools::Stream* stream) const {} + virtual stream_executor::Stream* stream() const { return nullptr; } + virtual void MaintainLifetimeOnStream(const Tensor* t, + stream_executor::Stream* stream) const { + } // "cpu_tensor" is a tensor on a CPU. Copies "cpu_tensor" into // "device_tensor" which is on a GPU device "device". "device_tensor" @@ -133,7 +132,7 @@ class DeviceBase { // but also by TPU devices (to provide default device context). struct GpuDeviceInfo { // Make sure all the defaults are NULL, so we can spot missing assignments. - perftools::gputools::Stream* stream = nullptr; + stream_executor::Stream* stream = nullptr; DeviceContext* default_context = nullptr; EventMgr* event_mgr = nullptr; int gpu_id = -1; diff --git a/tensorflow/core/platform/default/from_stream_executor_status.h b/tensorflow/core/platform/default/from_stream_executor_status.h index 2a2297a6577..36a67a36488 100644 --- a/tensorflow/core/platform/default/from_stream_executor_status.h +++ b/tensorflow/core/platform/default/from_stream_executor_status.h @@ -23,8 +23,6 @@ limitations under the License. namespace tensorflow { -namespace gpu = ::perftools::gputools; - // On the open-source platform, stream_executor currently uses // tensorflow::Status inline Status FromStreamExecutorStatus( diff --git a/tensorflow/core/platform/stream_executor.h b/tensorflow/core/platform/stream_executor.h index f31e556a708..006184ddeff 100644 --- a/tensorflow/core/platform/stream_executor.h +++ b/tensorflow/core/platform/stream_executor.h @@ -37,4 +37,15 @@ limitations under the License. #include "tensorflow/stream_executor/stream.h" #include "tensorflow/stream_executor/stream_executor.h" +namespace perftools { +namespace gputools { + +// Temporarily pull stream_executor into perftools::gputools while we migrate +// code to the new namespace. TODO(jlebar): Remove this once we've completed +// the migration. +using namespace stream_executor; // NOLINT[build/namespaces] + +} // namespace gputools +} // namespace perftools + #endif // TENSORFLOW_PLATFORM_STREAM_EXECUTOR_H_ diff --git a/tensorflow/core/platform/types.h b/tensorflow/core/platform/types.h index 6308e588470..f2471712cca 100644 --- a/tensorflow/core/platform/types.h +++ b/tensorflow/core/platform/types.h @@ -60,4 +60,12 @@ typedef uint64 Fprint; } // namespace tensorflow +// Alias namespace ::stream_executor as ::tensorflow::se. +namespace stream_executor {} +namespace tensorflow { +// TODO(b/77980417): Uncomment this once all namespace aliases named 'se' are +// removed in ::xla. +// namespace se = ::stream_executor; +} // namespace tensorflow + #endif // TENSORFLOW_PLATFORM_TYPES_H_ diff --git a/tensorflow/stream_executor/blas.cc b/tensorflow/stream_executor/blas.cc index 31724cf6c9b..906d6fb7020 100644 --- a/tensorflow/stream_executor/blas.cc +++ b/tensorflow/stream_executor/blas.cc @@ -17,8 +17,7 @@ limitations under the License. #include "tensorflow/stream_executor/lib/strcat.h" -namespace perftools { -namespace gputools { +namespace stream_executor { namespace blas { string TransposeString(Transpose t) { @@ -95,5 +94,4 @@ std::ostream& operator<<(std::ostream& os, ComputationType ty) { } } // namespace blas -} // namespace gputools -} // namespace perftools +} // namespace stream_executor diff --git a/tensorflow/stream_executor/blas.h b/tensorflow/stream_executor/blas.h index c5f778a5c74..6e62b85728a 100644 --- a/tensorflow/stream_executor/blas.h +++ b/tensorflow/stream_executor/blas.h @@ -49,8 +49,7 @@ namespace Eigen { struct half; } // namespace Eigen -namespace perftools { -namespace gputools { +namespace stream_executor { class Stream; class ScratchAllocator; @@ -2100,7 +2099,6 @@ class BlasSupport { DeviceMemory> *b, int ldb) override; } // namespace blas -} // namespace gputools -} // namespace perftools +} // namespace stream_executor #endif // TENSORFLOW_STREAM_EXECUTOR_BLAS_H_ diff --git a/tensorflow/stream_executor/cuda/cuda_activation.cc b/tensorflow/stream_executor/cuda/cuda_activation.cc index 5f4cf9dbd78..cf6b9e2c6e4 100644 --- a/tensorflow/stream_executor/cuda/cuda_activation.cc +++ b/tensorflow/stream_executor/cuda/cuda_activation.cc @@ -19,8 +19,7 @@ limitations under the License. #include "tensorflow/stream_executor/stream_executor.h" #include "tensorflow/stream_executor/stream_executor_internal.h" -namespace perftools { -namespace gputools { +namespace stream_executor { namespace cuda { CudaContext* ExtractCudaContext(CUDAExecutor *cuda_exec); @@ -40,5 +39,4 @@ ScopedActivateExecutorContext::~ScopedActivateExecutorContext() { } } // namespace cuda -} // namespace gputools -} // namespace perftools +} // namespace stream_executor diff --git a/tensorflow/stream_executor/cuda/cuda_activation.h b/tensorflow/stream_executor/cuda/cuda_activation.h index c9d43a9766e..04ffaef3646 100644 --- a/tensorflow/stream_executor/cuda/cuda_activation.h +++ b/tensorflow/stream_executor/cuda/cuda_activation.h @@ -25,8 +25,7 @@ limitations under the License. #include "tensorflow/stream_executor/platform/port.h" -namespace perftools { -namespace gputools { +namespace stream_executor { class StreamExecutor; @@ -56,7 +55,6 @@ class ScopedActivateExecutorContext { }; } // namespace cuda -} // namespace gputools -} // namespace perftools +} // namespace stream_executor #endif // TENSORFLOW_STREAM_EXECUTOR_CUDA_CUDA_ACTIVATION_H_ diff --git a/tensorflow/stream_executor/cuda/cuda_blas.cc b/tensorflow/stream_executor/cuda/cuda_blas.cc index 1c550dbb136..007c0f1c86c 100644 --- a/tensorflow/stream_executor/cuda/cuda_blas.cc +++ b/tensorflow/stream_executor/cuda/cuda_blas.cc @@ -75,15 +75,14 @@ limitations under the License. #include "tensorflow/stream_executor/scratch_allocator.h" #include "tensorflow/stream_executor/stream_executor.h" -namespace perftools { -namespace gputools { +namespace stream_executor { namespace cuda { PLUGIN_REGISTRY_DEFINE_PLUGIN_ID(kCuBlasPlugin); namespace wrap { -#define PERFTOOLS_GPUTOOLS_CUBLAS_WRAP(__name) \ +#define STREAM_EXECUTOR_CUBLAS_WRAP(__name) \ struct WrapperShim__##__name { \ static const char *kName; \ template \ @@ -94,8 +93,8 @@ namespace wrap { } __name; \ const char *WrapperShim__##__name::kName = #__name; -#define PERFTOOLS_GPUTOOLS_CUBLAS_V2_WRAP(__name) \ - PERFTOOLS_GPUTOOLS_CUBLAS_WRAP(__name) +#define STREAM_EXECUTOR_CUBLAS_V2_WRAP(__name) \ + STREAM_EXECUTOR_CUBLAS_WRAP(__name) #define CUBLAS_BLAS_ROUTINE_EACH(__macro) \ __macro(cublasSnrm2) \ @@ -269,28 +268,28 @@ namespace wrap { __macro(cublasCdgmm) \ __macro(cublasZdgmm) -PERFTOOLS_GPUTOOLS_CUBLAS_V2_WRAP(cublasCreate) -PERFTOOLS_GPUTOOLS_CUBLAS_V2_WRAP(cublasDestroy) -PERFTOOLS_GPUTOOLS_CUBLAS_V2_WRAP(cublasSetStream) -PERFTOOLS_GPUTOOLS_CUBLAS_V2_WRAP(cublasSetPointerMode) -PERFTOOLS_GPUTOOLS_CUBLAS_V2_WRAP(cublasGetPointerMode) -PERFTOOLS_GPUTOOLS_CUBLAS_WRAP(cublasSgemmBatched) -PERFTOOLS_GPUTOOLS_CUBLAS_WRAP(cublasDgemmBatched) -PERFTOOLS_GPUTOOLS_CUBLAS_WRAP(cublasCgemmBatched) -PERFTOOLS_GPUTOOLS_CUBLAS_WRAP(cublasZgemmBatched) -CUBLAS_BLAS_ROUTINE_EACH(PERFTOOLS_GPUTOOLS_CUBLAS_V2_WRAP) +STREAM_EXECUTOR_CUBLAS_V2_WRAP(cublasCreate) +STREAM_EXECUTOR_CUBLAS_V2_WRAP(cublasDestroy) +STREAM_EXECUTOR_CUBLAS_V2_WRAP(cublasSetStream) +STREAM_EXECUTOR_CUBLAS_V2_WRAP(cublasSetPointerMode) +STREAM_EXECUTOR_CUBLAS_V2_WRAP(cublasGetPointerMode) +STREAM_EXECUTOR_CUBLAS_WRAP(cublasSgemmBatched) +STREAM_EXECUTOR_CUBLAS_WRAP(cublasDgemmBatched) +STREAM_EXECUTOR_CUBLAS_WRAP(cublasCgemmBatched) +STREAM_EXECUTOR_CUBLAS_WRAP(cublasZgemmBatched) +CUBLAS_BLAS_ROUTINE_EACH(STREAM_EXECUTOR_CUBLAS_V2_WRAP) #if CUDA_VERSION >= 7050 -PERFTOOLS_GPUTOOLS_CUBLAS_WRAP(cublasSgemmEx) +STREAM_EXECUTOR_CUBLAS_WRAP(cublasSgemmEx) #endif #if CUDA_VERSION >= 8000 -PERFTOOLS_GPUTOOLS_CUBLAS_WRAP(cublasGemmEx) +STREAM_EXECUTOR_CUBLAS_WRAP(cublasGemmEx) #endif #if CUDA_VERSION >= 9000 -PERFTOOLS_GPUTOOLS_CUBLAS_WRAP(cublasGetMathMode) -PERFTOOLS_GPUTOOLS_CUBLAS_WRAP(cublasSetMathMode) +STREAM_EXECUTOR_CUBLAS_WRAP(cublasGetMathMode) +STREAM_EXECUTOR_CUBLAS_WRAP(cublasSetMathMode) #endif } // namespace wrap @@ -2803,46 +2802,39 @@ bool CUDABlas::DoBlasTrsm(Stream *stream, blas::Side side, } // namespace cuda -namespace gpu = ::perftools::gputools; - void initialize_cublas() { - gpu::port::Status status = - gpu::PluginRegistry::Instance() - ->RegisterFactory( - gpu::cuda::kCudaPlatformId, gpu::cuda::kCuBlasPlugin, "cuBLAS", - [](gpu::internal::StreamExecutorInterface - *parent) -> gpu::blas::BlasSupport * { - gpu::cuda::CUDAExecutor *cuda_executor = - dynamic_cast(parent); - if (cuda_executor == nullptr) { - LOG(ERROR) - << "Attempting to initialize an instance of the cuBLAS " - << "support library with a non-CUDA StreamExecutor"; - return nullptr; - } + port::Status status = + PluginRegistry::Instance()->RegisterFactory( + cuda::kCudaPlatformId, cuda::kCuBlasPlugin, "cuBLAS", + [](internal::StreamExecutorInterface *parent) -> blas::BlasSupport * { + cuda::CUDAExecutor *cuda_executor = + dynamic_cast(parent); + if (cuda_executor == nullptr) { + LOG(ERROR) + << "Attempting to initialize an instance of the cuBLAS " + << "support library with a non-CUDA StreamExecutor"; + return nullptr; + } - gpu::cuda::CUDABlas *blas = - new gpu::cuda::CUDABlas(cuda_executor); - if (!blas->Init()) { - // Note: Init() will log a more specific error. - delete blas; - return nullptr; - } - return blas; - }); + cuda::CUDABlas *blas = new cuda::CUDABlas(cuda_executor); + if (!blas->Init()) { + // Note: Init() will log a more specific error. + delete blas; + return nullptr; + } + return blas; + }); if (!status.ok()) { LOG(ERROR) << "Unable to register cuBLAS factory: " << status.error_message(); } - gpu::PluginRegistry::Instance()->SetDefaultFactory(gpu::cuda::kCudaPlatformId, - gpu::PluginKind::kBlas, - gpu::cuda::kCuBlasPlugin); + PluginRegistry::Instance()->SetDefaultFactory( + cuda::kCudaPlatformId, PluginKind::kBlas, cuda::kCuBlasPlugin); } -} // namespace gputools -} // namespace perftools +} // namespace stream_executor REGISTER_MODULE_INITIALIZER(register_cublas, - { perftools::gputools::initialize_cublas(); }); + { stream_executor::initialize_cublas(); }); diff --git a/tensorflow/stream_executor/cuda/cuda_blas.h b/tensorflow/stream_executor/cuda/cuda_blas.h index deb211c04bc..55c414a1f92 100644 --- a/tensorflow/stream_executor/cuda/cuda_blas.h +++ b/tensorflow/stream_executor/cuda/cuda_blas.h @@ -29,8 +29,7 @@ limitations under the License. typedef struct cublasContext *cublasHandle_t; -namespace perftools { -namespace gputools { +namespace stream_executor { class Stream; @@ -162,7 +161,6 @@ class CUDABlas : public blas::BlasSupport { }; } // namespace cuda -} // namespace gputools -} // namespace perftools +} // namespace stream_executor #endif // TENSORFLOW_STREAM_EXECUTOR_CUDA_CUDA_BLAS_H_ diff --git a/tensorflow/stream_executor/cuda/cuda_diagnostics.cc b/tensorflow/stream_executor/cuda/cuda_diagnostics.cc index 933c103f524..feb529297e8 100644 --- a/tensorflow/stream_executor/cuda/cuda_diagnostics.cc +++ b/tensorflow/stream_executor/cuda/cuda_diagnostics.cc @@ -51,8 +51,7 @@ limitations under the License. #include "tensorflow/stream_executor/lib/str_util.h" #include "tensorflow/stream_executor/lib/inlined_vector.h" -namespace perftools { -namespace gputools { +namespace stream_executor { namespace cuda { #ifdef __APPLE__ @@ -384,5 +383,4 @@ port::StatusOr Diagnostician::FindKernelDriverVersion() { } // namespace cuda -} // namespace gputools -} // namespace perftools +} // namespace stream_executor diff --git a/tensorflow/stream_executor/cuda/cuda_diagnostics.h b/tensorflow/stream_executor/cuda/cuda_diagnostics.h index aa68321acc8..f2db2eb20a1 100644 --- a/tensorflow/stream_executor/cuda/cuda_diagnostics.h +++ b/tensorflow/stream_executor/cuda/cuda_diagnostics.h @@ -22,8 +22,7 @@ limitations under the License. #include "tensorflow/stream_executor/lib/statusor.h" #include "tensorflow/stream_executor/platform/port.h" -namespace perftools { -namespace gputools { +namespace stream_executor { namespace cuda { // e.g. DriverVersion{346, 3, 4} @@ -93,7 +92,6 @@ class Diagnostician { }; } // namespace cuda -} // namespace gputools -} // namespace perftools +} // namespace stream_executor #endif // TENSORFLOW_STREAM_EXECUTOR_CUDA_CUDA_DIAGNOSTICS_H_ diff --git a/tensorflow/stream_executor/cuda/cuda_dnn.cc b/tensorflow/stream_executor/cuda/cuda_dnn.cc index 4a6b2bf5d7d..d673e19007d 100644 --- a/tensorflow/stream_executor/cuda/cuda_dnn.cc +++ b/tensorflow/stream_executor/cuda/cuda_dnn.cc @@ -59,8 +59,7 @@ NarrowT CheckedNarrowing(const WideT& wide) { } // namespace -namespace perftools { -namespace gputools { +namespace stream_executor { using dnn::BatchDescriptor; using dnn::FilterDescriptor; @@ -159,7 +158,7 @@ static port::ThreadPool* GetCudaThreadpool() { return cudnn_threadpool; } -#define PERFTOOLS_GPUTOOLS_CUDNN_WRAP(__name) \ +#define STREAM_EXECUTOR_CUDNN_WRAP(__name) \ struct WrapperShim__##__name { \ template \ cudnnStatus_t operator()(CUDAExecutor* parent, Args... args) { \ @@ -169,7 +168,7 @@ static port::ThreadPool* GetCudaThreadpool() { } \ } __name; -#define PERFTOOLS_GPUTOOLS_CUDNN_WRAP_WITH_CHECKED_STREAM(__name) \ +#define STREAM_EXECUTOR_CUDNN_WRAP_WITH_CHECKED_STREAM(__name) \ struct WrapperShim__##__name { \ template \ cudnnStatus_t operator()(CudnnSupport* dnn, Stream* s, Args... args) \ @@ -220,7 +219,7 @@ struct WrapperShim__cudnnSetStream { __macro(cudnnSetFilterNdDescriptor) // clang-format on -CUDNN_DNN_ROUTINE_EACH(PERFTOOLS_GPUTOOLS_CUDNN_WRAP) +CUDNN_DNN_ROUTINE_EACH(STREAM_EXECUTOR_CUDNN_WRAP) #undef CUDNN_DNN_ROUTINE_EACH // clang-format off @@ -242,7 +241,7 @@ CUDNN_DNN_ROUTINE_EACH(PERFTOOLS_GPUTOOLS_CUDNN_WRAP) // clang-format on CUDNN_DNN_ROUTINE_EACH_WITH_STREAM( - PERFTOOLS_GPUTOOLS_CUDNN_WRAP_WITH_CHECKED_STREAM) + STREAM_EXECUTOR_CUDNN_WRAP_WITH_CHECKED_STREAM) #undef CUDNN_DNN_ROUTINE_EACH_WITH_STREAM // APIs available after R3: @@ -252,7 +251,7 @@ CUDNN_DNN_ROUTINE_EACH_WITH_STREAM( __macro(cudnnGetConvolutionBackwardDataAlgorithm) \ __macro(cudnnGetConvolutionBackwardFilterAlgorithm) \ __macro(cudnnGetConvolutionBackwardDataWorkspaceSize) -CUDNN_DNN_ROUTINE_EACH_AFTER_R3(PERFTOOLS_GPUTOOLS_CUDNN_WRAP) +CUDNN_DNN_ROUTINE_EACH_AFTER_R3(STREAM_EXECUTOR_CUDNN_WRAP) #undef CUDNN_DNN_ROUTINE_EACH_AFTER_R3 #endif @@ -266,7 +265,7 @@ CUDNN_DNN_ROUTINE_EACH_AFTER_R3(PERFTOOLS_GPUTOOLS_CUDNN_WRAP) // clang-format on CUDNN_DNN_ROUTINE_EACH_R3_WITH_STREAM( - PERFTOOLS_GPUTOOLS_CUDNN_WRAP_WITH_CHECKED_STREAM) + STREAM_EXECUTOR_CUDNN_WRAP_WITH_CHECKED_STREAM) #undef CUDNN_DNN_ROUTINE_EACH_R3_WITH_STREAM #endif @@ -293,7 +292,7 @@ CUDNN_DNN_ROUTINE_EACH_R3_WITH_STREAM( __macro(cudnnGetFilterNdDescriptor) // clang-format on -CUDNN_DNN_ROUTINE_EACH_R5(PERFTOOLS_GPUTOOLS_CUDNN_WRAP) +CUDNN_DNN_ROUTINE_EACH_R5(STREAM_EXECUTOR_CUDNN_WRAP) #undef CUDNN_DNN_ROUTINE_EACH_R5 // clang-format off @@ -305,7 +304,7 @@ CUDNN_DNN_ROUTINE_EACH_R5(PERFTOOLS_GPUTOOLS_CUDNN_WRAP) // clang-format on CUDNN_DNN_ROUTINE_EACH_R5_WITH_STREAM( - PERFTOOLS_GPUTOOLS_CUDNN_WRAP_WITH_CHECKED_STREAM) + STREAM_EXECUTOR_CUDNN_WRAP_WITH_CHECKED_STREAM) #undef CUDNN_DNN_ROUTINE_EACH_R5_WITH_STREAM #endif @@ -316,7 +315,7 @@ CUDNN_DNN_ROUTINE_EACH_R5_WITH_STREAM( __macro(cudnnSetRNNDescriptor_v6) // clang-format on -CUDNN_DNN_ROUTINE_EACH_R6(PERFTOOLS_GPUTOOLS_CUDNN_WRAP) +CUDNN_DNN_ROUTINE_EACH_R6(STREAM_EXECUTOR_CUDNN_WRAP) #undef CUDNN_DNN_ROUTINE_EACH_R6 // clang-format off @@ -325,7 +324,7 @@ CUDNN_DNN_ROUTINE_EACH_R6(PERFTOOLS_GPUTOOLS_CUDNN_WRAP) // clang-format on CUDNN_DNN_ROUTINE_EACH_R6_WITH_STREAM( - PERFTOOLS_GPUTOOLS_CUDNN_WRAP_WITH_CHECKED_STREAM) + STREAM_EXECUTOR_CUDNN_WRAP_WITH_CHECKED_STREAM) #undef CUDNN_DNN_ROUTINE_EACH_R6_WITH_STREAM #endif @@ -337,7 +336,7 @@ CUDNN_DNN_ROUTINE_EACH_R6_WITH_STREAM( __macro(cudnnSetRNNMatrixMathType) // clang-format on -CUDNN_DNN_ROUTINE_EACH_R7(PERFTOOLS_GPUTOOLS_CUDNN_WRAP) +CUDNN_DNN_ROUTINE_EACH_R7(STREAM_EXECUTOR_CUDNN_WRAP) #undef CUDNN_DNN_ROUTINE_EACH_R7 #endif @@ -4727,46 +4726,39 @@ bool CudnnSupport::DeriveOutputBatchDescriptor( } // namespace cuda -namespace gpu = ::perftools::gputools; - void initialize_cudnn() { - gpu::port::Status status = - gpu::PluginRegistry::Instance() - ->RegisterFactory( - gpu::cuda::kCudaPlatformId, gpu::cuda::kCuDnnPlugin, "cuDNN", - [](gpu::internal::StreamExecutorInterface* - parent) -> gpu::dnn::DnnSupport* { - gpu::cuda::CUDAExecutor* cuda_executor = - dynamic_cast(parent); - if (cuda_executor == nullptr) { - LOG(ERROR) - << "Attempting to initialize an instance of the cuBLAS " - << "support library with a non-CUDA StreamExecutor"; - return nullptr; - } + port::Status status = + PluginRegistry::Instance()->RegisterFactory( + cuda::kCudaPlatformId, cuda::kCuDnnPlugin, "cuDNN", + [](internal::StreamExecutorInterface* parent) -> dnn::DnnSupport* { + cuda::CUDAExecutor* cuda_executor = + dynamic_cast(parent); + if (cuda_executor == nullptr) { + LOG(ERROR) + << "Attempting to initialize an instance of the cuBLAS " + << "support library with a non-CUDA StreamExecutor"; + return nullptr; + } - gpu::cuda::CudnnSupport* dnn = - new gpu::cuda::CudnnSupport(cuda_executor); - if (!dnn->Init().ok()) { - // Note: Init() will log a more specific error. - delete dnn; - return nullptr; - } - return dnn; - }); + cuda::CudnnSupport* dnn = new cuda::CudnnSupport(cuda_executor); + if (!dnn->Init().ok()) { + // Note: Init() will log a more specific error. + delete dnn; + return nullptr; + } + return dnn; + }); if (!status.ok()) { LOG(ERROR) << "Unable to register cuDNN factory: " << status.error_message(); } - gpu::PluginRegistry::Instance()->SetDefaultFactory(gpu::cuda::kCudaPlatformId, - gpu::PluginKind::kDnn, - gpu::cuda::kCuDnnPlugin); + PluginRegistry::Instance()->SetDefaultFactory( + cuda::kCudaPlatformId, PluginKind::kDnn, cuda::kCuDnnPlugin); } -} // namespace gputools -} // namespace perftools +} // namespace stream_executor REGISTER_MODULE_INITIALIZER(register_cudnn, - { perftools::gputools::initialize_cudnn(); }); + { stream_executor::initialize_cudnn(); }); diff --git a/tensorflow/stream_executor/cuda/cuda_dnn.h b/tensorflow/stream_executor/cuda/cuda_dnn.h index 7518b23757f..e6d12bfef98 100644 --- a/tensorflow/stream_executor/cuda/cuda_dnn.h +++ b/tensorflow/stream_executor/cuda/cuda_dnn.h @@ -26,8 +26,7 @@ limitations under the License. #include "tensorflow/stream_executor/plugin_registry.h" #include "tensorflow/stream_executor/temporary_device_memory.h" -namespace perftools { -namespace gputools { +namespace stream_executor { namespace cuda { class CUDAExecutor; @@ -810,7 +809,6 @@ class CudnnSupport : public dnn::DnnSupport { }; } // namespace cuda -} // namespace gputools -} // namespace perftools +} // namespace stream_executor #endif // TENSORFLOW_STREAM_EXECUTOR_CUDA_CUDA_DNN_H_ diff --git a/tensorflow/stream_executor/cuda/cuda_driver.cc b/tensorflow/stream_executor/cuda/cuda_driver.cc index 58e1e58c593..fedf4f53b85 100644 --- a/tensorflow/stream_executor/cuda/cuda_driver.cc +++ b/tensorflow/stream_executor/cuda/cuda_driver.cc @@ -53,8 +53,7 @@ bool FLAGS_gpuexec_cuda_device_0_only = false; // matches the expected one. constexpr bool kVerifyCudaContext = false; -namespace perftools { -namespace gputools { +namespace stream_executor { namespace cuda { namespace { @@ -1649,5 +1648,4 @@ static port::StatusOr GetSimpleAttribute(CUdevice device, } } // namespace cuda -} // namespace gputools -} // namespace perftools +} // namespace stream_executor diff --git a/tensorflow/stream_executor/cuda/cuda_driver.h b/tensorflow/stream_executor/cuda/cuda_driver.h index fa9172b3f00..a9969e247e1 100644 --- a/tensorflow/stream_executor/cuda/cuda_driver.h +++ b/tensorflow/stream_executor/cuda/cuda_driver.h @@ -27,8 +27,7 @@ limitations under the License. #include "tensorflow/stream_executor/platform/port.h" #include "cuda/include/cuda.h" -namespace perftools { -namespace gputools { +namespace stream_executor { namespace cuda { // Identifies the memory space where an allocation resides. See @@ -506,7 +505,6 @@ class CudaContext { }; } // namespace cuda -} // namespace gputools -} // namespace perftools +} // namespace stream_executor #endif // TENSORFLOW_STREAM_EXECUTOR_CUDA_CUDA_DRIVER_H_ diff --git a/tensorflow/stream_executor/cuda/cuda_event.cc b/tensorflow/stream_executor/cuda/cuda_event.cc index 1b41502300d..96dcf173566 100644 --- a/tensorflow/stream_executor/cuda/cuda_event.cc +++ b/tensorflow/stream_executor/cuda/cuda_event.cc @@ -19,8 +19,7 @@ limitations under the License. #include "tensorflow/stream_executor/cuda/cuda_stream.h" #include "tensorflow/stream_executor/lib/statusor.h" -namespace perftools { -namespace gputools { +namespace stream_executor { namespace cuda { CUDAEvent::CUDAEvent(CUDAExecutor* parent) @@ -68,5 +67,4 @@ const CUevent& CUDAEvent::cuda_event() { } } // namespace cuda -} // namespace gputools -} // namespace perftools +} // namespace stream_executor diff --git a/tensorflow/stream_executor/cuda/cuda_event.h b/tensorflow/stream_executor/cuda/cuda_event.h index 56667e65d38..f62344672ed 100644 --- a/tensorflow/stream_executor/cuda/cuda_event.h +++ b/tensorflow/stream_executor/cuda/cuda_event.h @@ -21,8 +21,7 @@ limitations under the License. #include "tensorflow/stream_executor/event.h" #include "tensorflow/stream_executor/lib/status.h" -namespace perftools { -namespace gputools { +namespace stream_executor { namespace cuda { // CUDAEvent wraps a CUevent in the platform-independent EventInterface @@ -58,7 +57,6 @@ class CUDAEvent : public internal::EventInterface { }; } // namespace cuda -} // namespace gputools -} // namespace perftools +} // namespace stream_executor #endif // TENSORFLOW_STREAM_EXECUTOR_CUDA_CUDA_EVENT_H_ diff --git a/tensorflow/stream_executor/cuda/cuda_fft.cc b/tensorflow/stream_executor/cuda/cuda_fft.cc index a922f14fb4a..5b34740f9f1 100644 --- a/tensorflow/stream_executor/cuda/cuda_fft.cc +++ b/tensorflow/stream_executor/cuda/cuda_fft.cc @@ -31,8 +31,7 @@ limitations under the License. #include "tensorflow/stream_executor/plugin_registry.h" #include "tensorflow/stream_executor/stream_executor_internal.h" -namespace perftools { -namespace gputools { +namespace stream_executor { namespace cuda { PLUGIN_REGISTRY_DEFINE_PLUGIN_ID(kCuFftPlugin); @@ -44,7 +43,7 @@ namespace wrap { // manner on first use. This dynamic loading technique is used to avoid DSO // dependencies on vendor libraries which may or may not be available in the // deployed binary environment. -#define PERFTOOLS_GPUTOOLS_CUFFT_WRAP(__name) \ +#define STREAM_EXECUTOR_CUFFT_WRAP(__name) \ struct WrapperShim__##__name { \ template \ cufftResult operator()(CUDAExecutor *parent, Args... args) { \ @@ -68,7 +67,7 @@ namespace wrap { __macro(cufftGetSizeMany) \ __macro(cufftMakePlanMany) -CUFFT_ROUTINE_EACH(PERFTOOLS_GPUTOOLS_CUFFT_WRAP) +CUFFT_ROUTINE_EACH(STREAM_EXECUTOR_CUFFT_WRAP) } // namespace wrap @@ -514,62 +513,59 @@ bool CUDAFft::DoFftWithDirectionInternal(Stream *stream, fft::Plan *plan, return true; } -#define PERFTOOLS_GPUTOOLS_CUDA_DEFINE_FFT(__type, __fft_type1, __fft_type2, \ - __fft_type3) \ - bool CUDAFft::DoFft(Stream *stream, fft::Plan *plan, \ - const DeviceMemory> &input, \ - DeviceMemory> *output) { \ - return DoFftWithDirectionInternal( \ - stream, plan, wrap::cufftExec##__fft_type1, input, output); \ - } \ - bool CUDAFft::DoFft(Stream *stream, fft::Plan *plan, \ - const DeviceMemory<__type> &input, \ - DeviceMemory> *output) { \ - return DoFftInternal(stream, plan, wrap::cufftExec##__fft_type2, input, \ - output); \ - } \ - bool CUDAFft::DoFft(Stream *stream, fft::Plan *plan, \ - const DeviceMemory> &input, \ - DeviceMemory<__type> *output) { \ - return DoFftInternal(stream, plan, wrap::cufftExec##__fft_type3, input, \ - output); \ +#define STREAM_EXECUTOR_CUDA_DEFINE_FFT(__type, __fft_type1, __fft_type2, \ + __fft_type3) \ + bool CUDAFft::DoFft(Stream *stream, fft::Plan *plan, \ + const DeviceMemory> &input, \ + DeviceMemory> *output) { \ + return DoFftWithDirectionInternal( \ + stream, plan, wrap::cufftExec##__fft_type1, input, output); \ + } \ + bool CUDAFft::DoFft(Stream *stream, fft::Plan *plan, \ + const DeviceMemory<__type> &input, \ + DeviceMemory> *output) { \ + return DoFftInternal(stream, plan, wrap::cufftExec##__fft_type2, input, \ + output); \ + } \ + bool CUDAFft::DoFft(Stream *stream, fft::Plan *plan, \ + const DeviceMemory> &input, \ + DeviceMemory<__type> *output) { \ + return DoFftInternal(stream, plan, wrap::cufftExec##__fft_type3, input, \ + output); \ } -PERFTOOLS_GPUTOOLS_CUDA_DEFINE_FFT(float, C2C, R2C, C2R) -PERFTOOLS_GPUTOOLS_CUDA_DEFINE_FFT(double, Z2Z, D2Z, Z2D) +STREAM_EXECUTOR_CUDA_DEFINE_FFT(float, C2C, R2C, C2R) +STREAM_EXECUTOR_CUDA_DEFINE_FFT(double, Z2Z, D2Z, Z2D) -#undef PERFTOOLS_GPUTOOLS_CUDA_DEFINE_FFT +#undef STREAM_EXECUTOR_CUDA_DEFINE_FFT } // namespace cuda -} // namespace gputools -} // namespace perftools -namespace gpu = ::perftools::gputools; +void initialize_cufft() { + port::Status status = + PluginRegistry::Instance()->RegisterFactory( + cuda::kCudaPlatformId, cuda::kCuFftPlugin, "cuFFT", + [](internal::StreamExecutorInterface *parent) -> fft::FftSupport * { + cuda::CUDAExecutor *cuda_executor = + dynamic_cast(parent); + if (cuda_executor == nullptr) { + LOG(ERROR) << "Attempting to initialize an instance of the cuFFT " + << "support library with a non-CUDA StreamExecutor"; + return nullptr; + } -REGISTER_MODULE_INITIALIZER(register_cufft, { - gpu::port::Status status = - gpu::PluginRegistry::Instance() - ->RegisterFactory( - gpu::cuda::kCudaPlatformId, gpu::cuda::kCuFftPlugin, "cuFFT", - [](gpu::internal::StreamExecutorInterface - *parent) -> gpu::fft::FftSupport * { - gpu::cuda::CUDAExecutor *cuda_executor = - dynamic_cast(parent); - if (cuda_executor == nullptr) { - LOG(ERROR) - << "Attempting to initialize an instance of the cuFFT " - << "support library with a non-CUDA StreamExecutor"; - return nullptr; - } - - return new gpu::cuda::CUDAFft(cuda_executor); - }); + return new cuda::CUDAFft(cuda_executor); + }); if (!status.ok()) { LOG(ERROR) << "Unable to register cuFFT factory: " << status.error_message(); } - gpu::PluginRegistry::Instance()->SetDefaultFactory(gpu::cuda::kCudaPlatformId, - gpu::PluginKind::kFft, - gpu::cuda::kCuFftPlugin); -}); + PluginRegistry::Instance()->SetDefaultFactory( + cuda::kCudaPlatformId, PluginKind::kFft, cuda::kCuFftPlugin); +} + +} // namespace stream_executor + +REGISTER_MODULE_INITIALIZER(register_cufft, + { stream_executor::initialize_cufft(); }); diff --git a/tensorflow/stream_executor/cuda/cuda_fft.h b/tensorflow/stream_executor/cuda/cuda_fft.h index 04c7dfe501c..8171e61418a 100644 --- a/tensorflow/stream_executor/cuda/cuda_fft.h +++ b/tensorflow/stream_executor/cuda/cuda_fft.h @@ -26,8 +26,7 @@ limitations under the License. #include "tensorflow/stream_executor/plugin_registry.h" #include "tensorflow/stream_executor/scratch_allocator.h" -namespace perftools { -namespace gputools { +namespace stream_executor { class Stream; @@ -133,7 +132,6 @@ class CUDAFft : public fft::FftSupport { }; } // namespace cuda -} // namespace gputools -} // namespace perftools +} // namespace stream_executor #endif // TENSORFLOW_STREAM_EXECUTOR_CUDA_CUDA_FFT_H_ diff --git a/tensorflow/stream_executor/cuda/cuda_gpu_executor.cc b/tensorflow/stream_executor/cuda/cuda_gpu_executor.cc index 5ecaf46b8ca..9700daca890 100644 --- a/tensorflow/stream_executor/cuda/cuda_gpu_executor.cc +++ b/tensorflow/stream_executor/cuda/cuda_gpu_executor.cc @@ -66,8 +66,7 @@ limitations under the License. extern bool FLAGS_check_gpu_leaks; bool FLAGS_prefer_cubin_to_ptx = true; -namespace perftools { -namespace gputools { +namespace stream_executor { namespace cuda { // Hook that can be used to CUBIN-ate PTX before it is loaded into the driver. @@ -1168,17 +1167,14 @@ DeviceDescription *CUDAExecutor::PopulateDeviceDescription() const { } // namespace cuda -namespace gpu = ::perftools::gputools; - void initialize_cuda_gpu_executor() { - *gpu::internal::MakeCUDAExecutorImplementation() = []( - const gpu::PluginConfig &config) { - return new gpu::cuda::CUDAExecutor{config}; + *internal::MakeCUDAExecutorImplementation() = [](const PluginConfig &config) { + return new cuda::CUDAExecutor{config}; }; } -} // namespace gputools -} // namespace perftools +} // namespace stream_executor -REGISTER_MODULE_INITIALIZER( - cuda_gpu_executor, {perftools::gputools::initialize_cuda_gpu_executor();}); +REGISTER_MODULE_INITIALIZER(cuda_gpu_executor, { + stream_executor::initialize_cuda_gpu_executor(); +}); diff --git a/tensorflow/stream_executor/cuda/cuda_gpu_executor.h b/tensorflow/stream_executor/cuda/cuda_gpu_executor.h index dbbbcd476f0..f686685474b 100644 --- a/tensorflow/stream_executor/cuda/cuda_gpu_executor.h +++ b/tensorflow/stream_executor/cuda/cuda_gpu_executor.h @@ -35,8 +35,7 @@ limitations under the License. #include "tensorflow/stream_executor/platform/thread_annotations.h" #include "tensorflow/stream_executor/stream_executor_internal.h" -namespace perftools { -namespace gputools { +namespace stream_executor { namespace cuda { // CUDA-platform implementation of the platform-agnostic @@ -273,7 +272,6 @@ class CUDAExecutor : public internal::StreamExecutorInterface { }; } // namespace cuda -} // namespace gputools -} // namespace perftools +} // namespace stream_executor #endif // TENSORFLOW_STREAM_EXECUTOR_CUDA_CUDA_GPU_EXECUTOR_H_ diff --git a/tensorflow/stream_executor/cuda/cuda_helpers.h b/tensorflow/stream_executor/cuda/cuda_helpers.h index 6a6134bf881..d55706c66a9 100644 --- a/tensorflow/stream_executor/cuda/cuda_helpers.h +++ b/tensorflow/stream_executor/cuda/cuda_helpers.h @@ -27,8 +27,7 @@ limitations under the License. #include "cuda/include/cuComplex.h" #include "cuda/include/cuda.h" -namespace perftools { -namespace gputools { +namespace stream_executor { template class DeviceMemory; @@ -101,7 +100,6 @@ inline cuDoubleComplex CUDAComplexValue(std::complex val) { } } // namespace cuda -} // namespace gputools -} // namespace perftools +} // namespace stream_executor #endif // TENSORFLOW_STREAM_EXECUTOR_CUDA_CUDA_HELPERS_H_ diff --git a/tensorflow/stream_executor/cuda/cuda_kernel.h b/tensorflow/stream_executor/cuda/cuda_kernel.h index 88d29fddd06..beaebe8f123 100644 --- a/tensorflow/stream_executor/cuda/cuda_kernel.h +++ b/tensorflow/stream_executor/cuda/cuda_kernel.h @@ -40,8 +40,7 @@ limitations under the License. "CUDA runtime being included into CUDA GPU executor; should be driver only." #endif -namespace perftools { -namespace gputools { +namespace stream_executor { namespace cuda { // Wraps a CUfunction to implement the platform-independent KernelInterface. @@ -124,7 +123,6 @@ inline CUDAKernel *AsCUDAKernel(KernelBase *kernel) { } } // namespace cuda -} // namespace gputools -} // namespace perftools +} // namespace stream_executor #endif // TENSORFLOW_STREAM_EXECUTOR_CUDA_CUDA_KERNEL_H_ diff --git a/tensorflow/stream_executor/cuda/cuda_platform.cc b/tensorflow/stream_executor/cuda/cuda_platform.cc index 3a738461489..7a6ef5a248f 100644 --- a/tensorflow/stream_executor/cuda/cuda_platform.cc +++ b/tensorflow/stream_executor/cuda/cuda_platform.cc @@ -24,8 +24,7 @@ limitations under the License. #include "tensorflow/stream_executor/lib/status.h" #include "tensorflow/stream_executor/lib/stringprintf.h" -namespace perftools { -namespace gputools { +namespace stream_executor { namespace cuda { namespace { @@ -41,16 +40,16 @@ const DeviceOptions GetDeviceOptionsFromEnv() { std::getenv("TF_CUDA_PLATFORM_GPU_DEVICE_SCHEDULE"); if (gpu_schedule_string == nullptr) { - return perftools::gputools::DeviceOptions::Default(); + return DeviceOptions::Default(); } unsigned device_flags = 0; if (strcmp(kScheduleSpinString, gpu_schedule_string) == 0) { - device_flags = perftools::gputools::DeviceOptions::kScheduleSpin; + device_flags = DeviceOptions::kScheduleSpin; } else if (strcmp(kScheduleYieldString, gpu_schedule_string) == 0) { - device_flags = perftools::gputools::DeviceOptions::kScheduleYield; + device_flags = DeviceOptions::kScheduleYield; } else if (strcmp(kScheduleBlockingSyncString, gpu_schedule_string) == 0) { - device_flags = perftools::gputools::DeviceOptions::kScheduleBlockingSync; + device_flags = DeviceOptions::kScheduleBlockingSync; } else { LOG(QFATAL) << "Unknown option for environment variable " "TF_CUDA_PLATFORM_GPU_DEVICE_SCHEDULE " @@ -59,7 +58,7 @@ const DeviceOptions GetDeviceOptionsFromEnv() { << ", " << kScheduleYieldString << "}"; } - return perftools::gputools::DeviceOptions(device_flags); + return DeviceOptions(device_flags); } } // namespace @@ -202,11 +201,10 @@ static void InitializeCudaPlatform() { SE_CHECK_OK(MultiPlatformManager::RegisterPlatform(std::move(platform))); } -} // namespace gputools -} // namespace perftools +} // namespace stream_executor REGISTER_MODULE_INITIALIZER(cuda_platform, - perftools::gputools::InitializeCudaPlatform()); + stream_executor::InitializeCudaPlatform()); DECLARE_MODULE_INITIALIZER(multi_platform_manager); // Note that module initialization sequencing is not supported in the diff --git a/tensorflow/stream_executor/cuda/cuda_platform.h b/tensorflow/stream_executor/cuda/cuda_platform.h index dab25602d08..fc0e15d5a6a 100644 --- a/tensorflow/stream_executor/cuda/cuda_platform.h +++ b/tensorflow/stream_executor/cuda/cuda_platform.h @@ -31,8 +31,7 @@ limitations under the License. #include "tensorflow/stream_executor/stream_executor_pimpl.h" #include "tensorflow/stream_executor/trace_listener.h" -namespace perftools { -namespace gputools { +namespace stream_executor { namespace cuda { // Opaque and unique identifier for the CUDA platform plugin. @@ -104,7 +103,6 @@ class CudaPlatform : public Platform { }; } // namespace cuda -} // namespace gputools -} // namespace perftools +} // namespace stream_executor #endif // TENSORFLOW_STREAM_EXECUTOR_CUDA_CUDA_PLATFORM_H_ diff --git a/tensorflow/stream_executor/cuda/cuda_platform_id.cc b/tensorflow/stream_executor/cuda/cuda_platform_id.cc index dfd11a9abe8..a7bb304cc8c 100644 --- a/tensorflow/stream_executor/cuda/cuda_platform_id.cc +++ b/tensorflow/stream_executor/cuda/cuda_platform_id.cc @@ -15,12 +15,10 @@ limitations under the License. #include "tensorflow/stream_executor/cuda/cuda_platform_id.h" -namespace perftools { -namespace gputools { +namespace stream_executor { namespace cuda { PLATFORM_DEFINE_ID(kCudaPlatformId); } // namespace cuda -} // namespace gputools -} // namespace perftools +} // namespace stream_executor diff --git a/tensorflow/stream_executor/cuda/cuda_platform_id.h b/tensorflow/stream_executor/cuda/cuda_platform_id.h index c677724517c..92bcfd83722 100644 --- a/tensorflow/stream_executor/cuda/cuda_platform_id.h +++ b/tensorflow/stream_executor/cuda/cuda_platform_id.h @@ -18,8 +18,7 @@ limitations under the License. #include "tensorflow/stream_executor/platform.h" -namespace perftools { -namespace gputools { +namespace stream_executor { namespace cuda { // Opaque and unique identifier for the cuda platform. @@ -30,7 +29,6 @@ namespace cuda { extern const Platform::Id kCudaPlatformId; } // namespace cuda -} // namespace gputools -} // namespace perftools +} // namespace stream_executor #endif // TENSORFLOW_STREAM_EXECUTOR_CUDA_CUDA_PLATFORM_ID_H_ diff --git a/tensorflow/stream_executor/cuda/cuda_rng.cc b/tensorflow/stream_executor/cuda/cuda_rng.cc index 8641b602277..e289e7ced57 100644 --- a/tensorflow/stream_executor/cuda/cuda_rng.cc +++ b/tensorflow/stream_executor/cuda/cuda_rng.cc @@ -54,15 +54,14 @@ std::ostream &operator<<(std::ostream &in, const curandStatus_t &status) { } } -namespace perftools { -namespace gputools { +namespace stream_executor { namespace cuda { PLUGIN_REGISTRY_DEFINE_PLUGIN_ID(kCuRandPlugin); namespace wrap { -#define PERFTOOLS_GPUTOOLS_CURAND_WRAP(__name) \ +#define STREAM_EXECUTOR_CURAND_WRAP(__name) \ struct WrapperShim__##__name { \ template \ curandStatus_t operator()(CUDAExecutor *parent, Args... args) { \ @@ -71,15 +70,15 @@ namespace wrap { } \ } __name; -PERFTOOLS_GPUTOOLS_CURAND_WRAP(curandCreateGenerator); -PERFTOOLS_GPUTOOLS_CURAND_WRAP(curandDestroyGenerator); -PERFTOOLS_GPUTOOLS_CURAND_WRAP(curandSetStream); -PERFTOOLS_GPUTOOLS_CURAND_WRAP(curandGenerateUniform); -PERFTOOLS_GPUTOOLS_CURAND_WRAP(curandGenerateUniformDouble); -PERFTOOLS_GPUTOOLS_CURAND_WRAP(curandSetPseudoRandomGeneratorSeed); -PERFTOOLS_GPUTOOLS_CURAND_WRAP(curandSetGeneratorOffset); -PERFTOOLS_GPUTOOLS_CURAND_WRAP(curandGenerateNormal); -PERFTOOLS_GPUTOOLS_CURAND_WRAP(curandGenerateNormalDouble); +STREAM_EXECUTOR_CURAND_WRAP(curandCreateGenerator); +STREAM_EXECUTOR_CURAND_WRAP(curandDestroyGenerator); +STREAM_EXECUTOR_CURAND_WRAP(curandSetStream); +STREAM_EXECUTOR_CURAND_WRAP(curandGenerateUniform); +STREAM_EXECUTOR_CURAND_WRAP(curandGenerateUniformDouble); +STREAM_EXECUTOR_CURAND_WRAP(curandSetPseudoRandomGeneratorSeed); +STREAM_EXECUTOR_CURAND_WRAP(curandSetGeneratorOffset); +STREAM_EXECUTOR_CURAND_WRAP(curandGenerateNormal); +STREAM_EXECUTOR_CURAND_WRAP(curandGenerateNormalDouble); } // namespace wrap @@ -271,42 +270,40 @@ bool CUDARng::SetSeed(Stream *stream, const uint8 *seed, uint64 seed_bytes) { } } // namespace cuda -} // namespace gputools -} // namespace perftools -namespace gpu = ::perftools::gputools; +void initialize_curand() { + port::Status status = + PluginRegistry::Instance()->RegisterFactory( + cuda::kCudaPlatformId, cuda::kCuRandPlugin, "cuRAND", + [](internal::StreamExecutorInterface *parent) -> rng::RngSupport * { + cuda::CUDAExecutor *cuda_executor = + dynamic_cast(parent); + if (cuda_executor == nullptr) { + LOG(ERROR) + << "Attempting to initialize an instance of the cuRAND " + << "support library with a non-CUDA StreamExecutor"; + return nullptr; + } -REGISTER_MODULE_INITIALIZER(register_curand, { - gpu::port::Status status = - gpu::PluginRegistry::Instance() - ->RegisterFactory( - gpu::cuda::kCudaPlatformId, gpu::cuda::kCuRandPlugin, "cuRAND", - [](gpu::internal::StreamExecutorInterface - *parent) -> gpu::rng::RngSupport * { - gpu::cuda::CUDAExecutor *cuda_executor = - dynamic_cast(parent); - if (cuda_executor == nullptr) { - LOG(ERROR) - << "Attempting to initialize an instance of the cuRAND " - << "support library with a non-CUDA StreamExecutor"; - return nullptr; - } - - gpu::cuda::CUDARng *rng = new gpu::cuda::CUDARng(cuda_executor); - if (!rng->Init()) { - // Note: Init() will log a more specific error. - delete rng; - return nullptr; - } - return rng; - }); + cuda::CUDARng *rng = new cuda::CUDARng(cuda_executor); + if (!rng->Init()) { + // Note: Init() will log a more specific error. + delete rng; + return nullptr; + } + return rng; + }); if (!status.ok()) { LOG(ERROR) << "Unable to register cuRAND factory: " << status.error_message(); } - gpu::PluginRegistry::Instance()->SetDefaultFactory(gpu::cuda::kCudaPlatformId, - gpu::PluginKind::kRng, - gpu::cuda::kCuRandPlugin); -}); + PluginRegistry::Instance()->SetDefaultFactory( + cuda::kCudaPlatformId, PluginKind::kRng, cuda::kCuRandPlugin); +} + +} // namespace stream_executor + +REGISTER_MODULE_INITIALIZER(register_curand, + { stream_executor::initialize_curand(); }); diff --git a/tensorflow/stream_executor/cuda/cuda_rng.h b/tensorflow/stream_executor/cuda/cuda_rng.h index 5bbfd0b37a0..57ef398aaa8 100644 --- a/tensorflow/stream_executor/cuda/cuda_rng.h +++ b/tensorflow/stream_executor/cuda/cuda_rng.h @@ -24,8 +24,7 @@ limitations under the License. typedef struct curandGenerator_st *curandGenerator_t; -namespace perftools { -namespace gputools { +namespace stream_executor { class Stream; template @@ -98,7 +97,6 @@ class CUDARng : public rng::RngSupport { }; } // namespace cuda -} // namespace gputools -} // namespace perftools +} // namespace stream_executor #endif // TENSORFLOW_STREAM_EXECUTOR_CUDA_CUDA_RNG_H_ diff --git a/tensorflow/stream_executor/cuda/cuda_stream.cc b/tensorflow/stream_executor/cuda/cuda_stream.cc index 3eb37a7d841..b5aa7694f7e 100644 --- a/tensorflow/stream_executor/cuda/cuda_stream.cc +++ b/tensorflow/stream_executor/cuda/cuda_stream.cc @@ -19,8 +19,7 @@ limitations under the License. #include "tensorflow/stream_executor/lib/status.h" #include "tensorflow/stream_executor/stream.h" -namespace perftools { -namespace gputools { +namespace stream_executor { namespace cuda { bool CUDAStream::Init() { @@ -59,5 +58,4 @@ CUstream AsCUDAStreamValue(Stream *stream) { } } // namespace cuda -} // namespace gputools -} // namespace perftools +} // namespace stream_executor diff --git a/tensorflow/stream_executor/cuda/cuda_stream.h b/tensorflow/stream_executor/cuda/cuda_stream.h index 7358243dc40..02edff64311 100644 --- a/tensorflow/stream_executor/cuda/cuda_stream.h +++ b/tensorflow/stream_executor/cuda/cuda_stream.h @@ -23,8 +23,7 @@ limitations under the License. #include "tensorflow/stream_executor/platform/thread_annotations.h" #include "tensorflow/stream_executor/stream_executor_internal.h" -namespace perftools { -namespace gputools { +namespace stream_executor { namespace cuda { class CUDAExecutor; @@ -89,7 +88,6 @@ CUDAStream *AsCUDAStream(Stream *stream); CUstream AsCUDAStreamValue(Stream *stream); } // namespace cuda -} // namespace gputools -} // namespace perftools +} // namespace stream_executor #endif // TENSORFLOW_STREAM_EXECUTOR_CUDA_CUDA_STREAM_H_ diff --git a/tensorflow/stream_executor/cuda/cuda_timer.cc b/tensorflow/stream_executor/cuda/cuda_timer.cc index 8532f08725d..991a12a23d6 100644 --- a/tensorflow/stream_executor/cuda/cuda_timer.cc +++ b/tensorflow/stream_executor/cuda/cuda_timer.cc @@ -20,8 +20,7 @@ limitations under the License. #include "tensorflow/stream_executor/cuda/cuda_stream.h" #include "tensorflow/stream_executor/lib/status.h" -namespace perftools { -namespace gputools { +namespace stream_executor { namespace cuda { bool CUDATimer::Init() { @@ -92,5 +91,4 @@ bool CUDATimer::Stop(CUDAStream* stream) { } } // namespace cuda -} // namespace gputools -} // namespace perftools +} // namespace stream_executor diff --git a/tensorflow/stream_executor/cuda/cuda_timer.h b/tensorflow/stream_executor/cuda/cuda_timer.h index 2abc55ec941..70554ec9312 100644 --- a/tensorflow/stream_executor/cuda/cuda_timer.h +++ b/tensorflow/stream_executor/cuda/cuda_timer.h @@ -23,8 +23,7 @@ limitations under the License. #include "tensorflow/stream_executor/cuda/cuda_driver.h" #include "tensorflow/stream_executor/cuda/cuda_gpu_executor.h" -namespace perftools { -namespace gputools { +namespace stream_executor { namespace cuda { class CUDAExecutor; @@ -60,13 +59,13 @@ class CUDATimer : public internal::TimerInterface { // events. float GetElapsedMilliseconds() const; - // See perftools::gputools::Timer::Microseconds(). + // See Timer::Microseconds(). // TODO(leary) make this into an error code interface... uint64 Microseconds() const override { return GetElapsedMilliseconds() * 1e3; } - // See perftools::GPUTools::Timer::Nanoseconds(). + // See Timer::Nanoseconds(). uint64 Nanoseconds() const override { return GetElapsedMilliseconds() * 1e6; } private: @@ -85,7 +84,6 @@ struct TimerDeleter { }; } // namespace cuda -} // namespace gputools -} // namespace perftools +} // namespace stream_executor #endif // TENSORFLOW_STREAM_EXECUTOR_CUDA_CUDA_TIMER_H_ diff --git a/tensorflow/stream_executor/cuda/cudnn_version.cc b/tensorflow/stream_executor/cuda/cudnn_version.cc index 5591801aae2..e8fcc036185 100644 --- a/tensorflow/stream_executor/cuda/cudnn_version.cc +++ b/tensorflow/stream_executor/cuda/cudnn_version.cc @@ -15,8 +15,7 @@ limitations under the License. #include "tensorflow/stream_executor/cuda/cudnn_version.h" -namespace perftools { -namespace gputools { +namespace stream_executor { namespace cuda { bool IsSourceCompatibleWithCudnnLibrary(CudnnVersion source_version, @@ -38,5 +37,4 @@ bool IsSourceCompatibleWithCudnnLibrary(CudnnVersion source_version, } } // namespace cuda -} // namespace gputools -} // namespace perftools +} // namespace stream_executor diff --git a/tensorflow/stream_executor/cuda/cudnn_version.h b/tensorflow/stream_executor/cuda/cudnn_version.h index 2ed02e1700c..6464e7f8e87 100644 --- a/tensorflow/stream_executor/cuda/cudnn_version.h +++ b/tensorflow/stream_executor/cuda/cudnn_version.h @@ -20,8 +20,7 @@ limitations under the License. #include "tensorflow/core/lib/strings/strcat.h" -namespace perftools { -namespace gputools { +namespace stream_executor { namespace cuda { struct CudnnVersion { @@ -46,7 +45,6 @@ bool IsSourceCompatibleWithCudnnLibrary(CudnnVersion source_version, CudnnVersion loaded_version); } // namespace cuda -} // namespace gputools -} // namespace perftools +} // namespace stream_executor #endif // TENSORFLOW_STREAM_EXECUTOR_CUDA_CUDNN_VERSION_H_ diff --git a/tensorflow/stream_executor/cuda/cudnn_version_test.cc b/tensorflow/stream_executor/cuda/cudnn_version_test.cc index 42b3dc8cc67..7d4c6399d04 100644 --- a/tensorflow/stream_executor/cuda/cudnn_version_test.cc +++ b/tensorflow/stream_executor/cuda/cudnn_version_test.cc @@ -17,8 +17,7 @@ limitations under the License. #include "tensorflow/core/platform/test.h" -namespace perftools { -namespace gputools { +namespace stream_executor { namespace cuda { namespace { @@ -70,5 +69,4 @@ TEST(IsSourceCompatibleWithCudnnLibraryTest, Basic) { } // namespace } // namespace cuda -} // namespace gputools -} // namespace perftools +} // namespace stream_executor diff --git a/tensorflow/stream_executor/device_description.cc b/tensorflow/stream_executor/device_description.cc index 52f5319a3b1..8ca0677f8a3 100644 --- a/tensorflow/stream_executor/device_description.cc +++ b/tensorflow/stream_executor/device_description.cc @@ -21,8 +21,7 @@ limitations under the License. #include "tensorflow/stream_executor/lib/mathutil.h" #include "tensorflow/stream_executor/lib/strcat.h" -namespace perftools { -namespace gputools { +namespace stream_executor { static const uint64 kUninitializedUint64 = -1ULL; /* static */ const char *DeviceDescription::kUndefinedString = ""; @@ -234,6 +233,4 @@ uint64 CalculateRegisterLimitForTargetOccupancy( return 0; } - -} // namespace gputools -} // namespace perftools +} // namespace stream_executor diff --git a/tensorflow/stream_executor/device_description.h b/tensorflow/stream_executor/device_description.h index fcf0928096e..7f99d81ef3b 100644 --- a/tensorflow/stream_executor/device_description.h +++ b/tensorflow/stream_executor/device_description.h @@ -27,8 +27,7 @@ limitations under the License. #include "tensorflow/stream_executor/launch_dim.h" #include "tensorflow/stream_executor/platform/port.h" -namespace perftools { -namespace gputools { +namespace stream_executor { namespace internal { class DeviceDescriptionBuilder; } // namespace internal @@ -388,7 +387,6 @@ uint64 CalculateRegisterLimitForTargetOccupancy( const DeviceDescription &device_description, uint64 shared_memory_per_block, const ThreadDim &thread_dims, uint64 target_blocks_per_core); -} // namespace gputools -} // namespace perftools +} // namespace stream_executor #endif // TENSORFLOW_STREAM_EXECUTOR_DEVICE_DESCRIPTION_H_ diff --git a/tensorflow/stream_executor/device_memory.h b/tensorflow/stream_executor/device_memory.h index 4c92b7dc785..5a5334e0f5f 100644 --- a/tensorflow/stream_executor/device_memory.h +++ b/tensorflow/stream_executor/device_memory.h @@ -32,6 +32,16 @@ limitations under the License. namespace perftools { namespace gputools { +// Temporarily pull stream_executor into perftools::gputools while we migrate +// code to the new namespace. TODO(b/77980417): Remove this once we've +// completed the migration. +using namespace stream_executor; // NOLINT[build/namespaces] + +} // namespace gputools +} // namespace perftools + +namespace stream_executor { + class StreamExecutor; // void*-analogous device memory allocation. For the typed variation, see @@ -280,7 +290,6 @@ static_assert(sizeof(Float2) == 2 * sizeof(float), "Float2 must be packed"); static_assert(sizeof(Float4) == 4 * sizeof(float), "Float4 must be packed"); static_assert(sizeof(Double2) == 2 * sizeof(double), "Double2 must be packed"); -} // namespace gputools -} // namespace perftools +} // namespace stream_executor #endif // TENSORFLOW_STREAM_EXECUTOR_DEVICE_MEMORY_H_ diff --git a/tensorflow/stream_executor/device_options.h b/tensorflow/stream_executor/device_options.h index 169325e7d12..2646950f42e 100644 --- a/tensorflow/stream_executor/device_options.h +++ b/tensorflow/stream_executor/device_options.h @@ -25,8 +25,7 @@ limitations under the License. #include "tensorflow/stream_executor/platform/port.h" #include "tensorflow/stream_executor/platform/logging.h" -namespace perftools { -namespace gputools { +namespace stream_executor { // Indicates a set of options for a device's usage, which generally must be // provided at StreamExecutor device-initialization time. @@ -84,7 +83,6 @@ struct DeviceOptions { unsigned flags_; }; -} // namespace gputools -} // namespace perftools +} // namespace stream_executor #endif // TENSORFLOW_STREAM_EXECUTOR_DEVICE_OPTIONS_H_ diff --git a/tensorflow/stream_executor/dnn.cc b/tensorflow/stream_executor/dnn.cc index 0a3c4bcf503..6edb5728201 100644 --- a/tensorflow/stream_executor/dnn.cc +++ b/tensorflow/stream_executor/dnn.cc @@ -18,8 +18,7 @@ limitations under the License. #include "tensorflow/stream_executor/lib/strcat.h" #include "tensorflow/stream_executor/lib/stringprintf.h" -namespace perftools { -namespace gputools { +namespace stream_executor { namespace dnn { bool DnnSupport::GetConvolveAlgorithms( @@ -554,5 +553,4 @@ string NormalizeDescriptor::ToShortString() const { } } // namespace dnn -} // namespace gputools -} // namespace perftools +} // namespace stream_executor diff --git a/tensorflow/stream_executor/dnn.h b/tensorflow/stream_executor/dnn.h index 3c47d2c2e85..8e202d115a5 100644 --- a/tensorflow/stream_executor/dnn.h +++ b/tensorflow/stream_executor/dnn.h @@ -38,8 +38,7 @@ namespace Eigen { struct half; } // namespace Eigen -namespace perftools { -namespace gputools { +namespace stream_executor { class HostBuffer; class Stream; @@ -2285,7 +2284,6 @@ class DnnSupport { }; } // namespace dnn -} // namespace gputools -} // namespace perftools +} // namespace stream_executor #endif // TENSORFLOW_STREAM_EXECUTOR_DNN_H_ diff --git a/tensorflow/stream_executor/dso_loader.cc b/tensorflow/stream_executor/dso_loader.cc index 95168836278..114143b3abe 100644 --- a/tensorflow/stream_executor/dso_loader.cc +++ b/tensorflow/stream_executor/dso_loader.cc @@ -37,8 +37,7 @@ limitations under the License. #include "cuda/cuda_config.h" #endif -namespace perftools { -namespace gputools { +namespace stream_executor { namespace internal { string GetCudaVersion() { return TF_CUDA_VERSION; } @@ -291,5 +290,4 @@ static std::vector* CreatePrimordialRpaths() { } } // namespace internal -} // namespace gputools -} // namespace perftools +} // namespace stream_executor diff --git a/tensorflow/stream_executor/dso_loader.h b/tensorflow/stream_executor/dso_loader.h index 354c7b50b82..9ee081cb3d6 100644 --- a/tensorflow/stream_executor/dso_loader.h +++ b/tensorflow/stream_executor/dso_loader.h @@ -28,8 +28,7 @@ limitations under the License. #include "tensorflow/stream_executor/platform.h" #include "tensorflow/stream_executor/platform/mutex.h" -namespace perftools { -namespace gputools { +namespace stream_executor { namespace internal { // Permits StreamExecutor code to dynamically load a pre-determined set of @@ -114,7 +113,6 @@ class CachedDsoLoader { }; } // namespace internal -} // namespace gputools -} // namespace perftools +} // namespace stream_executor #endif // TENSORFLOW_STREAM_EXECUTOR_DSO_LOADER_H_ diff --git a/tensorflow/stream_executor/event.cc b/tensorflow/stream_executor/event.cc index c423a453e9f..50a6edd80bd 100644 --- a/tensorflow/stream_executor/event.cc +++ b/tensorflow/stream_executor/event.cc @@ -19,8 +19,7 @@ limitations under the License. #include "tensorflow/stream_executor/stream_executor_pimpl.h" #include "tensorflow/stream_executor/stream.h" -namespace perftools { -namespace gputools { +namespace stream_executor { Event::Event(StreamExecutor* stream_exec) : stream_exec_(stream_exec), @@ -48,5 +47,4 @@ Event::Status Event::PollForStatus() { return stream_exec_->PollForEventStatus(this); } -} // namespace gputools -} // namespace perftools +} // namespace stream_executor diff --git a/tensorflow/stream_executor/event.h b/tensorflow/stream_executor/event.h index a06c26ea519..1f37262c78d 100644 --- a/tensorflow/stream_executor/event.h +++ b/tensorflow/stream_executor/event.h @@ -20,8 +20,7 @@ limitations under the License. #include "tensorflow/stream_executor/platform/port.h" -namespace perftools { -namespace gputools { +namespace stream_executor { namespace internal { class EventInterface; @@ -76,7 +75,6 @@ class Event { SE_DISALLOW_COPY_AND_ASSIGN(Event); }; -} // namespace gputools -} // namespace perftools +} // namespace stream_executor #endif // TENSORFLOW_STREAM_EXECUTOR_EVENT_H_ diff --git a/tensorflow/stream_executor/executor_cache.cc b/tensorflow/stream_executor/executor_cache.cc index d1a8aae1674..0b3ad7ebbc9 100644 --- a/tensorflow/stream_executor/executor_cache.cc +++ b/tensorflow/stream_executor/executor_cache.cc @@ -17,8 +17,7 @@ limitations under the License. #include "tensorflow/stream_executor/lib/stringprintf.h" -namespace perftools { -namespace gputools { +namespace stream_executor { port::StatusOr ExecutorCache::GetOrCreate( const StreamExecutorConfig& config, @@ -104,5 +103,4 @@ ExecutorCache::Entry::~Entry() { configurations.clear(); } -} // namespace gputools -} // namespace perftools +} // namespace stream_executor diff --git a/tensorflow/stream_executor/executor_cache.h b/tensorflow/stream_executor/executor_cache.h index 12f2275f6d8..bbeeaed787c 100644 --- a/tensorflow/stream_executor/executor_cache.h +++ b/tensorflow/stream_executor/executor_cache.h @@ -24,8 +24,7 @@ limitations under the License. #include "tensorflow/stream_executor/platform/mutex.h" #include "tensorflow/stream_executor/stream_executor_pimpl.h" -namespace perftools { -namespace gputools { +namespace stream_executor { // Utility class to allow Platform objects to manage cached StreamExecutors. // Thread-safe. @@ -76,7 +75,6 @@ class ExecutorCache { SE_DISALLOW_COPY_AND_ASSIGN(ExecutorCache); }; -} // namespace gputools -} // namespace perftools +} // namespace stream_executor #endif // TENSORFLOW_STREAM_EXECUTOR_EXECUTOR_CACHE_H_ diff --git a/tensorflow/stream_executor/fft.h b/tensorflow/stream_executor/fft.h index 6b1728829ab..814efb2e923 100644 --- a/tensorflow/stream_executor/fft.h +++ b/tensorflow/stream_executor/fft.h @@ -48,8 +48,7 @@ limitations under the License. #include #include "tensorflow/stream_executor/platform/port.h" -namespace perftools { -namespace gputools { +namespace stream_executor { class Stream; template @@ -210,7 +209,7 @@ class FftSupport { // Macro used to quickly declare overrides for abstract virtuals in the // fft::FftSupport base class. Assumes that it's emitted somewhere inside the -// ::perftools::gputools namespace. +// ::stream_executor namespace. #define TENSORFLOW_STREAM_EXECUTOR_GPU_FFT_SUPPORT_OVERRIDES \ std::unique_ptr Create1dPlan(Stream *stream, uint64 num_x, \ fft::Type type, bool in_place_fft) \ @@ -265,7 +264,6 @@ class FftSupport { DeviceMemory *output) override; } // namespace fft -} // namespace gputools -} // namespace perftools +} // namespace stream_executor #endif // TENSORFLOW_STREAM_EXECUTOR_FFT_H_ diff --git a/tensorflow/stream_executor/host/host_gpu_executor.cc b/tensorflow/stream_executor/host/host_gpu_executor.cc index 542f521ef77..2c4819651ac 100644 --- a/tensorflow/stream_executor/host/host_gpu_executor.cc +++ b/tensorflow/stream_executor/host/host_gpu_executor.cc @@ -28,8 +28,7 @@ limitations under the License. bool FLAGS_stream_executor_cpu_real_clock_rate = false; -namespace perftools { -namespace gputools { +namespace stream_executor { namespace host { HostStream *AsHostStream(Stream *stream) { @@ -266,5 +265,4 @@ rng::RngSupport *HostExecutor::CreateRng() { } } // namespace host -} // namespace gputools -} // namespace perftools +} // namespace stream_executor diff --git a/tensorflow/stream_executor/host/host_gpu_executor.h b/tensorflow/stream_executor/host/host_gpu_executor.h index e2c0e6d6b77..0c3991c151d 100644 --- a/tensorflow/stream_executor/host/host_gpu_executor.h +++ b/tensorflow/stream_executor/host/host_gpu_executor.h @@ -28,8 +28,7 @@ limitations under the License. #include "tensorflow/stream_executor/stream_executor.h" #include "tensorflow/stream_executor/stream_executor_internal.h" -namespace perftools { -namespace gputools { +namespace stream_executor { namespace host { // An implementation of StreamExecutor that does no communication or interaction @@ -210,7 +209,6 @@ class HostExecutor : public internal::StreamExecutorInterface { }; } // namespace host -} // namespace gputools -} // namespace perftools +} // namespace stream_executor #endif // TENSORFLOW_STREAM_EXECUTOR_HOST_HOST_GPU_EXECUTOR_H_ diff --git a/tensorflow/stream_executor/host/host_platform.cc b/tensorflow/stream_executor/host/host_platform.cc index 2cb7d369675..00a17a05ede 100644 --- a/tensorflow/stream_executor/host/host_platform.cc +++ b/tensorflow/stream_executor/host/host_platform.cc @@ -26,10 +26,7 @@ limitations under the License. #include "tensorflow/stream_executor/lib/status_macros.h" #include "tensorflow/stream_executor/lib/stringprintf.h" -namespace gpu = ::perftools::gputools; - -namespace perftools { -namespace gputools { +namespace stream_executor { namespace host { HostPlatform::HostPlatform() : name_("Host") {} @@ -93,16 +90,15 @@ void HostPlatform::UnregisterTraceListener(TraceListener* listener) { } static void InitializeHostPlatform() { - std::unique_ptr platform(new gpu::host::HostPlatform); - SE_CHECK_OK(gpu::MultiPlatformManager::RegisterPlatform(std::move(platform))); + std::unique_ptr platform(new host::HostPlatform); + SE_CHECK_OK(MultiPlatformManager::RegisterPlatform(std::move(platform))); } } // namespace host -} // namespace gputools -} // namespace perftools +} // namespace stream_executor -REGISTER_MODULE_INITIALIZER( - host_platform, perftools::gputools::host::InitializeHostPlatform()); +REGISTER_MODULE_INITIALIZER(host_platform, + stream_executor::host::InitializeHostPlatform()); DECLARE_MODULE_INITIALIZER(multi_platform_manager); // Note that module initialization sequencing is not supported in the diff --git a/tensorflow/stream_executor/host/host_platform.h b/tensorflow/stream_executor/host/host_platform.h index 0faec6c8b78..c6f46a2cc40 100644 --- a/tensorflow/stream_executor/host/host_platform.h +++ b/tensorflow/stream_executor/host/host_platform.h @@ -33,8 +33,7 @@ limitations under the License. #include "tensorflow/stream_executor/stream_executor_pimpl.h" #include "tensorflow/stream_executor/trace_listener.h" -namespace perftools { -namespace gputools { +namespace stream_executor { namespace host { // Host (CPU) platform plugin, registered as a singleton value via module @@ -79,7 +78,6 @@ class HostPlatform : public Platform { }; } // namespace host -} // namespace gputools -} // namespace perftools +} // namespace stream_executor #endif // TENSORFLOW_STREAM_EXECUTOR_HOST_HOST_PLATFORM_H_ diff --git a/tensorflow/stream_executor/host/host_platform_id.cc b/tensorflow/stream_executor/host/host_platform_id.cc index 69a203f2985..2256bccec3f 100644 --- a/tensorflow/stream_executor/host/host_platform_id.cc +++ b/tensorflow/stream_executor/host/host_platform_id.cc @@ -15,12 +15,10 @@ limitations under the License. #include "tensorflow/stream_executor/host/host_platform_id.h" -namespace perftools { -namespace gputools { +namespace stream_executor { namespace host { PLATFORM_DEFINE_ID(kHostPlatformId); } // namespace host -} // namespace gputools -} // namespace perftools +} // namespace stream_executor diff --git a/tensorflow/stream_executor/host/host_platform_id.h b/tensorflow/stream_executor/host/host_platform_id.h index 61d84ea2e2f..18d1f282f1f 100644 --- a/tensorflow/stream_executor/host/host_platform_id.h +++ b/tensorflow/stream_executor/host/host_platform_id.h @@ -18,8 +18,7 @@ limitations under the License. #include "tensorflow/stream_executor/platform.h" -namespace perftools { -namespace gputools { +namespace stream_executor { namespace host { // Opaque and unique identifier for the host platform. @@ -30,7 +29,6 @@ namespace host { extern const Platform::Id kHostPlatformId; } // namespace host -} // namespace gputools -} // namespace perftools +} // namespace stream_executor #endif // TENSORFLOW_STREAM_EXECUTOR_HOST_HOST_PLATFORM_ID_H_ diff --git a/tensorflow/stream_executor/host/host_stream.cc b/tensorflow/stream_executor/host/host_stream.cc index 5961c315160..5a7d3b3dd49 100644 --- a/tensorflow/stream_executor/host/host_stream.cc +++ b/tensorflow/stream_executor/host/host_stream.cc @@ -17,8 +17,7 @@ limitations under the License. // the HostExecutor implementation. #include "tensorflow/stream_executor/host/host_stream.h" -namespace perftools { -namespace gputools { +namespace stream_executor { namespace host { HostStream::HostStream() @@ -53,5 +52,4 @@ void HostStream::BlockUntilDone() { } // namespace host -} // namespace gputools -} // namespace perftools +} // namespace stream_executor diff --git a/tensorflow/stream_executor/host/host_stream.h b/tensorflow/stream_executor/host/host_stream.h index 9894d17febc..5d7b8a37826 100644 --- a/tensorflow/stream_executor/host/host_stream.h +++ b/tensorflow/stream_executor/host/host_stream.h @@ -24,8 +24,7 @@ limitations under the License. #include "tensorflow/stream_executor/lib/threadpool.h" #include "tensorflow/stream_executor/stream_executor_internal.h" -namespace perftools { -namespace gputools { +namespace stream_executor { namespace host { class HostStream : public internal::StreamInterface { @@ -52,7 +51,6 @@ class HostStream : public internal::StreamInterface { }; } // namespace host -} // namespace gputools -} // namespace perftools +} // namespace stream_executor #endif // TENSORFLOW_STREAM_EXECUTOR_HOST_HOST_STREAM_H_ diff --git a/tensorflow/stream_executor/host/host_timer.cc b/tensorflow/stream_executor/host/host_timer.cc index d84d825c92a..e138daf0e11 100644 --- a/tensorflow/stream_executor/host/host_timer.cc +++ b/tensorflow/stream_executor/host/host_timer.cc @@ -19,8 +19,7 @@ limitations under the License. #include "tensorflow/stream_executor/stream.h" #include "tensorflow/stream_executor/stream_executor_internal.h" -namespace perftools { -namespace gputools { +namespace stream_executor { namespace host { using std::chrono::duration_cast; @@ -46,5 +45,4 @@ void HostTimer::StartNow() { start_time_ = clock::now(); } void HostTimer::StopNow() { duration_ = clock::now() - start_time_; } } // namespace host -} // namespace gputools -} // namespace perftools +} // namespace stream_executor diff --git a/tensorflow/stream_executor/host/host_timer.h b/tensorflow/stream_executor/host/host_timer.h index 17af7c0521d..5954b8023be 100644 --- a/tensorflow/stream_executor/host/host_timer.h +++ b/tensorflow/stream_executor/host/host_timer.h @@ -20,8 +20,7 @@ limitations under the License. #include "tensorflow/stream_executor/stream_executor_internal.h" -namespace perftools { -namespace gputools { +namespace stream_executor { namespace host { class HostTimer : public internal::TimerInterface { @@ -57,7 +56,6 @@ class HostTimer : public internal::TimerInterface { }; } // namespace host -} // namespace gputools -} // namespace perftools +} // namespace stream_executor #endif // TENSORFLOW_STREAM_EXECUTOR_HOST_HOST_TIMER_H_ diff --git a/tensorflow/stream_executor/host_buffer.h b/tensorflow/stream_executor/host_buffer.h index 8fa542e9ff8..20299da5172 100644 --- a/tensorflow/stream_executor/host_buffer.h +++ b/tensorflow/stream_executor/host_buffer.h @@ -18,8 +18,7 @@ limitations under the License. #include "tensorflow/stream_executor/dnn.h" -namespace perftools { -namespace gputools { +namespace stream_executor { // A HostBuffer is a block of memory in host memory containing the data for a // dnn::BatchDescriptor using a device-dependent memory layout. @@ -42,7 +41,6 @@ class HostBuffer { const dnn::BatchDescriptor descriptor_; }; -} // namespace gputools -} // namespace perftools +} // namespace stream_executor #endif // TENSORFLOW_STREAM_EXECUTOR_HOST_BUFFER_H_ diff --git a/tensorflow/stream_executor/kernel.cc b/tensorflow/stream_executor/kernel.cc index 636199cfa27..d1aa596b73d 100644 --- a/tensorflow/stream_executor/kernel.cc +++ b/tensorflow/stream_executor/kernel.cc @@ -27,8 +27,7 @@ limitations under the License. #include "tensorflow/stream_executor/platform/logging.h" #include "tensorflow/stream_executor/stream_executor.h" -namespace perftools { -namespace gputools { +namespace stream_executor { bool KernelMetadata::registers_per_thread(int *registers_per_thread) const { if (has_registers_per_thread_) { @@ -103,5 +102,4 @@ void KernelBase::set_name(port::StringPiece name) { demangled_name_ = port::Demangle(stubless_name.data()); } -} // namespace gputools -} // namespace perftools +} // namespace stream_executor diff --git a/tensorflow/stream_executor/kernel.h b/tensorflow/stream_executor/kernel.h index 5358eac1ae0..2216884b873 100644 --- a/tensorflow/stream_executor/kernel.h +++ b/tensorflow/stream_executor/kernel.h @@ -64,7 +64,7 @@ limitations under the License. // // Users typically won't need to type out the TypedKernel signature in full, it // will be typedef'd by automatically generated code; for example, see -// perftools::gputools::executor_sample::VecReduceAddKernel. +// stream_executor::executor_sample::VecReduceAddKernel. #ifndef TENSORFLOW_STREAM_EXECUTOR_KERNEL_H_ #define TENSORFLOW_STREAM_EXECUTOR_KERNEL_H_ @@ -82,8 +82,7 @@ limitations under the License. #include "tensorflow/stream_executor/lib/stringpiece.h" #include "tensorflow/stream_executor/platform/port.h" -namespace perftools { -namespace gputools { +namespace stream_executor { class DeviceMemoryBase; template @@ -639,8 +638,8 @@ struct KernelInvocationChecker { // NOTE: if you encounter an error here, you can see the mismatch by looking // at the end of the last error message, which will be of the form: // - // ...::Compatible &, - // perftools::gputools::DeviceMemory, true, + // ...::Compatible &, + // stream_executor::DeviceMemory, true, // 0>' // requested here // @@ -711,7 +710,6 @@ struct KernelParamsOk, Args...> { std::tuple, std::tuple>::CheckAllNoStaticAssert(); }; -} // namespace gputools -} // namespace perftools +} // namespace stream_executor #endif // TENSORFLOW_STREAM_EXECUTOR_KERNEL_H_ diff --git a/tensorflow/stream_executor/kernel_cache_config.h b/tensorflow/stream_executor/kernel_cache_config.h index 9d7ab1b79f6..e63d6c6a0c0 100644 --- a/tensorflow/stream_executor/kernel_cache_config.h +++ b/tensorflow/stream_executor/kernel_cache_config.h @@ -18,8 +18,7 @@ limitations under the License. #ifndef TENSORFLOW_STREAM_EXECUTOR_KERNEL_CACHE_CONFIG_H_ #define TENSORFLOW_STREAM_EXECUTOR_KERNEL_CACHE_CONFIG_H_ -namespace perftools { -namespace gputools { +namespace stream_executor { // This enum represents potential configurations of L1/shared memory when // running a particular kernel. These values represent user preference, and @@ -38,7 +37,6 @@ enum class KernelCacheConfig { kPreferEqual, }; -} // namespace gputools -} // namespace perftools +} // namespace stream_executor #endif // TENSORFLOW_STREAM_EXECUTOR_KERNEL_CACHE_CONFIG_H_ diff --git a/tensorflow/stream_executor/kernel_spec.cc b/tensorflow/stream_executor/kernel_spec.cc index 0404c573f01..6a1f0a591ff 100644 --- a/tensorflow/stream_executor/kernel_spec.cc +++ b/tensorflow/stream_executor/kernel_spec.cc @@ -15,9 +15,7 @@ limitations under the License. #include "tensorflow/stream_executor/kernel_spec.h" - -namespace perftools { -namespace gputools { +namespace stream_executor { KernelLoaderSpec::KernelLoaderSpec(port::StringPiece kernelname) : kernelname_(kernelname.ToString()) {} @@ -247,5 +245,4 @@ MultiKernelLoaderSpec *MultiKernelLoaderSpec::AddCudaCompressedPtxInMemory( MultiKernelLoaderSpec::MultiKernelLoaderSpec(size_t arity) : arity_(arity) {} -} // namespace gputools -} // namespace perftools +} // namespace stream_executor diff --git a/tensorflow/stream_executor/kernel_spec.h b/tensorflow/stream_executor/kernel_spec.h index 3811bd833e7..7cc23bb4e64 100644 --- a/tensorflow/stream_executor/kernel_spec.h +++ b/tensorflow/stream_executor/kernel_spec.h @@ -56,8 +56,7 @@ limitations under the License. #include "tensorflow/stream_executor/platform/mutex.h" #include "tensorflow/stream_executor/platform/port.h" -namespace perftools { -namespace gputools { +namespace stream_executor { // Describes how to load a kernel on a target platform. // @@ -374,7 +373,6 @@ class MultiKernelLoaderSpec { size_t arity_; }; -} // namespace gputools -} // namespace perftools +} // namespace stream_executor #endif // TENSORFLOW_STREAM_EXECUTOR_KERNEL_SPEC_H_ diff --git a/tensorflow/stream_executor/launch_dim.h b/tensorflow/stream_executor/launch_dim.h index b95462667ee..68f2f748407 100644 --- a/tensorflow/stream_executor/launch_dim.h +++ b/tensorflow/stream_executor/launch_dim.h @@ -21,7 +21,7 @@ limitations under the License. // a single PC in a unit called a warp. There is a maximum number of threads // that can execute in a shared-context entity called a block. Presently, that // number is 1024 -- again, something that should not be relied on from this -// comment, but checked via perftools::gputools::DeviceDescription. +// comment, but checked via stream_executor::DeviceDescription. // // For additional information, see // http://docs.nvidia.com/cuda/kepler-tuning-guide/#device-utilization-and-occupancy @@ -40,8 +40,7 @@ limitations under the License. #include "tensorflow/stream_executor/lib/strcat.h" #include "tensorflow/stream_executor/platform/port.h" -namespace perftools { -namespace gputools { +namespace stream_executor { // Basic type that represents a 3-dimensional index space. struct Dim3D { @@ -74,7 +73,6 @@ struct BlockDim : public Dim3D { } }; -} // namespace gputools -} // namespace perftools +} // namespace stream_executor #endif // TENSORFLOW_STREAM_EXECUTOR_LAUNCH_DIM_H_ diff --git a/tensorflow/stream_executor/lib/array_slice.h b/tensorflow/stream_executor/lib/array_slice.h index bef61bb2fc5..8e3c4ca047b 100644 --- a/tensorflow/stream_executor/lib/array_slice.h +++ b/tensorflow/stream_executor/lib/array_slice.h @@ -18,14 +18,23 @@ limitations under the License. #include "tensorflow/core/lib/gtl/array_slice.h" -namespace perftools { -namespace gputools { +namespace stream_executor { namespace port { using tensorflow::gtl::ArraySlice; using tensorflow::gtl::MutableArraySlice; } // namespace port +} // namespace stream_executor + +namespace perftools { +namespace gputools { + +// Temporarily pull stream_executor into perftools::gputools while we migrate +// code to the new namespace. TODO(b/77980417): Remove this once we've +// completed the migration. +using namespace stream_executor; // NOLINT[build/namespaces] + } // namespace gputools } // namespace perftools diff --git a/tensorflow/stream_executor/lib/casts.h b/tensorflow/stream_executor/lib/casts.h index 2261944e252..ec562e804fa 100644 --- a/tensorflow/stream_executor/lib/casts.h +++ b/tensorflow/stream_executor/lib/casts.h @@ -13,15 +13,14 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -// IWYU pragma: private, include "perftools/gputools/executor/stream_executor.h" +// IWYU pragma: private, include "third_party/tensorflow/stream_executor/stream_executor.h" #ifndef TENSORFLOW_STREAM_EXECUTOR_LIB_CASTS_H_ #define TENSORFLOW_STREAM_EXECUTOR_LIB_CASTS_H_ #include -namespace perftools { -namespace gputools { +namespace stream_executor { namespace port { // port::bit_cast is a template function that implements the @@ -96,7 +95,6 @@ inline Dest bit_cast(const Source& source) { } } // namespace port -} // namespace gputools -} // namespace perftools +} // namespace stream_executor #endif // TENSORFLOW_STREAM_EXECUTOR_LIB_CASTS_H_ diff --git a/tensorflow/stream_executor/lib/demangle.cc b/tensorflow/stream_executor/lib/demangle.cc index fa2b4fa005c..adb6b4f2d11 100644 --- a/tensorflow/stream_executor/lib/demangle.cc +++ b/tensorflow/stream_executor/lib/demangle.cc @@ -27,8 +27,7 @@ limitations under the License. #include #endif -namespace perftools { -namespace gputools { +namespace stream_executor { namespace port { // The API reference of abi::__cxa_demangle() can be found in @@ -49,5 +48,4 @@ string Demangle(const char *mangled) { } } // namespace port -} // namespace gputools -} // namespace perftools +} // namespace stream_executor diff --git a/tensorflow/stream_executor/lib/demangle.h b/tensorflow/stream_executor/lib/demangle.h index 30be5225578..af16fa7d8cb 100644 --- a/tensorflow/stream_executor/lib/demangle.h +++ b/tensorflow/stream_executor/lib/demangle.h @@ -18,14 +18,12 @@ limitations under the License. #include "tensorflow/stream_executor/platform/port.h" -namespace perftools { -namespace gputools { +namespace stream_executor { namespace port { string Demangle(const char* mangled); } // namespace port -} // namespace gputools -} // namespace perftools +} // namespace stream_executor #endif // TENSORFLOW_STREAM_EXECUTOR_LIB_DEMANGLE_H_ diff --git a/tensorflow/stream_executor/lib/env.h b/tensorflow/stream_executor/lib/env.h index c9a22ebd558..776eba04080 100644 --- a/tensorflow/stream_executor/lib/env.h +++ b/tensorflow/stream_executor/lib/env.h @@ -21,8 +21,7 @@ limitations under the License. #include "tensorflow/stream_executor/lib/stringpiece.h" #include "tensorflow/stream_executor/platform/port.h" -namespace perftools { -namespace gputools { +namespace stream_executor { namespace port { using tensorflow::Env; @@ -37,7 +36,6 @@ inline Status FileExists(const port::StringPiece& filename) { } } // namespace port -} // namespace gputools -} // namespace perftools +} // namespace stream_executor #endif // TENSORFLOW_STREAM_EXECUTOR_LIB_ENV_H_ diff --git a/tensorflow/stream_executor/lib/error.h b/tensorflow/stream_executor/lib/error.h index 89df70cb5e5..c659f5fc140 100644 --- a/tensorflow/stream_executor/lib/error.h +++ b/tensorflow/stream_executor/lib/error.h @@ -13,21 +13,19 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -// IWYU pragma: private, include "perftools/gputools/executor/stream_executor.h" +// IWYU pragma: private, include "third_party/tensorflow/stream_executor/stream_executor.h" #ifndef TENSORFLOW_STREAM_EXECUTOR_LIB_ERROR_H_ #define TENSORFLOW_STREAM_EXECUTOR_LIB_ERROR_H_ #include "tensorflow/core/lib/core/error_codes.pb.h" // IWYU pragma: export -namespace perftools { -namespace gputools { +namespace stream_executor { namespace port { namespace error = tensorflow::error; } // namespace port -} // namespace gputools -} // namespace perftools +} // namespace stream_executor #endif // TENSORFLOW_STREAM_EXECUTOR_LIB_ERROR_H_ diff --git a/tensorflow/stream_executor/lib/human_readable.h b/tensorflow/stream_executor/lib/human_readable.h index f918c180d98..893865f6dad 100644 --- a/tensorflow/stream_executor/lib/human_readable.h +++ b/tensorflow/stream_executor/lib/human_readable.h @@ -22,8 +22,7 @@ limitations under the License. #include "tensorflow/stream_executor/lib/stringprintf.h" #include "tensorflow/stream_executor/platform/port.h" -namespace perftools { -namespace gputools { +namespace stream_executor { namespace port { class HumanReadableNumBytes { @@ -67,7 +66,6 @@ class HumanReadableNumBytes { }; } // namespace port -} // namespace gputools -} // namespace perftools +} // namespace stream_executor #endif // TENSORFLOW_STREAM_EXECUTOR_LIB_HUMAN_READABLE_H_ diff --git a/tensorflow/stream_executor/lib/initialize.h b/tensorflow/stream_executor/lib/initialize.h index 9a09318a6cb..688b0214694 100644 --- a/tensorflow/stream_executor/lib/initialize.h +++ b/tensorflow/stream_executor/lib/initialize.h @@ -26,8 +26,7 @@ limitations under the License. #undef DECLARE_MODULE_INITIALIZER #undef REGISTER_MODULE_INITIALIZER_SEQUENCE -namespace perftools { -namespace gputools { +namespace stream_executor { namespace port { class Initializer { @@ -49,20 +48,18 @@ class Initializer { }; } // namespace port -} // namespace gputools -} // namespace perftools +} // namespace stream_executor -#define REGISTER_INITIALIZER(type, name, body) \ - static void google_init_##type##_##name() { body; } \ - perftools::gputools::port::Initializer google_initializer_##type##_##name( \ +#define REGISTER_INITIALIZER(type, name, body) \ + static void google_init_##type##_##name() { body; } \ + ::stream_executor::port::Initializer google_initializer_##type##_##name( \ google_init_##type##_##name) #define REGISTER_MODULE_INITIALIZER(name, body) \ REGISTER_INITIALIZER(module, name, body) -#define DECLARE_INITIALIZER(type, name) \ - extern perftools::gputools::port::Initializer \ - google_initializer_##type##_##name +#define DECLARE_INITIALIZER(type, name) \ + extern ::stream_executor::port::Initializer google_initializer_##type##_##name #define DECLARE_MODULE_INITIALIZER(name) DECLARE_INITIALIZER(module, name) diff --git a/tensorflow/stream_executor/lib/inlined_vector.h b/tensorflow/stream_executor/lib/inlined_vector.h index 55a1e3ad102..40bdddb180f 100644 --- a/tensorflow/stream_executor/lib/inlined_vector.h +++ b/tensorflow/stream_executor/lib/inlined_vector.h @@ -18,14 +18,12 @@ limitations under the License. #include "tensorflow/core/lib/gtl/inlined_vector.h" -namespace perftools { -namespace gputools { +namespace stream_executor { namespace port { using tensorflow::gtl::InlinedVector; } // namespace port -} // namespace gputools -} // namespace perftools +} // namespace stream_executor #endif // TENSORFLOW_STREAM_EXECUTOR_LIB_INLINED_VECTOR_H_ diff --git a/tensorflow/stream_executor/lib/mathutil.h b/tensorflow/stream_executor/lib/mathutil.h index e8310d55dda..c225dc5f3cc 100644 --- a/tensorflow/stream_executor/lib/mathutil.h +++ b/tensorflow/stream_executor/lib/mathutil.h @@ -25,8 +25,7 @@ limitations under the License. #include "tensorflow/stream_executor/platform/logging.h" #include "tensorflow/stream_executor/platform/port.h" -namespace perftools { -namespace gputools { +namespace stream_executor { namespace port { class MathUtil { @@ -97,7 +96,6 @@ IntegralType MathUtil::CeilOrFloorOfRatio(IntegralType numerator, } } // namespace port -} // namespace gputools -} // namespace perftools +} // namespace stream_executor #endif // TENSORFLOW_STREAM_EXECUTOR_LIB_MATHUTIL_H_ diff --git a/tensorflow/stream_executor/lib/notification.h b/tensorflow/stream_executor/lib/notification.h index 9bb3e170dc7..472d8c9845c 100644 --- a/tensorflow/stream_executor/lib/notification.h +++ b/tensorflow/stream_executor/lib/notification.h @@ -18,14 +18,12 @@ limitations under the License. #include "tensorflow/core/platform/notification.h" -namespace perftools { -namespace gputools { +namespace stream_executor { namespace port { using tensorflow::Notification; } // namespace port -} // namespace gputools -} // namespace perftools +} // namespace stream_executor #endif // TENSORFLOW_STREAM_EXECUTOR_LIB_NOTIFICATION_H_ diff --git a/tensorflow/stream_executor/lib/numbers.cc b/tensorflow/stream_executor/lib/numbers.cc index 11a65e198d6..b670c42ec84 100644 --- a/tensorflow/stream_executor/lib/numbers.cc +++ b/tensorflow/stream_executor/lib/numbers.cc @@ -17,8 +17,7 @@ limitations under the License. #include -namespace perftools { -namespace gputools { +namespace stream_executor { namespace port { bool safe_strto32(const char* str, int32* value) { @@ -38,5 +37,4 @@ bool safe_strto32(const string& str, int32* value) { } } // namespace port -} // namespace gputools -} // namespace perftools +} // namespace stream_executor diff --git a/tensorflow/stream_executor/lib/numbers.h b/tensorflow/stream_executor/lib/numbers.h index 4a8692b7461..2f48281d2d6 100644 --- a/tensorflow/stream_executor/lib/numbers.h +++ b/tensorflow/stream_executor/lib/numbers.h @@ -18,8 +18,7 @@ limitations under the License. #include "tensorflow/stream_executor/platform/port.h" -namespace perftools { -namespace gputools { +namespace stream_executor { namespace port { // Convert strings to floating point values. @@ -28,7 +27,6 @@ namespace port { bool safe_strto32(const string& str, int32* value); } // namespace port -} // namespace gputools -} // namespace perftools +} // namespace stream_executor #endif // TENSORFLOW_STREAM_EXECUTOR_LIB_NUMBERS_H_ diff --git a/tensorflow/stream_executor/lib/path.cc b/tensorflow/stream_executor/lib/path.cc index f2591f47f7b..56e08c316f9 100644 --- a/tensorflow/stream_executor/lib/path.cc +++ b/tensorflow/stream_executor/lib/path.cc @@ -16,8 +16,7 @@ limitations under the License. #include "tensorflow/stream_executor/lib/path.h" #include "tensorflow/stream_executor/lib/strcat.h" -namespace perftools { -namespace gputools { +namespace stream_executor { namespace port { namespace internal { @@ -58,5 +57,4 @@ string JoinPathImpl(std::initializer_list paths) { } // namespace internal } // namespace port -} // namespace gputools -} // namespace perftools +} // namespace stream_executor diff --git a/tensorflow/stream_executor/lib/path.h b/tensorflow/stream_executor/lib/path.h index 93053dbcb68..325f04ff475 100644 --- a/tensorflow/stream_executor/lib/path.h +++ b/tensorflow/stream_executor/lib/path.h @@ -20,8 +20,7 @@ limitations under the License. #include "tensorflow/stream_executor/lib/stringpiece.h" #include "tensorflow/stream_executor/platform/port.h" -namespace perftools { -namespace gputools { +namespace stream_executor { namespace port { using tensorflow::io::Dirname; @@ -56,7 +55,6 @@ inline string JoinPath(const T&... args) { } } // namespace port -} // namespace gputools -} // namespace perftools +} // namespace stream_executor #endif // TENSORFLOW_STREAM_EXECUTOR_LIB_PATH_H_ diff --git a/tensorflow/stream_executor/lib/process_state.cc b/tensorflow/stream_executor/lib/process_state.cc index 3d856187f05..72d71e62116 100644 --- a/tensorflow/stream_executor/lib/process_state.cc +++ b/tensorflow/stream_executor/lib/process_state.cc @@ -25,8 +25,7 @@ limitations under the License. #endif #include -namespace perftools { -namespace gputools { +namespace stream_executor { namespace port { string Hostname() { @@ -54,5 +53,4 @@ bool GetCurrentDirectory(string* dir) { } } // namespace port -} // namespace gputools -} // namespace perftools +} // namespace stream_executor diff --git a/tensorflow/stream_executor/lib/process_state.h b/tensorflow/stream_executor/lib/process_state.h index 205e726d95c..248218c759e 100644 --- a/tensorflow/stream_executor/lib/process_state.h +++ b/tensorflow/stream_executor/lib/process_state.h @@ -18,15 +18,13 @@ limitations under the License. #include "tensorflow/stream_executor/platform/port.h" -namespace perftools { -namespace gputools { +namespace stream_executor { namespace port { string Hostname(); bool GetCurrentDirectory(string* dir); } // namespace port -} // namespace gputools -} // namespace perftools +} // namespace stream_executor #endif // TENSORFLOW_STREAM_EXECUTOR_LIB_PROCESS_STATE_H_ diff --git a/tensorflow/stream_executor/lib/ptr_util.h b/tensorflow/stream_executor/lib/ptr_util.h index 3d5e56faf74..3f89794688c 100644 --- a/tensorflow/stream_executor/lib/ptr_util.h +++ b/tensorflow/stream_executor/lib/ptr_util.h @@ -18,8 +18,7 @@ limitations under the License. #include -namespace perftools { -namespace gputools { +namespace stream_executor { namespace port { // Trait to select overloads and return types for MakeUnique. @@ -59,8 +58,17 @@ typename MakeUniqueResult::invalid MakeUnique(Args&&... /* args */) = delete; // NOLINT } // namespace port +} // namespace stream_executor + +namespace perftools { +namespace gputools { + +// Temporarily pull stream_executor into perftools::gputools while we migrate +// code to the new namespace. TODO(jlebar): Remove this once we've completed +// the migration. +using namespace stream_executor; // NOLINT[build/namespaces] + } // namespace gputools } // namespace perftools - #endif // TENSORFLOW_STREAM_EXECUTOR_LIB_PTR_UTIL_H_ diff --git a/tensorflow/stream_executor/lib/stacktrace.h b/tensorflow/stream_executor/lib/stacktrace.h index ba7e5317f0f..a15b0f30261 100644 --- a/tensorflow/stream_executor/lib/stacktrace.h +++ b/tensorflow/stream_executor/lib/stacktrace.h @@ -19,14 +19,12 @@ limitations under the License. #include "tensorflow/core/platform/stacktrace.h" #include "tensorflow/stream_executor/platform/port.h" -namespace perftools { -namespace gputools { +namespace stream_executor { namespace port { using tensorflow::CurrentStackTrace; } // namespace port -} // namespace gputools -} // namespace perftools +} // namespace stream_executor #endif // TENSORFLOW_STREAM_EXECUTOR_LIB_STACKTRACE_H_ diff --git a/tensorflow/stream_executor/lib/status.h b/tensorflow/stream_executor/lib/status.h index 8c289e1927f..407b71b405b 100644 --- a/tensorflow/stream_executor/lib/status.h +++ b/tensorflow/stream_executor/lib/status.h @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -// IWYU pragma: private, include "perftools/gputools/executor/stream_executor.h" +// IWYU pragma: private, include "third_party/tensorflow/stream_executor/stream_executor.h" #ifndef TENSORFLOW_STREAM_EXECUTOR_LIB_STATUS_H_ #define TENSORFLOW_STREAM_EXECUTOR_LIB_STATUS_H_ @@ -23,15 +23,14 @@ limitations under the License. #include "tensorflow/stream_executor/lib/stringpiece.h" #include "tensorflow/stream_executor/platform/logging.h" -namespace perftools { -namespace gputools { +namespace stream_executor { namespace port { using Status = tensorflow::Status; #define SE_CHECK_OK(val) TF_CHECK_OK(val) #define SE_ASSERT_OK(val) \ - ASSERT_EQ(::perftools::gputools::port::Status::OK(), (val)) + ASSERT_EQ(::stream_executor::port::Status::OK(), (val)) // Define some canonical error helpers. inline Status UnimplementedError(StringPiece message) { @@ -45,6 +44,16 @@ inline Status FailedPreconditionError(StringPiece message) { } } // namespace port +} // namespace stream_executor + +namespace perftools { +namespace gputools { + +// Temporarily pull stream_executor into perftools::gputools while we migrate +// code to the new namespace. TODO(b/77980417): Remove this once we've +// completed the migration. +using namespace stream_executor; // NOLINT[build/namespaces] + } // namespace gputools } // namespace perftools diff --git a/tensorflow/stream_executor/lib/statusor.h b/tensorflow/stream_executor/lib/statusor.h index 3b97929b37d..dab59096740 100644 --- a/tensorflow/stream_executor/lib/statusor.h +++ b/tensorflow/stream_executor/lib/statusor.h @@ -13,15 +13,14 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -// IWYU pragma: private, include "perftools/gputools/executor/stream_executor.h" +// IWYU pragma: private, include "third_party/tensorflow/stream_executor/stream_executor.h" #ifndef TENSORFLOW_STREAM_EXECUTOR_LIB_STATUSOR_H_ #define TENSORFLOW_STREAM_EXECUTOR_LIB_STATUSOR_H_ #include "tensorflow/compiler/xla/statusor.h" -namespace perftools { -namespace gputools { +namespace stream_executor { namespace port { // Use XLA's StatusOr so we don't duplicate code. @@ -29,7 +28,6 @@ template using StatusOr = ::xla::StatusOr; } // namespace port -} // namespace gputools -} // namespace perftools +} // namespace stream_executor #endif // TENSORFLOW_STREAM_EXECUTOR_LIB_STATUSOR_H_ diff --git a/tensorflow/stream_executor/lib/str_util.h b/tensorflow/stream_executor/lib/str_util.h index 5dd3d06affa..a81c6668184 100644 --- a/tensorflow/stream_executor/lib/str_util.h +++ b/tensorflow/stream_executor/lib/str_util.h @@ -19,8 +19,7 @@ limitations under the License. #include "tensorflow/core/lib/strings/str_util.h" #include "tensorflow/stream_executor/lib/stringpiece.h" -namespace perftools { -namespace gputools { +namespace stream_executor { namespace port { using tensorflow::str_util::Join; @@ -38,7 +37,6 @@ inline string StripSuffixString(port::StringPiece str, port::StringPiece suffix) using tensorflow::str_util::Lowercase; } // namespace port -} // namespace gputools -} // namespace perftools +} // namespace stream_executor #endif // TENSORFLOW_STREAM_EXECUTOR_LIB_STR_UTIL_H_ diff --git a/tensorflow/stream_executor/lib/strcat.h b/tensorflow/stream_executor/lib/strcat.h index 424cb75f0e8..c959e4df5b2 100644 --- a/tensorflow/stream_executor/lib/strcat.h +++ b/tensorflow/stream_executor/lib/strcat.h @@ -13,22 +13,20 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -// IWYU pragma: private, include "perftools/gputools/executor/stream_executor.h" +// IWYU pragma: private, include "third_party/tensorflow/stream_executor/stream_executor.h" #ifndef TENSORFLOW_STREAM_EXECUTOR_LIB_STRCAT_H_ #define TENSORFLOW_STREAM_EXECUTOR_LIB_STRCAT_H_ #include "tensorflow/core/lib/strings/strcat.h" -namespace perftools { -namespace gputools { +namespace stream_executor { namespace port { using tensorflow::strings::StrCat; using tensorflow::strings::StrAppend; } // namespace port -} // namespace gputools -} // namespace perftools +} // namespace stream_executor #endif // TENSORFLOW_STREAM_EXECUTOR_LIB_STRCAT_H_ diff --git a/tensorflow/stream_executor/lib/stringpiece.h b/tensorflow/stream_executor/lib/stringpiece.h index 97ee0c92064..b80de5df306 100644 --- a/tensorflow/stream_executor/lib/stringpiece.h +++ b/tensorflow/stream_executor/lib/stringpiece.h @@ -19,14 +19,12 @@ limitations under the License. #include "tensorflow/core/lib/core/stringpiece.h" #include "tensorflow/stream_executor/platform/port.h" -namespace perftools { -namespace gputools { +namespace stream_executor { namespace port { using tensorflow::StringPiece; } // namespace port -} // namespace gputools -} // namespace perftools +} // namespace stream_executor #endif // TENSORFLOW_STREAM_EXECUTOR_LIB_STRINGPIECE_H_ diff --git a/tensorflow/stream_executor/lib/stringprintf.h b/tensorflow/stream_executor/lib/stringprintf.h index 504de25a681..2f65ed9c6a8 100644 --- a/tensorflow/stream_executor/lib/stringprintf.h +++ b/tensorflow/stream_executor/lib/stringprintf.h @@ -18,15 +18,13 @@ limitations under the License. #include "tensorflow/core/lib/strings/stringprintf.h" -namespace perftools { -namespace gputools { +namespace stream_executor { namespace port { using tensorflow::strings::Printf; using tensorflow::strings::Appendf; } // namespace port -} // namespace gputools -} // namespace perftools +} // namespace stream_executor #endif // TENSORFLOW_STREAM_EXECUTOR_LIB_STRINGPRINTF_H_ diff --git a/tensorflow/stream_executor/lib/thread_options.h b/tensorflow/stream_executor/lib/thread_options.h index bd7f63714e2..079cf757acd 100644 --- a/tensorflow/stream_executor/lib/thread_options.h +++ b/tensorflow/stream_executor/lib/thread_options.h @@ -18,14 +18,12 @@ limitations under the License. #include "tensorflow/core/platform/env.h" -namespace perftools { -namespace gputools { +namespace stream_executor { namespace port { using tensorflow::ThreadOptions; } // namespace port -} // namespace gputools -} // namespace perftools +} // namespace stream_executor #endif // TENSORFLOW_STREAM_EXECUTOR_LIB_THREAD_OPTIONS_H_ diff --git a/tensorflow/stream_executor/lib/threadpool.h b/tensorflow/stream_executor/lib/threadpool.h index 35630c5106a..220068ade11 100644 --- a/tensorflow/stream_executor/lib/threadpool.h +++ b/tensorflow/stream_executor/lib/threadpool.h @@ -21,14 +21,12 @@ limitations under the License. #include "tensorflow/stream_executor/lib/notification.h" #include "tensorflow/stream_executor/lib/thread_options.h" -namespace perftools { -namespace gputools { +namespace stream_executor { namespace port { using tensorflow::thread::ThreadPool; } // namespace port -} // namespace gputools -} // namespace perftools +} // namespace stream_executor #endif // TENSORFLOW_STREAM_EXECUTOR_LIB_THREADPOOL_H_ diff --git a/tensorflow/stream_executor/multi_platform_manager.cc b/tensorflow/stream_executor/multi_platform_manager.cc index f9f3737a06d..5b51398d8ca 100644 --- a/tensorflow/stream_executor/multi_platform_manager.cc +++ b/tensorflow/stream_executor/multi_platform_manager.cc @@ -20,8 +20,7 @@ limitations under the License. #include "tensorflow/stream_executor/lib/str_util.h" #include "tensorflow/stream_executor/lib/stringprintf.h" -namespace perftools { -namespace gputools { +namespace stream_executor { /* static */ mutex MultiPlatformManager::platforms_mutex_{LINKER_INITIALIZED}; @@ -132,8 +131,7 @@ MultiPlatformManager::InitializePlatformWithId( GetPlatformByIdMap()->clear(); } -} // namespace gputools -} // namespace perftools +} // namespace stream_executor REGISTER_MODULE_INITIALIZER( multi_platform_manager, diff --git a/tensorflow/stream_executor/multi_platform_manager.h b/tensorflow/stream_executor/multi_platform_manager.h index 438653ee20b..672855d5fb6 100644 --- a/tensorflow/stream_executor/multi_platform_manager.h +++ b/tensorflow/stream_executor/multi_platform_manager.h @@ -22,8 +22,8 @@ limitations under the License. // In your BUILD rule, add a dependency on a platform plugin that you'd like // to use, such as: // -// //perftools/gputools/executor/cuda:cuda_platform -// //perftools/gputools/executor/opencl:opencl_platform +// //third_party/tensorflow/stream_executor/cuda:cuda_platform +// //third_party/tensorflow/stream_executor/opencl:opencl_platform // // This will register platform plugins that can be discovered via this // interface. Sample API usage: @@ -56,10 +56,10 @@ limitations under the License. // And similarly, for standard interfaces (BLAS, RNG, etc.) you can add // dependencies on support libraries, e.g.: // -// //perftools/gputools/executor/cuda:pluton_blas_plugin -// //perftools/gputools/executor/cuda:cudnn_plugin -// //perftools/gputools/executor/cuda:cublas_plugin -// //perftools/gputools/executor/cuda:curand_plugin +// //third_party/tensorflow/stream_executor/cuda:pluton_blas_plugin +// //third_party/tensorflow/stream_executor/cuda:cudnn_plugin +// //third_party/tensorflow/stream_executor/cuda:cublas_plugin +// //third_party/tensorflow/stream_executor/cuda:curand_plugin #ifndef TENSORFLOW_STREAM_EXECUTOR_MULTI_PLATFORM_MANAGER_H_ #define TENSORFLOW_STREAM_EXECUTOR_MULTI_PLATFORM_MANAGER_H_ @@ -75,8 +75,7 @@ limitations under the License. #include "tensorflow/stream_executor/platform/port.h" #include "tensorflow/stream_executor/platform/thread_annotations.h" -namespace perftools { -namespace gputools { +namespace stream_executor { // Manages multiple platforms that may be present on the current machine. class MultiPlatformManager { @@ -181,7 +180,6 @@ class MultiPlatformManager { SE_DISALLOW_COPY_AND_ASSIGN(MultiPlatformManager); }; -} // namespace gputools -} // namespace perftools +} // namespace stream_executor #endif // TENSORFLOW_STREAM_EXECUTOR_MULTI_PLATFORM_MANAGER_H_ diff --git a/tensorflow/stream_executor/platform.cc b/tensorflow/stream_executor/platform.cc index 4cdc22bd16a..777abced863 100644 --- a/tensorflow/stream_executor/platform.cc +++ b/tensorflow/stream_executor/platform.cc @@ -22,8 +22,7 @@ limitations under the License. #include "tensorflow/stream_executor/platform/logging.h" #include "tensorflow/stream_executor/stream_executor_pimpl.h" -namespace perftools { -namespace gputools { +namespace stream_executor { string PlatformKindString(PlatformKind kind) { switch (kind) { @@ -135,5 +134,4 @@ port::Status Platform::EnablePeerAccess() { return port::Status::OK(); } -} // namespace gputools -} // namespace perftools +} // namespace stream_executor diff --git a/tensorflow/stream_executor/platform.h b/tensorflow/stream_executor/platform.h index 54f8aa86c26..5cb7047b6f3 100644 --- a/tensorflow/stream_executor/platform.h +++ b/tensorflow/stream_executor/platform.h @@ -29,8 +29,7 @@ limitations under the License. #include "tensorflow/stream_executor/plugin.h" #include "tensorflow/stream_executor/trace_listener.h" -namespace perftools { -namespace gputools { +namespace stream_executor { class StreamExecutor; @@ -106,7 +105,7 @@ class Platform { namespace { \ int plugin_id_value; \ } \ - const perftools::gputools::Platform::Id ID_VAR_NAME = &plugin_id_value; + const ::stream_executor::Platform::Id ID_VAR_NAME = &plugin_id_value; // Returns a key uniquely identifying this platform. virtual Id id() const = 0; @@ -205,7 +204,6 @@ class Platform { SE_DISALLOW_COPY_AND_ASSIGN(Platform); }; -} // namespace gputools -} // namespace perftools +} // namespace stream_executor #endif // TENSORFLOW_STREAM_EXECUTOR_PLATFORM_H_ diff --git a/tensorflow/stream_executor/platform/default/mutex.h b/tensorflow/stream_executor/platform/default/mutex.h index 62de0cbce0b..c9f5a7c609e 100644 --- a/tensorflow/stream_executor/platform/default/mutex.h +++ b/tensorflow/stream_executor/platform/default/mutex.h @@ -18,8 +18,7 @@ limitations under the License. #include "tensorflow/stream_executor/platform/mutex.h" -namespace perftools { -namespace gputools { +namespace stream_executor { #undef mutex_lock #undef tf_shared_lock @@ -35,7 +34,6 @@ using tensorflow::tf_shared_lock; #define tf_shared_lock(x) \ static_assert(0, "tf_shared_lock_decl_missing_var_name"); -} // namespace gputools -} // namespace perftools +} // namespace stream_executor #endif // TENSORFLOW_STREAM_EXECUTOR_PLATFORM_DEFAULT_MUTEX_H_ diff --git a/tensorflow/stream_executor/platform/port.h b/tensorflow/stream_executor/platform/port.h index 6603df48787..259cf380d6c 100644 --- a/tensorflow/stream_executor/platform/port.h +++ b/tensorflow/stream_executor/platform/port.h @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -// IWYU pragma: private, include "perftools/gputools/executor/stream_executor.h" +// IWYU pragma: private, include "third_party/tensorflow/stream_executor/stream_executor.h" #ifndef TENSORFLOW_STREAM_EXECUTOR_PLATFORM_PORT_H_ #define TENSORFLOW_STREAM_EXECUTOR_PLATFORM_PORT_H_ @@ -22,8 +22,7 @@ limitations under the License. #include "tensorflow/core/platform/mutex.h" #include "tensorflow/core/platform/types.h" -namespace perftools { -namespace gputools { +namespace stream_executor { using tensorflow::int8; using tensorflow::int16; @@ -50,8 +49,7 @@ using tensorflow::LINKER_INITIALIZED; #define SE_FALLTHROUGH_INTENDED TF_FALLTHROUGH_INTENDED -} // namespace gputools -} // namespace perftools +} // namespace stream_executor #define SE_DISALLOW_COPY_AND_ASSIGN TF_DISALLOW_COPY_AND_ASSIGN #define SE_MUST_USE_RESULT TF_MUST_USE_RESULT diff --git a/tensorflow/stream_executor/plugin.cc b/tensorflow/stream_executor/plugin.cc index 6424658e22f..cfbc52ff17b 100644 --- a/tensorflow/stream_executor/plugin.cc +++ b/tensorflow/stream_executor/plugin.cc @@ -15,8 +15,7 @@ limitations under the License. #include "tensorflow/stream_executor/plugin.h" -namespace perftools { -namespace gputools { +namespace stream_executor { // Mostly-arbitrary ID only used as a sentinel "not otherwise initialized" // value. This value should never [need to] be specified aside by initialization @@ -51,5 +50,4 @@ PluginConfig& PluginConfig::SetRng(PluginId rng) { return *this; } -} // namespace gputools -} // namespace perftools +} // namespace stream_executor diff --git a/tensorflow/stream_executor/plugin.h b/tensorflow/stream_executor/plugin.h index 0b88b86e2b1..0505412e7ac 100644 --- a/tensorflow/stream_executor/plugin.h +++ b/tensorflow/stream_executor/plugin.h @@ -16,8 +16,7 @@ limitations under the License. #ifndef TENSORFLOW_STREAM_EXECUTOR_PLUGIN_H_ #define TENSORFLOW_STREAM_EXECUTOR_PLUGIN_H_ -namespace perftools { -namespace gputools { +namespace stream_executor { // A plugin ID is a unique identifier for each registered plugin type. typedef void* PluginId; @@ -83,7 +82,6 @@ class PluginConfig { PluginId blas_, dnn_, fft_, rng_; }; -} // namespace gputools -} // namespace perftools +} // namespace stream_executor #endif // TENSORFLOW_STREAM_EXECUTOR_PLUGIN_H_ diff --git a/tensorflow/stream_executor/plugin_registry.cc b/tensorflow/stream_executor/plugin_registry.cc index 54761139eaf..7812703efd8 100644 --- a/tensorflow/stream_executor/plugin_registry.cc +++ b/tensorflow/stream_executor/plugin_registry.cc @@ -19,8 +19,7 @@ limitations under the License. #include "tensorflow/stream_executor/lib/stringprintf.h" #include "tensorflow/stream_executor/multi_platform_manager.h" -namespace perftools { -namespace gputools { +namespace stream_executor { const PluginId kNullPlugin = nullptr; @@ -244,5 +243,4 @@ EMIT_PLUGIN_SPECIALIZATIONS(DnnFactory, dnn, "DNN"); EMIT_PLUGIN_SPECIALIZATIONS(FftFactory, fft, "FFT"); EMIT_PLUGIN_SPECIALIZATIONS(RngFactory, rng, "RNG"); -} // namespace gputools -} // namespace perftools +} // namespace stream_executor diff --git a/tensorflow/stream_executor/plugin_registry.h b/tensorflow/stream_executor/plugin_registry.h index 8636a49ce68..49628ecd246 100644 --- a/tensorflow/stream_executor/plugin_registry.h +++ b/tensorflow/stream_executor/plugin_registry.h @@ -28,8 +28,7 @@ limitations under the License. #include "tensorflow/stream_executor/plugin.h" #include "tensorflow/stream_executor/rng.h" -namespace perftools { -namespace gputools { +namespace stream_executor { namespace internal { class StreamExecutorInterface; @@ -160,7 +159,6 @@ class PluginRegistry { SE_DISALLOW_COPY_AND_ASSIGN(PluginRegistry); }; -} // namespace gputools -} // namespace perftools +} // namespace stream_executor #endif // TENSORFLOW_STREAM_EXECUTOR_PLUGIN_REGISTRY_H_ diff --git a/tensorflow/stream_executor/rng.cc b/tensorflow/stream_executor/rng.cc index 1c05005067c..b0efad91084 100644 --- a/tensorflow/stream_executor/rng.cc +++ b/tensorflow/stream_executor/rng.cc @@ -17,8 +17,7 @@ limitations under the License. #include "tensorflow/stream_executor/platform/logging.h" -namespace perftools { -namespace gputools { +namespace stream_executor { namespace rng { bool RngSupport::CheckSeed(const uint8 *seed, uint64 seed_bytes) { @@ -47,5 +46,4 @@ const int RngSupport::kMaxSeedBytes; #endif } // namespace rng -} // namespace gputools -} // namespace perftools +} // namespace stream_executor diff --git a/tensorflow/stream_executor/rng.h b/tensorflow/stream_executor/rng.h index 36d0fdd454f..acbf8fce4ca 100644 --- a/tensorflow/stream_executor/rng.h +++ b/tensorflow/stream_executor/rng.h @@ -22,8 +22,7 @@ limitations under the License. #include "tensorflow/stream_executor/platform/logging.h" #include "tensorflow/stream_executor/platform/port.h" -namespace perftools { -namespace gputools { +namespace stream_executor { class Stream; template @@ -89,7 +88,6 @@ class RngSupport { }; } // namespace rng -} // namespace gputools -} // namespace perftools +} // namespace stream_executor #endif // TENSORFLOW_STREAM_EXECUTOR_RNG_H_ diff --git a/tensorflow/stream_executor/scratch_allocator.cc b/tensorflow/stream_executor/scratch_allocator.cc index 0c1db414f2e..8fc4c4c509c 100644 --- a/tensorflow/stream_executor/scratch_allocator.cc +++ b/tensorflow/stream_executor/scratch_allocator.cc @@ -18,8 +18,7 @@ limitations under the License. #include "tensorflow/stream_executor/lib/status_macros.h" #include "tensorflow/stream_executor/stream.h" -namespace perftools { -namespace gputools { +namespace stream_executor { ScratchAllocator::~ScratchAllocator() {} @@ -38,5 +37,4 @@ port::StatusOr> OneTimeScratchAllocator::AllocateBytes( return temporary_->device_memory(); } -} // namespace gputools -} // namespace perftools +} // namespace stream_executor diff --git a/tensorflow/stream_executor/scratch_allocator.h b/tensorflow/stream_executor/scratch_allocator.h index 94d5ede1613..2aed2c44373 100644 --- a/tensorflow/stream_executor/scratch_allocator.h +++ b/tensorflow/stream_executor/scratch_allocator.h @@ -23,8 +23,7 @@ limitations under the License. #include "tensorflow/stream_executor/platform/port.h" #include "tensorflow/stream_executor/temporary_device_memory.h" -namespace perftools { -namespace gputools { +namespace stream_executor { class Stream; @@ -77,7 +76,6 @@ class OneTimeScratchAllocator : public ScratchAllocator { SE_DISALLOW_COPY_AND_ASSIGN(OneTimeScratchAllocator); }; -} // namespace gputools -} // namespace perftools +} // namespace stream_executor #endif // TENSORFLOW_STREAM_EXECUTOR_SCRATCH_ALLOCATOR_H_ diff --git a/tensorflow/stream_executor/shared_memory_config.h b/tensorflow/stream_executor/shared_memory_config.h index de556cb7340..7cbeb3bcd91 100644 --- a/tensorflow/stream_executor/shared_memory_config.h +++ b/tensorflow/stream_executor/shared_memory_config.h @@ -19,8 +19,7 @@ limitations under the License. #ifndef TENSORFLOW_STREAM_EXECUTOR_SHARED_MEMORY_CONFIG_H_ #define TENSORFLOW_STREAM_EXECUTOR_SHARED_MEMORY_CONFIG_H_ -namespace perftools { -namespace gputools { +namespace stream_executor { // SharedMemoryConfig enum describes potential widths of shared memory banks for // a device or kernel. @@ -30,7 +29,6 @@ enum class SharedMemoryConfig { kEightByte, // Sets shared memory banks to be eight bytes wide. }; -} // namespace gputools -} // namespace perftools +} // namespace stream_executor #endif // TENSORFLOW_STREAM_EXECUTOR_SHARED_MEMORY_CONFIG_H_ diff --git a/tensorflow/stream_executor/stream.cc b/tensorflow/stream_executor/stream.cc index fe498507a80..f59d9a13acf 100644 --- a/tensorflow/stream_executor/stream.cc +++ b/tensorflow/stream_executor/stream.cc @@ -28,8 +28,7 @@ limitations under the License. #include "tensorflow/stream_executor/stream_executor_internal.h" #include "tensorflow/stream_executor/stream_executor_pimpl.h" -namespace perftools { -namespace gputools { +namespace stream_executor { namespace { // Code to turn parameters to functions on stream into strings that @@ -5192,5 +5191,4 @@ port::Status Stream::BlockHostUntilDone() { return first_error; } -} // namespace gputools -} // namespace perftools +} // namespace stream_executor diff --git a/tensorflow/stream_executor/stream.h b/tensorflow/stream_executor/stream.h index 4af426001f2..d4a81440e96 100644 --- a/tensorflow/stream_executor/stream.h +++ b/tensorflow/stream_executor/stream.h @@ -38,8 +38,7 @@ limitations under the License. #include "tensorflow/stream_executor/platform/thread_annotations.h" #include "tensorflow/stream_executor/temporary_memory_manager.h" -namespace perftools { -namespace gputools { +namespace stream_executor { namespace host { class HostBlas; @@ -2098,7 +2097,6 @@ struct Quantization { dnn::QuantizedActivationMode::k32Bit; }; -} // namespace gputools -} // namespace perftools +} // namespace stream_executor #endif // TENSORFLOW_STREAM_EXECUTOR_STREAM_H_ diff --git a/tensorflow/stream_executor/stream_executor.h b/tensorflow/stream_executor/stream_executor.h index 2995dccf469..d63d485df56 100644 --- a/tensorflow/stream_executor/stream_executor.h +++ b/tensorflow/stream_executor/stream_executor.h @@ -35,4 +35,15 @@ limitations under the License. #include "tensorflow/stream_executor/stream_executor_pimpl.h" // IWYU pragma: export #include "tensorflow/stream_executor/timer.h" // IWYU pragma: export +namespace perftools { +namespace gputools { + +// Temporarily pull stream_executor into perftools::gputools while we migrate +// code to the new namespace. TODO(b/77980417): Remove this once we've +// completed the migration. +using namespace stream_executor; // NOLINT[build/namespaces] + +} // namespace gputools +} // namespace perftools + #endif // TENSORFLOW_STREAM_EXECUTOR_STREAM_EXECUTOR_H_ diff --git a/tensorflow/stream_executor/stream_executor_internal.cc b/tensorflow/stream_executor/stream_executor_internal.cc index 273d970b6fa..8297228e6fe 100644 --- a/tensorflow/stream_executor/stream_executor_internal.cc +++ b/tensorflow/stream_executor/stream_executor_internal.cc @@ -15,8 +15,7 @@ limitations under the License. #include "tensorflow/stream_executor/stream_executor_internal.h" -namespace perftools { -namespace gputools { +namespace stream_executor { namespace internal { // -- CUDA @@ -38,5 +37,4 @@ StreamExecutorFactory* MakeOpenCLExecutorImplementation() { StreamExecutorFactory MakeHostExecutorImplementation; } // namespace internal -} // namespace gputools -} // namespace perftools +} // namespace stream_executor diff --git a/tensorflow/stream_executor/stream_executor_internal.h b/tensorflow/stream_executor/stream_executor_internal.h index 37ef182e144..2584c92f0c5 100644 --- a/tensorflow/stream_executor/stream_executor_internal.h +++ b/tensorflow/stream_executor/stream_executor_internal.h @@ -45,8 +45,7 @@ limitations under the License. #include "tensorflow/stream_executor/trace_listener.h" #include "tensorflow/stream_executor/lib/inlined_vector.h" -namespace perftools { -namespace gputools { +namespace stream_executor { class Stream; class Timer; @@ -343,7 +342,6 @@ extern StreamExecutorFactory MakeHostExecutorImplementation; } // namespace internal -} // namespace gputools -} // namespace perftools +} // namespace stream_executor #endif // TENSORFLOW_STREAM_EXECUTOR_STREAM_EXECUTOR_INTERNAL_H_ diff --git a/tensorflow/stream_executor/stream_executor_pimpl.cc b/tensorflow/stream_executor/stream_executor_pimpl.cc index f55fa684029..2e1adeb31e4 100644 --- a/tensorflow/stream_executor/stream_executor_pimpl.cc +++ b/tensorflow/stream_executor/stream_executor_pimpl.cc @@ -39,8 +39,7 @@ namespace { bool FLAGS_check_device_leaks = false; } // namespace -namespace perftools { -namespace gputools { +namespace stream_executor { namespace { string StackTraceIfVLOG10() { @@ -788,5 +787,4 @@ internal::StreamExecutorInterface *StreamExecutor::implementation() { return implementation_->GetUnderlyingExecutor(); } -} // namespace gputools -} // namespace perftools +} // namespace stream_executor diff --git a/tensorflow/stream_executor/stream_executor_pimpl.h b/tensorflow/stream_executor/stream_executor_pimpl.h index 69d0374d73d..39af7115d8f 100644 --- a/tensorflow/stream_executor/stream_executor_pimpl.h +++ b/tensorflow/stream_executor/stream_executor_pimpl.h @@ -37,8 +37,7 @@ limitations under the License. #include "tensorflow/stream_executor/stream_executor_internal.h" #include "tensorflow/stream_executor/trace_listener.h" -namespace perftools { -namespace gputools { +namespace stream_executor { // Structure used for device memory leak checking. struct AllocRecord { @@ -95,7 +94,7 @@ class StreamExecutor { // Parameters: // spec: The MultiKernelLoaderSpec is usually generated as a compile-time // constant into an appropriate namespace. For example, see - // perftools::gputools::executor_sample::kKernelLoaderSpecs, from which a + // stream_executor::executor_sample::kKernelLoaderSpecs, from which a // MultiKernelLoaderSpec is selected. // kernel: Outparam that the kernel is loaded into. A given Kernel // instantiation should not be loaded into more than once. @@ -803,7 +802,6 @@ inline Stream &Stream::ThenLaunch(ThreadDim thread_dims, BlockDim block_dims, return *this; } -} // namespace gputools -} // namespace perftools +} // namespace stream_executor #endif // TENSORFLOW_STREAM_EXECUTOR_STREAM_EXECUTOR_PIMPL_H_ diff --git a/tensorflow/stream_executor/temporary_device_memory.cc b/tensorflow/stream_executor/temporary_device_memory.cc index c33166b2246..f113ce9be57 100644 --- a/tensorflow/stream_executor/temporary_device_memory.cc +++ b/tensorflow/stream_executor/temporary_device_memory.cc @@ -17,8 +17,7 @@ limitations under the License. #include "tensorflow/stream_executor/stream.h" -namespace perftools { -namespace gputools { +namespace stream_executor { TemporaryDeviceMemoryBase::~TemporaryDeviceMemoryBase() { parent_->temporary_memory_manager()->MarkFinalized(device_memory_, @@ -64,5 +63,4 @@ TemporaryDeviceMemoryBase::TemporaryDeviceMemoryBase( DCHECK(IsAllocated()); } -} // namespace gputools -} // namespace perftools +} // namespace stream_executor diff --git a/tensorflow/stream_executor/temporary_device_memory.h b/tensorflow/stream_executor/temporary_device_memory.h index 2255e7ffd71..77be8599a2d 100644 --- a/tensorflow/stream_executor/temporary_device_memory.h +++ b/tensorflow/stream_executor/temporary_device_memory.h @@ -43,8 +43,7 @@ limitations under the License. #include "tensorflow/stream_executor/device_memory.h" -namespace perftools { -namespace gputools { +namespace stream_executor { class Stream; namespace internal { @@ -132,7 +131,6 @@ class TemporaryDeviceMemory : public TemporaryDeviceMemoryBase { } }; -} // namespace gputools -} // namespace perftools +} // namespace stream_executor #endif // TENSORFLOW_STREAM_EXECUTOR_TEMPORARY_DEVICE_MEMORY_H_ diff --git a/tensorflow/stream_executor/temporary_memory_manager.cc b/tensorflow/stream_executor/temporary_memory_manager.cc index 449ab7d3f0b..420dbb0933d 100644 --- a/tensorflow/stream_executor/temporary_memory_manager.cc +++ b/tensorflow/stream_executor/temporary_memory_manager.cc @@ -21,8 +21,7 @@ limitations under the License. #include "tensorflow/stream_executor/stream.h" #include "tensorflow/stream_executor/stream_executor_pimpl.h" -namespace perftools { -namespace gputools { +namespace stream_executor { namespace internal { void TemporaryMemoryManager::ForceDeallocateAll() { @@ -124,5 +123,4 @@ TemporaryMemoryManager::AllocateArrayBase(uint64 element_count, } } // namespace internal -} // namespace gputools -} // namespace perftools +} // namespace stream_executor diff --git a/tensorflow/stream_executor/temporary_memory_manager.h b/tensorflow/stream_executor/temporary_memory_manager.h index 2e6fbd9d62a..faf13380dc2 100644 --- a/tensorflow/stream_executor/temporary_memory_manager.h +++ b/tensorflow/stream_executor/temporary_memory_manager.h @@ -31,8 +31,7 @@ limitations under the License. #include "tensorflow/stream_executor/platform/thread_annotations.h" #include "tensorflow/stream_executor/temporary_device_memory.h" -namespace perftools { -namespace gputools { +namespace stream_executor { namespace internal { // Record used inside the TemporaryMemoryManager as metadata for a given device @@ -147,7 +146,6 @@ TemporaryMemoryManager::AllocateArray(uint64 element_count) { } } // namespace internal -} // namespace gputools -} // namespace perftools +} // namespace stream_executor #endif // TENSORFLOW_STREAM_EXECUTOR_TEMPORARY_MEMORY_MANAGER_H_ diff --git a/tensorflow/stream_executor/timer.cc b/tensorflow/stream_executor/timer.cc index 41d7e4359d4..a29791a1049 100644 --- a/tensorflow/stream_executor/timer.cc +++ b/tensorflow/stream_executor/timer.cc @@ -21,8 +21,7 @@ limitations under the License. #include "tensorflow/stream_executor/platform/logging.h" #include "tensorflow/stream_executor/stream_executor.h" -namespace perftools { -namespace gputools { +namespace stream_executor { Timer::Timer(StreamExecutor *parent) : parent_(parent), @@ -34,5 +33,4 @@ uint64 Timer::Microseconds() const { return implementation_->Microseconds(); } uint64 Timer::Nanoseconds() const { return implementation_->Nanoseconds(); } -} // namespace gputools -} // namespace perftools +} // namespace stream_executor diff --git a/tensorflow/stream_executor/timer.h b/tensorflow/stream_executor/timer.h index 0a37caa0f2f..fba7dd8f589 100644 --- a/tensorflow/stream_executor/timer.h +++ b/tensorflow/stream_executor/timer.h @@ -20,8 +20,7 @@ limitations under the License. #include "tensorflow/stream_executor/platform/port.h" -namespace perftools { -namespace gputools { +namespace stream_executor { namespace internal { class TimerInterface; @@ -69,7 +68,6 @@ class Timer { SE_DISALLOW_COPY_AND_ASSIGN(Timer); }; -} // namespace gputools -} // namespace perftools +} // namespace stream_executor #endif // TENSORFLOW_STREAM_EXECUTOR_TIMER_H_ diff --git a/tensorflow/stream_executor/trace_listener.h b/tensorflow/stream_executor/trace_listener.h index d1e87c348b1..0e874a1d47b 100644 --- a/tensorflow/stream_executor/trace_listener.h +++ b/tensorflow/stream_executor/trace_listener.h @@ -23,8 +23,7 @@ limitations under the License. #include "tensorflow/stream_executor/launch_dim.h" #include "tensorflow/stream_executor/lib/status.h" -namespace perftools { -namespace gputools { +namespace stream_executor { class Stream; @@ -69,7 +68,6 @@ class TraceListener { const port::Status* result) {} }; -} // namespace gputools -} // namespace perftools +} // namespace stream_executor #endif // TENSORFLOW_STREAM_EXECUTOR_TRACE_LISTENER_H_