[StreamExecutor] Rename ::perftools::gputools -> ::stream_executor, part 1.
Step 1 of re-namespace'ing StreamExecutor into ::stream_executor. This moves everything inside of stream_executor/..., and leaves a namespace alias into ::perftools::gputools. The next steps will clean up users to use the new namespace. This is mostly a mechanical change, but it also includes a bunch of non-mechanical changes that ideally would be split out into separate patches. Unfortunately they all sort of need to be shoved in here for various reasons: - forward declarations need to be in the same namespace as the actual types, so we need to change all forward declarations of StreamExecutor types in this one patch. - Uses of these forward declarations need to be changed to the new namespace (or otherwise we need to add a namespace alias to the relevant header, but this is pretty ugly). - Various initialization code needs to live in StreamExecutor's "real" namespace, so all this needs to be changed. PiperOrigin-RevId: 193256128
This commit is contained in:
parent
72df3d60fa
commit
4764bf2986
tensorflow
compiler/xla
core
common_runtime
framework
platform
stream_executor
blas.ccblas.h
cuda
cuda_activation.cccuda_activation.hcuda_blas.cccuda_blas.hcuda_diagnostics.cccuda_diagnostics.hcuda_dnn.cccuda_dnn.hcuda_driver.cccuda_driver.hcuda_event.cccuda_event.hcuda_fft.cccuda_fft.hcuda_gpu_executor.cccuda_gpu_executor.hcuda_helpers.hcuda_kernel.hcuda_platform.cccuda_platform.hcuda_platform_id.cccuda_platform_id.hcuda_rng.cccuda_rng.hcuda_stream.cccuda_stream.hcuda_timer.cccuda_timer.hcudnn_version.cccudnn_version.hcudnn_version_test.cc
device_description.ccdevice_description.hdevice_memory.hdevice_options.hdnn.ccdnn.hdso_loader.ccdso_loader.hevent.ccevent.hexecutor_cache.ccexecutor_cache.hfft.hhost
host_gpu_executor.cchost_gpu_executor.hhost_platform.cchost_platform.hhost_platform_id.cchost_platform_id.hhost_stream.cchost_stream.hhost_timer.cchost_timer.h
host_buffer.hkernel.cckernel.hkernel_cache_config.hkernel_spec.cckernel_spec.hlaunch_dim.hlib
@ -36,12 +36,12 @@ DeviceMemoryAllocator* ExecutableRunOptions::allocator() const {
|
||||
}
|
||||
|
||||
ExecutableRunOptions& ExecutableRunOptions::set_stream(
|
||||
perftools::gputools::Stream* stream) {
|
||||
stream_executor::Stream* stream) {
|
||||
stream_ = stream;
|
||||
return *this;
|
||||
}
|
||||
|
||||
perftools::gputools::Stream* ExecutableRunOptions::stream() const {
|
||||
stream_executor::Stream* ExecutableRunOptions::stream() const {
|
||||
return stream_;
|
||||
}
|
||||
|
||||
|
@ -16,29 +16,31 @@ limitations under the License.
|
||||
#ifndef TENSORFLOW_COMPILER_XLA_EXECUTABLE_RUN_OPTIONS_H_
|
||||
#define TENSORFLOW_COMPILER_XLA_EXECUTABLE_RUN_OPTIONS_H_
|
||||
|
||||
// Intentionally forward declared so that ExecutableRunOptions can be linked
|
||||
// These classes are forward declared so that ExecutableRunOptions can be linked
|
||||
// into an XLA-compiled binary without having to link all of the pointed-to
|
||||
// objects (e.g., for an ahead-of-time compiled CPU binary, the gpu tools don't
|
||||
// need to be linked).
|
||||
namespace perftools {
|
||||
namespace gputools {
|
||||
namespace stream_executor {
|
||||
class Stream;
|
||||
class Platform;
|
||||
}
|
||||
}
|
||||
} // namespace stream_executor
|
||||
|
||||
namespace tensorflow {
|
||||
namespace thread {
|
||||
class ThreadPool;
|
||||
}
|
||||
}
|
||||
} // namespace thread
|
||||
} // namespace tensorflow
|
||||
|
||||
namespace Eigen {
|
||||
struct ThreadPoolDevice;
|
||||
}
|
||||
} // namespace Eigen
|
||||
|
||||
namespace xla {
|
||||
|
||||
// TODO(b/77980417): Once the perftools::gputools -> stream_executor migration
|
||||
// is complete, add "using namespace se = stream_executor" here and
|
||||
// s/stream_executor/se::/ to match our idiom elsewhere.
|
||||
|
||||
class DeviceMemoryAllocator;
|
||||
class DeviceAssignment;
|
||||
class ExecutionProfile;
|
||||
@ -61,8 +63,8 @@ class ExecutableRunOptions {
|
||||
// If set, this is the stream to run the computation on. The platform of the
|
||||
// stream must match the platform the executable was built for. A value of
|
||||
// nullptr indicates the option has not been set.
|
||||
ExecutableRunOptions& set_stream(perftools::gputools::Stream* stream);
|
||||
perftools::gputools::Stream* stream() const;
|
||||
ExecutableRunOptions& set_stream(stream_executor::Stream* stream);
|
||||
stream_executor::Stream* stream() const;
|
||||
|
||||
// Sets the thread pool on which to run parallel CPU backend
|
||||
// computations. Does not take ownership.
|
||||
@ -91,7 +93,7 @@ class ExecutableRunOptions {
|
||||
DeviceMemoryAllocator* allocator_ = nullptr;
|
||||
int device_ordinal_ = -1;
|
||||
DeviceAssignment* device_assignment_ = nullptr;
|
||||
perftools::gputools::Stream* stream_ = nullptr;
|
||||
stream_executor::Stream* stream_ = nullptr;
|
||||
tensorflow::thread::ThreadPool* inter_op_thread_pool_ = nullptr;
|
||||
const Eigen::ThreadPoolDevice* intra_op_thread_pool_ = nullptr;
|
||||
ExecutionProfile* execution_profile_ = nullptr;
|
||||
|
@ -46,4 +46,12 @@ using ::Eigen::half;
|
||||
|
||||
} // namespace xla
|
||||
|
||||
// Alias namespace ::stream_executor as ::xla::se.
|
||||
namespace stream_executor {}
|
||||
namespace xla {
|
||||
// TODO(b/77980417): Uncomment this once all namespace aliases named 'se' are
|
||||
// removed in ::xla.
|
||||
// namespace se = ::stream_executor;
|
||||
} // namespace xla
|
||||
|
||||
#endif // TENSORFLOW_COMPILER_XLA_TYPES_H_
|
||||
|
@ -29,13 +29,11 @@ limitations under the License.
|
||||
#include "tensorflow/core/platform/thread_annotations.h"
|
||||
#include "tensorflow/core/platform/types.h"
|
||||
|
||||
namespace perftools {
|
||||
namespace gputools {
|
||||
namespace stream_executor {
|
||||
class Event;
|
||||
class Stream;
|
||||
class StreamExecutor;
|
||||
} // namespace gputools
|
||||
} // namespace perftools
|
||||
} // namespace stream_executor
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
|
@ -23,7 +23,10 @@ limitations under the License.
|
||||
#include "tensorflow/core/platform/stream_executor.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace gpu = ::perftools::gputools;
|
||||
|
||||
// TODO(b/77980417): Remove this and use the regular tensorflow::se alias once
|
||||
// that's available.
|
||||
namespace gpu = ::stream_executor;
|
||||
|
||||
// Utility methods for translation between Tensorflow GPU ids and CUDA GPU ids.
|
||||
class GpuIdUtil {
|
||||
|
@ -18,11 +18,9 @@ limitations under the License.
|
||||
|
||||
#include "tensorflow/core/lib/core/status.h"
|
||||
|
||||
namespace perftools {
|
||||
namespace gputools {
|
||||
namespace stream_executor {
|
||||
class Platform;
|
||||
} // namespace gputools
|
||||
} // namespace perftools
|
||||
} // namespace stream_executor
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
@ -34,7 +32,7 @@ Status ValidateGPUMachineManager();
|
||||
// initializing the GPUs on the machine if needed the first time it is
|
||||
// called. Must only be called when there is a valid GPU environment
|
||||
// in the process (e.g., ValidateGPUMachineManager() returns OK).
|
||||
perftools::gputools::Platform* GPUMachineManager();
|
||||
stream_executor::Platform* GPUMachineManager();
|
||||
|
||||
} // namespace tensorflow
|
||||
|
||||
|
@ -60,7 +60,9 @@ using perftools::gputools::Stream;
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
namespace gpu = ::perftools::gputools;
|
||||
// TODO(b/77980417): Remove this and use the regular tensorflow::se alias once
|
||||
// that's available.
|
||||
namespace gpu = ::stream_executor;
|
||||
|
||||
Status PrepareCopy(Device* device, const DeviceContext* ctx, const Tensor& src,
|
||||
const Tensor* dst,
|
||||
|
@ -27,7 +27,9 @@ namespace tensorflow {
|
||||
class RecvTensorResponse;
|
||||
class TensorProto;
|
||||
|
||||
namespace gpu = ::perftools::gputools;
|
||||
// TODO(b/77980417): Remove this and use the regular tensorflow::se alias once
|
||||
// that's available.
|
||||
namespace gpu = ::stream_executor;
|
||||
|
||||
class GPUUtil {
|
||||
public:
|
||||
|
@ -19,23 +19,22 @@ limitations under the License.
|
||||
#include "tensorflow/core/common_runtime/device.h"
|
||||
#include "tensorflow/core/framework/device_base.h"
|
||||
|
||||
namespace perftools {
|
||||
namespace gputools {
|
||||
namespace stream_executor {
|
||||
class Stream;
|
||||
} // namespace gputools
|
||||
} // namespace perftools
|
||||
} // namespace stream_executor
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
namespace gpu = ::perftools::gputools;
|
||||
// TODO(b/77980417): Replace stream_executor:: with se:: once our namespace
|
||||
// migration is complete and the alias is available.
|
||||
|
||||
class GPUDeviceContext : public DeviceContext {
|
||||
public:
|
||||
// Does not take ownership of streams.
|
||||
GPUDeviceContext(int stream_id, gpu::Stream* stream,
|
||||
gpu::Stream* host_to_device_stream,
|
||||
gpu::Stream* device_to_host_stream,
|
||||
gpu::Stream* device_to_device_stream)
|
||||
GPUDeviceContext(int stream_id, stream_executor::Stream* stream,
|
||||
stream_executor::Stream* host_to_device_stream,
|
||||
stream_executor::Stream* device_to_host_stream,
|
||||
stream_executor::Stream* device_to_device_stream)
|
||||
: stream_id_(stream_id),
|
||||
stream_(stream),
|
||||
host_to_device_stream_(host_to_device_stream),
|
||||
@ -44,10 +43,14 @@ class GPUDeviceContext : public DeviceContext {
|
||||
|
||||
~GPUDeviceContext() override {}
|
||||
|
||||
gpu::Stream* stream() const override { return stream_; }
|
||||
gpu::Stream* host_to_device_stream() const { return host_to_device_stream_; }
|
||||
gpu::Stream* device_to_host_stream() const { return device_to_host_stream_; }
|
||||
gpu::Stream* device_to_device_stream() const {
|
||||
stream_executor::Stream* stream() const override { return stream_; }
|
||||
stream_executor::Stream* host_to_device_stream() const {
|
||||
return host_to_device_stream_;
|
||||
}
|
||||
stream_executor::Stream* device_to_host_stream() const {
|
||||
return device_to_host_stream_;
|
||||
}
|
||||
stream_executor::Stream* device_to_device_stream() const {
|
||||
return device_to_device_stream_;
|
||||
}
|
||||
int stream_id() const { return stream_id_; }
|
||||
@ -67,13 +70,13 @@ class GPUDeviceContext : public DeviceContext {
|
||||
int stream_id_;
|
||||
// The default primary stream to use for this context.
|
||||
// All the memory belongs to this stream.
|
||||
gpu::Stream* stream_;
|
||||
stream_executor::Stream* stream_;
|
||||
// The stream to use for copy data from host into GPU.
|
||||
gpu::Stream* host_to_device_stream_;
|
||||
stream_executor::Stream* host_to_device_stream_;
|
||||
// The stream to use for copy data from GPU to host.
|
||||
gpu::Stream* device_to_host_stream_;
|
||||
stream_executor::Stream* device_to_host_stream_;
|
||||
// The stream to use for copy data between GPU.
|
||||
gpu::Stream* device_to_device_stream_;
|
||||
stream_executor::Stream* device_to_device_stream_;
|
||||
};
|
||||
|
||||
} // namespace tensorflow
|
||||
|
@ -34,11 +34,9 @@ struct SyclDevice;
|
||||
#endif
|
||||
} // end namespace Eigen
|
||||
|
||||
namespace perftools {
|
||||
namespace gputools {
|
||||
namespace stream_executor {
|
||||
class Stream;
|
||||
} // namespace gputools
|
||||
} // namespace perftools
|
||||
} // namespace stream_executor
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
@ -69,9 +67,10 @@ class PerOpGpuDevice {
|
||||
class DeviceContext : public core::RefCounted {
|
||||
public:
|
||||
~DeviceContext() override {}
|
||||
virtual perftools::gputools::Stream* stream() const { return nullptr; }
|
||||
virtual void MaintainLifetimeOnStream(
|
||||
const Tensor* t, perftools::gputools::Stream* stream) const {}
|
||||
virtual stream_executor::Stream* stream() const { return nullptr; }
|
||||
virtual void MaintainLifetimeOnStream(const Tensor* t,
|
||||
stream_executor::Stream* stream) const {
|
||||
}
|
||||
|
||||
// "cpu_tensor" is a tensor on a CPU. Copies "cpu_tensor" into
|
||||
// "device_tensor" which is on a GPU device "device". "device_tensor"
|
||||
@ -133,7 +132,7 @@ class DeviceBase {
|
||||
// but also by TPU devices (to provide default device context).
|
||||
struct GpuDeviceInfo {
|
||||
// Make sure all the defaults are NULL, so we can spot missing assignments.
|
||||
perftools::gputools::Stream* stream = nullptr;
|
||||
stream_executor::Stream* stream = nullptr;
|
||||
DeviceContext* default_context = nullptr;
|
||||
EventMgr* event_mgr = nullptr;
|
||||
int gpu_id = -1;
|
||||
|
@ -23,8 +23,6 @@ limitations under the License.
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
namespace gpu = ::perftools::gputools;
|
||||
|
||||
// On the open-source platform, stream_executor currently uses
|
||||
// tensorflow::Status
|
||||
inline Status FromStreamExecutorStatus(
|
||||
|
@ -37,4 +37,15 @@ limitations under the License.
|
||||
#include "tensorflow/stream_executor/stream.h"
|
||||
#include "tensorflow/stream_executor/stream_executor.h"
|
||||
|
||||
namespace perftools {
|
||||
namespace gputools {
|
||||
|
||||
// Temporarily pull stream_executor into perftools::gputools while we migrate
|
||||
// code to the new namespace. TODO(jlebar): Remove this once we've completed
|
||||
// the migration.
|
||||
using namespace stream_executor; // NOLINT[build/namespaces]
|
||||
|
||||
} // namespace gputools
|
||||
} // namespace perftools
|
||||
|
||||
#endif // TENSORFLOW_PLATFORM_STREAM_EXECUTOR_H_
|
||||
|
@ -60,4 +60,12 @@ typedef uint64 Fprint;
|
||||
|
||||
} // namespace tensorflow
|
||||
|
||||
// Alias namespace ::stream_executor as ::tensorflow::se.
|
||||
namespace stream_executor {}
|
||||
namespace tensorflow {
|
||||
// TODO(b/77980417): Uncomment this once all namespace aliases named 'se' are
|
||||
// removed in ::xla.
|
||||
// namespace se = ::stream_executor;
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // TENSORFLOW_PLATFORM_TYPES_H_
|
||||
|
@ -17,8 +17,7 @@ limitations under the License.
|
||||
|
||||
#include "tensorflow/stream_executor/lib/strcat.h"
|
||||
|
||||
namespace perftools {
|
||||
namespace gputools {
|
||||
namespace stream_executor {
|
||||
namespace blas {
|
||||
|
||||
string TransposeString(Transpose t) {
|
||||
@ -95,5 +94,4 @@ std::ostream& operator<<(std::ostream& os, ComputationType ty) {
|
||||
}
|
||||
|
||||
} // namespace blas
|
||||
} // namespace gputools
|
||||
} // namespace perftools
|
||||
} // namespace stream_executor
|
||||
|
@ -49,8 +49,7 @@ namespace Eigen {
|
||||
struct half;
|
||||
} // namespace Eigen
|
||||
|
||||
namespace perftools {
|
||||
namespace gputools {
|
||||
namespace stream_executor {
|
||||
|
||||
class Stream;
|
||||
class ScratchAllocator;
|
||||
@ -2100,7 +2099,6 @@ class BlasSupport {
|
||||
DeviceMemory<std::complex<double>> *b, int ldb) override;
|
||||
|
||||
} // namespace blas
|
||||
} // namespace gputools
|
||||
} // namespace perftools
|
||||
} // namespace stream_executor
|
||||
|
||||
#endif // TENSORFLOW_STREAM_EXECUTOR_BLAS_H_
|
||||
|
@ -19,8 +19,7 @@ limitations under the License.
|
||||
#include "tensorflow/stream_executor/stream_executor.h"
|
||||
#include "tensorflow/stream_executor/stream_executor_internal.h"
|
||||
|
||||
namespace perftools {
|
||||
namespace gputools {
|
||||
namespace stream_executor {
|
||||
namespace cuda {
|
||||
|
||||
CudaContext* ExtractCudaContext(CUDAExecutor *cuda_exec);
|
||||
@ -40,5 +39,4 @@ ScopedActivateExecutorContext::~ScopedActivateExecutorContext() {
|
||||
}
|
||||
|
||||
} // namespace cuda
|
||||
} // namespace gputools
|
||||
} // namespace perftools
|
||||
} // namespace stream_executor
|
||||
|
@ -25,8 +25,7 @@ limitations under the License.
|
||||
|
||||
#include "tensorflow/stream_executor/platform/port.h"
|
||||
|
||||
namespace perftools {
|
||||
namespace gputools {
|
||||
namespace stream_executor {
|
||||
|
||||
class StreamExecutor;
|
||||
|
||||
@ -56,7 +55,6 @@ class ScopedActivateExecutorContext {
|
||||
};
|
||||
|
||||
} // namespace cuda
|
||||
} // namespace gputools
|
||||
} // namespace perftools
|
||||
} // namespace stream_executor
|
||||
|
||||
#endif // TENSORFLOW_STREAM_EXECUTOR_CUDA_CUDA_ACTIVATION_H_
|
||||
|
@ -75,15 +75,14 @@ limitations under the License.
|
||||
#include "tensorflow/stream_executor/scratch_allocator.h"
|
||||
#include "tensorflow/stream_executor/stream_executor.h"
|
||||
|
||||
namespace perftools {
|
||||
namespace gputools {
|
||||
namespace stream_executor {
|
||||
namespace cuda {
|
||||
|
||||
PLUGIN_REGISTRY_DEFINE_PLUGIN_ID(kCuBlasPlugin);
|
||||
|
||||
namespace wrap {
|
||||
|
||||
#define PERFTOOLS_GPUTOOLS_CUBLAS_WRAP(__name) \
|
||||
#define STREAM_EXECUTOR_CUBLAS_WRAP(__name) \
|
||||
struct WrapperShim__##__name { \
|
||||
static const char *kName; \
|
||||
template <typename... Args> \
|
||||
@ -94,8 +93,8 @@ namespace wrap {
|
||||
} __name; \
|
||||
const char *WrapperShim__##__name::kName = #__name;
|
||||
|
||||
#define PERFTOOLS_GPUTOOLS_CUBLAS_V2_WRAP(__name) \
|
||||
PERFTOOLS_GPUTOOLS_CUBLAS_WRAP(__name)
|
||||
#define STREAM_EXECUTOR_CUBLAS_V2_WRAP(__name) \
|
||||
STREAM_EXECUTOR_CUBLAS_WRAP(__name)
|
||||
|
||||
#define CUBLAS_BLAS_ROUTINE_EACH(__macro) \
|
||||
__macro(cublasSnrm2) \
|
||||
@ -269,28 +268,28 @@ namespace wrap {
|
||||
__macro(cublasCdgmm) \
|
||||
__macro(cublasZdgmm)
|
||||
|
||||
PERFTOOLS_GPUTOOLS_CUBLAS_V2_WRAP(cublasCreate)
|
||||
PERFTOOLS_GPUTOOLS_CUBLAS_V2_WRAP(cublasDestroy)
|
||||
PERFTOOLS_GPUTOOLS_CUBLAS_V2_WRAP(cublasSetStream)
|
||||
PERFTOOLS_GPUTOOLS_CUBLAS_V2_WRAP(cublasSetPointerMode)
|
||||
PERFTOOLS_GPUTOOLS_CUBLAS_V2_WRAP(cublasGetPointerMode)
|
||||
PERFTOOLS_GPUTOOLS_CUBLAS_WRAP(cublasSgemmBatched)
|
||||
PERFTOOLS_GPUTOOLS_CUBLAS_WRAP(cublasDgemmBatched)
|
||||
PERFTOOLS_GPUTOOLS_CUBLAS_WRAP(cublasCgemmBatched)
|
||||
PERFTOOLS_GPUTOOLS_CUBLAS_WRAP(cublasZgemmBatched)
|
||||
CUBLAS_BLAS_ROUTINE_EACH(PERFTOOLS_GPUTOOLS_CUBLAS_V2_WRAP)
|
||||
STREAM_EXECUTOR_CUBLAS_V2_WRAP(cublasCreate)
|
||||
STREAM_EXECUTOR_CUBLAS_V2_WRAP(cublasDestroy)
|
||||
STREAM_EXECUTOR_CUBLAS_V2_WRAP(cublasSetStream)
|
||||
STREAM_EXECUTOR_CUBLAS_V2_WRAP(cublasSetPointerMode)
|
||||
STREAM_EXECUTOR_CUBLAS_V2_WRAP(cublasGetPointerMode)
|
||||
STREAM_EXECUTOR_CUBLAS_WRAP(cublasSgemmBatched)
|
||||
STREAM_EXECUTOR_CUBLAS_WRAP(cublasDgemmBatched)
|
||||
STREAM_EXECUTOR_CUBLAS_WRAP(cublasCgemmBatched)
|
||||
STREAM_EXECUTOR_CUBLAS_WRAP(cublasZgemmBatched)
|
||||
CUBLAS_BLAS_ROUTINE_EACH(STREAM_EXECUTOR_CUBLAS_V2_WRAP)
|
||||
|
||||
#if CUDA_VERSION >= 7050
|
||||
PERFTOOLS_GPUTOOLS_CUBLAS_WRAP(cublasSgemmEx)
|
||||
STREAM_EXECUTOR_CUBLAS_WRAP(cublasSgemmEx)
|
||||
#endif
|
||||
|
||||
#if CUDA_VERSION >= 8000
|
||||
PERFTOOLS_GPUTOOLS_CUBLAS_WRAP(cublasGemmEx)
|
||||
STREAM_EXECUTOR_CUBLAS_WRAP(cublasGemmEx)
|
||||
#endif
|
||||
|
||||
#if CUDA_VERSION >= 9000
|
||||
PERFTOOLS_GPUTOOLS_CUBLAS_WRAP(cublasGetMathMode)
|
||||
PERFTOOLS_GPUTOOLS_CUBLAS_WRAP(cublasSetMathMode)
|
||||
STREAM_EXECUTOR_CUBLAS_WRAP(cublasGetMathMode)
|
||||
STREAM_EXECUTOR_CUBLAS_WRAP(cublasSetMathMode)
|
||||
#endif
|
||||
|
||||
} // namespace wrap
|
||||
@ -2803,46 +2802,39 @@ bool CUDABlas::DoBlasTrsm(Stream *stream, blas::Side side,
|
||||
|
||||
} // namespace cuda
|
||||
|
||||
namespace gpu = ::perftools::gputools;
|
||||
|
||||
void initialize_cublas() {
|
||||
gpu::port::Status status =
|
||||
gpu::PluginRegistry::Instance()
|
||||
->RegisterFactory<gpu::PluginRegistry::BlasFactory>(
|
||||
gpu::cuda::kCudaPlatformId, gpu::cuda::kCuBlasPlugin, "cuBLAS",
|
||||
[](gpu::internal::StreamExecutorInterface
|
||||
*parent) -> gpu::blas::BlasSupport * {
|
||||
gpu::cuda::CUDAExecutor *cuda_executor =
|
||||
dynamic_cast<gpu::cuda::CUDAExecutor *>(parent);
|
||||
if (cuda_executor == nullptr) {
|
||||
LOG(ERROR)
|
||||
<< "Attempting to initialize an instance of the cuBLAS "
|
||||
<< "support library with a non-CUDA StreamExecutor";
|
||||
return nullptr;
|
||||
}
|
||||
port::Status status =
|
||||
PluginRegistry::Instance()->RegisterFactory<PluginRegistry::BlasFactory>(
|
||||
cuda::kCudaPlatformId, cuda::kCuBlasPlugin, "cuBLAS",
|
||||
[](internal::StreamExecutorInterface *parent) -> blas::BlasSupport * {
|
||||
cuda::CUDAExecutor *cuda_executor =
|
||||
dynamic_cast<cuda::CUDAExecutor *>(parent);
|
||||
if (cuda_executor == nullptr) {
|
||||
LOG(ERROR)
|
||||
<< "Attempting to initialize an instance of the cuBLAS "
|
||||
<< "support library with a non-CUDA StreamExecutor";
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
gpu::cuda::CUDABlas *blas =
|
||||
new gpu::cuda::CUDABlas(cuda_executor);
|
||||
if (!blas->Init()) {
|
||||
// Note: Init() will log a more specific error.
|
||||
delete blas;
|
||||
return nullptr;
|
||||
}
|
||||
return blas;
|
||||
});
|
||||
cuda::CUDABlas *blas = new cuda::CUDABlas(cuda_executor);
|
||||
if (!blas->Init()) {
|
||||
// Note: Init() will log a more specific error.
|
||||
delete blas;
|
||||
return nullptr;
|
||||
}
|
||||
return blas;
|
||||
});
|
||||
|
||||
if (!status.ok()) {
|
||||
LOG(ERROR) << "Unable to register cuBLAS factory: "
|
||||
<< status.error_message();
|
||||
}
|
||||
|
||||
gpu::PluginRegistry::Instance()->SetDefaultFactory(gpu::cuda::kCudaPlatformId,
|
||||
gpu::PluginKind::kBlas,
|
||||
gpu::cuda::kCuBlasPlugin);
|
||||
PluginRegistry::Instance()->SetDefaultFactory(
|
||||
cuda::kCudaPlatformId, PluginKind::kBlas, cuda::kCuBlasPlugin);
|
||||
}
|
||||
|
||||
} // namespace gputools
|
||||
} // namespace perftools
|
||||
} // namespace stream_executor
|
||||
|
||||
REGISTER_MODULE_INITIALIZER(register_cublas,
|
||||
{ perftools::gputools::initialize_cublas(); });
|
||||
{ stream_executor::initialize_cublas(); });
|
||||
|
@ -29,8 +29,7 @@ limitations under the License.
|
||||
|
||||
typedef struct cublasContext *cublasHandle_t;
|
||||
|
||||
namespace perftools {
|
||||
namespace gputools {
|
||||
namespace stream_executor {
|
||||
|
||||
class Stream;
|
||||
|
||||
@ -162,7 +161,6 @@ class CUDABlas : public blas::BlasSupport {
|
||||
};
|
||||
|
||||
} // namespace cuda
|
||||
} // namespace gputools
|
||||
} // namespace perftools
|
||||
} // namespace stream_executor
|
||||
|
||||
#endif // TENSORFLOW_STREAM_EXECUTOR_CUDA_CUDA_BLAS_H_
|
||||
|
@ -51,8 +51,7 @@ limitations under the License.
|
||||
#include "tensorflow/stream_executor/lib/str_util.h"
|
||||
#include "tensorflow/stream_executor/lib/inlined_vector.h"
|
||||
|
||||
namespace perftools {
|
||||
namespace gputools {
|
||||
namespace stream_executor {
|
||||
namespace cuda {
|
||||
|
||||
#ifdef __APPLE__
|
||||
@ -384,5 +383,4 @@ port::StatusOr<DriverVersion> Diagnostician::FindKernelDriverVersion() {
|
||||
|
||||
|
||||
} // namespace cuda
|
||||
} // namespace gputools
|
||||
} // namespace perftools
|
||||
} // namespace stream_executor
|
||||
|
@ -22,8 +22,7 @@ limitations under the License.
|
||||
#include "tensorflow/stream_executor/lib/statusor.h"
|
||||
#include "tensorflow/stream_executor/platform/port.h"
|
||||
|
||||
namespace perftools {
|
||||
namespace gputools {
|
||||
namespace stream_executor {
|
||||
namespace cuda {
|
||||
|
||||
// e.g. DriverVersion{346, 3, 4}
|
||||
@ -93,7 +92,6 @@ class Diagnostician {
|
||||
};
|
||||
|
||||
} // namespace cuda
|
||||
} // namespace gputools
|
||||
} // namespace perftools
|
||||
} // namespace stream_executor
|
||||
|
||||
#endif // TENSORFLOW_STREAM_EXECUTOR_CUDA_CUDA_DIAGNOSTICS_H_
|
||||
|
@ -59,8 +59,7 @@ NarrowT CheckedNarrowing(const WideT& wide) {
|
||||
|
||||
} // namespace
|
||||
|
||||
namespace perftools {
|
||||
namespace gputools {
|
||||
namespace stream_executor {
|
||||
|
||||
using dnn::BatchDescriptor;
|
||||
using dnn::FilterDescriptor;
|
||||
@ -159,7 +158,7 @@ static port::ThreadPool* GetCudaThreadpool() {
|
||||
return cudnn_threadpool;
|
||||
}
|
||||
|
||||
#define PERFTOOLS_GPUTOOLS_CUDNN_WRAP(__name) \
|
||||
#define STREAM_EXECUTOR_CUDNN_WRAP(__name) \
|
||||
struct WrapperShim__##__name { \
|
||||
template <typename... Args> \
|
||||
cudnnStatus_t operator()(CUDAExecutor* parent, Args... args) { \
|
||||
@ -169,7 +168,7 @@ static port::ThreadPool* GetCudaThreadpool() {
|
||||
} \
|
||||
} __name;
|
||||
|
||||
#define PERFTOOLS_GPUTOOLS_CUDNN_WRAP_WITH_CHECKED_STREAM(__name) \
|
||||
#define STREAM_EXECUTOR_CUDNN_WRAP_WITH_CHECKED_STREAM(__name) \
|
||||
struct WrapperShim__##__name { \
|
||||
template <typename... Args> \
|
||||
cudnnStatus_t operator()(CudnnSupport* dnn, Stream* s, Args... args) \
|
||||
@ -220,7 +219,7 @@ struct WrapperShim__cudnnSetStream {
|
||||
__macro(cudnnSetFilterNdDescriptor)
|
||||
|
||||
// clang-format on
|
||||
CUDNN_DNN_ROUTINE_EACH(PERFTOOLS_GPUTOOLS_CUDNN_WRAP)
|
||||
CUDNN_DNN_ROUTINE_EACH(STREAM_EXECUTOR_CUDNN_WRAP)
|
||||
#undef CUDNN_DNN_ROUTINE_EACH
|
||||
|
||||
// clang-format off
|
||||
@ -242,7 +241,7 @@ CUDNN_DNN_ROUTINE_EACH(PERFTOOLS_GPUTOOLS_CUDNN_WRAP)
|
||||
|
||||
// clang-format on
|
||||
CUDNN_DNN_ROUTINE_EACH_WITH_STREAM(
|
||||
PERFTOOLS_GPUTOOLS_CUDNN_WRAP_WITH_CHECKED_STREAM)
|
||||
STREAM_EXECUTOR_CUDNN_WRAP_WITH_CHECKED_STREAM)
|
||||
#undef CUDNN_DNN_ROUTINE_EACH_WITH_STREAM
|
||||
|
||||
// APIs available after R3:
|
||||
@ -252,7 +251,7 @@ CUDNN_DNN_ROUTINE_EACH_WITH_STREAM(
|
||||
__macro(cudnnGetConvolutionBackwardDataAlgorithm) \
|
||||
__macro(cudnnGetConvolutionBackwardFilterAlgorithm) \
|
||||
__macro(cudnnGetConvolutionBackwardDataWorkspaceSize)
|
||||
CUDNN_DNN_ROUTINE_EACH_AFTER_R3(PERFTOOLS_GPUTOOLS_CUDNN_WRAP)
|
||||
CUDNN_DNN_ROUTINE_EACH_AFTER_R3(STREAM_EXECUTOR_CUDNN_WRAP)
|
||||
#undef CUDNN_DNN_ROUTINE_EACH_AFTER_R3
|
||||
#endif
|
||||
|
||||
@ -266,7 +265,7 @@ CUDNN_DNN_ROUTINE_EACH_AFTER_R3(PERFTOOLS_GPUTOOLS_CUDNN_WRAP)
|
||||
// clang-format on
|
||||
|
||||
CUDNN_DNN_ROUTINE_EACH_R3_WITH_STREAM(
|
||||
PERFTOOLS_GPUTOOLS_CUDNN_WRAP_WITH_CHECKED_STREAM)
|
||||
STREAM_EXECUTOR_CUDNN_WRAP_WITH_CHECKED_STREAM)
|
||||
#undef CUDNN_DNN_ROUTINE_EACH_R3_WITH_STREAM
|
||||
#endif
|
||||
|
||||
@ -293,7 +292,7 @@ CUDNN_DNN_ROUTINE_EACH_R3_WITH_STREAM(
|
||||
__macro(cudnnGetFilterNdDescriptor)
|
||||
|
||||
// clang-format on
|
||||
CUDNN_DNN_ROUTINE_EACH_R5(PERFTOOLS_GPUTOOLS_CUDNN_WRAP)
|
||||
CUDNN_DNN_ROUTINE_EACH_R5(STREAM_EXECUTOR_CUDNN_WRAP)
|
||||
#undef CUDNN_DNN_ROUTINE_EACH_R5
|
||||
|
||||
// clang-format off
|
||||
@ -305,7 +304,7 @@ CUDNN_DNN_ROUTINE_EACH_R5(PERFTOOLS_GPUTOOLS_CUDNN_WRAP)
|
||||
|
||||
// clang-format on
|
||||
CUDNN_DNN_ROUTINE_EACH_R5_WITH_STREAM(
|
||||
PERFTOOLS_GPUTOOLS_CUDNN_WRAP_WITH_CHECKED_STREAM)
|
||||
STREAM_EXECUTOR_CUDNN_WRAP_WITH_CHECKED_STREAM)
|
||||
#undef CUDNN_DNN_ROUTINE_EACH_R5_WITH_STREAM
|
||||
#endif
|
||||
|
||||
@ -316,7 +315,7 @@ CUDNN_DNN_ROUTINE_EACH_R5_WITH_STREAM(
|
||||
__macro(cudnnSetRNNDescriptor_v6)
|
||||
|
||||
// clang-format on
|
||||
CUDNN_DNN_ROUTINE_EACH_R6(PERFTOOLS_GPUTOOLS_CUDNN_WRAP)
|
||||
CUDNN_DNN_ROUTINE_EACH_R6(STREAM_EXECUTOR_CUDNN_WRAP)
|
||||
#undef CUDNN_DNN_ROUTINE_EACH_R6
|
||||
|
||||
// clang-format off
|
||||
@ -325,7 +324,7 @@ CUDNN_DNN_ROUTINE_EACH_R6(PERFTOOLS_GPUTOOLS_CUDNN_WRAP)
|
||||
|
||||
// clang-format on
|
||||
CUDNN_DNN_ROUTINE_EACH_R6_WITH_STREAM(
|
||||
PERFTOOLS_GPUTOOLS_CUDNN_WRAP_WITH_CHECKED_STREAM)
|
||||
STREAM_EXECUTOR_CUDNN_WRAP_WITH_CHECKED_STREAM)
|
||||
#undef CUDNN_DNN_ROUTINE_EACH_R6_WITH_STREAM
|
||||
#endif
|
||||
|
||||
@ -337,7 +336,7 @@ CUDNN_DNN_ROUTINE_EACH_R6_WITH_STREAM(
|
||||
__macro(cudnnSetRNNMatrixMathType)
|
||||
|
||||
// clang-format on
|
||||
CUDNN_DNN_ROUTINE_EACH_R7(PERFTOOLS_GPUTOOLS_CUDNN_WRAP)
|
||||
CUDNN_DNN_ROUTINE_EACH_R7(STREAM_EXECUTOR_CUDNN_WRAP)
|
||||
#undef CUDNN_DNN_ROUTINE_EACH_R7
|
||||
#endif
|
||||
|
||||
@ -4727,46 +4726,39 @@ bool CudnnSupport::DeriveOutputBatchDescriptor(
|
||||
|
||||
} // namespace cuda
|
||||
|
||||
namespace gpu = ::perftools::gputools;
|
||||
|
||||
void initialize_cudnn() {
|
||||
gpu::port::Status status =
|
||||
gpu::PluginRegistry::Instance()
|
||||
->RegisterFactory<gpu::PluginRegistry::DnnFactory>(
|
||||
gpu::cuda::kCudaPlatformId, gpu::cuda::kCuDnnPlugin, "cuDNN",
|
||||
[](gpu::internal::StreamExecutorInterface*
|
||||
parent) -> gpu::dnn::DnnSupport* {
|
||||
gpu::cuda::CUDAExecutor* cuda_executor =
|
||||
dynamic_cast<gpu::cuda::CUDAExecutor*>(parent);
|
||||
if (cuda_executor == nullptr) {
|
||||
LOG(ERROR)
|
||||
<< "Attempting to initialize an instance of the cuBLAS "
|
||||
<< "support library with a non-CUDA StreamExecutor";
|
||||
return nullptr;
|
||||
}
|
||||
port::Status status =
|
||||
PluginRegistry::Instance()->RegisterFactory<PluginRegistry::DnnFactory>(
|
||||
cuda::kCudaPlatformId, cuda::kCuDnnPlugin, "cuDNN",
|
||||
[](internal::StreamExecutorInterface* parent) -> dnn::DnnSupport* {
|
||||
cuda::CUDAExecutor* cuda_executor =
|
||||
dynamic_cast<cuda::CUDAExecutor*>(parent);
|
||||
if (cuda_executor == nullptr) {
|
||||
LOG(ERROR)
|
||||
<< "Attempting to initialize an instance of the cuBLAS "
|
||||
<< "support library with a non-CUDA StreamExecutor";
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
gpu::cuda::CudnnSupport* dnn =
|
||||
new gpu::cuda::CudnnSupport(cuda_executor);
|
||||
if (!dnn->Init().ok()) {
|
||||
// Note: Init() will log a more specific error.
|
||||
delete dnn;
|
||||
return nullptr;
|
||||
}
|
||||
return dnn;
|
||||
});
|
||||
cuda::CudnnSupport* dnn = new cuda::CudnnSupport(cuda_executor);
|
||||
if (!dnn->Init().ok()) {
|
||||
// Note: Init() will log a more specific error.
|
||||
delete dnn;
|
||||
return nullptr;
|
||||
}
|
||||
return dnn;
|
||||
});
|
||||
|
||||
if (!status.ok()) {
|
||||
LOG(ERROR) << "Unable to register cuDNN factory: "
|
||||
<< status.error_message();
|
||||
}
|
||||
|
||||
gpu::PluginRegistry::Instance()->SetDefaultFactory(gpu::cuda::kCudaPlatformId,
|
||||
gpu::PluginKind::kDnn,
|
||||
gpu::cuda::kCuDnnPlugin);
|
||||
PluginRegistry::Instance()->SetDefaultFactory(
|
||||
cuda::kCudaPlatformId, PluginKind::kDnn, cuda::kCuDnnPlugin);
|
||||
}
|
||||
|
||||
} // namespace gputools
|
||||
} // namespace perftools
|
||||
} // namespace stream_executor
|
||||
|
||||
REGISTER_MODULE_INITIALIZER(register_cudnn,
|
||||
{ perftools::gputools::initialize_cudnn(); });
|
||||
{ stream_executor::initialize_cudnn(); });
|
||||
|
@ -26,8 +26,7 @@ limitations under the License.
|
||||
#include "tensorflow/stream_executor/plugin_registry.h"
|
||||
#include "tensorflow/stream_executor/temporary_device_memory.h"
|
||||
|
||||
namespace perftools {
|
||||
namespace gputools {
|
||||
namespace stream_executor {
|
||||
namespace cuda {
|
||||
|
||||
class CUDAExecutor;
|
||||
@ -810,7 +809,6 @@ class CudnnSupport : public dnn::DnnSupport {
|
||||
};
|
||||
|
||||
} // namespace cuda
|
||||
} // namespace gputools
|
||||
} // namespace perftools
|
||||
} // namespace stream_executor
|
||||
|
||||
#endif // TENSORFLOW_STREAM_EXECUTOR_CUDA_CUDA_DNN_H_
|
||||
|
@ -53,8 +53,7 @@ bool FLAGS_gpuexec_cuda_device_0_only = false;
|
||||
// matches the expected one.
|
||||
constexpr bool kVerifyCudaContext = false;
|
||||
|
||||
namespace perftools {
|
||||
namespace gputools {
|
||||
namespace stream_executor {
|
||||
namespace cuda {
|
||||
|
||||
namespace {
|
||||
@ -1649,5 +1648,4 @@ static port::StatusOr<T> GetSimpleAttribute(CUdevice device,
|
||||
}
|
||||
|
||||
} // namespace cuda
|
||||
} // namespace gputools
|
||||
} // namespace perftools
|
||||
} // namespace stream_executor
|
||||
|
@ -27,8 +27,7 @@ limitations under the License.
|
||||
#include "tensorflow/stream_executor/platform/port.h"
|
||||
#include "cuda/include/cuda.h"
|
||||
|
||||
namespace perftools {
|
||||
namespace gputools {
|
||||
namespace stream_executor {
|
||||
namespace cuda {
|
||||
|
||||
// Identifies the memory space where an allocation resides. See
|
||||
@ -506,7 +505,6 @@ class CudaContext {
|
||||
};
|
||||
|
||||
} // namespace cuda
|
||||
} // namespace gputools
|
||||
} // namespace perftools
|
||||
} // namespace stream_executor
|
||||
|
||||
#endif // TENSORFLOW_STREAM_EXECUTOR_CUDA_CUDA_DRIVER_H_
|
||||
|
@ -19,8 +19,7 @@ limitations under the License.
|
||||
#include "tensorflow/stream_executor/cuda/cuda_stream.h"
|
||||
#include "tensorflow/stream_executor/lib/statusor.h"
|
||||
|
||||
namespace perftools {
|
||||
namespace gputools {
|
||||
namespace stream_executor {
|
||||
namespace cuda {
|
||||
|
||||
CUDAEvent::CUDAEvent(CUDAExecutor* parent)
|
||||
@ -68,5 +67,4 @@ const CUevent& CUDAEvent::cuda_event() {
|
||||
}
|
||||
|
||||
} // namespace cuda
|
||||
} // namespace gputools
|
||||
} // namespace perftools
|
||||
} // namespace stream_executor
|
||||
|
@ -21,8 +21,7 @@ limitations under the License.
|
||||
#include "tensorflow/stream_executor/event.h"
|
||||
#include "tensorflow/stream_executor/lib/status.h"
|
||||
|
||||
namespace perftools {
|
||||
namespace gputools {
|
||||
namespace stream_executor {
|
||||
namespace cuda {
|
||||
|
||||
// CUDAEvent wraps a CUevent in the platform-independent EventInterface
|
||||
@ -58,7 +57,6 @@ class CUDAEvent : public internal::EventInterface {
|
||||
};
|
||||
|
||||
} // namespace cuda
|
||||
} // namespace gputools
|
||||
} // namespace perftools
|
||||
} // namespace stream_executor
|
||||
|
||||
#endif // TENSORFLOW_STREAM_EXECUTOR_CUDA_CUDA_EVENT_H_
|
||||
|
@ -31,8 +31,7 @@ limitations under the License.
|
||||
#include "tensorflow/stream_executor/plugin_registry.h"
|
||||
#include "tensorflow/stream_executor/stream_executor_internal.h"
|
||||
|
||||
namespace perftools {
|
||||
namespace gputools {
|
||||
namespace stream_executor {
|
||||
namespace cuda {
|
||||
|
||||
PLUGIN_REGISTRY_DEFINE_PLUGIN_ID(kCuFftPlugin);
|
||||
@ -44,7 +43,7 @@ namespace wrap {
|
||||
// manner on first use. This dynamic loading technique is used to avoid DSO
|
||||
// dependencies on vendor libraries which may or may not be available in the
|
||||
// deployed binary environment.
|
||||
#define PERFTOOLS_GPUTOOLS_CUFFT_WRAP(__name) \
|
||||
#define STREAM_EXECUTOR_CUFFT_WRAP(__name) \
|
||||
struct WrapperShim__##__name { \
|
||||
template <typename... Args> \
|
||||
cufftResult operator()(CUDAExecutor *parent, Args... args) { \
|
||||
@ -68,7 +67,7 @@ namespace wrap {
|
||||
__macro(cufftGetSizeMany) \
|
||||
__macro(cufftMakePlanMany)
|
||||
|
||||
CUFFT_ROUTINE_EACH(PERFTOOLS_GPUTOOLS_CUFFT_WRAP)
|
||||
CUFFT_ROUTINE_EACH(STREAM_EXECUTOR_CUFFT_WRAP)
|
||||
|
||||
} // namespace wrap
|
||||
|
||||
@ -514,62 +513,59 @@ bool CUDAFft::DoFftWithDirectionInternal(Stream *stream, fft::Plan *plan,
|
||||
return true;
|
||||
}
|
||||
|
||||
#define PERFTOOLS_GPUTOOLS_CUDA_DEFINE_FFT(__type, __fft_type1, __fft_type2, \
|
||||
__fft_type3) \
|
||||
bool CUDAFft::DoFft(Stream *stream, fft::Plan *plan, \
|
||||
const DeviceMemory<std::complex<__type>> &input, \
|
||||
DeviceMemory<std::complex<__type>> *output) { \
|
||||
return DoFftWithDirectionInternal( \
|
||||
stream, plan, wrap::cufftExec##__fft_type1, input, output); \
|
||||
} \
|
||||
bool CUDAFft::DoFft(Stream *stream, fft::Plan *plan, \
|
||||
const DeviceMemory<__type> &input, \
|
||||
DeviceMemory<std::complex<__type>> *output) { \
|
||||
return DoFftInternal(stream, plan, wrap::cufftExec##__fft_type2, input, \
|
||||
output); \
|
||||
} \
|
||||
bool CUDAFft::DoFft(Stream *stream, fft::Plan *plan, \
|
||||
const DeviceMemory<std::complex<__type>> &input, \
|
||||
DeviceMemory<__type> *output) { \
|
||||
return DoFftInternal(stream, plan, wrap::cufftExec##__fft_type3, input, \
|
||||
output); \
|
||||
#define STREAM_EXECUTOR_CUDA_DEFINE_FFT(__type, __fft_type1, __fft_type2, \
|
||||
__fft_type3) \
|
||||
bool CUDAFft::DoFft(Stream *stream, fft::Plan *plan, \
|
||||
const DeviceMemory<std::complex<__type>> &input, \
|
||||
DeviceMemory<std::complex<__type>> *output) { \
|
||||
return DoFftWithDirectionInternal( \
|
||||
stream, plan, wrap::cufftExec##__fft_type1, input, output); \
|
||||
} \
|
||||
bool CUDAFft::DoFft(Stream *stream, fft::Plan *plan, \
|
||||
const DeviceMemory<__type> &input, \
|
||||
DeviceMemory<std::complex<__type>> *output) { \
|
||||
return DoFftInternal(stream, plan, wrap::cufftExec##__fft_type2, input, \
|
||||
output); \
|
||||
} \
|
||||
bool CUDAFft::DoFft(Stream *stream, fft::Plan *plan, \
|
||||
const DeviceMemory<std::complex<__type>> &input, \
|
||||
DeviceMemory<__type> *output) { \
|
||||
return DoFftInternal(stream, plan, wrap::cufftExec##__fft_type3, input, \
|
||||
output); \
|
||||
}
|
||||
|
||||
PERFTOOLS_GPUTOOLS_CUDA_DEFINE_FFT(float, C2C, R2C, C2R)
|
||||
PERFTOOLS_GPUTOOLS_CUDA_DEFINE_FFT(double, Z2Z, D2Z, Z2D)
|
||||
STREAM_EXECUTOR_CUDA_DEFINE_FFT(float, C2C, R2C, C2R)
|
||||
STREAM_EXECUTOR_CUDA_DEFINE_FFT(double, Z2Z, D2Z, Z2D)
|
||||
|
||||
#undef PERFTOOLS_GPUTOOLS_CUDA_DEFINE_FFT
|
||||
#undef STREAM_EXECUTOR_CUDA_DEFINE_FFT
|
||||
|
||||
} // namespace cuda
|
||||
} // namespace gputools
|
||||
} // namespace perftools
|
||||
|
||||
namespace gpu = ::perftools::gputools;
|
||||
void initialize_cufft() {
|
||||
port::Status status =
|
||||
PluginRegistry::Instance()->RegisterFactory<PluginRegistry::FftFactory>(
|
||||
cuda::kCudaPlatformId, cuda::kCuFftPlugin, "cuFFT",
|
||||
[](internal::StreamExecutorInterface *parent) -> fft::FftSupport * {
|
||||
cuda::CUDAExecutor *cuda_executor =
|
||||
dynamic_cast<cuda::CUDAExecutor *>(parent);
|
||||
if (cuda_executor == nullptr) {
|
||||
LOG(ERROR) << "Attempting to initialize an instance of the cuFFT "
|
||||
<< "support library with a non-CUDA StreamExecutor";
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
REGISTER_MODULE_INITIALIZER(register_cufft, {
|
||||
gpu::port::Status status =
|
||||
gpu::PluginRegistry::Instance()
|
||||
->RegisterFactory<gpu::PluginRegistry::FftFactory>(
|
||||
gpu::cuda::kCudaPlatformId, gpu::cuda::kCuFftPlugin, "cuFFT",
|
||||
[](gpu::internal::StreamExecutorInterface
|
||||
*parent) -> gpu::fft::FftSupport * {
|
||||
gpu::cuda::CUDAExecutor *cuda_executor =
|
||||
dynamic_cast<gpu::cuda::CUDAExecutor *>(parent);
|
||||
if (cuda_executor == nullptr) {
|
||||
LOG(ERROR)
|
||||
<< "Attempting to initialize an instance of the cuFFT "
|
||||
<< "support library with a non-CUDA StreamExecutor";
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
return new gpu::cuda::CUDAFft(cuda_executor);
|
||||
});
|
||||
return new cuda::CUDAFft(cuda_executor);
|
||||
});
|
||||
if (!status.ok()) {
|
||||
LOG(ERROR) << "Unable to register cuFFT factory: "
|
||||
<< status.error_message();
|
||||
}
|
||||
|
||||
gpu::PluginRegistry::Instance()->SetDefaultFactory(gpu::cuda::kCudaPlatformId,
|
||||
gpu::PluginKind::kFft,
|
||||
gpu::cuda::kCuFftPlugin);
|
||||
});
|
||||
PluginRegistry::Instance()->SetDefaultFactory(
|
||||
cuda::kCudaPlatformId, PluginKind::kFft, cuda::kCuFftPlugin);
|
||||
}
|
||||
|
||||
} // namespace stream_executor
|
||||
|
||||
REGISTER_MODULE_INITIALIZER(register_cufft,
|
||||
{ stream_executor::initialize_cufft(); });
|
||||
|
@ -26,8 +26,7 @@ limitations under the License.
|
||||
#include "tensorflow/stream_executor/plugin_registry.h"
|
||||
#include "tensorflow/stream_executor/scratch_allocator.h"
|
||||
|
||||
namespace perftools {
|
||||
namespace gputools {
|
||||
namespace stream_executor {
|
||||
|
||||
class Stream;
|
||||
|
||||
@ -133,7 +132,6 @@ class CUDAFft : public fft::FftSupport {
|
||||
};
|
||||
|
||||
} // namespace cuda
|
||||
} // namespace gputools
|
||||
} // namespace perftools
|
||||
} // namespace stream_executor
|
||||
|
||||
#endif // TENSORFLOW_STREAM_EXECUTOR_CUDA_CUDA_FFT_H_
|
||||
|
@ -66,8 +66,7 @@ limitations under the License.
|
||||
extern bool FLAGS_check_gpu_leaks;
|
||||
bool FLAGS_prefer_cubin_to_ptx = true;
|
||||
|
||||
namespace perftools {
|
||||
namespace gputools {
|
||||
namespace stream_executor {
|
||||
namespace cuda {
|
||||
|
||||
// Hook that can be used to CUBIN-ate PTX before it is loaded into the driver.
|
||||
@ -1168,17 +1167,14 @@ DeviceDescription *CUDAExecutor::PopulateDeviceDescription() const {
|
||||
|
||||
} // namespace cuda
|
||||
|
||||
namespace gpu = ::perftools::gputools;
|
||||
|
||||
void initialize_cuda_gpu_executor() {
|
||||
*gpu::internal::MakeCUDAExecutorImplementation() = [](
|
||||
const gpu::PluginConfig &config) {
|
||||
return new gpu::cuda::CUDAExecutor{config};
|
||||
*internal::MakeCUDAExecutorImplementation() = [](const PluginConfig &config) {
|
||||
return new cuda::CUDAExecutor{config};
|
||||
};
|
||||
}
|
||||
|
||||
} // namespace gputools
|
||||
} // namespace perftools
|
||||
} // namespace stream_executor
|
||||
|
||||
REGISTER_MODULE_INITIALIZER(
|
||||
cuda_gpu_executor, {perftools::gputools::initialize_cuda_gpu_executor();});
|
||||
REGISTER_MODULE_INITIALIZER(cuda_gpu_executor, {
|
||||
stream_executor::initialize_cuda_gpu_executor();
|
||||
});
|
||||
|
@ -35,8 +35,7 @@ limitations under the License.
|
||||
#include "tensorflow/stream_executor/platform/thread_annotations.h"
|
||||
#include "tensorflow/stream_executor/stream_executor_internal.h"
|
||||
|
||||
namespace perftools {
|
||||
namespace gputools {
|
||||
namespace stream_executor {
|
||||
namespace cuda {
|
||||
|
||||
// CUDA-platform implementation of the platform-agnostic
|
||||
@ -273,7 +272,6 @@ class CUDAExecutor : public internal::StreamExecutorInterface {
|
||||
};
|
||||
|
||||
} // namespace cuda
|
||||
} // namespace gputools
|
||||
} // namespace perftools
|
||||
} // namespace stream_executor
|
||||
|
||||
#endif // TENSORFLOW_STREAM_EXECUTOR_CUDA_CUDA_GPU_EXECUTOR_H_
|
||||
|
@ -27,8 +27,7 @@ limitations under the License.
|
||||
#include "cuda/include/cuComplex.h"
|
||||
#include "cuda/include/cuda.h"
|
||||
|
||||
namespace perftools {
|
||||
namespace gputools {
|
||||
namespace stream_executor {
|
||||
|
||||
template <typename ElemT>
|
||||
class DeviceMemory;
|
||||
@ -101,7 +100,6 @@ inline cuDoubleComplex CUDAComplexValue(std::complex<double> val) {
|
||||
}
|
||||
|
||||
} // namespace cuda
|
||||
} // namespace gputools
|
||||
} // namespace perftools
|
||||
} // namespace stream_executor
|
||||
|
||||
#endif // TENSORFLOW_STREAM_EXECUTOR_CUDA_CUDA_HELPERS_H_
|
||||
|
@ -40,8 +40,7 @@ limitations under the License.
|
||||
"CUDA runtime being included into CUDA GPU executor; should be driver only."
|
||||
#endif
|
||||
|
||||
namespace perftools {
|
||||
namespace gputools {
|
||||
namespace stream_executor {
|
||||
namespace cuda {
|
||||
|
||||
// Wraps a CUfunction to implement the platform-independent KernelInterface.
|
||||
@ -124,7 +123,6 @@ inline CUDAKernel *AsCUDAKernel(KernelBase *kernel) {
|
||||
}
|
||||
|
||||
} // namespace cuda
|
||||
} // namespace gputools
|
||||
} // namespace perftools
|
||||
} // namespace stream_executor
|
||||
|
||||
#endif // TENSORFLOW_STREAM_EXECUTOR_CUDA_CUDA_KERNEL_H_
|
||||
|
@ -24,8 +24,7 @@ limitations under the License.
|
||||
#include "tensorflow/stream_executor/lib/status.h"
|
||||
#include "tensorflow/stream_executor/lib/stringprintf.h"
|
||||
|
||||
namespace perftools {
|
||||
namespace gputools {
|
||||
namespace stream_executor {
|
||||
namespace cuda {
|
||||
namespace {
|
||||
|
||||
@ -41,16 +40,16 @@ const DeviceOptions GetDeviceOptionsFromEnv() {
|
||||
std::getenv("TF_CUDA_PLATFORM_GPU_DEVICE_SCHEDULE");
|
||||
|
||||
if (gpu_schedule_string == nullptr) {
|
||||
return perftools::gputools::DeviceOptions::Default();
|
||||
return DeviceOptions::Default();
|
||||
}
|
||||
|
||||
unsigned device_flags = 0;
|
||||
if (strcmp(kScheduleSpinString, gpu_schedule_string) == 0) {
|
||||
device_flags = perftools::gputools::DeviceOptions::kScheduleSpin;
|
||||
device_flags = DeviceOptions::kScheduleSpin;
|
||||
} else if (strcmp(kScheduleYieldString, gpu_schedule_string) == 0) {
|
||||
device_flags = perftools::gputools::DeviceOptions::kScheduleYield;
|
||||
device_flags = DeviceOptions::kScheduleYield;
|
||||
} else if (strcmp(kScheduleBlockingSyncString, gpu_schedule_string) == 0) {
|
||||
device_flags = perftools::gputools::DeviceOptions::kScheduleBlockingSync;
|
||||
device_flags = DeviceOptions::kScheduleBlockingSync;
|
||||
} else {
|
||||
LOG(QFATAL) << "Unknown option for environment variable "
|
||||
"TF_CUDA_PLATFORM_GPU_DEVICE_SCHEDULE "
|
||||
@ -59,7 +58,7 @@ const DeviceOptions GetDeviceOptionsFromEnv() {
|
||||
<< ", " << kScheduleYieldString << "}";
|
||||
}
|
||||
|
||||
return perftools::gputools::DeviceOptions(device_flags);
|
||||
return DeviceOptions(device_flags);
|
||||
}
|
||||
|
||||
} // namespace
|
||||
@ -202,11 +201,10 @@ static void InitializeCudaPlatform() {
|
||||
SE_CHECK_OK(MultiPlatformManager::RegisterPlatform(std::move(platform)));
|
||||
}
|
||||
|
||||
} // namespace gputools
|
||||
} // namespace perftools
|
||||
} // namespace stream_executor
|
||||
|
||||
REGISTER_MODULE_INITIALIZER(cuda_platform,
|
||||
perftools::gputools::InitializeCudaPlatform());
|
||||
stream_executor::InitializeCudaPlatform());
|
||||
|
||||
DECLARE_MODULE_INITIALIZER(multi_platform_manager);
|
||||
// Note that module initialization sequencing is not supported in the
|
||||
|
@ -31,8 +31,7 @@ limitations under the License.
|
||||
#include "tensorflow/stream_executor/stream_executor_pimpl.h"
|
||||
#include "tensorflow/stream_executor/trace_listener.h"
|
||||
|
||||
namespace perftools {
|
||||
namespace gputools {
|
||||
namespace stream_executor {
|
||||
namespace cuda {
|
||||
|
||||
// Opaque and unique identifier for the CUDA platform plugin.
|
||||
@ -104,7 +103,6 @@ class CudaPlatform : public Platform {
|
||||
};
|
||||
|
||||
} // namespace cuda
|
||||
} // namespace gputools
|
||||
} // namespace perftools
|
||||
} // namespace stream_executor
|
||||
|
||||
#endif // TENSORFLOW_STREAM_EXECUTOR_CUDA_CUDA_PLATFORM_H_
|
||||
|
@ -15,12 +15,10 @@ limitations under the License.
|
||||
|
||||
#include "tensorflow/stream_executor/cuda/cuda_platform_id.h"
|
||||
|
||||
namespace perftools {
|
||||
namespace gputools {
|
||||
namespace stream_executor {
|
||||
namespace cuda {
|
||||
|
||||
PLATFORM_DEFINE_ID(kCudaPlatformId);
|
||||
|
||||
} // namespace cuda
|
||||
} // namespace gputools
|
||||
} // namespace perftools
|
||||
} // namespace stream_executor
|
||||
|
@ -18,8 +18,7 @@ limitations under the License.
|
||||
|
||||
#include "tensorflow/stream_executor/platform.h"
|
||||
|
||||
namespace perftools {
|
||||
namespace gputools {
|
||||
namespace stream_executor {
|
||||
namespace cuda {
|
||||
|
||||
// Opaque and unique identifier for the cuda platform.
|
||||
@ -30,7 +29,6 @@ namespace cuda {
|
||||
extern const Platform::Id kCudaPlatformId;
|
||||
|
||||
} // namespace cuda
|
||||
} // namespace gputools
|
||||
} // namespace perftools
|
||||
} // namespace stream_executor
|
||||
|
||||
#endif // TENSORFLOW_STREAM_EXECUTOR_CUDA_CUDA_PLATFORM_ID_H_
|
||||
|
@ -54,15 +54,14 @@ std::ostream &operator<<(std::ostream &in, const curandStatus_t &status) {
|
||||
}
|
||||
}
|
||||
|
||||
namespace perftools {
|
||||
namespace gputools {
|
||||
namespace stream_executor {
|
||||
namespace cuda {
|
||||
|
||||
PLUGIN_REGISTRY_DEFINE_PLUGIN_ID(kCuRandPlugin);
|
||||
|
||||
namespace wrap {
|
||||
|
||||
#define PERFTOOLS_GPUTOOLS_CURAND_WRAP(__name) \
|
||||
#define STREAM_EXECUTOR_CURAND_WRAP(__name) \
|
||||
struct WrapperShim__##__name { \
|
||||
template <typename... Args> \
|
||||
curandStatus_t operator()(CUDAExecutor *parent, Args... args) { \
|
||||
@ -71,15 +70,15 @@ namespace wrap {
|
||||
} \
|
||||
} __name;
|
||||
|
||||
PERFTOOLS_GPUTOOLS_CURAND_WRAP(curandCreateGenerator);
|
||||
PERFTOOLS_GPUTOOLS_CURAND_WRAP(curandDestroyGenerator);
|
||||
PERFTOOLS_GPUTOOLS_CURAND_WRAP(curandSetStream);
|
||||
PERFTOOLS_GPUTOOLS_CURAND_WRAP(curandGenerateUniform);
|
||||
PERFTOOLS_GPUTOOLS_CURAND_WRAP(curandGenerateUniformDouble);
|
||||
PERFTOOLS_GPUTOOLS_CURAND_WRAP(curandSetPseudoRandomGeneratorSeed);
|
||||
PERFTOOLS_GPUTOOLS_CURAND_WRAP(curandSetGeneratorOffset);
|
||||
PERFTOOLS_GPUTOOLS_CURAND_WRAP(curandGenerateNormal);
|
||||
PERFTOOLS_GPUTOOLS_CURAND_WRAP(curandGenerateNormalDouble);
|
||||
STREAM_EXECUTOR_CURAND_WRAP(curandCreateGenerator);
|
||||
STREAM_EXECUTOR_CURAND_WRAP(curandDestroyGenerator);
|
||||
STREAM_EXECUTOR_CURAND_WRAP(curandSetStream);
|
||||
STREAM_EXECUTOR_CURAND_WRAP(curandGenerateUniform);
|
||||
STREAM_EXECUTOR_CURAND_WRAP(curandGenerateUniformDouble);
|
||||
STREAM_EXECUTOR_CURAND_WRAP(curandSetPseudoRandomGeneratorSeed);
|
||||
STREAM_EXECUTOR_CURAND_WRAP(curandSetGeneratorOffset);
|
||||
STREAM_EXECUTOR_CURAND_WRAP(curandGenerateNormal);
|
||||
STREAM_EXECUTOR_CURAND_WRAP(curandGenerateNormalDouble);
|
||||
|
||||
} // namespace wrap
|
||||
|
||||
@ -271,42 +270,40 @@ bool CUDARng::SetSeed(Stream *stream, const uint8 *seed, uint64 seed_bytes) {
|
||||
}
|
||||
|
||||
} // namespace cuda
|
||||
} // namespace gputools
|
||||
} // namespace perftools
|
||||
|
||||
namespace gpu = ::perftools::gputools;
|
||||
void initialize_curand() {
|
||||
port::Status status =
|
||||
PluginRegistry::Instance()->RegisterFactory<PluginRegistry::RngFactory>(
|
||||
cuda::kCudaPlatformId, cuda::kCuRandPlugin, "cuRAND",
|
||||
[](internal::StreamExecutorInterface *parent) -> rng::RngSupport * {
|
||||
cuda::CUDAExecutor *cuda_executor =
|
||||
dynamic_cast<cuda::CUDAExecutor *>(parent);
|
||||
if (cuda_executor == nullptr) {
|
||||
LOG(ERROR)
|
||||
<< "Attempting to initialize an instance of the cuRAND "
|
||||
<< "support library with a non-CUDA StreamExecutor";
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
REGISTER_MODULE_INITIALIZER(register_curand, {
|
||||
gpu::port::Status status =
|
||||
gpu::PluginRegistry::Instance()
|
||||
->RegisterFactory<gpu::PluginRegistry::RngFactory>(
|
||||
gpu::cuda::kCudaPlatformId, gpu::cuda::kCuRandPlugin, "cuRAND",
|
||||
[](gpu::internal::StreamExecutorInterface
|
||||
*parent) -> gpu::rng::RngSupport * {
|
||||
gpu::cuda::CUDAExecutor *cuda_executor =
|
||||
dynamic_cast<gpu::cuda::CUDAExecutor *>(parent);
|
||||
if (cuda_executor == nullptr) {
|
||||
LOG(ERROR)
|
||||
<< "Attempting to initialize an instance of the cuRAND "
|
||||
<< "support library with a non-CUDA StreamExecutor";
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
gpu::cuda::CUDARng *rng = new gpu::cuda::CUDARng(cuda_executor);
|
||||
if (!rng->Init()) {
|
||||
// Note: Init() will log a more specific error.
|
||||
delete rng;
|
||||
return nullptr;
|
||||
}
|
||||
return rng;
|
||||
});
|
||||
cuda::CUDARng *rng = new cuda::CUDARng(cuda_executor);
|
||||
if (!rng->Init()) {
|
||||
// Note: Init() will log a more specific error.
|
||||
delete rng;
|
||||
return nullptr;
|
||||
}
|
||||
return rng;
|
||||
});
|
||||
|
||||
if (!status.ok()) {
|
||||
LOG(ERROR) << "Unable to register cuRAND factory: "
|
||||
<< status.error_message();
|
||||
}
|
||||
|
||||
gpu::PluginRegistry::Instance()->SetDefaultFactory(gpu::cuda::kCudaPlatformId,
|
||||
gpu::PluginKind::kRng,
|
||||
gpu::cuda::kCuRandPlugin);
|
||||
});
|
||||
PluginRegistry::Instance()->SetDefaultFactory(
|
||||
cuda::kCudaPlatformId, PluginKind::kRng, cuda::kCuRandPlugin);
|
||||
}
|
||||
|
||||
} // namespace stream_executor
|
||||
|
||||
REGISTER_MODULE_INITIALIZER(register_curand,
|
||||
{ stream_executor::initialize_curand(); });
|
||||
|
@ -24,8 +24,7 @@ limitations under the License.
|
||||
|
||||
typedef struct curandGenerator_st *curandGenerator_t;
|
||||
|
||||
namespace perftools {
|
||||
namespace gputools {
|
||||
namespace stream_executor {
|
||||
|
||||
class Stream;
|
||||
template <typename ElemT>
|
||||
@ -98,7 +97,6 @@ class CUDARng : public rng::RngSupport {
|
||||
};
|
||||
|
||||
} // namespace cuda
|
||||
} // namespace gputools
|
||||
} // namespace perftools
|
||||
} // namespace stream_executor
|
||||
|
||||
#endif // TENSORFLOW_STREAM_EXECUTOR_CUDA_CUDA_RNG_H_
|
||||
|
@ -19,8 +19,7 @@ limitations under the License.
|
||||
#include "tensorflow/stream_executor/lib/status.h"
|
||||
#include "tensorflow/stream_executor/stream.h"
|
||||
|
||||
namespace perftools {
|
||||
namespace gputools {
|
||||
namespace stream_executor {
|
||||
namespace cuda {
|
||||
|
||||
bool CUDAStream::Init() {
|
||||
@ -59,5 +58,4 @@ CUstream AsCUDAStreamValue(Stream *stream) {
|
||||
}
|
||||
|
||||
} // namespace cuda
|
||||
} // namespace gputools
|
||||
} // namespace perftools
|
||||
} // namespace stream_executor
|
||||
|
@ -23,8 +23,7 @@ limitations under the License.
|
||||
#include "tensorflow/stream_executor/platform/thread_annotations.h"
|
||||
#include "tensorflow/stream_executor/stream_executor_internal.h"
|
||||
|
||||
namespace perftools {
|
||||
namespace gputools {
|
||||
namespace stream_executor {
|
||||
namespace cuda {
|
||||
|
||||
class CUDAExecutor;
|
||||
@ -89,7 +88,6 @@ CUDAStream *AsCUDAStream(Stream *stream);
|
||||
CUstream AsCUDAStreamValue(Stream *stream);
|
||||
|
||||
} // namespace cuda
|
||||
} // namespace gputools
|
||||
} // namespace perftools
|
||||
} // namespace stream_executor
|
||||
|
||||
#endif // TENSORFLOW_STREAM_EXECUTOR_CUDA_CUDA_STREAM_H_
|
||||
|
@ -20,8 +20,7 @@ limitations under the License.
|
||||
#include "tensorflow/stream_executor/cuda/cuda_stream.h"
|
||||
#include "tensorflow/stream_executor/lib/status.h"
|
||||
|
||||
namespace perftools {
|
||||
namespace gputools {
|
||||
namespace stream_executor {
|
||||
namespace cuda {
|
||||
|
||||
bool CUDATimer::Init() {
|
||||
@ -92,5 +91,4 @@ bool CUDATimer::Stop(CUDAStream* stream) {
|
||||
}
|
||||
|
||||
} // namespace cuda
|
||||
} // namespace gputools
|
||||
} // namespace perftools
|
||||
} // namespace stream_executor
|
||||
|
@ -23,8 +23,7 @@ limitations under the License.
|
||||
#include "tensorflow/stream_executor/cuda/cuda_driver.h"
|
||||
#include "tensorflow/stream_executor/cuda/cuda_gpu_executor.h"
|
||||
|
||||
namespace perftools {
|
||||
namespace gputools {
|
||||
namespace stream_executor {
|
||||
namespace cuda {
|
||||
|
||||
class CUDAExecutor;
|
||||
@ -60,13 +59,13 @@ class CUDATimer : public internal::TimerInterface {
|
||||
// events.
|
||||
float GetElapsedMilliseconds() const;
|
||||
|
||||
// See perftools::gputools::Timer::Microseconds().
|
||||
// See Timer::Microseconds().
|
||||
// TODO(leary) make this into an error code interface...
|
||||
uint64 Microseconds() const override {
|
||||
return GetElapsedMilliseconds() * 1e3;
|
||||
}
|
||||
|
||||
// See perftools::GPUTools::Timer::Nanoseconds().
|
||||
// See Timer::Nanoseconds().
|
||||
uint64 Nanoseconds() const override { return GetElapsedMilliseconds() * 1e6; }
|
||||
|
||||
private:
|
||||
@ -85,7 +84,6 @@ struct TimerDeleter {
|
||||
};
|
||||
|
||||
} // namespace cuda
|
||||
} // namespace gputools
|
||||
} // namespace perftools
|
||||
} // namespace stream_executor
|
||||
|
||||
#endif // TENSORFLOW_STREAM_EXECUTOR_CUDA_CUDA_TIMER_H_
|
||||
|
@ -15,8 +15,7 @@ limitations under the License.
|
||||
|
||||
#include "tensorflow/stream_executor/cuda/cudnn_version.h"
|
||||
|
||||
namespace perftools {
|
||||
namespace gputools {
|
||||
namespace stream_executor {
|
||||
namespace cuda {
|
||||
|
||||
bool IsSourceCompatibleWithCudnnLibrary(CudnnVersion source_version,
|
||||
@ -38,5 +37,4 @@ bool IsSourceCompatibleWithCudnnLibrary(CudnnVersion source_version,
|
||||
}
|
||||
|
||||
} // namespace cuda
|
||||
} // namespace gputools
|
||||
} // namespace perftools
|
||||
} // namespace stream_executor
|
||||
|
@ -20,8 +20,7 @@ limitations under the License.
|
||||
|
||||
#include "tensorflow/core/lib/strings/strcat.h"
|
||||
|
||||
namespace perftools {
|
||||
namespace gputools {
|
||||
namespace stream_executor {
|
||||
namespace cuda {
|
||||
|
||||
struct CudnnVersion {
|
||||
@ -46,7 +45,6 @@ bool IsSourceCompatibleWithCudnnLibrary(CudnnVersion source_version,
|
||||
CudnnVersion loaded_version);
|
||||
|
||||
} // namespace cuda
|
||||
} // namespace gputools
|
||||
} // namespace perftools
|
||||
} // namespace stream_executor
|
||||
|
||||
#endif // TENSORFLOW_STREAM_EXECUTOR_CUDA_CUDNN_VERSION_H_
|
||||
|
@ -17,8 +17,7 @@ limitations under the License.
|
||||
|
||||
#include "tensorflow/core/platform/test.h"
|
||||
|
||||
namespace perftools {
|
||||
namespace gputools {
|
||||
namespace stream_executor {
|
||||
namespace cuda {
|
||||
namespace {
|
||||
|
||||
@ -70,5 +69,4 @@ TEST(IsSourceCompatibleWithCudnnLibraryTest, Basic) {
|
||||
|
||||
} // namespace
|
||||
} // namespace cuda
|
||||
} // namespace gputools
|
||||
} // namespace perftools
|
||||
} // namespace stream_executor
|
||||
|
@ -21,8 +21,7 @@ limitations under the License.
|
||||
#include "tensorflow/stream_executor/lib/mathutil.h"
|
||||
#include "tensorflow/stream_executor/lib/strcat.h"
|
||||
|
||||
namespace perftools {
|
||||
namespace gputools {
|
||||
namespace stream_executor {
|
||||
|
||||
static const uint64 kUninitializedUint64 = -1ULL;
|
||||
/* static */ const char *DeviceDescription::kUndefinedString = "<undefined>";
|
||||
@ -234,6 +233,4 @@ uint64 CalculateRegisterLimitForTargetOccupancy(
|
||||
return 0;
|
||||
}
|
||||
|
||||
|
||||
} // namespace gputools
|
||||
} // namespace perftools
|
||||
} // namespace stream_executor
|
||||
|
@ -27,8 +27,7 @@ limitations under the License.
|
||||
#include "tensorflow/stream_executor/launch_dim.h"
|
||||
#include "tensorflow/stream_executor/platform/port.h"
|
||||
|
||||
namespace perftools {
|
||||
namespace gputools {
|
||||
namespace stream_executor {
|
||||
namespace internal {
|
||||
class DeviceDescriptionBuilder;
|
||||
} // namespace internal
|
||||
@ -388,7 +387,6 @@ uint64 CalculateRegisterLimitForTargetOccupancy(
|
||||
const DeviceDescription &device_description, uint64 shared_memory_per_block,
|
||||
const ThreadDim &thread_dims, uint64 target_blocks_per_core);
|
||||
|
||||
} // namespace gputools
|
||||
} // namespace perftools
|
||||
} // namespace stream_executor
|
||||
|
||||
#endif // TENSORFLOW_STREAM_EXECUTOR_DEVICE_DESCRIPTION_H_
|
||||
|
@ -32,6 +32,16 @@ limitations under the License.
|
||||
namespace perftools {
|
||||
namespace gputools {
|
||||
|
||||
// Temporarily pull stream_executor into perftools::gputools while we migrate
|
||||
// code to the new namespace. TODO(b/77980417): Remove this once we've
|
||||
// completed the migration.
|
||||
using namespace stream_executor; // NOLINT[build/namespaces]
|
||||
|
||||
} // namespace gputools
|
||||
} // namespace perftools
|
||||
|
||||
namespace stream_executor {
|
||||
|
||||
class StreamExecutor;
|
||||
|
||||
// void*-analogous device memory allocation. For the typed variation, see
|
||||
@ -280,7 +290,6 @@ static_assert(sizeof(Float2) == 2 * sizeof(float), "Float2 must be packed");
|
||||
static_assert(sizeof(Float4) == 4 * sizeof(float), "Float4 must be packed");
|
||||
static_assert(sizeof(Double2) == 2 * sizeof(double), "Double2 must be packed");
|
||||
|
||||
} // namespace gputools
|
||||
} // namespace perftools
|
||||
} // namespace stream_executor
|
||||
|
||||
#endif // TENSORFLOW_STREAM_EXECUTOR_DEVICE_MEMORY_H_
|
||||
|
@ -25,8 +25,7 @@ limitations under the License.
|
||||
#include "tensorflow/stream_executor/platform/port.h"
|
||||
#include "tensorflow/stream_executor/platform/logging.h"
|
||||
|
||||
namespace perftools {
|
||||
namespace gputools {
|
||||
namespace stream_executor {
|
||||
|
||||
// Indicates a set of options for a device's usage, which generally must be
|
||||
// provided at StreamExecutor device-initialization time.
|
||||
@ -84,7 +83,6 @@ struct DeviceOptions {
|
||||
unsigned flags_;
|
||||
};
|
||||
|
||||
} // namespace gputools
|
||||
} // namespace perftools
|
||||
} // namespace stream_executor
|
||||
|
||||
#endif // TENSORFLOW_STREAM_EXECUTOR_DEVICE_OPTIONS_H_
|
||||
|
@ -18,8 +18,7 @@ limitations under the License.
|
||||
#include "tensorflow/stream_executor/lib/strcat.h"
|
||||
#include "tensorflow/stream_executor/lib/stringprintf.h"
|
||||
|
||||
namespace perftools {
|
||||
namespace gputools {
|
||||
namespace stream_executor {
|
||||
namespace dnn {
|
||||
|
||||
bool DnnSupport::GetConvolveAlgorithms(
|
||||
@ -554,5 +553,4 @@ string NormalizeDescriptor::ToShortString() const {
|
||||
}
|
||||
|
||||
} // namespace dnn
|
||||
} // namespace gputools
|
||||
} // namespace perftools
|
||||
} // namespace stream_executor
|
||||
|
@ -38,8 +38,7 @@ namespace Eigen {
|
||||
struct half;
|
||||
} // namespace Eigen
|
||||
|
||||
namespace perftools {
|
||||
namespace gputools {
|
||||
namespace stream_executor {
|
||||
|
||||
class HostBuffer;
|
||||
class Stream;
|
||||
@ -2285,7 +2284,6 @@ class DnnSupport {
|
||||
};
|
||||
|
||||
} // namespace dnn
|
||||
} // namespace gputools
|
||||
} // namespace perftools
|
||||
} // namespace stream_executor
|
||||
|
||||
#endif // TENSORFLOW_STREAM_EXECUTOR_DNN_H_
|
||||
|
@ -37,8 +37,7 @@ limitations under the License.
|
||||
#include "cuda/cuda_config.h"
|
||||
#endif
|
||||
|
||||
namespace perftools {
|
||||
namespace gputools {
|
||||
namespace stream_executor {
|
||||
namespace internal {
|
||||
|
||||
string GetCudaVersion() { return TF_CUDA_VERSION; }
|
||||
@ -291,5 +290,4 @@ static std::vector<string>* CreatePrimordialRpaths() {
|
||||
}
|
||||
|
||||
} // namespace internal
|
||||
} // namespace gputools
|
||||
} // namespace perftools
|
||||
} // namespace stream_executor
|
||||
|
@ -28,8 +28,7 @@ limitations under the License.
|
||||
#include "tensorflow/stream_executor/platform.h"
|
||||
#include "tensorflow/stream_executor/platform/mutex.h"
|
||||
|
||||
namespace perftools {
|
||||
namespace gputools {
|
||||
namespace stream_executor {
|
||||
namespace internal {
|
||||
|
||||
// Permits StreamExecutor code to dynamically load a pre-determined set of
|
||||
@ -114,7 +113,6 @@ class CachedDsoLoader {
|
||||
};
|
||||
|
||||
} // namespace internal
|
||||
} // namespace gputools
|
||||
} // namespace perftools
|
||||
} // namespace stream_executor
|
||||
|
||||
#endif // TENSORFLOW_STREAM_EXECUTOR_DSO_LOADER_H_
|
||||
|
@ -19,8 +19,7 @@ limitations under the License.
|
||||
#include "tensorflow/stream_executor/stream_executor_pimpl.h"
|
||||
#include "tensorflow/stream_executor/stream.h"
|
||||
|
||||
namespace perftools {
|
||||
namespace gputools {
|
||||
namespace stream_executor {
|
||||
|
||||
Event::Event(StreamExecutor* stream_exec)
|
||||
: stream_exec_(stream_exec),
|
||||
@ -48,5 +47,4 @@ Event::Status Event::PollForStatus() {
|
||||
return stream_exec_->PollForEventStatus(this);
|
||||
}
|
||||
|
||||
} // namespace gputools
|
||||
} // namespace perftools
|
||||
} // namespace stream_executor
|
||||
|
@ -20,8 +20,7 @@ limitations under the License.
|
||||
|
||||
#include "tensorflow/stream_executor/platform/port.h"
|
||||
|
||||
namespace perftools {
|
||||
namespace gputools {
|
||||
namespace stream_executor {
|
||||
|
||||
namespace internal {
|
||||
class EventInterface;
|
||||
@ -76,7 +75,6 @@ class Event {
|
||||
SE_DISALLOW_COPY_AND_ASSIGN(Event);
|
||||
};
|
||||
|
||||
} // namespace gputools
|
||||
} // namespace perftools
|
||||
} // namespace stream_executor
|
||||
|
||||
#endif // TENSORFLOW_STREAM_EXECUTOR_EVENT_H_
|
||||
|
@ -17,8 +17,7 @@ limitations under the License.
|
||||
|
||||
#include "tensorflow/stream_executor/lib/stringprintf.h"
|
||||
|
||||
namespace perftools {
|
||||
namespace gputools {
|
||||
namespace stream_executor {
|
||||
|
||||
port::StatusOr<StreamExecutor*> ExecutorCache::GetOrCreate(
|
||||
const StreamExecutorConfig& config,
|
||||
@ -104,5 +103,4 @@ ExecutorCache::Entry::~Entry() {
|
||||
configurations.clear();
|
||||
}
|
||||
|
||||
} // namespace gputools
|
||||
} // namespace perftools
|
||||
} // namespace stream_executor
|
||||
|
@ -24,8 +24,7 @@ limitations under the License.
|
||||
#include "tensorflow/stream_executor/platform/mutex.h"
|
||||
#include "tensorflow/stream_executor/stream_executor_pimpl.h"
|
||||
|
||||
namespace perftools {
|
||||
namespace gputools {
|
||||
namespace stream_executor {
|
||||
|
||||
// Utility class to allow Platform objects to manage cached StreamExecutors.
|
||||
// Thread-safe.
|
||||
@ -76,7 +75,6 @@ class ExecutorCache {
|
||||
SE_DISALLOW_COPY_AND_ASSIGN(ExecutorCache);
|
||||
};
|
||||
|
||||
} // namespace gputools
|
||||
} // namespace perftools
|
||||
} // namespace stream_executor
|
||||
|
||||
#endif // TENSORFLOW_STREAM_EXECUTOR_EXECUTOR_CACHE_H_
|
||||
|
@ -48,8 +48,7 @@ limitations under the License.
|
||||
#include <memory>
|
||||
#include "tensorflow/stream_executor/platform/port.h"
|
||||
|
||||
namespace perftools {
|
||||
namespace gputools {
|
||||
namespace stream_executor {
|
||||
|
||||
class Stream;
|
||||
template <typename ElemT>
|
||||
@ -210,7 +209,7 @@ class FftSupport {
|
||||
|
||||
// Macro used to quickly declare overrides for abstract virtuals in the
|
||||
// fft::FftSupport base class. Assumes that it's emitted somewhere inside the
|
||||
// ::perftools::gputools namespace.
|
||||
// ::stream_executor namespace.
|
||||
#define TENSORFLOW_STREAM_EXECUTOR_GPU_FFT_SUPPORT_OVERRIDES \
|
||||
std::unique_ptr<fft::Plan> Create1dPlan(Stream *stream, uint64 num_x, \
|
||||
fft::Type type, bool in_place_fft) \
|
||||
@ -265,7 +264,6 @@ class FftSupport {
|
||||
DeviceMemory<double> *output) override;
|
||||
|
||||
} // namespace fft
|
||||
} // namespace gputools
|
||||
} // namespace perftools
|
||||
} // namespace stream_executor
|
||||
|
||||
#endif // TENSORFLOW_STREAM_EXECUTOR_FFT_H_
|
||||
|
@ -28,8 +28,7 @@ limitations under the License.
|
||||
|
||||
bool FLAGS_stream_executor_cpu_real_clock_rate = false;
|
||||
|
||||
namespace perftools {
|
||||
namespace gputools {
|
||||
namespace stream_executor {
|
||||
namespace host {
|
||||
|
||||
HostStream *AsHostStream(Stream *stream) {
|
||||
@ -266,5 +265,4 @@ rng::RngSupport *HostExecutor::CreateRng() {
|
||||
}
|
||||
|
||||
} // namespace host
|
||||
} // namespace gputools
|
||||
} // namespace perftools
|
||||
} // namespace stream_executor
|
||||
|
@ -28,8 +28,7 @@ limitations under the License.
|
||||
#include "tensorflow/stream_executor/stream_executor.h"
|
||||
#include "tensorflow/stream_executor/stream_executor_internal.h"
|
||||
|
||||
namespace perftools {
|
||||
namespace gputools {
|
||||
namespace stream_executor {
|
||||
namespace host {
|
||||
|
||||
// An implementation of StreamExecutor that does no communication or interaction
|
||||
@ -210,7 +209,6 @@ class HostExecutor : public internal::StreamExecutorInterface {
|
||||
};
|
||||
|
||||
} // namespace host
|
||||
} // namespace gputools
|
||||
} // namespace perftools
|
||||
} // namespace stream_executor
|
||||
|
||||
#endif // TENSORFLOW_STREAM_EXECUTOR_HOST_HOST_GPU_EXECUTOR_H_
|
||||
|
@ -26,10 +26,7 @@ limitations under the License.
|
||||
#include "tensorflow/stream_executor/lib/status_macros.h"
|
||||
#include "tensorflow/stream_executor/lib/stringprintf.h"
|
||||
|
||||
namespace gpu = ::perftools::gputools;
|
||||
|
||||
namespace perftools {
|
||||
namespace gputools {
|
||||
namespace stream_executor {
|
||||
namespace host {
|
||||
|
||||
HostPlatform::HostPlatform() : name_("Host") {}
|
||||
@ -93,16 +90,15 @@ void HostPlatform::UnregisterTraceListener(TraceListener* listener) {
|
||||
}
|
||||
|
||||
static void InitializeHostPlatform() {
|
||||
std::unique_ptr<gpu::Platform> platform(new gpu::host::HostPlatform);
|
||||
SE_CHECK_OK(gpu::MultiPlatformManager::RegisterPlatform(std::move(platform)));
|
||||
std::unique_ptr<Platform> platform(new host::HostPlatform);
|
||||
SE_CHECK_OK(MultiPlatformManager::RegisterPlatform(std::move(platform)));
|
||||
}
|
||||
|
||||
} // namespace host
|
||||
} // namespace gputools
|
||||
} // namespace perftools
|
||||
} // namespace stream_executor
|
||||
|
||||
REGISTER_MODULE_INITIALIZER(
|
||||
host_platform, perftools::gputools::host::InitializeHostPlatform());
|
||||
REGISTER_MODULE_INITIALIZER(host_platform,
|
||||
stream_executor::host::InitializeHostPlatform());
|
||||
|
||||
DECLARE_MODULE_INITIALIZER(multi_platform_manager);
|
||||
// Note that module initialization sequencing is not supported in the
|
||||
|
@ -33,8 +33,7 @@ limitations under the License.
|
||||
#include "tensorflow/stream_executor/stream_executor_pimpl.h"
|
||||
#include "tensorflow/stream_executor/trace_listener.h"
|
||||
|
||||
namespace perftools {
|
||||
namespace gputools {
|
||||
namespace stream_executor {
|
||||
namespace host {
|
||||
|
||||
// Host (CPU) platform plugin, registered as a singleton value via module
|
||||
@ -79,7 +78,6 @@ class HostPlatform : public Platform {
|
||||
};
|
||||
|
||||
} // namespace host
|
||||
} // namespace gputools
|
||||
} // namespace perftools
|
||||
} // namespace stream_executor
|
||||
|
||||
#endif // TENSORFLOW_STREAM_EXECUTOR_HOST_HOST_PLATFORM_H_
|
||||
|
@ -15,12 +15,10 @@ limitations under the License.
|
||||
|
||||
#include "tensorflow/stream_executor/host/host_platform_id.h"
|
||||
|
||||
namespace perftools {
|
||||
namespace gputools {
|
||||
namespace stream_executor {
|
||||
namespace host {
|
||||
|
||||
PLATFORM_DEFINE_ID(kHostPlatformId);
|
||||
|
||||
} // namespace host
|
||||
} // namespace gputools
|
||||
} // namespace perftools
|
||||
} // namespace stream_executor
|
||||
|
@ -18,8 +18,7 @@ limitations under the License.
|
||||
|
||||
#include "tensorflow/stream_executor/platform.h"
|
||||
|
||||
namespace perftools {
|
||||
namespace gputools {
|
||||
namespace stream_executor {
|
||||
namespace host {
|
||||
|
||||
// Opaque and unique identifier for the host platform.
|
||||
@ -30,7 +29,6 @@ namespace host {
|
||||
extern const Platform::Id kHostPlatformId;
|
||||
|
||||
} // namespace host
|
||||
} // namespace gputools
|
||||
} // namespace perftools
|
||||
} // namespace stream_executor
|
||||
|
||||
#endif // TENSORFLOW_STREAM_EXECUTOR_HOST_HOST_PLATFORM_ID_H_
|
||||
|
@ -17,8 +17,7 @@ limitations under the License.
|
||||
// the HostExecutor implementation.
|
||||
#include "tensorflow/stream_executor/host/host_stream.h"
|
||||
|
||||
namespace perftools {
|
||||
namespace gputools {
|
||||
namespace stream_executor {
|
||||
namespace host {
|
||||
|
||||
HostStream::HostStream()
|
||||
@ -53,5 +52,4 @@ void HostStream::BlockUntilDone() {
|
||||
|
||||
} // namespace host
|
||||
|
||||
} // namespace gputools
|
||||
} // namespace perftools
|
||||
} // namespace stream_executor
|
||||
|
@ -24,8 +24,7 @@ limitations under the License.
|
||||
#include "tensorflow/stream_executor/lib/threadpool.h"
|
||||
#include "tensorflow/stream_executor/stream_executor_internal.h"
|
||||
|
||||
namespace perftools {
|
||||
namespace gputools {
|
||||
namespace stream_executor {
|
||||
namespace host {
|
||||
|
||||
class HostStream : public internal::StreamInterface {
|
||||
@ -52,7 +51,6 @@ class HostStream : public internal::StreamInterface {
|
||||
};
|
||||
|
||||
} // namespace host
|
||||
} // namespace gputools
|
||||
} // namespace perftools
|
||||
} // namespace stream_executor
|
||||
|
||||
#endif // TENSORFLOW_STREAM_EXECUTOR_HOST_HOST_STREAM_H_
|
||||
|
@ -19,8 +19,7 @@ limitations under the License.
|
||||
#include "tensorflow/stream_executor/stream.h"
|
||||
#include "tensorflow/stream_executor/stream_executor_internal.h"
|
||||
|
||||
namespace perftools {
|
||||
namespace gputools {
|
||||
namespace stream_executor {
|
||||
namespace host {
|
||||
|
||||
using std::chrono::duration_cast;
|
||||
@ -46,5 +45,4 @@ void HostTimer::StartNow() { start_time_ = clock::now(); }
|
||||
void HostTimer::StopNow() { duration_ = clock::now() - start_time_; }
|
||||
|
||||
} // namespace host
|
||||
} // namespace gputools
|
||||
} // namespace perftools
|
||||
} // namespace stream_executor
|
||||
|
@ -20,8 +20,7 @@ limitations under the License.
|
||||
|
||||
#include "tensorflow/stream_executor/stream_executor_internal.h"
|
||||
|
||||
namespace perftools {
|
||||
namespace gputools {
|
||||
namespace stream_executor {
|
||||
namespace host {
|
||||
|
||||
class HostTimer : public internal::TimerInterface {
|
||||
@ -57,7 +56,6 @@ class HostTimer : public internal::TimerInterface {
|
||||
};
|
||||
|
||||
} // namespace host
|
||||
} // namespace gputools
|
||||
} // namespace perftools
|
||||
} // namespace stream_executor
|
||||
|
||||
#endif // TENSORFLOW_STREAM_EXECUTOR_HOST_HOST_TIMER_H_
|
||||
|
@ -18,8 +18,7 @@ limitations under the License.
|
||||
|
||||
#include "tensorflow/stream_executor/dnn.h"
|
||||
|
||||
namespace perftools {
|
||||
namespace gputools {
|
||||
namespace stream_executor {
|
||||
|
||||
// A HostBuffer is a block of memory in host memory containing the data for a
|
||||
// dnn::BatchDescriptor using a device-dependent memory layout.
|
||||
@ -42,7 +41,6 @@ class HostBuffer {
|
||||
const dnn::BatchDescriptor descriptor_;
|
||||
};
|
||||
|
||||
} // namespace gputools
|
||||
} // namespace perftools
|
||||
} // namespace stream_executor
|
||||
|
||||
#endif // TENSORFLOW_STREAM_EXECUTOR_HOST_BUFFER_H_
|
||||
|
@ -27,8 +27,7 @@ limitations under the License.
|
||||
#include "tensorflow/stream_executor/platform/logging.h"
|
||||
#include "tensorflow/stream_executor/stream_executor.h"
|
||||
|
||||
namespace perftools {
|
||||
namespace gputools {
|
||||
namespace stream_executor {
|
||||
|
||||
bool KernelMetadata::registers_per_thread(int *registers_per_thread) const {
|
||||
if (has_registers_per_thread_) {
|
||||
@ -103,5 +102,4 @@ void KernelBase::set_name(port::StringPiece name) {
|
||||
demangled_name_ = port::Demangle(stubless_name.data());
|
||||
}
|
||||
|
||||
} // namespace gputools
|
||||
} // namespace perftools
|
||||
} // namespace stream_executor
|
||||
|
@ -64,7 +64,7 @@ limitations under the License.
|
||||
//
|
||||
// Users typically won't need to type out the TypedKernel signature in full, it
|
||||
// will be typedef'd by automatically generated code; for example, see
|
||||
// perftools::gputools::executor_sample::VecReduceAddKernel.
|
||||
// stream_executor::executor_sample::VecReduceAddKernel.
|
||||
|
||||
#ifndef TENSORFLOW_STREAM_EXECUTOR_KERNEL_H_
|
||||
#define TENSORFLOW_STREAM_EXECUTOR_KERNEL_H_
|
||||
@ -82,8 +82,7 @@ limitations under the License.
|
||||
#include "tensorflow/stream_executor/lib/stringpiece.h"
|
||||
#include "tensorflow/stream_executor/platform/port.h"
|
||||
|
||||
namespace perftools {
|
||||
namespace gputools {
|
||||
namespace stream_executor {
|
||||
|
||||
class DeviceMemoryBase;
|
||||
template <typename ElemT>
|
||||
@ -639,8 +638,8 @@ struct KernelInvocationChecker {
|
||||
// NOTE: if you encounter an error here, you can see the mismatch by looking
|
||||
// at the end of the last error message, which will be of the form:
|
||||
//
|
||||
// ...::Compatible<const perftools::gputools::DeviceMemory<OneThing> &,
|
||||
// perftools::gputools::DeviceMemory<AnotherThing>, true,
|
||||
// ...::Compatible<const stream_executor::DeviceMemory<OneThing> &,
|
||||
// stream_executor::DeviceMemory<AnotherThing>, true,
|
||||
// 0>'
|
||||
// requested here
|
||||
//
|
||||
@ -711,7 +710,6 @@ struct KernelParamsOk<TypedKernel<Params...>, Args...> {
|
||||
std::tuple<Params...>, std::tuple<Args...>>::CheckAllNoStaticAssert();
|
||||
};
|
||||
|
||||
} // namespace gputools
|
||||
} // namespace perftools
|
||||
} // namespace stream_executor
|
||||
|
||||
#endif // TENSORFLOW_STREAM_EXECUTOR_KERNEL_H_
|
||||
|
@ -18,8 +18,7 @@ limitations under the License.
|
||||
#ifndef TENSORFLOW_STREAM_EXECUTOR_KERNEL_CACHE_CONFIG_H_
|
||||
#define TENSORFLOW_STREAM_EXECUTOR_KERNEL_CACHE_CONFIG_H_
|
||||
|
||||
namespace perftools {
|
||||
namespace gputools {
|
||||
namespace stream_executor {
|
||||
|
||||
// This enum represents potential configurations of L1/shared memory when
|
||||
// running a particular kernel. These values represent user preference, and
|
||||
@ -38,7 +37,6 @@ enum class KernelCacheConfig {
|
||||
kPreferEqual,
|
||||
};
|
||||
|
||||
} // namespace gputools
|
||||
} // namespace perftools
|
||||
} // namespace stream_executor
|
||||
|
||||
#endif // TENSORFLOW_STREAM_EXECUTOR_KERNEL_CACHE_CONFIG_H_
|
||||
|
@ -15,9 +15,7 @@ limitations under the License.
|
||||
|
||||
#include "tensorflow/stream_executor/kernel_spec.h"
|
||||
|
||||
|
||||
namespace perftools {
|
||||
namespace gputools {
|
||||
namespace stream_executor {
|
||||
|
||||
KernelLoaderSpec::KernelLoaderSpec(port::StringPiece kernelname)
|
||||
: kernelname_(kernelname.ToString()) {}
|
||||
@ -247,5 +245,4 @@ MultiKernelLoaderSpec *MultiKernelLoaderSpec::AddCudaCompressedPtxInMemory(
|
||||
|
||||
MultiKernelLoaderSpec::MultiKernelLoaderSpec(size_t arity) : arity_(arity) {}
|
||||
|
||||
} // namespace gputools
|
||||
} // namespace perftools
|
||||
} // namespace stream_executor
|
||||
|
@ -56,8 +56,7 @@ limitations under the License.
|
||||
#include "tensorflow/stream_executor/platform/mutex.h"
|
||||
#include "tensorflow/stream_executor/platform/port.h"
|
||||
|
||||
namespace perftools {
|
||||
namespace gputools {
|
||||
namespace stream_executor {
|
||||
|
||||
// Describes how to load a kernel on a target platform.
|
||||
//
|
||||
@ -374,7 +373,6 @@ class MultiKernelLoaderSpec {
|
||||
size_t arity_;
|
||||
};
|
||||
|
||||
} // namespace gputools
|
||||
} // namespace perftools
|
||||
} // namespace stream_executor
|
||||
|
||||
#endif // TENSORFLOW_STREAM_EXECUTOR_KERNEL_SPEC_H_
|
||||
|
@ -21,7 +21,7 @@ limitations under the License.
|
||||
// a single PC in a unit called a warp. There is a maximum number of threads
|
||||
// that can execute in a shared-context entity called a block. Presently, that
|
||||
// number is 1024 -- again, something that should not be relied on from this
|
||||
// comment, but checked via perftools::gputools::DeviceDescription.
|
||||
// comment, but checked via stream_executor::DeviceDescription.
|
||||
//
|
||||
// For additional information, see
|
||||
// http://docs.nvidia.com/cuda/kepler-tuning-guide/#device-utilization-and-occupancy
|
||||
@ -40,8 +40,7 @@ limitations under the License.
|
||||
#include "tensorflow/stream_executor/lib/strcat.h"
|
||||
#include "tensorflow/stream_executor/platform/port.h"
|
||||
|
||||
namespace perftools {
|
||||
namespace gputools {
|
||||
namespace stream_executor {
|
||||
|
||||
// Basic type that represents a 3-dimensional index space.
|
||||
struct Dim3D {
|
||||
@ -74,7 +73,6 @@ struct BlockDim : public Dim3D {
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace gputools
|
||||
} // namespace perftools
|
||||
} // namespace stream_executor
|
||||
|
||||
#endif // TENSORFLOW_STREAM_EXECUTOR_LAUNCH_DIM_H_
|
||||
|
@ -18,14 +18,23 @@ limitations under the License.
|
||||
|
||||
#include "tensorflow/core/lib/gtl/array_slice.h"
|
||||
|
||||
namespace perftools {
|
||||
namespace gputools {
|
||||
namespace stream_executor {
|
||||
namespace port {
|
||||
|
||||
using tensorflow::gtl::ArraySlice;
|
||||
using tensorflow::gtl::MutableArraySlice;
|
||||
|
||||
} // namespace port
|
||||
} // namespace stream_executor
|
||||
|
||||
namespace perftools {
|
||||
namespace gputools {
|
||||
|
||||
// Temporarily pull stream_executor into perftools::gputools while we migrate
|
||||
// code to the new namespace. TODO(b/77980417): Remove this once we've
|
||||
// completed the migration.
|
||||
using namespace stream_executor; // NOLINT[build/namespaces]
|
||||
|
||||
} // namespace gputools
|
||||
} // namespace perftools
|
||||
|
||||
|
@ -13,15 +13,14 @@ See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
// IWYU pragma: private, include "perftools/gputools/executor/stream_executor.h"
|
||||
// IWYU pragma: private, include "third_party/tensorflow/stream_executor/stream_executor.h"
|
||||
|
||||
#ifndef TENSORFLOW_STREAM_EXECUTOR_LIB_CASTS_H_
|
||||
#define TENSORFLOW_STREAM_EXECUTOR_LIB_CASTS_H_
|
||||
|
||||
#include <stdlib.h>
|
||||
|
||||
namespace perftools {
|
||||
namespace gputools {
|
||||
namespace stream_executor {
|
||||
namespace port {
|
||||
|
||||
// port::bit_cast<Dest,Source> is a template function that implements the
|
||||
@ -96,7 +95,6 @@ inline Dest bit_cast(const Source& source) {
|
||||
}
|
||||
|
||||
} // namespace port
|
||||
} // namespace gputools
|
||||
} // namespace perftools
|
||||
} // namespace stream_executor
|
||||
|
||||
#endif // TENSORFLOW_STREAM_EXECUTOR_LIB_CASTS_H_
|
||||
|
@ -27,8 +27,7 @@ limitations under the License.
|
||||
#include <cxxabi.h>
|
||||
#endif
|
||||
|
||||
namespace perftools {
|
||||
namespace gputools {
|
||||
namespace stream_executor {
|
||||
namespace port {
|
||||
|
||||
// The API reference of abi::__cxa_demangle() can be found in
|
||||
@ -49,5 +48,4 @@ string Demangle(const char *mangled) {
|
||||
}
|
||||
|
||||
} // namespace port
|
||||
} // namespace gputools
|
||||
} // namespace perftools
|
||||
} // namespace stream_executor
|
||||
|
@ -18,14 +18,12 @@ limitations under the License.
|
||||
|
||||
#include "tensorflow/stream_executor/platform/port.h"
|
||||
|
||||
namespace perftools {
|
||||
namespace gputools {
|
||||
namespace stream_executor {
|
||||
namespace port {
|
||||
|
||||
string Demangle(const char* mangled);
|
||||
|
||||
} // namespace port
|
||||
} // namespace gputools
|
||||
} // namespace perftools
|
||||
} // namespace stream_executor
|
||||
|
||||
#endif // TENSORFLOW_STREAM_EXECUTOR_LIB_DEMANGLE_H_
|
||||
|
@ -21,8 +21,7 @@ limitations under the License.
|
||||
#include "tensorflow/stream_executor/lib/stringpiece.h"
|
||||
#include "tensorflow/stream_executor/platform/port.h"
|
||||
|
||||
namespace perftools {
|
||||
namespace gputools {
|
||||
namespace stream_executor {
|
||||
namespace port {
|
||||
|
||||
using tensorflow::Env;
|
||||
@ -37,7 +36,6 @@ inline Status FileExists(const port::StringPiece& filename) {
|
||||
}
|
||||
|
||||
} // namespace port
|
||||
} // namespace gputools
|
||||
} // namespace perftools
|
||||
} // namespace stream_executor
|
||||
|
||||
#endif // TENSORFLOW_STREAM_EXECUTOR_LIB_ENV_H_
|
||||
|
@ -13,21 +13,19 @@ See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
// IWYU pragma: private, include "perftools/gputools/executor/stream_executor.h"
|
||||
// IWYU pragma: private, include "third_party/tensorflow/stream_executor/stream_executor.h"
|
||||
|
||||
#ifndef TENSORFLOW_STREAM_EXECUTOR_LIB_ERROR_H_
|
||||
#define TENSORFLOW_STREAM_EXECUTOR_LIB_ERROR_H_
|
||||
|
||||
#include "tensorflow/core/lib/core/error_codes.pb.h" // IWYU pragma: export
|
||||
|
||||
namespace perftools {
|
||||
namespace gputools {
|
||||
namespace stream_executor {
|
||||
namespace port {
|
||||
|
||||
namespace error = tensorflow::error;
|
||||
|
||||
} // namespace port
|
||||
} // namespace gputools
|
||||
} // namespace perftools
|
||||
} // namespace stream_executor
|
||||
|
||||
#endif // TENSORFLOW_STREAM_EXECUTOR_LIB_ERROR_H_
|
||||
|
@ -22,8 +22,7 @@ limitations under the License.
|
||||
#include "tensorflow/stream_executor/lib/stringprintf.h"
|
||||
#include "tensorflow/stream_executor/platform/port.h"
|
||||
|
||||
namespace perftools {
|
||||
namespace gputools {
|
||||
namespace stream_executor {
|
||||
namespace port {
|
||||
|
||||
class HumanReadableNumBytes {
|
||||
@ -67,7 +66,6 @@ class HumanReadableNumBytes {
|
||||
};
|
||||
|
||||
} // namespace port
|
||||
} // namespace gputools
|
||||
} // namespace perftools
|
||||
} // namespace stream_executor
|
||||
|
||||
#endif // TENSORFLOW_STREAM_EXECUTOR_LIB_HUMAN_READABLE_H_
|
||||
|
@ -26,8 +26,7 @@ limitations under the License.
|
||||
#undef DECLARE_MODULE_INITIALIZER
|
||||
#undef REGISTER_MODULE_INITIALIZER_SEQUENCE
|
||||
|
||||
namespace perftools {
|
||||
namespace gputools {
|
||||
namespace stream_executor {
|
||||
namespace port {
|
||||
|
||||
class Initializer {
|
||||
@ -49,20 +48,18 @@ class Initializer {
|
||||
};
|
||||
|
||||
} // namespace port
|
||||
} // namespace gputools
|
||||
} // namespace perftools
|
||||
} // namespace stream_executor
|
||||
|
||||
#define REGISTER_INITIALIZER(type, name, body) \
|
||||
static void google_init_##type##_##name() { body; } \
|
||||
perftools::gputools::port::Initializer google_initializer_##type##_##name( \
|
||||
#define REGISTER_INITIALIZER(type, name, body) \
|
||||
static void google_init_##type##_##name() { body; } \
|
||||
::stream_executor::port::Initializer google_initializer_##type##_##name( \
|
||||
google_init_##type##_##name)
|
||||
|
||||
#define REGISTER_MODULE_INITIALIZER(name, body) \
|
||||
REGISTER_INITIALIZER(module, name, body)
|
||||
|
||||
#define DECLARE_INITIALIZER(type, name) \
|
||||
extern perftools::gputools::port::Initializer \
|
||||
google_initializer_##type##_##name
|
||||
#define DECLARE_INITIALIZER(type, name) \
|
||||
extern ::stream_executor::port::Initializer google_initializer_##type##_##name
|
||||
|
||||
#define DECLARE_MODULE_INITIALIZER(name) DECLARE_INITIALIZER(module, name)
|
||||
|
||||
|
@ -18,14 +18,12 @@ limitations under the License.
|
||||
|
||||
#include "tensorflow/core/lib/gtl/inlined_vector.h"
|
||||
|
||||
namespace perftools {
|
||||
namespace gputools {
|
||||
namespace stream_executor {
|
||||
namespace port {
|
||||
|
||||
using tensorflow::gtl::InlinedVector;
|
||||
|
||||
} // namespace port
|
||||
} // namespace gputools
|
||||
} // namespace perftools
|
||||
} // namespace stream_executor
|
||||
|
||||
#endif // TENSORFLOW_STREAM_EXECUTOR_LIB_INLINED_VECTOR_H_
|
||||
|
@ -25,8 +25,7 @@ limitations under the License.
|
||||
#include "tensorflow/stream_executor/platform/logging.h"
|
||||
#include "tensorflow/stream_executor/platform/port.h"
|
||||
|
||||
namespace perftools {
|
||||
namespace gputools {
|
||||
namespace stream_executor {
|
||||
namespace port {
|
||||
|
||||
class MathUtil {
|
||||
@ -97,7 +96,6 @@ IntegralType MathUtil::CeilOrFloorOfRatio(IntegralType numerator,
|
||||
}
|
||||
|
||||
} // namespace port
|
||||
} // namespace gputools
|
||||
} // namespace perftools
|
||||
} // namespace stream_executor
|
||||
|
||||
#endif // TENSORFLOW_STREAM_EXECUTOR_LIB_MATHUTIL_H_
|
||||
|
@ -18,14 +18,12 @@ limitations under the License.
|
||||
|
||||
#include "tensorflow/core/platform/notification.h"
|
||||
|
||||
namespace perftools {
|
||||
namespace gputools {
|
||||
namespace stream_executor {
|
||||
namespace port {
|
||||
|
||||
using tensorflow::Notification;
|
||||
|
||||
} // namespace port
|
||||
} // namespace gputools
|
||||
} // namespace perftools
|
||||
} // namespace stream_executor
|
||||
|
||||
#endif // TENSORFLOW_STREAM_EXECUTOR_LIB_NOTIFICATION_H_
|
||||
|
@ -17,8 +17,7 @@ limitations under the License.
|
||||
|
||||
#include <stdlib.h>
|
||||
|
||||
namespace perftools {
|
||||
namespace gputools {
|
||||
namespace stream_executor {
|
||||
namespace port {
|
||||
|
||||
bool safe_strto32(const char* str, int32* value) {
|
||||
@ -38,5 +37,4 @@ bool safe_strto32(const string& str, int32* value) {
|
||||
}
|
||||
|
||||
} // namespace port
|
||||
} // namespace gputools
|
||||
} // namespace perftools
|
||||
} // namespace stream_executor
|
||||
|
@ -18,8 +18,7 @@ limitations under the License.
|
||||
|
||||
#include "tensorflow/stream_executor/platform/port.h"
|
||||
|
||||
namespace perftools {
|
||||
namespace gputools {
|
||||
namespace stream_executor {
|
||||
namespace port {
|
||||
|
||||
// Convert strings to floating point values.
|
||||
@ -28,7 +27,6 @@ namespace port {
|
||||
bool safe_strto32(const string& str, int32* value);
|
||||
|
||||
} // namespace port
|
||||
} // namespace gputools
|
||||
} // namespace perftools
|
||||
} // namespace stream_executor
|
||||
|
||||
#endif // TENSORFLOW_STREAM_EXECUTOR_LIB_NUMBERS_H_
|
||||
|
@ -16,8 +16,7 @@ limitations under the License.
|
||||
#include "tensorflow/stream_executor/lib/path.h"
|
||||
#include "tensorflow/stream_executor/lib/strcat.h"
|
||||
|
||||
namespace perftools {
|
||||
namespace gputools {
|
||||
namespace stream_executor {
|
||||
namespace port {
|
||||
namespace internal {
|
||||
|
||||
@ -58,5 +57,4 @@ string JoinPathImpl(std::initializer_list<port::StringPiece> paths) {
|
||||
|
||||
} // namespace internal
|
||||
} // namespace port
|
||||
} // namespace gputools
|
||||
} // namespace perftools
|
||||
} // namespace stream_executor
|
||||
|
@ -20,8 +20,7 @@ limitations under the License.
|
||||
#include "tensorflow/stream_executor/lib/stringpiece.h"
|
||||
#include "tensorflow/stream_executor/platform/port.h"
|
||||
|
||||
namespace perftools {
|
||||
namespace gputools {
|
||||
namespace stream_executor {
|
||||
namespace port {
|
||||
|
||||
using tensorflow::io::Dirname;
|
||||
@ -56,7 +55,6 @@ inline string JoinPath(const T&... args) {
|
||||
}
|
||||
|
||||
} // namespace port
|
||||
} // namespace gputools
|
||||
} // namespace perftools
|
||||
} // namespace stream_executor
|
||||
|
||||
#endif // TENSORFLOW_STREAM_EXECUTOR_LIB_PATH_H_
|
||||
|
@ -25,8 +25,7 @@ limitations under the License.
|
||||
#endif
|
||||
#include <memory>
|
||||
|
||||
namespace perftools {
|
||||
namespace gputools {
|
||||
namespace stream_executor {
|
||||
namespace port {
|
||||
|
||||
string Hostname() {
|
||||
@ -54,5 +53,4 @@ bool GetCurrentDirectory(string* dir) {
|
||||
}
|
||||
|
||||
} // namespace port
|
||||
} // namespace gputools
|
||||
} // namespace perftools
|
||||
} // namespace stream_executor
|
||||
|
@ -18,15 +18,13 @@ limitations under the License.
|
||||
|
||||
#include "tensorflow/stream_executor/platform/port.h"
|
||||
|
||||
namespace perftools {
|
||||
namespace gputools {
|
||||
namespace stream_executor {
|
||||
namespace port {
|
||||
|
||||
string Hostname();
|
||||
bool GetCurrentDirectory(string* dir);
|
||||
|
||||
} // namespace port
|
||||
} // namespace gputools
|
||||
} // namespace perftools
|
||||
} // namespace stream_executor
|
||||
|
||||
#endif // TENSORFLOW_STREAM_EXECUTOR_LIB_PROCESS_STATE_H_
|
||||
|
@ -18,8 +18,7 @@ limitations under the License.
|
||||
|
||||
#include <memory>
|
||||
|
||||
namespace perftools {
|
||||
namespace gputools {
|
||||
namespace stream_executor {
|
||||
namespace port {
|
||||
|
||||
// Trait to select overloads and return types for MakeUnique.
|
||||
@ -59,8 +58,17 @@ typename MakeUniqueResult<T>::invalid MakeUnique(Args&&... /* args */) =
|
||||
delete; // NOLINT
|
||||
|
||||
} // namespace port
|
||||
} // namespace stream_executor
|
||||
|
||||
namespace perftools {
|
||||
namespace gputools {
|
||||
|
||||
// Temporarily pull stream_executor into perftools::gputools while we migrate
|
||||
// code to the new namespace. TODO(jlebar): Remove this once we've completed
|
||||
// the migration.
|
||||
using namespace stream_executor; // NOLINT[build/namespaces]
|
||||
|
||||
} // namespace gputools
|
||||
} // namespace perftools
|
||||
|
||||
|
||||
#endif // TENSORFLOW_STREAM_EXECUTOR_LIB_PTR_UTIL_H_
|
||||
|
@ -19,14 +19,12 @@ limitations under the License.
|
||||
#include "tensorflow/core/platform/stacktrace.h"
|
||||
#include "tensorflow/stream_executor/platform/port.h"
|
||||
|
||||
namespace perftools {
|
||||
namespace gputools {
|
||||
namespace stream_executor {
|
||||
namespace port {
|
||||
|
||||
using tensorflow::CurrentStackTrace;
|
||||
|
||||
} // namespace port
|
||||
} // namespace gputools
|
||||
} // namespace perftools
|
||||
} // namespace stream_executor
|
||||
|
||||
#endif // TENSORFLOW_STREAM_EXECUTOR_LIB_STACKTRACE_H_
|
||||
|
@ -13,7 +13,7 @@ See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
// IWYU pragma: private, include "perftools/gputools/executor/stream_executor.h"
|
||||
// IWYU pragma: private, include "third_party/tensorflow/stream_executor/stream_executor.h"
|
||||
|
||||
#ifndef TENSORFLOW_STREAM_EXECUTOR_LIB_STATUS_H_
|
||||
#define TENSORFLOW_STREAM_EXECUTOR_LIB_STATUS_H_
|
||||
@ -23,15 +23,14 @@ limitations under the License.
|
||||
#include "tensorflow/stream_executor/lib/stringpiece.h"
|
||||
#include "tensorflow/stream_executor/platform/logging.h"
|
||||
|
||||
namespace perftools {
|
||||
namespace gputools {
|
||||
namespace stream_executor {
|
||||
namespace port {
|
||||
|
||||
using Status = tensorflow::Status;
|
||||
|
||||
#define SE_CHECK_OK(val) TF_CHECK_OK(val)
|
||||
#define SE_ASSERT_OK(val) \
|
||||
ASSERT_EQ(::perftools::gputools::port::Status::OK(), (val))
|
||||
ASSERT_EQ(::stream_executor::port::Status::OK(), (val))
|
||||
|
||||
// Define some canonical error helpers.
|
||||
inline Status UnimplementedError(StringPiece message) {
|
||||
@ -45,6 +44,16 @@ inline Status FailedPreconditionError(StringPiece message) {
|
||||
}
|
||||
|
||||
} // namespace port
|
||||
} // namespace stream_executor
|
||||
|
||||
namespace perftools {
|
||||
namespace gputools {
|
||||
|
||||
// Temporarily pull stream_executor into perftools::gputools while we migrate
|
||||
// code to the new namespace. TODO(b/77980417): Remove this once we've
|
||||
// completed the migration.
|
||||
using namespace stream_executor; // NOLINT[build/namespaces]
|
||||
|
||||
} // namespace gputools
|
||||
} // namespace perftools
|
||||
|
||||
|
@ -13,15 +13,14 @@ See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
// IWYU pragma: private, include "perftools/gputools/executor/stream_executor.h"
|
||||
// IWYU pragma: private, include "third_party/tensorflow/stream_executor/stream_executor.h"
|
||||
|
||||
#ifndef TENSORFLOW_STREAM_EXECUTOR_LIB_STATUSOR_H_
|
||||
#define TENSORFLOW_STREAM_EXECUTOR_LIB_STATUSOR_H_
|
||||
|
||||
#include "tensorflow/compiler/xla/statusor.h"
|
||||
|
||||
namespace perftools {
|
||||
namespace gputools {
|
||||
namespace stream_executor {
|
||||
namespace port {
|
||||
|
||||
// Use XLA's StatusOr so we don't duplicate code.
|
||||
@ -29,7 +28,6 @@ template <typename T>
|
||||
using StatusOr = ::xla::StatusOr<T>;
|
||||
|
||||
} // namespace port
|
||||
} // namespace gputools
|
||||
} // namespace perftools
|
||||
} // namespace stream_executor
|
||||
|
||||
#endif // TENSORFLOW_STREAM_EXECUTOR_LIB_STATUSOR_H_
|
||||
|
@ -19,8 +19,7 @@ limitations under the License.
|
||||
#include "tensorflow/core/lib/strings/str_util.h"
|
||||
#include "tensorflow/stream_executor/lib/stringpiece.h"
|
||||
|
||||
namespace perftools {
|
||||
namespace gputools {
|
||||
namespace stream_executor {
|
||||
namespace port {
|
||||
|
||||
using tensorflow::str_util::Join;
|
||||
@ -38,7 +37,6 @@ inline string StripSuffixString(port::StringPiece str, port::StringPiece suffix)
|
||||
using tensorflow::str_util::Lowercase;
|
||||
|
||||
} // namespace port
|
||||
} // namespace gputools
|
||||
} // namespace perftools
|
||||
} // namespace stream_executor
|
||||
|
||||
#endif // TENSORFLOW_STREAM_EXECUTOR_LIB_STR_UTIL_H_
|
||||
|
@ -13,22 +13,20 @@ See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
// IWYU pragma: private, include "perftools/gputools/executor/stream_executor.h"
|
||||
// IWYU pragma: private, include "third_party/tensorflow/stream_executor/stream_executor.h"
|
||||
|
||||
#ifndef TENSORFLOW_STREAM_EXECUTOR_LIB_STRCAT_H_
|
||||
#define TENSORFLOW_STREAM_EXECUTOR_LIB_STRCAT_H_
|
||||
|
||||
#include "tensorflow/core/lib/strings/strcat.h"
|
||||
|
||||
namespace perftools {
|
||||
namespace gputools {
|
||||
namespace stream_executor {
|
||||
namespace port {
|
||||
|
||||
using tensorflow::strings::StrCat;
|
||||
using tensorflow::strings::StrAppend;
|
||||
|
||||
} // namespace port
|
||||
} // namespace gputools
|
||||
} // namespace perftools
|
||||
} // namespace stream_executor
|
||||
|
||||
#endif // TENSORFLOW_STREAM_EXECUTOR_LIB_STRCAT_H_
|
||||
|
@ -19,14 +19,12 @@ limitations under the License.
|
||||
#include "tensorflow/core/lib/core/stringpiece.h"
|
||||
#include "tensorflow/stream_executor/platform/port.h"
|
||||
|
||||
namespace perftools {
|
||||
namespace gputools {
|
||||
namespace stream_executor {
|
||||
namespace port {
|
||||
|
||||
using tensorflow::StringPiece;
|
||||
|
||||
} // namespace port
|
||||
} // namespace gputools
|
||||
} // namespace perftools
|
||||
} // namespace stream_executor
|
||||
|
||||
#endif // TENSORFLOW_STREAM_EXECUTOR_LIB_STRINGPIECE_H_
|
||||
|
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue
Block a user