[StreamExecutor] Rename ::perftools::gputools -> ::stream_executor, part 1.

Step 1 of re-namespace'ing StreamExecutor into ::stream_executor.

This moves everything inside of stream_executor/..., and leaves a
namespace alias into ::perftools::gputools.  The next steps will clean
up users to use the new namespace.

This is mostly a mechanical change, but it also includes a bunch of
non-mechanical changes that ideally would be split out into separate
patches.  Unfortunately they all sort of need to be shoved in here for
various reasons:

 - forward declarations need to be in the same namespace as the actual
   types, so we need to change all forward declarations of
   StreamExecutor types in this one patch.

 - Uses of these forward declarations need to be changed to the new
   namespace (or otherwise we need to add a namespace alias to the
   relevant header, but this is pretty ugly).

 - Various initialization code needs to live in StreamExecutor's "real"
   namespace, so all this needs to be changed.

PiperOrigin-RevId: 193256128
This commit is contained in:
Justin Lebar 2018-04-17 14:26:16 -07:00 committed by TensorFlower Gardener
parent 72df3d60fa
commit 4764bf2986
132 changed files with 566 additions and 738 deletions

View File

@ -36,12 +36,12 @@ DeviceMemoryAllocator* ExecutableRunOptions::allocator() const {
} }
ExecutableRunOptions& ExecutableRunOptions::set_stream( ExecutableRunOptions& ExecutableRunOptions::set_stream(
perftools::gputools::Stream* stream) { stream_executor::Stream* stream) {
stream_ = stream; stream_ = stream;
return *this; return *this;
} }
perftools::gputools::Stream* ExecutableRunOptions::stream() const { stream_executor::Stream* ExecutableRunOptions::stream() const {
return stream_; return stream_;
} }

View File

@ -16,29 +16,31 @@ limitations under the License.
#ifndef TENSORFLOW_COMPILER_XLA_EXECUTABLE_RUN_OPTIONS_H_ #ifndef TENSORFLOW_COMPILER_XLA_EXECUTABLE_RUN_OPTIONS_H_
#define TENSORFLOW_COMPILER_XLA_EXECUTABLE_RUN_OPTIONS_H_ #define TENSORFLOW_COMPILER_XLA_EXECUTABLE_RUN_OPTIONS_H_
// Intentionally forward declared so that ExecutableRunOptions can be linked // These classes are forward declared so that ExecutableRunOptions can be linked
// into an XLA-compiled binary without having to link all of the pointed-to // into an XLA-compiled binary without having to link all of the pointed-to
// objects (e.g., for an ahead-of-time compiled CPU binary, the gpu tools don't // objects (e.g., for an ahead-of-time compiled CPU binary, the gpu tools don't
// need to be linked). // need to be linked).
namespace perftools { namespace stream_executor {
namespace gputools {
class Stream; class Stream;
class Platform; class Platform;
} } // namespace stream_executor
}
namespace tensorflow { namespace tensorflow {
namespace thread { namespace thread {
class ThreadPool; class ThreadPool;
} } // namespace thread
} } // namespace tensorflow
namespace Eigen { namespace Eigen {
struct ThreadPoolDevice; struct ThreadPoolDevice;
} } // namespace Eigen
namespace xla { namespace xla {
// TODO(b/77980417): Once the perftools::gputools -> stream_executor migration
// is complete, add "using namespace se = stream_executor" here and
// s/stream_executor/se::/ to match our idiom elsewhere.
class DeviceMemoryAllocator; class DeviceMemoryAllocator;
class DeviceAssignment; class DeviceAssignment;
class ExecutionProfile; class ExecutionProfile;
@ -61,8 +63,8 @@ class ExecutableRunOptions {
// If set, this is the stream to run the computation on. The platform of the // If set, this is the stream to run the computation on. The platform of the
// stream must match the platform the executable was built for. A value of // stream must match the platform the executable was built for. A value of
// nullptr indicates the option has not been set. // nullptr indicates the option has not been set.
ExecutableRunOptions& set_stream(perftools::gputools::Stream* stream); ExecutableRunOptions& set_stream(stream_executor::Stream* stream);
perftools::gputools::Stream* stream() const; stream_executor::Stream* stream() const;
// Sets the thread pool on which to run parallel CPU backend // Sets the thread pool on which to run parallel CPU backend
// computations. Does not take ownership. // computations. Does not take ownership.
@ -91,7 +93,7 @@ class ExecutableRunOptions {
DeviceMemoryAllocator* allocator_ = nullptr; DeviceMemoryAllocator* allocator_ = nullptr;
int device_ordinal_ = -1; int device_ordinal_ = -1;
DeviceAssignment* device_assignment_ = nullptr; DeviceAssignment* device_assignment_ = nullptr;
perftools::gputools::Stream* stream_ = nullptr; stream_executor::Stream* stream_ = nullptr;
tensorflow::thread::ThreadPool* inter_op_thread_pool_ = nullptr; tensorflow::thread::ThreadPool* inter_op_thread_pool_ = nullptr;
const Eigen::ThreadPoolDevice* intra_op_thread_pool_ = nullptr; const Eigen::ThreadPoolDevice* intra_op_thread_pool_ = nullptr;
ExecutionProfile* execution_profile_ = nullptr; ExecutionProfile* execution_profile_ = nullptr;

View File

@ -46,4 +46,12 @@ using ::Eigen::half;
} // namespace xla } // namespace xla
// Alias namespace ::stream_executor as ::xla::se.
namespace stream_executor {}
namespace xla {
// TODO(b/77980417): Uncomment this once all namespace aliases named 'se' are
// removed in ::xla.
// namespace se = ::stream_executor;
} // namespace xla
#endif // TENSORFLOW_COMPILER_XLA_TYPES_H_ #endif // TENSORFLOW_COMPILER_XLA_TYPES_H_

View File

@ -29,13 +29,11 @@ limitations under the License.
#include "tensorflow/core/platform/thread_annotations.h" #include "tensorflow/core/platform/thread_annotations.h"
#include "tensorflow/core/platform/types.h" #include "tensorflow/core/platform/types.h"
namespace perftools { namespace stream_executor {
namespace gputools {
class Event; class Event;
class Stream; class Stream;
class StreamExecutor; class StreamExecutor;
} // namespace gputools } // namespace stream_executor
} // namespace perftools
namespace tensorflow { namespace tensorflow {

View File

@ -23,7 +23,10 @@ limitations under the License.
#include "tensorflow/core/platform/stream_executor.h" #include "tensorflow/core/platform/stream_executor.h"
namespace tensorflow { namespace tensorflow {
namespace gpu = ::perftools::gputools;
// TODO(b/77980417): Remove this and use the regular tensorflow::se alias once
// that's available.
namespace gpu = ::stream_executor;
// Utility methods for translation between Tensorflow GPU ids and CUDA GPU ids. // Utility methods for translation between Tensorflow GPU ids and CUDA GPU ids.
class GpuIdUtil { class GpuIdUtil {

View File

@ -18,11 +18,9 @@ limitations under the License.
#include "tensorflow/core/lib/core/status.h" #include "tensorflow/core/lib/core/status.h"
namespace perftools { namespace stream_executor {
namespace gputools {
class Platform; class Platform;
} // namespace gputools } // namespace stream_executor
} // namespace perftools
namespace tensorflow { namespace tensorflow {
@ -34,7 +32,7 @@ Status ValidateGPUMachineManager();
// initializing the GPUs on the machine if needed the first time it is // initializing the GPUs on the machine if needed the first time it is
// called. Must only be called when there is a valid GPU environment // called. Must only be called when there is a valid GPU environment
// in the process (e.g., ValidateGPUMachineManager() returns OK). // in the process (e.g., ValidateGPUMachineManager() returns OK).
perftools::gputools::Platform* GPUMachineManager(); stream_executor::Platform* GPUMachineManager();
} // namespace tensorflow } // namespace tensorflow

View File

@ -60,7 +60,9 @@ using perftools::gputools::Stream;
namespace tensorflow { namespace tensorflow {
namespace gpu = ::perftools::gputools; // TODO(b/77980417): Remove this and use the regular tensorflow::se alias once
// that's available.
namespace gpu = ::stream_executor;
Status PrepareCopy(Device* device, const DeviceContext* ctx, const Tensor& src, Status PrepareCopy(Device* device, const DeviceContext* ctx, const Tensor& src,
const Tensor* dst, const Tensor* dst,

View File

@ -27,7 +27,9 @@ namespace tensorflow {
class RecvTensorResponse; class RecvTensorResponse;
class TensorProto; class TensorProto;
namespace gpu = ::perftools::gputools; // TODO(b/77980417): Remove this and use the regular tensorflow::se alias once
// that's available.
namespace gpu = ::stream_executor;
class GPUUtil { class GPUUtil {
public: public:

View File

@ -19,23 +19,22 @@ limitations under the License.
#include "tensorflow/core/common_runtime/device.h" #include "tensorflow/core/common_runtime/device.h"
#include "tensorflow/core/framework/device_base.h" #include "tensorflow/core/framework/device_base.h"
namespace perftools { namespace stream_executor {
namespace gputools {
class Stream; class Stream;
} // namespace gputools } // namespace stream_executor
} // namespace perftools
namespace tensorflow { namespace tensorflow {
namespace gpu = ::perftools::gputools; // TODO(b/77980417): Replace stream_executor:: with se:: once our namespace
// migration is complete and the alias is available.
class GPUDeviceContext : public DeviceContext { class GPUDeviceContext : public DeviceContext {
public: public:
// Does not take ownership of streams. // Does not take ownership of streams.
GPUDeviceContext(int stream_id, gpu::Stream* stream, GPUDeviceContext(int stream_id, stream_executor::Stream* stream,
gpu::Stream* host_to_device_stream, stream_executor::Stream* host_to_device_stream,
gpu::Stream* device_to_host_stream, stream_executor::Stream* device_to_host_stream,
gpu::Stream* device_to_device_stream) stream_executor::Stream* device_to_device_stream)
: stream_id_(stream_id), : stream_id_(stream_id),
stream_(stream), stream_(stream),
host_to_device_stream_(host_to_device_stream), host_to_device_stream_(host_to_device_stream),
@ -44,10 +43,14 @@ class GPUDeviceContext : public DeviceContext {
~GPUDeviceContext() override {} ~GPUDeviceContext() override {}
gpu::Stream* stream() const override { return stream_; } stream_executor::Stream* stream() const override { return stream_; }
gpu::Stream* host_to_device_stream() const { return host_to_device_stream_; } stream_executor::Stream* host_to_device_stream() const {
gpu::Stream* device_to_host_stream() const { return device_to_host_stream_; } return host_to_device_stream_;
gpu::Stream* device_to_device_stream() const { }
stream_executor::Stream* device_to_host_stream() const {
return device_to_host_stream_;
}
stream_executor::Stream* device_to_device_stream() const {
return device_to_device_stream_; return device_to_device_stream_;
} }
int stream_id() const { return stream_id_; } int stream_id() const { return stream_id_; }
@ -67,13 +70,13 @@ class GPUDeviceContext : public DeviceContext {
int stream_id_; int stream_id_;
// The default primary stream to use for this context. // The default primary stream to use for this context.
// All the memory belongs to this stream. // All the memory belongs to this stream.
gpu::Stream* stream_; stream_executor::Stream* stream_;
// The stream to use for copy data from host into GPU. // The stream to use for copy data from host into GPU.
gpu::Stream* host_to_device_stream_; stream_executor::Stream* host_to_device_stream_;
// The stream to use for copy data from GPU to host. // The stream to use for copy data from GPU to host.
gpu::Stream* device_to_host_stream_; stream_executor::Stream* device_to_host_stream_;
// The stream to use for copy data between GPU. // The stream to use for copy data between GPU.
gpu::Stream* device_to_device_stream_; stream_executor::Stream* device_to_device_stream_;
}; };
} // namespace tensorflow } // namespace tensorflow

View File

@ -34,11 +34,9 @@ struct SyclDevice;
#endif #endif
} // end namespace Eigen } // end namespace Eigen
namespace perftools { namespace stream_executor {
namespace gputools {
class Stream; class Stream;
} // namespace gputools } // namespace stream_executor
} // namespace perftools
namespace tensorflow { namespace tensorflow {
@ -69,9 +67,10 @@ class PerOpGpuDevice {
class DeviceContext : public core::RefCounted { class DeviceContext : public core::RefCounted {
public: public:
~DeviceContext() override {} ~DeviceContext() override {}
virtual perftools::gputools::Stream* stream() const { return nullptr; } virtual stream_executor::Stream* stream() const { return nullptr; }
virtual void MaintainLifetimeOnStream( virtual void MaintainLifetimeOnStream(const Tensor* t,
const Tensor* t, perftools::gputools::Stream* stream) const {} stream_executor::Stream* stream) const {
}
// "cpu_tensor" is a tensor on a CPU. Copies "cpu_tensor" into // "cpu_tensor" is a tensor on a CPU. Copies "cpu_tensor" into
// "device_tensor" which is on a GPU device "device". "device_tensor" // "device_tensor" which is on a GPU device "device". "device_tensor"
@ -133,7 +132,7 @@ class DeviceBase {
// but also by TPU devices (to provide default device context). // but also by TPU devices (to provide default device context).
struct GpuDeviceInfo { struct GpuDeviceInfo {
// Make sure all the defaults are NULL, so we can spot missing assignments. // Make sure all the defaults are NULL, so we can spot missing assignments.
perftools::gputools::Stream* stream = nullptr; stream_executor::Stream* stream = nullptr;
DeviceContext* default_context = nullptr; DeviceContext* default_context = nullptr;
EventMgr* event_mgr = nullptr; EventMgr* event_mgr = nullptr;
int gpu_id = -1; int gpu_id = -1;

View File

@ -23,8 +23,6 @@ limitations under the License.
namespace tensorflow { namespace tensorflow {
namespace gpu = ::perftools::gputools;
// On the open-source platform, stream_executor currently uses // On the open-source platform, stream_executor currently uses
// tensorflow::Status // tensorflow::Status
inline Status FromStreamExecutorStatus( inline Status FromStreamExecutorStatus(

View File

@ -37,4 +37,15 @@ limitations under the License.
#include "tensorflow/stream_executor/stream.h" #include "tensorflow/stream_executor/stream.h"
#include "tensorflow/stream_executor/stream_executor.h" #include "tensorflow/stream_executor/stream_executor.h"
namespace perftools {
namespace gputools {
// Temporarily pull stream_executor into perftools::gputools while we migrate
// code to the new namespace. TODO(jlebar): Remove this once we've completed
// the migration.
using namespace stream_executor; // NOLINT[build/namespaces]
} // namespace gputools
} // namespace perftools
#endif // TENSORFLOW_PLATFORM_STREAM_EXECUTOR_H_ #endif // TENSORFLOW_PLATFORM_STREAM_EXECUTOR_H_

View File

@ -60,4 +60,12 @@ typedef uint64 Fprint;
} // namespace tensorflow } // namespace tensorflow
// Alias namespace ::stream_executor as ::tensorflow::se.
namespace stream_executor {}
namespace tensorflow {
// TODO(b/77980417): Uncomment this once all namespace aliases named 'se' are
// removed in ::xla.
// namespace se = ::stream_executor;
} // namespace tensorflow
#endif // TENSORFLOW_PLATFORM_TYPES_H_ #endif // TENSORFLOW_PLATFORM_TYPES_H_

View File

@ -17,8 +17,7 @@ limitations under the License.
#include "tensorflow/stream_executor/lib/strcat.h" #include "tensorflow/stream_executor/lib/strcat.h"
namespace perftools { namespace stream_executor {
namespace gputools {
namespace blas { namespace blas {
string TransposeString(Transpose t) { string TransposeString(Transpose t) {
@ -95,5 +94,4 @@ std::ostream& operator<<(std::ostream& os, ComputationType ty) {
} }
} // namespace blas } // namespace blas
} // namespace gputools } // namespace stream_executor
} // namespace perftools

View File

@ -49,8 +49,7 @@ namespace Eigen {
struct half; struct half;
} // namespace Eigen } // namespace Eigen
namespace perftools { namespace stream_executor {
namespace gputools {
class Stream; class Stream;
class ScratchAllocator; class ScratchAllocator;
@ -2100,7 +2099,6 @@ class BlasSupport {
DeviceMemory<std::complex<double>> *b, int ldb) override; DeviceMemory<std::complex<double>> *b, int ldb) override;
} // namespace blas } // namespace blas
} // namespace gputools } // namespace stream_executor
} // namespace perftools
#endif // TENSORFLOW_STREAM_EXECUTOR_BLAS_H_ #endif // TENSORFLOW_STREAM_EXECUTOR_BLAS_H_

View File

@ -19,8 +19,7 @@ limitations under the License.
#include "tensorflow/stream_executor/stream_executor.h" #include "tensorflow/stream_executor/stream_executor.h"
#include "tensorflow/stream_executor/stream_executor_internal.h" #include "tensorflow/stream_executor/stream_executor_internal.h"
namespace perftools { namespace stream_executor {
namespace gputools {
namespace cuda { namespace cuda {
CudaContext* ExtractCudaContext(CUDAExecutor *cuda_exec); CudaContext* ExtractCudaContext(CUDAExecutor *cuda_exec);
@ -40,5 +39,4 @@ ScopedActivateExecutorContext::~ScopedActivateExecutorContext() {
} }
} // namespace cuda } // namespace cuda
} // namespace gputools } // namespace stream_executor
} // namespace perftools

View File

@ -25,8 +25,7 @@ limitations under the License.
#include "tensorflow/stream_executor/platform/port.h" #include "tensorflow/stream_executor/platform/port.h"
namespace perftools { namespace stream_executor {
namespace gputools {
class StreamExecutor; class StreamExecutor;
@ -56,7 +55,6 @@ class ScopedActivateExecutorContext {
}; };
} // namespace cuda } // namespace cuda
} // namespace gputools } // namespace stream_executor
} // namespace perftools
#endif // TENSORFLOW_STREAM_EXECUTOR_CUDA_CUDA_ACTIVATION_H_ #endif // TENSORFLOW_STREAM_EXECUTOR_CUDA_CUDA_ACTIVATION_H_

View File

@ -75,15 +75,14 @@ limitations under the License.
#include "tensorflow/stream_executor/scratch_allocator.h" #include "tensorflow/stream_executor/scratch_allocator.h"
#include "tensorflow/stream_executor/stream_executor.h" #include "tensorflow/stream_executor/stream_executor.h"
namespace perftools { namespace stream_executor {
namespace gputools {
namespace cuda { namespace cuda {
PLUGIN_REGISTRY_DEFINE_PLUGIN_ID(kCuBlasPlugin); PLUGIN_REGISTRY_DEFINE_PLUGIN_ID(kCuBlasPlugin);
namespace wrap { namespace wrap {
#define PERFTOOLS_GPUTOOLS_CUBLAS_WRAP(__name) \ #define STREAM_EXECUTOR_CUBLAS_WRAP(__name) \
struct WrapperShim__##__name { \ struct WrapperShim__##__name { \
static const char *kName; \ static const char *kName; \
template <typename... Args> \ template <typename... Args> \
@ -94,8 +93,8 @@ namespace wrap {
} __name; \ } __name; \
const char *WrapperShim__##__name::kName = #__name; const char *WrapperShim__##__name::kName = #__name;
#define PERFTOOLS_GPUTOOLS_CUBLAS_V2_WRAP(__name) \ #define STREAM_EXECUTOR_CUBLAS_V2_WRAP(__name) \
PERFTOOLS_GPUTOOLS_CUBLAS_WRAP(__name) STREAM_EXECUTOR_CUBLAS_WRAP(__name)
#define CUBLAS_BLAS_ROUTINE_EACH(__macro) \ #define CUBLAS_BLAS_ROUTINE_EACH(__macro) \
__macro(cublasSnrm2) \ __macro(cublasSnrm2) \
@ -269,28 +268,28 @@ namespace wrap {
__macro(cublasCdgmm) \ __macro(cublasCdgmm) \
__macro(cublasZdgmm) __macro(cublasZdgmm)
PERFTOOLS_GPUTOOLS_CUBLAS_V2_WRAP(cublasCreate) STREAM_EXECUTOR_CUBLAS_V2_WRAP(cublasCreate)
PERFTOOLS_GPUTOOLS_CUBLAS_V2_WRAP(cublasDestroy) STREAM_EXECUTOR_CUBLAS_V2_WRAP(cublasDestroy)
PERFTOOLS_GPUTOOLS_CUBLAS_V2_WRAP(cublasSetStream) STREAM_EXECUTOR_CUBLAS_V2_WRAP(cublasSetStream)
PERFTOOLS_GPUTOOLS_CUBLAS_V2_WRAP(cublasSetPointerMode) STREAM_EXECUTOR_CUBLAS_V2_WRAP(cublasSetPointerMode)
PERFTOOLS_GPUTOOLS_CUBLAS_V2_WRAP(cublasGetPointerMode) STREAM_EXECUTOR_CUBLAS_V2_WRAP(cublasGetPointerMode)
PERFTOOLS_GPUTOOLS_CUBLAS_WRAP(cublasSgemmBatched) STREAM_EXECUTOR_CUBLAS_WRAP(cublasSgemmBatched)
PERFTOOLS_GPUTOOLS_CUBLAS_WRAP(cublasDgemmBatched) STREAM_EXECUTOR_CUBLAS_WRAP(cublasDgemmBatched)
PERFTOOLS_GPUTOOLS_CUBLAS_WRAP(cublasCgemmBatched) STREAM_EXECUTOR_CUBLAS_WRAP(cublasCgemmBatched)
PERFTOOLS_GPUTOOLS_CUBLAS_WRAP(cublasZgemmBatched) STREAM_EXECUTOR_CUBLAS_WRAP(cublasZgemmBatched)
CUBLAS_BLAS_ROUTINE_EACH(PERFTOOLS_GPUTOOLS_CUBLAS_V2_WRAP) CUBLAS_BLAS_ROUTINE_EACH(STREAM_EXECUTOR_CUBLAS_V2_WRAP)
#if CUDA_VERSION >= 7050 #if CUDA_VERSION >= 7050
PERFTOOLS_GPUTOOLS_CUBLAS_WRAP(cublasSgemmEx) STREAM_EXECUTOR_CUBLAS_WRAP(cublasSgemmEx)
#endif #endif
#if CUDA_VERSION >= 8000 #if CUDA_VERSION >= 8000
PERFTOOLS_GPUTOOLS_CUBLAS_WRAP(cublasGemmEx) STREAM_EXECUTOR_CUBLAS_WRAP(cublasGemmEx)
#endif #endif
#if CUDA_VERSION >= 9000 #if CUDA_VERSION >= 9000
PERFTOOLS_GPUTOOLS_CUBLAS_WRAP(cublasGetMathMode) STREAM_EXECUTOR_CUBLAS_WRAP(cublasGetMathMode)
PERFTOOLS_GPUTOOLS_CUBLAS_WRAP(cublasSetMathMode) STREAM_EXECUTOR_CUBLAS_WRAP(cublasSetMathMode)
#endif #endif
} // namespace wrap } // namespace wrap
@ -2803,46 +2802,39 @@ bool CUDABlas::DoBlasTrsm(Stream *stream, blas::Side side,
} // namespace cuda } // namespace cuda
namespace gpu = ::perftools::gputools;
void initialize_cublas() { void initialize_cublas() {
gpu::port::Status status = port::Status status =
gpu::PluginRegistry::Instance() PluginRegistry::Instance()->RegisterFactory<PluginRegistry::BlasFactory>(
->RegisterFactory<gpu::PluginRegistry::BlasFactory>( cuda::kCudaPlatformId, cuda::kCuBlasPlugin, "cuBLAS",
gpu::cuda::kCudaPlatformId, gpu::cuda::kCuBlasPlugin, "cuBLAS", [](internal::StreamExecutorInterface *parent) -> blas::BlasSupport * {
[](gpu::internal::StreamExecutorInterface cuda::CUDAExecutor *cuda_executor =
*parent) -> gpu::blas::BlasSupport * { dynamic_cast<cuda::CUDAExecutor *>(parent);
gpu::cuda::CUDAExecutor *cuda_executor = if (cuda_executor == nullptr) {
dynamic_cast<gpu::cuda::CUDAExecutor *>(parent); LOG(ERROR)
if (cuda_executor == nullptr) { << "Attempting to initialize an instance of the cuBLAS "
LOG(ERROR) << "support library with a non-CUDA StreamExecutor";
<< "Attempting to initialize an instance of the cuBLAS " return nullptr;
<< "support library with a non-CUDA StreamExecutor"; }
return nullptr;
}
gpu::cuda::CUDABlas *blas = cuda::CUDABlas *blas = new cuda::CUDABlas(cuda_executor);
new gpu::cuda::CUDABlas(cuda_executor); if (!blas->Init()) {
if (!blas->Init()) { // Note: Init() will log a more specific error.
// Note: Init() will log a more specific error. delete blas;
delete blas; return nullptr;
return nullptr; }
} return blas;
return blas; });
});
if (!status.ok()) { if (!status.ok()) {
LOG(ERROR) << "Unable to register cuBLAS factory: " LOG(ERROR) << "Unable to register cuBLAS factory: "
<< status.error_message(); << status.error_message();
} }
gpu::PluginRegistry::Instance()->SetDefaultFactory(gpu::cuda::kCudaPlatformId, PluginRegistry::Instance()->SetDefaultFactory(
gpu::PluginKind::kBlas, cuda::kCudaPlatformId, PluginKind::kBlas, cuda::kCuBlasPlugin);
gpu::cuda::kCuBlasPlugin);
} }
} // namespace gputools } // namespace stream_executor
} // namespace perftools
REGISTER_MODULE_INITIALIZER(register_cublas, REGISTER_MODULE_INITIALIZER(register_cublas,
{ perftools::gputools::initialize_cublas(); }); { stream_executor::initialize_cublas(); });

View File

@ -29,8 +29,7 @@ limitations under the License.
typedef struct cublasContext *cublasHandle_t; typedef struct cublasContext *cublasHandle_t;
namespace perftools { namespace stream_executor {
namespace gputools {
class Stream; class Stream;
@ -162,7 +161,6 @@ class CUDABlas : public blas::BlasSupport {
}; };
} // namespace cuda } // namespace cuda
} // namespace gputools } // namespace stream_executor
} // namespace perftools
#endif // TENSORFLOW_STREAM_EXECUTOR_CUDA_CUDA_BLAS_H_ #endif // TENSORFLOW_STREAM_EXECUTOR_CUDA_CUDA_BLAS_H_

View File

@ -51,8 +51,7 @@ limitations under the License.
#include "tensorflow/stream_executor/lib/str_util.h" #include "tensorflow/stream_executor/lib/str_util.h"
#include "tensorflow/stream_executor/lib/inlined_vector.h" #include "tensorflow/stream_executor/lib/inlined_vector.h"
namespace perftools { namespace stream_executor {
namespace gputools {
namespace cuda { namespace cuda {
#ifdef __APPLE__ #ifdef __APPLE__
@ -384,5 +383,4 @@ port::StatusOr<DriverVersion> Diagnostician::FindKernelDriverVersion() {
} // namespace cuda } // namespace cuda
} // namespace gputools } // namespace stream_executor
} // namespace perftools

View File

@ -22,8 +22,7 @@ limitations under the License.
#include "tensorflow/stream_executor/lib/statusor.h" #include "tensorflow/stream_executor/lib/statusor.h"
#include "tensorflow/stream_executor/platform/port.h" #include "tensorflow/stream_executor/platform/port.h"
namespace perftools { namespace stream_executor {
namespace gputools {
namespace cuda { namespace cuda {
// e.g. DriverVersion{346, 3, 4} // e.g. DriverVersion{346, 3, 4}
@ -93,7 +92,6 @@ class Diagnostician {
}; };
} // namespace cuda } // namespace cuda
} // namespace gputools } // namespace stream_executor
} // namespace perftools
#endif // TENSORFLOW_STREAM_EXECUTOR_CUDA_CUDA_DIAGNOSTICS_H_ #endif // TENSORFLOW_STREAM_EXECUTOR_CUDA_CUDA_DIAGNOSTICS_H_

View File

@ -59,8 +59,7 @@ NarrowT CheckedNarrowing(const WideT& wide) {
} // namespace } // namespace
namespace perftools { namespace stream_executor {
namespace gputools {
using dnn::BatchDescriptor; using dnn::BatchDescriptor;
using dnn::FilterDescriptor; using dnn::FilterDescriptor;
@ -159,7 +158,7 @@ static port::ThreadPool* GetCudaThreadpool() {
return cudnn_threadpool; return cudnn_threadpool;
} }
#define PERFTOOLS_GPUTOOLS_CUDNN_WRAP(__name) \ #define STREAM_EXECUTOR_CUDNN_WRAP(__name) \
struct WrapperShim__##__name { \ struct WrapperShim__##__name { \
template <typename... Args> \ template <typename... Args> \
cudnnStatus_t operator()(CUDAExecutor* parent, Args... args) { \ cudnnStatus_t operator()(CUDAExecutor* parent, Args... args) { \
@ -169,7 +168,7 @@ static port::ThreadPool* GetCudaThreadpool() {
} \ } \
} __name; } __name;
#define PERFTOOLS_GPUTOOLS_CUDNN_WRAP_WITH_CHECKED_STREAM(__name) \ #define STREAM_EXECUTOR_CUDNN_WRAP_WITH_CHECKED_STREAM(__name) \
struct WrapperShim__##__name { \ struct WrapperShim__##__name { \
template <typename... Args> \ template <typename... Args> \
cudnnStatus_t operator()(CudnnSupport* dnn, Stream* s, Args... args) \ cudnnStatus_t operator()(CudnnSupport* dnn, Stream* s, Args... args) \
@ -220,7 +219,7 @@ struct WrapperShim__cudnnSetStream {
__macro(cudnnSetFilterNdDescriptor) __macro(cudnnSetFilterNdDescriptor)
// clang-format on // clang-format on
CUDNN_DNN_ROUTINE_EACH(PERFTOOLS_GPUTOOLS_CUDNN_WRAP) CUDNN_DNN_ROUTINE_EACH(STREAM_EXECUTOR_CUDNN_WRAP)
#undef CUDNN_DNN_ROUTINE_EACH #undef CUDNN_DNN_ROUTINE_EACH
// clang-format off // clang-format off
@ -242,7 +241,7 @@ CUDNN_DNN_ROUTINE_EACH(PERFTOOLS_GPUTOOLS_CUDNN_WRAP)
// clang-format on // clang-format on
CUDNN_DNN_ROUTINE_EACH_WITH_STREAM( CUDNN_DNN_ROUTINE_EACH_WITH_STREAM(
PERFTOOLS_GPUTOOLS_CUDNN_WRAP_WITH_CHECKED_STREAM) STREAM_EXECUTOR_CUDNN_WRAP_WITH_CHECKED_STREAM)
#undef CUDNN_DNN_ROUTINE_EACH_WITH_STREAM #undef CUDNN_DNN_ROUTINE_EACH_WITH_STREAM
// APIs available after R3: // APIs available after R3:
@ -252,7 +251,7 @@ CUDNN_DNN_ROUTINE_EACH_WITH_STREAM(
__macro(cudnnGetConvolutionBackwardDataAlgorithm) \ __macro(cudnnGetConvolutionBackwardDataAlgorithm) \
__macro(cudnnGetConvolutionBackwardFilterAlgorithm) \ __macro(cudnnGetConvolutionBackwardFilterAlgorithm) \
__macro(cudnnGetConvolutionBackwardDataWorkspaceSize) __macro(cudnnGetConvolutionBackwardDataWorkspaceSize)
CUDNN_DNN_ROUTINE_EACH_AFTER_R3(PERFTOOLS_GPUTOOLS_CUDNN_WRAP) CUDNN_DNN_ROUTINE_EACH_AFTER_R3(STREAM_EXECUTOR_CUDNN_WRAP)
#undef CUDNN_DNN_ROUTINE_EACH_AFTER_R3 #undef CUDNN_DNN_ROUTINE_EACH_AFTER_R3
#endif #endif
@ -266,7 +265,7 @@ CUDNN_DNN_ROUTINE_EACH_AFTER_R3(PERFTOOLS_GPUTOOLS_CUDNN_WRAP)
// clang-format on // clang-format on
CUDNN_DNN_ROUTINE_EACH_R3_WITH_STREAM( CUDNN_DNN_ROUTINE_EACH_R3_WITH_STREAM(
PERFTOOLS_GPUTOOLS_CUDNN_WRAP_WITH_CHECKED_STREAM) STREAM_EXECUTOR_CUDNN_WRAP_WITH_CHECKED_STREAM)
#undef CUDNN_DNN_ROUTINE_EACH_R3_WITH_STREAM #undef CUDNN_DNN_ROUTINE_EACH_R3_WITH_STREAM
#endif #endif
@ -293,7 +292,7 @@ CUDNN_DNN_ROUTINE_EACH_R3_WITH_STREAM(
__macro(cudnnGetFilterNdDescriptor) __macro(cudnnGetFilterNdDescriptor)
// clang-format on // clang-format on
CUDNN_DNN_ROUTINE_EACH_R5(PERFTOOLS_GPUTOOLS_CUDNN_WRAP) CUDNN_DNN_ROUTINE_EACH_R5(STREAM_EXECUTOR_CUDNN_WRAP)
#undef CUDNN_DNN_ROUTINE_EACH_R5 #undef CUDNN_DNN_ROUTINE_EACH_R5
// clang-format off // clang-format off
@ -305,7 +304,7 @@ CUDNN_DNN_ROUTINE_EACH_R5(PERFTOOLS_GPUTOOLS_CUDNN_WRAP)
// clang-format on // clang-format on
CUDNN_DNN_ROUTINE_EACH_R5_WITH_STREAM( CUDNN_DNN_ROUTINE_EACH_R5_WITH_STREAM(
PERFTOOLS_GPUTOOLS_CUDNN_WRAP_WITH_CHECKED_STREAM) STREAM_EXECUTOR_CUDNN_WRAP_WITH_CHECKED_STREAM)
#undef CUDNN_DNN_ROUTINE_EACH_R5_WITH_STREAM #undef CUDNN_DNN_ROUTINE_EACH_R5_WITH_STREAM
#endif #endif
@ -316,7 +315,7 @@ CUDNN_DNN_ROUTINE_EACH_R5_WITH_STREAM(
__macro(cudnnSetRNNDescriptor_v6) __macro(cudnnSetRNNDescriptor_v6)
// clang-format on // clang-format on
CUDNN_DNN_ROUTINE_EACH_R6(PERFTOOLS_GPUTOOLS_CUDNN_WRAP) CUDNN_DNN_ROUTINE_EACH_R6(STREAM_EXECUTOR_CUDNN_WRAP)
#undef CUDNN_DNN_ROUTINE_EACH_R6 #undef CUDNN_DNN_ROUTINE_EACH_R6
// clang-format off // clang-format off
@ -325,7 +324,7 @@ CUDNN_DNN_ROUTINE_EACH_R6(PERFTOOLS_GPUTOOLS_CUDNN_WRAP)
// clang-format on // clang-format on
CUDNN_DNN_ROUTINE_EACH_R6_WITH_STREAM( CUDNN_DNN_ROUTINE_EACH_R6_WITH_STREAM(
PERFTOOLS_GPUTOOLS_CUDNN_WRAP_WITH_CHECKED_STREAM) STREAM_EXECUTOR_CUDNN_WRAP_WITH_CHECKED_STREAM)
#undef CUDNN_DNN_ROUTINE_EACH_R6_WITH_STREAM #undef CUDNN_DNN_ROUTINE_EACH_R6_WITH_STREAM
#endif #endif
@ -337,7 +336,7 @@ CUDNN_DNN_ROUTINE_EACH_R6_WITH_STREAM(
__macro(cudnnSetRNNMatrixMathType) __macro(cudnnSetRNNMatrixMathType)
// clang-format on // clang-format on
CUDNN_DNN_ROUTINE_EACH_R7(PERFTOOLS_GPUTOOLS_CUDNN_WRAP) CUDNN_DNN_ROUTINE_EACH_R7(STREAM_EXECUTOR_CUDNN_WRAP)
#undef CUDNN_DNN_ROUTINE_EACH_R7 #undef CUDNN_DNN_ROUTINE_EACH_R7
#endif #endif
@ -4727,46 +4726,39 @@ bool CudnnSupport::DeriveOutputBatchDescriptor(
} // namespace cuda } // namespace cuda
namespace gpu = ::perftools::gputools;
void initialize_cudnn() { void initialize_cudnn() {
gpu::port::Status status = port::Status status =
gpu::PluginRegistry::Instance() PluginRegistry::Instance()->RegisterFactory<PluginRegistry::DnnFactory>(
->RegisterFactory<gpu::PluginRegistry::DnnFactory>( cuda::kCudaPlatformId, cuda::kCuDnnPlugin, "cuDNN",
gpu::cuda::kCudaPlatformId, gpu::cuda::kCuDnnPlugin, "cuDNN", [](internal::StreamExecutorInterface* parent) -> dnn::DnnSupport* {
[](gpu::internal::StreamExecutorInterface* cuda::CUDAExecutor* cuda_executor =
parent) -> gpu::dnn::DnnSupport* { dynamic_cast<cuda::CUDAExecutor*>(parent);
gpu::cuda::CUDAExecutor* cuda_executor = if (cuda_executor == nullptr) {
dynamic_cast<gpu::cuda::CUDAExecutor*>(parent); LOG(ERROR)
if (cuda_executor == nullptr) { << "Attempting to initialize an instance of the cuBLAS "
LOG(ERROR) << "support library with a non-CUDA StreamExecutor";
<< "Attempting to initialize an instance of the cuBLAS " return nullptr;
<< "support library with a non-CUDA StreamExecutor"; }
return nullptr;
}
gpu::cuda::CudnnSupport* dnn = cuda::CudnnSupport* dnn = new cuda::CudnnSupport(cuda_executor);
new gpu::cuda::CudnnSupport(cuda_executor); if (!dnn->Init().ok()) {
if (!dnn->Init().ok()) { // Note: Init() will log a more specific error.
// Note: Init() will log a more specific error. delete dnn;
delete dnn; return nullptr;
return nullptr; }
} return dnn;
return dnn; });
});
if (!status.ok()) { if (!status.ok()) {
LOG(ERROR) << "Unable to register cuDNN factory: " LOG(ERROR) << "Unable to register cuDNN factory: "
<< status.error_message(); << status.error_message();
} }
gpu::PluginRegistry::Instance()->SetDefaultFactory(gpu::cuda::kCudaPlatformId, PluginRegistry::Instance()->SetDefaultFactory(
gpu::PluginKind::kDnn, cuda::kCudaPlatformId, PluginKind::kDnn, cuda::kCuDnnPlugin);
gpu::cuda::kCuDnnPlugin);
} }
} // namespace gputools } // namespace stream_executor
} // namespace perftools
REGISTER_MODULE_INITIALIZER(register_cudnn, REGISTER_MODULE_INITIALIZER(register_cudnn,
{ perftools::gputools::initialize_cudnn(); }); { stream_executor::initialize_cudnn(); });

View File

@ -26,8 +26,7 @@ limitations under the License.
#include "tensorflow/stream_executor/plugin_registry.h" #include "tensorflow/stream_executor/plugin_registry.h"
#include "tensorflow/stream_executor/temporary_device_memory.h" #include "tensorflow/stream_executor/temporary_device_memory.h"
namespace perftools { namespace stream_executor {
namespace gputools {
namespace cuda { namespace cuda {
class CUDAExecutor; class CUDAExecutor;
@ -810,7 +809,6 @@ class CudnnSupport : public dnn::DnnSupport {
}; };
} // namespace cuda } // namespace cuda
} // namespace gputools } // namespace stream_executor
} // namespace perftools
#endif // TENSORFLOW_STREAM_EXECUTOR_CUDA_CUDA_DNN_H_ #endif // TENSORFLOW_STREAM_EXECUTOR_CUDA_CUDA_DNN_H_

View File

@ -53,8 +53,7 @@ bool FLAGS_gpuexec_cuda_device_0_only = false;
// matches the expected one. // matches the expected one.
constexpr bool kVerifyCudaContext = false; constexpr bool kVerifyCudaContext = false;
namespace perftools { namespace stream_executor {
namespace gputools {
namespace cuda { namespace cuda {
namespace { namespace {
@ -1649,5 +1648,4 @@ static port::StatusOr<T> GetSimpleAttribute(CUdevice device,
} }
} // namespace cuda } // namespace cuda
} // namespace gputools } // namespace stream_executor
} // namespace perftools

View File

@ -27,8 +27,7 @@ limitations under the License.
#include "tensorflow/stream_executor/platform/port.h" #include "tensorflow/stream_executor/platform/port.h"
#include "cuda/include/cuda.h" #include "cuda/include/cuda.h"
namespace perftools { namespace stream_executor {
namespace gputools {
namespace cuda { namespace cuda {
// Identifies the memory space where an allocation resides. See // Identifies the memory space where an allocation resides. See
@ -506,7 +505,6 @@ class CudaContext {
}; };
} // namespace cuda } // namespace cuda
} // namespace gputools } // namespace stream_executor
} // namespace perftools
#endif // TENSORFLOW_STREAM_EXECUTOR_CUDA_CUDA_DRIVER_H_ #endif // TENSORFLOW_STREAM_EXECUTOR_CUDA_CUDA_DRIVER_H_

View File

@ -19,8 +19,7 @@ limitations under the License.
#include "tensorflow/stream_executor/cuda/cuda_stream.h" #include "tensorflow/stream_executor/cuda/cuda_stream.h"
#include "tensorflow/stream_executor/lib/statusor.h" #include "tensorflow/stream_executor/lib/statusor.h"
namespace perftools { namespace stream_executor {
namespace gputools {
namespace cuda { namespace cuda {
CUDAEvent::CUDAEvent(CUDAExecutor* parent) CUDAEvent::CUDAEvent(CUDAExecutor* parent)
@ -68,5 +67,4 @@ const CUevent& CUDAEvent::cuda_event() {
} }
} // namespace cuda } // namespace cuda
} // namespace gputools } // namespace stream_executor
} // namespace perftools

View File

@ -21,8 +21,7 @@ limitations under the License.
#include "tensorflow/stream_executor/event.h" #include "tensorflow/stream_executor/event.h"
#include "tensorflow/stream_executor/lib/status.h" #include "tensorflow/stream_executor/lib/status.h"
namespace perftools { namespace stream_executor {
namespace gputools {
namespace cuda { namespace cuda {
// CUDAEvent wraps a CUevent in the platform-independent EventInterface // CUDAEvent wraps a CUevent in the platform-independent EventInterface
@ -58,7 +57,6 @@ class CUDAEvent : public internal::EventInterface {
}; };
} // namespace cuda } // namespace cuda
} // namespace gputools } // namespace stream_executor
} // namespace perftools
#endif // TENSORFLOW_STREAM_EXECUTOR_CUDA_CUDA_EVENT_H_ #endif // TENSORFLOW_STREAM_EXECUTOR_CUDA_CUDA_EVENT_H_

View File

@ -31,8 +31,7 @@ limitations under the License.
#include "tensorflow/stream_executor/plugin_registry.h" #include "tensorflow/stream_executor/plugin_registry.h"
#include "tensorflow/stream_executor/stream_executor_internal.h" #include "tensorflow/stream_executor/stream_executor_internal.h"
namespace perftools { namespace stream_executor {
namespace gputools {
namespace cuda { namespace cuda {
PLUGIN_REGISTRY_DEFINE_PLUGIN_ID(kCuFftPlugin); PLUGIN_REGISTRY_DEFINE_PLUGIN_ID(kCuFftPlugin);
@ -44,7 +43,7 @@ namespace wrap {
// manner on first use. This dynamic loading technique is used to avoid DSO // manner on first use. This dynamic loading technique is used to avoid DSO
// dependencies on vendor libraries which may or may not be available in the // dependencies on vendor libraries which may or may not be available in the
// deployed binary environment. // deployed binary environment.
#define PERFTOOLS_GPUTOOLS_CUFFT_WRAP(__name) \ #define STREAM_EXECUTOR_CUFFT_WRAP(__name) \
struct WrapperShim__##__name { \ struct WrapperShim__##__name { \
template <typename... Args> \ template <typename... Args> \
cufftResult operator()(CUDAExecutor *parent, Args... args) { \ cufftResult operator()(CUDAExecutor *parent, Args... args) { \
@ -68,7 +67,7 @@ namespace wrap {
__macro(cufftGetSizeMany) \ __macro(cufftGetSizeMany) \
__macro(cufftMakePlanMany) __macro(cufftMakePlanMany)
CUFFT_ROUTINE_EACH(PERFTOOLS_GPUTOOLS_CUFFT_WRAP) CUFFT_ROUTINE_EACH(STREAM_EXECUTOR_CUFFT_WRAP)
} // namespace wrap } // namespace wrap
@ -514,62 +513,59 @@ bool CUDAFft::DoFftWithDirectionInternal(Stream *stream, fft::Plan *plan,
return true; return true;
} }
#define PERFTOOLS_GPUTOOLS_CUDA_DEFINE_FFT(__type, __fft_type1, __fft_type2, \ #define STREAM_EXECUTOR_CUDA_DEFINE_FFT(__type, __fft_type1, __fft_type2, \
__fft_type3) \ __fft_type3) \
bool CUDAFft::DoFft(Stream *stream, fft::Plan *plan, \ bool CUDAFft::DoFft(Stream *stream, fft::Plan *plan, \
const DeviceMemory<std::complex<__type>> &input, \ const DeviceMemory<std::complex<__type>> &input, \
DeviceMemory<std::complex<__type>> *output) { \ DeviceMemory<std::complex<__type>> *output) { \
return DoFftWithDirectionInternal( \ return DoFftWithDirectionInternal( \
stream, plan, wrap::cufftExec##__fft_type1, input, output); \ stream, plan, wrap::cufftExec##__fft_type1, input, output); \
} \ } \
bool CUDAFft::DoFft(Stream *stream, fft::Plan *plan, \ bool CUDAFft::DoFft(Stream *stream, fft::Plan *plan, \
const DeviceMemory<__type> &input, \ const DeviceMemory<__type> &input, \
DeviceMemory<std::complex<__type>> *output) { \ DeviceMemory<std::complex<__type>> *output) { \
return DoFftInternal(stream, plan, wrap::cufftExec##__fft_type2, input, \ return DoFftInternal(stream, plan, wrap::cufftExec##__fft_type2, input, \
output); \ output); \
} \ } \
bool CUDAFft::DoFft(Stream *stream, fft::Plan *plan, \ bool CUDAFft::DoFft(Stream *stream, fft::Plan *plan, \
const DeviceMemory<std::complex<__type>> &input, \ const DeviceMemory<std::complex<__type>> &input, \
DeviceMemory<__type> *output) { \ DeviceMemory<__type> *output) { \
return DoFftInternal(stream, plan, wrap::cufftExec##__fft_type3, input, \ return DoFftInternal(stream, plan, wrap::cufftExec##__fft_type3, input, \
output); \ output); \
} }
PERFTOOLS_GPUTOOLS_CUDA_DEFINE_FFT(float, C2C, R2C, C2R) STREAM_EXECUTOR_CUDA_DEFINE_FFT(float, C2C, R2C, C2R)
PERFTOOLS_GPUTOOLS_CUDA_DEFINE_FFT(double, Z2Z, D2Z, Z2D) STREAM_EXECUTOR_CUDA_DEFINE_FFT(double, Z2Z, D2Z, Z2D)
#undef PERFTOOLS_GPUTOOLS_CUDA_DEFINE_FFT #undef STREAM_EXECUTOR_CUDA_DEFINE_FFT
} // namespace cuda } // namespace cuda
} // namespace gputools
} // namespace perftools
namespace gpu = ::perftools::gputools; void initialize_cufft() {
port::Status status =
PluginRegistry::Instance()->RegisterFactory<PluginRegistry::FftFactory>(
cuda::kCudaPlatformId, cuda::kCuFftPlugin, "cuFFT",
[](internal::StreamExecutorInterface *parent) -> fft::FftSupport * {
cuda::CUDAExecutor *cuda_executor =
dynamic_cast<cuda::CUDAExecutor *>(parent);
if (cuda_executor == nullptr) {
LOG(ERROR) << "Attempting to initialize an instance of the cuFFT "
<< "support library with a non-CUDA StreamExecutor";
return nullptr;
}
REGISTER_MODULE_INITIALIZER(register_cufft, { return new cuda::CUDAFft(cuda_executor);
gpu::port::Status status = });
gpu::PluginRegistry::Instance()
->RegisterFactory<gpu::PluginRegistry::FftFactory>(
gpu::cuda::kCudaPlatformId, gpu::cuda::kCuFftPlugin, "cuFFT",
[](gpu::internal::StreamExecutorInterface
*parent) -> gpu::fft::FftSupport * {
gpu::cuda::CUDAExecutor *cuda_executor =
dynamic_cast<gpu::cuda::CUDAExecutor *>(parent);
if (cuda_executor == nullptr) {
LOG(ERROR)
<< "Attempting to initialize an instance of the cuFFT "
<< "support library with a non-CUDA StreamExecutor";
return nullptr;
}
return new gpu::cuda::CUDAFft(cuda_executor);
});
if (!status.ok()) { if (!status.ok()) {
LOG(ERROR) << "Unable to register cuFFT factory: " LOG(ERROR) << "Unable to register cuFFT factory: "
<< status.error_message(); << status.error_message();
} }
gpu::PluginRegistry::Instance()->SetDefaultFactory(gpu::cuda::kCudaPlatformId, PluginRegistry::Instance()->SetDefaultFactory(
gpu::PluginKind::kFft, cuda::kCudaPlatformId, PluginKind::kFft, cuda::kCuFftPlugin);
gpu::cuda::kCuFftPlugin); }
});
} // namespace stream_executor
REGISTER_MODULE_INITIALIZER(register_cufft,
{ stream_executor::initialize_cufft(); });

View File

@ -26,8 +26,7 @@ limitations under the License.
#include "tensorflow/stream_executor/plugin_registry.h" #include "tensorflow/stream_executor/plugin_registry.h"
#include "tensorflow/stream_executor/scratch_allocator.h" #include "tensorflow/stream_executor/scratch_allocator.h"
namespace perftools { namespace stream_executor {
namespace gputools {
class Stream; class Stream;
@ -133,7 +132,6 @@ class CUDAFft : public fft::FftSupport {
}; };
} // namespace cuda } // namespace cuda
} // namespace gputools } // namespace stream_executor
} // namespace perftools
#endif // TENSORFLOW_STREAM_EXECUTOR_CUDA_CUDA_FFT_H_ #endif // TENSORFLOW_STREAM_EXECUTOR_CUDA_CUDA_FFT_H_

View File

@ -66,8 +66,7 @@ limitations under the License.
extern bool FLAGS_check_gpu_leaks; extern bool FLAGS_check_gpu_leaks;
bool FLAGS_prefer_cubin_to_ptx = true; bool FLAGS_prefer_cubin_to_ptx = true;
namespace perftools { namespace stream_executor {
namespace gputools {
namespace cuda { namespace cuda {
// Hook that can be used to CUBIN-ate PTX before it is loaded into the driver. // Hook that can be used to CUBIN-ate PTX before it is loaded into the driver.
@ -1168,17 +1167,14 @@ DeviceDescription *CUDAExecutor::PopulateDeviceDescription() const {
} // namespace cuda } // namespace cuda
namespace gpu = ::perftools::gputools;
void initialize_cuda_gpu_executor() { void initialize_cuda_gpu_executor() {
*gpu::internal::MakeCUDAExecutorImplementation() = []( *internal::MakeCUDAExecutorImplementation() = [](const PluginConfig &config) {
const gpu::PluginConfig &config) { return new cuda::CUDAExecutor{config};
return new gpu::cuda::CUDAExecutor{config};
}; };
} }
} // namespace gputools } // namespace stream_executor
} // namespace perftools
REGISTER_MODULE_INITIALIZER( REGISTER_MODULE_INITIALIZER(cuda_gpu_executor, {
cuda_gpu_executor, {perftools::gputools::initialize_cuda_gpu_executor();}); stream_executor::initialize_cuda_gpu_executor();
});

View File

@ -35,8 +35,7 @@ limitations under the License.
#include "tensorflow/stream_executor/platform/thread_annotations.h" #include "tensorflow/stream_executor/platform/thread_annotations.h"
#include "tensorflow/stream_executor/stream_executor_internal.h" #include "tensorflow/stream_executor/stream_executor_internal.h"
namespace perftools { namespace stream_executor {
namespace gputools {
namespace cuda { namespace cuda {
// CUDA-platform implementation of the platform-agnostic // CUDA-platform implementation of the platform-agnostic
@ -273,7 +272,6 @@ class CUDAExecutor : public internal::StreamExecutorInterface {
}; };
} // namespace cuda } // namespace cuda
} // namespace gputools } // namespace stream_executor
} // namespace perftools
#endif // TENSORFLOW_STREAM_EXECUTOR_CUDA_CUDA_GPU_EXECUTOR_H_ #endif // TENSORFLOW_STREAM_EXECUTOR_CUDA_CUDA_GPU_EXECUTOR_H_

View File

@ -27,8 +27,7 @@ limitations under the License.
#include "cuda/include/cuComplex.h" #include "cuda/include/cuComplex.h"
#include "cuda/include/cuda.h" #include "cuda/include/cuda.h"
namespace perftools { namespace stream_executor {
namespace gputools {
template <typename ElemT> template <typename ElemT>
class DeviceMemory; class DeviceMemory;
@ -101,7 +100,6 @@ inline cuDoubleComplex CUDAComplexValue(std::complex<double> val) {
} }
} // namespace cuda } // namespace cuda
} // namespace gputools } // namespace stream_executor
} // namespace perftools
#endif // TENSORFLOW_STREAM_EXECUTOR_CUDA_CUDA_HELPERS_H_ #endif // TENSORFLOW_STREAM_EXECUTOR_CUDA_CUDA_HELPERS_H_

View File

@ -40,8 +40,7 @@ limitations under the License.
"CUDA runtime being included into CUDA GPU executor; should be driver only." "CUDA runtime being included into CUDA GPU executor; should be driver only."
#endif #endif
namespace perftools { namespace stream_executor {
namespace gputools {
namespace cuda { namespace cuda {
// Wraps a CUfunction to implement the platform-independent KernelInterface. // Wraps a CUfunction to implement the platform-independent KernelInterface.
@ -124,7 +123,6 @@ inline CUDAKernel *AsCUDAKernel(KernelBase *kernel) {
} }
} // namespace cuda } // namespace cuda
} // namespace gputools } // namespace stream_executor
} // namespace perftools
#endif // TENSORFLOW_STREAM_EXECUTOR_CUDA_CUDA_KERNEL_H_ #endif // TENSORFLOW_STREAM_EXECUTOR_CUDA_CUDA_KERNEL_H_

View File

@ -24,8 +24,7 @@ limitations under the License.
#include "tensorflow/stream_executor/lib/status.h" #include "tensorflow/stream_executor/lib/status.h"
#include "tensorflow/stream_executor/lib/stringprintf.h" #include "tensorflow/stream_executor/lib/stringprintf.h"
namespace perftools { namespace stream_executor {
namespace gputools {
namespace cuda { namespace cuda {
namespace { namespace {
@ -41,16 +40,16 @@ const DeviceOptions GetDeviceOptionsFromEnv() {
std::getenv("TF_CUDA_PLATFORM_GPU_DEVICE_SCHEDULE"); std::getenv("TF_CUDA_PLATFORM_GPU_DEVICE_SCHEDULE");
if (gpu_schedule_string == nullptr) { if (gpu_schedule_string == nullptr) {
return perftools::gputools::DeviceOptions::Default(); return DeviceOptions::Default();
} }
unsigned device_flags = 0; unsigned device_flags = 0;
if (strcmp(kScheduleSpinString, gpu_schedule_string) == 0) { if (strcmp(kScheduleSpinString, gpu_schedule_string) == 0) {
device_flags = perftools::gputools::DeviceOptions::kScheduleSpin; device_flags = DeviceOptions::kScheduleSpin;
} else if (strcmp(kScheduleYieldString, gpu_schedule_string) == 0) { } else if (strcmp(kScheduleYieldString, gpu_schedule_string) == 0) {
device_flags = perftools::gputools::DeviceOptions::kScheduleYield; device_flags = DeviceOptions::kScheduleYield;
} else if (strcmp(kScheduleBlockingSyncString, gpu_schedule_string) == 0) { } else if (strcmp(kScheduleBlockingSyncString, gpu_schedule_string) == 0) {
device_flags = perftools::gputools::DeviceOptions::kScheduleBlockingSync; device_flags = DeviceOptions::kScheduleBlockingSync;
} else { } else {
LOG(QFATAL) << "Unknown option for environment variable " LOG(QFATAL) << "Unknown option for environment variable "
"TF_CUDA_PLATFORM_GPU_DEVICE_SCHEDULE " "TF_CUDA_PLATFORM_GPU_DEVICE_SCHEDULE "
@ -59,7 +58,7 @@ const DeviceOptions GetDeviceOptionsFromEnv() {
<< ", " << kScheduleYieldString << "}"; << ", " << kScheduleYieldString << "}";
} }
return perftools::gputools::DeviceOptions(device_flags); return DeviceOptions(device_flags);
} }
} // namespace } // namespace
@ -202,11 +201,10 @@ static void InitializeCudaPlatform() {
SE_CHECK_OK(MultiPlatformManager::RegisterPlatform(std::move(platform))); SE_CHECK_OK(MultiPlatformManager::RegisterPlatform(std::move(platform)));
} }
} // namespace gputools } // namespace stream_executor
} // namespace perftools
REGISTER_MODULE_INITIALIZER(cuda_platform, REGISTER_MODULE_INITIALIZER(cuda_platform,
perftools::gputools::InitializeCudaPlatform()); stream_executor::InitializeCudaPlatform());
DECLARE_MODULE_INITIALIZER(multi_platform_manager); DECLARE_MODULE_INITIALIZER(multi_platform_manager);
// Note that module initialization sequencing is not supported in the // Note that module initialization sequencing is not supported in the

View File

@ -31,8 +31,7 @@ limitations under the License.
#include "tensorflow/stream_executor/stream_executor_pimpl.h" #include "tensorflow/stream_executor/stream_executor_pimpl.h"
#include "tensorflow/stream_executor/trace_listener.h" #include "tensorflow/stream_executor/trace_listener.h"
namespace perftools { namespace stream_executor {
namespace gputools {
namespace cuda { namespace cuda {
// Opaque and unique identifier for the CUDA platform plugin. // Opaque and unique identifier for the CUDA platform plugin.
@ -104,7 +103,6 @@ class CudaPlatform : public Platform {
}; };
} // namespace cuda } // namespace cuda
} // namespace gputools } // namespace stream_executor
} // namespace perftools
#endif // TENSORFLOW_STREAM_EXECUTOR_CUDA_CUDA_PLATFORM_H_ #endif // TENSORFLOW_STREAM_EXECUTOR_CUDA_CUDA_PLATFORM_H_

View File

@ -15,12 +15,10 @@ limitations under the License.
#include "tensorflow/stream_executor/cuda/cuda_platform_id.h" #include "tensorflow/stream_executor/cuda/cuda_platform_id.h"
namespace perftools { namespace stream_executor {
namespace gputools {
namespace cuda { namespace cuda {
PLATFORM_DEFINE_ID(kCudaPlatformId); PLATFORM_DEFINE_ID(kCudaPlatformId);
} // namespace cuda } // namespace cuda
} // namespace gputools } // namespace stream_executor
} // namespace perftools

View File

@ -18,8 +18,7 @@ limitations under the License.
#include "tensorflow/stream_executor/platform.h" #include "tensorflow/stream_executor/platform.h"
namespace perftools { namespace stream_executor {
namespace gputools {
namespace cuda { namespace cuda {
// Opaque and unique identifier for the cuda platform. // Opaque and unique identifier for the cuda platform.
@ -30,7 +29,6 @@ namespace cuda {
extern const Platform::Id kCudaPlatformId; extern const Platform::Id kCudaPlatformId;
} // namespace cuda } // namespace cuda
} // namespace gputools } // namespace stream_executor
} // namespace perftools
#endif // TENSORFLOW_STREAM_EXECUTOR_CUDA_CUDA_PLATFORM_ID_H_ #endif // TENSORFLOW_STREAM_EXECUTOR_CUDA_CUDA_PLATFORM_ID_H_

View File

@ -54,15 +54,14 @@ std::ostream &operator<<(std::ostream &in, const curandStatus_t &status) {
} }
} }
namespace perftools { namespace stream_executor {
namespace gputools {
namespace cuda { namespace cuda {
PLUGIN_REGISTRY_DEFINE_PLUGIN_ID(kCuRandPlugin); PLUGIN_REGISTRY_DEFINE_PLUGIN_ID(kCuRandPlugin);
namespace wrap { namespace wrap {
#define PERFTOOLS_GPUTOOLS_CURAND_WRAP(__name) \ #define STREAM_EXECUTOR_CURAND_WRAP(__name) \
struct WrapperShim__##__name { \ struct WrapperShim__##__name { \
template <typename... Args> \ template <typename... Args> \
curandStatus_t operator()(CUDAExecutor *parent, Args... args) { \ curandStatus_t operator()(CUDAExecutor *parent, Args... args) { \
@ -71,15 +70,15 @@ namespace wrap {
} \ } \
} __name; } __name;
PERFTOOLS_GPUTOOLS_CURAND_WRAP(curandCreateGenerator); STREAM_EXECUTOR_CURAND_WRAP(curandCreateGenerator);
PERFTOOLS_GPUTOOLS_CURAND_WRAP(curandDestroyGenerator); STREAM_EXECUTOR_CURAND_WRAP(curandDestroyGenerator);
PERFTOOLS_GPUTOOLS_CURAND_WRAP(curandSetStream); STREAM_EXECUTOR_CURAND_WRAP(curandSetStream);
PERFTOOLS_GPUTOOLS_CURAND_WRAP(curandGenerateUniform); STREAM_EXECUTOR_CURAND_WRAP(curandGenerateUniform);
PERFTOOLS_GPUTOOLS_CURAND_WRAP(curandGenerateUniformDouble); STREAM_EXECUTOR_CURAND_WRAP(curandGenerateUniformDouble);
PERFTOOLS_GPUTOOLS_CURAND_WRAP(curandSetPseudoRandomGeneratorSeed); STREAM_EXECUTOR_CURAND_WRAP(curandSetPseudoRandomGeneratorSeed);
PERFTOOLS_GPUTOOLS_CURAND_WRAP(curandSetGeneratorOffset); STREAM_EXECUTOR_CURAND_WRAP(curandSetGeneratorOffset);
PERFTOOLS_GPUTOOLS_CURAND_WRAP(curandGenerateNormal); STREAM_EXECUTOR_CURAND_WRAP(curandGenerateNormal);
PERFTOOLS_GPUTOOLS_CURAND_WRAP(curandGenerateNormalDouble); STREAM_EXECUTOR_CURAND_WRAP(curandGenerateNormalDouble);
} // namespace wrap } // namespace wrap
@ -271,42 +270,40 @@ bool CUDARng::SetSeed(Stream *stream, const uint8 *seed, uint64 seed_bytes) {
} }
} // namespace cuda } // namespace cuda
} // namespace gputools
} // namespace perftools
namespace gpu = ::perftools::gputools; void initialize_curand() {
port::Status status =
PluginRegistry::Instance()->RegisterFactory<PluginRegistry::RngFactory>(
cuda::kCudaPlatformId, cuda::kCuRandPlugin, "cuRAND",
[](internal::StreamExecutorInterface *parent) -> rng::RngSupport * {
cuda::CUDAExecutor *cuda_executor =
dynamic_cast<cuda::CUDAExecutor *>(parent);
if (cuda_executor == nullptr) {
LOG(ERROR)
<< "Attempting to initialize an instance of the cuRAND "
<< "support library with a non-CUDA StreamExecutor";
return nullptr;
}
REGISTER_MODULE_INITIALIZER(register_curand, { cuda::CUDARng *rng = new cuda::CUDARng(cuda_executor);
gpu::port::Status status = if (!rng->Init()) {
gpu::PluginRegistry::Instance() // Note: Init() will log a more specific error.
->RegisterFactory<gpu::PluginRegistry::RngFactory>( delete rng;
gpu::cuda::kCudaPlatformId, gpu::cuda::kCuRandPlugin, "cuRAND", return nullptr;
[](gpu::internal::StreamExecutorInterface }
*parent) -> gpu::rng::RngSupport * { return rng;
gpu::cuda::CUDAExecutor *cuda_executor = });
dynamic_cast<gpu::cuda::CUDAExecutor *>(parent);
if (cuda_executor == nullptr) {
LOG(ERROR)
<< "Attempting to initialize an instance of the cuRAND "
<< "support library with a non-CUDA StreamExecutor";
return nullptr;
}
gpu::cuda::CUDARng *rng = new gpu::cuda::CUDARng(cuda_executor);
if (!rng->Init()) {
// Note: Init() will log a more specific error.
delete rng;
return nullptr;
}
return rng;
});
if (!status.ok()) { if (!status.ok()) {
LOG(ERROR) << "Unable to register cuRAND factory: " LOG(ERROR) << "Unable to register cuRAND factory: "
<< status.error_message(); << status.error_message();
} }
gpu::PluginRegistry::Instance()->SetDefaultFactory(gpu::cuda::kCudaPlatformId, PluginRegistry::Instance()->SetDefaultFactory(
gpu::PluginKind::kRng, cuda::kCudaPlatformId, PluginKind::kRng, cuda::kCuRandPlugin);
gpu::cuda::kCuRandPlugin); }
});
} // namespace stream_executor
REGISTER_MODULE_INITIALIZER(register_curand,
{ stream_executor::initialize_curand(); });

View File

@ -24,8 +24,7 @@ limitations under the License.
typedef struct curandGenerator_st *curandGenerator_t; typedef struct curandGenerator_st *curandGenerator_t;
namespace perftools { namespace stream_executor {
namespace gputools {
class Stream; class Stream;
template <typename ElemT> template <typename ElemT>
@ -98,7 +97,6 @@ class CUDARng : public rng::RngSupport {
}; };
} // namespace cuda } // namespace cuda
} // namespace gputools } // namespace stream_executor
} // namespace perftools
#endif // TENSORFLOW_STREAM_EXECUTOR_CUDA_CUDA_RNG_H_ #endif // TENSORFLOW_STREAM_EXECUTOR_CUDA_CUDA_RNG_H_

View File

@ -19,8 +19,7 @@ limitations under the License.
#include "tensorflow/stream_executor/lib/status.h" #include "tensorflow/stream_executor/lib/status.h"
#include "tensorflow/stream_executor/stream.h" #include "tensorflow/stream_executor/stream.h"
namespace perftools { namespace stream_executor {
namespace gputools {
namespace cuda { namespace cuda {
bool CUDAStream::Init() { bool CUDAStream::Init() {
@ -59,5 +58,4 @@ CUstream AsCUDAStreamValue(Stream *stream) {
} }
} // namespace cuda } // namespace cuda
} // namespace gputools } // namespace stream_executor
} // namespace perftools

View File

@ -23,8 +23,7 @@ limitations under the License.
#include "tensorflow/stream_executor/platform/thread_annotations.h" #include "tensorflow/stream_executor/platform/thread_annotations.h"
#include "tensorflow/stream_executor/stream_executor_internal.h" #include "tensorflow/stream_executor/stream_executor_internal.h"
namespace perftools { namespace stream_executor {
namespace gputools {
namespace cuda { namespace cuda {
class CUDAExecutor; class CUDAExecutor;
@ -89,7 +88,6 @@ CUDAStream *AsCUDAStream(Stream *stream);
CUstream AsCUDAStreamValue(Stream *stream); CUstream AsCUDAStreamValue(Stream *stream);
} // namespace cuda } // namespace cuda
} // namespace gputools } // namespace stream_executor
} // namespace perftools
#endif // TENSORFLOW_STREAM_EXECUTOR_CUDA_CUDA_STREAM_H_ #endif // TENSORFLOW_STREAM_EXECUTOR_CUDA_CUDA_STREAM_H_

View File

@ -20,8 +20,7 @@ limitations under the License.
#include "tensorflow/stream_executor/cuda/cuda_stream.h" #include "tensorflow/stream_executor/cuda/cuda_stream.h"
#include "tensorflow/stream_executor/lib/status.h" #include "tensorflow/stream_executor/lib/status.h"
namespace perftools { namespace stream_executor {
namespace gputools {
namespace cuda { namespace cuda {
bool CUDATimer::Init() { bool CUDATimer::Init() {
@ -92,5 +91,4 @@ bool CUDATimer::Stop(CUDAStream* stream) {
} }
} // namespace cuda } // namespace cuda
} // namespace gputools } // namespace stream_executor
} // namespace perftools

View File

@ -23,8 +23,7 @@ limitations under the License.
#include "tensorflow/stream_executor/cuda/cuda_driver.h" #include "tensorflow/stream_executor/cuda/cuda_driver.h"
#include "tensorflow/stream_executor/cuda/cuda_gpu_executor.h" #include "tensorflow/stream_executor/cuda/cuda_gpu_executor.h"
namespace perftools { namespace stream_executor {
namespace gputools {
namespace cuda { namespace cuda {
class CUDAExecutor; class CUDAExecutor;
@ -60,13 +59,13 @@ class CUDATimer : public internal::TimerInterface {
// events. // events.
float GetElapsedMilliseconds() const; float GetElapsedMilliseconds() const;
// See perftools::gputools::Timer::Microseconds(). // See Timer::Microseconds().
// TODO(leary) make this into an error code interface... // TODO(leary) make this into an error code interface...
uint64 Microseconds() const override { uint64 Microseconds() const override {
return GetElapsedMilliseconds() * 1e3; return GetElapsedMilliseconds() * 1e3;
} }
// See perftools::GPUTools::Timer::Nanoseconds(). // See Timer::Nanoseconds().
uint64 Nanoseconds() const override { return GetElapsedMilliseconds() * 1e6; } uint64 Nanoseconds() const override { return GetElapsedMilliseconds() * 1e6; }
private: private:
@ -85,7 +84,6 @@ struct TimerDeleter {
}; };
} // namespace cuda } // namespace cuda
} // namespace gputools } // namespace stream_executor
} // namespace perftools
#endif // TENSORFLOW_STREAM_EXECUTOR_CUDA_CUDA_TIMER_H_ #endif // TENSORFLOW_STREAM_EXECUTOR_CUDA_CUDA_TIMER_H_

View File

@ -15,8 +15,7 @@ limitations under the License.
#include "tensorflow/stream_executor/cuda/cudnn_version.h" #include "tensorflow/stream_executor/cuda/cudnn_version.h"
namespace perftools { namespace stream_executor {
namespace gputools {
namespace cuda { namespace cuda {
bool IsSourceCompatibleWithCudnnLibrary(CudnnVersion source_version, bool IsSourceCompatibleWithCudnnLibrary(CudnnVersion source_version,
@ -38,5 +37,4 @@ bool IsSourceCompatibleWithCudnnLibrary(CudnnVersion source_version,
} }
} // namespace cuda } // namespace cuda
} // namespace gputools } // namespace stream_executor
} // namespace perftools

View File

@ -20,8 +20,7 @@ limitations under the License.
#include "tensorflow/core/lib/strings/strcat.h" #include "tensorflow/core/lib/strings/strcat.h"
namespace perftools { namespace stream_executor {
namespace gputools {
namespace cuda { namespace cuda {
struct CudnnVersion { struct CudnnVersion {
@ -46,7 +45,6 @@ bool IsSourceCompatibleWithCudnnLibrary(CudnnVersion source_version,
CudnnVersion loaded_version); CudnnVersion loaded_version);
} // namespace cuda } // namespace cuda
} // namespace gputools } // namespace stream_executor
} // namespace perftools
#endif // TENSORFLOW_STREAM_EXECUTOR_CUDA_CUDNN_VERSION_H_ #endif // TENSORFLOW_STREAM_EXECUTOR_CUDA_CUDNN_VERSION_H_

View File

@ -17,8 +17,7 @@ limitations under the License.
#include "tensorflow/core/platform/test.h" #include "tensorflow/core/platform/test.h"
namespace perftools { namespace stream_executor {
namespace gputools {
namespace cuda { namespace cuda {
namespace { namespace {
@ -70,5 +69,4 @@ TEST(IsSourceCompatibleWithCudnnLibraryTest, Basic) {
} // namespace } // namespace
} // namespace cuda } // namespace cuda
} // namespace gputools } // namespace stream_executor
} // namespace perftools

View File

@ -21,8 +21,7 @@ limitations under the License.
#include "tensorflow/stream_executor/lib/mathutil.h" #include "tensorflow/stream_executor/lib/mathutil.h"
#include "tensorflow/stream_executor/lib/strcat.h" #include "tensorflow/stream_executor/lib/strcat.h"
namespace perftools { namespace stream_executor {
namespace gputools {
static const uint64 kUninitializedUint64 = -1ULL; static const uint64 kUninitializedUint64 = -1ULL;
/* static */ const char *DeviceDescription::kUndefinedString = "<undefined>"; /* static */ const char *DeviceDescription::kUndefinedString = "<undefined>";
@ -234,6 +233,4 @@ uint64 CalculateRegisterLimitForTargetOccupancy(
return 0; return 0;
} }
} // namespace stream_executor
} // namespace gputools
} // namespace perftools

View File

@ -27,8 +27,7 @@ limitations under the License.
#include "tensorflow/stream_executor/launch_dim.h" #include "tensorflow/stream_executor/launch_dim.h"
#include "tensorflow/stream_executor/platform/port.h" #include "tensorflow/stream_executor/platform/port.h"
namespace perftools { namespace stream_executor {
namespace gputools {
namespace internal { namespace internal {
class DeviceDescriptionBuilder; class DeviceDescriptionBuilder;
} // namespace internal } // namespace internal
@ -388,7 +387,6 @@ uint64 CalculateRegisterLimitForTargetOccupancy(
const DeviceDescription &device_description, uint64 shared_memory_per_block, const DeviceDescription &device_description, uint64 shared_memory_per_block,
const ThreadDim &thread_dims, uint64 target_blocks_per_core); const ThreadDim &thread_dims, uint64 target_blocks_per_core);
} // namespace gputools } // namespace stream_executor
} // namespace perftools
#endif // TENSORFLOW_STREAM_EXECUTOR_DEVICE_DESCRIPTION_H_ #endif // TENSORFLOW_STREAM_EXECUTOR_DEVICE_DESCRIPTION_H_

View File

@ -32,6 +32,16 @@ limitations under the License.
namespace perftools { namespace perftools {
namespace gputools { namespace gputools {
// Temporarily pull stream_executor into perftools::gputools while we migrate
// code to the new namespace. TODO(b/77980417): Remove this once we've
// completed the migration.
using namespace stream_executor; // NOLINT[build/namespaces]
} // namespace gputools
} // namespace perftools
namespace stream_executor {
class StreamExecutor; class StreamExecutor;
// void*-analogous device memory allocation. For the typed variation, see // void*-analogous device memory allocation. For the typed variation, see
@ -280,7 +290,6 @@ static_assert(sizeof(Float2) == 2 * sizeof(float), "Float2 must be packed");
static_assert(sizeof(Float4) == 4 * sizeof(float), "Float4 must be packed"); static_assert(sizeof(Float4) == 4 * sizeof(float), "Float4 must be packed");
static_assert(sizeof(Double2) == 2 * sizeof(double), "Double2 must be packed"); static_assert(sizeof(Double2) == 2 * sizeof(double), "Double2 must be packed");
} // namespace gputools } // namespace stream_executor
} // namespace perftools
#endif // TENSORFLOW_STREAM_EXECUTOR_DEVICE_MEMORY_H_ #endif // TENSORFLOW_STREAM_EXECUTOR_DEVICE_MEMORY_H_

View File

@ -25,8 +25,7 @@ limitations under the License.
#include "tensorflow/stream_executor/platform/port.h" #include "tensorflow/stream_executor/platform/port.h"
#include "tensorflow/stream_executor/platform/logging.h" #include "tensorflow/stream_executor/platform/logging.h"
namespace perftools { namespace stream_executor {
namespace gputools {
// Indicates a set of options for a device's usage, which generally must be // Indicates a set of options for a device's usage, which generally must be
// provided at StreamExecutor device-initialization time. // provided at StreamExecutor device-initialization time.
@ -84,7 +83,6 @@ struct DeviceOptions {
unsigned flags_; unsigned flags_;
}; };
} // namespace gputools } // namespace stream_executor
} // namespace perftools
#endif // TENSORFLOW_STREAM_EXECUTOR_DEVICE_OPTIONS_H_ #endif // TENSORFLOW_STREAM_EXECUTOR_DEVICE_OPTIONS_H_

View File

@ -18,8 +18,7 @@ limitations under the License.
#include "tensorflow/stream_executor/lib/strcat.h" #include "tensorflow/stream_executor/lib/strcat.h"
#include "tensorflow/stream_executor/lib/stringprintf.h" #include "tensorflow/stream_executor/lib/stringprintf.h"
namespace perftools { namespace stream_executor {
namespace gputools {
namespace dnn { namespace dnn {
bool DnnSupport::GetConvolveAlgorithms( bool DnnSupport::GetConvolveAlgorithms(
@ -554,5 +553,4 @@ string NormalizeDescriptor::ToShortString() const {
} }
} // namespace dnn } // namespace dnn
} // namespace gputools } // namespace stream_executor
} // namespace perftools

View File

@ -38,8 +38,7 @@ namespace Eigen {
struct half; struct half;
} // namespace Eigen } // namespace Eigen
namespace perftools { namespace stream_executor {
namespace gputools {
class HostBuffer; class HostBuffer;
class Stream; class Stream;
@ -2285,7 +2284,6 @@ class DnnSupport {
}; };
} // namespace dnn } // namespace dnn
} // namespace gputools } // namespace stream_executor
} // namespace perftools
#endif // TENSORFLOW_STREAM_EXECUTOR_DNN_H_ #endif // TENSORFLOW_STREAM_EXECUTOR_DNN_H_

View File

@ -37,8 +37,7 @@ limitations under the License.
#include "cuda/cuda_config.h" #include "cuda/cuda_config.h"
#endif #endif
namespace perftools { namespace stream_executor {
namespace gputools {
namespace internal { namespace internal {
string GetCudaVersion() { return TF_CUDA_VERSION; } string GetCudaVersion() { return TF_CUDA_VERSION; }
@ -291,5 +290,4 @@ static std::vector<string>* CreatePrimordialRpaths() {
} }
} // namespace internal } // namespace internal
} // namespace gputools } // namespace stream_executor
} // namespace perftools

View File

@ -28,8 +28,7 @@ limitations under the License.
#include "tensorflow/stream_executor/platform.h" #include "tensorflow/stream_executor/platform.h"
#include "tensorflow/stream_executor/platform/mutex.h" #include "tensorflow/stream_executor/platform/mutex.h"
namespace perftools { namespace stream_executor {
namespace gputools {
namespace internal { namespace internal {
// Permits StreamExecutor code to dynamically load a pre-determined set of // Permits StreamExecutor code to dynamically load a pre-determined set of
@ -114,7 +113,6 @@ class CachedDsoLoader {
}; };
} // namespace internal } // namespace internal
} // namespace gputools } // namespace stream_executor
} // namespace perftools
#endif // TENSORFLOW_STREAM_EXECUTOR_DSO_LOADER_H_ #endif // TENSORFLOW_STREAM_EXECUTOR_DSO_LOADER_H_

View File

@ -19,8 +19,7 @@ limitations under the License.
#include "tensorflow/stream_executor/stream_executor_pimpl.h" #include "tensorflow/stream_executor/stream_executor_pimpl.h"
#include "tensorflow/stream_executor/stream.h" #include "tensorflow/stream_executor/stream.h"
namespace perftools { namespace stream_executor {
namespace gputools {
Event::Event(StreamExecutor* stream_exec) Event::Event(StreamExecutor* stream_exec)
: stream_exec_(stream_exec), : stream_exec_(stream_exec),
@ -48,5 +47,4 @@ Event::Status Event::PollForStatus() {
return stream_exec_->PollForEventStatus(this); return stream_exec_->PollForEventStatus(this);
} }
} // namespace gputools } // namespace stream_executor
} // namespace perftools

View File

@ -20,8 +20,7 @@ limitations under the License.
#include "tensorflow/stream_executor/platform/port.h" #include "tensorflow/stream_executor/platform/port.h"
namespace perftools { namespace stream_executor {
namespace gputools {
namespace internal { namespace internal {
class EventInterface; class EventInterface;
@ -76,7 +75,6 @@ class Event {
SE_DISALLOW_COPY_AND_ASSIGN(Event); SE_DISALLOW_COPY_AND_ASSIGN(Event);
}; };
} // namespace gputools } // namespace stream_executor
} // namespace perftools
#endif // TENSORFLOW_STREAM_EXECUTOR_EVENT_H_ #endif // TENSORFLOW_STREAM_EXECUTOR_EVENT_H_

View File

@ -17,8 +17,7 @@ limitations under the License.
#include "tensorflow/stream_executor/lib/stringprintf.h" #include "tensorflow/stream_executor/lib/stringprintf.h"
namespace perftools { namespace stream_executor {
namespace gputools {
port::StatusOr<StreamExecutor*> ExecutorCache::GetOrCreate( port::StatusOr<StreamExecutor*> ExecutorCache::GetOrCreate(
const StreamExecutorConfig& config, const StreamExecutorConfig& config,
@ -104,5 +103,4 @@ ExecutorCache::Entry::~Entry() {
configurations.clear(); configurations.clear();
} }
} // namespace gputools } // namespace stream_executor
} // namespace perftools

View File

@ -24,8 +24,7 @@ limitations under the License.
#include "tensorflow/stream_executor/platform/mutex.h" #include "tensorflow/stream_executor/platform/mutex.h"
#include "tensorflow/stream_executor/stream_executor_pimpl.h" #include "tensorflow/stream_executor/stream_executor_pimpl.h"
namespace perftools { namespace stream_executor {
namespace gputools {
// Utility class to allow Platform objects to manage cached StreamExecutors. // Utility class to allow Platform objects to manage cached StreamExecutors.
// Thread-safe. // Thread-safe.
@ -76,7 +75,6 @@ class ExecutorCache {
SE_DISALLOW_COPY_AND_ASSIGN(ExecutorCache); SE_DISALLOW_COPY_AND_ASSIGN(ExecutorCache);
}; };
} // namespace gputools } // namespace stream_executor
} // namespace perftools
#endif // TENSORFLOW_STREAM_EXECUTOR_EXECUTOR_CACHE_H_ #endif // TENSORFLOW_STREAM_EXECUTOR_EXECUTOR_CACHE_H_

View File

@ -48,8 +48,7 @@ limitations under the License.
#include <memory> #include <memory>
#include "tensorflow/stream_executor/platform/port.h" #include "tensorflow/stream_executor/platform/port.h"
namespace perftools { namespace stream_executor {
namespace gputools {
class Stream; class Stream;
template <typename ElemT> template <typename ElemT>
@ -210,7 +209,7 @@ class FftSupport {
// Macro used to quickly declare overrides for abstract virtuals in the // Macro used to quickly declare overrides for abstract virtuals in the
// fft::FftSupport base class. Assumes that it's emitted somewhere inside the // fft::FftSupport base class. Assumes that it's emitted somewhere inside the
// ::perftools::gputools namespace. // ::stream_executor namespace.
#define TENSORFLOW_STREAM_EXECUTOR_GPU_FFT_SUPPORT_OVERRIDES \ #define TENSORFLOW_STREAM_EXECUTOR_GPU_FFT_SUPPORT_OVERRIDES \
std::unique_ptr<fft::Plan> Create1dPlan(Stream *stream, uint64 num_x, \ std::unique_ptr<fft::Plan> Create1dPlan(Stream *stream, uint64 num_x, \
fft::Type type, bool in_place_fft) \ fft::Type type, bool in_place_fft) \
@ -265,7 +264,6 @@ class FftSupport {
DeviceMemory<double> *output) override; DeviceMemory<double> *output) override;
} // namespace fft } // namespace fft
} // namespace gputools } // namespace stream_executor
} // namespace perftools
#endif // TENSORFLOW_STREAM_EXECUTOR_FFT_H_ #endif // TENSORFLOW_STREAM_EXECUTOR_FFT_H_

View File

@ -28,8 +28,7 @@ limitations under the License.
bool FLAGS_stream_executor_cpu_real_clock_rate = false; bool FLAGS_stream_executor_cpu_real_clock_rate = false;
namespace perftools { namespace stream_executor {
namespace gputools {
namespace host { namespace host {
HostStream *AsHostStream(Stream *stream) { HostStream *AsHostStream(Stream *stream) {
@ -266,5 +265,4 @@ rng::RngSupport *HostExecutor::CreateRng() {
} }
} // namespace host } // namespace host
} // namespace gputools } // namespace stream_executor
} // namespace perftools

View File

@ -28,8 +28,7 @@ limitations under the License.
#include "tensorflow/stream_executor/stream_executor.h" #include "tensorflow/stream_executor/stream_executor.h"
#include "tensorflow/stream_executor/stream_executor_internal.h" #include "tensorflow/stream_executor/stream_executor_internal.h"
namespace perftools { namespace stream_executor {
namespace gputools {
namespace host { namespace host {
// An implementation of StreamExecutor that does no communication or interaction // An implementation of StreamExecutor that does no communication or interaction
@ -210,7 +209,6 @@ class HostExecutor : public internal::StreamExecutorInterface {
}; };
} // namespace host } // namespace host
} // namespace gputools } // namespace stream_executor
} // namespace perftools
#endif // TENSORFLOW_STREAM_EXECUTOR_HOST_HOST_GPU_EXECUTOR_H_ #endif // TENSORFLOW_STREAM_EXECUTOR_HOST_HOST_GPU_EXECUTOR_H_

View File

@ -26,10 +26,7 @@ limitations under the License.
#include "tensorflow/stream_executor/lib/status_macros.h" #include "tensorflow/stream_executor/lib/status_macros.h"
#include "tensorflow/stream_executor/lib/stringprintf.h" #include "tensorflow/stream_executor/lib/stringprintf.h"
namespace gpu = ::perftools::gputools; namespace stream_executor {
namespace perftools {
namespace gputools {
namespace host { namespace host {
HostPlatform::HostPlatform() : name_("Host") {} HostPlatform::HostPlatform() : name_("Host") {}
@ -93,16 +90,15 @@ void HostPlatform::UnregisterTraceListener(TraceListener* listener) {
} }
static void InitializeHostPlatform() { static void InitializeHostPlatform() {
std::unique_ptr<gpu::Platform> platform(new gpu::host::HostPlatform); std::unique_ptr<Platform> platform(new host::HostPlatform);
SE_CHECK_OK(gpu::MultiPlatformManager::RegisterPlatform(std::move(platform))); SE_CHECK_OK(MultiPlatformManager::RegisterPlatform(std::move(platform)));
} }
} // namespace host } // namespace host
} // namespace gputools } // namespace stream_executor
} // namespace perftools
REGISTER_MODULE_INITIALIZER( REGISTER_MODULE_INITIALIZER(host_platform,
host_platform, perftools::gputools::host::InitializeHostPlatform()); stream_executor::host::InitializeHostPlatform());
DECLARE_MODULE_INITIALIZER(multi_platform_manager); DECLARE_MODULE_INITIALIZER(multi_platform_manager);
// Note that module initialization sequencing is not supported in the // Note that module initialization sequencing is not supported in the

View File

@ -33,8 +33,7 @@ limitations under the License.
#include "tensorflow/stream_executor/stream_executor_pimpl.h" #include "tensorflow/stream_executor/stream_executor_pimpl.h"
#include "tensorflow/stream_executor/trace_listener.h" #include "tensorflow/stream_executor/trace_listener.h"
namespace perftools { namespace stream_executor {
namespace gputools {
namespace host { namespace host {
// Host (CPU) platform plugin, registered as a singleton value via module // Host (CPU) platform plugin, registered as a singleton value via module
@ -79,7 +78,6 @@ class HostPlatform : public Platform {
}; };
} // namespace host } // namespace host
} // namespace gputools } // namespace stream_executor
} // namespace perftools
#endif // TENSORFLOW_STREAM_EXECUTOR_HOST_HOST_PLATFORM_H_ #endif // TENSORFLOW_STREAM_EXECUTOR_HOST_HOST_PLATFORM_H_

View File

@ -15,12 +15,10 @@ limitations under the License.
#include "tensorflow/stream_executor/host/host_platform_id.h" #include "tensorflow/stream_executor/host/host_platform_id.h"
namespace perftools { namespace stream_executor {
namespace gputools {
namespace host { namespace host {
PLATFORM_DEFINE_ID(kHostPlatformId); PLATFORM_DEFINE_ID(kHostPlatformId);
} // namespace host } // namespace host
} // namespace gputools } // namespace stream_executor
} // namespace perftools

View File

@ -18,8 +18,7 @@ limitations under the License.
#include "tensorflow/stream_executor/platform.h" #include "tensorflow/stream_executor/platform.h"
namespace perftools { namespace stream_executor {
namespace gputools {
namespace host { namespace host {
// Opaque and unique identifier for the host platform. // Opaque and unique identifier for the host platform.
@ -30,7 +29,6 @@ namespace host {
extern const Platform::Id kHostPlatformId; extern const Platform::Id kHostPlatformId;
} // namespace host } // namespace host
} // namespace gputools } // namespace stream_executor
} // namespace perftools
#endif // TENSORFLOW_STREAM_EXECUTOR_HOST_HOST_PLATFORM_ID_H_ #endif // TENSORFLOW_STREAM_EXECUTOR_HOST_HOST_PLATFORM_ID_H_

View File

@ -17,8 +17,7 @@ limitations under the License.
// the HostExecutor implementation. // the HostExecutor implementation.
#include "tensorflow/stream_executor/host/host_stream.h" #include "tensorflow/stream_executor/host/host_stream.h"
namespace perftools { namespace stream_executor {
namespace gputools {
namespace host { namespace host {
HostStream::HostStream() HostStream::HostStream()
@ -53,5 +52,4 @@ void HostStream::BlockUntilDone() {
} // namespace host } // namespace host
} // namespace gputools } // namespace stream_executor
} // namespace perftools

View File

@ -24,8 +24,7 @@ limitations under the License.
#include "tensorflow/stream_executor/lib/threadpool.h" #include "tensorflow/stream_executor/lib/threadpool.h"
#include "tensorflow/stream_executor/stream_executor_internal.h" #include "tensorflow/stream_executor/stream_executor_internal.h"
namespace perftools { namespace stream_executor {
namespace gputools {
namespace host { namespace host {
class HostStream : public internal::StreamInterface { class HostStream : public internal::StreamInterface {
@ -52,7 +51,6 @@ class HostStream : public internal::StreamInterface {
}; };
} // namespace host } // namespace host
} // namespace gputools } // namespace stream_executor
} // namespace perftools
#endif // TENSORFLOW_STREAM_EXECUTOR_HOST_HOST_STREAM_H_ #endif // TENSORFLOW_STREAM_EXECUTOR_HOST_HOST_STREAM_H_

View File

@ -19,8 +19,7 @@ limitations under the License.
#include "tensorflow/stream_executor/stream.h" #include "tensorflow/stream_executor/stream.h"
#include "tensorflow/stream_executor/stream_executor_internal.h" #include "tensorflow/stream_executor/stream_executor_internal.h"
namespace perftools { namespace stream_executor {
namespace gputools {
namespace host { namespace host {
using std::chrono::duration_cast; using std::chrono::duration_cast;
@ -46,5 +45,4 @@ void HostTimer::StartNow() { start_time_ = clock::now(); }
void HostTimer::StopNow() { duration_ = clock::now() - start_time_; } void HostTimer::StopNow() { duration_ = clock::now() - start_time_; }
} // namespace host } // namespace host
} // namespace gputools } // namespace stream_executor
} // namespace perftools

View File

@ -20,8 +20,7 @@ limitations under the License.
#include "tensorflow/stream_executor/stream_executor_internal.h" #include "tensorflow/stream_executor/stream_executor_internal.h"
namespace perftools { namespace stream_executor {
namespace gputools {
namespace host { namespace host {
class HostTimer : public internal::TimerInterface { class HostTimer : public internal::TimerInterface {
@ -57,7 +56,6 @@ class HostTimer : public internal::TimerInterface {
}; };
} // namespace host } // namespace host
} // namespace gputools } // namespace stream_executor
} // namespace perftools
#endif // TENSORFLOW_STREAM_EXECUTOR_HOST_HOST_TIMER_H_ #endif // TENSORFLOW_STREAM_EXECUTOR_HOST_HOST_TIMER_H_

View File

@ -18,8 +18,7 @@ limitations under the License.
#include "tensorflow/stream_executor/dnn.h" #include "tensorflow/stream_executor/dnn.h"
namespace perftools { namespace stream_executor {
namespace gputools {
// A HostBuffer is a block of memory in host memory containing the data for a // A HostBuffer is a block of memory in host memory containing the data for a
// dnn::BatchDescriptor using a device-dependent memory layout. // dnn::BatchDescriptor using a device-dependent memory layout.
@ -42,7 +41,6 @@ class HostBuffer {
const dnn::BatchDescriptor descriptor_; const dnn::BatchDescriptor descriptor_;
}; };
} // namespace gputools } // namespace stream_executor
} // namespace perftools
#endif // TENSORFLOW_STREAM_EXECUTOR_HOST_BUFFER_H_ #endif // TENSORFLOW_STREAM_EXECUTOR_HOST_BUFFER_H_

View File

@ -27,8 +27,7 @@ limitations under the License.
#include "tensorflow/stream_executor/platform/logging.h" #include "tensorflow/stream_executor/platform/logging.h"
#include "tensorflow/stream_executor/stream_executor.h" #include "tensorflow/stream_executor/stream_executor.h"
namespace perftools { namespace stream_executor {
namespace gputools {
bool KernelMetadata::registers_per_thread(int *registers_per_thread) const { bool KernelMetadata::registers_per_thread(int *registers_per_thread) const {
if (has_registers_per_thread_) { if (has_registers_per_thread_) {
@ -103,5 +102,4 @@ void KernelBase::set_name(port::StringPiece name) {
demangled_name_ = port::Demangle(stubless_name.data()); demangled_name_ = port::Demangle(stubless_name.data());
} }
} // namespace gputools } // namespace stream_executor
} // namespace perftools

View File

@ -64,7 +64,7 @@ limitations under the License.
// //
// Users typically won't need to type out the TypedKernel signature in full, it // Users typically won't need to type out the TypedKernel signature in full, it
// will be typedef'd by automatically generated code; for example, see // will be typedef'd by automatically generated code; for example, see
// perftools::gputools::executor_sample::VecReduceAddKernel. // stream_executor::executor_sample::VecReduceAddKernel.
#ifndef TENSORFLOW_STREAM_EXECUTOR_KERNEL_H_ #ifndef TENSORFLOW_STREAM_EXECUTOR_KERNEL_H_
#define TENSORFLOW_STREAM_EXECUTOR_KERNEL_H_ #define TENSORFLOW_STREAM_EXECUTOR_KERNEL_H_
@ -82,8 +82,7 @@ limitations under the License.
#include "tensorflow/stream_executor/lib/stringpiece.h" #include "tensorflow/stream_executor/lib/stringpiece.h"
#include "tensorflow/stream_executor/platform/port.h" #include "tensorflow/stream_executor/platform/port.h"
namespace perftools { namespace stream_executor {
namespace gputools {
class DeviceMemoryBase; class DeviceMemoryBase;
template <typename ElemT> template <typename ElemT>
@ -639,8 +638,8 @@ struct KernelInvocationChecker {
// NOTE: if you encounter an error here, you can see the mismatch by looking // NOTE: if you encounter an error here, you can see the mismatch by looking
// at the end of the last error message, which will be of the form: // at the end of the last error message, which will be of the form:
// //
// ...::Compatible<const perftools::gputools::DeviceMemory<OneThing> &, // ...::Compatible<const stream_executor::DeviceMemory<OneThing> &,
// perftools::gputools::DeviceMemory<AnotherThing>, true, // stream_executor::DeviceMemory<AnotherThing>, true,
// 0>' // 0>'
// requested here // requested here
// //
@ -711,7 +710,6 @@ struct KernelParamsOk<TypedKernel<Params...>, Args...> {
std::tuple<Params...>, std::tuple<Args...>>::CheckAllNoStaticAssert(); std::tuple<Params...>, std::tuple<Args...>>::CheckAllNoStaticAssert();
}; };
} // namespace gputools } // namespace stream_executor
} // namespace perftools
#endif // TENSORFLOW_STREAM_EXECUTOR_KERNEL_H_ #endif // TENSORFLOW_STREAM_EXECUTOR_KERNEL_H_

View File

@ -18,8 +18,7 @@ limitations under the License.
#ifndef TENSORFLOW_STREAM_EXECUTOR_KERNEL_CACHE_CONFIG_H_ #ifndef TENSORFLOW_STREAM_EXECUTOR_KERNEL_CACHE_CONFIG_H_
#define TENSORFLOW_STREAM_EXECUTOR_KERNEL_CACHE_CONFIG_H_ #define TENSORFLOW_STREAM_EXECUTOR_KERNEL_CACHE_CONFIG_H_
namespace perftools { namespace stream_executor {
namespace gputools {
// This enum represents potential configurations of L1/shared memory when // This enum represents potential configurations of L1/shared memory when
// running a particular kernel. These values represent user preference, and // running a particular kernel. These values represent user preference, and
@ -38,7 +37,6 @@ enum class KernelCacheConfig {
kPreferEqual, kPreferEqual,
}; };
} // namespace gputools } // namespace stream_executor
} // namespace perftools
#endif // TENSORFLOW_STREAM_EXECUTOR_KERNEL_CACHE_CONFIG_H_ #endif // TENSORFLOW_STREAM_EXECUTOR_KERNEL_CACHE_CONFIG_H_

View File

@ -15,9 +15,7 @@ limitations under the License.
#include "tensorflow/stream_executor/kernel_spec.h" #include "tensorflow/stream_executor/kernel_spec.h"
namespace stream_executor {
namespace perftools {
namespace gputools {
KernelLoaderSpec::KernelLoaderSpec(port::StringPiece kernelname) KernelLoaderSpec::KernelLoaderSpec(port::StringPiece kernelname)
: kernelname_(kernelname.ToString()) {} : kernelname_(kernelname.ToString()) {}
@ -247,5 +245,4 @@ MultiKernelLoaderSpec *MultiKernelLoaderSpec::AddCudaCompressedPtxInMemory(
MultiKernelLoaderSpec::MultiKernelLoaderSpec(size_t arity) : arity_(arity) {} MultiKernelLoaderSpec::MultiKernelLoaderSpec(size_t arity) : arity_(arity) {}
} // namespace gputools } // namespace stream_executor
} // namespace perftools

View File

@ -56,8 +56,7 @@ limitations under the License.
#include "tensorflow/stream_executor/platform/mutex.h" #include "tensorflow/stream_executor/platform/mutex.h"
#include "tensorflow/stream_executor/platform/port.h" #include "tensorflow/stream_executor/platform/port.h"
namespace perftools { namespace stream_executor {
namespace gputools {
// Describes how to load a kernel on a target platform. // Describes how to load a kernel on a target platform.
// //
@ -374,7 +373,6 @@ class MultiKernelLoaderSpec {
size_t arity_; size_t arity_;
}; };
} // namespace gputools } // namespace stream_executor
} // namespace perftools
#endif // TENSORFLOW_STREAM_EXECUTOR_KERNEL_SPEC_H_ #endif // TENSORFLOW_STREAM_EXECUTOR_KERNEL_SPEC_H_

View File

@ -21,7 +21,7 @@ limitations under the License.
// a single PC in a unit called a warp. There is a maximum number of threads // a single PC in a unit called a warp. There is a maximum number of threads
// that can execute in a shared-context entity called a block. Presently, that // that can execute in a shared-context entity called a block. Presently, that
// number is 1024 -- again, something that should not be relied on from this // number is 1024 -- again, something that should not be relied on from this
// comment, but checked via perftools::gputools::DeviceDescription. // comment, but checked via stream_executor::DeviceDescription.
// //
// For additional information, see // For additional information, see
// http://docs.nvidia.com/cuda/kepler-tuning-guide/#device-utilization-and-occupancy // http://docs.nvidia.com/cuda/kepler-tuning-guide/#device-utilization-and-occupancy
@ -40,8 +40,7 @@ limitations under the License.
#include "tensorflow/stream_executor/lib/strcat.h" #include "tensorflow/stream_executor/lib/strcat.h"
#include "tensorflow/stream_executor/platform/port.h" #include "tensorflow/stream_executor/platform/port.h"
namespace perftools { namespace stream_executor {
namespace gputools {
// Basic type that represents a 3-dimensional index space. // Basic type that represents a 3-dimensional index space.
struct Dim3D { struct Dim3D {
@ -74,7 +73,6 @@ struct BlockDim : public Dim3D {
} }
}; };
} // namespace gputools } // namespace stream_executor
} // namespace perftools
#endif // TENSORFLOW_STREAM_EXECUTOR_LAUNCH_DIM_H_ #endif // TENSORFLOW_STREAM_EXECUTOR_LAUNCH_DIM_H_

View File

@ -18,14 +18,23 @@ limitations under the License.
#include "tensorflow/core/lib/gtl/array_slice.h" #include "tensorflow/core/lib/gtl/array_slice.h"
namespace perftools { namespace stream_executor {
namespace gputools {
namespace port { namespace port {
using tensorflow::gtl::ArraySlice; using tensorflow::gtl::ArraySlice;
using tensorflow::gtl::MutableArraySlice; using tensorflow::gtl::MutableArraySlice;
} // namespace port } // namespace port
} // namespace stream_executor
namespace perftools {
namespace gputools {
// Temporarily pull stream_executor into perftools::gputools while we migrate
// code to the new namespace. TODO(b/77980417): Remove this once we've
// completed the migration.
using namespace stream_executor; // NOLINT[build/namespaces]
} // namespace gputools } // namespace gputools
} // namespace perftools } // namespace perftools

View File

@ -13,15 +13,14 @@ See the License for the specific language governing permissions and
limitations under the License. limitations under the License.
==============================================================================*/ ==============================================================================*/
// IWYU pragma: private, include "perftools/gputools/executor/stream_executor.h" // IWYU pragma: private, include "third_party/tensorflow/stream_executor/stream_executor.h"
#ifndef TENSORFLOW_STREAM_EXECUTOR_LIB_CASTS_H_ #ifndef TENSORFLOW_STREAM_EXECUTOR_LIB_CASTS_H_
#define TENSORFLOW_STREAM_EXECUTOR_LIB_CASTS_H_ #define TENSORFLOW_STREAM_EXECUTOR_LIB_CASTS_H_
#include <stdlib.h> #include <stdlib.h>
namespace perftools { namespace stream_executor {
namespace gputools {
namespace port { namespace port {
// port::bit_cast<Dest,Source> is a template function that implements the // port::bit_cast<Dest,Source> is a template function that implements the
@ -96,7 +95,6 @@ inline Dest bit_cast(const Source& source) {
} }
} // namespace port } // namespace port
} // namespace gputools } // namespace stream_executor
} // namespace perftools
#endif // TENSORFLOW_STREAM_EXECUTOR_LIB_CASTS_H_ #endif // TENSORFLOW_STREAM_EXECUTOR_LIB_CASTS_H_

View File

@ -27,8 +27,7 @@ limitations under the License.
#include <cxxabi.h> #include <cxxabi.h>
#endif #endif
namespace perftools { namespace stream_executor {
namespace gputools {
namespace port { namespace port {
// The API reference of abi::__cxa_demangle() can be found in // The API reference of abi::__cxa_demangle() can be found in
@ -49,5 +48,4 @@ string Demangle(const char *mangled) {
} }
} // namespace port } // namespace port
} // namespace gputools } // namespace stream_executor
} // namespace perftools

View File

@ -18,14 +18,12 @@ limitations under the License.
#include "tensorflow/stream_executor/platform/port.h" #include "tensorflow/stream_executor/platform/port.h"
namespace perftools { namespace stream_executor {
namespace gputools {
namespace port { namespace port {
string Demangle(const char* mangled); string Demangle(const char* mangled);
} // namespace port } // namespace port
} // namespace gputools } // namespace stream_executor
} // namespace perftools
#endif // TENSORFLOW_STREAM_EXECUTOR_LIB_DEMANGLE_H_ #endif // TENSORFLOW_STREAM_EXECUTOR_LIB_DEMANGLE_H_

View File

@ -21,8 +21,7 @@ limitations under the License.
#include "tensorflow/stream_executor/lib/stringpiece.h" #include "tensorflow/stream_executor/lib/stringpiece.h"
#include "tensorflow/stream_executor/platform/port.h" #include "tensorflow/stream_executor/platform/port.h"
namespace perftools { namespace stream_executor {
namespace gputools {
namespace port { namespace port {
using tensorflow::Env; using tensorflow::Env;
@ -37,7 +36,6 @@ inline Status FileExists(const port::StringPiece& filename) {
} }
} // namespace port } // namespace port
} // namespace gputools } // namespace stream_executor
} // namespace perftools
#endif // TENSORFLOW_STREAM_EXECUTOR_LIB_ENV_H_ #endif // TENSORFLOW_STREAM_EXECUTOR_LIB_ENV_H_

View File

@ -13,21 +13,19 @@ See the License for the specific language governing permissions and
limitations under the License. limitations under the License.
==============================================================================*/ ==============================================================================*/
// IWYU pragma: private, include "perftools/gputools/executor/stream_executor.h" // IWYU pragma: private, include "third_party/tensorflow/stream_executor/stream_executor.h"
#ifndef TENSORFLOW_STREAM_EXECUTOR_LIB_ERROR_H_ #ifndef TENSORFLOW_STREAM_EXECUTOR_LIB_ERROR_H_
#define TENSORFLOW_STREAM_EXECUTOR_LIB_ERROR_H_ #define TENSORFLOW_STREAM_EXECUTOR_LIB_ERROR_H_
#include "tensorflow/core/lib/core/error_codes.pb.h" // IWYU pragma: export #include "tensorflow/core/lib/core/error_codes.pb.h" // IWYU pragma: export
namespace perftools { namespace stream_executor {
namespace gputools {
namespace port { namespace port {
namespace error = tensorflow::error; namespace error = tensorflow::error;
} // namespace port } // namespace port
} // namespace gputools } // namespace stream_executor
} // namespace perftools
#endif // TENSORFLOW_STREAM_EXECUTOR_LIB_ERROR_H_ #endif // TENSORFLOW_STREAM_EXECUTOR_LIB_ERROR_H_

View File

@ -22,8 +22,7 @@ limitations under the License.
#include "tensorflow/stream_executor/lib/stringprintf.h" #include "tensorflow/stream_executor/lib/stringprintf.h"
#include "tensorflow/stream_executor/platform/port.h" #include "tensorflow/stream_executor/platform/port.h"
namespace perftools { namespace stream_executor {
namespace gputools {
namespace port { namespace port {
class HumanReadableNumBytes { class HumanReadableNumBytes {
@ -67,7 +66,6 @@ class HumanReadableNumBytes {
}; };
} // namespace port } // namespace port
} // namespace gputools } // namespace stream_executor
} // namespace perftools
#endif // TENSORFLOW_STREAM_EXECUTOR_LIB_HUMAN_READABLE_H_ #endif // TENSORFLOW_STREAM_EXECUTOR_LIB_HUMAN_READABLE_H_

View File

@ -26,8 +26,7 @@ limitations under the License.
#undef DECLARE_MODULE_INITIALIZER #undef DECLARE_MODULE_INITIALIZER
#undef REGISTER_MODULE_INITIALIZER_SEQUENCE #undef REGISTER_MODULE_INITIALIZER_SEQUENCE
namespace perftools { namespace stream_executor {
namespace gputools {
namespace port { namespace port {
class Initializer { class Initializer {
@ -49,20 +48,18 @@ class Initializer {
}; };
} // namespace port } // namespace port
} // namespace gputools } // namespace stream_executor
} // namespace perftools
#define REGISTER_INITIALIZER(type, name, body) \ #define REGISTER_INITIALIZER(type, name, body) \
static void google_init_##type##_##name() { body; } \ static void google_init_##type##_##name() { body; } \
perftools::gputools::port::Initializer google_initializer_##type##_##name( \ ::stream_executor::port::Initializer google_initializer_##type##_##name( \
google_init_##type##_##name) google_init_##type##_##name)
#define REGISTER_MODULE_INITIALIZER(name, body) \ #define REGISTER_MODULE_INITIALIZER(name, body) \
REGISTER_INITIALIZER(module, name, body) REGISTER_INITIALIZER(module, name, body)
#define DECLARE_INITIALIZER(type, name) \ #define DECLARE_INITIALIZER(type, name) \
extern perftools::gputools::port::Initializer \ extern ::stream_executor::port::Initializer google_initializer_##type##_##name
google_initializer_##type##_##name
#define DECLARE_MODULE_INITIALIZER(name) DECLARE_INITIALIZER(module, name) #define DECLARE_MODULE_INITIALIZER(name) DECLARE_INITIALIZER(module, name)

View File

@ -18,14 +18,12 @@ limitations under the License.
#include "tensorflow/core/lib/gtl/inlined_vector.h" #include "tensorflow/core/lib/gtl/inlined_vector.h"
namespace perftools { namespace stream_executor {
namespace gputools {
namespace port { namespace port {
using tensorflow::gtl::InlinedVector; using tensorflow::gtl::InlinedVector;
} // namespace port } // namespace port
} // namespace gputools } // namespace stream_executor
} // namespace perftools
#endif // TENSORFLOW_STREAM_EXECUTOR_LIB_INLINED_VECTOR_H_ #endif // TENSORFLOW_STREAM_EXECUTOR_LIB_INLINED_VECTOR_H_

View File

@ -25,8 +25,7 @@ limitations under the License.
#include "tensorflow/stream_executor/platform/logging.h" #include "tensorflow/stream_executor/platform/logging.h"
#include "tensorflow/stream_executor/platform/port.h" #include "tensorflow/stream_executor/platform/port.h"
namespace perftools { namespace stream_executor {
namespace gputools {
namespace port { namespace port {
class MathUtil { class MathUtil {
@ -97,7 +96,6 @@ IntegralType MathUtil::CeilOrFloorOfRatio(IntegralType numerator,
} }
} // namespace port } // namespace port
} // namespace gputools } // namespace stream_executor
} // namespace perftools
#endif // TENSORFLOW_STREAM_EXECUTOR_LIB_MATHUTIL_H_ #endif // TENSORFLOW_STREAM_EXECUTOR_LIB_MATHUTIL_H_

View File

@ -18,14 +18,12 @@ limitations under the License.
#include "tensorflow/core/platform/notification.h" #include "tensorflow/core/platform/notification.h"
namespace perftools { namespace stream_executor {
namespace gputools {
namespace port { namespace port {
using tensorflow::Notification; using tensorflow::Notification;
} // namespace port } // namespace port
} // namespace gputools } // namespace stream_executor
} // namespace perftools
#endif // TENSORFLOW_STREAM_EXECUTOR_LIB_NOTIFICATION_H_ #endif // TENSORFLOW_STREAM_EXECUTOR_LIB_NOTIFICATION_H_

View File

@ -17,8 +17,7 @@ limitations under the License.
#include <stdlib.h> #include <stdlib.h>
namespace perftools { namespace stream_executor {
namespace gputools {
namespace port { namespace port {
bool safe_strto32(const char* str, int32* value) { bool safe_strto32(const char* str, int32* value) {
@ -38,5 +37,4 @@ bool safe_strto32(const string& str, int32* value) {
} }
} // namespace port } // namespace port
} // namespace gputools } // namespace stream_executor
} // namespace perftools

View File

@ -18,8 +18,7 @@ limitations under the License.
#include "tensorflow/stream_executor/platform/port.h" #include "tensorflow/stream_executor/platform/port.h"
namespace perftools { namespace stream_executor {
namespace gputools {
namespace port { namespace port {
// Convert strings to floating point values. // Convert strings to floating point values.
@ -28,7 +27,6 @@ namespace port {
bool safe_strto32(const string& str, int32* value); bool safe_strto32(const string& str, int32* value);
} // namespace port } // namespace port
} // namespace gputools } // namespace stream_executor
} // namespace perftools
#endif // TENSORFLOW_STREAM_EXECUTOR_LIB_NUMBERS_H_ #endif // TENSORFLOW_STREAM_EXECUTOR_LIB_NUMBERS_H_

View File

@ -16,8 +16,7 @@ limitations under the License.
#include "tensorflow/stream_executor/lib/path.h" #include "tensorflow/stream_executor/lib/path.h"
#include "tensorflow/stream_executor/lib/strcat.h" #include "tensorflow/stream_executor/lib/strcat.h"
namespace perftools { namespace stream_executor {
namespace gputools {
namespace port { namespace port {
namespace internal { namespace internal {
@ -58,5 +57,4 @@ string JoinPathImpl(std::initializer_list<port::StringPiece> paths) {
} // namespace internal } // namespace internal
} // namespace port } // namespace port
} // namespace gputools } // namespace stream_executor
} // namespace perftools

View File

@ -20,8 +20,7 @@ limitations under the License.
#include "tensorflow/stream_executor/lib/stringpiece.h" #include "tensorflow/stream_executor/lib/stringpiece.h"
#include "tensorflow/stream_executor/platform/port.h" #include "tensorflow/stream_executor/platform/port.h"
namespace perftools { namespace stream_executor {
namespace gputools {
namespace port { namespace port {
using tensorflow::io::Dirname; using tensorflow::io::Dirname;
@ -56,7 +55,6 @@ inline string JoinPath(const T&... args) {
} }
} // namespace port } // namespace port
} // namespace gputools } // namespace stream_executor
} // namespace perftools
#endif // TENSORFLOW_STREAM_EXECUTOR_LIB_PATH_H_ #endif // TENSORFLOW_STREAM_EXECUTOR_LIB_PATH_H_

View File

@ -25,8 +25,7 @@ limitations under the License.
#endif #endif
#include <memory> #include <memory>
namespace perftools { namespace stream_executor {
namespace gputools {
namespace port { namespace port {
string Hostname() { string Hostname() {
@ -54,5 +53,4 @@ bool GetCurrentDirectory(string* dir) {
} }
} // namespace port } // namespace port
} // namespace gputools } // namespace stream_executor
} // namespace perftools

View File

@ -18,15 +18,13 @@ limitations under the License.
#include "tensorflow/stream_executor/platform/port.h" #include "tensorflow/stream_executor/platform/port.h"
namespace perftools { namespace stream_executor {
namespace gputools {
namespace port { namespace port {
string Hostname(); string Hostname();
bool GetCurrentDirectory(string* dir); bool GetCurrentDirectory(string* dir);
} // namespace port } // namespace port
} // namespace gputools } // namespace stream_executor
} // namespace perftools
#endif // TENSORFLOW_STREAM_EXECUTOR_LIB_PROCESS_STATE_H_ #endif // TENSORFLOW_STREAM_EXECUTOR_LIB_PROCESS_STATE_H_

View File

@ -18,8 +18,7 @@ limitations under the License.
#include <memory> #include <memory>
namespace perftools { namespace stream_executor {
namespace gputools {
namespace port { namespace port {
// Trait to select overloads and return types for MakeUnique. // Trait to select overloads and return types for MakeUnique.
@ -59,8 +58,17 @@ typename MakeUniqueResult<T>::invalid MakeUnique(Args&&... /* args */) =
delete; // NOLINT delete; // NOLINT
} // namespace port } // namespace port
} // namespace stream_executor
namespace perftools {
namespace gputools {
// Temporarily pull stream_executor into perftools::gputools while we migrate
// code to the new namespace. TODO(jlebar): Remove this once we've completed
// the migration.
using namespace stream_executor; // NOLINT[build/namespaces]
} // namespace gputools } // namespace gputools
} // namespace perftools } // namespace perftools
#endif // TENSORFLOW_STREAM_EXECUTOR_LIB_PTR_UTIL_H_ #endif // TENSORFLOW_STREAM_EXECUTOR_LIB_PTR_UTIL_H_

View File

@ -19,14 +19,12 @@ limitations under the License.
#include "tensorflow/core/platform/stacktrace.h" #include "tensorflow/core/platform/stacktrace.h"
#include "tensorflow/stream_executor/platform/port.h" #include "tensorflow/stream_executor/platform/port.h"
namespace perftools { namespace stream_executor {
namespace gputools {
namespace port { namespace port {
using tensorflow::CurrentStackTrace; using tensorflow::CurrentStackTrace;
} // namespace port } // namespace port
} // namespace gputools } // namespace stream_executor
} // namespace perftools
#endif // TENSORFLOW_STREAM_EXECUTOR_LIB_STACKTRACE_H_ #endif // TENSORFLOW_STREAM_EXECUTOR_LIB_STACKTRACE_H_

View File

@ -13,7 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License. limitations under the License.
==============================================================================*/ ==============================================================================*/
// IWYU pragma: private, include "perftools/gputools/executor/stream_executor.h" // IWYU pragma: private, include "third_party/tensorflow/stream_executor/stream_executor.h"
#ifndef TENSORFLOW_STREAM_EXECUTOR_LIB_STATUS_H_ #ifndef TENSORFLOW_STREAM_EXECUTOR_LIB_STATUS_H_
#define TENSORFLOW_STREAM_EXECUTOR_LIB_STATUS_H_ #define TENSORFLOW_STREAM_EXECUTOR_LIB_STATUS_H_
@ -23,15 +23,14 @@ limitations under the License.
#include "tensorflow/stream_executor/lib/stringpiece.h" #include "tensorflow/stream_executor/lib/stringpiece.h"
#include "tensorflow/stream_executor/platform/logging.h" #include "tensorflow/stream_executor/platform/logging.h"
namespace perftools { namespace stream_executor {
namespace gputools {
namespace port { namespace port {
using Status = tensorflow::Status; using Status = tensorflow::Status;
#define SE_CHECK_OK(val) TF_CHECK_OK(val) #define SE_CHECK_OK(val) TF_CHECK_OK(val)
#define SE_ASSERT_OK(val) \ #define SE_ASSERT_OK(val) \
ASSERT_EQ(::perftools::gputools::port::Status::OK(), (val)) ASSERT_EQ(::stream_executor::port::Status::OK(), (val))
// Define some canonical error helpers. // Define some canonical error helpers.
inline Status UnimplementedError(StringPiece message) { inline Status UnimplementedError(StringPiece message) {
@ -45,6 +44,16 @@ inline Status FailedPreconditionError(StringPiece message) {
} }
} // namespace port } // namespace port
} // namespace stream_executor
namespace perftools {
namespace gputools {
// Temporarily pull stream_executor into perftools::gputools while we migrate
// code to the new namespace. TODO(b/77980417): Remove this once we've
// completed the migration.
using namespace stream_executor; // NOLINT[build/namespaces]
} // namespace gputools } // namespace gputools
} // namespace perftools } // namespace perftools

View File

@ -13,15 +13,14 @@ See the License for the specific language governing permissions and
limitations under the License. limitations under the License.
==============================================================================*/ ==============================================================================*/
// IWYU pragma: private, include "perftools/gputools/executor/stream_executor.h" // IWYU pragma: private, include "third_party/tensorflow/stream_executor/stream_executor.h"
#ifndef TENSORFLOW_STREAM_EXECUTOR_LIB_STATUSOR_H_ #ifndef TENSORFLOW_STREAM_EXECUTOR_LIB_STATUSOR_H_
#define TENSORFLOW_STREAM_EXECUTOR_LIB_STATUSOR_H_ #define TENSORFLOW_STREAM_EXECUTOR_LIB_STATUSOR_H_
#include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/compiler/xla/statusor.h"
namespace perftools { namespace stream_executor {
namespace gputools {
namespace port { namespace port {
// Use XLA's StatusOr so we don't duplicate code. // Use XLA's StatusOr so we don't duplicate code.
@ -29,7 +28,6 @@ template <typename T>
using StatusOr = ::xla::StatusOr<T>; using StatusOr = ::xla::StatusOr<T>;
} // namespace port } // namespace port
} // namespace gputools } // namespace stream_executor
} // namespace perftools
#endif // TENSORFLOW_STREAM_EXECUTOR_LIB_STATUSOR_H_ #endif // TENSORFLOW_STREAM_EXECUTOR_LIB_STATUSOR_H_

View File

@ -19,8 +19,7 @@ limitations under the License.
#include "tensorflow/core/lib/strings/str_util.h" #include "tensorflow/core/lib/strings/str_util.h"
#include "tensorflow/stream_executor/lib/stringpiece.h" #include "tensorflow/stream_executor/lib/stringpiece.h"
namespace perftools { namespace stream_executor {
namespace gputools {
namespace port { namespace port {
using tensorflow::str_util::Join; using tensorflow::str_util::Join;
@ -38,7 +37,6 @@ inline string StripSuffixString(port::StringPiece str, port::StringPiece suffix)
using tensorflow::str_util::Lowercase; using tensorflow::str_util::Lowercase;
} // namespace port } // namespace port
} // namespace gputools } // namespace stream_executor
} // namespace perftools
#endif // TENSORFLOW_STREAM_EXECUTOR_LIB_STR_UTIL_H_ #endif // TENSORFLOW_STREAM_EXECUTOR_LIB_STR_UTIL_H_

View File

@ -13,22 +13,20 @@ See the License for the specific language governing permissions and
limitations under the License. limitations under the License.
==============================================================================*/ ==============================================================================*/
// IWYU pragma: private, include "perftools/gputools/executor/stream_executor.h" // IWYU pragma: private, include "third_party/tensorflow/stream_executor/stream_executor.h"
#ifndef TENSORFLOW_STREAM_EXECUTOR_LIB_STRCAT_H_ #ifndef TENSORFLOW_STREAM_EXECUTOR_LIB_STRCAT_H_
#define TENSORFLOW_STREAM_EXECUTOR_LIB_STRCAT_H_ #define TENSORFLOW_STREAM_EXECUTOR_LIB_STRCAT_H_
#include "tensorflow/core/lib/strings/strcat.h" #include "tensorflow/core/lib/strings/strcat.h"
namespace perftools { namespace stream_executor {
namespace gputools {
namespace port { namespace port {
using tensorflow::strings::StrCat; using tensorflow::strings::StrCat;
using tensorflow::strings::StrAppend; using tensorflow::strings::StrAppend;
} // namespace port } // namespace port
} // namespace gputools } // namespace stream_executor
} // namespace perftools
#endif // TENSORFLOW_STREAM_EXECUTOR_LIB_STRCAT_H_ #endif // TENSORFLOW_STREAM_EXECUTOR_LIB_STRCAT_H_

View File

@ -19,14 +19,12 @@ limitations under the License.
#include "tensorflow/core/lib/core/stringpiece.h" #include "tensorflow/core/lib/core/stringpiece.h"
#include "tensorflow/stream_executor/platform/port.h" #include "tensorflow/stream_executor/platform/port.h"
namespace perftools { namespace stream_executor {
namespace gputools {
namespace port { namespace port {
using tensorflow::StringPiece; using tensorflow::StringPiece;
} // namespace port } // namespace port
} // namespace gputools } // namespace stream_executor
} // namespace perftools
#endif // TENSORFLOW_STREAM_EXECUTOR_LIB_STRINGPIECE_H_ #endif // TENSORFLOW_STREAM_EXECUTOR_LIB_STRINGPIECE_H_

Some files were not shown because too many files have changed in this diff Show More