Merge pull request from ROCmSoftwarePlatform:google_upstream_add_rocm_blas_fft_rand

PiperOrigin-RevId: 234978728
This commit is contained in:
TensorFlower Gardener 2019-02-21 05:39:56 -08:00
commit 6a25b9d561
16 changed files with 3824 additions and 213 deletions

View File

@ -39,7 +39,7 @@ port::StatusOr<void*> GetDsoHandle(const string& name, const string& version) {
port::Status status =
port::Env::Default()->LoadLibrary(filename.c_str(), &dso_handle);
if (status.ok()) {
LOG(INFO) << "Successfully opened CUDA library " << filename;
LOG(INFO) << "Successfully opened dynamic library " << filename;
return dso_handle;
}
@ -99,6 +99,25 @@ port::StatusOr<void*> GetCuptiDsoHandle() {
port::StatusOr<void*> GetCudnnDsoHandle() {
return GetDsoHandle("cudnn", GetCudnnVersion());
}
port::StatusOr<void*> GetRocblasDsoHandle() {
return GetDsoHandle("rocblas", "");
}
port::StatusOr<void*> GetMiopenDsoHandle() {
return GetDsoHandle("MIOpen", "");
}
port::StatusOr<void*> GetRocfftDsoHandle() {
return GetDsoHandle("rocfft", "");
}
port::StatusOr<void*> GetRocrandDsoHandle() {
return GetDsoHandle("rocrand", "");
}
port::StatusOr<void*> GetHipDsoHandle() { return GetDsoHandle("hip_hcc", ""); }
} // namespace DsoLoader
namespace CachedDsoLoader {
@ -136,6 +155,32 @@ port::StatusOr<void*> GetCudnnDsoHandle() {
static auto result = new auto(DsoLoader::GetCudnnDsoHandle());
return *result;
}
port::StatusOr<void*> GetRocblasDsoHandle() {
static auto result = new auto(DsoLoader::GetRocblasDsoHandle());
return result;
}
port::StatusOr<void*> GetMiopenDsoHandle() {
static auto result = new auto(DsoLoader::GetMiopenDsoHandle());
return result;
}
port::StatusOr<void*> GetRocfftDsoHandle() {
static auto result = new auto(DsoLoader::GetRocfftDsoHandle());
return result;
}
port::StatusOr<void*> GetRocrandDsoHandle() {
static auto result = new auto(DsoLoader::GetRocrandDsoHandle());
return result;
}
port::StatusOr<void*> GetHipDsoHandle() {
static auto result = new auto(DsoLoader::GetHipDsoHandle());
return result;
}
} // namespace CachedDsoLoader
} // namespace internal
} // namespace stream_executor

View File

@ -41,6 +41,12 @@ port::StatusOr<void*> GetCufftDsoHandle();
port::StatusOr<void*> GetCurandDsoHandle();
port::StatusOr<void*> GetCuptiDsoHandle();
port::StatusOr<void*> GetCudnnDsoHandle();
port::StatusOr<void*> GetRocblasDsoHandle();
port::StatusOr<void*> GetMiopenDsoHandle();
port::StatusOr<void*> GetRocfftDsoHandle();
port::StatusOr<void*> GetRocrandDsoHandle();
port::StatusOr<void*> GetHipDsoHandle();
} // namespace DsoLoader
// Wrapper around the DsoLoader that prevents us from dlopen'ing any of the DSOs
@ -54,7 +60,14 @@ port::StatusOr<void*> GetCufftDsoHandle();
port::StatusOr<void*> GetCurandDsoHandle();
port::StatusOr<void*> GetCuptiDsoHandle();
port::StatusOr<void*> GetCudnnDsoHandle();
port::StatusOr<void*> GetRocblasDsoHandle();
port::StatusOr<void*> GetMiopenDsoHandle();
port::StatusOr<void*> GetRocfftDsoHandle();
port::StatusOr<void*> GetRocrandDsoHandle();
port::StatusOr<void*> GetHipDsoHandle();
} // namespace CachedDsoLoader
} // namespace internal
} // namespace stream_executor

View File

@ -33,7 +33,7 @@ filegroup(
cc_library(
name = "rocm_diagnostics",
srcs = if_rocm_is_configured(["rocm_diagnostics.cc"]),
hdrs = [],
hdrs = if_rocm_is_configured(["rocm_diagnostics.h"]),
deps = if_rocm_is_configured([
"@com_google_absl//absl/container:inlined_vector",
"@com_google_absl//absl/strings",
@ -47,7 +47,7 @@ cc_library(
cc_library(
name = "rocm_driver",
srcs = if_rocm_is_configured(["rocm_driver.cc"]),
hdrs = [],
hdrs = if_rocm_is_configured(["rocm_driver_wrapper.h"]),
deps = if_rocm_is_configured([
":rocm_diagnostics",
"@com_google_absl//absl/base",
@ -57,6 +57,7 @@ cc_library(
"//tensorflow/stream_executor/gpu:gpu_driver_header",
"//tensorflow/stream_executor/lib",
"//tensorflow/stream_executor/platform",
"//tensorflow/stream_executor/platform:dso_loader",
"@local_config_rocm//rocm:rocm_headers",
]),
)
@ -140,55 +141,62 @@ cc_library(
deps = ["//tensorflow/stream_executor:platform"],
)
# FIXME: enable in future PRs
#cc_library(
# name = "rocblas_plugin",
# srcs = ["rocm_blas.cc"],
# hdrs = ["rocm_blas.h"],
# visibility = ["//visibility:public"],
# deps = [
# ":rocm_gpu_executor",
# ":rocm_platform_id",
# "//third_party/eigen3",
# "//tensorflow/core:lib_internal",
# "//tensorflow/stream_executor",
# "//tensorflow/stream_executor:event",
# "//tensorflow/stream_executor:host_or_device_scalar",
# "//tensorflow/stream_executor:plugin_registry",
# "//tensorflow/stream_executor:scratch_allocator",
# "//tensorflow/stream_executor:timer",
# "//tenosrflow/stream_executor/gpu:gpu_activation_header",
# "//tenosrflow/stream_executor/gpu:gpu_stream_header",
# "//tenosrflow/stream_executor/gpu:gpu_timer_header",
# "//tensorflow/stream_executor/lib",
# "//tensorflow/stream_executor/platform",
# "//tensorflow/stream_executor/platform:dso_loader",
# "@com_google_absl//absl/strings",
# "@local_config_rocm//rocm:rocm_headers",
# ] + if_static(["@local_config_rocm//rocm:rocblas"]),
# alwayslink = True,
#)
cc_library(
name = "rocblas_plugin",
srcs = if_rocm_is_configured(["rocm_blas.cc"]),
hdrs = if_rocm_is_configured(["rocm_blas.h"]),
visibility = ["//visibility:public"],
deps = if_rocm_is_configured([
":rocm_gpu_executor",
":rocm_platform_id",
"//third_party/eigen3",
"//tensorflow/core:lib_internal",
"//tensorflow/stream_executor",
"//tensorflow/stream_executor:event",
"//tensorflow/stream_executor:host_or_device_scalar",
"//tensorflow/stream_executor:plugin_registry",
"//tensorflow/stream_executor:scratch_allocator",
"//tensorflow/stream_executor:timer",
"//tensorflow/stream_executor/gpu:gpu_activation",
"//tensorflow/stream_executor/gpu:gpu_helpers_header",
"//tensorflow/stream_executor/gpu:gpu_stream_header",
"//tensorflow/stream_executor/gpu:gpu_timer_header",
"//tensorflow/stream_executor/lib",
"//tensorflow/stream_executor/platform",
"//tensorflow/stream_executor/platform:dso_loader",
"@com_google_absl//absl/strings",
"@local_config_rocm//rocm:rocm_headers",
] + if_static([
"@local_config_rocm//rocm:rocblas",
])),
alwayslink = True,
)
# FIXME: enable in future PRs
#cc_library(
# name = "rocfft_plugin",
# srcs = ["rocm_fft.cc"],
# hdrs = [],
# visibility = ["//visibility:public"],
# deps = [
# ":rocm_platform_id",
# "//tensorflow/stream_executor:event",
# "//tensorflow/stream_executor:fft",
# "//tensorflow/stream_executor:plugin_registry",
# "//tensorflow/stream_executor:scratch_allocator",
# "//tenosrflow/stream_executor/gpu:gpu_stream_header",
# "//tensorflow/stream_executor/lib",
# "//tensorflow/stream_executor/platform",
# "//tensorflow/stream_executor/platform:dso_loader",
# "@local_config_rocm//rocm:rocm_headers",
# ] + if_static(["@local_config_rocm//rocm:rocfft"]),
# alwayslink = True,
#)
cc_library(
name = "rocfft_plugin",
srcs = if_rocm_is_configured(["rocm_fft.cc"]),
hdrs = if_rocm_is_configured(["rocm_fft.h"]),
visibility = ["//visibility:public"],
deps = if_rocm_is_configured([
":rocm_platform_id",
"//tensorflow/stream_executor:event",
"//tensorflow/stream_executor:fft",
"//tensorflow/stream_executor:plugin_registry",
"//tensorflow/stream_executor:scratch_allocator",
"//tensorflow/stream_executor/gpu:gpu_activation",
"//tensorflow/stream_executor/gpu:gpu_helpers_header",
"//tensorflow/stream_executor/gpu:gpu_executor_header",
"//tensorflow/stream_executor/gpu:gpu_stream_header",
"//tensorflow/stream_executor/gpu:gpu_kernel_header",
"//tensorflow/stream_executor/lib",
"//tensorflow/stream_executor/platform",
"//tensorflow/stream_executor/platform:dso_loader",
"@local_config_rocm//rocm:rocm_headers",
] + if_static([
"@local_config_rocm//rocm:rocfft",
])),
alwayslink = True,
)
# FIXME: enable in future PRs
#cc_library(
@ -229,26 +237,30 @@ cc_library(
# alwayslink = True,
#)
# FIXME: enable in future PRs
#cc_library(
# name = "rocrand_plugin",
# srcs = ["rocm_rng.cc"],
# hdrs = [],
# deps = [
# ":rocm_gpu_executor",
# ":rocm_platform_id",
# "@local_config_rocm//rocm:rocm_headers",
# "//tensorflow/stream_executor:event",
# "//tensorflow/stream_executor:plugin_registry",
# "//tensorflow/stream_executor:rng",
# "//tenosrflow/stream_executor/gpu:gpu_activation_header",
# "//tenosrflow/stream_executor/gpu:gpu_stream_header",
# "//tensorflow/stream_executor/lib",
# "//tensorflow/stream_executor/platform",
# "//tensorflow/stream_executor/platform:dso_loader",
# ] + if_static(["@local_config_rocm//rocm:curand"]),
# alwayslink = True,
#)
cc_library(
name = "rocrand_plugin",
srcs = if_rocm_is_configured(["rocm_rng.cc"]),
hdrs = if_rocm_is_configured([]),
deps = if_rocm_is_configured([
":rocm_gpu_executor",
":rocm_platform_id",
"@local_config_rocm//rocm:rocm_headers",
"//tensorflow/stream_executor:event",
"//tensorflow/stream_executor:plugin_registry",
"//tensorflow/stream_executor:rng",
"//tensorflow/stream_executor/gpu:gpu_activation_header",
"//tensorflow/stream_executor/gpu:gpu_helpers_header",
"//tensorflow/stream_executor/gpu:gpu_executor_header",
"//tensorflow/stream_executor/gpu:gpu_rng_header",
"//tensorflow/stream_executor/gpu:gpu_stream_header",
"//tensorflow/stream_executor/lib",
"//tensorflow/stream_executor/platform",
"//tensorflow/stream_executor/platform:dso_loader",
] + if_static([
"@local_config_rocm//rocm:hiprand",
])),
alwayslink = True,
)
cc_library(
name = "all_runtime",
@ -257,9 +269,9 @@ cc_library(
deps = if_rocm_is_configured([
# FIXME: enable in future PRs
#":miopen_plugin",
#":rocfft_plugin",
#":rocblas_plugin",
#":rocrand_plugin",
":rocfft_plugin",
":rocblas_plugin",
":rocrand_plugin",
":rocm_driver",
":rocm_platform",
]),

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,159 @@
/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
// ROCM-specific support for BLAS functionality -- this wraps the rocBLAS
// library capabilities, and is only included into ROCM implementation code --
// it will not introduce rocm headers into other code.
#ifndef TENSORFLOW_STREAM_EXECUTOR_ROCM_ROCM_BLAS_H_
#define TENSORFLOW_STREAM_EXECUTOR_ROCM_ROCM_BLAS_H_
#include "tensorflow/stream_executor/blas.h"
#include "tensorflow/stream_executor/platform/mutex.h"
#include "tensorflow/stream_executor/platform/port.h"
#include "tensorflow/stream_executor/platform/thread_annotations.h"
#include "tensorflow/stream_executor/plugin_registry.h"
namespace stream_executor {
class Stream;
namespace gpu {
// Opaque and unique identifier for the rocBLAS plugin.
extern const PluginId kRocBlasPlugin;
class GpuExecutor;
// BLAS plugin for ROCM platform via rocBLAS library.
//
// This satisfies the platform-agnostic BlasSupport interface.
//
// Note that the rocBLAS handle that this encapsulates is implicitly tied to the
// context (and, as a result, the device) that the parent GpuExecutor is tied
// to. This simply happens as an artifact of creating the rocBLAS handle when a
// ROCM context is active.
//
// Thread-safe post-initialization.
class ROCMBlas : public blas::BlasSupport {
public:
explicit ROCMBlas(GpuExecutor *parent);
// Allocates a rocBLAS handle.
bool Init();
// Releases the rocBLAS handle, if present.
~ROCMBlas() override;
TENSORFLOW_STREAM_EXECUTOR_GPU_BLAS_SUPPORT_OVERRIDES
private:
// Tells rocBLAS to enqueue the BLAS operation onto a particular Stream.
//
// rocBLAS is stateful, and only be associated with one stream (in order to
// enqueue dispatch) at a given time. As a result, this generally must be
// invoked before calling into rocBLAS.
bool SetStream(Stream *stream) EXCLUSIVE_LOCKS_REQUIRED(mu_);
// A helper function that calls the real rocBLAS function together with error
// handling.
//
// rocblas_func: rocBLAS function pointer.
// rocblas_name: rocBLAS function name.
// stream: Stream to enqueue the BLAS operation onto.
// pointer_mode_host: Indicate if the pointer to a scalar value is from host
// (true) or device (false).
// err_on_failure: Whether to print an error if the rocBLAS function
// fails. args: Arguments of rocBLAS function.
template <typename FuncT, typename... Args>
bool DoBlasInternalImpl(FuncT rocblas_func, Stream *stream,
bool pointer_mode_host, bool err_on_failure,
Args... args);
// Convenience functions that call DoBlasInternalImpl with different values
// for err_on_failure.
template <typename FuncT, typename... Args>
bool DoBlasInternal(FuncT rocblas_func, Stream *stream,
bool pointer_mode_host, Args... args) {
return DoBlasInternalImpl(rocblas_func, stream, pointer_mode_host,
/*err_on_failure=*/true, args...);
}
template <typename FuncT, typename... Args>
bool DoBlasInternalFailureOK(FuncT rocblas_func, Stream *stream,
bool pointer_mode_host, Args... args) {
return DoBlasInternalImpl(rocblas_func, stream, pointer_mode_host,
/*err_on_failure=*/false, args...);
}
// A helper function to implement DoBlasGemmBatched interfaces for generic
// types.
template <typename T, typename FuncT>
port::Status DoBlasGemmBatchedInternal(
FuncT rocblas_func, Stream *stream, blas::Transpose transa,
blas::Transpose transb, uint64 m, uint64 n, uint64 k, T alpha,
const port::ArraySlice<DeviceMemory<T> *> &a_array, int lda,
const port::ArraySlice<DeviceMemory<T> *> &b_array, int ldb, T beta,
const port::ArraySlice<DeviceMemory<T> *> &c_array, int ldc,
int batch_count, ScratchAllocator *scratch_allocator);
// Helper function for implementing DoBlasGemmWithAlgorithm.
//
// We take alpha and beta by const reference because T might be Eigen::half,
// and we want to avoid pulling in a dependency on Eigen. When we pass the
// references to rocBLAS, we essentially reinterpret_cast to __half, which is
// safe because Eigen::half inherits from __half.
template <typename InT, typename OutT, typename CompT>
bool DoBlasGemmWithAlgorithmImpl(
Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64 m,
uint64 n, uint64 k, const CompT &alpha, const DeviceMemory<InT> &a,
int lda, const DeviceMemory<InT> &b, int ldb, const CompT &beta,
DeviceMemory<OutT> *c, int ldc, blas::ComputationType computation_type,
blas::AlgorithmType algorithm,
blas::ProfileResult *output_profile_result);
// Helper function for implementing DoBlasGemmWithProfiling.
template <typename T, typename ParamType>
bool DoBlasGemmWithProfilingImpl(
Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64 m,
uint64 n, uint64 k, const ParamType &alpha, const DeviceMemory<T> &a,
int lda, const DeviceMemory<T> &b, int ldb, const ParamType &beta,
DeviceMemory<T> *c, int ldc, blas::ProfileResult *output_profile_result);
// Helper function for implementing DoBlasGemvWithProfiling.
template <typename T>
bool DoBlasGemvWithProfilingImpl(Stream *stream, blas::Transpose trans,
uint64 m, uint64 n, const T &alpha,
const DeviceMemory<T> &a, int lda,
const DeviceMemory<T> &x, int incx,
const T &beta, DeviceMemory<T> *y, int incy,
blas::ProfileResult *output_profile_result);
// mutex that guards the rocBLAS handle for this device.
mutex mu_;
// GpuExecutor which instantiated this ROCMBlas.
// Immutable post-initialization.
GpuExecutor *parent_;
// rocBLAS library handle on the device.
rocblas_handle blas_ GUARDED_BY(mu_);
SE_DISALLOW_COPY_AND_ASSIGN(ROCMBlas);
};
} // namespace gpu
} // namespace stream_executor
#endif // TENSORFLOW_STREAM_EXECUTOR_ROCM_ROCM_BLAS_H_

View File

@ -30,7 +30,6 @@ limitations under the License.
#include "absl/container/inlined_vector.h"
#include "absl/strings/str_cat.h"
#include "absl/strings/str_format.h"
#include "tensorflow/stream_executor/gpu/gpu_diagnostics.h"
#include "tensorflow/stream_executor/lib/error.h"
#include "tensorflow/stream_executor/lib/numbers.h"
#include "tensorflow/stream_executor/lib/process_state.h"
@ -38,9 +37,10 @@ limitations under the License.
#include "tensorflow/stream_executor/lib/str_util.h"
#include "tensorflow/stream_executor/lib/stringprintf.h"
#include "tensorflow/stream_executor/platform/logging.h"
#include "tensorflow/stream_executor/rocm/rocm_diagnostics.h"
namespace stream_executor {
namespace gpu {
namespace rocm {
string DriverVersionToString(DriverVersion version) {
return absl::StrFormat("%d.%d.%d", std::get<0>(version), std::get<1>(version),
@ -95,6 +95,12 @@ port::StatusOr<DriverVersion> StringToDriverVersion(const string& value) {
return result;
}
} // namespace rocm
} // namespace stream_executor
namespace stream_executor {
namespace gpu {
// -- class Diagnostician
string Diagnostician::GetDevNodePath(int dev_node_ordinal) {
@ -133,11 +139,11 @@ void Diagnostician::LogDiagnosticInformation() {
}
port::StatusOr<DriverVersion> dso_version = FindDsoVersion();
LOG(INFO) << "librocm reported version is: "
<< DriverVersionStatusToString(dso_version);
<< rocm::DriverVersionStatusToString(dso_version);
port::StatusOr<DriverVersion> kernel_version = FindKernelDriverVersion();
LOG(INFO) << "kernel reported version is: "
<< DriverVersionStatusToString(kernel_version);
<< rocm::DriverVersionStatusToString(kernel_version);
if (kernel_version.ok() && dso_version.ok()) {
WarnOnDsoKernelMismatch(dso_version, kernel_version);
@ -175,7 +181,7 @@ port::StatusOr<DriverVersion> Diagnostician::FindDsoVersion() {
// TODO(b/22689637): Eliminate the explicit namespace if possible.
auto stripped_dso_version = port::StripSuffixString(dso_version, ".ld64");
auto result = static_cast<port::StatusOr<DriverVersion>*>(data);
*result = StringToDriverVersion(stripped_dso_version);
*result = rocm::StringToDriverVersion(stripped_dso_version);
return 1;
}
return 0;
@ -205,7 +211,7 @@ port::StatusOr<DriverVersion> Diagnostician::FindKernelModuleVersion(
// TODO(b/22689637): Eliminate the explicit namespace if possible.
auto stripped_kernel_version =
port::StripSuffixString(kernel_version, ".ld64");
return StringToDriverVersion(stripped_kernel_version);
return rocm::StringToDriverVersion(stripped_kernel_version);
}
void Diagnostician::WarnOnDsoKernelMismatch(
@ -214,12 +220,12 @@ void Diagnostician::WarnOnDsoKernelMismatch(
if (kernel_version.ok() && dso_version.ok() &&
dso_version.ValueOrDie() == kernel_version.ValueOrDie()) {
LOG(INFO) << "kernel version seems to match DSO: "
<< DriverVersionToString(kernel_version.ValueOrDie());
<< rocm::DriverVersionToString(kernel_version.ValueOrDie());
} else {
LOG(ERROR) << "kernel version "
<< DriverVersionStatusToString(kernel_version)
<< rocm::DriverVersionStatusToString(kernel_version)
<< " does not match DSO version "
<< DriverVersionStatusToString(dso_version)
<< rocm::DriverVersionStatusToString(dso_version)
<< " -- cannot find working devices in this configuration";
}
}

View File

@ -0,0 +1,41 @@
/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_STREAM_EXECUTOR_ROCM_ROCM_DIAGNOSTICS_H_
#define TENSORFLOW_STREAM_EXECUTOR_ROCM_ROCM_DIAGNOSTICS_H_
#include "tensorflow/stream_executor/gpu/gpu_diagnostics.h"
namespace stream_executor {
namespace rocm {
// e.g. DriverVersion{346, 3, 4}
using DriverVersion = gpu::DriverVersion;
// Converts a parsed driver version to string form.
string DriverVersionToString(DriverVersion version);
// Converts a parsed driver version or status value to natural string form.
string DriverVersionStatusToString(port::StatusOr<DriverVersion> version);
// Converts a string of a form like "331.79" to a DriverVersion{331, 79}.
port::StatusOr<DriverVersion> StringToDriverVersion(const string& value);
using Diagnostician = gpu::Diagnostician;
} // namespace rocm
} // namespace stream_executor
#endif // TENSORFLOW_STREAM_EXECUTOR_ROCM_ROCM_DIAGNOSTICS_H_

View File

@ -36,6 +36,7 @@ limitations under the License.
#include "tensorflow/stream_executor/platform/logging.h"
#include "tensorflow/stream_executor/platform/mutex.h"
#include "tensorflow/stream_executor/platform/port.h"
#include "tensorflow/stream_executor/rocm/rocm_driver_wrapper.h"
bool FLAGS_gpuexec_rocm_driver_inject_init_error = false;
bool FLAGS_gpuexec_rocm_sync_around_driver_calls = false;
@ -143,7 +144,7 @@ string MemorySpaceString(MemorySpace memory_space) {
// HIP driver (e.g., this value is not our cached view of the current device).
static int CurrentDeviceOrDie() {
int current = -1;
hipError_t result = hipGetDevice(&current);
hipError_t result = tensorflow::wrap::hipGetDevice(&current);
if (result != hipSuccess) {
LOG(FATAL) << "failed to query current device: " << ToString(result);
}
@ -154,7 +155,7 @@ namespace {
// Call hipDeviceSynchronize and crash if it doesn't succeed.
void SynchronizeOrDie() {
auto res = hipDeviceSynchronize();
auto res = tensorflow::wrap::hipDeviceSynchronize();
if (res != hipSuccess) {
LOG(FATAL) << "Synchronize found " << ToString(res)
<< " :: " << port::CurrentStackTrace();
@ -197,7 +198,8 @@ ScopedActivateContext::ScopedActivateContext(GpuContext* context) {
<< tls->current_device_ordinal << " to " << context->device_ordinal();
// Set the device and update thread local.
CHECK_EQ(hipSuccess, hipSetDevice(context->device_ordinal()));
CHECK_EQ(hipSuccess,
tensorflow::wrap::hipSetDevice(context->device_ordinal()));
tls->current_device_ordinal = context->device_ordinal();
}
@ -225,7 +227,8 @@ ScopedActivateContext::~ScopedActivateContext() {
<< to_restore_->device_ordinal();
// Set context and update thread local.
CHECK_EQ(hipSuccess, hipSetDevice(to_restore_->device_ordinal()));
CHECK_EQ(hipSuccess,
tensorflow::wrap::hipSetDevice(to_restore_->device_ordinal()));
tls->current_device_ordinal = to_restore_->device_ordinal();
}
@ -261,7 +264,8 @@ string ROCMPointerToMemorySpaceString(hipDeviceptr_t pointer) {
// in the process of querying.
string ROCMPointersToCanAccessString(hipDeviceptr_t from, hipDeviceptr_t to) {
hipPointerAttribute_t from_pointerAttributes;
hipError_t result = hipPointerGetAttributes(&from_pointerAttributes, from);
hipError_t result =
tensorflow::wrap::hipPointerGetAttributes(&from_pointerAttributes, from);
if (result != hipSuccess) {
LOG(ERROR) << "could not retrieve source pointer's device: "
<< ToString(result);
@ -269,7 +273,7 @@ string ROCMPointersToCanAccessString(hipDeviceptr_t from, hipDeviceptr_t to) {
}
hipPointerAttribute_t to_pointerAttributes;
result = hipPointerGetAttributes(&to_pointerAttributes, to);
result = tensorflow::wrap::hipPointerGetAttributes(&to_pointerAttributes, to);
if (result != hipSuccess) {
LOG(ERROR) << "could not retrieve destination pointer's device: "
<< ToString(result);
@ -289,7 +293,7 @@ static port::Status InternalInit() {
if (FLAGS_gpuexec_rocm_driver_inject_init_error) {
LOG(ERROR) << "injecting ROCM init error; initialization will fail";
} else {
res = hipInit(0 /* = flags */);
res = tensorflow::wrap::hipInit(0 /* = flags */);
}
if (res == hipSuccess) {
@ -322,7 +326,7 @@ static port::Status InternalInit() {
/* static */ port::Status GpuDriver::GetDevice(int device_ordinal,
hipDevice_t* device) {
hipError_t res = hipDeviceGet(device, device_ordinal);
hipError_t res = tensorflow::wrap::hipDeviceGet(device, device_ordinal);
if (res == hipSuccess) {
return port::Status::OK();
}
@ -336,7 +340,8 @@ static port::Status InternalInit() {
string* device_name) {
static const size_t kCharLimit = 64;
absl::InlinedVector<char, 4> chars(kCharLimit);
hipError_t res = hipDeviceGetName(chars.begin(), kCharLimit - 1, device);
hipError_t res =
tensorflow::wrap::hipDeviceGetName(chars.begin(), kCharLimit - 1, device);
if (res != hipSuccess) {
LOG(ERROR) << "failed to get device name for " << device << ": "
<< ToString(res);
@ -382,7 +387,8 @@ bool DeviceOptionsToContextFlags(const DeviceOptions& device_options,
/* static */ bool GpuDriver::FuncSetCacheConfig(hipFunction_t function,
hipFuncCache_t cache_config) {
hipError_t res = hipFuncSetCacheConfig(function, cache_config);
hipError_t res =
tensorflow::wrap::hipFuncSetCacheConfig(function, cache_config);
if (res != hipSuccess) {
LOG(ERROR) << "failed to set ROCM kernel cache config. kernel: " << function
<< ", config: " << cache_config << ", result: " << ToString(res);
@ -396,7 +402,8 @@ bool DeviceOptionsToContextFlags(const DeviceOptions& device_options,
GpuDriver::ContextGetSharedMemConfig(GpuContext* context) {
hipSharedMemConfig shared_mem_config;
ScopedActivateContext activation{context};
hipError_t result = hipDeviceGetSharedMemConfig(&shared_mem_config);
hipError_t result =
tensorflow::wrap::hipDeviceGetSharedMemConfig(&shared_mem_config);
if (result != hipSuccess) {
LOG(ERROR) << "failed to get ROCM device shared memory config. "
<< "Context device ID: " << context->device_ordinal()
@ -411,7 +418,8 @@ GpuDriver::ContextGetSharedMemConfig(GpuContext* context) {
/* static */ port::Status GpuDriver::ContextSetSharedMemConfig(
GpuContext* context, hipSharedMemConfig shared_mem_config) {
ScopedActivateContext activation{context};
hipError_t result = hipDeviceSetSharedMemConfig(shared_mem_config);
hipError_t result =
tensorflow::wrap::hipDeviceSetSharedMemConfig(shared_mem_config);
if (result != hipSuccess) {
LOG(ERROR) << "failed to set ROCM device shared memory config. "
<< "Context device ID: " << context->device_ordinal()
@ -435,7 +443,7 @@ GpuDriver::ContextGetSharedMemConfig(GpuContext* context) {
<< " gdy: " << grid_dim_y << " gdz: " << grid_dim_z
<< " bdx: " << block_dim_x << " bdy: " << block_dim_y
<< " bdz: " << block_dim_z << " smem: " << shared_mem_bytes;
hipError_t res = hipModuleLaunchKernel(
hipError_t res = tensorflow::wrap::hipModuleLaunchKernel(
function, grid_dim_x, grid_dim_y, grid_dim_z, block_dim_x, block_dim_y,
block_dim_z, shared_mem_bytes, stream, kernel_params, extra);
if (res != hipSuccess) {
@ -471,7 +479,8 @@ GpuDriver::ContextGetSharedMemConfig(GpuContext* context) {
ScopedActivateContext activation{context};
void* hsaco_data = const_cast<char*>(hsaco_contents);
hipError_t res = hipModuleLoadData(module, hsaco_data);
hipError_t res =
tensorflow::wrap::hipModuleLoadData(module, hsaco_data);
if (res != hipSuccess) {
LOG(ERROR) << "failed to load HSACO: " << ToString(res);
@ -491,7 +500,7 @@ GpuDriver::ContextGetSharedMemConfig(GpuContext* context) {
hipDeviceptr_t location,
uint8 value, size_t size) {
ScopedActivateContext activation{context};
hipError_t res = hipMemset(location, value, size);
hipError_t res = tensorflow::wrap::hipMemset(location, value, size);
if (res != hipSuccess) {
LOG(ERROR) << "failed to memset memory: " << ToString(res);
return false;
@ -512,8 +521,8 @@ GpuDriver::ContextGetSharedMemConfig(GpuContext* context) {
LOG(ERROR) << "failed to memset memory";
return false;
}
hipError_t res =
hipMemset(pointer, static_cast<int>(value), uint32_count * 4);
hipError_t res = tensorflow::wrap::hipMemset(pointer, static_cast<int>(value),
uint32_count * 4);
if (res != hipSuccess) {
LOG(ERROR) << "failed to memset memory: " << ToString(res);
return false;
@ -527,7 +536,8 @@ GpuDriver::ContextGetSharedMemConfig(GpuContext* context) {
size_t uint32_count,
GpuStreamHandle stream) {
ScopedActivateContext activation{context};
hipError_t res = hipMemsetAsync(location, value, uint32_count, stream);
hipError_t res =
tensorflow::wrap::hipMemsetAsync(location, value, uint32_count, stream);
if (res != hipSuccess) {
LOG(ERROR) << "failed to enqueue async memset operation: " << ToString(res);
return false;
@ -552,7 +562,8 @@ GpuDriver::ContextGetSharedMemConfig(GpuContext* context) {
LOG(ERROR) << "failed to memset memory";
return false;
}
hipError_t res = hipMemsetAsync(pointer, value, uint32_count * 4, stream);
hipError_t res = tensorflow::wrap::hipMemsetAsync(pointer, value,
uint32_count * 4, stream);
if (res != hipSuccess) {
LOG(ERROR) << "failed to enqueue async memset operation: " << ToString(res);
return false;
@ -565,8 +576,8 @@ GpuDriver::ContextGetSharedMemConfig(GpuContext* context) {
GpuStreamHandle stream,
StreamCallback callback,
void* data) {
hipError_t res = hipStreamAddCallback(stream, (hipStreamCallback_t)callback,
data, 0 /* = flags */);
hipError_t res = tensorflow::wrap::hipStreamAddCallback(
stream, (hipStreamCallback_t)callback, data, 0 /* = flags */);
if (res != hipSuccess) {
LOG(ERROR) << "unable to add host callback: " << ToString(res);
return false;
@ -580,7 +591,8 @@ GpuDriver::ContextGetSharedMemConfig(GpuContext* context) {
hipFunction_t* function) {
ScopedActivateContext activated{context};
CHECK(module != nullptr && kernel_name != nullptr);
hipError_t res = hipModuleGetFunction(function, module, kernel_name);
hipError_t res =
tensorflow::wrap::hipModuleGetFunction(function, module, kernel_name);
if (res != hipSuccess) {
LOG(ERROR) << "failed to get kernel \"" << kernel_name
<< "\" from module: " << ToString(res);
@ -598,7 +610,8 @@ GpuDriver::ContextGetSharedMemConfig(GpuContext* context) {
ScopedActivateContext activated{context};
CHECK(module != nullptr && symbol_name != nullptr &&
(dptr != nullptr || bytes != nullptr));
hipError_t res = hipModuleGetGlobal(dptr, bytes, module, symbol_name);
hipError_t res =
tensorflow::wrap::hipModuleGetGlobal(dptr, bytes, module, symbol_name);
if (res != hipSuccess) {
// symbol may not be found in the current module, but it may reside in
// another module.
@ -613,7 +626,7 @@ GpuDriver::ContextGetSharedMemConfig(GpuContext* context) {
/* static */ void GpuDriver::UnloadModule(GpuContext* context,
hipModule_t module) {
ScopedActivateContext activated{context};
hipError_t res = hipModuleUnload(module);
hipError_t res = tensorflow::wrap::hipModuleUnload(module);
if (res != hipSuccess) {
LOG(ERROR) << "failed to unload module " << module
<< "; leaking: " << ToString(res);
@ -623,7 +636,7 @@ GpuDriver::ContextGetSharedMemConfig(GpuContext* context) {
/* static */ bool GpuDriver::CreateStream(GpuContext* context,
GpuStreamHandle* stream) {
ScopedActivateContext activated{context};
hipError_t res = hipStreamCreateWithFlags(
hipError_t res = tensorflow::wrap::hipStreamCreateWithFlags(
stream, hipStreamDefault); // switch to hipStreamNonBlocking?
if (res != hipSuccess) {
LOG(ERROR) << "could not allocate ROCM stream for device "
@ -643,7 +656,7 @@ GpuDriver::ContextGetSharedMemConfig(GpuContext* context) {
}
ScopedActivateContext activated{context};
hipError_t res = hipStreamDestroy(*stream);
hipError_t res = tensorflow::wrap::hipStreamDestroy(*stream);
if (res != hipSuccess) {
LOG(ERROR) << "failed to destroy ROCM stream for device "
<< context->device_ordinal() << ": " << ToString(res);
@ -658,7 +671,7 @@ GpuDriver::ContextGetSharedMemConfig(GpuContext* context) {
uint64 bytes) {
ScopedActivateContext activated{context};
hipDeviceptr_t result = 0;
hipError_t res = hipMalloc(&result, bytes);
hipError_t res = tensorflow::wrap::hipMallocVanilla(&result, bytes);
if (res != hipSuccess) {
LOG(ERROR) << "failed to allocate "
<< port::HumanReadableNumBytes::ToString(bytes) << " (" << bytes
@ -675,7 +688,7 @@ GpuDriver::ContextGetSharedMemConfig(GpuContext* context) {
void* location) {
ScopedActivateContext activation{context};
hipDeviceptr_t pointer = absl::bit_cast<hipDeviceptr_t>(location);
hipError_t res = hipFree(pointer);
hipError_t res = tensorflow::wrap::hipFree(pointer);
if (res != hipSuccess) {
LOG(ERROR) << "failed to free device memory at " << location
<< "; result: " << ToString(res);
@ -704,7 +717,8 @@ GpuDriver::ContextGetSharedMemConfig(GpuContext* context) {
ScopedActivateContext activation{context};
void* host_mem = nullptr;
// "Portable" memory is visible to all ROCM contexts. Safe for our use model.
hipError_t res = hipHostMalloc(&host_mem, bytes, hipHostMallocPortable);
hipError_t res = tensorflow::wrap::hipHostMallocVanilla(
&host_mem, bytes, hipHostMallocPortable);
if (res != hipSuccess) {
LOG(ERROR) << "failed to alloc " << bytes
<< " bytes on host: " << ToString(res);
@ -715,7 +729,7 @@ GpuDriver::ContextGetSharedMemConfig(GpuContext* context) {
/* static */ void GpuDriver::HostDeallocate(GpuContext* context,
void* location) {
ScopedActivateContext activation{context};
hipError_t res = hipHostFree(location);
hipError_t res = tensorflow::wrap::hipHostFree(location);
if (res != hipSuccess) {
LOG(ERROR) << "error deallocating host memory at " << location << ": "
<< ToString(res);
@ -726,7 +740,8 @@ GpuDriver::ContextGetSharedMemConfig(GpuContext* context) {
uint64 bytes) {
ScopedActivateContext activation{context};
// "Portable" memory is visible to all ROCM contexts. Safe for our use model.
hipError_t res = hipHostRegister(location, bytes, hipHostRegisterPortable);
hipError_t res = tensorflow::wrap::hipHostRegister(location, bytes,
hipHostRegisterPortable);
if (res != hipSuccess) {
LOG(ERROR) << "error registering host memory at " << location << ": "
<< ToString(res);
@ -738,7 +753,7 @@ GpuDriver::ContextGetSharedMemConfig(GpuContext* context) {
/* static */ bool GpuDriver::HostUnregister(GpuContext* context,
void* location) {
ScopedActivateContext activation{context};
hipError_t res = hipHostUnregister(location);
hipError_t res = tensorflow::wrap::hipHostUnregister(location);
if (res != hipSuccess) {
LOG(ERROR) << "error unregistering host memory at " << location << ": "
<< ToString(res);
@ -755,7 +770,7 @@ GpuDriver::ContextGetSharedMemConfig(GpuContext* context) {
}
ScopedActivateContext activated{context};
hipError_t res = hipEventDestroy(*event);
hipError_t res = tensorflow::wrap::hipEventDestroy(*event);
*event = nullptr;
switch (res) {
@ -779,7 +794,7 @@ GpuDriver::ContextGetSharedMemConfig(GpuContext* context) {
GpuEventHandle event,
GpuStreamHandle stream) {
ScopedActivateContext activated{context};
hipError_t res = hipEventRecord(event, stream);
hipError_t res = tensorflow::wrap::hipEventRecord(event, stream);
switch (res) {
case hipSuccess:
return port::Status::OK();
@ -800,7 +815,7 @@ GpuDriver::ContextGetSharedMemConfig(GpuContext* context) {
/* static */ port::StatusOr<hipError_t> GpuDriver::QueryEvent(
GpuContext* context, GpuEventHandle event) {
ScopedActivateContext activated{context};
hipError_t res = hipEventQuery(event);
hipError_t res = tensorflow::wrap::hipEventQuery(event);
if (res != hipSuccess && res != hipErrorNotReady) {
return port::Status{
port::error::INTERNAL,
@ -817,12 +832,13 @@ GpuDriver::ContextGetSharedMemConfig(GpuContext* context) {
ScopedActivateContext activated{context};
// The stop event must have completed in order for hipEventElapsedTime to
// work.
hipError_t res = hipEventSynchronize(stop);
hipError_t res = tensorflow::wrap::hipEventSynchronize(stop);
if (res != hipSuccess) {
LOG(ERROR) << "failed to synchronize the stop event: " << ToString(res);
return false;
}
res = hipEventElapsedTime(elapsed_milliseconds, start, stop);
res =
tensorflow::wrap::hipEventElapsedTime(elapsed_milliseconds, start, stop);
if (res != hipSuccess) {
LOG(ERROR) << "failed to get elapsed time between events: "
<< ToString(res);
@ -836,7 +852,8 @@ GpuDriver::ContextGetSharedMemConfig(GpuContext* context) {
GpuStreamHandle stream,
GpuEventHandle event) {
ScopedActivateContext activation{context};
hipError_t res = hipStreamWaitEvent(stream, event, 0 /* = flags */);
hipError_t res =
tensorflow::wrap::hipStreamWaitEvent(stream, event, 0 /* = flags */);
if (res != hipSuccess) {
LOG(ERROR) << "could not wait stream on event: " << ToString(res);
return false;
@ -847,7 +864,7 @@ GpuDriver::ContextGetSharedMemConfig(GpuContext* context) {
/* static */ bool GpuDriver::SynchronizeContext(GpuContext* context) {
ScopedActivateContext activation{context};
hipError_t res = hipDeviceSynchronize();
hipError_t res = tensorflow::wrap::hipDeviceSynchronize();
if (res != hipSuccess) {
LOG(ERROR) << "could not synchronize on ROCM device: " << ToString(res)
<< " :: " << port::CurrentStackTrace();
@ -861,7 +878,7 @@ GpuDriver::ContextGetSharedMemConfig(GpuContext* context) {
GpuStreamHandle stream) {
ScopedActivateContext activated{context};
CHECK(stream != nullptr);
hipError_t res = hipStreamSynchronize(stream);
hipError_t res = tensorflow::wrap::hipStreamSynchronize(stream);
if (res != hipSuccess) {
port::Status status = port::InternalError(
absl::StrCat("could not synchronize on ROCM stream: ", ToString(res)));
@ -877,7 +894,7 @@ GpuDriver::ContextGetSharedMemConfig(GpuContext* context) {
GpuStreamHandle stream) {
ScopedActivateContext activated{context};
CHECK(stream != nullptr);
hipError_t res = hipStreamQuery(stream);
hipError_t res = tensorflow::wrap::hipStreamQuery(stream);
if (res == hipSuccess) {
return true;
}
@ -891,7 +908,7 @@ GpuDriver::ContextGetSharedMemConfig(GpuContext* context) {
/* static */ port::Status GpuDriver::SynchronousMemcpyD2H(
GpuContext* context, void* host_dst, hipDeviceptr_t gpu_src, uint64 size) {
ScopedActivateContext activation{context};
hipError_t res = hipMemcpyDtoH(host_dst, gpu_src, size);
hipError_t res = tensorflow::wrap::hipMemcpyDtoH(host_dst, gpu_src, size);
if (res != hipSuccess) {
return port::InternalError(
absl::StrFormat("failed to synchronous memcpy from device to host: %s; "
@ -908,7 +925,8 @@ GpuDriver::ContextGetSharedMemConfig(GpuContext* context) {
GpuContext* context, hipDeviceptr_t gpu_dst, const void* host_src,
uint64 size) {
ScopedActivateContext activation{context};
hipError_t res = hipMemcpyHtoD(gpu_dst, const_cast<void*>(host_src), size);
hipError_t res = tensorflow::wrap::hipMemcpyHtoD(
gpu_dst, const_cast<void*>(host_src), size);
if (res != hipSuccess) {
return port::InternalError(absl::StrFormat(
"failed to synchronous memcpy from host to device: %s; Gpu dst: %p;"
@ -924,7 +942,7 @@ GpuDriver::ContextGetSharedMemConfig(GpuContext* context) {
GpuContext* context, hipDeviceptr_t gpu_dst, hipDeviceptr_t gpu_src,
uint64 size) {
ScopedActivateContext activation{context};
hipError_t res = hipMemcpyDtoD(gpu_dst, gpu_src, size);
hipError_t res = tensorflow::wrap::hipMemcpyDtoD(gpu_dst, gpu_src, size);
if (res != hipSuccess) {
return port::InternalError(absl::StrFormat(
"failed to synchronous memcpy from host to device: %s; Gpu dst: %p; "
@ -942,7 +960,8 @@ GpuDriver::ContextGetSharedMemConfig(GpuContext* context) {
uint64 size,
GpuStreamHandle stream) {
ScopedActivateContext activation{context};
hipError_t res = hipMemcpyDtoHAsync(host_dst, gpu_src, size, stream);
hipError_t res =
tensorflow::wrap::hipMemcpyDtoHAsync(host_dst, gpu_src, size, stream);
if (res != hipSuccess) {
LOG(ERROR) << absl::StrFormat(
"failed to enqueue async memcpy from device to host: %s; host dst: %p; "
@ -963,8 +982,8 @@ GpuDriver::ContextGetSharedMemConfig(GpuContext* context) {
uint64 size,
GpuStreamHandle stream) {
ScopedActivateContext activation{context};
hipError_t res =
hipMemcpyHtoDAsync(gpu_dst, const_cast<void*>(host_src), size, stream);
hipError_t res = tensorflow::wrap::hipMemcpyHtoDAsync(
gpu_dst, const_cast<void*>(host_src), size, stream);
if (res != hipSuccess) {
LOG(ERROR) << absl::StrFormat(
"failed to enqueue async memcpy from host to device: %s; Gpu dst: %p; "
@ -984,7 +1003,8 @@ GpuDriver::ContextGetSharedMemConfig(GpuContext* context) {
uint64 size,
GpuStreamHandle stream) {
ScopedActivateContext activation{context};
hipError_t result = hipMemcpyDtoDAsync(gpu_dst, gpu_src, size, stream);
hipError_t result =
tensorflow::wrap::hipMemcpyDtoDAsync(gpu_dst, gpu_src, size, stream);
if (result != hipSuccess) {
LOG(ERROR) << absl::StrFormat(
"failed to enqueue async memcpy from device to device: %s"
@ -1021,7 +1041,7 @@ GpuDriver::ContextGetSharedMemConfig(GpuContext* context) {
}
ScopedActivateContext activated{context};
hipError_t res = hipEventCreateWithFlags(event, hipflags);
hipError_t res = tensorflow::wrap::hipEventCreateWithFlags(event, hipflags);
if (res == hipSuccess) {
return port::Status::OK();
@ -1037,7 +1057,7 @@ GpuDriver::ContextGetSharedMemConfig(GpuContext* context) {
/* static */ int GpuDriver::GetDeviceCount() {
int device_count = 0;
hipError_t res = hipGetDeviceCount(&device_count);
hipError_t res = tensorflow::wrap::hipGetDeviceCount(&device_count);
if (res != hipSuccess) {
LOG(ERROR) << "could not retrieve ROCM device count: " << ToString(res);
return 0;
@ -1061,7 +1081,7 @@ GpuDriver::ContextGetSharedMemConfig(GpuContext* context) {
/* static */ port::Status GpuDriver::GetPointerAddressRange(
hipDeviceptr_t dptr, hipDeviceptr_t* base, size_t* size) {
hipError_t result = hipMemGetAddressRange(base, size, dptr);
hipError_t result = tensorflow::wrap::hipMemGetAddressRange(base, size, dptr);
if (result == hipSuccess) {
return port::Status::OK();
} else if (result == hipErrorNotFound) {
@ -1106,7 +1126,8 @@ GpuDriver::ContextGetSharedMemConfig(GpuContext* context) {
/* static */ port::StatusOr<hipDevice_t> GpuDriver::GetPointerDevice(
hipDeviceptr_t pointer) {
hipPointerAttribute_t pointerAttributes;
hipError_t result = hipPointerGetAttributes(&pointerAttributes, pointer);
hipError_t result =
tensorflow::wrap::hipPointerGetAttributes(&pointerAttributes, pointer);
if (result != hipSuccess) {
return port::Status{
port::error::INTERNAL,
@ -1114,7 +1135,7 @@ GpuDriver::ContextGetSharedMemConfig(GpuContext* context) {
}
hipDevice_t device;
result = hipDeviceGet(&device, pointerAttributes.device);
result = tensorflow::wrap::hipDeviceGet(&device, pointerAttributes.device);
if (result != hipSuccess) {
return port::Status{
port::error::INTERNAL,
@ -1127,7 +1148,7 @@ GpuDriver::ContextGetSharedMemConfig(GpuContext* context) {
/* static */ port::Status GpuDriver::GetGpuISAVersion(int* version,
hipDevice_t device) {
hipDeviceProp_t props;
hipError_t result = hipGetDeviceProperties(&props, device);
hipError_t result = tensorflow::wrap::hipGetDeviceProperties(&props, device);
if (result == hipSuccess) {
*version = props.gcnArch;
return port::Status::OK();
@ -1145,7 +1166,8 @@ template <typename T>
static port::StatusOr<T> GetSimpleAttribute(hipDevice_t device,
hipDeviceAttribute_t attribute) {
int value = -1;
hipError_t result = hipDeviceGetAttribute(&value, attribute, device);
hipError_t result =
tensorflow::wrap::hipDeviceGetAttribute(&value, attribute, device);
if (result != hipSuccess) {
return port::Status{
port::error::NOT_FOUND,
@ -1199,22 +1221,24 @@ static port::StatusOr<T> GetSimpleAttribute(hipDevice_t device,
/* static */ bool GpuDriver::GetGridLimits(int* x, int* y, int* z,
hipDevice_t device) {
int value;
hipError_t res =
hipDeviceGetAttribute(&value, hipDeviceAttributeMaxGridDimX, device);
hipError_t res = tensorflow::wrap::hipDeviceGetAttribute(
&value, hipDeviceAttributeMaxGridDimX, device);
if (res != hipSuccess) {
LOG(ERROR) << "failed to query max grid dim x: " << ToString(res);
return false;
}
*x = value;
res = hipDeviceGetAttribute(&value, hipDeviceAttributeMaxGridDimY, device);
res = tensorflow::wrap::hipDeviceGetAttribute(
&value, hipDeviceAttributeMaxGridDimY, device);
if (res != hipSuccess) {
LOG(ERROR) << "failed to query max grid dim y: " << ToString(res);
return false;
}
*y = value;
res = hipDeviceGetAttribute(&value, hipDeviceAttributeMaxGridDimZ, device);
res = tensorflow::wrap::hipDeviceGetAttribute(
&value, hipDeviceAttributeMaxGridDimZ, device);
if (res != hipSuccess) {
LOG(ERROR) << "failed to query max grid dim z: " << ToString(res);
return false;
@ -1224,7 +1248,7 @@ static port::StatusOr<T> GetSimpleAttribute(hipDevice_t device,
}
/* static */ bool GpuDriver::GetDriverVersion(int* driver_version) {
hipError_t res = hipDriverGetVersion(driver_version);
hipError_t res = tensorflow::wrap::hipDriverGetVersion(driver_version);
if (res != hipSuccess) {
LOG(ERROR) << "failed to query driver version: " << ToString(res);
return false;
@ -1235,7 +1259,8 @@ static port::StatusOr<T> GetSimpleAttribute(hipDevice_t device,
/* static */ bool GpuDriver::GetDeviceProperties(
hipDeviceProp_t* device_properties, int device_ordinal) {
hipError_t res = hipGetDeviceProperties(device_properties, device_ordinal);
hipError_t res = tensorflow::wrap::hipGetDeviceProperties(device_properties,
device_ordinal);
if (res != hipSuccess) {
LOG(ERROR) << "failed to query device properties: " << ToString(res);
return false;
@ -1268,7 +1293,7 @@ static port::StatusOr<T> GetSimpleAttribute(hipDevice_t device,
ScopedActivateContext activation{context};
size_t free = 0;
size_t total = 0;
hipError_t res = hipMemGetInfo(&free, &total);
hipError_t res = tensorflow::wrap::hipMemGetInfo(&free, &total);
if (res != hipSuccess) {
LOG(ERROR) << "failed to query device memory info: " << ToString(res);
return false;
@ -1282,7 +1307,7 @@ static port::StatusOr<T> GetSimpleAttribute(hipDevice_t device,
/* static */ bool GpuDriver::GetDeviceTotalMemory(hipDevice_t device,
uint64* result) {
size_t value = -1;
hipError_t res = hipDeviceTotalMem(&value, device);
hipError_t res = tensorflow::wrap::hipDeviceTotalMem(&value, device);
if (res != hipSuccess) {
LOG(ERROR) << "failed to query total available memory: " << ToString(res);
return false;
@ -1297,7 +1322,8 @@ static port::StatusOr<T> GetSimpleAttribute(hipDevice_t device,
static const int kBufferSize = 64;
absl::InlinedVector<char, 4> chars(kBufferSize);
chars[kBufferSize - 1] = '\0';
hipError_t res = hipDeviceGetPCIBusId(chars.begin(), kBufferSize - 1, device);
hipError_t res = tensorflow::wrap::hipDeviceGetPCIBusId(
chars.begin(), kBufferSize - 1, device);
if (res != hipSuccess) {
LOG(ERROR) << "failed to query PCI bus id for device: " << ToString(res);
return pci_bus_id;
@ -1313,7 +1339,7 @@ static port::StatusOr<T> GetSimpleAttribute(hipDevice_t device,
}
int can_access_peer = -1;
hipError_t res = hipDeviceCanAccessPeer(
hipError_t res = tensorflow::wrap::hipDeviceCanAccessPeer(
&can_access_peer, from->device_ordinal(), to->device_ordinal());
if (res != hipSuccess) {
LOG(ERROR) << "failed to detect peer access capability: " << ToString(res);
@ -1330,8 +1356,8 @@ static port::StatusOr<T> GetSimpleAttribute(hipDevice_t device,
}
ScopedActivateContext activated{from};
hipError_t result =
hipDeviceEnablePeerAccess(to->device_ordinal(), 0 /* = flags */);
hipError_t result = tensorflow::wrap::hipDeviceEnablePeerAccess(
to->device_ordinal(), 0 /* = flags */);
if (result != hipSuccess && result != hipErrorPeerAccessAlreadyEnabled) {
return port::Status{
port::error::INTERNAL,

View File

@ -0,0 +1,146 @@
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
// This file wraps rocm driver calls with dso loader so that we don't need to
// have explicit linking to librocm. All TF rocm driver usage should route
// through this wrapper.
#ifndef TENSORFLOW_STREAM_EXECUTOR_ROCM_ROCM_DRIVER_WRAPPER_H_
#define TENSORFLOW_STREAM_EXECUTOR_ROCM_ROCM_DRIVER_WRAPPER_H_
#include "rocm/include/hip/hip_runtime.h"
#include "tensorflow/stream_executor/lib/env.h"
#include "tensorflow/stream_executor/platform/dso_loader.h"
#include "tensorflow/stream_executor/platform/port.h"
#if defined(TENSORFLOW_USE_ROCM)
#endif
namespace tensorflow {
namespace wrap {
#ifdef PLATFORM_GOOGLE
// Use static linked library
#define STREAM_EXECUTOR_HIP_WRAP(hipSymbolName) \
template <typename... Args> \
auto hipSymbolName(Args... args)->decltype(::hipSymbolName(args...)) { \
return ::hipSymbolName(args...); \
}
// This macro wraps a global identifier, given by hipSymbolName, in a callable
// structure that loads the DLL symbol out of the DSO handle in a thread-safe
// manner on first use. This dynamic loading technique is used to avoid DSO
// dependencies on vendor libraries which may or may not be available in the
// deployed binary environment.
#else
#define TO_STR_(x) #x
#define TO_STR(x) TO_STR_(x)
// hipMalloc and hipHostMalloc are defined as funtion templates in the
// HIP header files, and hence their names get mangled and the attempt
// to resolve their name when trying to dynamically load them will fail
// Updating the HIP header files to make them C functions is underway.
// Until that change flows through, we will workaround the issue by
// creating dummy wrappers for them here
hipError_t hipMallocVanilla(void** ptr, size_t size) {
return hipErrorNotInitialized;
}
hipError_t hipHostMallocVanilla(void** ptr, size_t size, unsigned int flags) {
return hipErrorNotInitialized;
}
#define STREAM_EXECUTOR_HIP_WRAP(hipSymbolName) \
template <typename... Args> \
auto hipSymbolName(Args... args)->decltype(::hipSymbolName(args...)) { \
using FuncPtrT = std::add_pointer<decltype(::hipSymbolName)>::type; \
static FuncPtrT loaded = []() -> FuncPtrT { \
static const char *kName = TO_STR(hipSymbolName); \
void *f; \
auto s = stream_executor::port::Env::Default()->GetSymbolFromLibrary( \
stream_executor::internal::CachedDsoLoader::GetHipDsoHandle() \
.ValueOrDie(), \
kName, &f); \
CHECK(s.ok()) << "could not find " << kName \
<< " in HIP DSO; dlerror: " << s.error_message(); \
return reinterpret_cast<FuncPtrT>(f); \
}(); \
return loaded(args...); \
}
#endif
// clang-format off
#define HIP_ROUTINE_EACH(__macro) \
__macro(hipDeviceCanAccessPeer) \
__macro(hipDeviceEnablePeerAccess) \
__macro(hipDeviceGet) \
__macro(hipDeviceGetAttribute) \
__macro(hipDeviceGetName) \
__macro(hipDeviceGetPCIBusId) \
__macro(hipDeviceGetSharedMemConfig) \
__macro(hipDeviceSetSharedMemConfig) \
__macro(hipDeviceSynchronize) \
__macro(hipDeviceTotalMem) \
__macro(hipDriverGetVersion) \
__macro(hipEventCreateWithFlags) \
__macro(hipEventElapsedTime) \
__macro(hipEventDestroy) \
__macro(hipEventQuery) \
__macro(hipEventRecord) \
__macro(hipEventSynchronize) \
__macro(hipFree) \
__macro(hipFuncSetCacheConfig) \
__macro(hipGetDevice) \
__macro(hipGetDeviceCount) \
__macro(hipGetDeviceProperties) \
__macro(hipHostFree) \
__macro(hipHostRegister) \
__macro(hipHostUnregister) \
__macro(hipInit) \
__macro(hipMemGetAddressRange) \
__macro(hipMemGetInfo) \
__macro(hipMemcpyDtoD) \
__macro(hipMemcpyDtoDAsync) \
__macro(hipMemcpyDtoH) \
__macro(hipMemcpyDtoHAsync) \
__macro(hipMemcpyHtoD) \
__macro(hipMemcpyHtoDAsync) \
__macro(hipMemset) \
__macro(hipMemsetAsync) \
__macro(hipModuleGetFunction) \
__macro(hipModuleGetGlobal) \
__macro(hipModuleLaunchKernel) \
__macro(hipModuleLoadData) \
__macro(hipModuleUnload) \
__macro(hipPointerGetAttributes) \
__macro(hipSetDevice) \
__macro(hipStreamAddCallback) \
__macro(hipStreamCreateWithFlags) \
__macro(hipStreamDestroy) \
__macro(hipStreamQuery) \
__macro(hipStreamSynchronize) \
__macro(hipStreamWaitEvent) \
// clang-format on
HIP_ROUTINE_EACH(STREAM_EXECUTOR_HIP_WRAP)
#undef HIP_ROUTINE_EACH
#undef STREAM_EXECUTOR_HIP_WRAP
#undef TO_STR
#undef TO_STR_
} // namespace wrap
} // namespace tensorflow
#endif // TENSORFLOW_STREAM_EXECUTOR_ROCM_ROCM_DRIVER_WRAPPER_H_

View File

@ -0,0 +1,618 @@
/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/stream_executor/rocm/rocm_fft.h"
#include <complex>
#include "tensorflow/stream_executor/device_memory.h"
#include "tensorflow/stream_executor/gpu/gpu_activation.h"
#include "tensorflow/stream_executor/gpu/gpu_executor.h"
#include "tensorflow/stream_executor/gpu/gpu_helpers.h"
#include "tensorflow/stream_executor/gpu/gpu_stream.h"
#include "tensorflow/stream_executor/lib/env.h"
#include "tensorflow/stream_executor/lib/initialize.h"
#include "tensorflow/stream_executor/lib/status.h"
#include "tensorflow/stream_executor/platform/dso_loader.h"
#include "tensorflow/stream_executor/platform/logging.h"
#include "tensorflow/stream_executor/platform/port.h"
#include "tensorflow/stream_executor/plugin_registry.h"
#include "tensorflow/stream_executor/rocm/rocm_platform_id.h"
#include "tensorflow/stream_executor/stream_executor_internal.h"
namespace stream_executor {
namespace gpu {
PLUGIN_REGISTRY_DEFINE_PLUGIN_ID(kRocFftPlugin);
namespace wrap {
#ifdef PLATFORM_GOOGLE
// This macro wraps a global identifier, given by __name, in a callable
// structure that loads the DLL symbol out of the DSO handle in a thread-safe
// manner on first use. This dynamic loading technique is used to avoid DSO
// dependencies on vendor libraries which may or may not be available in the
// deployed binary environment.
#define STREAM_EXECUTOR_ROCFFT_WRAP(__name) \
struct WrapperShim__##__name { \
template <typename... Args> \
hipfftResult operator()(GpuExecutor *parent, Args... args) { \
gpu::ScopedActivateExecutorContext sac{parent}; \
return ::__name(args...); \
} \
} __name;
#else
#define STREAM_EXECUTOR_ROCFFT_WRAP(__name) \
struct DynLoadShim__##__name { \
static const char *kName; \
using FuncPtrT = std::add_pointer<decltype(::__name)>::type; \
static void *GetDsoHandle() { \
auto s = internal::CachedDsoLoader::GetRocfftDsoHandle(); \
return s.ValueOrDie(); \
} \
static FuncPtrT LoadOrDie() { \
void *f; \
auto s = port::Env::Default()->GetSymbolFromLibrary(GetDsoHandle(), \
kName, &f); \
CHECK(s.ok()) << "could not find " << kName \
<< " in rocfft DSO; dlerror: " << s.error_message(); \
return reinterpret_cast<FuncPtrT>(f); \
} \
static FuncPtrT DynLoad() { \
static FuncPtrT f = LoadOrDie(); \
return f; \
} \
template <typename... Args> \
hipfftResult operator()(GpuExecutor *parent, Args... args) { \
gpu::ScopedActivateExecutorContext sac{parent}; \
return DynLoad()(args...); \
} \
} __name; \
const char *DynLoadShim__##__name::kName = #__name;
#endif
#define ROCFFT_ROUTINE_EACH(__macro) \
__macro(hipfftDestroy) __macro(hipfftSetStream) __macro(hipfftPlan1d) \
__macro(hipfftPlan2d) __macro(hipfftPlan3d) __macro(hipfftPlanMany) \
__macro(hipfftCreate) __macro(hipfftSetAutoAllocation) \
__macro(hipfftSetWorkArea) __macro(hipfftGetSize1d) \
__macro(hipfftMakePlan1d) __macro(hipfftGetSize2d) \
__macro(hipfftMakePlan2d) __macro(hipfftGetSize3d) \
__macro(hipfftMakePlan3d) __macro(hipfftGetSizeMany) \
__macro(hipfftMakePlanMany) \
__macro(hipfftExecD2Z) \
__macro(hipfftExecZ2D) \
__macro(hipfftExecC2C) \
__macro(hipfftExecC2R) \
__macro(hipfftExecZ2Z) \
__macro(hipfftExecR2C)
ROCFFT_ROUTINE_EACH(STREAM_EXECUTOR_ROCFFT_WRAP)
} // namespace wrap
namespace {
// A helper function transforming gpu_fft arguments into rocFFT arguments.
hipfftType ROCMFftType(fft::Type type) {
switch (type) {
case fft::Type::kC2CForward:
case fft::Type::kC2CInverse:
return HIPFFT_C2C;
case fft::Type::kC2R:
return HIPFFT_C2R;
case fft::Type::kR2C:
return HIPFFT_R2C;
case fft::Type::kZ2ZForward:
case fft::Type::kZ2ZInverse:
return HIPFFT_Z2Z;
case fft::Type::kZ2D:
return HIPFFT_Z2D;
case fft::Type::kD2Z:
return HIPFFT_D2Z;
default:
LOG(FATAL) << "Invalid value of fft::Type.";
}
}
// Associates the given stream with the given rocFFT plan.
bool SetStream(GpuExecutor *parent, hipfftHandle plan, Stream *stream) {
auto ret = wrap::hipfftSetStream(parent, plan, AsGpuStreamValue(stream));
if (ret != HIPFFT_SUCCESS) {
LOG(ERROR) << "failed to run rocFFT routine hipfftSetStream: " << ret;
return false;
}
return true;
}
} // namespace
port::Status ROCMFftPlan::Initialize(
GpuExecutor *parent, Stream *stream, int rank, uint64 *elem_count,
uint64 *input_embed, uint64 input_stride, uint64 input_distance,
uint64 *output_embed, uint64 output_stride, uint64 output_distance,
fft::Type type, int batch_count, ScratchAllocator *scratch_allocator) {
if (IsInitialized()) {
LOG(FATAL) << "Try to repeatedly initialize.";
}
is_initialized_ = true;
int elem_count_[3], input_embed_[3], output_embed_[3];
for (int i = 0; i < rank; ++i) {
elem_count_[i] = elem_count[i];
if (input_embed) {
input_embed_[i] = input_embed[i];
}
if (output_embed) {
output_embed_[i] = output_embed[i];
}
}
parent_ = parent;
fft_type_ = type;
if (batch_count == 1 && input_embed == nullptr && output_embed == nullptr) {
hipfftResult_t ret;
if (scratch_allocator == nullptr) {
switch (rank) {
case 1:
// hipfftPlan1d
ret = wrap::hipfftPlan1d(parent, &plan_, elem_count_[0],
ROCMFftType(type), 1 /* = batch */);
if (ret != HIPFFT_SUCCESS) {
LOG(ERROR) << "failed to create rocFFT 1d plan:" << ret;
return port::Status{port::error::INTERNAL,
"Failed to create rocFFT 1d plan."};
}
return port::Status::OK();
case 2:
// hipfftPlan2d
ret = wrap::hipfftPlan2d(parent, &plan_, elem_count_[0],
elem_count_[1], ROCMFftType(type));
if (ret != HIPFFT_SUCCESS) {
LOG(ERROR) << "failed to create rocFFT 2d plan:" << ret;
return port::Status{port::error::INTERNAL,
"Failed to create rocFFT 2d plan."};
}
return port::Status::OK();
case 3:
// hipfftPlan3d
ret =
wrap::hipfftPlan3d(parent, &plan_, elem_count_[0], elem_count_[1],
elem_count_[2], ROCMFftType(type));
if (ret != HIPFFT_SUCCESS) {
LOG(ERROR) << "failed to create rocFFT 3d plan:" << ret;
return port::Status{port::error::INTERNAL,
"Failed to create rocFFT 3d plan."};
}
return port::Status::OK();
default:
LOG(ERROR) << "Invalid rank value for hipfftPlan. "
"Requested 1, 2, or 3, given: "
<< rank;
return port::Status{port::error::INVALID_ARGUMENT,
"hipfftPlan only takes rank 1, 2, or 3."};
}
} else {
ret = wrap::hipfftCreate(parent, &plan_);
if (ret != HIPFFT_SUCCESS) {
LOG(ERROR) << "failed to create rocFFT plan:" << ret;
return port::Status{port::error::INTERNAL,
"Failed to create rocFFT plan."};
}
ret = wrap::hipfftSetAutoAllocation(parent, plan_, 0);
if (ret != HIPFFT_SUCCESS) {
LOG(ERROR) << "failed to set auto allocation for rocFFT plan:" << ret;
return port::Status{port::error::INTERNAL,
"Failed to set auto allocation for rocFFT plan."};
}
size_t size_in_bytes;
switch (rank) {
case 1:
ret = wrap::hipfftMakePlan1d(parent, plan_, elem_count_[0],
ROCMFftType(type), /*batch=*/1,
&size_in_bytes);
if (ret != HIPFFT_SUCCESS) {
LOG(ERROR) << "failed to make rocFFT 1d plan:" << ret;
return port::Status{port::error::INTERNAL,
"Failed to make rocFFT 1d plan."};
}
break;
case 2:
ret = wrap::hipfftMakePlan2d(parent, plan_, elem_count_[0],
elem_count_[1], ROCMFftType(type),
&size_in_bytes);
if (ret != HIPFFT_SUCCESS) {
LOG(ERROR) << "failed to make rocFFT 2d plan:" << ret;
return port::Status{port::error::INTERNAL,
"Failed to make rocFFT 2d plan."};
}
break;
case 3:
ret = wrap::hipfftMakePlan3d(parent, plan_, elem_count_[0],
elem_count_[1], elem_count_[2],
ROCMFftType(type), &size_in_bytes);
if (ret != HIPFFT_SUCCESS) {
LOG(ERROR) << "failed to make rocFFT 3d plan:" << ret;
return port::Status{port::error::INTERNAL,
"Failed to make rocFFT 3d plan."};
}
break;
default:
LOG(ERROR) << "Invalid rank value for hipfftPlan. "
"Requested 1, 2, or 3, given: "
<< rank;
return port::Status{port::error::INVALID_ARGUMENT,
"hipfftPlan only takes rank 1, 2, or 3."};
}
// TODO(yangzihao): refactor this code and the one with the same function
// in the batch mode.
if (size_in_bytes != 0) {
auto allocated =
scratch_allocator->AllocateBytes(stream, size_in_bytes);
if (!allocated.ok() || (scratch_ = allocated.ValueOrDie()) == nullptr) {
LOG(ERROR) << "failed to allocate work area.";
return allocated.status();
}
}
// Connect work area with allocated space.
ret = wrap::hipfftSetWorkArea(parent, plan_, scratch_.opaque());
if (ret != HIPFFT_SUCCESS) {
LOG(ERROR) << "failed to set work area for rocFFT plan:" << ret;
return port::Status{port::error::INTERNAL,
"Failed to set work area for rocFFT plan."};
}
return port::Status::OK();
}
} else {
// For either multiple batches or rank higher than 3, use hipfftPlanMany().
if (scratch_allocator == nullptr) {
auto ret = wrap::hipfftPlanMany(
parent, &plan_, rank, elem_count_,
input_embed ? input_embed_ : nullptr, input_stride, input_distance,
output_embed ? output_embed_ : nullptr, output_stride,
output_distance, ROCMFftType(type), batch_count);
if (ret != HIPFFT_SUCCESS) {
LOG(ERROR) << "failed to create rocFFT batched plan:" << ret;
return port::Status{port::error::INTERNAL,
"Failed to create rocFFT bacthed plan."};
}
} else {
auto ret = wrap::hipfftCreate(parent, &plan_);
if (ret != HIPFFT_SUCCESS) {
LOG(ERROR) << "failed to create rocFFT batched plan:" << ret;
return port::Status{port::error::INTERNAL,
"Failed to create rocFFT bacthed plan."};
}
ret = wrap::hipfftSetAutoAllocation(parent, plan_, 0);
if (ret != HIPFFT_SUCCESS) {
LOG(ERROR) << "failed to set auto allocation for rocFFT batched plan:"
<< ret;
return port::Status{
port::error::INTERNAL,
"Failed to set auto allocation for rocFFT bacthed plan."};
}
size_t size_in_bytes;
ret = wrap::hipfftMakePlanMany(
parent, plan_, rank, elem_count_,
input_embed ? input_embed_ : nullptr, input_stride, input_distance,
output_embed ? output_embed_ : nullptr, output_stride,
output_distance, ROCMFftType(type), batch_count, &size_in_bytes);
if (ret != HIPFFT_SUCCESS) {
LOG(ERROR) << "failed to make rocFFT batched plan:" << ret;
return port::Status{port::error::INTERNAL,
"Failed to make rocFFT bacthed plan."};
}
if (size_in_bytes != 0) {
auto allocated =
scratch_allocator->AllocateBytes(stream, size_in_bytes);
if (!allocated.ok() || (scratch_ = allocated.ValueOrDie()) == nullptr) {
LOG(ERROR) << "failed to allocate work area.";
return allocated.status();
}
}
// Connect work area with allocated space.
ret = wrap::hipfftSetWorkArea(parent, plan_, scratch_.opaque());
if (ret != HIPFFT_SUCCESS) {
LOG(ERROR) << "failed to set work area for rocFFT batched plan:" << ret;
return port::Status{port::error::INTERNAL,
"Failed to set work area for rocFFT bacthed plan."};
}
}
}
return port::Status::OK();
}
port::Status ROCMFftPlan::Initialize(GpuExecutor *parent, Stream *stream,
int rank, uint64 *elem_count,
fft::Type type,
ScratchAllocator *scratch_allocator) {
return Initialize(parent_, stream, rank, elem_count,
/*input_embed=*/nullptr, /*input_stride=*/0,
/*input_distance=*/0,
/*output_embed=*/nullptr, /*output_stride=*/0,
/*output_distance=*/0, type, 1, scratch_allocator);
}
ROCMFftPlan::~ROCMFftPlan() { wrap::hipfftDestroy(parent_, plan_); }
int ROCMFftPlan::GetFftDirection() const {
if (!IsInitialized()) {
LOG(FATAL) << "Try to get fft direction before initialization.";
} else {
switch (fft_type_) {
case fft::Type::kC2CForward:
case fft::Type::kZ2ZForward:
case fft::Type::kR2C:
case fft::Type::kD2Z:
return HIPFFT_FORWARD;
case fft::Type::kC2CInverse:
case fft::Type::kZ2ZInverse:
case fft::Type::kC2R:
case fft::Type::kZ2D:
return HIPFFT_BACKWARD;
default:
LOG(FATAL) << "Invalid value of fft::Type.";
}
}
}
std::unique_ptr<fft::Plan> ROCMFft::Create1dPlan(Stream *stream, uint64 num_x,
fft::Type type,
bool in_place_fft) {
std::unique_ptr<ROCMFftPlan> fft_plan_ptr{new ROCMFftPlan()};
uint64 elem_count[1] = {num_x};
port::Status status = fft_plan_ptr->Initialize(
parent_, stream, 1, elem_count, type, /*scratch_allocator=*/nullptr);
// TODO(yangzihao): In the future, send error msg back to TensorFlow
// so it can fail gracefully,
if (!status.ok()) {
LOG(FATAL) << "failed to initialize hipfft 1d plan: "
<< status.error_message();
}
return std::move(fft_plan_ptr);
}
std::unique_ptr<fft::Plan> ROCMFft::Create1dPlanWithScratchAllocator(
Stream *stream, uint64 num_x, fft::Type type, bool in_place_fft,
ScratchAllocator *scratch_allocator) {
std::unique_ptr<ROCMFftPlan> fft_plan_ptr{new ROCMFftPlan()};
uint64 elem_count[1] = {num_x};
port::Status status = fft_plan_ptr->Initialize(parent_, stream, 1, elem_count,
type, scratch_allocator);
if (!status.ok()) {
LOG(FATAL)
<< "failed to initialize hipfft 1d plan with customized allocator: "
<< status.error_message();
}
return std::move(fft_plan_ptr);
}
std::unique_ptr<fft::Plan> ROCMFft::Create2dPlan(Stream *stream, uint64 num_x,
uint64 num_y, fft::Type type,
bool in_place_fft) {
std::unique_ptr<ROCMFftPlan> fft_plan_ptr{new ROCMFftPlan()};
uint64 elem_count[2] = {num_x, num_y};
port::Status status = fft_plan_ptr->Initialize(
parent_, stream, 1, elem_count, type, /*scratch_allocator=*/nullptr);
if (!status.ok()) {
LOG(FATAL) << "failed to initialize hipfft 2d plan: "
<< status.error_message();
}
return std::move(fft_plan_ptr);
}
std::unique_ptr<fft::Plan> ROCMFft::Create2dPlanWithScratchAllocator(
Stream *stream, uint64 num_x, uint64 num_y, fft::Type type,
bool in_place_fft, ScratchAllocator *scratch_allocator) {
std::unique_ptr<ROCMFftPlan> fft_plan_ptr{new ROCMFftPlan()};
uint64 elem_count[2] = {num_x, num_y};
port::Status status = fft_plan_ptr->Initialize(parent_, stream, 2, elem_count,
type, scratch_allocator);
if (!status.ok()) {
LOG(FATAL)
<< "failed to initialize hipfft 2d plan with customized allocator: "
<< status.error_message();
}
return std::move(fft_plan_ptr);
}
std::unique_ptr<fft::Plan> ROCMFft::Create3dPlan(Stream *stream, uint64 num_x,
uint64 num_y, uint64 num_z,
fft::Type type,
bool in_place_fft) {
std::unique_ptr<ROCMFftPlan> fft_plan_ptr{new ROCMFftPlan()};
uint64 elem_count[3] = {num_x, num_y, num_z};
port::Status status = fft_plan_ptr->Initialize(
parent_, stream, 3, elem_count, type, /*scratch_allocator=*/nullptr);
if (!status.ok()) {
LOG(FATAL) << "failed to initialize hipfft 3d plan: "
<< status.error_message();
}
return std::move(fft_plan_ptr);
}
std::unique_ptr<fft::Plan> ROCMFft::Create3dPlanWithScratchAllocator(
Stream *stream, uint64 num_x, uint64 num_y, uint64 num_z, fft::Type type,
bool in_place_fft, ScratchAllocator *scratch_allocator) {
std::unique_ptr<ROCMFftPlan> fft_plan_ptr{new ROCMFftPlan()};
uint64 elem_count[3] = {num_x, num_y, num_z};
port::Status status = fft_plan_ptr->Initialize(parent_, stream, 3, elem_count,
type, scratch_allocator);
if (!status.ok()) {
LOG(FATAL)
<< "failed to initialize hipfft 3d plan with customized allocator: "
<< status.error_message();
}
return std::move(fft_plan_ptr);
}
std::unique_ptr<fft::Plan> ROCMFft::CreateBatchedPlan(
Stream *stream, int rank, uint64 *elem_count, uint64 *input_embed,
uint64 input_stride, uint64 input_distance, uint64 *output_embed,
uint64 output_stride, uint64 output_distance, fft::Type type,
bool in_place_fft, int batch_count) {
std::unique_ptr<ROCMFftPlan> fft_plan_ptr{new ROCMFftPlan()};
port::Status status = fft_plan_ptr->Initialize(
parent_, stream, rank, elem_count, input_embed, input_stride,
input_distance, output_embed, output_stride, output_distance, type,
batch_count, /*scratch_allocator=*/nullptr);
if (!status.ok()) {
LOG(FATAL) << "failed to initialize batched hipfft plan: "
<< status.error_message();
}
return std::move(fft_plan_ptr);
}
std::unique_ptr<fft::Plan> ROCMFft::CreateBatchedPlanWithScratchAllocator(
Stream *stream, int rank, uint64 *elem_count, uint64 *input_embed,
uint64 input_stride, uint64 input_distance, uint64 *output_embed,
uint64 output_stride, uint64 output_distance, fft::Type type,
bool in_place_fft, int batch_count, ScratchAllocator *scratch_allocator) {
std::unique_ptr<ROCMFftPlan> fft_plan_ptr{new ROCMFftPlan()};
port::Status status = fft_plan_ptr->Initialize(
parent_, stream, rank, elem_count, input_embed, input_stride,
input_distance, output_embed, output_stride, output_distance, type,
batch_count, scratch_allocator);
if (!status.ok()) {
LOG(FATAL) << "failed to initialize batched hipfft plan with customized "
"allocator: "
<< status.error_message();
}
return std::move(fft_plan_ptr);
}
void ROCMFft::UpdatePlanWithScratchAllocator(
Stream *stream, fft::Plan *plan, ScratchAllocator *scratch_allocator) {
LOG(ERROR) << "update plan with scratch allocator not implemented";
}
template <typename FuncT, typename InputT, typename OutputT>
bool ROCMFft::DoFftInternal(Stream *stream, fft::Plan *plan, FuncT hipfftExec,
const DeviceMemory<InputT> &input,
DeviceMemory<OutputT> *output) {
ROCMFftPlan *rocm_fft_plan = dynamic_cast<ROCMFftPlan *>(plan);
if (rocm_fft_plan == nullptr) {
LOG(ERROR) << "the passed-in plan is not a ROCMFftPlan object.";
return false;
}
if (!SetStream(parent_, rocm_fft_plan->GetPlan(), stream)) {
return false;
}
auto ret = hipfftExec(parent_, rocm_fft_plan->GetPlan(),
GpuComplex(const_cast<InputT *>(GpuMemory(input))),
GpuComplex(GpuMemoryMutable(output)));
if (ret != HIPFFT_SUCCESS) {
LOG(ERROR) << "failed to run rocFFT routine: " << ret;
return false;
}
return true;
}
template <typename FuncT, typename InputT, typename OutputT>
bool ROCMFft::DoFftWithDirectionInternal(Stream *stream, fft::Plan *plan,
FuncT hipfftExec,
const DeviceMemory<InputT> &input,
DeviceMemory<OutputT> *output) {
ROCMFftPlan *rocm_fft_plan = dynamic_cast<ROCMFftPlan *>(plan);
if (rocm_fft_plan == nullptr) {
LOG(ERROR) << "the passed-in plan is not a ROCMFftPlan object.";
return false;
}
if (!SetStream(parent_, rocm_fft_plan->GetPlan(), stream)) {
return false;
}
auto ret = hipfftExec(parent_, rocm_fft_plan->GetPlan(),
GpuComplex(const_cast<InputT *>(GpuMemory(input))),
GpuComplex(GpuMemoryMutable(output)),
rocm_fft_plan->GetFftDirection());
if (ret != HIPFFT_SUCCESS) {
LOG(ERROR) << "failed to run rocFFT routine: " << ret;
return false;
}
return true;
}
#define STREAM_EXECUTOR_ROCM_DEFINE_FFT(__type, __fft_type1, __fft_type2, \
__fft_type3) \
bool ROCMFft::DoFft(Stream *stream, fft::Plan *plan, \
const DeviceMemory<std::complex<__type>> &input, \
DeviceMemory<std::complex<__type>> *output) { \
return DoFftWithDirectionInternal( \
stream, plan, wrap::hipfftExec##__fft_type1, input, output); \
} \
bool ROCMFft::DoFft(Stream *stream, fft::Plan *plan, \
const DeviceMemory<__type> &input, \
DeviceMemory<std::complex<__type>> *output) { \
return DoFftInternal(stream, plan, wrap::hipfftExec##__fft_type2, input, \
output); \
} \
bool ROCMFft::DoFft(Stream *stream, fft::Plan *plan, \
const DeviceMemory<std::complex<__type>> &input, \
DeviceMemory<__type> *output) { \
return DoFftInternal(stream, plan, wrap::hipfftExec##__fft_type3, input, \
output); \
}
STREAM_EXECUTOR_ROCM_DEFINE_FFT(float, C2C, R2C, C2R)
STREAM_EXECUTOR_ROCM_DEFINE_FFT(double, Z2Z, D2Z, Z2D)
#undef STREAM_EXECUTOR_ROCM_DEFINE_FFT
} // namespace gpu
void initialize_rocfft() {
auto rocFftAlreadyRegistered = PluginRegistry::Instance()->HasFactory(
rocm::kROCmPlatformId, PluginKind::kFft, gpu::kRocFftPlugin);
if (!rocFftAlreadyRegistered) {
port::Status status =
PluginRegistry::Instance()->RegisterFactory<PluginRegistry::FftFactory>(
rocm::kROCmPlatformId, gpu::kRocFftPlugin, "rocFFT",
[](internal::StreamExecutorInterface *parent) -> fft::FftSupport * {
gpu::GpuExecutor *rocm_executor =
dynamic_cast<gpu::GpuExecutor *>(parent);
if (rocm_executor == nullptr) {
LOG(ERROR)
<< "Attempting to initialize an instance of the rocFFT "
<< "support library with a non-ROCM StreamExecutor";
return nullptr;
}
return new gpu::ROCMFft(rocm_executor);
});
if (!status.ok()) {
LOG(ERROR) << "Unable to register rocFFT factory: "
<< status.error_message();
}
PluginRegistry::Instance()->SetDefaultFactory(
rocm::kROCmPlatformId, PluginKind::kFft, gpu::kRocFftPlugin);
}
}
} // namespace stream_executor
REGISTER_MODULE_INITIALIZER(register_rocfft,
{ stream_executor::initialize_rocfft(); });

View File

@ -0,0 +1,132 @@
/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
// ROCM-specific support for FFT functionality -- this wraps the rocFFT library
// capabilities, and is only included into ROCM implementation code -- it will
// not introduce rocm headers into other code.
#ifndef TENSORFLOW_STREAM_EXECUTOR_ROCM_ROCM_FFT_H_
#define TENSORFLOW_STREAM_EXECUTOR_ROCM_ROCM_FFT_H_
#include "rocm/include/rocfft/hipfft.h"
#include "tensorflow/stream_executor/fft.h"
#include "tensorflow/stream_executor/platform/port.h"
#include "tensorflow/stream_executor/plugin_registry.h"
#include "tensorflow/stream_executor/scratch_allocator.h"
namespace stream_executor {
class Stream;
namespace gpu {
class GpuExecutor;
// Opaque and unique indentifier for the rocFFT plugin.
extern const PluginId kRocFftPlugin;
// ROCMFftPlan uses deferred initialization. Only a single call of
// Initialize() is allowed to properly create hipfft plan and set member
// variable is_initialized_ to true. Newly added interface that uses member
// variables should first check is_initialized_ to make sure that the values of
// member variables are valid.
class ROCMFftPlan : public fft::Plan {
public:
ROCMFftPlan()
: parent_(nullptr),
plan_(),
fft_type_(fft::Type::kInvalid),
scratch_(nullptr),
is_initialized_(false) {}
~ROCMFftPlan() override;
// Get FFT direction in hipFFT based on FFT type.
int GetFftDirection() const;
hipfftHandle GetPlan() const {
if (IsInitialized()) {
return plan_;
} else {
LOG(FATAL) << "Try to get hipfftHandle value before initialization.";
}
}
// Initialize function for batched plan
port::Status Initialize(GpuExecutor *parent, Stream *stream, int rank,
uint64 *elem_count, uint64 *input_embed,
uint64 input_stride, uint64 input_distance,
uint64 *output_embed, uint64 output_stride,
uint64 output_distance, fft::Type type,
int batch_count, ScratchAllocator *scratch_allocator);
// Initialize function for 1d,2d, and 3d plan
port::Status Initialize(GpuExecutor *parent, Stream *stream, int rank,
uint64 *elem_count, fft::Type type,
ScratchAllocator *scratch_allocator);
protected:
bool IsInitialized() const { return is_initialized_; }
private:
GpuExecutor *parent_;
hipfftHandle plan_;
fft::Type fft_type_;
DeviceMemory<uint8> scratch_;
bool is_initialized_;
};
// FFT support for ROCM platform via rocFFT library.
//
// This satisfies the platform-agnostic FftSupport interface.
//
// Note that the hipFFT handle that this encapsulates is implicitly tied to the
// context (and, as a result, the device) that the parent GpuExecutor is tied
// to. This simply happens as an artifact of creating the hipFFT handle when a
// ROCM context is active.
//
// Thread-safe. The ROCM context associated with all operations is the ROCM
// context of parent_, so all context is explicit.
class ROCMFft : public fft::FftSupport {
public:
explicit ROCMFft(GpuExecutor *parent) : parent_(parent) {}
~ROCMFft() override {}
TENSORFLOW_STREAM_EXECUTOR_GPU_FFT_SUPPORT_OVERRIDES
private:
GpuExecutor *parent_;
// Two helper functions that execute dynload::hipfftExec?2?.
// This is for complex to complex FFT, when the direction is required.
template <typename FuncT, typename InputT, typename OutputT>
bool DoFftWithDirectionInternal(Stream *stream, fft::Plan *plan,
FuncT hipfft_exec,
const DeviceMemory<InputT> &input,
DeviceMemory<OutputT> *output);
// This is for complex to real or real to complex FFT, when the direction
// is implied.
template <typename FuncT, typename InputT, typename OutputT>
bool DoFftInternal(Stream *stream, fft::Plan *plan, FuncT hipfft_exec,
const DeviceMemory<InputT> &input,
DeviceMemory<OutputT> *output);
SE_DISALLOW_COPY_AND_ASSIGN(ROCMFft);
};
} // namespace gpu
} // namespace stream_executor
#endif // TENSORFLOW_STREAM_EXECUTOR_ROCM_ROCM_FFT_H_

View File

@ -18,7 +18,6 @@ limitations under the License.
#include "absl/base/casts.h"
#include "absl/strings/str_cat.h"
#include "absl/strings/str_format.h"
#include "tensorflow/stream_executor/gpu/gpu_diagnostics.h"
#include "tensorflow/stream_executor/gpu/gpu_driver.h"
#include "tensorflow/stream_executor/gpu/gpu_event.h"
#include "tensorflow/stream_executor/gpu/gpu_executor.h"
@ -41,6 +40,7 @@ limitations under the License.
#include "tensorflow/stream_executor/platform/logging.h"
#include "tensorflow/stream_executor/platform/port.h"
#include "tensorflow/stream_executor/plugin_registry.h"
#include "tensorflow/stream_executor/rocm/rocm_diagnostics.h"
#include "tensorflow/stream_executor/rocm/rocm_platform_id.h"
#include "tensorflow/stream_executor/stream.h"
#include "tensorflow/stream_executor/stream_executor_internal.h"
@ -655,7 +655,7 @@ port::Status GpuExecutor::BlockHostUntilDone(Stream* stream) {
blas::BlasSupport* GpuExecutor::CreateBlas() {
PluginRegistry* registry = PluginRegistry::Instance();
port::StatusOr<PluginRegistry::BlasFactory> status =
registry->GetFactory<PluginRegistry::BlasFactory>(kROCmPlatformId,
registry->GetFactory<PluginRegistry::BlasFactory>(rocm::kROCmPlatformId,
plugin_config_.blas());
if (!status.ok()) {
LOG(ERROR) << "Unable to retrieve BLAS factory: "
@ -669,7 +669,7 @@ blas::BlasSupport* GpuExecutor::CreateBlas() {
dnn::DnnSupport* GpuExecutor::CreateDnn() {
PluginRegistry* registry = PluginRegistry::Instance();
port::StatusOr<PluginRegistry::DnnFactory> status =
registry->GetFactory<PluginRegistry::DnnFactory>(kROCmPlatformId,
registry->GetFactory<PluginRegistry::DnnFactory>(rocm::kROCmPlatformId,
plugin_config_.dnn());
if (!status.ok()) {
LOG(ERROR) << "Unable to retrieve DNN factory: "
@ -683,7 +683,7 @@ dnn::DnnSupport* GpuExecutor::CreateDnn() {
fft::FftSupport* GpuExecutor::CreateFft() {
PluginRegistry* registry = PluginRegistry::Instance();
port::StatusOr<PluginRegistry::FftFactory> status =
registry->GetFactory<PluginRegistry::FftFactory>(kROCmPlatformId,
registry->GetFactory<PluginRegistry::FftFactory>(rocm::kROCmPlatformId,
plugin_config_.fft());
if (!status.ok()) {
LOG(ERROR) << "Unable to retrieve FFT factory: "
@ -697,7 +697,7 @@ fft::FftSupport* GpuExecutor::CreateFft() {
rng::RngSupport* GpuExecutor::CreateRng() {
PluginRegistry* registry = PluginRegistry::Instance();
port::StatusOr<PluginRegistry::RngFactory> status =
registry->GetFactory<PluginRegistry::RngFactory>(kROCmPlatformId,
registry->GetFactory<PluginRegistry::RngFactory>(rocm::kROCmPlatformId,
plugin_config_.rng());
if (!status.ok()) {
LOG(ERROR) << "Unable to retrieve RNG factory: "
@ -878,12 +878,10 @@ DeviceDescription* GpuExecutor::PopulateDeviceDescription() const {
{
int driver_version = 0;
(void)GpuDriver::GetDriverVersion(&driver_version);
string augmented_driver_version =
absl::StrFormat("%d (%s)", driver_version, "__FIXME__");
// FIXME:
// uncomment the line below once the "DriverVersionStatusToString"
// routine is moved from the "cuda" namespace to the "gpu" naemspace
// DriverVersionStatusToString(Diagnostician::FindDsoVersion()).c_str());
string augmented_driver_version = absl::StrFormat(
"%d (%s)", driver_version,
rocm::DriverVersionStatusToString(Diagnostician::FindDsoVersion())
.c_str());
builder.set_driver_version(augmented_driver_version);
}

View File

@ -94,7 +94,7 @@ port::StatusOr<StreamExecutor*> ROCmPlatform::FirstExecutorForBus(
absl::StrFormat("Executor for bus %d not found.", bus_ordinal)};
}
Platform::Id ROCmPlatform::id() const { return kROCmPlatformId; }
Platform::Id ROCmPlatform::id() const { return rocm::kROCmPlatformId; }
int ROCmPlatform::VisibleDeviceCount() const {
// Throw away the result - it logs internally, and this [containing] function

View File

@ -16,9 +16,9 @@ limitations under the License.
#include "tensorflow/stream_executor/rocm/rocm_platform_id.h"
namespace stream_executor {
namespace gpu {
namespace rocm {
PLATFORM_DEFINE_ID(kROCmPlatformId);
} // namespace gpu
} // namespace rocm
} // namespace stream_executor

View File

@ -19,16 +19,16 @@ limitations under the License.
#include "tensorflow/stream_executor/platform.h"
namespace stream_executor {
namespace gpu {
namespace rocm {
// Opaque and unique identifier for the ROCm platform.
// This is needed so that plugins can refer to/identify this platform without
// instantiating a ROCmPlatform object.
// This is broken out here to avoid a circular dependency between ROCmPlatform
// and GpuExecutor.
// and ROCmExecutor.
extern const Platform::Id kROCmPlatformId;
} // namespace gpu
} // namespace rocm
} // namespace stream_executor
#endif // TENSORFLOW_STREAM_EXECUTOR_ROCM_ROCM_PLATFORM_ID_H_

View File

@ -14,15 +14,17 @@ limitations under the License.
==============================================================================*/
#include "rocm/include/hiprand/hiprand.h"
#include "tensorflow/stream_executor/gpu/gpu_rng.h"
#include "tensorflow/stream_executor/device_memory.h"
#include "tensorflow/stream_executor/gpu/gpu_activation.h"
#include "tensorflow/stream_executor/gpu/gpu_executor.h"
#include "tensorflow/stream_executor/gpu/gpu_helpers.h"
#include "tensorflow/stream_executor/gpu/gpu_rng.h"
#include "tensorflow/stream_executor/gpu/gpu_stream.h"
#include "tensorflow/stream_executor/lib/env.h"
#include "tensorflow/stream_executor/lib/initialize.h"
#include "tensorflow/stream_executor/lib/status.h"
#include "tensorflow/stream_executor/platform/dso_loader.h"
#include "tensorflow/stream_executor/platform/logging.h"
#include "tensorflow/stream_executor/rng.h"
#include "tensorflow/stream_executor/rocm/rocm_platform_id.h"
@ -60,7 +62,9 @@ PLUGIN_REGISTRY_DEFINE_PLUGIN_ID(kGpuRandPlugin);
namespace wrap {
#define PERFTOOLS_GPUTOOLS_HIPRAND_WRAP(__name) \
#ifdef PLATFORM_GOOGLE
#define STREAM_EXECUTOR_HIPRAND_WRAP(__name) \
struct WrapperShim__##__name { \
template <typename... Args> \
hiprandStatus_t operator()(GpuExecutor* parent, Args... args) { \
@ -69,15 +73,47 @@ namespace wrap {
} \
} __name;
PERFTOOLS_GPUTOOLS_HIPRAND_WRAP(hiprandCreateGenerator);
PERFTOOLS_GPUTOOLS_HIPRAND_WRAP(hiprandDestroyGenerator);
PERFTOOLS_GPUTOOLS_HIPRAND_WRAP(hiprandSetStream);
PERFTOOLS_GPUTOOLS_HIPRAND_WRAP(hiprandGenerateUniform);
PERFTOOLS_GPUTOOLS_HIPRAND_WRAP(hiprandGenerateUniformDouble);
PERFTOOLS_GPUTOOLS_HIPRAND_WRAP(hiprandSetPseudoRandomGeneratorSeed);
PERFTOOLS_GPUTOOLS_HIPRAND_WRAP(hiprandSetGeneratorOffset);
PERFTOOLS_GPUTOOLS_HIPRAND_WRAP(hiprandGenerateNormal);
PERFTOOLS_GPUTOOLS_HIPRAND_WRAP(hiprandGenerateNormalDouble);
#else
#define STREAM_EXECUTOR_HIPRAND_WRAP(__name) \
struct DynLoadShim__##__name { \
static const char* kName; \
using FuncPtrT = std::add_pointer<decltype(::__name)>::type; \
static void* GetDsoHandle() { \
auto s = internal::CachedDsoLoader::GetRocrandDsoHandle(); \
return s.ValueOrDie(); \
} \
static FuncPtrT LoadOrDie() { \
void* f; \
auto s = port::Env::Default()->GetSymbolFromLibrary(GetDsoHandle(), \
kName, &f); \
CHECK(s.ok()) << "could not find " << kName \
<< " in rocrand DSO; dlerror: " << s.error_message(); \
return reinterpret_cast<FuncPtrT>(f); \
} \
static FuncPtrT DynLoad() { \
static FuncPtrT f = LoadOrDie(); \
return f; \
} \
template <typename... Args> \
hiprandStatus operator()(GpuExecutor* parent, Args... args) { \
gpu::ScopedActivateExecutorContext sac{parent}; \
return DynLoad()(args...); \
} \
} __name; \
const char* DynLoadShim__##__name::kName = #__name;
#endif
STREAM_EXECUTOR_HIPRAND_WRAP(hiprandCreateGenerator);
STREAM_EXECUTOR_HIPRAND_WRAP(hiprandDestroyGenerator);
STREAM_EXECUTOR_HIPRAND_WRAP(hiprandSetStream);
STREAM_EXECUTOR_HIPRAND_WRAP(hiprandGenerateUniform);
STREAM_EXECUTOR_HIPRAND_WRAP(hiprandGenerateUniformDouble);
STREAM_EXECUTOR_HIPRAND_WRAP(hiprandSetPseudoRandomGeneratorSeed);
STREAM_EXECUTOR_HIPRAND_WRAP(hiprandSetGeneratorOffset);
STREAM_EXECUTOR_HIPRAND_WRAP(hiprandGenerateNormal);
STREAM_EXECUTOR_HIPRAND_WRAP(hiprandGenerateNormalDouble);
} // namespace wrap
@ -245,40 +281,45 @@ bool GpuRng::SetSeed(Stream* stream, const uint8* seed, uint64 seed_bytes) {
}
} // namespace gpu
void initialize_rocrand() {
auto rocRandAlreadyRegistered = PluginRegistry::Instance()->HasFactory(
rocm::kROCmPlatformId, PluginKind::kRng, gpu::kGpuRandPlugin);
if (!rocRandAlreadyRegistered) {
port::Status status =
PluginRegistry::Instance()->RegisterFactory<PluginRegistry::RngFactory>(
rocm::kROCmPlatformId, gpu::kGpuRandPlugin, "rocRAND",
[](internal::StreamExecutorInterface* parent) -> rng::RngSupport* {
gpu::GpuExecutor* rocm_executor =
dynamic_cast<gpu::GpuExecutor*>(parent);
if (rocm_executor == nullptr) {
LOG(ERROR)
<< "Attempting to initialize an instance of the hipRAND "
<< "support library with a non-ROCM StreamExecutor";
return nullptr;
}
gpu::GpuRng* rng = new gpu::GpuRng(rocm_executor);
if (!rng->Init()) {
// Note: Init() will log a more specific error.
delete rng;
return nullptr;
}
return rng;
});
if (!status.ok()) {
LOG(ERROR) << "Unable to register rocRAND factory: "
<< status.error_message();
}
PluginRegistry::Instance()->SetDefaultFactory(
rocm::kROCmPlatformId, PluginKind::kRng, gpu::kGpuRandPlugin);
}
}
} // namespace stream_executor
namespace se = ::stream_executor;
REGISTER_MODULE_INITIALIZER(register_hiprand, {
se::port::Status status =
se::PluginRegistry::Instance()
->RegisterFactory<se::PluginRegistry::RngFactory>(
se::gpu::kROCmPlatformId, se::gpu::kGpuRandPlugin, "hipRAND",
[](se::internal::StreamExecutorInterface* parent)
-> se::rng::RngSupport* {
se::gpu::GpuExecutor* rocm_executor =
dynamic_cast<se::gpu::GpuExecutor*>(parent);
if (rocm_executor == nullptr) {
LOG(ERROR)
<< "Attempting to initialize an instance of the hipRAND "
<< "support library with a non-ROCM StreamExecutor";
return nullptr;
}
se::gpu::GpuRng* rng = new se::gpu::GpuRng(rocm_executor);
if (!rng->Init()) {
// Note: Init() will log a more specific error.
delete rng;
return nullptr;
}
return rng;
});
if (!status.ok()) {
LOG(ERROR) << "Unable to register hipRAND factory: "
<< status.error_message();
}
se::PluginRegistry::Instance()->SetDefaultFactory(
se::gpu::kROCmPlatformId, se::PluginKind::kRng, se::gpu::kGpuRandPlugin);
});
REGISTER_MODULE_INITIALIZER(register_rocrand,
{ stream_executor::initialize_rocrand(); });