Merge pull request #25424 from ROCmSoftwarePlatform:google_upstream_add_rocm_blas_fft_rand
PiperOrigin-RevId: 234978728
This commit is contained in:
commit
6a25b9d561
@ -39,7 +39,7 @@ port::StatusOr<void*> GetDsoHandle(const string& name, const string& version) {
|
||||
port::Status status =
|
||||
port::Env::Default()->LoadLibrary(filename.c_str(), &dso_handle);
|
||||
if (status.ok()) {
|
||||
LOG(INFO) << "Successfully opened CUDA library " << filename;
|
||||
LOG(INFO) << "Successfully opened dynamic library " << filename;
|
||||
return dso_handle;
|
||||
}
|
||||
|
||||
@ -99,6 +99,25 @@ port::StatusOr<void*> GetCuptiDsoHandle() {
|
||||
port::StatusOr<void*> GetCudnnDsoHandle() {
|
||||
return GetDsoHandle("cudnn", GetCudnnVersion());
|
||||
}
|
||||
|
||||
port::StatusOr<void*> GetRocblasDsoHandle() {
|
||||
return GetDsoHandle("rocblas", "");
|
||||
}
|
||||
|
||||
port::StatusOr<void*> GetMiopenDsoHandle() {
|
||||
return GetDsoHandle("MIOpen", "");
|
||||
}
|
||||
|
||||
port::StatusOr<void*> GetRocfftDsoHandle() {
|
||||
return GetDsoHandle("rocfft", "");
|
||||
}
|
||||
|
||||
port::StatusOr<void*> GetRocrandDsoHandle() {
|
||||
return GetDsoHandle("rocrand", "");
|
||||
}
|
||||
|
||||
port::StatusOr<void*> GetHipDsoHandle() { return GetDsoHandle("hip_hcc", ""); }
|
||||
|
||||
} // namespace DsoLoader
|
||||
|
||||
namespace CachedDsoLoader {
|
||||
@ -136,6 +155,32 @@ port::StatusOr<void*> GetCudnnDsoHandle() {
|
||||
static auto result = new auto(DsoLoader::GetCudnnDsoHandle());
|
||||
return *result;
|
||||
}
|
||||
|
||||
port::StatusOr<void*> GetRocblasDsoHandle() {
|
||||
static auto result = new auto(DsoLoader::GetRocblasDsoHandle());
|
||||
return result;
|
||||
}
|
||||
|
||||
port::StatusOr<void*> GetMiopenDsoHandle() {
|
||||
static auto result = new auto(DsoLoader::GetMiopenDsoHandle());
|
||||
return result;
|
||||
}
|
||||
|
||||
port::StatusOr<void*> GetRocfftDsoHandle() {
|
||||
static auto result = new auto(DsoLoader::GetRocfftDsoHandle());
|
||||
return result;
|
||||
}
|
||||
|
||||
port::StatusOr<void*> GetRocrandDsoHandle() {
|
||||
static auto result = new auto(DsoLoader::GetRocrandDsoHandle());
|
||||
return result;
|
||||
}
|
||||
|
||||
port::StatusOr<void*> GetHipDsoHandle() {
|
||||
static auto result = new auto(DsoLoader::GetHipDsoHandle());
|
||||
return result;
|
||||
}
|
||||
|
||||
} // namespace CachedDsoLoader
|
||||
} // namespace internal
|
||||
} // namespace stream_executor
|
||||
|
@ -41,6 +41,12 @@ port::StatusOr<void*> GetCufftDsoHandle();
|
||||
port::StatusOr<void*> GetCurandDsoHandle();
|
||||
port::StatusOr<void*> GetCuptiDsoHandle();
|
||||
port::StatusOr<void*> GetCudnnDsoHandle();
|
||||
|
||||
port::StatusOr<void*> GetRocblasDsoHandle();
|
||||
port::StatusOr<void*> GetMiopenDsoHandle();
|
||||
port::StatusOr<void*> GetRocfftDsoHandle();
|
||||
port::StatusOr<void*> GetRocrandDsoHandle();
|
||||
port::StatusOr<void*> GetHipDsoHandle();
|
||||
} // namespace DsoLoader
|
||||
|
||||
// Wrapper around the DsoLoader that prevents us from dlopen'ing any of the DSOs
|
||||
@ -54,7 +60,14 @@ port::StatusOr<void*> GetCufftDsoHandle();
|
||||
port::StatusOr<void*> GetCurandDsoHandle();
|
||||
port::StatusOr<void*> GetCuptiDsoHandle();
|
||||
port::StatusOr<void*> GetCudnnDsoHandle();
|
||||
|
||||
port::StatusOr<void*> GetRocblasDsoHandle();
|
||||
port::StatusOr<void*> GetMiopenDsoHandle();
|
||||
port::StatusOr<void*> GetRocfftDsoHandle();
|
||||
port::StatusOr<void*> GetRocrandDsoHandle();
|
||||
port::StatusOr<void*> GetHipDsoHandle();
|
||||
} // namespace CachedDsoLoader
|
||||
|
||||
} // namespace internal
|
||||
} // namespace stream_executor
|
||||
|
||||
|
@ -33,7 +33,7 @@ filegroup(
|
||||
cc_library(
|
||||
name = "rocm_diagnostics",
|
||||
srcs = if_rocm_is_configured(["rocm_diagnostics.cc"]),
|
||||
hdrs = [],
|
||||
hdrs = if_rocm_is_configured(["rocm_diagnostics.h"]),
|
||||
deps = if_rocm_is_configured([
|
||||
"@com_google_absl//absl/container:inlined_vector",
|
||||
"@com_google_absl//absl/strings",
|
||||
@ -47,7 +47,7 @@ cc_library(
|
||||
cc_library(
|
||||
name = "rocm_driver",
|
||||
srcs = if_rocm_is_configured(["rocm_driver.cc"]),
|
||||
hdrs = [],
|
||||
hdrs = if_rocm_is_configured(["rocm_driver_wrapper.h"]),
|
||||
deps = if_rocm_is_configured([
|
||||
":rocm_diagnostics",
|
||||
"@com_google_absl//absl/base",
|
||||
@ -57,6 +57,7 @@ cc_library(
|
||||
"//tensorflow/stream_executor/gpu:gpu_driver_header",
|
||||
"//tensorflow/stream_executor/lib",
|
||||
"//tensorflow/stream_executor/platform",
|
||||
"//tensorflow/stream_executor/platform:dso_loader",
|
||||
"@local_config_rocm//rocm:rocm_headers",
|
||||
]),
|
||||
)
|
||||
@ -140,55 +141,62 @@ cc_library(
|
||||
deps = ["//tensorflow/stream_executor:platform"],
|
||||
)
|
||||
|
||||
# FIXME: enable in future PRs
|
||||
#cc_library(
|
||||
# name = "rocblas_plugin",
|
||||
# srcs = ["rocm_blas.cc"],
|
||||
# hdrs = ["rocm_blas.h"],
|
||||
# visibility = ["//visibility:public"],
|
||||
# deps = [
|
||||
# ":rocm_gpu_executor",
|
||||
# ":rocm_platform_id",
|
||||
# "//third_party/eigen3",
|
||||
# "//tensorflow/core:lib_internal",
|
||||
# "//tensorflow/stream_executor",
|
||||
# "//tensorflow/stream_executor:event",
|
||||
# "//tensorflow/stream_executor:host_or_device_scalar",
|
||||
# "//tensorflow/stream_executor:plugin_registry",
|
||||
# "//tensorflow/stream_executor:scratch_allocator",
|
||||
# "//tensorflow/stream_executor:timer",
|
||||
# "//tenosrflow/stream_executor/gpu:gpu_activation_header",
|
||||
# "//tenosrflow/stream_executor/gpu:gpu_stream_header",
|
||||
# "//tenosrflow/stream_executor/gpu:gpu_timer_header",
|
||||
# "//tensorflow/stream_executor/lib",
|
||||
# "//tensorflow/stream_executor/platform",
|
||||
# "//tensorflow/stream_executor/platform:dso_loader",
|
||||
# "@com_google_absl//absl/strings",
|
||||
# "@local_config_rocm//rocm:rocm_headers",
|
||||
# ] + if_static(["@local_config_rocm//rocm:rocblas"]),
|
||||
# alwayslink = True,
|
||||
#)
|
||||
cc_library(
|
||||
name = "rocblas_plugin",
|
||||
srcs = if_rocm_is_configured(["rocm_blas.cc"]),
|
||||
hdrs = if_rocm_is_configured(["rocm_blas.h"]),
|
||||
visibility = ["//visibility:public"],
|
||||
deps = if_rocm_is_configured([
|
||||
":rocm_gpu_executor",
|
||||
":rocm_platform_id",
|
||||
"//third_party/eigen3",
|
||||
"//tensorflow/core:lib_internal",
|
||||
"//tensorflow/stream_executor",
|
||||
"//tensorflow/stream_executor:event",
|
||||
"//tensorflow/stream_executor:host_or_device_scalar",
|
||||
"//tensorflow/stream_executor:plugin_registry",
|
||||
"//tensorflow/stream_executor:scratch_allocator",
|
||||
"//tensorflow/stream_executor:timer",
|
||||
"//tensorflow/stream_executor/gpu:gpu_activation",
|
||||
"//tensorflow/stream_executor/gpu:gpu_helpers_header",
|
||||
"//tensorflow/stream_executor/gpu:gpu_stream_header",
|
||||
"//tensorflow/stream_executor/gpu:gpu_timer_header",
|
||||
"//tensorflow/stream_executor/lib",
|
||||
"//tensorflow/stream_executor/platform",
|
||||
"//tensorflow/stream_executor/platform:dso_loader",
|
||||
"@com_google_absl//absl/strings",
|
||||
"@local_config_rocm//rocm:rocm_headers",
|
||||
] + if_static([
|
||||
"@local_config_rocm//rocm:rocblas",
|
||||
])),
|
||||
alwayslink = True,
|
||||
)
|
||||
|
||||
# FIXME: enable in future PRs
|
||||
#cc_library(
|
||||
# name = "rocfft_plugin",
|
||||
# srcs = ["rocm_fft.cc"],
|
||||
# hdrs = [],
|
||||
# visibility = ["//visibility:public"],
|
||||
# deps = [
|
||||
# ":rocm_platform_id",
|
||||
# "//tensorflow/stream_executor:event",
|
||||
# "//tensorflow/stream_executor:fft",
|
||||
# "//tensorflow/stream_executor:plugin_registry",
|
||||
# "//tensorflow/stream_executor:scratch_allocator",
|
||||
# "//tenosrflow/stream_executor/gpu:gpu_stream_header",
|
||||
# "//tensorflow/stream_executor/lib",
|
||||
# "//tensorflow/stream_executor/platform",
|
||||
# "//tensorflow/stream_executor/platform:dso_loader",
|
||||
# "@local_config_rocm//rocm:rocm_headers",
|
||||
# ] + if_static(["@local_config_rocm//rocm:rocfft"]),
|
||||
# alwayslink = True,
|
||||
#)
|
||||
cc_library(
|
||||
name = "rocfft_plugin",
|
||||
srcs = if_rocm_is_configured(["rocm_fft.cc"]),
|
||||
hdrs = if_rocm_is_configured(["rocm_fft.h"]),
|
||||
visibility = ["//visibility:public"],
|
||||
deps = if_rocm_is_configured([
|
||||
":rocm_platform_id",
|
||||
"//tensorflow/stream_executor:event",
|
||||
"//tensorflow/stream_executor:fft",
|
||||
"//tensorflow/stream_executor:plugin_registry",
|
||||
"//tensorflow/stream_executor:scratch_allocator",
|
||||
"//tensorflow/stream_executor/gpu:gpu_activation",
|
||||
"//tensorflow/stream_executor/gpu:gpu_helpers_header",
|
||||
"//tensorflow/stream_executor/gpu:gpu_executor_header",
|
||||
"//tensorflow/stream_executor/gpu:gpu_stream_header",
|
||||
"//tensorflow/stream_executor/gpu:gpu_kernel_header",
|
||||
"//tensorflow/stream_executor/lib",
|
||||
"//tensorflow/stream_executor/platform",
|
||||
"//tensorflow/stream_executor/platform:dso_loader",
|
||||
"@local_config_rocm//rocm:rocm_headers",
|
||||
] + if_static([
|
||||
"@local_config_rocm//rocm:rocfft",
|
||||
])),
|
||||
alwayslink = True,
|
||||
)
|
||||
|
||||
# FIXME: enable in future PRs
|
||||
#cc_library(
|
||||
@ -229,26 +237,30 @@ cc_library(
|
||||
# alwayslink = True,
|
||||
#)
|
||||
|
||||
# FIXME: enable in future PRs
|
||||
#cc_library(
|
||||
# name = "rocrand_plugin",
|
||||
# srcs = ["rocm_rng.cc"],
|
||||
# hdrs = [],
|
||||
# deps = [
|
||||
# ":rocm_gpu_executor",
|
||||
# ":rocm_platform_id",
|
||||
# "@local_config_rocm//rocm:rocm_headers",
|
||||
# "//tensorflow/stream_executor:event",
|
||||
# "//tensorflow/stream_executor:plugin_registry",
|
||||
# "//tensorflow/stream_executor:rng",
|
||||
# "//tenosrflow/stream_executor/gpu:gpu_activation_header",
|
||||
# "//tenosrflow/stream_executor/gpu:gpu_stream_header",
|
||||
# "//tensorflow/stream_executor/lib",
|
||||
# "//tensorflow/stream_executor/platform",
|
||||
# "//tensorflow/stream_executor/platform:dso_loader",
|
||||
# ] + if_static(["@local_config_rocm//rocm:curand"]),
|
||||
# alwayslink = True,
|
||||
#)
|
||||
cc_library(
|
||||
name = "rocrand_plugin",
|
||||
srcs = if_rocm_is_configured(["rocm_rng.cc"]),
|
||||
hdrs = if_rocm_is_configured([]),
|
||||
deps = if_rocm_is_configured([
|
||||
":rocm_gpu_executor",
|
||||
":rocm_platform_id",
|
||||
"@local_config_rocm//rocm:rocm_headers",
|
||||
"//tensorflow/stream_executor:event",
|
||||
"//tensorflow/stream_executor:plugin_registry",
|
||||
"//tensorflow/stream_executor:rng",
|
||||
"//tensorflow/stream_executor/gpu:gpu_activation_header",
|
||||
"//tensorflow/stream_executor/gpu:gpu_helpers_header",
|
||||
"//tensorflow/stream_executor/gpu:gpu_executor_header",
|
||||
"//tensorflow/stream_executor/gpu:gpu_rng_header",
|
||||
"//tensorflow/stream_executor/gpu:gpu_stream_header",
|
||||
"//tensorflow/stream_executor/lib",
|
||||
"//tensorflow/stream_executor/platform",
|
||||
"//tensorflow/stream_executor/platform:dso_loader",
|
||||
] + if_static([
|
||||
"@local_config_rocm//rocm:hiprand",
|
||||
])),
|
||||
alwayslink = True,
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "all_runtime",
|
||||
@ -257,9 +269,9 @@ cc_library(
|
||||
deps = if_rocm_is_configured([
|
||||
# FIXME: enable in future PRs
|
||||
#":miopen_plugin",
|
||||
#":rocfft_plugin",
|
||||
#":rocblas_plugin",
|
||||
#":rocrand_plugin",
|
||||
":rocfft_plugin",
|
||||
":rocblas_plugin",
|
||||
":rocrand_plugin",
|
||||
":rocm_driver",
|
||||
":rocm_platform",
|
||||
]),
|
||||
|
2374
tensorflow/stream_executor/rocm/rocm_blas.cc
Normal file
2374
tensorflow/stream_executor/rocm/rocm_blas.cc
Normal file
File diff suppressed because it is too large
Load Diff
159
tensorflow/stream_executor/rocm/rocm_blas.h
Normal file
159
tensorflow/stream_executor/rocm/rocm_blas.h
Normal file
@ -0,0 +1,159 @@
|
||||
/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
// ROCM-specific support for BLAS functionality -- this wraps the rocBLAS
|
||||
// library capabilities, and is only included into ROCM implementation code --
|
||||
// it will not introduce rocm headers into other code.
|
||||
|
||||
#ifndef TENSORFLOW_STREAM_EXECUTOR_ROCM_ROCM_BLAS_H_
|
||||
#define TENSORFLOW_STREAM_EXECUTOR_ROCM_ROCM_BLAS_H_
|
||||
|
||||
#include "tensorflow/stream_executor/blas.h"
|
||||
#include "tensorflow/stream_executor/platform/mutex.h"
|
||||
#include "tensorflow/stream_executor/platform/port.h"
|
||||
#include "tensorflow/stream_executor/platform/thread_annotations.h"
|
||||
#include "tensorflow/stream_executor/plugin_registry.h"
|
||||
|
||||
namespace stream_executor {
|
||||
|
||||
class Stream;
|
||||
|
||||
namespace gpu {
|
||||
|
||||
// Opaque and unique identifier for the rocBLAS plugin.
|
||||
extern const PluginId kRocBlasPlugin;
|
||||
|
||||
class GpuExecutor;
|
||||
|
||||
// BLAS plugin for ROCM platform via rocBLAS library.
|
||||
//
|
||||
// This satisfies the platform-agnostic BlasSupport interface.
|
||||
//
|
||||
// Note that the rocBLAS handle that this encapsulates is implicitly tied to the
|
||||
// context (and, as a result, the device) that the parent GpuExecutor is tied
|
||||
// to. This simply happens as an artifact of creating the rocBLAS handle when a
|
||||
// ROCM context is active.
|
||||
//
|
||||
// Thread-safe post-initialization.
|
||||
class ROCMBlas : public blas::BlasSupport {
|
||||
public:
|
||||
explicit ROCMBlas(GpuExecutor *parent);
|
||||
|
||||
// Allocates a rocBLAS handle.
|
||||
bool Init();
|
||||
|
||||
// Releases the rocBLAS handle, if present.
|
||||
~ROCMBlas() override;
|
||||
|
||||
TENSORFLOW_STREAM_EXECUTOR_GPU_BLAS_SUPPORT_OVERRIDES
|
||||
|
||||
private:
|
||||
// Tells rocBLAS to enqueue the BLAS operation onto a particular Stream.
|
||||
//
|
||||
// rocBLAS is stateful, and only be associated with one stream (in order to
|
||||
// enqueue dispatch) at a given time. As a result, this generally must be
|
||||
// invoked before calling into rocBLAS.
|
||||
bool SetStream(Stream *stream) EXCLUSIVE_LOCKS_REQUIRED(mu_);
|
||||
|
||||
// A helper function that calls the real rocBLAS function together with error
|
||||
// handling.
|
||||
//
|
||||
// rocblas_func: rocBLAS function pointer.
|
||||
// rocblas_name: rocBLAS function name.
|
||||
// stream: Stream to enqueue the BLAS operation onto.
|
||||
// pointer_mode_host: Indicate if the pointer to a scalar value is from host
|
||||
// (true) or device (false).
|
||||
// err_on_failure: Whether to print an error if the rocBLAS function
|
||||
// fails. args: Arguments of rocBLAS function.
|
||||
template <typename FuncT, typename... Args>
|
||||
bool DoBlasInternalImpl(FuncT rocblas_func, Stream *stream,
|
||||
bool pointer_mode_host, bool err_on_failure,
|
||||
Args... args);
|
||||
|
||||
// Convenience functions that call DoBlasInternalImpl with different values
|
||||
// for err_on_failure.
|
||||
template <typename FuncT, typename... Args>
|
||||
bool DoBlasInternal(FuncT rocblas_func, Stream *stream,
|
||||
bool pointer_mode_host, Args... args) {
|
||||
return DoBlasInternalImpl(rocblas_func, stream, pointer_mode_host,
|
||||
/*err_on_failure=*/true, args...);
|
||||
}
|
||||
template <typename FuncT, typename... Args>
|
||||
bool DoBlasInternalFailureOK(FuncT rocblas_func, Stream *stream,
|
||||
bool pointer_mode_host, Args... args) {
|
||||
return DoBlasInternalImpl(rocblas_func, stream, pointer_mode_host,
|
||||
/*err_on_failure=*/false, args...);
|
||||
}
|
||||
|
||||
// A helper function to implement DoBlasGemmBatched interfaces for generic
|
||||
// types.
|
||||
template <typename T, typename FuncT>
|
||||
port::Status DoBlasGemmBatchedInternal(
|
||||
FuncT rocblas_func, Stream *stream, blas::Transpose transa,
|
||||
blas::Transpose transb, uint64 m, uint64 n, uint64 k, T alpha,
|
||||
const port::ArraySlice<DeviceMemory<T> *> &a_array, int lda,
|
||||
const port::ArraySlice<DeviceMemory<T> *> &b_array, int ldb, T beta,
|
||||
const port::ArraySlice<DeviceMemory<T> *> &c_array, int ldc,
|
||||
int batch_count, ScratchAllocator *scratch_allocator);
|
||||
|
||||
// Helper function for implementing DoBlasGemmWithAlgorithm.
|
||||
//
|
||||
// We take alpha and beta by const reference because T might be Eigen::half,
|
||||
// and we want to avoid pulling in a dependency on Eigen. When we pass the
|
||||
// references to rocBLAS, we essentially reinterpret_cast to __half, which is
|
||||
// safe because Eigen::half inherits from __half.
|
||||
template <typename InT, typename OutT, typename CompT>
|
||||
bool DoBlasGemmWithAlgorithmImpl(
|
||||
Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64 m,
|
||||
uint64 n, uint64 k, const CompT &alpha, const DeviceMemory<InT> &a,
|
||||
int lda, const DeviceMemory<InT> &b, int ldb, const CompT &beta,
|
||||
DeviceMemory<OutT> *c, int ldc, blas::ComputationType computation_type,
|
||||
blas::AlgorithmType algorithm,
|
||||
blas::ProfileResult *output_profile_result);
|
||||
|
||||
// Helper function for implementing DoBlasGemmWithProfiling.
|
||||
template <typename T, typename ParamType>
|
||||
bool DoBlasGemmWithProfilingImpl(
|
||||
Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64 m,
|
||||
uint64 n, uint64 k, const ParamType &alpha, const DeviceMemory<T> &a,
|
||||
int lda, const DeviceMemory<T> &b, int ldb, const ParamType &beta,
|
||||
DeviceMemory<T> *c, int ldc, blas::ProfileResult *output_profile_result);
|
||||
|
||||
// Helper function for implementing DoBlasGemvWithProfiling.
|
||||
template <typename T>
|
||||
bool DoBlasGemvWithProfilingImpl(Stream *stream, blas::Transpose trans,
|
||||
uint64 m, uint64 n, const T &alpha,
|
||||
const DeviceMemory<T> &a, int lda,
|
||||
const DeviceMemory<T> &x, int incx,
|
||||
const T &beta, DeviceMemory<T> *y, int incy,
|
||||
blas::ProfileResult *output_profile_result);
|
||||
|
||||
// mutex that guards the rocBLAS handle for this device.
|
||||
mutex mu_;
|
||||
|
||||
// GpuExecutor which instantiated this ROCMBlas.
|
||||
// Immutable post-initialization.
|
||||
GpuExecutor *parent_;
|
||||
|
||||
// rocBLAS library handle on the device.
|
||||
rocblas_handle blas_ GUARDED_BY(mu_);
|
||||
|
||||
SE_DISALLOW_COPY_AND_ASSIGN(ROCMBlas);
|
||||
};
|
||||
|
||||
} // namespace gpu
|
||||
} // namespace stream_executor
|
||||
|
||||
#endif // TENSORFLOW_STREAM_EXECUTOR_ROCM_ROCM_BLAS_H_
|
@ -30,7 +30,6 @@ limitations under the License.
|
||||
#include "absl/container/inlined_vector.h"
|
||||
#include "absl/strings/str_cat.h"
|
||||
#include "absl/strings/str_format.h"
|
||||
#include "tensorflow/stream_executor/gpu/gpu_diagnostics.h"
|
||||
#include "tensorflow/stream_executor/lib/error.h"
|
||||
#include "tensorflow/stream_executor/lib/numbers.h"
|
||||
#include "tensorflow/stream_executor/lib/process_state.h"
|
||||
@ -38,9 +37,10 @@ limitations under the License.
|
||||
#include "tensorflow/stream_executor/lib/str_util.h"
|
||||
#include "tensorflow/stream_executor/lib/stringprintf.h"
|
||||
#include "tensorflow/stream_executor/platform/logging.h"
|
||||
#include "tensorflow/stream_executor/rocm/rocm_diagnostics.h"
|
||||
|
||||
namespace stream_executor {
|
||||
namespace gpu {
|
||||
namespace rocm {
|
||||
|
||||
string DriverVersionToString(DriverVersion version) {
|
||||
return absl::StrFormat("%d.%d.%d", std::get<0>(version), std::get<1>(version),
|
||||
@ -95,6 +95,12 @@ port::StatusOr<DriverVersion> StringToDriverVersion(const string& value) {
|
||||
return result;
|
||||
}
|
||||
|
||||
} // namespace rocm
|
||||
} // namespace stream_executor
|
||||
|
||||
namespace stream_executor {
|
||||
namespace gpu {
|
||||
|
||||
// -- class Diagnostician
|
||||
|
||||
string Diagnostician::GetDevNodePath(int dev_node_ordinal) {
|
||||
@ -133,11 +139,11 @@ void Diagnostician::LogDiagnosticInformation() {
|
||||
}
|
||||
port::StatusOr<DriverVersion> dso_version = FindDsoVersion();
|
||||
LOG(INFO) << "librocm reported version is: "
|
||||
<< DriverVersionStatusToString(dso_version);
|
||||
<< rocm::DriverVersionStatusToString(dso_version);
|
||||
|
||||
port::StatusOr<DriverVersion> kernel_version = FindKernelDriverVersion();
|
||||
LOG(INFO) << "kernel reported version is: "
|
||||
<< DriverVersionStatusToString(kernel_version);
|
||||
<< rocm::DriverVersionStatusToString(kernel_version);
|
||||
|
||||
if (kernel_version.ok() && dso_version.ok()) {
|
||||
WarnOnDsoKernelMismatch(dso_version, kernel_version);
|
||||
@ -175,7 +181,7 @@ port::StatusOr<DriverVersion> Diagnostician::FindDsoVersion() {
|
||||
// TODO(b/22689637): Eliminate the explicit namespace if possible.
|
||||
auto stripped_dso_version = port::StripSuffixString(dso_version, ".ld64");
|
||||
auto result = static_cast<port::StatusOr<DriverVersion>*>(data);
|
||||
*result = StringToDriverVersion(stripped_dso_version);
|
||||
*result = rocm::StringToDriverVersion(stripped_dso_version);
|
||||
return 1;
|
||||
}
|
||||
return 0;
|
||||
@ -205,7 +211,7 @@ port::StatusOr<DriverVersion> Diagnostician::FindKernelModuleVersion(
|
||||
// TODO(b/22689637): Eliminate the explicit namespace if possible.
|
||||
auto stripped_kernel_version =
|
||||
port::StripSuffixString(kernel_version, ".ld64");
|
||||
return StringToDriverVersion(stripped_kernel_version);
|
||||
return rocm::StringToDriverVersion(stripped_kernel_version);
|
||||
}
|
||||
|
||||
void Diagnostician::WarnOnDsoKernelMismatch(
|
||||
@ -214,12 +220,12 @@ void Diagnostician::WarnOnDsoKernelMismatch(
|
||||
if (kernel_version.ok() && dso_version.ok() &&
|
||||
dso_version.ValueOrDie() == kernel_version.ValueOrDie()) {
|
||||
LOG(INFO) << "kernel version seems to match DSO: "
|
||||
<< DriverVersionToString(kernel_version.ValueOrDie());
|
||||
<< rocm::DriverVersionToString(kernel_version.ValueOrDie());
|
||||
} else {
|
||||
LOG(ERROR) << "kernel version "
|
||||
<< DriverVersionStatusToString(kernel_version)
|
||||
<< rocm::DriverVersionStatusToString(kernel_version)
|
||||
<< " does not match DSO version "
|
||||
<< DriverVersionStatusToString(dso_version)
|
||||
<< rocm::DriverVersionStatusToString(dso_version)
|
||||
<< " -- cannot find working devices in this configuration";
|
||||
}
|
||||
}
|
||||
|
41
tensorflow/stream_executor/rocm/rocm_diagnostics.h
Normal file
41
tensorflow/stream_executor/rocm/rocm_diagnostics.h
Normal file
@ -0,0 +1,41 @@
|
||||
/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef TENSORFLOW_STREAM_EXECUTOR_ROCM_ROCM_DIAGNOSTICS_H_
|
||||
#define TENSORFLOW_STREAM_EXECUTOR_ROCM_ROCM_DIAGNOSTICS_H_
|
||||
|
||||
#include "tensorflow/stream_executor/gpu/gpu_diagnostics.h"
|
||||
|
||||
namespace stream_executor {
|
||||
namespace rocm {
|
||||
|
||||
// e.g. DriverVersion{346, 3, 4}
|
||||
using DriverVersion = gpu::DriverVersion;
|
||||
|
||||
// Converts a parsed driver version to string form.
|
||||
string DriverVersionToString(DriverVersion version);
|
||||
|
||||
// Converts a parsed driver version or status value to natural string form.
|
||||
string DriverVersionStatusToString(port::StatusOr<DriverVersion> version);
|
||||
|
||||
// Converts a string of a form like "331.79" to a DriverVersion{331, 79}.
|
||||
port::StatusOr<DriverVersion> StringToDriverVersion(const string& value);
|
||||
|
||||
using Diagnostician = gpu::Diagnostician;
|
||||
|
||||
} // namespace rocm
|
||||
} // namespace stream_executor
|
||||
|
||||
#endif // TENSORFLOW_STREAM_EXECUTOR_ROCM_ROCM_DIAGNOSTICS_H_
|
@ -36,6 +36,7 @@ limitations under the License.
|
||||
#include "tensorflow/stream_executor/platform/logging.h"
|
||||
#include "tensorflow/stream_executor/platform/mutex.h"
|
||||
#include "tensorflow/stream_executor/platform/port.h"
|
||||
#include "tensorflow/stream_executor/rocm/rocm_driver_wrapper.h"
|
||||
|
||||
bool FLAGS_gpuexec_rocm_driver_inject_init_error = false;
|
||||
bool FLAGS_gpuexec_rocm_sync_around_driver_calls = false;
|
||||
@ -143,7 +144,7 @@ string MemorySpaceString(MemorySpace memory_space) {
|
||||
// HIP driver (e.g., this value is not our cached view of the current device).
|
||||
static int CurrentDeviceOrDie() {
|
||||
int current = -1;
|
||||
hipError_t result = hipGetDevice(¤t);
|
||||
hipError_t result = tensorflow::wrap::hipGetDevice(¤t);
|
||||
if (result != hipSuccess) {
|
||||
LOG(FATAL) << "failed to query current device: " << ToString(result);
|
||||
}
|
||||
@ -154,7 +155,7 @@ namespace {
|
||||
|
||||
// Call hipDeviceSynchronize and crash if it doesn't succeed.
|
||||
void SynchronizeOrDie() {
|
||||
auto res = hipDeviceSynchronize();
|
||||
auto res = tensorflow::wrap::hipDeviceSynchronize();
|
||||
if (res != hipSuccess) {
|
||||
LOG(FATAL) << "Synchronize found " << ToString(res)
|
||||
<< " :: " << port::CurrentStackTrace();
|
||||
@ -197,7 +198,8 @@ ScopedActivateContext::ScopedActivateContext(GpuContext* context) {
|
||||
<< tls->current_device_ordinal << " to " << context->device_ordinal();
|
||||
|
||||
// Set the device and update thread local.
|
||||
CHECK_EQ(hipSuccess, hipSetDevice(context->device_ordinal()));
|
||||
CHECK_EQ(hipSuccess,
|
||||
tensorflow::wrap::hipSetDevice(context->device_ordinal()));
|
||||
tls->current_device_ordinal = context->device_ordinal();
|
||||
}
|
||||
|
||||
@ -225,7 +227,8 @@ ScopedActivateContext::~ScopedActivateContext() {
|
||||
<< to_restore_->device_ordinal();
|
||||
|
||||
// Set context and update thread local.
|
||||
CHECK_EQ(hipSuccess, hipSetDevice(to_restore_->device_ordinal()));
|
||||
CHECK_EQ(hipSuccess,
|
||||
tensorflow::wrap::hipSetDevice(to_restore_->device_ordinal()));
|
||||
tls->current_device_ordinal = to_restore_->device_ordinal();
|
||||
}
|
||||
|
||||
@ -261,7 +264,8 @@ string ROCMPointerToMemorySpaceString(hipDeviceptr_t pointer) {
|
||||
// in the process of querying.
|
||||
string ROCMPointersToCanAccessString(hipDeviceptr_t from, hipDeviceptr_t to) {
|
||||
hipPointerAttribute_t from_pointerAttributes;
|
||||
hipError_t result = hipPointerGetAttributes(&from_pointerAttributes, from);
|
||||
hipError_t result =
|
||||
tensorflow::wrap::hipPointerGetAttributes(&from_pointerAttributes, from);
|
||||
if (result != hipSuccess) {
|
||||
LOG(ERROR) << "could not retrieve source pointer's device: "
|
||||
<< ToString(result);
|
||||
@ -269,7 +273,7 @@ string ROCMPointersToCanAccessString(hipDeviceptr_t from, hipDeviceptr_t to) {
|
||||
}
|
||||
|
||||
hipPointerAttribute_t to_pointerAttributes;
|
||||
result = hipPointerGetAttributes(&to_pointerAttributes, to);
|
||||
result = tensorflow::wrap::hipPointerGetAttributes(&to_pointerAttributes, to);
|
||||
if (result != hipSuccess) {
|
||||
LOG(ERROR) << "could not retrieve destination pointer's device: "
|
||||
<< ToString(result);
|
||||
@ -289,7 +293,7 @@ static port::Status InternalInit() {
|
||||
if (FLAGS_gpuexec_rocm_driver_inject_init_error) {
|
||||
LOG(ERROR) << "injecting ROCM init error; initialization will fail";
|
||||
} else {
|
||||
res = hipInit(0 /* = flags */);
|
||||
res = tensorflow::wrap::hipInit(0 /* = flags */);
|
||||
}
|
||||
|
||||
if (res == hipSuccess) {
|
||||
@ -322,7 +326,7 @@ static port::Status InternalInit() {
|
||||
|
||||
/* static */ port::Status GpuDriver::GetDevice(int device_ordinal,
|
||||
hipDevice_t* device) {
|
||||
hipError_t res = hipDeviceGet(device, device_ordinal);
|
||||
hipError_t res = tensorflow::wrap::hipDeviceGet(device, device_ordinal);
|
||||
if (res == hipSuccess) {
|
||||
return port::Status::OK();
|
||||
}
|
||||
@ -336,7 +340,8 @@ static port::Status InternalInit() {
|
||||
string* device_name) {
|
||||
static const size_t kCharLimit = 64;
|
||||
absl::InlinedVector<char, 4> chars(kCharLimit);
|
||||
hipError_t res = hipDeviceGetName(chars.begin(), kCharLimit - 1, device);
|
||||
hipError_t res =
|
||||
tensorflow::wrap::hipDeviceGetName(chars.begin(), kCharLimit - 1, device);
|
||||
if (res != hipSuccess) {
|
||||
LOG(ERROR) << "failed to get device name for " << device << ": "
|
||||
<< ToString(res);
|
||||
@ -382,7 +387,8 @@ bool DeviceOptionsToContextFlags(const DeviceOptions& device_options,
|
||||
|
||||
/* static */ bool GpuDriver::FuncSetCacheConfig(hipFunction_t function,
|
||||
hipFuncCache_t cache_config) {
|
||||
hipError_t res = hipFuncSetCacheConfig(function, cache_config);
|
||||
hipError_t res =
|
||||
tensorflow::wrap::hipFuncSetCacheConfig(function, cache_config);
|
||||
if (res != hipSuccess) {
|
||||
LOG(ERROR) << "failed to set ROCM kernel cache config. kernel: " << function
|
||||
<< ", config: " << cache_config << ", result: " << ToString(res);
|
||||
@ -396,7 +402,8 @@ bool DeviceOptionsToContextFlags(const DeviceOptions& device_options,
|
||||
GpuDriver::ContextGetSharedMemConfig(GpuContext* context) {
|
||||
hipSharedMemConfig shared_mem_config;
|
||||
ScopedActivateContext activation{context};
|
||||
hipError_t result = hipDeviceGetSharedMemConfig(&shared_mem_config);
|
||||
hipError_t result =
|
||||
tensorflow::wrap::hipDeviceGetSharedMemConfig(&shared_mem_config);
|
||||
if (result != hipSuccess) {
|
||||
LOG(ERROR) << "failed to get ROCM device shared memory config. "
|
||||
<< "Context device ID: " << context->device_ordinal()
|
||||
@ -411,7 +418,8 @@ GpuDriver::ContextGetSharedMemConfig(GpuContext* context) {
|
||||
/* static */ port::Status GpuDriver::ContextSetSharedMemConfig(
|
||||
GpuContext* context, hipSharedMemConfig shared_mem_config) {
|
||||
ScopedActivateContext activation{context};
|
||||
hipError_t result = hipDeviceSetSharedMemConfig(shared_mem_config);
|
||||
hipError_t result =
|
||||
tensorflow::wrap::hipDeviceSetSharedMemConfig(shared_mem_config);
|
||||
if (result != hipSuccess) {
|
||||
LOG(ERROR) << "failed to set ROCM device shared memory config. "
|
||||
<< "Context device ID: " << context->device_ordinal()
|
||||
@ -435,7 +443,7 @@ GpuDriver::ContextGetSharedMemConfig(GpuContext* context) {
|
||||
<< " gdy: " << grid_dim_y << " gdz: " << grid_dim_z
|
||||
<< " bdx: " << block_dim_x << " bdy: " << block_dim_y
|
||||
<< " bdz: " << block_dim_z << " smem: " << shared_mem_bytes;
|
||||
hipError_t res = hipModuleLaunchKernel(
|
||||
hipError_t res = tensorflow::wrap::hipModuleLaunchKernel(
|
||||
function, grid_dim_x, grid_dim_y, grid_dim_z, block_dim_x, block_dim_y,
|
||||
block_dim_z, shared_mem_bytes, stream, kernel_params, extra);
|
||||
if (res != hipSuccess) {
|
||||
@ -471,7 +479,8 @@ GpuDriver::ContextGetSharedMemConfig(GpuContext* context) {
|
||||
ScopedActivateContext activation{context};
|
||||
void* hsaco_data = const_cast<char*>(hsaco_contents);
|
||||
|
||||
hipError_t res = hipModuleLoadData(module, hsaco_data);
|
||||
hipError_t res =
|
||||
tensorflow::wrap::hipModuleLoadData(module, hsaco_data);
|
||||
|
||||
if (res != hipSuccess) {
|
||||
LOG(ERROR) << "failed to load HSACO: " << ToString(res);
|
||||
@ -491,7 +500,7 @@ GpuDriver::ContextGetSharedMemConfig(GpuContext* context) {
|
||||
hipDeviceptr_t location,
|
||||
uint8 value, size_t size) {
|
||||
ScopedActivateContext activation{context};
|
||||
hipError_t res = hipMemset(location, value, size);
|
||||
hipError_t res = tensorflow::wrap::hipMemset(location, value, size);
|
||||
if (res != hipSuccess) {
|
||||
LOG(ERROR) << "failed to memset memory: " << ToString(res);
|
||||
return false;
|
||||
@ -512,8 +521,8 @@ GpuDriver::ContextGetSharedMemConfig(GpuContext* context) {
|
||||
LOG(ERROR) << "failed to memset memory";
|
||||
return false;
|
||||
}
|
||||
hipError_t res =
|
||||
hipMemset(pointer, static_cast<int>(value), uint32_count * 4);
|
||||
hipError_t res = tensorflow::wrap::hipMemset(pointer, static_cast<int>(value),
|
||||
uint32_count * 4);
|
||||
if (res != hipSuccess) {
|
||||
LOG(ERROR) << "failed to memset memory: " << ToString(res);
|
||||
return false;
|
||||
@ -527,7 +536,8 @@ GpuDriver::ContextGetSharedMemConfig(GpuContext* context) {
|
||||
size_t uint32_count,
|
||||
GpuStreamHandle stream) {
|
||||
ScopedActivateContext activation{context};
|
||||
hipError_t res = hipMemsetAsync(location, value, uint32_count, stream);
|
||||
hipError_t res =
|
||||
tensorflow::wrap::hipMemsetAsync(location, value, uint32_count, stream);
|
||||
if (res != hipSuccess) {
|
||||
LOG(ERROR) << "failed to enqueue async memset operation: " << ToString(res);
|
||||
return false;
|
||||
@ -552,7 +562,8 @@ GpuDriver::ContextGetSharedMemConfig(GpuContext* context) {
|
||||
LOG(ERROR) << "failed to memset memory";
|
||||
return false;
|
||||
}
|
||||
hipError_t res = hipMemsetAsync(pointer, value, uint32_count * 4, stream);
|
||||
hipError_t res = tensorflow::wrap::hipMemsetAsync(pointer, value,
|
||||
uint32_count * 4, stream);
|
||||
if (res != hipSuccess) {
|
||||
LOG(ERROR) << "failed to enqueue async memset operation: " << ToString(res);
|
||||
return false;
|
||||
@ -565,8 +576,8 @@ GpuDriver::ContextGetSharedMemConfig(GpuContext* context) {
|
||||
GpuStreamHandle stream,
|
||||
StreamCallback callback,
|
||||
void* data) {
|
||||
hipError_t res = hipStreamAddCallback(stream, (hipStreamCallback_t)callback,
|
||||
data, 0 /* = flags */);
|
||||
hipError_t res = tensorflow::wrap::hipStreamAddCallback(
|
||||
stream, (hipStreamCallback_t)callback, data, 0 /* = flags */);
|
||||
if (res != hipSuccess) {
|
||||
LOG(ERROR) << "unable to add host callback: " << ToString(res);
|
||||
return false;
|
||||
@ -580,7 +591,8 @@ GpuDriver::ContextGetSharedMemConfig(GpuContext* context) {
|
||||
hipFunction_t* function) {
|
||||
ScopedActivateContext activated{context};
|
||||
CHECK(module != nullptr && kernel_name != nullptr);
|
||||
hipError_t res = hipModuleGetFunction(function, module, kernel_name);
|
||||
hipError_t res =
|
||||
tensorflow::wrap::hipModuleGetFunction(function, module, kernel_name);
|
||||
if (res != hipSuccess) {
|
||||
LOG(ERROR) << "failed to get kernel \"" << kernel_name
|
||||
<< "\" from module: " << ToString(res);
|
||||
@ -598,7 +610,8 @@ GpuDriver::ContextGetSharedMemConfig(GpuContext* context) {
|
||||
ScopedActivateContext activated{context};
|
||||
CHECK(module != nullptr && symbol_name != nullptr &&
|
||||
(dptr != nullptr || bytes != nullptr));
|
||||
hipError_t res = hipModuleGetGlobal(dptr, bytes, module, symbol_name);
|
||||
hipError_t res =
|
||||
tensorflow::wrap::hipModuleGetGlobal(dptr, bytes, module, symbol_name);
|
||||
if (res != hipSuccess) {
|
||||
// symbol may not be found in the current module, but it may reside in
|
||||
// another module.
|
||||
@ -613,7 +626,7 @@ GpuDriver::ContextGetSharedMemConfig(GpuContext* context) {
|
||||
/* static */ void GpuDriver::UnloadModule(GpuContext* context,
|
||||
hipModule_t module) {
|
||||
ScopedActivateContext activated{context};
|
||||
hipError_t res = hipModuleUnload(module);
|
||||
hipError_t res = tensorflow::wrap::hipModuleUnload(module);
|
||||
if (res != hipSuccess) {
|
||||
LOG(ERROR) << "failed to unload module " << module
|
||||
<< "; leaking: " << ToString(res);
|
||||
@ -623,7 +636,7 @@ GpuDriver::ContextGetSharedMemConfig(GpuContext* context) {
|
||||
/* static */ bool GpuDriver::CreateStream(GpuContext* context,
|
||||
GpuStreamHandle* stream) {
|
||||
ScopedActivateContext activated{context};
|
||||
hipError_t res = hipStreamCreateWithFlags(
|
||||
hipError_t res = tensorflow::wrap::hipStreamCreateWithFlags(
|
||||
stream, hipStreamDefault); // switch to hipStreamNonBlocking?
|
||||
if (res != hipSuccess) {
|
||||
LOG(ERROR) << "could not allocate ROCM stream for device "
|
||||
@ -643,7 +656,7 @@ GpuDriver::ContextGetSharedMemConfig(GpuContext* context) {
|
||||
}
|
||||
|
||||
ScopedActivateContext activated{context};
|
||||
hipError_t res = hipStreamDestroy(*stream);
|
||||
hipError_t res = tensorflow::wrap::hipStreamDestroy(*stream);
|
||||
if (res != hipSuccess) {
|
||||
LOG(ERROR) << "failed to destroy ROCM stream for device "
|
||||
<< context->device_ordinal() << ": " << ToString(res);
|
||||
@ -658,7 +671,7 @@ GpuDriver::ContextGetSharedMemConfig(GpuContext* context) {
|
||||
uint64 bytes) {
|
||||
ScopedActivateContext activated{context};
|
||||
hipDeviceptr_t result = 0;
|
||||
hipError_t res = hipMalloc(&result, bytes);
|
||||
hipError_t res = tensorflow::wrap::hipMallocVanilla(&result, bytes);
|
||||
if (res != hipSuccess) {
|
||||
LOG(ERROR) << "failed to allocate "
|
||||
<< port::HumanReadableNumBytes::ToString(bytes) << " (" << bytes
|
||||
@ -675,7 +688,7 @@ GpuDriver::ContextGetSharedMemConfig(GpuContext* context) {
|
||||
void* location) {
|
||||
ScopedActivateContext activation{context};
|
||||
hipDeviceptr_t pointer = absl::bit_cast<hipDeviceptr_t>(location);
|
||||
hipError_t res = hipFree(pointer);
|
||||
hipError_t res = tensorflow::wrap::hipFree(pointer);
|
||||
if (res != hipSuccess) {
|
||||
LOG(ERROR) << "failed to free device memory at " << location
|
||||
<< "; result: " << ToString(res);
|
||||
@ -704,7 +717,8 @@ GpuDriver::ContextGetSharedMemConfig(GpuContext* context) {
|
||||
ScopedActivateContext activation{context};
|
||||
void* host_mem = nullptr;
|
||||
// "Portable" memory is visible to all ROCM contexts. Safe for our use model.
|
||||
hipError_t res = hipHostMalloc(&host_mem, bytes, hipHostMallocPortable);
|
||||
hipError_t res = tensorflow::wrap::hipHostMallocVanilla(
|
||||
&host_mem, bytes, hipHostMallocPortable);
|
||||
if (res != hipSuccess) {
|
||||
LOG(ERROR) << "failed to alloc " << bytes
|
||||
<< " bytes on host: " << ToString(res);
|
||||
@ -715,7 +729,7 @@ GpuDriver::ContextGetSharedMemConfig(GpuContext* context) {
|
||||
/* static */ void GpuDriver::HostDeallocate(GpuContext* context,
|
||||
void* location) {
|
||||
ScopedActivateContext activation{context};
|
||||
hipError_t res = hipHostFree(location);
|
||||
hipError_t res = tensorflow::wrap::hipHostFree(location);
|
||||
if (res != hipSuccess) {
|
||||
LOG(ERROR) << "error deallocating host memory at " << location << ": "
|
||||
<< ToString(res);
|
||||
@ -726,7 +740,8 @@ GpuDriver::ContextGetSharedMemConfig(GpuContext* context) {
|
||||
uint64 bytes) {
|
||||
ScopedActivateContext activation{context};
|
||||
// "Portable" memory is visible to all ROCM contexts. Safe for our use model.
|
||||
hipError_t res = hipHostRegister(location, bytes, hipHostRegisterPortable);
|
||||
hipError_t res = tensorflow::wrap::hipHostRegister(location, bytes,
|
||||
hipHostRegisterPortable);
|
||||
if (res != hipSuccess) {
|
||||
LOG(ERROR) << "error registering host memory at " << location << ": "
|
||||
<< ToString(res);
|
||||
@ -738,7 +753,7 @@ GpuDriver::ContextGetSharedMemConfig(GpuContext* context) {
|
||||
/* static */ bool GpuDriver::HostUnregister(GpuContext* context,
|
||||
void* location) {
|
||||
ScopedActivateContext activation{context};
|
||||
hipError_t res = hipHostUnregister(location);
|
||||
hipError_t res = tensorflow::wrap::hipHostUnregister(location);
|
||||
if (res != hipSuccess) {
|
||||
LOG(ERROR) << "error unregistering host memory at " << location << ": "
|
||||
<< ToString(res);
|
||||
@ -755,7 +770,7 @@ GpuDriver::ContextGetSharedMemConfig(GpuContext* context) {
|
||||
}
|
||||
|
||||
ScopedActivateContext activated{context};
|
||||
hipError_t res = hipEventDestroy(*event);
|
||||
hipError_t res = tensorflow::wrap::hipEventDestroy(*event);
|
||||
*event = nullptr;
|
||||
|
||||
switch (res) {
|
||||
@ -779,7 +794,7 @@ GpuDriver::ContextGetSharedMemConfig(GpuContext* context) {
|
||||
GpuEventHandle event,
|
||||
GpuStreamHandle stream) {
|
||||
ScopedActivateContext activated{context};
|
||||
hipError_t res = hipEventRecord(event, stream);
|
||||
hipError_t res = tensorflow::wrap::hipEventRecord(event, stream);
|
||||
switch (res) {
|
||||
case hipSuccess:
|
||||
return port::Status::OK();
|
||||
@ -800,7 +815,7 @@ GpuDriver::ContextGetSharedMemConfig(GpuContext* context) {
|
||||
/* static */ port::StatusOr<hipError_t> GpuDriver::QueryEvent(
|
||||
GpuContext* context, GpuEventHandle event) {
|
||||
ScopedActivateContext activated{context};
|
||||
hipError_t res = hipEventQuery(event);
|
||||
hipError_t res = tensorflow::wrap::hipEventQuery(event);
|
||||
if (res != hipSuccess && res != hipErrorNotReady) {
|
||||
return port::Status{
|
||||
port::error::INTERNAL,
|
||||
@ -817,12 +832,13 @@ GpuDriver::ContextGetSharedMemConfig(GpuContext* context) {
|
||||
ScopedActivateContext activated{context};
|
||||
// The stop event must have completed in order for hipEventElapsedTime to
|
||||
// work.
|
||||
hipError_t res = hipEventSynchronize(stop);
|
||||
hipError_t res = tensorflow::wrap::hipEventSynchronize(stop);
|
||||
if (res != hipSuccess) {
|
||||
LOG(ERROR) << "failed to synchronize the stop event: " << ToString(res);
|
||||
return false;
|
||||
}
|
||||
res = hipEventElapsedTime(elapsed_milliseconds, start, stop);
|
||||
res =
|
||||
tensorflow::wrap::hipEventElapsedTime(elapsed_milliseconds, start, stop);
|
||||
if (res != hipSuccess) {
|
||||
LOG(ERROR) << "failed to get elapsed time between events: "
|
||||
<< ToString(res);
|
||||
@ -836,7 +852,8 @@ GpuDriver::ContextGetSharedMemConfig(GpuContext* context) {
|
||||
GpuStreamHandle stream,
|
||||
GpuEventHandle event) {
|
||||
ScopedActivateContext activation{context};
|
||||
hipError_t res = hipStreamWaitEvent(stream, event, 0 /* = flags */);
|
||||
hipError_t res =
|
||||
tensorflow::wrap::hipStreamWaitEvent(stream, event, 0 /* = flags */);
|
||||
if (res != hipSuccess) {
|
||||
LOG(ERROR) << "could not wait stream on event: " << ToString(res);
|
||||
return false;
|
||||
@ -847,7 +864,7 @@ GpuDriver::ContextGetSharedMemConfig(GpuContext* context) {
|
||||
|
||||
/* static */ bool GpuDriver::SynchronizeContext(GpuContext* context) {
|
||||
ScopedActivateContext activation{context};
|
||||
hipError_t res = hipDeviceSynchronize();
|
||||
hipError_t res = tensorflow::wrap::hipDeviceSynchronize();
|
||||
if (res != hipSuccess) {
|
||||
LOG(ERROR) << "could not synchronize on ROCM device: " << ToString(res)
|
||||
<< " :: " << port::CurrentStackTrace();
|
||||
@ -861,7 +878,7 @@ GpuDriver::ContextGetSharedMemConfig(GpuContext* context) {
|
||||
GpuStreamHandle stream) {
|
||||
ScopedActivateContext activated{context};
|
||||
CHECK(stream != nullptr);
|
||||
hipError_t res = hipStreamSynchronize(stream);
|
||||
hipError_t res = tensorflow::wrap::hipStreamSynchronize(stream);
|
||||
if (res != hipSuccess) {
|
||||
port::Status status = port::InternalError(
|
||||
absl::StrCat("could not synchronize on ROCM stream: ", ToString(res)));
|
||||
@ -877,7 +894,7 @@ GpuDriver::ContextGetSharedMemConfig(GpuContext* context) {
|
||||
GpuStreamHandle stream) {
|
||||
ScopedActivateContext activated{context};
|
||||
CHECK(stream != nullptr);
|
||||
hipError_t res = hipStreamQuery(stream);
|
||||
hipError_t res = tensorflow::wrap::hipStreamQuery(stream);
|
||||
if (res == hipSuccess) {
|
||||
return true;
|
||||
}
|
||||
@ -891,7 +908,7 @@ GpuDriver::ContextGetSharedMemConfig(GpuContext* context) {
|
||||
/* static */ port::Status GpuDriver::SynchronousMemcpyD2H(
|
||||
GpuContext* context, void* host_dst, hipDeviceptr_t gpu_src, uint64 size) {
|
||||
ScopedActivateContext activation{context};
|
||||
hipError_t res = hipMemcpyDtoH(host_dst, gpu_src, size);
|
||||
hipError_t res = tensorflow::wrap::hipMemcpyDtoH(host_dst, gpu_src, size);
|
||||
if (res != hipSuccess) {
|
||||
return port::InternalError(
|
||||
absl::StrFormat("failed to synchronous memcpy from device to host: %s; "
|
||||
@ -908,7 +925,8 @@ GpuDriver::ContextGetSharedMemConfig(GpuContext* context) {
|
||||
GpuContext* context, hipDeviceptr_t gpu_dst, const void* host_src,
|
||||
uint64 size) {
|
||||
ScopedActivateContext activation{context};
|
||||
hipError_t res = hipMemcpyHtoD(gpu_dst, const_cast<void*>(host_src), size);
|
||||
hipError_t res = tensorflow::wrap::hipMemcpyHtoD(
|
||||
gpu_dst, const_cast<void*>(host_src), size);
|
||||
if (res != hipSuccess) {
|
||||
return port::InternalError(absl::StrFormat(
|
||||
"failed to synchronous memcpy from host to device: %s; Gpu dst: %p;"
|
||||
@ -924,7 +942,7 @@ GpuDriver::ContextGetSharedMemConfig(GpuContext* context) {
|
||||
GpuContext* context, hipDeviceptr_t gpu_dst, hipDeviceptr_t gpu_src,
|
||||
uint64 size) {
|
||||
ScopedActivateContext activation{context};
|
||||
hipError_t res = hipMemcpyDtoD(gpu_dst, gpu_src, size);
|
||||
hipError_t res = tensorflow::wrap::hipMemcpyDtoD(gpu_dst, gpu_src, size);
|
||||
if (res != hipSuccess) {
|
||||
return port::InternalError(absl::StrFormat(
|
||||
"failed to synchronous memcpy from host to device: %s; Gpu dst: %p; "
|
||||
@ -942,7 +960,8 @@ GpuDriver::ContextGetSharedMemConfig(GpuContext* context) {
|
||||
uint64 size,
|
||||
GpuStreamHandle stream) {
|
||||
ScopedActivateContext activation{context};
|
||||
hipError_t res = hipMemcpyDtoHAsync(host_dst, gpu_src, size, stream);
|
||||
hipError_t res =
|
||||
tensorflow::wrap::hipMemcpyDtoHAsync(host_dst, gpu_src, size, stream);
|
||||
if (res != hipSuccess) {
|
||||
LOG(ERROR) << absl::StrFormat(
|
||||
"failed to enqueue async memcpy from device to host: %s; host dst: %p; "
|
||||
@ -963,8 +982,8 @@ GpuDriver::ContextGetSharedMemConfig(GpuContext* context) {
|
||||
uint64 size,
|
||||
GpuStreamHandle stream) {
|
||||
ScopedActivateContext activation{context};
|
||||
hipError_t res =
|
||||
hipMemcpyHtoDAsync(gpu_dst, const_cast<void*>(host_src), size, stream);
|
||||
hipError_t res = tensorflow::wrap::hipMemcpyHtoDAsync(
|
||||
gpu_dst, const_cast<void*>(host_src), size, stream);
|
||||
if (res != hipSuccess) {
|
||||
LOG(ERROR) << absl::StrFormat(
|
||||
"failed to enqueue async memcpy from host to device: %s; Gpu dst: %p; "
|
||||
@ -984,7 +1003,8 @@ GpuDriver::ContextGetSharedMemConfig(GpuContext* context) {
|
||||
uint64 size,
|
||||
GpuStreamHandle stream) {
|
||||
ScopedActivateContext activation{context};
|
||||
hipError_t result = hipMemcpyDtoDAsync(gpu_dst, gpu_src, size, stream);
|
||||
hipError_t result =
|
||||
tensorflow::wrap::hipMemcpyDtoDAsync(gpu_dst, gpu_src, size, stream);
|
||||
if (result != hipSuccess) {
|
||||
LOG(ERROR) << absl::StrFormat(
|
||||
"failed to enqueue async memcpy from device to device: %s"
|
||||
@ -1021,7 +1041,7 @@ GpuDriver::ContextGetSharedMemConfig(GpuContext* context) {
|
||||
}
|
||||
|
||||
ScopedActivateContext activated{context};
|
||||
hipError_t res = hipEventCreateWithFlags(event, hipflags);
|
||||
hipError_t res = tensorflow::wrap::hipEventCreateWithFlags(event, hipflags);
|
||||
|
||||
if (res == hipSuccess) {
|
||||
return port::Status::OK();
|
||||
@ -1037,7 +1057,7 @@ GpuDriver::ContextGetSharedMemConfig(GpuContext* context) {
|
||||
|
||||
/* static */ int GpuDriver::GetDeviceCount() {
|
||||
int device_count = 0;
|
||||
hipError_t res = hipGetDeviceCount(&device_count);
|
||||
hipError_t res = tensorflow::wrap::hipGetDeviceCount(&device_count);
|
||||
if (res != hipSuccess) {
|
||||
LOG(ERROR) << "could not retrieve ROCM device count: " << ToString(res);
|
||||
return 0;
|
||||
@ -1061,7 +1081,7 @@ GpuDriver::ContextGetSharedMemConfig(GpuContext* context) {
|
||||
|
||||
/* static */ port::Status GpuDriver::GetPointerAddressRange(
|
||||
hipDeviceptr_t dptr, hipDeviceptr_t* base, size_t* size) {
|
||||
hipError_t result = hipMemGetAddressRange(base, size, dptr);
|
||||
hipError_t result = tensorflow::wrap::hipMemGetAddressRange(base, size, dptr);
|
||||
if (result == hipSuccess) {
|
||||
return port::Status::OK();
|
||||
} else if (result == hipErrorNotFound) {
|
||||
@ -1106,7 +1126,8 @@ GpuDriver::ContextGetSharedMemConfig(GpuContext* context) {
|
||||
/* static */ port::StatusOr<hipDevice_t> GpuDriver::GetPointerDevice(
|
||||
hipDeviceptr_t pointer) {
|
||||
hipPointerAttribute_t pointerAttributes;
|
||||
hipError_t result = hipPointerGetAttributes(&pointerAttributes, pointer);
|
||||
hipError_t result =
|
||||
tensorflow::wrap::hipPointerGetAttributes(&pointerAttributes, pointer);
|
||||
if (result != hipSuccess) {
|
||||
return port::Status{
|
||||
port::error::INTERNAL,
|
||||
@ -1114,7 +1135,7 @@ GpuDriver::ContextGetSharedMemConfig(GpuContext* context) {
|
||||
}
|
||||
|
||||
hipDevice_t device;
|
||||
result = hipDeviceGet(&device, pointerAttributes.device);
|
||||
result = tensorflow::wrap::hipDeviceGet(&device, pointerAttributes.device);
|
||||
if (result != hipSuccess) {
|
||||
return port::Status{
|
||||
port::error::INTERNAL,
|
||||
@ -1127,7 +1148,7 @@ GpuDriver::ContextGetSharedMemConfig(GpuContext* context) {
|
||||
/* static */ port::Status GpuDriver::GetGpuISAVersion(int* version,
|
||||
hipDevice_t device) {
|
||||
hipDeviceProp_t props;
|
||||
hipError_t result = hipGetDeviceProperties(&props, device);
|
||||
hipError_t result = tensorflow::wrap::hipGetDeviceProperties(&props, device);
|
||||
if (result == hipSuccess) {
|
||||
*version = props.gcnArch;
|
||||
return port::Status::OK();
|
||||
@ -1145,7 +1166,8 @@ template <typename T>
|
||||
static port::StatusOr<T> GetSimpleAttribute(hipDevice_t device,
|
||||
hipDeviceAttribute_t attribute) {
|
||||
int value = -1;
|
||||
hipError_t result = hipDeviceGetAttribute(&value, attribute, device);
|
||||
hipError_t result =
|
||||
tensorflow::wrap::hipDeviceGetAttribute(&value, attribute, device);
|
||||
if (result != hipSuccess) {
|
||||
return port::Status{
|
||||
port::error::NOT_FOUND,
|
||||
@ -1199,22 +1221,24 @@ static port::StatusOr<T> GetSimpleAttribute(hipDevice_t device,
|
||||
/* static */ bool GpuDriver::GetGridLimits(int* x, int* y, int* z,
|
||||
hipDevice_t device) {
|
||||
int value;
|
||||
hipError_t res =
|
||||
hipDeviceGetAttribute(&value, hipDeviceAttributeMaxGridDimX, device);
|
||||
hipError_t res = tensorflow::wrap::hipDeviceGetAttribute(
|
||||
&value, hipDeviceAttributeMaxGridDimX, device);
|
||||
if (res != hipSuccess) {
|
||||
LOG(ERROR) << "failed to query max grid dim x: " << ToString(res);
|
||||
return false;
|
||||
}
|
||||
*x = value;
|
||||
|
||||
res = hipDeviceGetAttribute(&value, hipDeviceAttributeMaxGridDimY, device);
|
||||
res = tensorflow::wrap::hipDeviceGetAttribute(
|
||||
&value, hipDeviceAttributeMaxGridDimY, device);
|
||||
if (res != hipSuccess) {
|
||||
LOG(ERROR) << "failed to query max grid dim y: " << ToString(res);
|
||||
return false;
|
||||
}
|
||||
*y = value;
|
||||
|
||||
res = hipDeviceGetAttribute(&value, hipDeviceAttributeMaxGridDimZ, device);
|
||||
res = tensorflow::wrap::hipDeviceGetAttribute(
|
||||
&value, hipDeviceAttributeMaxGridDimZ, device);
|
||||
if (res != hipSuccess) {
|
||||
LOG(ERROR) << "failed to query max grid dim z: " << ToString(res);
|
||||
return false;
|
||||
@ -1224,7 +1248,7 @@ static port::StatusOr<T> GetSimpleAttribute(hipDevice_t device,
|
||||
}
|
||||
|
||||
/* static */ bool GpuDriver::GetDriverVersion(int* driver_version) {
|
||||
hipError_t res = hipDriverGetVersion(driver_version);
|
||||
hipError_t res = tensorflow::wrap::hipDriverGetVersion(driver_version);
|
||||
if (res != hipSuccess) {
|
||||
LOG(ERROR) << "failed to query driver version: " << ToString(res);
|
||||
return false;
|
||||
@ -1235,7 +1259,8 @@ static port::StatusOr<T> GetSimpleAttribute(hipDevice_t device,
|
||||
|
||||
/* static */ bool GpuDriver::GetDeviceProperties(
|
||||
hipDeviceProp_t* device_properties, int device_ordinal) {
|
||||
hipError_t res = hipGetDeviceProperties(device_properties, device_ordinal);
|
||||
hipError_t res = tensorflow::wrap::hipGetDeviceProperties(device_properties,
|
||||
device_ordinal);
|
||||
if (res != hipSuccess) {
|
||||
LOG(ERROR) << "failed to query device properties: " << ToString(res);
|
||||
return false;
|
||||
@ -1268,7 +1293,7 @@ static port::StatusOr<T> GetSimpleAttribute(hipDevice_t device,
|
||||
ScopedActivateContext activation{context};
|
||||
size_t free = 0;
|
||||
size_t total = 0;
|
||||
hipError_t res = hipMemGetInfo(&free, &total);
|
||||
hipError_t res = tensorflow::wrap::hipMemGetInfo(&free, &total);
|
||||
if (res != hipSuccess) {
|
||||
LOG(ERROR) << "failed to query device memory info: " << ToString(res);
|
||||
return false;
|
||||
@ -1282,7 +1307,7 @@ static port::StatusOr<T> GetSimpleAttribute(hipDevice_t device,
|
||||
/* static */ bool GpuDriver::GetDeviceTotalMemory(hipDevice_t device,
|
||||
uint64* result) {
|
||||
size_t value = -1;
|
||||
hipError_t res = hipDeviceTotalMem(&value, device);
|
||||
hipError_t res = tensorflow::wrap::hipDeviceTotalMem(&value, device);
|
||||
if (res != hipSuccess) {
|
||||
LOG(ERROR) << "failed to query total available memory: " << ToString(res);
|
||||
return false;
|
||||
@ -1297,7 +1322,8 @@ static port::StatusOr<T> GetSimpleAttribute(hipDevice_t device,
|
||||
static const int kBufferSize = 64;
|
||||
absl::InlinedVector<char, 4> chars(kBufferSize);
|
||||
chars[kBufferSize - 1] = '\0';
|
||||
hipError_t res = hipDeviceGetPCIBusId(chars.begin(), kBufferSize - 1, device);
|
||||
hipError_t res = tensorflow::wrap::hipDeviceGetPCIBusId(
|
||||
chars.begin(), kBufferSize - 1, device);
|
||||
if (res != hipSuccess) {
|
||||
LOG(ERROR) << "failed to query PCI bus id for device: " << ToString(res);
|
||||
return pci_bus_id;
|
||||
@ -1313,7 +1339,7 @@ static port::StatusOr<T> GetSimpleAttribute(hipDevice_t device,
|
||||
}
|
||||
|
||||
int can_access_peer = -1;
|
||||
hipError_t res = hipDeviceCanAccessPeer(
|
||||
hipError_t res = tensorflow::wrap::hipDeviceCanAccessPeer(
|
||||
&can_access_peer, from->device_ordinal(), to->device_ordinal());
|
||||
if (res != hipSuccess) {
|
||||
LOG(ERROR) << "failed to detect peer access capability: " << ToString(res);
|
||||
@ -1330,8 +1356,8 @@ static port::StatusOr<T> GetSimpleAttribute(hipDevice_t device,
|
||||
}
|
||||
|
||||
ScopedActivateContext activated{from};
|
||||
hipError_t result =
|
||||
hipDeviceEnablePeerAccess(to->device_ordinal(), 0 /* = flags */);
|
||||
hipError_t result = tensorflow::wrap::hipDeviceEnablePeerAccess(
|
||||
to->device_ordinal(), 0 /* = flags */);
|
||||
if (result != hipSuccess && result != hipErrorPeerAccessAlreadyEnabled) {
|
||||
return port::Status{
|
||||
port::error::INTERNAL,
|
||||
|
146
tensorflow/stream_executor/rocm/rocm_driver_wrapper.h
Normal file
146
tensorflow/stream_executor/rocm/rocm_driver_wrapper.h
Normal file
@ -0,0 +1,146 @@
|
||||
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
// This file wraps rocm driver calls with dso loader so that we don't need to
|
||||
// have explicit linking to librocm. All TF rocm driver usage should route
|
||||
// through this wrapper.
|
||||
|
||||
#ifndef TENSORFLOW_STREAM_EXECUTOR_ROCM_ROCM_DRIVER_WRAPPER_H_
|
||||
#define TENSORFLOW_STREAM_EXECUTOR_ROCM_ROCM_DRIVER_WRAPPER_H_
|
||||
|
||||
#include "rocm/include/hip/hip_runtime.h"
|
||||
#include "tensorflow/stream_executor/lib/env.h"
|
||||
#include "tensorflow/stream_executor/platform/dso_loader.h"
|
||||
#include "tensorflow/stream_executor/platform/port.h"
|
||||
|
||||
#if defined(TENSORFLOW_USE_ROCM)
|
||||
|
||||
#endif
|
||||
|
||||
namespace tensorflow {
|
||||
namespace wrap {
|
||||
#ifdef PLATFORM_GOOGLE
|
||||
// Use static linked library
|
||||
#define STREAM_EXECUTOR_HIP_WRAP(hipSymbolName) \
|
||||
template <typename... Args> \
|
||||
auto hipSymbolName(Args... args)->decltype(::hipSymbolName(args...)) { \
|
||||
return ::hipSymbolName(args...); \
|
||||
}
|
||||
|
||||
// This macro wraps a global identifier, given by hipSymbolName, in a callable
|
||||
// structure that loads the DLL symbol out of the DSO handle in a thread-safe
|
||||
// manner on first use. This dynamic loading technique is used to avoid DSO
|
||||
// dependencies on vendor libraries which may or may not be available in the
|
||||
// deployed binary environment.
|
||||
#else
|
||||
#define TO_STR_(x) #x
|
||||
#define TO_STR(x) TO_STR_(x)
|
||||
|
||||
// hipMalloc and hipHostMalloc are defined as funtion templates in the
|
||||
// HIP header files, and hence their names get mangled and the attempt
|
||||
// to resolve their name when trying to dynamically load them will fail
|
||||
// Updating the HIP header files to make them C functions is underway.
|
||||
// Until that change flows through, we will workaround the issue by
|
||||
// creating dummy wrappers for them here
|
||||
|
||||
hipError_t hipMallocVanilla(void** ptr, size_t size) {
|
||||
return hipErrorNotInitialized;
|
||||
}
|
||||
|
||||
hipError_t hipHostMallocVanilla(void** ptr, size_t size, unsigned int flags) {
|
||||
return hipErrorNotInitialized;
|
||||
}
|
||||
|
||||
#define STREAM_EXECUTOR_HIP_WRAP(hipSymbolName) \
|
||||
template <typename... Args> \
|
||||
auto hipSymbolName(Args... args)->decltype(::hipSymbolName(args...)) { \
|
||||
using FuncPtrT = std::add_pointer<decltype(::hipSymbolName)>::type; \
|
||||
static FuncPtrT loaded = []() -> FuncPtrT { \
|
||||
static const char *kName = TO_STR(hipSymbolName); \
|
||||
void *f; \
|
||||
auto s = stream_executor::port::Env::Default()->GetSymbolFromLibrary( \
|
||||
stream_executor::internal::CachedDsoLoader::GetHipDsoHandle() \
|
||||
.ValueOrDie(), \
|
||||
kName, &f); \
|
||||
CHECK(s.ok()) << "could not find " << kName \
|
||||
<< " in HIP DSO; dlerror: " << s.error_message(); \
|
||||
return reinterpret_cast<FuncPtrT>(f); \
|
||||
}(); \
|
||||
return loaded(args...); \
|
||||
}
|
||||
#endif
|
||||
|
||||
// clang-format off
|
||||
#define HIP_ROUTINE_EACH(__macro) \
|
||||
__macro(hipDeviceCanAccessPeer) \
|
||||
__macro(hipDeviceEnablePeerAccess) \
|
||||
__macro(hipDeviceGet) \
|
||||
__macro(hipDeviceGetAttribute) \
|
||||
__macro(hipDeviceGetName) \
|
||||
__macro(hipDeviceGetPCIBusId) \
|
||||
__macro(hipDeviceGetSharedMemConfig) \
|
||||
__macro(hipDeviceSetSharedMemConfig) \
|
||||
__macro(hipDeviceSynchronize) \
|
||||
__macro(hipDeviceTotalMem) \
|
||||
__macro(hipDriverGetVersion) \
|
||||
__macro(hipEventCreateWithFlags) \
|
||||
__macro(hipEventElapsedTime) \
|
||||
__macro(hipEventDestroy) \
|
||||
__macro(hipEventQuery) \
|
||||
__macro(hipEventRecord) \
|
||||
__macro(hipEventSynchronize) \
|
||||
__macro(hipFree) \
|
||||
__macro(hipFuncSetCacheConfig) \
|
||||
__macro(hipGetDevice) \
|
||||
__macro(hipGetDeviceCount) \
|
||||
__macro(hipGetDeviceProperties) \
|
||||
__macro(hipHostFree) \
|
||||
__macro(hipHostRegister) \
|
||||
__macro(hipHostUnregister) \
|
||||
__macro(hipInit) \
|
||||
__macro(hipMemGetAddressRange) \
|
||||
__macro(hipMemGetInfo) \
|
||||
__macro(hipMemcpyDtoD) \
|
||||
__macro(hipMemcpyDtoDAsync) \
|
||||
__macro(hipMemcpyDtoH) \
|
||||
__macro(hipMemcpyDtoHAsync) \
|
||||
__macro(hipMemcpyHtoD) \
|
||||
__macro(hipMemcpyHtoDAsync) \
|
||||
__macro(hipMemset) \
|
||||
__macro(hipMemsetAsync) \
|
||||
__macro(hipModuleGetFunction) \
|
||||
__macro(hipModuleGetGlobal) \
|
||||
__macro(hipModuleLaunchKernel) \
|
||||
__macro(hipModuleLoadData) \
|
||||
__macro(hipModuleUnload) \
|
||||
__macro(hipPointerGetAttributes) \
|
||||
__macro(hipSetDevice) \
|
||||
__macro(hipStreamAddCallback) \
|
||||
__macro(hipStreamCreateWithFlags) \
|
||||
__macro(hipStreamDestroy) \
|
||||
__macro(hipStreamQuery) \
|
||||
__macro(hipStreamSynchronize) \
|
||||
__macro(hipStreamWaitEvent) \
|
||||
// clang-format on
|
||||
|
||||
HIP_ROUTINE_EACH(STREAM_EXECUTOR_HIP_WRAP)
|
||||
#undef HIP_ROUTINE_EACH
|
||||
#undef STREAM_EXECUTOR_HIP_WRAP
|
||||
#undef TO_STR
|
||||
#undef TO_STR_
|
||||
} // namespace wrap
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // TENSORFLOW_STREAM_EXECUTOR_ROCM_ROCM_DRIVER_WRAPPER_H_
|
618
tensorflow/stream_executor/rocm/rocm_fft.cc
Normal file
618
tensorflow/stream_executor/rocm/rocm_fft.cc
Normal file
@ -0,0 +1,618 @@
|
||||
/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/stream_executor/rocm/rocm_fft.h"
|
||||
|
||||
#include <complex>
|
||||
|
||||
#include "tensorflow/stream_executor/device_memory.h"
|
||||
#include "tensorflow/stream_executor/gpu/gpu_activation.h"
|
||||
#include "tensorflow/stream_executor/gpu/gpu_executor.h"
|
||||
#include "tensorflow/stream_executor/gpu/gpu_helpers.h"
|
||||
#include "tensorflow/stream_executor/gpu/gpu_stream.h"
|
||||
#include "tensorflow/stream_executor/lib/env.h"
|
||||
#include "tensorflow/stream_executor/lib/initialize.h"
|
||||
#include "tensorflow/stream_executor/lib/status.h"
|
||||
#include "tensorflow/stream_executor/platform/dso_loader.h"
|
||||
#include "tensorflow/stream_executor/platform/logging.h"
|
||||
#include "tensorflow/stream_executor/platform/port.h"
|
||||
#include "tensorflow/stream_executor/plugin_registry.h"
|
||||
#include "tensorflow/stream_executor/rocm/rocm_platform_id.h"
|
||||
#include "tensorflow/stream_executor/stream_executor_internal.h"
|
||||
|
||||
namespace stream_executor {
|
||||
namespace gpu {
|
||||
|
||||
PLUGIN_REGISTRY_DEFINE_PLUGIN_ID(kRocFftPlugin);
|
||||
|
||||
namespace wrap {
|
||||
|
||||
#ifdef PLATFORM_GOOGLE
|
||||
// This macro wraps a global identifier, given by __name, in a callable
|
||||
// structure that loads the DLL symbol out of the DSO handle in a thread-safe
|
||||
// manner on first use. This dynamic loading technique is used to avoid DSO
|
||||
// dependencies on vendor libraries which may or may not be available in the
|
||||
// deployed binary environment.
|
||||
#define STREAM_EXECUTOR_ROCFFT_WRAP(__name) \
|
||||
struct WrapperShim__##__name { \
|
||||
template <typename... Args> \
|
||||
hipfftResult operator()(GpuExecutor *parent, Args... args) { \
|
||||
gpu::ScopedActivateExecutorContext sac{parent}; \
|
||||
return ::__name(args...); \
|
||||
} \
|
||||
} __name;
|
||||
|
||||
#else
|
||||
|
||||
#define STREAM_EXECUTOR_ROCFFT_WRAP(__name) \
|
||||
struct DynLoadShim__##__name { \
|
||||
static const char *kName; \
|
||||
using FuncPtrT = std::add_pointer<decltype(::__name)>::type; \
|
||||
static void *GetDsoHandle() { \
|
||||
auto s = internal::CachedDsoLoader::GetRocfftDsoHandle(); \
|
||||
return s.ValueOrDie(); \
|
||||
} \
|
||||
static FuncPtrT LoadOrDie() { \
|
||||
void *f; \
|
||||
auto s = port::Env::Default()->GetSymbolFromLibrary(GetDsoHandle(), \
|
||||
kName, &f); \
|
||||
CHECK(s.ok()) << "could not find " << kName \
|
||||
<< " in rocfft DSO; dlerror: " << s.error_message(); \
|
||||
return reinterpret_cast<FuncPtrT>(f); \
|
||||
} \
|
||||
static FuncPtrT DynLoad() { \
|
||||
static FuncPtrT f = LoadOrDie(); \
|
||||
return f; \
|
||||
} \
|
||||
template <typename... Args> \
|
||||
hipfftResult operator()(GpuExecutor *parent, Args... args) { \
|
||||
gpu::ScopedActivateExecutorContext sac{parent}; \
|
||||
return DynLoad()(args...); \
|
||||
} \
|
||||
} __name; \
|
||||
const char *DynLoadShim__##__name::kName = #__name;
|
||||
|
||||
#endif
|
||||
|
||||
#define ROCFFT_ROUTINE_EACH(__macro) \
|
||||
__macro(hipfftDestroy) __macro(hipfftSetStream) __macro(hipfftPlan1d) \
|
||||
__macro(hipfftPlan2d) __macro(hipfftPlan3d) __macro(hipfftPlanMany) \
|
||||
__macro(hipfftCreate) __macro(hipfftSetAutoAllocation) \
|
||||
__macro(hipfftSetWorkArea) __macro(hipfftGetSize1d) \
|
||||
__macro(hipfftMakePlan1d) __macro(hipfftGetSize2d) \
|
||||
__macro(hipfftMakePlan2d) __macro(hipfftGetSize3d) \
|
||||
__macro(hipfftMakePlan3d) __macro(hipfftGetSizeMany) \
|
||||
__macro(hipfftMakePlanMany) \
|
||||
__macro(hipfftExecD2Z) \
|
||||
__macro(hipfftExecZ2D) \
|
||||
__macro(hipfftExecC2C) \
|
||||
__macro(hipfftExecC2R) \
|
||||
__macro(hipfftExecZ2Z) \
|
||||
__macro(hipfftExecR2C)
|
||||
|
||||
ROCFFT_ROUTINE_EACH(STREAM_EXECUTOR_ROCFFT_WRAP)
|
||||
|
||||
} // namespace wrap
|
||||
|
||||
namespace {
|
||||
|
||||
// A helper function transforming gpu_fft arguments into rocFFT arguments.
|
||||
hipfftType ROCMFftType(fft::Type type) {
|
||||
switch (type) {
|
||||
case fft::Type::kC2CForward:
|
||||
case fft::Type::kC2CInverse:
|
||||
return HIPFFT_C2C;
|
||||
case fft::Type::kC2R:
|
||||
return HIPFFT_C2R;
|
||||
case fft::Type::kR2C:
|
||||
return HIPFFT_R2C;
|
||||
case fft::Type::kZ2ZForward:
|
||||
case fft::Type::kZ2ZInverse:
|
||||
return HIPFFT_Z2Z;
|
||||
case fft::Type::kZ2D:
|
||||
return HIPFFT_Z2D;
|
||||
case fft::Type::kD2Z:
|
||||
return HIPFFT_D2Z;
|
||||
default:
|
||||
LOG(FATAL) << "Invalid value of fft::Type.";
|
||||
}
|
||||
}
|
||||
|
||||
// Associates the given stream with the given rocFFT plan.
|
||||
bool SetStream(GpuExecutor *parent, hipfftHandle plan, Stream *stream) {
|
||||
auto ret = wrap::hipfftSetStream(parent, plan, AsGpuStreamValue(stream));
|
||||
if (ret != HIPFFT_SUCCESS) {
|
||||
LOG(ERROR) << "failed to run rocFFT routine hipfftSetStream: " << ret;
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
port::Status ROCMFftPlan::Initialize(
|
||||
GpuExecutor *parent, Stream *stream, int rank, uint64 *elem_count,
|
||||
uint64 *input_embed, uint64 input_stride, uint64 input_distance,
|
||||
uint64 *output_embed, uint64 output_stride, uint64 output_distance,
|
||||
fft::Type type, int batch_count, ScratchAllocator *scratch_allocator) {
|
||||
if (IsInitialized()) {
|
||||
LOG(FATAL) << "Try to repeatedly initialize.";
|
||||
}
|
||||
is_initialized_ = true;
|
||||
int elem_count_[3], input_embed_[3], output_embed_[3];
|
||||
for (int i = 0; i < rank; ++i) {
|
||||
elem_count_[i] = elem_count[i];
|
||||
if (input_embed) {
|
||||
input_embed_[i] = input_embed[i];
|
||||
}
|
||||
if (output_embed) {
|
||||
output_embed_[i] = output_embed[i];
|
||||
}
|
||||
}
|
||||
parent_ = parent;
|
||||
fft_type_ = type;
|
||||
if (batch_count == 1 && input_embed == nullptr && output_embed == nullptr) {
|
||||
hipfftResult_t ret;
|
||||
if (scratch_allocator == nullptr) {
|
||||
switch (rank) {
|
||||
case 1:
|
||||
// hipfftPlan1d
|
||||
ret = wrap::hipfftPlan1d(parent, &plan_, elem_count_[0],
|
||||
ROCMFftType(type), 1 /* = batch */);
|
||||
if (ret != HIPFFT_SUCCESS) {
|
||||
LOG(ERROR) << "failed to create rocFFT 1d plan:" << ret;
|
||||
return port::Status{port::error::INTERNAL,
|
||||
"Failed to create rocFFT 1d plan."};
|
||||
}
|
||||
return port::Status::OK();
|
||||
case 2:
|
||||
// hipfftPlan2d
|
||||
ret = wrap::hipfftPlan2d(parent, &plan_, elem_count_[0],
|
||||
elem_count_[1], ROCMFftType(type));
|
||||
if (ret != HIPFFT_SUCCESS) {
|
||||
LOG(ERROR) << "failed to create rocFFT 2d plan:" << ret;
|
||||
return port::Status{port::error::INTERNAL,
|
||||
"Failed to create rocFFT 2d plan."};
|
||||
}
|
||||
return port::Status::OK();
|
||||
case 3:
|
||||
// hipfftPlan3d
|
||||
ret =
|
||||
wrap::hipfftPlan3d(parent, &plan_, elem_count_[0], elem_count_[1],
|
||||
elem_count_[2], ROCMFftType(type));
|
||||
if (ret != HIPFFT_SUCCESS) {
|
||||
LOG(ERROR) << "failed to create rocFFT 3d plan:" << ret;
|
||||
return port::Status{port::error::INTERNAL,
|
||||
"Failed to create rocFFT 3d plan."};
|
||||
}
|
||||
return port::Status::OK();
|
||||
default:
|
||||
LOG(ERROR) << "Invalid rank value for hipfftPlan. "
|
||||
"Requested 1, 2, or 3, given: "
|
||||
<< rank;
|
||||
return port::Status{port::error::INVALID_ARGUMENT,
|
||||
"hipfftPlan only takes rank 1, 2, or 3."};
|
||||
}
|
||||
} else {
|
||||
ret = wrap::hipfftCreate(parent, &plan_);
|
||||
if (ret != HIPFFT_SUCCESS) {
|
||||
LOG(ERROR) << "failed to create rocFFT plan:" << ret;
|
||||
return port::Status{port::error::INTERNAL,
|
||||
"Failed to create rocFFT plan."};
|
||||
}
|
||||
ret = wrap::hipfftSetAutoAllocation(parent, plan_, 0);
|
||||
if (ret != HIPFFT_SUCCESS) {
|
||||
LOG(ERROR) << "failed to set auto allocation for rocFFT plan:" << ret;
|
||||
return port::Status{port::error::INTERNAL,
|
||||
"Failed to set auto allocation for rocFFT plan."};
|
||||
}
|
||||
size_t size_in_bytes;
|
||||
switch (rank) {
|
||||
case 1:
|
||||
ret = wrap::hipfftMakePlan1d(parent, plan_, elem_count_[0],
|
||||
ROCMFftType(type), /*batch=*/1,
|
||||
&size_in_bytes);
|
||||
if (ret != HIPFFT_SUCCESS) {
|
||||
LOG(ERROR) << "failed to make rocFFT 1d plan:" << ret;
|
||||
return port::Status{port::error::INTERNAL,
|
||||
"Failed to make rocFFT 1d plan."};
|
||||
}
|
||||
break;
|
||||
case 2:
|
||||
ret = wrap::hipfftMakePlan2d(parent, plan_, elem_count_[0],
|
||||
elem_count_[1], ROCMFftType(type),
|
||||
&size_in_bytes);
|
||||
if (ret != HIPFFT_SUCCESS) {
|
||||
LOG(ERROR) << "failed to make rocFFT 2d plan:" << ret;
|
||||
return port::Status{port::error::INTERNAL,
|
||||
"Failed to make rocFFT 2d plan."};
|
||||
}
|
||||
break;
|
||||
case 3:
|
||||
ret = wrap::hipfftMakePlan3d(parent, plan_, elem_count_[0],
|
||||
elem_count_[1], elem_count_[2],
|
||||
ROCMFftType(type), &size_in_bytes);
|
||||
if (ret != HIPFFT_SUCCESS) {
|
||||
LOG(ERROR) << "failed to make rocFFT 3d plan:" << ret;
|
||||
return port::Status{port::error::INTERNAL,
|
||||
"Failed to make rocFFT 3d plan."};
|
||||
}
|
||||
break;
|
||||
default:
|
||||
LOG(ERROR) << "Invalid rank value for hipfftPlan. "
|
||||
"Requested 1, 2, or 3, given: "
|
||||
<< rank;
|
||||
return port::Status{port::error::INVALID_ARGUMENT,
|
||||
"hipfftPlan only takes rank 1, 2, or 3."};
|
||||
}
|
||||
// TODO(yangzihao): refactor this code and the one with the same function
|
||||
// in the batch mode.
|
||||
if (size_in_bytes != 0) {
|
||||
auto allocated =
|
||||
scratch_allocator->AllocateBytes(stream, size_in_bytes);
|
||||
if (!allocated.ok() || (scratch_ = allocated.ValueOrDie()) == nullptr) {
|
||||
LOG(ERROR) << "failed to allocate work area.";
|
||||
return allocated.status();
|
||||
}
|
||||
}
|
||||
// Connect work area with allocated space.
|
||||
ret = wrap::hipfftSetWorkArea(parent, plan_, scratch_.opaque());
|
||||
if (ret != HIPFFT_SUCCESS) {
|
||||
LOG(ERROR) << "failed to set work area for rocFFT plan:" << ret;
|
||||
return port::Status{port::error::INTERNAL,
|
||||
"Failed to set work area for rocFFT plan."};
|
||||
}
|
||||
return port::Status::OK();
|
||||
}
|
||||
} else {
|
||||
// For either multiple batches or rank higher than 3, use hipfftPlanMany().
|
||||
if (scratch_allocator == nullptr) {
|
||||
auto ret = wrap::hipfftPlanMany(
|
||||
parent, &plan_, rank, elem_count_,
|
||||
input_embed ? input_embed_ : nullptr, input_stride, input_distance,
|
||||
output_embed ? output_embed_ : nullptr, output_stride,
|
||||
output_distance, ROCMFftType(type), batch_count);
|
||||
if (ret != HIPFFT_SUCCESS) {
|
||||
LOG(ERROR) << "failed to create rocFFT batched plan:" << ret;
|
||||
return port::Status{port::error::INTERNAL,
|
||||
"Failed to create rocFFT bacthed plan."};
|
||||
}
|
||||
} else {
|
||||
auto ret = wrap::hipfftCreate(parent, &plan_);
|
||||
if (ret != HIPFFT_SUCCESS) {
|
||||
LOG(ERROR) << "failed to create rocFFT batched plan:" << ret;
|
||||
return port::Status{port::error::INTERNAL,
|
||||
"Failed to create rocFFT bacthed plan."};
|
||||
}
|
||||
ret = wrap::hipfftSetAutoAllocation(parent, plan_, 0);
|
||||
if (ret != HIPFFT_SUCCESS) {
|
||||
LOG(ERROR) << "failed to set auto allocation for rocFFT batched plan:"
|
||||
<< ret;
|
||||
return port::Status{
|
||||
port::error::INTERNAL,
|
||||
"Failed to set auto allocation for rocFFT bacthed plan."};
|
||||
}
|
||||
size_t size_in_bytes;
|
||||
ret = wrap::hipfftMakePlanMany(
|
||||
parent, plan_, rank, elem_count_,
|
||||
input_embed ? input_embed_ : nullptr, input_stride, input_distance,
|
||||
output_embed ? output_embed_ : nullptr, output_stride,
|
||||
output_distance, ROCMFftType(type), batch_count, &size_in_bytes);
|
||||
if (ret != HIPFFT_SUCCESS) {
|
||||
LOG(ERROR) << "failed to make rocFFT batched plan:" << ret;
|
||||
return port::Status{port::error::INTERNAL,
|
||||
"Failed to make rocFFT bacthed plan."};
|
||||
}
|
||||
if (size_in_bytes != 0) {
|
||||
auto allocated =
|
||||
scratch_allocator->AllocateBytes(stream, size_in_bytes);
|
||||
if (!allocated.ok() || (scratch_ = allocated.ValueOrDie()) == nullptr) {
|
||||
LOG(ERROR) << "failed to allocate work area.";
|
||||
return allocated.status();
|
||||
}
|
||||
}
|
||||
// Connect work area with allocated space.
|
||||
ret = wrap::hipfftSetWorkArea(parent, plan_, scratch_.opaque());
|
||||
if (ret != HIPFFT_SUCCESS) {
|
||||
LOG(ERROR) << "failed to set work area for rocFFT batched plan:" << ret;
|
||||
return port::Status{port::error::INTERNAL,
|
||||
"Failed to set work area for rocFFT bacthed plan."};
|
||||
}
|
||||
}
|
||||
}
|
||||
return port::Status::OK();
|
||||
}
|
||||
|
||||
port::Status ROCMFftPlan::Initialize(GpuExecutor *parent, Stream *stream,
|
||||
int rank, uint64 *elem_count,
|
||||
fft::Type type,
|
||||
ScratchAllocator *scratch_allocator) {
|
||||
return Initialize(parent_, stream, rank, elem_count,
|
||||
/*input_embed=*/nullptr, /*input_stride=*/0,
|
||||
/*input_distance=*/0,
|
||||
/*output_embed=*/nullptr, /*output_stride=*/0,
|
||||
/*output_distance=*/0, type, 1, scratch_allocator);
|
||||
}
|
||||
|
||||
ROCMFftPlan::~ROCMFftPlan() { wrap::hipfftDestroy(parent_, plan_); }
|
||||
|
||||
int ROCMFftPlan::GetFftDirection() const {
|
||||
if (!IsInitialized()) {
|
||||
LOG(FATAL) << "Try to get fft direction before initialization.";
|
||||
} else {
|
||||
switch (fft_type_) {
|
||||
case fft::Type::kC2CForward:
|
||||
case fft::Type::kZ2ZForward:
|
||||
case fft::Type::kR2C:
|
||||
case fft::Type::kD2Z:
|
||||
return HIPFFT_FORWARD;
|
||||
case fft::Type::kC2CInverse:
|
||||
case fft::Type::kZ2ZInverse:
|
||||
case fft::Type::kC2R:
|
||||
case fft::Type::kZ2D:
|
||||
return HIPFFT_BACKWARD;
|
||||
default:
|
||||
LOG(FATAL) << "Invalid value of fft::Type.";
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
std::unique_ptr<fft::Plan> ROCMFft::Create1dPlan(Stream *stream, uint64 num_x,
|
||||
fft::Type type,
|
||||
bool in_place_fft) {
|
||||
std::unique_ptr<ROCMFftPlan> fft_plan_ptr{new ROCMFftPlan()};
|
||||
uint64 elem_count[1] = {num_x};
|
||||
port::Status status = fft_plan_ptr->Initialize(
|
||||
parent_, stream, 1, elem_count, type, /*scratch_allocator=*/nullptr);
|
||||
// TODO(yangzihao): In the future, send error msg back to TensorFlow
|
||||
// so it can fail gracefully,
|
||||
if (!status.ok()) {
|
||||
LOG(FATAL) << "failed to initialize hipfft 1d plan: "
|
||||
<< status.error_message();
|
||||
}
|
||||
return std::move(fft_plan_ptr);
|
||||
}
|
||||
|
||||
std::unique_ptr<fft::Plan> ROCMFft::Create1dPlanWithScratchAllocator(
|
||||
Stream *stream, uint64 num_x, fft::Type type, bool in_place_fft,
|
||||
ScratchAllocator *scratch_allocator) {
|
||||
std::unique_ptr<ROCMFftPlan> fft_plan_ptr{new ROCMFftPlan()};
|
||||
uint64 elem_count[1] = {num_x};
|
||||
port::Status status = fft_plan_ptr->Initialize(parent_, stream, 1, elem_count,
|
||||
type, scratch_allocator);
|
||||
if (!status.ok()) {
|
||||
LOG(FATAL)
|
||||
<< "failed to initialize hipfft 1d plan with customized allocator: "
|
||||
<< status.error_message();
|
||||
}
|
||||
return std::move(fft_plan_ptr);
|
||||
}
|
||||
|
||||
std::unique_ptr<fft::Plan> ROCMFft::Create2dPlan(Stream *stream, uint64 num_x,
|
||||
uint64 num_y, fft::Type type,
|
||||
bool in_place_fft) {
|
||||
std::unique_ptr<ROCMFftPlan> fft_plan_ptr{new ROCMFftPlan()};
|
||||
uint64 elem_count[2] = {num_x, num_y};
|
||||
port::Status status = fft_plan_ptr->Initialize(
|
||||
parent_, stream, 1, elem_count, type, /*scratch_allocator=*/nullptr);
|
||||
if (!status.ok()) {
|
||||
LOG(FATAL) << "failed to initialize hipfft 2d plan: "
|
||||
<< status.error_message();
|
||||
}
|
||||
return std::move(fft_plan_ptr);
|
||||
}
|
||||
|
||||
std::unique_ptr<fft::Plan> ROCMFft::Create2dPlanWithScratchAllocator(
|
||||
Stream *stream, uint64 num_x, uint64 num_y, fft::Type type,
|
||||
bool in_place_fft, ScratchAllocator *scratch_allocator) {
|
||||
std::unique_ptr<ROCMFftPlan> fft_plan_ptr{new ROCMFftPlan()};
|
||||
uint64 elem_count[2] = {num_x, num_y};
|
||||
port::Status status = fft_plan_ptr->Initialize(parent_, stream, 2, elem_count,
|
||||
type, scratch_allocator);
|
||||
if (!status.ok()) {
|
||||
LOG(FATAL)
|
||||
<< "failed to initialize hipfft 2d plan with customized allocator: "
|
||||
<< status.error_message();
|
||||
}
|
||||
return std::move(fft_plan_ptr);
|
||||
}
|
||||
|
||||
std::unique_ptr<fft::Plan> ROCMFft::Create3dPlan(Stream *stream, uint64 num_x,
|
||||
uint64 num_y, uint64 num_z,
|
||||
fft::Type type,
|
||||
bool in_place_fft) {
|
||||
std::unique_ptr<ROCMFftPlan> fft_plan_ptr{new ROCMFftPlan()};
|
||||
uint64 elem_count[3] = {num_x, num_y, num_z};
|
||||
port::Status status = fft_plan_ptr->Initialize(
|
||||
parent_, stream, 3, elem_count, type, /*scratch_allocator=*/nullptr);
|
||||
if (!status.ok()) {
|
||||
LOG(FATAL) << "failed to initialize hipfft 3d plan: "
|
||||
<< status.error_message();
|
||||
}
|
||||
return std::move(fft_plan_ptr);
|
||||
}
|
||||
|
||||
std::unique_ptr<fft::Plan> ROCMFft::Create3dPlanWithScratchAllocator(
|
||||
Stream *stream, uint64 num_x, uint64 num_y, uint64 num_z, fft::Type type,
|
||||
bool in_place_fft, ScratchAllocator *scratch_allocator) {
|
||||
std::unique_ptr<ROCMFftPlan> fft_plan_ptr{new ROCMFftPlan()};
|
||||
uint64 elem_count[3] = {num_x, num_y, num_z};
|
||||
port::Status status = fft_plan_ptr->Initialize(parent_, stream, 3, elem_count,
|
||||
type, scratch_allocator);
|
||||
if (!status.ok()) {
|
||||
LOG(FATAL)
|
||||
<< "failed to initialize hipfft 3d plan with customized allocator: "
|
||||
<< status.error_message();
|
||||
}
|
||||
return std::move(fft_plan_ptr);
|
||||
}
|
||||
|
||||
std::unique_ptr<fft::Plan> ROCMFft::CreateBatchedPlan(
|
||||
Stream *stream, int rank, uint64 *elem_count, uint64 *input_embed,
|
||||
uint64 input_stride, uint64 input_distance, uint64 *output_embed,
|
||||
uint64 output_stride, uint64 output_distance, fft::Type type,
|
||||
bool in_place_fft, int batch_count) {
|
||||
std::unique_ptr<ROCMFftPlan> fft_plan_ptr{new ROCMFftPlan()};
|
||||
port::Status status = fft_plan_ptr->Initialize(
|
||||
parent_, stream, rank, elem_count, input_embed, input_stride,
|
||||
input_distance, output_embed, output_stride, output_distance, type,
|
||||
batch_count, /*scratch_allocator=*/nullptr);
|
||||
if (!status.ok()) {
|
||||
LOG(FATAL) << "failed to initialize batched hipfft plan: "
|
||||
<< status.error_message();
|
||||
}
|
||||
|
||||
return std::move(fft_plan_ptr);
|
||||
}
|
||||
|
||||
std::unique_ptr<fft::Plan> ROCMFft::CreateBatchedPlanWithScratchAllocator(
|
||||
Stream *stream, int rank, uint64 *elem_count, uint64 *input_embed,
|
||||
uint64 input_stride, uint64 input_distance, uint64 *output_embed,
|
||||
uint64 output_stride, uint64 output_distance, fft::Type type,
|
||||
bool in_place_fft, int batch_count, ScratchAllocator *scratch_allocator) {
|
||||
std::unique_ptr<ROCMFftPlan> fft_plan_ptr{new ROCMFftPlan()};
|
||||
port::Status status = fft_plan_ptr->Initialize(
|
||||
parent_, stream, rank, elem_count, input_embed, input_stride,
|
||||
input_distance, output_embed, output_stride, output_distance, type,
|
||||
batch_count, scratch_allocator);
|
||||
if (!status.ok()) {
|
||||
LOG(FATAL) << "failed to initialize batched hipfft plan with customized "
|
||||
"allocator: "
|
||||
<< status.error_message();
|
||||
}
|
||||
return std::move(fft_plan_ptr);
|
||||
}
|
||||
|
||||
void ROCMFft::UpdatePlanWithScratchAllocator(
|
||||
Stream *stream, fft::Plan *plan, ScratchAllocator *scratch_allocator) {
|
||||
LOG(ERROR) << "update plan with scratch allocator not implemented";
|
||||
}
|
||||
|
||||
template <typename FuncT, typename InputT, typename OutputT>
|
||||
bool ROCMFft::DoFftInternal(Stream *stream, fft::Plan *plan, FuncT hipfftExec,
|
||||
const DeviceMemory<InputT> &input,
|
||||
DeviceMemory<OutputT> *output) {
|
||||
ROCMFftPlan *rocm_fft_plan = dynamic_cast<ROCMFftPlan *>(plan);
|
||||
if (rocm_fft_plan == nullptr) {
|
||||
LOG(ERROR) << "the passed-in plan is not a ROCMFftPlan object.";
|
||||
return false;
|
||||
}
|
||||
|
||||
if (!SetStream(parent_, rocm_fft_plan->GetPlan(), stream)) {
|
||||
return false;
|
||||
}
|
||||
|
||||
auto ret = hipfftExec(parent_, rocm_fft_plan->GetPlan(),
|
||||
GpuComplex(const_cast<InputT *>(GpuMemory(input))),
|
||||
GpuComplex(GpuMemoryMutable(output)));
|
||||
|
||||
if (ret != HIPFFT_SUCCESS) {
|
||||
LOG(ERROR) << "failed to run rocFFT routine: " << ret;
|
||||
return false;
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
template <typename FuncT, typename InputT, typename OutputT>
|
||||
bool ROCMFft::DoFftWithDirectionInternal(Stream *stream, fft::Plan *plan,
|
||||
FuncT hipfftExec,
|
||||
const DeviceMemory<InputT> &input,
|
||||
DeviceMemory<OutputT> *output) {
|
||||
ROCMFftPlan *rocm_fft_plan = dynamic_cast<ROCMFftPlan *>(plan);
|
||||
if (rocm_fft_plan == nullptr) {
|
||||
LOG(ERROR) << "the passed-in plan is not a ROCMFftPlan object.";
|
||||
return false;
|
||||
}
|
||||
|
||||
if (!SetStream(parent_, rocm_fft_plan->GetPlan(), stream)) {
|
||||
return false;
|
||||
}
|
||||
|
||||
auto ret = hipfftExec(parent_, rocm_fft_plan->GetPlan(),
|
||||
GpuComplex(const_cast<InputT *>(GpuMemory(input))),
|
||||
GpuComplex(GpuMemoryMutable(output)),
|
||||
rocm_fft_plan->GetFftDirection());
|
||||
|
||||
if (ret != HIPFFT_SUCCESS) {
|
||||
LOG(ERROR) << "failed to run rocFFT routine: " << ret;
|
||||
return false;
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
#define STREAM_EXECUTOR_ROCM_DEFINE_FFT(__type, __fft_type1, __fft_type2, \
|
||||
__fft_type3) \
|
||||
bool ROCMFft::DoFft(Stream *stream, fft::Plan *plan, \
|
||||
const DeviceMemory<std::complex<__type>> &input, \
|
||||
DeviceMemory<std::complex<__type>> *output) { \
|
||||
return DoFftWithDirectionInternal( \
|
||||
stream, plan, wrap::hipfftExec##__fft_type1, input, output); \
|
||||
} \
|
||||
bool ROCMFft::DoFft(Stream *stream, fft::Plan *plan, \
|
||||
const DeviceMemory<__type> &input, \
|
||||
DeviceMemory<std::complex<__type>> *output) { \
|
||||
return DoFftInternal(stream, plan, wrap::hipfftExec##__fft_type2, input, \
|
||||
output); \
|
||||
} \
|
||||
bool ROCMFft::DoFft(Stream *stream, fft::Plan *plan, \
|
||||
const DeviceMemory<std::complex<__type>> &input, \
|
||||
DeviceMemory<__type> *output) { \
|
||||
return DoFftInternal(stream, plan, wrap::hipfftExec##__fft_type3, input, \
|
||||
output); \
|
||||
}
|
||||
|
||||
STREAM_EXECUTOR_ROCM_DEFINE_FFT(float, C2C, R2C, C2R)
|
||||
STREAM_EXECUTOR_ROCM_DEFINE_FFT(double, Z2Z, D2Z, Z2D)
|
||||
|
||||
#undef STREAM_EXECUTOR_ROCM_DEFINE_FFT
|
||||
|
||||
} // namespace gpu
|
||||
|
||||
void initialize_rocfft() {
|
||||
auto rocFftAlreadyRegistered = PluginRegistry::Instance()->HasFactory(
|
||||
rocm::kROCmPlatformId, PluginKind::kFft, gpu::kRocFftPlugin);
|
||||
|
||||
if (!rocFftAlreadyRegistered) {
|
||||
port::Status status =
|
||||
PluginRegistry::Instance()->RegisterFactory<PluginRegistry::FftFactory>(
|
||||
rocm::kROCmPlatformId, gpu::kRocFftPlugin, "rocFFT",
|
||||
[](internal::StreamExecutorInterface *parent) -> fft::FftSupport * {
|
||||
gpu::GpuExecutor *rocm_executor =
|
||||
dynamic_cast<gpu::GpuExecutor *>(parent);
|
||||
if (rocm_executor == nullptr) {
|
||||
LOG(ERROR)
|
||||
<< "Attempting to initialize an instance of the rocFFT "
|
||||
<< "support library with a non-ROCM StreamExecutor";
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
return new gpu::ROCMFft(rocm_executor);
|
||||
});
|
||||
if (!status.ok()) {
|
||||
LOG(ERROR) << "Unable to register rocFFT factory: "
|
||||
<< status.error_message();
|
||||
}
|
||||
|
||||
PluginRegistry::Instance()->SetDefaultFactory(
|
||||
rocm::kROCmPlatformId, PluginKind::kFft, gpu::kRocFftPlugin);
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace stream_executor
|
||||
|
||||
REGISTER_MODULE_INITIALIZER(register_rocfft,
|
||||
{ stream_executor::initialize_rocfft(); });
|
132
tensorflow/stream_executor/rocm/rocm_fft.h
Normal file
132
tensorflow/stream_executor/rocm/rocm_fft.h
Normal file
@ -0,0 +1,132 @@
|
||||
/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
// ROCM-specific support for FFT functionality -- this wraps the rocFFT library
|
||||
// capabilities, and is only included into ROCM implementation code -- it will
|
||||
// not introduce rocm headers into other code.
|
||||
|
||||
#ifndef TENSORFLOW_STREAM_EXECUTOR_ROCM_ROCM_FFT_H_
|
||||
#define TENSORFLOW_STREAM_EXECUTOR_ROCM_ROCM_FFT_H_
|
||||
|
||||
#include "rocm/include/rocfft/hipfft.h"
|
||||
#include "tensorflow/stream_executor/fft.h"
|
||||
#include "tensorflow/stream_executor/platform/port.h"
|
||||
#include "tensorflow/stream_executor/plugin_registry.h"
|
||||
#include "tensorflow/stream_executor/scratch_allocator.h"
|
||||
|
||||
namespace stream_executor {
|
||||
|
||||
class Stream;
|
||||
|
||||
namespace gpu {
|
||||
|
||||
class GpuExecutor;
|
||||
|
||||
// Opaque and unique indentifier for the rocFFT plugin.
|
||||
extern const PluginId kRocFftPlugin;
|
||||
|
||||
// ROCMFftPlan uses deferred initialization. Only a single call of
|
||||
// Initialize() is allowed to properly create hipfft plan and set member
|
||||
// variable is_initialized_ to true. Newly added interface that uses member
|
||||
// variables should first check is_initialized_ to make sure that the values of
|
||||
// member variables are valid.
|
||||
class ROCMFftPlan : public fft::Plan {
|
||||
public:
|
||||
ROCMFftPlan()
|
||||
: parent_(nullptr),
|
||||
plan_(),
|
||||
fft_type_(fft::Type::kInvalid),
|
||||
scratch_(nullptr),
|
||||
is_initialized_(false) {}
|
||||
~ROCMFftPlan() override;
|
||||
|
||||
// Get FFT direction in hipFFT based on FFT type.
|
||||
int GetFftDirection() const;
|
||||
hipfftHandle GetPlan() const {
|
||||
if (IsInitialized()) {
|
||||
return plan_;
|
||||
} else {
|
||||
LOG(FATAL) << "Try to get hipfftHandle value before initialization.";
|
||||
}
|
||||
}
|
||||
|
||||
// Initialize function for batched plan
|
||||
port::Status Initialize(GpuExecutor *parent, Stream *stream, int rank,
|
||||
uint64 *elem_count, uint64 *input_embed,
|
||||
uint64 input_stride, uint64 input_distance,
|
||||
uint64 *output_embed, uint64 output_stride,
|
||||
uint64 output_distance, fft::Type type,
|
||||
int batch_count, ScratchAllocator *scratch_allocator);
|
||||
|
||||
// Initialize function for 1d,2d, and 3d plan
|
||||
port::Status Initialize(GpuExecutor *parent, Stream *stream, int rank,
|
||||
uint64 *elem_count, fft::Type type,
|
||||
ScratchAllocator *scratch_allocator);
|
||||
|
||||
protected:
|
||||
bool IsInitialized() const { return is_initialized_; }
|
||||
|
||||
private:
|
||||
GpuExecutor *parent_;
|
||||
hipfftHandle plan_;
|
||||
fft::Type fft_type_;
|
||||
DeviceMemory<uint8> scratch_;
|
||||
bool is_initialized_;
|
||||
};
|
||||
|
||||
// FFT support for ROCM platform via rocFFT library.
|
||||
//
|
||||
// This satisfies the platform-agnostic FftSupport interface.
|
||||
//
|
||||
// Note that the hipFFT handle that this encapsulates is implicitly tied to the
|
||||
// context (and, as a result, the device) that the parent GpuExecutor is tied
|
||||
// to. This simply happens as an artifact of creating the hipFFT handle when a
|
||||
// ROCM context is active.
|
||||
//
|
||||
// Thread-safe. The ROCM context associated with all operations is the ROCM
|
||||
// context of parent_, so all context is explicit.
|
||||
class ROCMFft : public fft::FftSupport {
|
||||
public:
|
||||
explicit ROCMFft(GpuExecutor *parent) : parent_(parent) {}
|
||||
~ROCMFft() override {}
|
||||
|
||||
TENSORFLOW_STREAM_EXECUTOR_GPU_FFT_SUPPORT_OVERRIDES
|
||||
|
||||
private:
|
||||
GpuExecutor *parent_;
|
||||
|
||||
// Two helper functions that execute dynload::hipfftExec?2?.
|
||||
|
||||
// This is for complex to complex FFT, when the direction is required.
|
||||
template <typename FuncT, typename InputT, typename OutputT>
|
||||
bool DoFftWithDirectionInternal(Stream *stream, fft::Plan *plan,
|
||||
FuncT hipfft_exec,
|
||||
const DeviceMemory<InputT> &input,
|
||||
DeviceMemory<OutputT> *output);
|
||||
|
||||
// This is for complex to real or real to complex FFT, when the direction
|
||||
// is implied.
|
||||
template <typename FuncT, typename InputT, typename OutputT>
|
||||
bool DoFftInternal(Stream *stream, fft::Plan *plan, FuncT hipfft_exec,
|
||||
const DeviceMemory<InputT> &input,
|
||||
DeviceMemory<OutputT> *output);
|
||||
|
||||
SE_DISALLOW_COPY_AND_ASSIGN(ROCMFft);
|
||||
};
|
||||
|
||||
} // namespace gpu
|
||||
} // namespace stream_executor
|
||||
|
||||
#endif // TENSORFLOW_STREAM_EXECUTOR_ROCM_ROCM_FFT_H_
|
@ -18,7 +18,6 @@ limitations under the License.
|
||||
#include "absl/base/casts.h"
|
||||
#include "absl/strings/str_cat.h"
|
||||
#include "absl/strings/str_format.h"
|
||||
#include "tensorflow/stream_executor/gpu/gpu_diagnostics.h"
|
||||
#include "tensorflow/stream_executor/gpu/gpu_driver.h"
|
||||
#include "tensorflow/stream_executor/gpu/gpu_event.h"
|
||||
#include "tensorflow/stream_executor/gpu/gpu_executor.h"
|
||||
@ -41,6 +40,7 @@ limitations under the License.
|
||||
#include "tensorflow/stream_executor/platform/logging.h"
|
||||
#include "tensorflow/stream_executor/platform/port.h"
|
||||
#include "tensorflow/stream_executor/plugin_registry.h"
|
||||
#include "tensorflow/stream_executor/rocm/rocm_diagnostics.h"
|
||||
#include "tensorflow/stream_executor/rocm/rocm_platform_id.h"
|
||||
#include "tensorflow/stream_executor/stream.h"
|
||||
#include "tensorflow/stream_executor/stream_executor_internal.h"
|
||||
@ -655,7 +655,7 @@ port::Status GpuExecutor::BlockHostUntilDone(Stream* stream) {
|
||||
blas::BlasSupport* GpuExecutor::CreateBlas() {
|
||||
PluginRegistry* registry = PluginRegistry::Instance();
|
||||
port::StatusOr<PluginRegistry::BlasFactory> status =
|
||||
registry->GetFactory<PluginRegistry::BlasFactory>(kROCmPlatformId,
|
||||
registry->GetFactory<PluginRegistry::BlasFactory>(rocm::kROCmPlatformId,
|
||||
plugin_config_.blas());
|
||||
if (!status.ok()) {
|
||||
LOG(ERROR) << "Unable to retrieve BLAS factory: "
|
||||
@ -669,7 +669,7 @@ blas::BlasSupport* GpuExecutor::CreateBlas() {
|
||||
dnn::DnnSupport* GpuExecutor::CreateDnn() {
|
||||
PluginRegistry* registry = PluginRegistry::Instance();
|
||||
port::StatusOr<PluginRegistry::DnnFactory> status =
|
||||
registry->GetFactory<PluginRegistry::DnnFactory>(kROCmPlatformId,
|
||||
registry->GetFactory<PluginRegistry::DnnFactory>(rocm::kROCmPlatformId,
|
||||
plugin_config_.dnn());
|
||||
if (!status.ok()) {
|
||||
LOG(ERROR) << "Unable to retrieve DNN factory: "
|
||||
@ -683,7 +683,7 @@ dnn::DnnSupport* GpuExecutor::CreateDnn() {
|
||||
fft::FftSupport* GpuExecutor::CreateFft() {
|
||||
PluginRegistry* registry = PluginRegistry::Instance();
|
||||
port::StatusOr<PluginRegistry::FftFactory> status =
|
||||
registry->GetFactory<PluginRegistry::FftFactory>(kROCmPlatformId,
|
||||
registry->GetFactory<PluginRegistry::FftFactory>(rocm::kROCmPlatformId,
|
||||
plugin_config_.fft());
|
||||
if (!status.ok()) {
|
||||
LOG(ERROR) << "Unable to retrieve FFT factory: "
|
||||
@ -697,7 +697,7 @@ fft::FftSupport* GpuExecutor::CreateFft() {
|
||||
rng::RngSupport* GpuExecutor::CreateRng() {
|
||||
PluginRegistry* registry = PluginRegistry::Instance();
|
||||
port::StatusOr<PluginRegistry::RngFactory> status =
|
||||
registry->GetFactory<PluginRegistry::RngFactory>(kROCmPlatformId,
|
||||
registry->GetFactory<PluginRegistry::RngFactory>(rocm::kROCmPlatformId,
|
||||
plugin_config_.rng());
|
||||
if (!status.ok()) {
|
||||
LOG(ERROR) << "Unable to retrieve RNG factory: "
|
||||
@ -878,12 +878,10 @@ DeviceDescription* GpuExecutor::PopulateDeviceDescription() const {
|
||||
{
|
||||
int driver_version = 0;
|
||||
(void)GpuDriver::GetDriverVersion(&driver_version);
|
||||
string augmented_driver_version =
|
||||
absl::StrFormat("%d (%s)", driver_version, "__FIXME__");
|
||||
// FIXME:
|
||||
// uncomment the line below once the "DriverVersionStatusToString"
|
||||
// routine is moved from the "cuda" namespace to the "gpu" naemspace
|
||||
// DriverVersionStatusToString(Diagnostician::FindDsoVersion()).c_str());
|
||||
string augmented_driver_version = absl::StrFormat(
|
||||
"%d (%s)", driver_version,
|
||||
rocm::DriverVersionStatusToString(Diagnostician::FindDsoVersion())
|
||||
.c_str());
|
||||
builder.set_driver_version(augmented_driver_version);
|
||||
}
|
||||
|
||||
|
@ -94,7 +94,7 @@ port::StatusOr<StreamExecutor*> ROCmPlatform::FirstExecutorForBus(
|
||||
absl::StrFormat("Executor for bus %d not found.", bus_ordinal)};
|
||||
}
|
||||
|
||||
Platform::Id ROCmPlatform::id() const { return kROCmPlatformId; }
|
||||
Platform::Id ROCmPlatform::id() const { return rocm::kROCmPlatformId; }
|
||||
|
||||
int ROCmPlatform::VisibleDeviceCount() const {
|
||||
// Throw away the result - it logs internally, and this [containing] function
|
||||
|
@ -16,9 +16,9 @@ limitations under the License.
|
||||
#include "tensorflow/stream_executor/rocm/rocm_platform_id.h"
|
||||
|
||||
namespace stream_executor {
|
||||
namespace gpu {
|
||||
namespace rocm {
|
||||
|
||||
PLATFORM_DEFINE_ID(kROCmPlatformId);
|
||||
|
||||
} // namespace gpu
|
||||
} // namespace rocm
|
||||
} // namespace stream_executor
|
||||
|
@ -19,16 +19,16 @@ limitations under the License.
|
||||
#include "tensorflow/stream_executor/platform.h"
|
||||
|
||||
namespace stream_executor {
|
||||
namespace gpu {
|
||||
namespace rocm {
|
||||
|
||||
// Opaque and unique identifier for the ROCm platform.
|
||||
// This is needed so that plugins can refer to/identify this platform without
|
||||
// instantiating a ROCmPlatform object.
|
||||
// This is broken out here to avoid a circular dependency between ROCmPlatform
|
||||
// and GpuExecutor.
|
||||
// and ROCmExecutor.
|
||||
extern const Platform::Id kROCmPlatformId;
|
||||
|
||||
} // namespace gpu
|
||||
} // namespace rocm
|
||||
} // namespace stream_executor
|
||||
|
||||
#endif // TENSORFLOW_STREAM_EXECUTOR_ROCM_ROCM_PLATFORM_ID_H_
|
||||
|
@ -14,15 +14,17 @@ limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "rocm/include/hiprand/hiprand.h"
|
||||
#include "tensorflow/stream_executor/gpu/gpu_rng.h"
|
||||
|
||||
#include "tensorflow/stream_executor/device_memory.h"
|
||||
#include "tensorflow/stream_executor/gpu/gpu_activation.h"
|
||||
#include "tensorflow/stream_executor/gpu/gpu_executor.h"
|
||||
#include "tensorflow/stream_executor/gpu/gpu_helpers.h"
|
||||
#include "tensorflow/stream_executor/gpu/gpu_rng.h"
|
||||
#include "tensorflow/stream_executor/gpu/gpu_stream.h"
|
||||
#include "tensorflow/stream_executor/lib/env.h"
|
||||
#include "tensorflow/stream_executor/lib/initialize.h"
|
||||
#include "tensorflow/stream_executor/lib/status.h"
|
||||
#include "tensorflow/stream_executor/platform/dso_loader.h"
|
||||
#include "tensorflow/stream_executor/platform/logging.h"
|
||||
#include "tensorflow/stream_executor/rng.h"
|
||||
#include "tensorflow/stream_executor/rocm/rocm_platform_id.h"
|
||||
@ -60,7 +62,9 @@ PLUGIN_REGISTRY_DEFINE_PLUGIN_ID(kGpuRandPlugin);
|
||||
|
||||
namespace wrap {
|
||||
|
||||
#define PERFTOOLS_GPUTOOLS_HIPRAND_WRAP(__name) \
|
||||
#ifdef PLATFORM_GOOGLE
|
||||
|
||||
#define STREAM_EXECUTOR_HIPRAND_WRAP(__name) \
|
||||
struct WrapperShim__##__name { \
|
||||
template <typename... Args> \
|
||||
hiprandStatus_t operator()(GpuExecutor* parent, Args... args) { \
|
||||
@ -69,15 +73,47 @@ namespace wrap {
|
||||
} \
|
||||
} __name;
|
||||
|
||||
PERFTOOLS_GPUTOOLS_HIPRAND_WRAP(hiprandCreateGenerator);
|
||||
PERFTOOLS_GPUTOOLS_HIPRAND_WRAP(hiprandDestroyGenerator);
|
||||
PERFTOOLS_GPUTOOLS_HIPRAND_WRAP(hiprandSetStream);
|
||||
PERFTOOLS_GPUTOOLS_HIPRAND_WRAP(hiprandGenerateUniform);
|
||||
PERFTOOLS_GPUTOOLS_HIPRAND_WRAP(hiprandGenerateUniformDouble);
|
||||
PERFTOOLS_GPUTOOLS_HIPRAND_WRAP(hiprandSetPseudoRandomGeneratorSeed);
|
||||
PERFTOOLS_GPUTOOLS_HIPRAND_WRAP(hiprandSetGeneratorOffset);
|
||||
PERFTOOLS_GPUTOOLS_HIPRAND_WRAP(hiprandGenerateNormal);
|
||||
PERFTOOLS_GPUTOOLS_HIPRAND_WRAP(hiprandGenerateNormalDouble);
|
||||
#else
|
||||
|
||||
#define STREAM_EXECUTOR_HIPRAND_WRAP(__name) \
|
||||
struct DynLoadShim__##__name { \
|
||||
static const char* kName; \
|
||||
using FuncPtrT = std::add_pointer<decltype(::__name)>::type; \
|
||||
static void* GetDsoHandle() { \
|
||||
auto s = internal::CachedDsoLoader::GetRocrandDsoHandle(); \
|
||||
return s.ValueOrDie(); \
|
||||
} \
|
||||
static FuncPtrT LoadOrDie() { \
|
||||
void* f; \
|
||||
auto s = port::Env::Default()->GetSymbolFromLibrary(GetDsoHandle(), \
|
||||
kName, &f); \
|
||||
CHECK(s.ok()) << "could not find " << kName \
|
||||
<< " in rocrand DSO; dlerror: " << s.error_message(); \
|
||||
return reinterpret_cast<FuncPtrT>(f); \
|
||||
} \
|
||||
static FuncPtrT DynLoad() { \
|
||||
static FuncPtrT f = LoadOrDie(); \
|
||||
return f; \
|
||||
} \
|
||||
template <typename... Args> \
|
||||
hiprandStatus operator()(GpuExecutor* parent, Args... args) { \
|
||||
gpu::ScopedActivateExecutorContext sac{parent}; \
|
||||
return DynLoad()(args...); \
|
||||
} \
|
||||
} __name; \
|
||||
const char* DynLoadShim__##__name::kName = #__name;
|
||||
|
||||
#endif
|
||||
|
||||
STREAM_EXECUTOR_HIPRAND_WRAP(hiprandCreateGenerator);
|
||||
STREAM_EXECUTOR_HIPRAND_WRAP(hiprandDestroyGenerator);
|
||||
STREAM_EXECUTOR_HIPRAND_WRAP(hiprandSetStream);
|
||||
STREAM_EXECUTOR_HIPRAND_WRAP(hiprandGenerateUniform);
|
||||
STREAM_EXECUTOR_HIPRAND_WRAP(hiprandGenerateUniformDouble);
|
||||
STREAM_EXECUTOR_HIPRAND_WRAP(hiprandSetPseudoRandomGeneratorSeed);
|
||||
STREAM_EXECUTOR_HIPRAND_WRAP(hiprandSetGeneratorOffset);
|
||||
STREAM_EXECUTOR_HIPRAND_WRAP(hiprandGenerateNormal);
|
||||
STREAM_EXECUTOR_HIPRAND_WRAP(hiprandGenerateNormalDouble);
|
||||
|
||||
} // namespace wrap
|
||||
|
||||
@ -245,40 +281,45 @@ bool GpuRng::SetSeed(Stream* stream, const uint8* seed, uint64 seed_bytes) {
|
||||
}
|
||||
|
||||
} // namespace gpu
|
||||
|
||||
void initialize_rocrand() {
|
||||
auto rocRandAlreadyRegistered = PluginRegistry::Instance()->HasFactory(
|
||||
rocm::kROCmPlatformId, PluginKind::kRng, gpu::kGpuRandPlugin);
|
||||
|
||||
if (!rocRandAlreadyRegistered) {
|
||||
port::Status status =
|
||||
PluginRegistry::Instance()->RegisterFactory<PluginRegistry::RngFactory>(
|
||||
rocm::kROCmPlatformId, gpu::kGpuRandPlugin, "rocRAND",
|
||||
[](internal::StreamExecutorInterface* parent) -> rng::RngSupport* {
|
||||
gpu::GpuExecutor* rocm_executor =
|
||||
dynamic_cast<gpu::GpuExecutor*>(parent);
|
||||
if (rocm_executor == nullptr) {
|
||||
LOG(ERROR)
|
||||
<< "Attempting to initialize an instance of the hipRAND "
|
||||
<< "support library with a non-ROCM StreamExecutor";
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
gpu::GpuRng* rng = new gpu::GpuRng(rocm_executor);
|
||||
if (!rng->Init()) {
|
||||
// Note: Init() will log a more specific error.
|
||||
delete rng;
|
||||
return nullptr;
|
||||
}
|
||||
return rng;
|
||||
});
|
||||
|
||||
if (!status.ok()) {
|
||||
LOG(ERROR) << "Unable to register rocRAND factory: "
|
||||
<< status.error_message();
|
||||
}
|
||||
|
||||
PluginRegistry::Instance()->SetDefaultFactory(
|
||||
rocm::kROCmPlatformId, PluginKind::kRng, gpu::kGpuRandPlugin);
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace stream_executor
|
||||
|
||||
namespace se = ::stream_executor;
|
||||
|
||||
REGISTER_MODULE_INITIALIZER(register_hiprand, {
|
||||
se::port::Status status =
|
||||
se::PluginRegistry::Instance()
|
||||
->RegisterFactory<se::PluginRegistry::RngFactory>(
|
||||
se::gpu::kROCmPlatformId, se::gpu::kGpuRandPlugin, "hipRAND",
|
||||
[](se::internal::StreamExecutorInterface* parent)
|
||||
-> se::rng::RngSupport* {
|
||||
se::gpu::GpuExecutor* rocm_executor =
|
||||
dynamic_cast<se::gpu::GpuExecutor*>(parent);
|
||||
if (rocm_executor == nullptr) {
|
||||
LOG(ERROR)
|
||||
<< "Attempting to initialize an instance of the hipRAND "
|
||||
<< "support library with a non-ROCM StreamExecutor";
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
se::gpu::GpuRng* rng = new se::gpu::GpuRng(rocm_executor);
|
||||
if (!rng->Init()) {
|
||||
// Note: Init() will log a more specific error.
|
||||
delete rng;
|
||||
return nullptr;
|
||||
}
|
||||
return rng;
|
||||
});
|
||||
|
||||
if (!status.ok()) {
|
||||
LOG(ERROR) << "Unable to register hipRAND factory: "
|
||||
<< status.error_message();
|
||||
}
|
||||
|
||||
se::PluginRegistry::Instance()->SetDefaultFactory(
|
||||
se::gpu::kROCmPlatformId, se::PluginKind::kRng, se::gpu::kGpuRandPlugin);
|
||||
});
|
||||
REGISTER_MODULE_INITIALIZER(register_rocrand,
|
||||
{ stream_executor::initialize_rocrand(); });
|
||||
|
Loading…
Reference in New Issue
Block a user