Merge pull request #25424 from ROCmSoftwarePlatform:google_upstream_add_rocm_blas_fft_rand
PiperOrigin-RevId: 234978728
This commit is contained in:
commit
6a25b9d561
@ -39,7 +39,7 @@ port::StatusOr<void*> GetDsoHandle(const string& name, const string& version) {
|
|||||||
port::Status status =
|
port::Status status =
|
||||||
port::Env::Default()->LoadLibrary(filename.c_str(), &dso_handle);
|
port::Env::Default()->LoadLibrary(filename.c_str(), &dso_handle);
|
||||||
if (status.ok()) {
|
if (status.ok()) {
|
||||||
LOG(INFO) << "Successfully opened CUDA library " << filename;
|
LOG(INFO) << "Successfully opened dynamic library " << filename;
|
||||||
return dso_handle;
|
return dso_handle;
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -99,6 +99,25 @@ port::StatusOr<void*> GetCuptiDsoHandle() {
|
|||||||
port::StatusOr<void*> GetCudnnDsoHandle() {
|
port::StatusOr<void*> GetCudnnDsoHandle() {
|
||||||
return GetDsoHandle("cudnn", GetCudnnVersion());
|
return GetDsoHandle("cudnn", GetCudnnVersion());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
port::StatusOr<void*> GetRocblasDsoHandle() {
|
||||||
|
return GetDsoHandle("rocblas", "");
|
||||||
|
}
|
||||||
|
|
||||||
|
port::StatusOr<void*> GetMiopenDsoHandle() {
|
||||||
|
return GetDsoHandle("MIOpen", "");
|
||||||
|
}
|
||||||
|
|
||||||
|
port::StatusOr<void*> GetRocfftDsoHandle() {
|
||||||
|
return GetDsoHandle("rocfft", "");
|
||||||
|
}
|
||||||
|
|
||||||
|
port::StatusOr<void*> GetRocrandDsoHandle() {
|
||||||
|
return GetDsoHandle("rocrand", "");
|
||||||
|
}
|
||||||
|
|
||||||
|
port::StatusOr<void*> GetHipDsoHandle() { return GetDsoHandle("hip_hcc", ""); }
|
||||||
|
|
||||||
} // namespace DsoLoader
|
} // namespace DsoLoader
|
||||||
|
|
||||||
namespace CachedDsoLoader {
|
namespace CachedDsoLoader {
|
||||||
@ -136,6 +155,32 @@ port::StatusOr<void*> GetCudnnDsoHandle() {
|
|||||||
static auto result = new auto(DsoLoader::GetCudnnDsoHandle());
|
static auto result = new auto(DsoLoader::GetCudnnDsoHandle());
|
||||||
return *result;
|
return *result;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
port::StatusOr<void*> GetRocblasDsoHandle() {
|
||||||
|
static auto result = new auto(DsoLoader::GetRocblasDsoHandle());
|
||||||
|
return result;
|
||||||
|
}
|
||||||
|
|
||||||
|
port::StatusOr<void*> GetMiopenDsoHandle() {
|
||||||
|
static auto result = new auto(DsoLoader::GetMiopenDsoHandle());
|
||||||
|
return result;
|
||||||
|
}
|
||||||
|
|
||||||
|
port::StatusOr<void*> GetRocfftDsoHandle() {
|
||||||
|
static auto result = new auto(DsoLoader::GetRocfftDsoHandle());
|
||||||
|
return result;
|
||||||
|
}
|
||||||
|
|
||||||
|
port::StatusOr<void*> GetRocrandDsoHandle() {
|
||||||
|
static auto result = new auto(DsoLoader::GetRocrandDsoHandle());
|
||||||
|
return result;
|
||||||
|
}
|
||||||
|
|
||||||
|
port::StatusOr<void*> GetHipDsoHandle() {
|
||||||
|
static auto result = new auto(DsoLoader::GetHipDsoHandle());
|
||||||
|
return result;
|
||||||
|
}
|
||||||
|
|
||||||
} // namespace CachedDsoLoader
|
} // namespace CachedDsoLoader
|
||||||
} // namespace internal
|
} // namespace internal
|
||||||
} // namespace stream_executor
|
} // namespace stream_executor
|
||||||
|
@ -41,6 +41,12 @@ port::StatusOr<void*> GetCufftDsoHandle();
|
|||||||
port::StatusOr<void*> GetCurandDsoHandle();
|
port::StatusOr<void*> GetCurandDsoHandle();
|
||||||
port::StatusOr<void*> GetCuptiDsoHandle();
|
port::StatusOr<void*> GetCuptiDsoHandle();
|
||||||
port::StatusOr<void*> GetCudnnDsoHandle();
|
port::StatusOr<void*> GetCudnnDsoHandle();
|
||||||
|
|
||||||
|
port::StatusOr<void*> GetRocblasDsoHandle();
|
||||||
|
port::StatusOr<void*> GetMiopenDsoHandle();
|
||||||
|
port::StatusOr<void*> GetRocfftDsoHandle();
|
||||||
|
port::StatusOr<void*> GetRocrandDsoHandle();
|
||||||
|
port::StatusOr<void*> GetHipDsoHandle();
|
||||||
} // namespace DsoLoader
|
} // namespace DsoLoader
|
||||||
|
|
||||||
// Wrapper around the DsoLoader that prevents us from dlopen'ing any of the DSOs
|
// Wrapper around the DsoLoader that prevents us from dlopen'ing any of the DSOs
|
||||||
@ -54,7 +60,14 @@ port::StatusOr<void*> GetCufftDsoHandle();
|
|||||||
port::StatusOr<void*> GetCurandDsoHandle();
|
port::StatusOr<void*> GetCurandDsoHandle();
|
||||||
port::StatusOr<void*> GetCuptiDsoHandle();
|
port::StatusOr<void*> GetCuptiDsoHandle();
|
||||||
port::StatusOr<void*> GetCudnnDsoHandle();
|
port::StatusOr<void*> GetCudnnDsoHandle();
|
||||||
|
|
||||||
|
port::StatusOr<void*> GetRocblasDsoHandle();
|
||||||
|
port::StatusOr<void*> GetMiopenDsoHandle();
|
||||||
|
port::StatusOr<void*> GetRocfftDsoHandle();
|
||||||
|
port::StatusOr<void*> GetRocrandDsoHandle();
|
||||||
|
port::StatusOr<void*> GetHipDsoHandle();
|
||||||
} // namespace CachedDsoLoader
|
} // namespace CachedDsoLoader
|
||||||
|
|
||||||
} // namespace internal
|
} // namespace internal
|
||||||
} // namespace stream_executor
|
} // namespace stream_executor
|
||||||
|
|
||||||
|
@ -33,7 +33,7 @@ filegroup(
|
|||||||
cc_library(
|
cc_library(
|
||||||
name = "rocm_diagnostics",
|
name = "rocm_diagnostics",
|
||||||
srcs = if_rocm_is_configured(["rocm_diagnostics.cc"]),
|
srcs = if_rocm_is_configured(["rocm_diagnostics.cc"]),
|
||||||
hdrs = [],
|
hdrs = if_rocm_is_configured(["rocm_diagnostics.h"]),
|
||||||
deps = if_rocm_is_configured([
|
deps = if_rocm_is_configured([
|
||||||
"@com_google_absl//absl/container:inlined_vector",
|
"@com_google_absl//absl/container:inlined_vector",
|
||||||
"@com_google_absl//absl/strings",
|
"@com_google_absl//absl/strings",
|
||||||
@ -47,7 +47,7 @@ cc_library(
|
|||||||
cc_library(
|
cc_library(
|
||||||
name = "rocm_driver",
|
name = "rocm_driver",
|
||||||
srcs = if_rocm_is_configured(["rocm_driver.cc"]),
|
srcs = if_rocm_is_configured(["rocm_driver.cc"]),
|
||||||
hdrs = [],
|
hdrs = if_rocm_is_configured(["rocm_driver_wrapper.h"]),
|
||||||
deps = if_rocm_is_configured([
|
deps = if_rocm_is_configured([
|
||||||
":rocm_diagnostics",
|
":rocm_diagnostics",
|
||||||
"@com_google_absl//absl/base",
|
"@com_google_absl//absl/base",
|
||||||
@ -57,6 +57,7 @@ cc_library(
|
|||||||
"//tensorflow/stream_executor/gpu:gpu_driver_header",
|
"//tensorflow/stream_executor/gpu:gpu_driver_header",
|
||||||
"//tensorflow/stream_executor/lib",
|
"//tensorflow/stream_executor/lib",
|
||||||
"//tensorflow/stream_executor/platform",
|
"//tensorflow/stream_executor/platform",
|
||||||
|
"//tensorflow/stream_executor/platform:dso_loader",
|
||||||
"@local_config_rocm//rocm:rocm_headers",
|
"@local_config_rocm//rocm:rocm_headers",
|
||||||
]),
|
]),
|
||||||
)
|
)
|
||||||
@ -140,55 +141,62 @@ cc_library(
|
|||||||
deps = ["//tensorflow/stream_executor:platform"],
|
deps = ["//tensorflow/stream_executor:platform"],
|
||||||
)
|
)
|
||||||
|
|
||||||
# FIXME: enable in future PRs
|
cc_library(
|
||||||
#cc_library(
|
name = "rocblas_plugin",
|
||||||
# name = "rocblas_plugin",
|
srcs = if_rocm_is_configured(["rocm_blas.cc"]),
|
||||||
# srcs = ["rocm_blas.cc"],
|
hdrs = if_rocm_is_configured(["rocm_blas.h"]),
|
||||||
# hdrs = ["rocm_blas.h"],
|
visibility = ["//visibility:public"],
|
||||||
# visibility = ["//visibility:public"],
|
deps = if_rocm_is_configured([
|
||||||
# deps = [
|
":rocm_gpu_executor",
|
||||||
# ":rocm_gpu_executor",
|
":rocm_platform_id",
|
||||||
# ":rocm_platform_id",
|
"//third_party/eigen3",
|
||||||
# "//third_party/eigen3",
|
"//tensorflow/core:lib_internal",
|
||||||
# "//tensorflow/core:lib_internal",
|
"//tensorflow/stream_executor",
|
||||||
# "//tensorflow/stream_executor",
|
"//tensorflow/stream_executor:event",
|
||||||
# "//tensorflow/stream_executor:event",
|
"//tensorflow/stream_executor:host_or_device_scalar",
|
||||||
# "//tensorflow/stream_executor:host_or_device_scalar",
|
"//tensorflow/stream_executor:plugin_registry",
|
||||||
# "//tensorflow/stream_executor:plugin_registry",
|
"//tensorflow/stream_executor:scratch_allocator",
|
||||||
# "//tensorflow/stream_executor:scratch_allocator",
|
"//tensorflow/stream_executor:timer",
|
||||||
# "//tensorflow/stream_executor:timer",
|
"//tensorflow/stream_executor/gpu:gpu_activation",
|
||||||
# "//tenosrflow/stream_executor/gpu:gpu_activation_header",
|
"//tensorflow/stream_executor/gpu:gpu_helpers_header",
|
||||||
# "//tenosrflow/stream_executor/gpu:gpu_stream_header",
|
"//tensorflow/stream_executor/gpu:gpu_stream_header",
|
||||||
# "//tenosrflow/stream_executor/gpu:gpu_timer_header",
|
"//tensorflow/stream_executor/gpu:gpu_timer_header",
|
||||||
# "//tensorflow/stream_executor/lib",
|
"//tensorflow/stream_executor/lib",
|
||||||
# "//tensorflow/stream_executor/platform",
|
"//tensorflow/stream_executor/platform",
|
||||||
# "//tensorflow/stream_executor/platform:dso_loader",
|
"//tensorflow/stream_executor/platform:dso_loader",
|
||||||
# "@com_google_absl//absl/strings",
|
"@com_google_absl//absl/strings",
|
||||||
# "@local_config_rocm//rocm:rocm_headers",
|
"@local_config_rocm//rocm:rocm_headers",
|
||||||
# ] + if_static(["@local_config_rocm//rocm:rocblas"]),
|
] + if_static([
|
||||||
# alwayslink = True,
|
"@local_config_rocm//rocm:rocblas",
|
||||||
#)
|
])),
|
||||||
|
alwayslink = True,
|
||||||
|
)
|
||||||
|
|
||||||
# FIXME: enable in future PRs
|
cc_library(
|
||||||
#cc_library(
|
name = "rocfft_plugin",
|
||||||
# name = "rocfft_plugin",
|
srcs = if_rocm_is_configured(["rocm_fft.cc"]),
|
||||||
# srcs = ["rocm_fft.cc"],
|
hdrs = if_rocm_is_configured(["rocm_fft.h"]),
|
||||||
# hdrs = [],
|
visibility = ["//visibility:public"],
|
||||||
# visibility = ["//visibility:public"],
|
deps = if_rocm_is_configured([
|
||||||
# deps = [
|
":rocm_platform_id",
|
||||||
# ":rocm_platform_id",
|
"//tensorflow/stream_executor:event",
|
||||||
# "//tensorflow/stream_executor:event",
|
"//tensorflow/stream_executor:fft",
|
||||||
# "//tensorflow/stream_executor:fft",
|
"//tensorflow/stream_executor:plugin_registry",
|
||||||
# "//tensorflow/stream_executor:plugin_registry",
|
"//tensorflow/stream_executor:scratch_allocator",
|
||||||
# "//tensorflow/stream_executor:scratch_allocator",
|
"//tensorflow/stream_executor/gpu:gpu_activation",
|
||||||
# "//tenosrflow/stream_executor/gpu:gpu_stream_header",
|
"//tensorflow/stream_executor/gpu:gpu_helpers_header",
|
||||||
# "//tensorflow/stream_executor/lib",
|
"//tensorflow/stream_executor/gpu:gpu_executor_header",
|
||||||
# "//tensorflow/stream_executor/platform",
|
"//tensorflow/stream_executor/gpu:gpu_stream_header",
|
||||||
# "//tensorflow/stream_executor/platform:dso_loader",
|
"//tensorflow/stream_executor/gpu:gpu_kernel_header",
|
||||||
# "@local_config_rocm//rocm:rocm_headers",
|
"//tensorflow/stream_executor/lib",
|
||||||
# ] + if_static(["@local_config_rocm//rocm:rocfft"]),
|
"//tensorflow/stream_executor/platform",
|
||||||
# alwayslink = True,
|
"//tensorflow/stream_executor/platform:dso_loader",
|
||||||
#)
|
"@local_config_rocm//rocm:rocm_headers",
|
||||||
|
] + if_static([
|
||||||
|
"@local_config_rocm//rocm:rocfft",
|
||||||
|
])),
|
||||||
|
alwayslink = True,
|
||||||
|
)
|
||||||
|
|
||||||
# FIXME: enable in future PRs
|
# FIXME: enable in future PRs
|
||||||
#cc_library(
|
#cc_library(
|
||||||
@ -229,26 +237,30 @@ cc_library(
|
|||||||
# alwayslink = True,
|
# alwayslink = True,
|
||||||
#)
|
#)
|
||||||
|
|
||||||
# FIXME: enable in future PRs
|
cc_library(
|
||||||
#cc_library(
|
name = "rocrand_plugin",
|
||||||
# name = "rocrand_plugin",
|
srcs = if_rocm_is_configured(["rocm_rng.cc"]),
|
||||||
# srcs = ["rocm_rng.cc"],
|
hdrs = if_rocm_is_configured([]),
|
||||||
# hdrs = [],
|
deps = if_rocm_is_configured([
|
||||||
# deps = [
|
":rocm_gpu_executor",
|
||||||
# ":rocm_gpu_executor",
|
":rocm_platform_id",
|
||||||
# ":rocm_platform_id",
|
"@local_config_rocm//rocm:rocm_headers",
|
||||||
# "@local_config_rocm//rocm:rocm_headers",
|
"//tensorflow/stream_executor:event",
|
||||||
# "//tensorflow/stream_executor:event",
|
"//tensorflow/stream_executor:plugin_registry",
|
||||||
# "//tensorflow/stream_executor:plugin_registry",
|
"//tensorflow/stream_executor:rng",
|
||||||
# "//tensorflow/stream_executor:rng",
|
"//tensorflow/stream_executor/gpu:gpu_activation_header",
|
||||||
# "//tenosrflow/stream_executor/gpu:gpu_activation_header",
|
"//tensorflow/stream_executor/gpu:gpu_helpers_header",
|
||||||
# "//tenosrflow/stream_executor/gpu:gpu_stream_header",
|
"//tensorflow/stream_executor/gpu:gpu_executor_header",
|
||||||
# "//tensorflow/stream_executor/lib",
|
"//tensorflow/stream_executor/gpu:gpu_rng_header",
|
||||||
# "//tensorflow/stream_executor/platform",
|
"//tensorflow/stream_executor/gpu:gpu_stream_header",
|
||||||
# "//tensorflow/stream_executor/platform:dso_loader",
|
"//tensorflow/stream_executor/lib",
|
||||||
# ] + if_static(["@local_config_rocm//rocm:curand"]),
|
"//tensorflow/stream_executor/platform",
|
||||||
# alwayslink = True,
|
"//tensorflow/stream_executor/platform:dso_loader",
|
||||||
#)
|
] + if_static([
|
||||||
|
"@local_config_rocm//rocm:hiprand",
|
||||||
|
])),
|
||||||
|
alwayslink = True,
|
||||||
|
)
|
||||||
|
|
||||||
cc_library(
|
cc_library(
|
||||||
name = "all_runtime",
|
name = "all_runtime",
|
||||||
@ -257,9 +269,9 @@ cc_library(
|
|||||||
deps = if_rocm_is_configured([
|
deps = if_rocm_is_configured([
|
||||||
# FIXME: enable in future PRs
|
# FIXME: enable in future PRs
|
||||||
#":miopen_plugin",
|
#":miopen_plugin",
|
||||||
#":rocfft_plugin",
|
":rocfft_plugin",
|
||||||
#":rocblas_plugin",
|
":rocblas_plugin",
|
||||||
#":rocrand_plugin",
|
":rocrand_plugin",
|
||||||
":rocm_driver",
|
":rocm_driver",
|
||||||
":rocm_platform",
|
":rocm_platform",
|
||||||
]),
|
]),
|
||||||
|
2374
tensorflow/stream_executor/rocm/rocm_blas.cc
Normal file
2374
tensorflow/stream_executor/rocm/rocm_blas.cc
Normal file
File diff suppressed because it is too large
Load Diff
159
tensorflow/stream_executor/rocm/rocm_blas.h
Normal file
159
tensorflow/stream_executor/rocm/rocm_blas.h
Normal file
@ -0,0 +1,159 @@
|
|||||||
|
/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
// ROCM-specific support for BLAS functionality -- this wraps the rocBLAS
|
||||||
|
// library capabilities, and is only included into ROCM implementation code --
|
||||||
|
// it will not introduce rocm headers into other code.
|
||||||
|
|
||||||
|
#ifndef TENSORFLOW_STREAM_EXECUTOR_ROCM_ROCM_BLAS_H_
|
||||||
|
#define TENSORFLOW_STREAM_EXECUTOR_ROCM_ROCM_BLAS_H_
|
||||||
|
|
||||||
|
#include "tensorflow/stream_executor/blas.h"
|
||||||
|
#include "tensorflow/stream_executor/platform/mutex.h"
|
||||||
|
#include "tensorflow/stream_executor/platform/port.h"
|
||||||
|
#include "tensorflow/stream_executor/platform/thread_annotations.h"
|
||||||
|
#include "tensorflow/stream_executor/plugin_registry.h"
|
||||||
|
|
||||||
|
namespace stream_executor {
|
||||||
|
|
||||||
|
class Stream;
|
||||||
|
|
||||||
|
namespace gpu {
|
||||||
|
|
||||||
|
// Opaque and unique identifier for the rocBLAS plugin.
|
||||||
|
extern const PluginId kRocBlasPlugin;
|
||||||
|
|
||||||
|
class GpuExecutor;
|
||||||
|
|
||||||
|
// BLAS plugin for ROCM platform via rocBLAS library.
|
||||||
|
//
|
||||||
|
// This satisfies the platform-agnostic BlasSupport interface.
|
||||||
|
//
|
||||||
|
// Note that the rocBLAS handle that this encapsulates is implicitly tied to the
|
||||||
|
// context (and, as a result, the device) that the parent GpuExecutor is tied
|
||||||
|
// to. This simply happens as an artifact of creating the rocBLAS handle when a
|
||||||
|
// ROCM context is active.
|
||||||
|
//
|
||||||
|
// Thread-safe post-initialization.
|
||||||
|
class ROCMBlas : public blas::BlasSupport {
|
||||||
|
public:
|
||||||
|
explicit ROCMBlas(GpuExecutor *parent);
|
||||||
|
|
||||||
|
// Allocates a rocBLAS handle.
|
||||||
|
bool Init();
|
||||||
|
|
||||||
|
// Releases the rocBLAS handle, if present.
|
||||||
|
~ROCMBlas() override;
|
||||||
|
|
||||||
|
TENSORFLOW_STREAM_EXECUTOR_GPU_BLAS_SUPPORT_OVERRIDES
|
||||||
|
|
||||||
|
private:
|
||||||
|
// Tells rocBLAS to enqueue the BLAS operation onto a particular Stream.
|
||||||
|
//
|
||||||
|
// rocBLAS is stateful, and only be associated with one stream (in order to
|
||||||
|
// enqueue dispatch) at a given time. As a result, this generally must be
|
||||||
|
// invoked before calling into rocBLAS.
|
||||||
|
bool SetStream(Stream *stream) EXCLUSIVE_LOCKS_REQUIRED(mu_);
|
||||||
|
|
||||||
|
// A helper function that calls the real rocBLAS function together with error
|
||||||
|
// handling.
|
||||||
|
//
|
||||||
|
// rocblas_func: rocBLAS function pointer.
|
||||||
|
// rocblas_name: rocBLAS function name.
|
||||||
|
// stream: Stream to enqueue the BLAS operation onto.
|
||||||
|
// pointer_mode_host: Indicate if the pointer to a scalar value is from host
|
||||||
|
// (true) or device (false).
|
||||||
|
// err_on_failure: Whether to print an error if the rocBLAS function
|
||||||
|
// fails. args: Arguments of rocBLAS function.
|
||||||
|
template <typename FuncT, typename... Args>
|
||||||
|
bool DoBlasInternalImpl(FuncT rocblas_func, Stream *stream,
|
||||||
|
bool pointer_mode_host, bool err_on_failure,
|
||||||
|
Args... args);
|
||||||
|
|
||||||
|
// Convenience functions that call DoBlasInternalImpl with different values
|
||||||
|
// for err_on_failure.
|
||||||
|
template <typename FuncT, typename... Args>
|
||||||
|
bool DoBlasInternal(FuncT rocblas_func, Stream *stream,
|
||||||
|
bool pointer_mode_host, Args... args) {
|
||||||
|
return DoBlasInternalImpl(rocblas_func, stream, pointer_mode_host,
|
||||||
|
/*err_on_failure=*/true, args...);
|
||||||
|
}
|
||||||
|
template <typename FuncT, typename... Args>
|
||||||
|
bool DoBlasInternalFailureOK(FuncT rocblas_func, Stream *stream,
|
||||||
|
bool pointer_mode_host, Args... args) {
|
||||||
|
return DoBlasInternalImpl(rocblas_func, stream, pointer_mode_host,
|
||||||
|
/*err_on_failure=*/false, args...);
|
||||||
|
}
|
||||||
|
|
||||||
|
// A helper function to implement DoBlasGemmBatched interfaces for generic
|
||||||
|
// types.
|
||||||
|
template <typename T, typename FuncT>
|
||||||
|
port::Status DoBlasGemmBatchedInternal(
|
||||||
|
FuncT rocblas_func, Stream *stream, blas::Transpose transa,
|
||||||
|
blas::Transpose transb, uint64 m, uint64 n, uint64 k, T alpha,
|
||||||
|
const port::ArraySlice<DeviceMemory<T> *> &a_array, int lda,
|
||||||
|
const port::ArraySlice<DeviceMemory<T> *> &b_array, int ldb, T beta,
|
||||||
|
const port::ArraySlice<DeviceMemory<T> *> &c_array, int ldc,
|
||||||
|
int batch_count, ScratchAllocator *scratch_allocator);
|
||||||
|
|
||||||
|
// Helper function for implementing DoBlasGemmWithAlgorithm.
|
||||||
|
//
|
||||||
|
// We take alpha and beta by const reference because T might be Eigen::half,
|
||||||
|
// and we want to avoid pulling in a dependency on Eigen. When we pass the
|
||||||
|
// references to rocBLAS, we essentially reinterpret_cast to __half, which is
|
||||||
|
// safe because Eigen::half inherits from __half.
|
||||||
|
template <typename InT, typename OutT, typename CompT>
|
||||||
|
bool DoBlasGemmWithAlgorithmImpl(
|
||||||
|
Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64 m,
|
||||||
|
uint64 n, uint64 k, const CompT &alpha, const DeviceMemory<InT> &a,
|
||||||
|
int lda, const DeviceMemory<InT> &b, int ldb, const CompT &beta,
|
||||||
|
DeviceMemory<OutT> *c, int ldc, blas::ComputationType computation_type,
|
||||||
|
blas::AlgorithmType algorithm,
|
||||||
|
blas::ProfileResult *output_profile_result);
|
||||||
|
|
||||||
|
// Helper function for implementing DoBlasGemmWithProfiling.
|
||||||
|
template <typename T, typename ParamType>
|
||||||
|
bool DoBlasGemmWithProfilingImpl(
|
||||||
|
Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64 m,
|
||||||
|
uint64 n, uint64 k, const ParamType &alpha, const DeviceMemory<T> &a,
|
||||||
|
int lda, const DeviceMemory<T> &b, int ldb, const ParamType &beta,
|
||||||
|
DeviceMemory<T> *c, int ldc, blas::ProfileResult *output_profile_result);
|
||||||
|
|
||||||
|
// Helper function for implementing DoBlasGemvWithProfiling.
|
||||||
|
template <typename T>
|
||||||
|
bool DoBlasGemvWithProfilingImpl(Stream *stream, blas::Transpose trans,
|
||||||
|
uint64 m, uint64 n, const T &alpha,
|
||||||
|
const DeviceMemory<T> &a, int lda,
|
||||||
|
const DeviceMemory<T> &x, int incx,
|
||||||
|
const T &beta, DeviceMemory<T> *y, int incy,
|
||||||
|
blas::ProfileResult *output_profile_result);
|
||||||
|
|
||||||
|
// mutex that guards the rocBLAS handle for this device.
|
||||||
|
mutex mu_;
|
||||||
|
|
||||||
|
// GpuExecutor which instantiated this ROCMBlas.
|
||||||
|
// Immutable post-initialization.
|
||||||
|
GpuExecutor *parent_;
|
||||||
|
|
||||||
|
// rocBLAS library handle on the device.
|
||||||
|
rocblas_handle blas_ GUARDED_BY(mu_);
|
||||||
|
|
||||||
|
SE_DISALLOW_COPY_AND_ASSIGN(ROCMBlas);
|
||||||
|
};
|
||||||
|
|
||||||
|
} // namespace gpu
|
||||||
|
} // namespace stream_executor
|
||||||
|
|
||||||
|
#endif // TENSORFLOW_STREAM_EXECUTOR_ROCM_ROCM_BLAS_H_
|
@ -30,7 +30,6 @@ limitations under the License.
|
|||||||
#include "absl/container/inlined_vector.h"
|
#include "absl/container/inlined_vector.h"
|
||||||
#include "absl/strings/str_cat.h"
|
#include "absl/strings/str_cat.h"
|
||||||
#include "absl/strings/str_format.h"
|
#include "absl/strings/str_format.h"
|
||||||
#include "tensorflow/stream_executor/gpu/gpu_diagnostics.h"
|
|
||||||
#include "tensorflow/stream_executor/lib/error.h"
|
#include "tensorflow/stream_executor/lib/error.h"
|
||||||
#include "tensorflow/stream_executor/lib/numbers.h"
|
#include "tensorflow/stream_executor/lib/numbers.h"
|
||||||
#include "tensorflow/stream_executor/lib/process_state.h"
|
#include "tensorflow/stream_executor/lib/process_state.h"
|
||||||
@ -38,9 +37,10 @@ limitations under the License.
|
|||||||
#include "tensorflow/stream_executor/lib/str_util.h"
|
#include "tensorflow/stream_executor/lib/str_util.h"
|
||||||
#include "tensorflow/stream_executor/lib/stringprintf.h"
|
#include "tensorflow/stream_executor/lib/stringprintf.h"
|
||||||
#include "tensorflow/stream_executor/platform/logging.h"
|
#include "tensorflow/stream_executor/platform/logging.h"
|
||||||
|
#include "tensorflow/stream_executor/rocm/rocm_diagnostics.h"
|
||||||
|
|
||||||
namespace stream_executor {
|
namespace stream_executor {
|
||||||
namespace gpu {
|
namespace rocm {
|
||||||
|
|
||||||
string DriverVersionToString(DriverVersion version) {
|
string DriverVersionToString(DriverVersion version) {
|
||||||
return absl::StrFormat("%d.%d.%d", std::get<0>(version), std::get<1>(version),
|
return absl::StrFormat("%d.%d.%d", std::get<0>(version), std::get<1>(version),
|
||||||
@ -95,6 +95,12 @@ port::StatusOr<DriverVersion> StringToDriverVersion(const string& value) {
|
|||||||
return result;
|
return result;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
} // namespace rocm
|
||||||
|
} // namespace stream_executor
|
||||||
|
|
||||||
|
namespace stream_executor {
|
||||||
|
namespace gpu {
|
||||||
|
|
||||||
// -- class Diagnostician
|
// -- class Diagnostician
|
||||||
|
|
||||||
string Diagnostician::GetDevNodePath(int dev_node_ordinal) {
|
string Diagnostician::GetDevNodePath(int dev_node_ordinal) {
|
||||||
@ -133,11 +139,11 @@ void Diagnostician::LogDiagnosticInformation() {
|
|||||||
}
|
}
|
||||||
port::StatusOr<DriverVersion> dso_version = FindDsoVersion();
|
port::StatusOr<DriverVersion> dso_version = FindDsoVersion();
|
||||||
LOG(INFO) << "librocm reported version is: "
|
LOG(INFO) << "librocm reported version is: "
|
||||||
<< DriverVersionStatusToString(dso_version);
|
<< rocm::DriverVersionStatusToString(dso_version);
|
||||||
|
|
||||||
port::StatusOr<DriverVersion> kernel_version = FindKernelDriverVersion();
|
port::StatusOr<DriverVersion> kernel_version = FindKernelDriverVersion();
|
||||||
LOG(INFO) << "kernel reported version is: "
|
LOG(INFO) << "kernel reported version is: "
|
||||||
<< DriverVersionStatusToString(kernel_version);
|
<< rocm::DriverVersionStatusToString(kernel_version);
|
||||||
|
|
||||||
if (kernel_version.ok() && dso_version.ok()) {
|
if (kernel_version.ok() && dso_version.ok()) {
|
||||||
WarnOnDsoKernelMismatch(dso_version, kernel_version);
|
WarnOnDsoKernelMismatch(dso_version, kernel_version);
|
||||||
@ -175,7 +181,7 @@ port::StatusOr<DriverVersion> Diagnostician::FindDsoVersion() {
|
|||||||
// TODO(b/22689637): Eliminate the explicit namespace if possible.
|
// TODO(b/22689637): Eliminate the explicit namespace if possible.
|
||||||
auto stripped_dso_version = port::StripSuffixString(dso_version, ".ld64");
|
auto stripped_dso_version = port::StripSuffixString(dso_version, ".ld64");
|
||||||
auto result = static_cast<port::StatusOr<DriverVersion>*>(data);
|
auto result = static_cast<port::StatusOr<DriverVersion>*>(data);
|
||||||
*result = StringToDriverVersion(stripped_dso_version);
|
*result = rocm::StringToDriverVersion(stripped_dso_version);
|
||||||
return 1;
|
return 1;
|
||||||
}
|
}
|
||||||
return 0;
|
return 0;
|
||||||
@ -205,7 +211,7 @@ port::StatusOr<DriverVersion> Diagnostician::FindKernelModuleVersion(
|
|||||||
// TODO(b/22689637): Eliminate the explicit namespace if possible.
|
// TODO(b/22689637): Eliminate the explicit namespace if possible.
|
||||||
auto stripped_kernel_version =
|
auto stripped_kernel_version =
|
||||||
port::StripSuffixString(kernel_version, ".ld64");
|
port::StripSuffixString(kernel_version, ".ld64");
|
||||||
return StringToDriverVersion(stripped_kernel_version);
|
return rocm::StringToDriverVersion(stripped_kernel_version);
|
||||||
}
|
}
|
||||||
|
|
||||||
void Diagnostician::WarnOnDsoKernelMismatch(
|
void Diagnostician::WarnOnDsoKernelMismatch(
|
||||||
@ -214,12 +220,12 @@ void Diagnostician::WarnOnDsoKernelMismatch(
|
|||||||
if (kernel_version.ok() && dso_version.ok() &&
|
if (kernel_version.ok() && dso_version.ok() &&
|
||||||
dso_version.ValueOrDie() == kernel_version.ValueOrDie()) {
|
dso_version.ValueOrDie() == kernel_version.ValueOrDie()) {
|
||||||
LOG(INFO) << "kernel version seems to match DSO: "
|
LOG(INFO) << "kernel version seems to match DSO: "
|
||||||
<< DriverVersionToString(kernel_version.ValueOrDie());
|
<< rocm::DriverVersionToString(kernel_version.ValueOrDie());
|
||||||
} else {
|
} else {
|
||||||
LOG(ERROR) << "kernel version "
|
LOG(ERROR) << "kernel version "
|
||||||
<< DriverVersionStatusToString(kernel_version)
|
<< rocm::DriverVersionStatusToString(kernel_version)
|
||||||
<< " does not match DSO version "
|
<< " does not match DSO version "
|
||||||
<< DriverVersionStatusToString(dso_version)
|
<< rocm::DriverVersionStatusToString(dso_version)
|
||||||
<< " -- cannot find working devices in this configuration";
|
<< " -- cannot find working devices in this configuration";
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
41
tensorflow/stream_executor/rocm/rocm_diagnostics.h
Normal file
41
tensorflow/stream_executor/rocm/rocm_diagnostics.h
Normal file
@ -0,0 +1,41 @@
|
|||||||
|
/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
#ifndef TENSORFLOW_STREAM_EXECUTOR_ROCM_ROCM_DIAGNOSTICS_H_
|
||||||
|
#define TENSORFLOW_STREAM_EXECUTOR_ROCM_ROCM_DIAGNOSTICS_H_
|
||||||
|
|
||||||
|
#include "tensorflow/stream_executor/gpu/gpu_diagnostics.h"
|
||||||
|
|
||||||
|
namespace stream_executor {
|
||||||
|
namespace rocm {
|
||||||
|
|
||||||
|
// e.g. DriverVersion{346, 3, 4}
|
||||||
|
using DriverVersion = gpu::DriverVersion;
|
||||||
|
|
||||||
|
// Converts a parsed driver version to string form.
|
||||||
|
string DriverVersionToString(DriverVersion version);
|
||||||
|
|
||||||
|
// Converts a parsed driver version or status value to natural string form.
|
||||||
|
string DriverVersionStatusToString(port::StatusOr<DriverVersion> version);
|
||||||
|
|
||||||
|
// Converts a string of a form like "331.79" to a DriverVersion{331, 79}.
|
||||||
|
port::StatusOr<DriverVersion> StringToDriverVersion(const string& value);
|
||||||
|
|
||||||
|
using Diagnostician = gpu::Diagnostician;
|
||||||
|
|
||||||
|
} // namespace rocm
|
||||||
|
} // namespace stream_executor
|
||||||
|
|
||||||
|
#endif // TENSORFLOW_STREAM_EXECUTOR_ROCM_ROCM_DIAGNOSTICS_H_
|
@ -36,6 +36,7 @@ limitations under the License.
|
|||||||
#include "tensorflow/stream_executor/platform/logging.h"
|
#include "tensorflow/stream_executor/platform/logging.h"
|
||||||
#include "tensorflow/stream_executor/platform/mutex.h"
|
#include "tensorflow/stream_executor/platform/mutex.h"
|
||||||
#include "tensorflow/stream_executor/platform/port.h"
|
#include "tensorflow/stream_executor/platform/port.h"
|
||||||
|
#include "tensorflow/stream_executor/rocm/rocm_driver_wrapper.h"
|
||||||
|
|
||||||
bool FLAGS_gpuexec_rocm_driver_inject_init_error = false;
|
bool FLAGS_gpuexec_rocm_driver_inject_init_error = false;
|
||||||
bool FLAGS_gpuexec_rocm_sync_around_driver_calls = false;
|
bool FLAGS_gpuexec_rocm_sync_around_driver_calls = false;
|
||||||
@ -143,7 +144,7 @@ string MemorySpaceString(MemorySpace memory_space) {
|
|||||||
// HIP driver (e.g., this value is not our cached view of the current device).
|
// HIP driver (e.g., this value is not our cached view of the current device).
|
||||||
static int CurrentDeviceOrDie() {
|
static int CurrentDeviceOrDie() {
|
||||||
int current = -1;
|
int current = -1;
|
||||||
hipError_t result = hipGetDevice(¤t);
|
hipError_t result = tensorflow::wrap::hipGetDevice(¤t);
|
||||||
if (result != hipSuccess) {
|
if (result != hipSuccess) {
|
||||||
LOG(FATAL) << "failed to query current device: " << ToString(result);
|
LOG(FATAL) << "failed to query current device: " << ToString(result);
|
||||||
}
|
}
|
||||||
@ -154,7 +155,7 @@ namespace {
|
|||||||
|
|
||||||
// Call hipDeviceSynchronize and crash if it doesn't succeed.
|
// Call hipDeviceSynchronize and crash if it doesn't succeed.
|
||||||
void SynchronizeOrDie() {
|
void SynchronizeOrDie() {
|
||||||
auto res = hipDeviceSynchronize();
|
auto res = tensorflow::wrap::hipDeviceSynchronize();
|
||||||
if (res != hipSuccess) {
|
if (res != hipSuccess) {
|
||||||
LOG(FATAL) << "Synchronize found " << ToString(res)
|
LOG(FATAL) << "Synchronize found " << ToString(res)
|
||||||
<< " :: " << port::CurrentStackTrace();
|
<< " :: " << port::CurrentStackTrace();
|
||||||
@ -197,7 +198,8 @@ ScopedActivateContext::ScopedActivateContext(GpuContext* context) {
|
|||||||
<< tls->current_device_ordinal << " to " << context->device_ordinal();
|
<< tls->current_device_ordinal << " to " << context->device_ordinal();
|
||||||
|
|
||||||
// Set the device and update thread local.
|
// Set the device and update thread local.
|
||||||
CHECK_EQ(hipSuccess, hipSetDevice(context->device_ordinal()));
|
CHECK_EQ(hipSuccess,
|
||||||
|
tensorflow::wrap::hipSetDevice(context->device_ordinal()));
|
||||||
tls->current_device_ordinal = context->device_ordinal();
|
tls->current_device_ordinal = context->device_ordinal();
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -225,7 +227,8 @@ ScopedActivateContext::~ScopedActivateContext() {
|
|||||||
<< to_restore_->device_ordinal();
|
<< to_restore_->device_ordinal();
|
||||||
|
|
||||||
// Set context and update thread local.
|
// Set context and update thread local.
|
||||||
CHECK_EQ(hipSuccess, hipSetDevice(to_restore_->device_ordinal()));
|
CHECK_EQ(hipSuccess,
|
||||||
|
tensorflow::wrap::hipSetDevice(to_restore_->device_ordinal()));
|
||||||
tls->current_device_ordinal = to_restore_->device_ordinal();
|
tls->current_device_ordinal = to_restore_->device_ordinal();
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -261,7 +264,8 @@ string ROCMPointerToMemorySpaceString(hipDeviceptr_t pointer) {
|
|||||||
// in the process of querying.
|
// in the process of querying.
|
||||||
string ROCMPointersToCanAccessString(hipDeviceptr_t from, hipDeviceptr_t to) {
|
string ROCMPointersToCanAccessString(hipDeviceptr_t from, hipDeviceptr_t to) {
|
||||||
hipPointerAttribute_t from_pointerAttributes;
|
hipPointerAttribute_t from_pointerAttributes;
|
||||||
hipError_t result = hipPointerGetAttributes(&from_pointerAttributes, from);
|
hipError_t result =
|
||||||
|
tensorflow::wrap::hipPointerGetAttributes(&from_pointerAttributes, from);
|
||||||
if (result != hipSuccess) {
|
if (result != hipSuccess) {
|
||||||
LOG(ERROR) << "could not retrieve source pointer's device: "
|
LOG(ERROR) << "could not retrieve source pointer's device: "
|
||||||
<< ToString(result);
|
<< ToString(result);
|
||||||
@ -269,7 +273,7 @@ string ROCMPointersToCanAccessString(hipDeviceptr_t from, hipDeviceptr_t to) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
hipPointerAttribute_t to_pointerAttributes;
|
hipPointerAttribute_t to_pointerAttributes;
|
||||||
result = hipPointerGetAttributes(&to_pointerAttributes, to);
|
result = tensorflow::wrap::hipPointerGetAttributes(&to_pointerAttributes, to);
|
||||||
if (result != hipSuccess) {
|
if (result != hipSuccess) {
|
||||||
LOG(ERROR) << "could not retrieve destination pointer's device: "
|
LOG(ERROR) << "could not retrieve destination pointer's device: "
|
||||||
<< ToString(result);
|
<< ToString(result);
|
||||||
@ -289,7 +293,7 @@ static port::Status InternalInit() {
|
|||||||
if (FLAGS_gpuexec_rocm_driver_inject_init_error) {
|
if (FLAGS_gpuexec_rocm_driver_inject_init_error) {
|
||||||
LOG(ERROR) << "injecting ROCM init error; initialization will fail";
|
LOG(ERROR) << "injecting ROCM init error; initialization will fail";
|
||||||
} else {
|
} else {
|
||||||
res = hipInit(0 /* = flags */);
|
res = tensorflow::wrap::hipInit(0 /* = flags */);
|
||||||
}
|
}
|
||||||
|
|
||||||
if (res == hipSuccess) {
|
if (res == hipSuccess) {
|
||||||
@ -322,7 +326,7 @@ static port::Status InternalInit() {
|
|||||||
|
|
||||||
/* static */ port::Status GpuDriver::GetDevice(int device_ordinal,
|
/* static */ port::Status GpuDriver::GetDevice(int device_ordinal,
|
||||||
hipDevice_t* device) {
|
hipDevice_t* device) {
|
||||||
hipError_t res = hipDeviceGet(device, device_ordinal);
|
hipError_t res = tensorflow::wrap::hipDeviceGet(device, device_ordinal);
|
||||||
if (res == hipSuccess) {
|
if (res == hipSuccess) {
|
||||||
return port::Status::OK();
|
return port::Status::OK();
|
||||||
}
|
}
|
||||||
@ -336,7 +340,8 @@ static port::Status InternalInit() {
|
|||||||
string* device_name) {
|
string* device_name) {
|
||||||
static const size_t kCharLimit = 64;
|
static const size_t kCharLimit = 64;
|
||||||
absl::InlinedVector<char, 4> chars(kCharLimit);
|
absl::InlinedVector<char, 4> chars(kCharLimit);
|
||||||
hipError_t res = hipDeviceGetName(chars.begin(), kCharLimit - 1, device);
|
hipError_t res =
|
||||||
|
tensorflow::wrap::hipDeviceGetName(chars.begin(), kCharLimit - 1, device);
|
||||||
if (res != hipSuccess) {
|
if (res != hipSuccess) {
|
||||||
LOG(ERROR) << "failed to get device name for " << device << ": "
|
LOG(ERROR) << "failed to get device name for " << device << ": "
|
||||||
<< ToString(res);
|
<< ToString(res);
|
||||||
@ -382,7 +387,8 @@ bool DeviceOptionsToContextFlags(const DeviceOptions& device_options,
|
|||||||
|
|
||||||
/* static */ bool GpuDriver::FuncSetCacheConfig(hipFunction_t function,
|
/* static */ bool GpuDriver::FuncSetCacheConfig(hipFunction_t function,
|
||||||
hipFuncCache_t cache_config) {
|
hipFuncCache_t cache_config) {
|
||||||
hipError_t res = hipFuncSetCacheConfig(function, cache_config);
|
hipError_t res =
|
||||||
|
tensorflow::wrap::hipFuncSetCacheConfig(function, cache_config);
|
||||||
if (res != hipSuccess) {
|
if (res != hipSuccess) {
|
||||||
LOG(ERROR) << "failed to set ROCM kernel cache config. kernel: " << function
|
LOG(ERROR) << "failed to set ROCM kernel cache config. kernel: " << function
|
||||||
<< ", config: " << cache_config << ", result: " << ToString(res);
|
<< ", config: " << cache_config << ", result: " << ToString(res);
|
||||||
@ -396,7 +402,8 @@ bool DeviceOptionsToContextFlags(const DeviceOptions& device_options,
|
|||||||
GpuDriver::ContextGetSharedMemConfig(GpuContext* context) {
|
GpuDriver::ContextGetSharedMemConfig(GpuContext* context) {
|
||||||
hipSharedMemConfig shared_mem_config;
|
hipSharedMemConfig shared_mem_config;
|
||||||
ScopedActivateContext activation{context};
|
ScopedActivateContext activation{context};
|
||||||
hipError_t result = hipDeviceGetSharedMemConfig(&shared_mem_config);
|
hipError_t result =
|
||||||
|
tensorflow::wrap::hipDeviceGetSharedMemConfig(&shared_mem_config);
|
||||||
if (result != hipSuccess) {
|
if (result != hipSuccess) {
|
||||||
LOG(ERROR) << "failed to get ROCM device shared memory config. "
|
LOG(ERROR) << "failed to get ROCM device shared memory config. "
|
||||||
<< "Context device ID: " << context->device_ordinal()
|
<< "Context device ID: " << context->device_ordinal()
|
||||||
@ -411,7 +418,8 @@ GpuDriver::ContextGetSharedMemConfig(GpuContext* context) {
|
|||||||
/* static */ port::Status GpuDriver::ContextSetSharedMemConfig(
|
/* static */ port::Status GpuDriver::ContextSetSharedMemConfig(
|
||||||
GpuContext* context, hipSharedMemConfig shared_mem_config) {
|
GpuContext* context, hipSharedMemConfig shared_mem_config) {
|
||||||
ScopedActivateContext activation{context};
|
ScopedActivateContext activation{context};
|
||||||
hipError_t result = hipDeviceSetSharedMemConfig(shared_mem_config);
|
hipError_t result =
|
||||||
|
tensorflow::wrap::hipDeviceSetSharedMemConfig(shared_mem_config);
|
||||||
if (result != hipSuccess) {
|
if (result != hipSuccess) {
|
||||||
LOG(ERROR) << "failed to set ROCM device shared memory config. "
|
LOG(ERROR) << "failed to set ROCM device shared memory config. "
|
||||||
<< "Context device ID: " << context->device_ordinal()
|
<< "Context device ID: " << context->device_ordinal()
|
||||||
@ -435,7 +443,7 @@ GpuDriver::ContextGetSharedMemConfig(GpuContext* context) {
|
|||||||
<< " gdy: " << grid_dim_y << " gdz: " << grid_dim_z
|
<< " gdy: " << grid_dim_y << " gdz: " << grid_dim_z
|
||||||
<< " bdx: " << block_dim_x << " bdy: " << block_dim_y
|
<< " bdx: " << block_dim_x << " bdy: " << block_dim_y
|
||||||
<< " bdz: " << block_dim_z << " smem: " << shared_mem_bytes;
|
<< " bdz: " << block_dim_z << " smem: " << shared_mem_bytes;
|
||||||
hipError_t res = hipModuleLaunchKernel(
|
hipError_t res = tensorflow::wrap::hipModuleLaunchKernel(
|
||||||
function, grid_dim_x, grid_dim_y, grid_dim_z, block_dim_x, block_dim_y,
|
function, grid_dim_x, grid_dim_y, grid_dim_z, block_dim_x, block_dim_y,
|
||||||
block_dim_z, shared_mem_bytes, stream, kernel_params, extra);
|
block_dim_z, shared_mem_bytes, stream, kernel_params, extra);
|
||||||
if (res != hipSuccess) {
|
if (res != hipSuccess) {
|
||||||
@ -471,7 +479,8 @@ GpuDriver::ContextGetSharedMemConfig(GpuContext* context) {
|
|||||||
ScopedActivateContext activation{context};
|
ScopedActivateContext activation{context};
|
||||||
void* hsaco_data = const_cast<char*>(hsaco_contents);
|
void* hsaco_data = const_cast<char*>(hsaco_contents);
|
||||||
|
|
||||||
hipError_t res = hipModuleLoadData(module, hsaco_data);
|
hipError_t res =
|
||||||
|
tensorflow::wrap::hipModuleLoadData(module, hsaco_data);
|
||||||
|
|
||||||
if (res != hipSuccess) {
|
if (res != hipSuccess) {
|
||||||
LOG(ERROR) << "failed to load HSACO: " << ToString(res);
|
LOG(ERROR) << "failed to load HSACO: " << ToString(res);
|
||||||
@ -491,7 +500,7 @@ GpuDriver::ContextGetSharedMemConfig(GpuContext* context) {
|
|||||||
hipDeviceptr_t location,
|
hipDeviceptr_t location,
|
||||||
uint8 value, size_t size) {
|
uint8 value, size_t size) {
|
||||||
ScopedActivateContext activation{context};
|
ScopedActivateContext activation{context};
|
||||||
hipError_t res = hipMemset(location, value, size);
|
hipError_t res = tensorflow::wrap::hipMemset(location, value, size);
|
||||||
if (res != hipSuccess) {
|
if (res != hipSuccess) {
|
||||||
LOG(ERROR) << "failed to memset memory: " << ToString(res);
|
LOG(ERROR) << "failed to memset memory: " << ToString(res);
|
||||||
return false;
|
return false;
|
||||||
@ -512,8 +521,8 @@ GpuDriver::ContextGetSharedMemConfig(GpuContext* context) {
|
|||||||
LOG(ERROR) << "failed to memset memory";
|
LOG(ERROR) << "failed to memset memory";
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
hipError_t res =
|
hipError_t res = tensorflow::wrap::hipMemset(pointer, static_cast<int>(value),
|
||||||
hipMemset(pointer, static_cast<int>(value), uint32_count * 4);
|
uint32_count * 4);
|
||||||
if (res != hipSuccess) {
|
if (res != hipSuccess) {
|
||||||
LOG(ERROR) << "failed to memset memory: " << ToString(res);
|
LOG(ERROR) << "failed to memset memory: " << ToString(res);
|
||||||
return false;
|
return false;
|
||||||
@ -527,7 +536,8 @@ GpuDriver::ContextGetSharedMemConfig(GpuContext* context) {
|
|||||||
size_t uint32_count,
|
size_t uint32_count,
|
||||||
GpuStreamHandle stream) {
|
GpuStreamHandle stream) {
|
||||||
ScopedActivateContext activation{context};
|
ScopedActivateContext activation{context};
|
||||||
hipError_t res = hipMemsetAsync(location, value, uint32_count, stream);
|
hipError_t res =
|
||||||
|
tensorflow::wrap::hipMemsetAsync(location, value, uint32_count, stream);
|
||||||
if (res != hipSuccess) {
|
if (res != hipSuccess) {
|
||||||
LOG(ERROR) << "failed to enqueue async memset operation: " << ToString(res);
|
LOG(ERROR) << "failed to enqueue async memset operation: " << ToString(res);
|
||||||
return false;
|
return false;
|
||||||
@ -552,7 +562,8 @@ GpuDriver::ContextGetSharedMemConfig(GpuContext* context) {
|
|||||||
LOG(ERROR) << "failed to memset memory";
|
LOG(ERROR) << "failed to memset memory";
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
hipError_t res = hipMemsetAsync(pointer, value, uint32_count * 4, stream);
|
hipError_t res = tensorflow::wrap::hipMemsetAsync(pointer, value,
|
||||||
|
uint32_count * 4, stream);
|
||||||
if (res != hipSuccess) {
|
if (res != hipSuccess) {
|
||||||
LOG(ERROR) << "failed to enqueue async memset operation: " << ToString(res);
|
LOG(ERROR) << "failed to enqueue async memset operation: " << ToString(res);
|
||||||
return false;
|
return false;
|
||||||
@ -565,8 +576,8 @@ GpuDriver::ContextGetSharedMemConfig(GpuContext* context) {
|
|||||||
GpuStreamHandle stream,
|
GpuStreamHandle stream,
|
||||||
StreamCallback callback,
|
StreamCallback callback,
|
||||||
void* data) {
|
void* data) {
|
||||||
hipError_t res = hipStreamAddCallback(stream, (hipStreamCallback_t)callback,
|
hipError_t res = tensorflow::wrap::hipStreamAddCallback(
|
||||||
data, 0 /* = flags */);
|
stream, (hipStreamCallback_t)callback, data, 0 /* = flags */);
|
||||||
if (res != hipSuccess) {
|
if (res != hipSuccess) {
|
||||||
LOG(ERROR) << "unable to add host callback: " << ToString(res);
|
LOG(ERROR) << "unable to add host callback: " << ToString(res);
|
||||||
return false;
|
return false;
|
||||||
@ -580,7 +591,8 @@ GpuDriver::ContextGetSharedMemConfig(GpuContext* context) {
|
|||||||
hipFunction_t* function) {
|
hipFunction_t* function) {
|
||||||
ScopedActivateContext activated{context};
|
ScopedActivateContext activated{context};
|
||||||
CHECK(module != nullptr && kernel_name != nullptr);
|
CHECK(module != nullptr && kernel_name != nullptr);
|
||||||
hipError_t res = hipModuleGetFunction(function, module, kernel_name);
|
hipError_t res =
|
||||||
|
tensorflow::wrap::hipModuleGetFunction(function, module, kernel_name);
|
||||||
if (res != hipSuccess) {
|
if (res != hipSuccess) {
|
||||||
LOG(ERROR) << "failed to get kernel \"" << kernel_name
|
LOG(ERROR) << "failed to get kernel \"" << kernel_name
|
||||||
<< "\" from module: " << ToString(res);
|
<< "\" from module: " << ToString(res);
|
||||||
@ -598,7 +610,8 @@ GpuDriver::ContextGetSharedMemConfig(GpuContext* context) {
|
|||||||
ScopedActivateContext activated{context};
|
ScopedActivateContext activated{context};
|
||||||
CHECK(module != nullptr && symbol_name != nullptr &&
|
CHECK(module != nullptr && symbol_name != nullptr &&
|
||||||
(dptr != nullptr || bytes != nullptr));
|
(dptr != nullptr || bytes != nullptr));
|
||||||
hipError_t res = hipModuleGetGlobal(dptr, bytes, module, symbol_name);
|
hipError_t res =
|
||||||
|
tensorflow::wrap::hipModuleGetGlobal(dptr, bytes, module, symbol_name);
|
||||||
if (res != hipSuccess) {
|
if (res != hipSuccess) {
|
||||||
// symbol may not be found in the current module, but it may reside in
|
// symbol may not be found in the current module, but it may reside in
|
||||||
// another module.
|
// another module.
|
||||||
@ -613,7 +626,7 @@ GpuDriver::ContextGetSharedMemConfig(GpuContext* context) {
|
|||||||
/* static */ void GpuDriver::UnloadModule(GpuContext* context,
|
/* static */ void GpuDriver::UnloadModule(GpuContext* context,
|
||||||
hipModule_t module) {
|
hipModule_t module) {
|
||||||
ScopedActivateContext activated{context};
|
ScopedActivateContext activated{context};
|
||||||
hipError_t res = hipModuleUnload(module);
|
hipError_t res = tensorflow::wrap::hipModuleUnload(module);
|
||||||
if (res != hipSuccess) {
|
if (res != hipSuccess) {
|
||||||
LOG(ERROR) << "failed to unload module " << module
|
LOG(ERROR) << "failed to unload module " << module
|
||||||
<< "; leaking: " << ToString(res);
|
<< "; leaking: " << ToString(res);
|
||||||
@ -623,7 +636,7 @@ GpuDriver::ContextGetSharedMemConfig(GpuContext* context) {
|
|||||||
/* static */ bool GpuDriver::CreateStream(GpuContext* context,
|
/* static */ bool GpuDriver::CreateStream(GpuContext* context,
|
||||||
GpuStreamHandle* stream) {
|
GpuStreamHandle* stream) {
|
||||||
ScopedActivateContext activated{context};
|
ScopedActivateContext activated{context};
|
||||||
hipError_t res = hipStreamCreateWithFlags(
|
hipError_t res = tensorflow::wrap::hipStreamCreateWithFlags(
|
||||||
stream, hipStreamDefault); // switch to hipStreamNonBlocking?
|
stream, hipStreamDefault); // switch to hipStreamNonBlocking?
|
||||||
if (res != hipSuccess) {
|
if (res != hipSuccess) {
|
||||||
LOG(ERROR) << "could not allocate ROCM stream for device "
|
LOG(ERROR) << "could not allocate ROCM stream for device "
|
||||||
@ -643,7 +656,7 @@ GpuDriver::ContextGetSharedMemConfig(GpuContext* context) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
ScopedActivateContext activated{context};
|
ScopedActivateContext activated{context};
|
||||||
hipError_t res = hipStreamDestroy(*stream);
|
hipError_t res = tensorflow::wrap::hipStreamDestroy(*stream);
|
||||||
if (res != hipSuccess) {
|
if (res != hipSuccess) {
|
||||||
LOG(ERROR) << "failed to destroy ROCM stream for device "
|
LOG(ERROR) << "failed to destroy ROCM stream for device "
|
||||||
<< context->device_ordinal() << ": " << ToString(res);
|
<< context->device_ordinal() << ": " << ToString(res);
|
||||||
@ -658,7 +671,7 @@ GpuDriver::ContextGetSharedMemConfig(GpuContext* context) {
|
|||||||
uint64 bytes) {
|
uint64 bytes) {
|
||||||
ScopedActivateContext activated{context};
|
ScopedActivateContext activated{context};
|
||||||
hipDeviceptr_t result = 0;
|
hipDeviceptr_t result = 0;
|
||||||
hipError_t res = hipMalloc(&result, bytes);
|
hipError_t res = tensorflow::wrap::hipMallocVanilla(&result, bytes);
|
||||||
if (res != hipSuccess) {
|
if (res != hipSuccess) {
|
||||||
LOG(ERROR) << "failed to allocate "
|
LOG(ERROR) << "failed to allocate "
|
||||||
<< port::HumanReadableNumBytes::ToString(bytes) << " (" << bytes
|
<< port::HumanReadableNumBytes::ToString(bytes) << " (" << bytes
|
||||||
@ -675,7 +688,7 @@ GpuDriver::ContextGetSharedMemConfig(GpuContext* context) {
|
|||||||
void* location) {
|
void* location) {
|
||||||
ScopedActivateContext activation{context};
|
ScopedActivateContext activation{context};
|
||||||
hipDeviceptr_t pointer = absl::bit_cast<hipDeviceptr_t>(location);
|
hipDeviceptr_t pointer = absl::bit_cast<hipDeviceptr_t>(location);
|
||||||
hipError_t res = hipFree(pointer);
|
hipError_t res = tensorflow::wrap::hipFree(pointer);
|
||||||
if (res != hipSuccess) {
|
if (res != hipSuccess) {
|
||||||
LOG(ERROR) << "failed to free device memory at " << location
|
LOG(ERROR) << "failed to free device memory at " << location
|
||||||
<< "; result: " << ToString(res);
|
<< "; result: " << ToString(res);
|
||||||
@ -704,7 +717,8 @@ GpuDriver::ContextGetSharedMemConfig(GpuContext* context) {
|
|||||||
ScopedActivateContext activation{context};
|
ScopedActivateContext activation{context};
|
||||||
void* host_mem = nullptr;
|
void* host_mem = nullptr;
|
||||||
// "Portable" memory is visible to all ROCM contexts. Safe for our use model.
|
// "Portable" memory is visible to all ROCM contexts. Safe for our use model.
|
||||||
hipError_t res = hipHostMalloc(&host_mem, bytes, hipHostMallocPortable);
|
hipError_t res = tensorflow::wrap::hipHostMallocVanilla(
|
||||||
|
&host_mem, bytes, hipHostMallocPortable);
|
||||||
if (res != hipSuccess) {
|
if (res != hipSuccess) {
|
||||||
LOG(ERROR) << "failed to alloc " << bytes
|
LOG(ERROR) << "failed to alloc " << bytes
|
||||||
<< " bytes on host: " << ToString(res);
|
<< " bytes on host: " << ToString(res);
|
||||||
@ -715,7 +729,7 @@ GpuDriver::ContextGetSharedMemConfig(GpuContext* context) {
|
|||||||
/* static */ void GpuDriver::HostDeallocate(GpuContext* context,
|
/* static */ void GpuDriver::HostDeallocate(GpuContext* context,
|
||||||
void* location) {
|
void* location) {
|
||||||
ScopedActivateContext activation{context};
|
ScopedActivateContext activation{context};
|
||||||
hipError_t res = hipHostFree(location);
|
hipError_t res = tensorflow::wrap::hipHostFree(location);
|
||||||
if (res != hipSuccess) {
|
if (res != hipSuccess) {
|
||||||
LOG(ERROR) << "error deallocating host memory at " << location << ": "
|
LOG(ERROR) << "error deallocating host memory at " << location << ": "
|
||||||
<< ToString(res);
|
<< ToString(res);
|
||||||
@ -726,7 +740,8 @@ GpuDriver::ContextGetSharedMemConfig(GpuContext* context) {
|
|||||||
uint64 bytes) {
|
uint64 bytes) {
|
||||||
ScopedActivateContext activation{context};
|
ScopedActivateContext activation{context};
|
||||||
// "Portable" memory is visible to all ROCM contexts. Safe for our use model.
|
// "Portable" memory is visible to all ROCM contexts. Safe for our use model.
|
||||||
hipError_t res = hipHostRegister(location, bytes, hipHostRegisterPortable);
|
hipError_t res = tensorflow::wrap::hipHostRegister(location, bytes,
|
||||||
|
hipHostRegisterPortable);
|
||||||
if (res != hipSuccess) {
|
if (res != hipSuccess) {
|
||||||
LOG(ERROR) << "error registering host memory at " << location << ": "
|
LOG(ERROR) << "error registering host memory at " << location << ": "
|
||||||
<< ToString(res);
|
<< ToString(res);
|
||||||
@ -738,7 +753,7 @@ GpuDriver::ContextGetSharedMemConfig(GpuContext* context) {
|
|||||||
/* static */ bool GpuDriver::HostUnregister(GpuContext* context,
|
/* static */ bool GpuDriver::HostUnregister(GpuContext* context,
|
||||||
void* location) {
|
void* location) {
|
||||||
ScopedActivateContext activation{context};
|
ScopedActivateContext activation{context};
|
||||||
hipError_t res = hipHostUnregister(location);
|
hipError_t res = tensorflow::wrap::hipHostUnregister(location);
|
||||||
if (res != hipSuccess) {
|
if (res != hipSuccess) {
|
||||||
LOG(ERROR) << "error unregistering host memory at " << location << ": "
|
LOG(ERROR) << "error unregistering host memory at " << location << ": "
|
||||||
<< ToString(res);
|
<< ToString(res);
|
||||||
@ -755,7 +770,7 @@ GpuDriver::ContextGetSharedMemConfig(GpuContext* context) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
ScopedActivateContext activated{context};
|
ScopedActivateContext activated{context};
|
||||||
hipError_t res = hipEventDestroy(*event);
|
hipError_t res = tensorflow::wrap::hipEventDestroy(*event);
|
||||||
*event = nullptr;
|
*event = nullptr;
|
||||||
|
|
||||||
switch (res) {
|
switch (res) {
|
||||||
@ -779,7 +794,7 @@ GpuDriver::ContextGetSharedMemConfig(GpuContext* context) {
|
|||||||
GpuEventHandle event,
|
GpuEventHandle event,
|
||||||
GpuStreamHandle stream) {
|
GpuStreamHandle stream) {
|
||||||
ScopedActivateContext activated{context};
|
ScopedActivateContext activated{context};
|
||||||
hipError_t res = hipEventRecord(event, stream);
|
hipError_t res = tensorflow::wrap::hipEventRecord(event, stream);
|
||||||
switch (res) {
|
switch (res) {
|
||||||
case hipSuccess:
|
case hipSuccess:
|
||||||
return port::Status::OK();
|
return port::Status::OK();
|
||||||
@ -800,7 +815,7 @@ GpuDriver::ContextGetSharedMemConfig(GpuContext* context) {
|
|||||||
/* static */ port::StatusOr<hipError_t> GpuDriver::QueryEvent(
|
/* static */ port::StatusOr<hipError_t> GpuDriver::QueryEvent(
|
||||||
GpuContext* context, GpuEventHandle event) {
|
GpuContext* context, GpuEventHandle event) {
|
||||||
ScopedActivateContext activated{context};
|
ScopedActivateContext activated{context};
|
||||||
hipError_t res = hipEventQuery(event);
|
hipError_t res = tensorflow::wrap::hipEventQuery(event);
|
||||||
if (res != hipSuccess && res != hipErrorNotReady) {
|
if (res != hipSuccess && res != hipErrorNotReady) {
|
||||||
return port::Status{
|
return port::Status{
|
||||||
port::error::INTERNAL,
|
port::error::INTERNAL,
|
||||||
@ -817,12 +832,13 @@ GpuDriver::ContextGetSharedMemConfig(GpuContext* context) {
|
|||||||
ScopedActivateContext activated{context};
|
ScopedActivateContext activated{context};
|
||||||
// The stop event must have completed in order for hipEventElapsedTime to
|
// The stop event must have completed in order for hipEventElapsedTime to
|
||||||
// work.
|
// work.
|
||||||
hipError_t res = hipEventSynchronize(stop);
|
hipError_t res = tensorflow::wrap::hipEventSynchronize(stop);
|
||||||
if (res != hipSuccess) {
|
if (res != hipSuccess) {
|
||||||
LOG(ERROR) << "failed to synchronize the stop event: " << ToString(res);
|
LOG(ERROR) << "failed to synchronize the stop event: " << ToString(res);
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
res = hipEventElapsedTime(elapsed_milliseconds, start, stop);
|
res =
|
||||||
|
tensorflow::wrap::hipEventElapsedTime(elapsed_milliseconds, start, stop);
|
||||||
if (res != hipSuccess) {
|
if (res != hipSuccess) {
|
||||||
LOG(ERROR) << "failed to get elapsed time between events: "
|
LOG(ERROR) << "failed to get elapsed time between events: "
|
||||||
<< ToString(res);
|
<< ToString(res);
|
||||||
@ -836,7 +852,8 @@ GpuDriver::ContextGetSharedMemConfig(GpuContext* context) {
|
|||||||
GpuStreamHandle stream,
|
GpuStreamHandle stream,
|
||||||
GpuEventHandle event) {
|
GpuEventHandle event) {
|
||||||
ScopedActivateContext activation{context};
|
ScopedActivateContext activation{context};
|
||||||
hipError_t res = hipStreamWaitEvent(stream, event, 0 /* = flags */);
|
hipError_t res =
|
||||||
|
tensorflow::wrap::hipStreamWaitEvent(stream, event, 0 /* = flags */);
|
||||||
if (res != hipSuccess) {
|
if (res != hipSuccess) {
|
||||||
LOG(ERROR) << "could not wait stream on event: " << ToString(res);
|
LOG(ERROR) << "could not wait stream on event: " << ToString(res);
|
||||||
return false;
|
return false;
|
||||||
@ -847,7 +864,7 @@ GpuDriver::ContextGetSharedMemConfig(GpuContext* context) {
|
|||||||
|
|
||||||
/* static */ bool GpuDriver::SynchronizeContext(GpuContext* context) {
|
/* static */ bool GpuDriver::SynchronizeContext(GpuContext* context) {
|
||||||
ScopedActivateContext activation{context};
|
ScopedActivateContext activation{context};
|
||||||
hipError_t res = hipDeviceSynchronize();
|
hipError_t res = tensorflow::wrap::hipDeviceSynchronize();
|
||||||
if (res != hipSuccess) {
|
if (res != hipSuccess) {
|
||||||
LOG(ERROR) << "could not synchronize on ROCM device: " << ToString(res)
|
LOG(ERROR) << "could not synchronize on ROCM device: " << ToString(res)
|
||||||
<< " :: " << port::CurrentStackTrace();
|
<< " :: " << port::CurrentStackTrace();
|
||||||
@ -861,7 +878,7 @@ GpuDriver::ContextGetSharedMemConfig(GpuContext* context) {
|
|||||||
GpuStreamHandle stream) {
|
GpuStreamHandle stream) {
|
||||||
ScopedActivateContext activated{context};
|
ScopedActivateContext activated{context};
|
||||||
CHECK(stream != nullptr);
|
CHECK(stream != nullptr);
|
||||||
hipError_t res = hipStreamSynchronize(stream);
|
hipError_t res = tensorflow::wrap::hipStreamSynchronize(stream);
|
||||||
if (res != hipSuccess) {
|
if (res != hipSuccess) {
|
||||||
port::Status status = port::InternalError(
|
port::Status status = port::InternalError(
|
||||||
absl::StrCat("could not synchronize on ROCM stream: ", ToString(res)));
|
absl::StrCat("could not synchronize on ROCM stream: ", ToString(res)));
|
||||||
@ -877,7 +894,7 @@ GpuDriver::ContextGetSharedMemConfig(GpuContext* context) {
|
|||||||
GpuStreamHandle stream) {
|
GpuStreamHandle stream) {
|
||||||
ScopedActivateContext activated{context};
|
ScopedActivateContext activated{context};
|
||||||
CHECK(stream != nullptr);
|
CHECK(stream != nullptr);
|
||||||
hipError_t res = hipStreamQuery(stream);
|
hipError_t res = tensorflow::wrap::hipStreamQuery(stream);
|
||||||
if (res == hipSuccess) {
|
if (res == hipSuccess) {
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
@ -891,7 +908,7 @@ GpuDriver::ContextGetSharedMemConfig(GpuContext* context) {
|
|||||||
/* static */ port::Status GpuDriver::SynchronousMemcpyD2H(
|
/* static */ port::Status GpuDriver::SynchronousMemcpyD2H(
|
||||||
GpuContext* context, void* host_dst, hipDeviceptr_t gpu_src, uint64 size) {
|
GpuContext* context, void* host_dst, hipDeviceptr_t gpu_src, uint64 size) {
|
||||||
ScopedActivateContext activation{context};
|
ScopedActivateContext activation{context};
|
||||||
hipError_t res = hipMemcpyDtoH(host_dst, gpu_src, size);
|
hipError_t res = tensorflow::wrap::hipMemcpyDtoH(host_dst, gpu_src, size);
|
||||||
if (res != hipSuccess) {
|
if (res != hipSuccess) {
|
||||||
return port::InternalError(
|
return port::InternalError(
|
||||||
absl::StrFormat("failed to synchronous memcpy from device to host: %s; "
|
absl::StrFormat("failed to synchronous memcpy from device to host: %s; "
|
||||||
@ -908,7 +925,8 @@ GpuDriver::ContextGetSharedMemConfig(GpuContext* context) {
|
|||||||
GpuContext* context, hipDeviceptr_t gpu_dst, const void* host_src,
|
GpuContext* context, hipDeviceptr_t gpu_dst, const void* host_src,
|
||||||
uint64 size) {
|
uint64 size) {
|
||||||
ScopedActivateContext activation{context};
|
ScopedActivateContext activation{context};
|
||||||
hipError_t res = hipMemcpyHtoD(gpu_dst, const_cast<void*>(host_src), size);
|
hipError_t res = tensorflow::wrap::hipMemcpyHtoD(
|
||||||
|
gpu_dst, const_cast<void*>(host_src), size);
|
||||||
if (res != hipSuccess) {
|
if (res != hipSuccess) {
|
||||||
return port::InternalError(absl::StrFormat(
|
return port::InternalError(absl::StrFormat(
|
||||||
"failed to synchronous memcpy from host to device: %s; Gpu dst: %p;"
|
"failed to synchronous memcpy from host to device: %s; Gpu dst: %p;"
|
||||||
@ -924,7 +942,7 @@ GpuDriver::ContextGetSharedMemConfig(GpuContext* context) {
|
|||||||
GpuContext* context, hipDeviceptr_t gpu_dst, hipDeviceptr_t gpu_src,
|
GpuContext* context, hipDeviceptr_t gpu_dst, hipDeviceptr_t gpu_src,
|
||||||
uint64 size) {
|
uint64 size) {
|
||||||
ScopedActivateContext activation{context};
|
ScopedActivateContext activation{context};
|
||||||
hipError_t res = hipMemcpyDtoD(gpu_dst, gpu_src, size);
|
hipError_t res = tensorflow::wrap::hipMemcpyDtoD(gpu_dst, gpu_src, size);
|
||||||
if (res != hipSuccess) {
|
if (res != hipSuccess) {
|
||||||
return port::InternalError(absl::StrFormat(
|
return port::InternalError(absl::StrFormat(
|
||||||
"failed to synchronous memcpy from host to device: %s; Gpu dst: %p; "
|
"failed to synchronous memcpy from host to device: %s; Gpu dst: %p; "
|
||||||
@ -942,7 +960,8 @@ GpuDriver::ContextGetSharedMemConfig(GpuContext* context) {
|
|||||||
uint64 size,
|
uint64 size,
|
||||||
GpuStreamHandle stream) {
|
GpuStreamHandle stream) {
|
||||||
ScopedActivateContext activation{context};
|
ScopedActivateContext activation{context};
|
||||||
hipError_t res = hipMemcpyDtoHAsync(host_dst, gpu_src, size, stream);
|
hipError_t res =
|
||||||
|
tensorflow::wrap::hipMemcpyDtoHAsync(host_dst, gpu_src, size, stream);
|
||||||
if (res != hipSuccess) {
|
if (res != hipSuccess) {
|
||||||
LOG(ERROR) << absl::StrFormat(
|
LOG(ERROR) << absl::StrFormat(
|
||||||
"failed to enqueue async memcpy from device to host: %s; host dst: %p; "
|
"failed to enqueue async memcpy from device to host: %s; host dst: %p; "
|
||||||
@ -963,8 +982,8 @@ GpuDriver::ContextGetSharedMemConfig(GpuContext* context) {
|
|||||||
uint64 size,
|
uint64 size,
|
||||||
GpuStreamHandle stream) {
|
GpuStreamHandle stream) {
|
||||||
ScopedActivateContext activation{context};
|
ScopedActivateContext activation{context};
|
||||||
hipError_t res =
|
hipError_t res = tensorflow::wrap::hipMemcpyHtoDAsync(
|
||||||
hipMemcpyHtoDAsync(gpu_dst, const_cast<void*>(host_src), size, stream);
|
gpu_dst, const_cast<void*>(host_src), size, stream);
|
||||||
if (res != hipSuccess) {
|
if (res != hipSuccess) {
|
||||||
LOG(ERROR) << absl::StrFormat(
|
LOG(ERROR) << absl::StrFormat(
|
||||||
"failed to enqueue async memcpy from host to device: %s; Gpu dst: %p; "
|
"failed to enqueue async memcpy from host to device: %s; Gpu dst: %p; "
|
||||||
@ -984,7 +1003,8 @@ GpuDriver::ContextGetSharedMemConfig(GpuContext* context) {
|
|||||||
uint64 size,
|
uint64 size,
|
||||||
GpuStreamHandle stream) {
|
GpuStreamHandle stream) {
|
||||||
ScopedActivateContext activation{context};
|
ScopedActivateContext activation{context};
|
||||||
hipError_t result = hipMemcpyDtoDAsync(gpu_dst, gpu_src, size, stream);
|
hipError_t result =
|
||||||
|
tensorflow::wrap::hipMemcpyDtoDAsync(gpu_dst, gpu_src, size, stream);
|
||||||
if (result != hipSuccess) {
|
if (result != hipSuccess) {
|
||||||
LOG(ERROR) << absl::StrFormat(
|
LOG(ERROR) << absl::StrFormat(
|
||||||
"failed to enqueue async memcpy from device to device: %s"
|
"failed to enqueue async memcpy from device to device: %s"
|
||||||
@ -1021,7 +1041,7 @@ GpuDriver::ContextGetSharedMemConfig(GpuContext* context) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
ScopedActivateContext activated{context};
|
ScopedActivateContext activated{context};
|
||||||
hipError_t res = hipEventCreateWithFlags(event, hipflags);
|
hipError_t res = tensorflow::wrap::hipEventCreateWithFlags(event, hipflags);
|
||||||
|
|
||||||
if (res == hipSuccess) {
|
if (res == hipSuccess) {
|
||||||
return port::Status::OK();
|
return port::Status::OK();
|
||||||
@ -1037,7 +1057,7 @@ GpuDriver::ContextGetSharedMemConfig(GpuContext* context) {
|
|||||||
|
|
||||||
/* static */ int GpuDriver::GetDeviceCount() {
|
/* static */ int GpuDriver::GetDeviceCount() {
|
||||||
int device_count = 0;
|
int device_count = 0;
|
||||||
hipError_t res = hipGetDeviceCount(&device_count);
|
hipError_t res = tensorflow::wrap::hipGetDeviceCount(&device_count);
|
||||||
if (res != hipSuccess) {
|
if (res != hipSuccess) {
|
||||||
LOG(ERROR) << "could not retrieve ROCM device count: " << ToString(res);
|
LOG(ERROR) << "could not retrieve ROCM device count: " << ToString(res);
|
||||||
return 0;
|
return 0;
|
||||||
@ -1061,7 +1081,7 @@ GpuDriver::ContextGetSharedMemConfig(GpuContext* context) {
|
|||||||
|
|
||||||
/* static */ port::Status GpuDriver::GetPointerAddressRange(
|
/* static */ port::Status GpuDriver::GetPointerAddressRange(
|
||||||
hipDeviceptr_t dptr, hipDeviceptr_t* base, size_t* size) {
|
hipDeviceptr_t dptr, hipDeviceptr_t* base, size_t* size) {
|
||||||
hipError_t result = hipMemGetAddressRange(base, size, dptr);
|
hipError_t result = tensorflow::wrap::hipMemGetAddressRange(base, size, dptr);
|
||||||
if (result == hipSuccess) {
|
if (result == hipSuccess) {
|
||||||
return port::Status::OK();
|
return port::Status::OK();
|
||||||
} else if (result == hipErrorNotFound) {
|
} else if (result == hipErrorNotFound) {
|
||||||
@ -1106,7 +1126,8 @@ GpuDriver::ContextGetSharedMemConfig(GpuContext* context) {
|
|||||||
/* static */ port::StatusOr<hipDevice_t> GpuDriver::GetPointerDevice(
|
/* static */ port::StatusOr<hipDevice_t> GpuDriver::GetPointerDevice(
|
||||||
hipDeviceptr_t pointer) {
|
hipDeviceptr_t pointer) {
|
||||||
hipPointerAttribute_t pointerAttributes;
|
hipPointerAttribute_t pointerAttributes;
|
||||||
hipError_t result = hipPointerGetAttributes(&pointerAttributes, pointer);
|
hipError_t result =
|
||||||
|
tensorflow::wrap::hipPointerGetAttributes(&pointerAttributes, pointer);
|
||||||
if (result != hipSuccess) {
|
if (result != hipSuccess) {
|
||||||
return port::Status{
|
return port::Status{
|
||||||
port::error::INTERNAL,
|
port::error::INTERNAL,
|
||||||
@ -1114,7 +1135,7 @@ GpuDriver::ContextGetSharedMemConfig(GpuContext* context) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
hipDevice_t device;
|
hipDevice_t device;
|
||||||
result = hipDeviceGet(&device, pointerAttributes.device);
|
result = tensorflow::wrap::hipDeviceGet(&device, pointerAttributes.device);
|
||||||
if (result != hipSuccess) {
|
if (result != hipSuccess) {
|
||||||
return port::Status{
|
return port::Status{
|
||||||
port::error::INTERNAL,
|
port::error::INTERNAL,
|
||||||
@ -1127,7 +1148,7 @@ GpuDriver::ContextGetSharedMemConfig(GpuContext* context) {
|
|||||||
/* static */ port::Status GpuDriver::GetGpuISAVersion(int* version,
|
/* static */ port::Status GpuDriver::GetGpuISAVersion(int* version,
|
||||||
hipDevice_t device) {
|
hipDevice_t device) {
|
||||||
hipDeviceProp_t props;
|
hipDeviceProp_t props;
|
||||||
hipError_t result = hipGetDeviceProperties(&props, device);
|
hipError_t result = tensorflow::wrap::hipGetDeviceProperties(&props, device);
|
||||||
if (result == hipSuccess) {
|
if (result == hipSuccess) {
|
||||||
*version = props.gcnArch;
|
*version = props.gcnArch;
|
||||||
return port::Status::OK();
|
return port::Status::OK();
|
||||||
@ -1145,7 +1166,8 @@ template <typename T>
|
|||||||
static port::StatusOr<T> GetSimpleAttribute(hipDevice_t device,
|
static port::StatusOr<T> GetSimpleAttribute(hipDevice_t device,
|
||||||
hipDeviceAttribute_t attribute) {
|
hipDeviceAttribute_t attribute) {
|
||||||
int value = -1;
|
int value = -1;
|
||||||
hipError_t result = hipDeviceGetAttribute(&value, attribute, device);
|
hipError_t result =
|
||||||
|
tensorflow::wrap::hipDeviceGetAttribute(&value, attribute, device);
|
||||||
if (result != hipSuccess) {
|
if (result != hipSuccess) {
|
||||||
return port::Status{
|
return port::Status{
|
||||||
port::error::NOT_FOUND,
|
port::error::NOT_FOUND,
|
||||||
@ -1199,22 +1221,24 @@ static port::StatusOr<T> GetSimpleAttribute(hipDevice_t device,
|
|||||||
/* static */ bool GpuDriver::GetGridLimits(int* x, int* y, int* z,
|
/* static */ bool GpuDriver::GetGridLimits(int* x, int* y, int* z,
|
||||||
hipDevice_t device) {
|
hipDevice_t device) {
|
||||||
int value;
|
int value;
|
||||||
hipError_t res =
|
hipError_t res = tensorflow::wrap::hipDeviceGetAttribute(
|
||||||
hipDeviceGetAttribute(&value, hipDeviceAttributeMaxGridDimX, device);
|
&value, hipDeviceAttributeMaxGridDimX, device);
|
||||||
if (res != hipSuccess) {
|
if (res != hipSuccess) {
|
||||||
LOG(ERROR) << "failed to query max grid dim x: " << ToString(res);
|
LOG(ERROR) << "failed to query max grid dim x: " << ToString(res);
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
*x = value;
|
*x = value;
|
||||||
|
|
||||||
res = hipDeviceGetAttribute(&value, hipDeviceAttributeMaxGridDimY, device);
|
res = tensorflow::wrap::hipDeviceGetAttribute(
|
||||||
|
&value, hipDeviceAttributeMaxGridDimY, device);
|
||||||
if (res != hipSuccess) {
|
if (res != hipSuccess) {
|
||||||
LOG(ERROR) << "failed to query max grid dim y: " << ToString(res);
|
LOG(ERROR) << "failed to query max grid dim y: " << ToString(res);
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
*y = value;
|
*y = value;
|
||||||
|
|
||||||
res = hipDeviceGetAttribute(&value, hipDeviceAttributeMaxGridDimZ, device);
|
res = tensorflow::wrap::hipDeviceGetAttribute(
|
||||||
|
&value, hipDeviceAttributeMaxGridDimZ, device);
|
||||||
if (res != hipSuccess) {
|
if (res != hipSuccess) {
|
||||||
LOG(ERROR) << "failed to query max grid dim z: " << ToString(res);
|
LOG(ERROR) << "failed to query max grid dim z: " << ToString(res);
|
||||||
return false;
|
return false;
|
||||||
@ -1224,7 +1248,7 @@ static port::StatusOr<T> GetSimpleAttribute(hipDevice_t device,
|
|||||||
}
|
}
|
||||||
|
|
||||||
/* static */ bool GpuDriver::GetDriverVersion(int* driver_version) {
|
/* static */ bool GpuDriver::GetDriverVersion(int* driver_version) {
|
||||||
hipError_t res = hipDriverGetVersion(driver_version);
|
hipError_t res = tensorflow::wrap::hipDriverGetVersion(driver_version);
|
||||||
if (res != hipSuccess) {
|
if (res != hipSuccess) {
|
||||||
LOG(ERROR) << "failed to query driver version: " << ToString(res);
|
LOG(ERROR) << "failed to query driver version: " << ToString(res);
|
||||||
return false;
|
return false;
|
||||||
@ -1235,7 +1259,8 @@ static port::StatusOr<T> GetSimpleAttribute(hipDevice_t device,
|
|||||||
|
|
||||||
/* static */ bool GpuDriver::GetDeviceProperties(
|
/* static */ bool GpuDriver::GetDeviceProperties(
|
||||||
hipDeviceProp_t* device_properties, int device_ordinal) {
|
hipDeviceProp_t* device_properties, int device_ordinal) {
|
||||||
hipError_t res = hipGetDeviceProperties(device_properties, device_ordinal);
|
hipError_t res = tensorflow::wrap::hipGetDeviceProperties(device_properties,
|
||||||
|
device_ordinal);
|
||||||
if (res != hipSuccess) {
|
if (res != hipSuccess) {
|
||||||
LOG(ERROR) << "failed to query device properties: " << ToString(res);
|
LOG(ERROR) << "failed to query device properties: " << ToString(res);
|
||||||
return false;
|
return false;
|
||||||
@ -1268,7 +1293,7 @@ static port::StatusOr<T> GetSimpleAttribute(hipDevice_t device,
|
|||||||
ScopedActivateContext activation{context};
|
ScopedActivateContext activation{context};
|
||||||
size_t free = 0;
|
size_t free = 0;
|
||||||
size_t total = 0;
|
size_t total = 0;
|
||||||
hipError_t res = hipMemGetInfo(&free, &total);
|
hipError_t res = tensorflow::wrap::hipMemGetInfo(&free, &total);
|
||||||
if (res != hipSuccess) {
|
if (res != hipSuccess) {
|
||||||
LOG(ERROR) << "failed to query device memory info: " << ToString(res);
|
LOG(ERROR) << "failed to query device memory info: " << ToString(res);
|
||||||
return false;
|
return false;
|
||||||
@ -1282,7 +1307,7 @@ static port::StatusOr<T> GetSimpleAttribute(hipDevice_t device,
|
|||||||
/* static */ bool GpuDriver::GetDeviceTotalMemory(hipDevice_t device,
|
/* static */ bool GpuDriver::GetDeviceTotalMemory(hipDevice_t device,
|
||||||
uint64* result) {
|
uint64* result) {
|
||||||
size_t value = -1;
|
size_t value = -1;
|
||||||
hipError_t res = hipDeviceTotalMem(&value, device);
|
hipError_t res = tensorflow::wrap::hipDeviceTotalMem(&value, device);
|
||||||
if (res != hipSuccess) {
|
if (res != hipSuccess) {
|
||||||
LOG(ERROR) << "failed to query total available memory: " << ToString(res);
|
LOG(ERROR) << "failed to query total available memory: " << ToString(res);
|
||||||
return false;
|
return false;
|
||||||
@ -1297,7 +1322,8 @@ static port::StatusOr<T> GetSimpleAttribute(hipDevice_t device,
|
|||||||
static const int kBufferSize = 64;
|
static const int kBufferSize = 64;
|
||||||
absl::InlinedVector<char, 4> chars(kBufferSize);
|
absl::InlinedVector<char, 4> chars(kBufferSize);
|
||||||
chars[kBufferSize - 1] = '\0';
|
chars[kBufferSize - 1] = '\0';
|
||||||
hipError_t res = hipDeviceGetPCIBusId(chars.begin(), kBufferSize - 1, device);
|
hipError_t res = tensorflow::wrap::hipDeviceGetPCIBusId(
|
||||||
|
chars.begin(), kBufferSize - 1, device);
|
||||||
if (res != hipSuccess) {
|
if (res != hipSuccess) {
|
||||||
LOG(ERROR) << "failed to query PCI bus id for device: " << ToString(res);
|
LOG(ERROR) << "failed to query PCI bus id for device: " << ToString(res);
|
||||||
return pci_bus_id;
|
return pci_bus_id;
|
||||||
@ -1313,7 +1339,7 @@ static port::StatusOr<T> GetSimpleAttribute(hipDevice_t device,
|
|||||||
}
|
}
|
||||||
|
|
||||||
int can_access_peer = -1;
|
int can_access_peer = -1;
|
||||||
hipError_t res = hipDeviceCanAccessPeer(
|
hipError_t res = tensorflow::wrap::hipDeviceCanAccessPeer(
|
||||||
&can_access_peer, from->device_ordinal(), to->device_ordinal());
|
&can_access_peer, from->device_ordinal(), to->device_ordinal());
|
||||||
if (res != hipSuccess) {
|
if (res != hipSuccess) {
|
||||||
LOG(ERROR) << "failed to detect peer access capability: " << ToString(res);
|
LOG(ERROR) << "failed to detect peer access capability: " << ToString(res);
|
||||||
@ -1330,8 +1356,8 @@ static port::StatusOr<T> GetSimpleAttribute(hipDevice_t device,
|
|||||||
}
|
}
|
||||||
|
|
||||||
ScopedActivateContext activated{from};
|
ScopedActivateContext activated{from};
|
||||||
hipError_t result =
|
hipError_t result = tensorflow::wrap::hipDeviceEnablePeerAccess(
|
||||||
hipDeviceEnablePeerAccess(to->device_ordinal(), 0 /* = flags */);
|
to->device_ordinal(), 0 /* = flags */);
|
||||||
if (result != hipSuccess && result != hipErrorPeerAccessAlreadyEnabled) {
|
if (result != hipSuccess && result != hipErrorPeerAccessAlreadyEnabled) {
|
||||||
return port::Status{
|
return port::Status{
|
||||||
port::error::INTERNAL,
|
port::error::INTERNAL,
|
||||||
|
146
tensorflow/stream_executor/rocm/rocm_driver_wrapper.h
Normal file
146
tensorflow/stream_executor/rocm/rocm_driver_wrapper.h
Normal file
@ -0,0 +1,146 @@
|
|||||||
|
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
// This file wraps rocm driver calls with dso loader so that we don't need to
|
||||||
|
// have explicit linking to librocm. All TF rocm driver usage should route
|
||||||
|
// through this wrapper.
|
||||||
|
|
||||||
|
#ifndef TENSORFLOW_STREAM_EXECUTOR_ROCM_ROCM_DRIVER_WRAPPER_H_
|
||||||
|
#define TENSORFLOW_STREAM_EXECUTOR_ROCM_ROCM_DRIVER_WRAPPER_H_
|
||||||
|
|
||||||
|
#include "rocm/include/hip/hip_runtime.h"
|
||||||
|
#include "tensorflow/stream_executor/lib/env.h"
|
||||||
|
#include "tensorflow/stream_executor/platform/dso_loader.h"
|
||||||
|
#include "tensorflow/stream_executor/platform/port.h"
|
||||||
|
|
||||||
|
#if defined(TENSORFLOW_USE_ROCM)
|
||||||
|
|
||||||
|
#endif
|
||||||
|
|
||||||
|
namespace tensorflow {
|
||||||
|
namespace wrap {
|
||||||
|
#ifdef PLATFORM_GOOGLE
|
||||||
|
// Use static linked library
|
||||||
|
#define STREAM_EXECUTOR_HIP_WRAP(hipSymbolName) \
|
||||||
|
template <typename... Args> \
|
||||||
|
auto hipSymbolName(Args... args)->decltype(::hipSymbolName(args...)) { \
|
||||||
|
return ::hipSymbolName(args...); \
|
||||||
|
}
|
||||||
|
|
||||||
|
// This macro wraps a global identifier, given by hipSymbolName, in a callable
|
||||||
|
// structure that loads the DLL symbol out of the DSO handle in a thread-safe
|
||||||
|
// manner on first use. This dynamic loading technique is used to avoid DSO
|
||||||
|
// dependencies on vendor libraries which may or may not be available in the
|
||||||
|
// deployed binary environment.
|
||||||
|
#else
|
||||||
|
#define TO_STR_(x) #x
|
||||||
|
#define TO_STR(x) TO_STR_(x)
|
||||||
|
|
||||||
|
// hipMalloc and hipHostMalloc are defined as funtion templates in the
|
||||||
|
// HIP header files, and hence their names get mangled and the attempt
|
||||||
|
// to resolve their name when trying to dynamically load them will fail
|
||||||
|
// Updating the HIP header files to make them C functions is underway.
|
||||||
|
// Until that change flows through, we will workaround the issue by
|
||||||
|
// creating dummy wrappers for them here
|
||||||
|
|
||||||
|
hipError_t hipMallocVanilla(void** ptr, size_t size) {
|
||||||
|
return hipErrorNotInitialized;
|
||||||
|
}
|
||||||
|
|
||||||
|
hipError_t hipHostMallocVanilla(void** ptr, size_t size, unsigned int flags) {
|
||||||
|
return hipErrorNotInitialized;
|
||||||
|
}
|
||||||
|
|
||||||
|
#define STREAM_EXECUTOR_HIP_WRAP(hipSymbolName) \
|
||||||
|
template <typename... Args> \
|
||||||
|
auto hipSymbolName(Args... args)->decltype(::hipSymbolName(args...)) { \
|
||||||
|
using FuncPtrT = std::add_pointer<decltype(::hipSymbolName)>::type; \
|
||||||
|
static FuncPtrT loaded = []() -> FuncPtrT { \
|
||||||
|
static const char *kName = TO_STR(hipSymbolName); \
|
||||||
|
void *f; \
|
||||||
|
auto s = stream_executor::port::Env::Default()->GetSymbolFromLibrary( \
|
||||||
|
stream_executor::internal::CachedDsoLoader::GetHipDsoHandle() \
|
||||||
|
.ValueOrDie(), \
|
||||||
|
kName, &f); \
|
||||||
|
CHECK(s.ok()) << "could not find " << kName \
|
||||||
|
<< " in HIP DSO; dlerror: " << s.error_message(); \
|
||||||
|
return reinterpret_cast<FuncPtrT>(f); \
|
||||||
|
}(); \
|
||||||
|
return loaded(args...); \
|
||||||
|
}
|
||||||
|
#endif
|
||||||
|
|
||||||
|
// clang-format off
|
||||||
|
#define HIP_ROUTINE_EACH(__macro) \
|
||||||
|
__macro(hipDeviceCanAccessPeer) \
|
||||||
|
__macro(hipDeviceEnablePeerAccess) \
|
||||||
|
__macro(hipDeviceGet) \
|
||||||
|
__macro(hipDeviceGetAttribute) \
|
||||||
|
__macro(hipDeviceGetName) \
|
||||||
|
__macro(hipDeviceGetPCIBusId) \
|
||||||
|
__macro(hipDeviceGetSharedMemConfig) \
|
||||||
|
__macro(hipDeviceSetSharedMemConfig) \
|
||||||
|
__macro(hipDeviceSynchronize) \
|
||||||
|
__macro(hipDeviceTotalMem) \
|
||||||
|
__macro(hipDriverGetVersion) \
|
||||||
|
__macro(hipEventCreateWithFlags) \
|
||||||
|
__macro(hipEventElapsedTime) \
|
||||||
|
__macro(hipEventDestroy) \
|
||||||
|
__macro(hipEventQuery) \
|
||||||
|
__macro(hipEventRecord) \
|
||||||
|
__macro(hipEventSynchronize) \
|
||||||
|
__macro(hipFree) \
|
||||||
|
__macro(hipFuncSetCacheConfig) \
|
||||||
|
__macro(hipGetDevice) \
|
||||||
|
__macro(hipGetDeviceCount) \
|
||||||
|
__macro(hipGetDeviceProperties) \
|
||||||
|
__macro(hipHostFree) \
|
||||||
|
__macro(hipHostRegister) \
|
||||||
|
__macro(hipHostUnregister) \
|
||||||
|
__macro(hipInit) \
|
||||||
|
__macro(hipMemGetAddressRange) \
|
||||||
|
__macro(hipMemGetInfo) \
|
||||||
|
__macro(hipMemcpyDtoD) \
|
||||||
|
__macro(hipMemcpyDtoDAsync) \
|
||||||
|
__macro(hipMemcpyDtoH) \
|
||||||
|
__macro(hipMemcpyDtoHAsync) \
|
||||||
|
__macro(hipMemcpyHtoD) \
|
||||||
|
__macro(hipMemcpyHtoDAsync) \
|
||||||
|
__macro(hipMemset) \
|
||||||
|
__macro(hipMemsetAsync) \
|
||||||
|
__macro(hipModuleGetFunction) \
|
||||||
|
__macro(hipModuleGetGlobal) \
|
||||||
|
__macro(hipModuleLaunchKernel) \
|
||||||
|
__macro(hipModuleLoadData) \
|
||||||
|
__macro(hipModuleUnload) \
|
||||||
|
__macro(hipPointerGetAttributes) \
|
||||||
|
__macro(hipSetDevice) \
|
||||||
|
__macro(hipStreamAddCallback) \
|
||||||
|
__macro(hipStreamCreateWithFlags) \
|
||||||
|
__macro(hipStreamDestroy) \
|
||||||
|
__macro(hipStreamQuery) \
|
||||||
|
__macro(hipStreamSynchronize) \
|
||||||
|
__macro(hipStreamWaitEvent) \
|
||||||
|
// clang-format on
|
||||||
|
|
||||||
|
HIP_ROUTINE_EACH(STREAM_EXECUTOR_HIP_WRAP)
|
||||||
|
#undef HIP_ROUTINE_EACH
|
||||||
|
#undef STREAM_EXECUTOR_HIP_WRAP
|
||||||
|
#undef TO_STR
|
||||||
|
#undef TO_STR_
|
||||||
|
} // namespace wrap
|
||||||
|
} // namespace tensorflow
|
||||||
|
|
||||||
|
#endif // TENSORFLOW_STREAM_EXECUTOR_ROCM_ROCM_DRIVER_WRAPPER_H_
|
618
tensorflow/stream_executor/rocm/rocm_fft.cc
Normal file
618
tensorflow/stream_executor/rocm/rocm_fft.cc
Normal file
@ -0,0 +1,618 @@
|
|||||||
|
/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
#include "tensorflow/stream_executor/rocm/rocm_fft.h"
|
||||||
|
|
||||||
|
#include <complex>
|
||||||
|
|
||||||
|
#include "tensorflow/stream_executor/device_memory.h"
|
||||||
|
#include "tensorflow/stream_executor/gpu/gpu_activation.h"
|
||||||
|
#include "tensorflow/stream_executor/gpu/gpu_executor.h"
|
||||||
|
#include "tensorflow/stream_executor/gpu/gpu_helpers.h"
|
||||||
|
#include "tensorflow/stream_executor/gpu/gpu_stream.h"
|
||||||
|
#include "tensorflow/stream_executor/lib/env.h"
|
||||||
|
#include "tensorflow/stream_executor/lib/initialize.h"
|
||||||
|
#include "tensorflow/stream_executor/lib/status.h"
|
||||||
|
#include "tensorflow/stream_executor/platform/dso_loader.h"
|
||||||
|
#include "tensorflow/stream_executor/platform/logging.h"
|
||||||
|
#include "tensorflow/stream_executor/platform/port.h"
|
||||||
|
#include "tensorflow/stream_executor/plugin_registry.h"
|
||||||
|
#include "tensorflow/stream_executor/rocm/rocm_platform_id.h"
|
||||||
|
#include "tensorflow/stream_executor/stream_executor_internal.h"
|
||||||
|
|
||||||
|
namespace stream_executor {
|
||||||
|
namespace gpu {
|
||||||
|
|
||||||
|
PLUGIN_REGISTRY_DEFINE_PLUGIN_ID(kRocFftPlugin);
|
||||||
|
|
||||||
|
namespace wrap {
|
||||||
|
|
||||||
|
#ifdef PLATFORM_GOOGLE
|
||||||
|
// This macro wraps a global identifier, given by __name, in a callable
|
||||||
|
// structure that loads the DLL symbol out of the DSO handle in a thread-safe
|
||||||
|
// manner on first use. This dynamic loading technique is used to avoid DSO
|
||||||
|
// dependencies on vendor libraries which may or may not be available in the
|
||||||
|
// deployed binary environment.
|
||||||
|
#define STREAM_EXECUTOR_ROCFFT_WRAP(__name) \
|
||||||
|
struct WrapperShim__##__name { \
|
||||||
|
template <typename... Args> \
|
||||||
|
hipfftResult operator()(GpuExecutor *parent, Args... args) { \
|
||||||
|
gpu::ScopedActivateExecutorContext sac{parent}; \
|
||||||
|
return ::__name(args...); \
|
||||||
|
} \
|
||||||
|
} __name;
|
||||||
|
|
||||||
|
#else
|
||||||
|
|
||||||
|
#define STREAM_EXECUTOR_ROCFFT_WRAP(__name) \
|
||||||
|
struct DynLoadShim__##__name { \
|
||||||
|
static const char *kName; \
|
||||||
|
using FuncPtrT = std::add_pointer<decltype(::__name)>::type; \
|
||||||
|
static void *GetDsoHandle() { \
|
||||||
|
auto s = internal::CachedDsoLoader::GetRocfftDsoHandle(); \
|
||||||
|
return s.ValueOrDie(); \
|
||||||
|
} \
|
||||||
|
static FuncPtrT LoadOrDie() { \
|
||||||
|
void *f; \
|
||||||
|
auto s = port::Env::Default()->GetSymbolFromLibrary(GetDsoHandle(), \
|
||||||
|
kName, &f); \
|
||||||
|
CHECK(s.ok()) << "could not find " << kName \
|
||||||
|
<< " in rocfft DSO; dlerror: " << s.error_message(); \
|
||||||
|
return reinterpret_cast<FuncPtrT>(f); \
|
||||||
|
} \
|
||||||
|
static FuncPtrT DynLoad() { \
|
||||||
|
static FuncPtrT f = LoadOrDie(); \
|
||||||
|
return f; \
|
||||||
|
} \
|
||||||
|
template <typename... Args> \
|
||||||
|
hipfftResult operator()(GpuExecutor *parent, Args... args) { \
|
||||||
|
gpu::ScopedActivateExecutorContext sac{parent}; \
|
||||||
|
return DynLoad()(args...); \
|
||||||
|
} \
|
||||||
|
} __name; \
|
||||||
|
const char *DynLoadShim__##__name::kName = #__name;
|
||||||
|
|
||||||
|
#endif
|
||||||
|
|
||||||
|
#define ROCFFT_ROUTINE_EACH(__macro) \
|
||||||
|
__macro(hipfftDestroy) __macro(hipfftSetStream) __macro(hipfftPlan1d) \
|
||||||
|
__macro(hipfftPlan2d) __macro(hipfftPlan3d) __macro(hipfftPlanMany) \
|
||||||
|
__macro(hipfftCreate) __macro(hipfftSetAutoAllocation) \
|
||||||
|
__macro(hipfftSetWorkArea) __macro(hipfftGetSize1d) \
|
||||||
|
__macro(hipfftMakePlan1d) __macro(hipfftGetSize2d) \
|
||||||
|
__macro(hipfftMakePlan2d) __macro(hipfftGetSize3d) \
|
||||||
|
__macro(hipfftMakePlan3d) __macro(hipfftGetSizeMany) \
|
||||||
|
__macro(hipfftMakePlanMany) \
|
||||||
|
__macro(hipfftExecD2Z) \
|
||||||
|
__macro(hipfftExecZ2D) \
|
||||||
|
__macro(hipfftExecC2C) \
|
||||||
|
__macro(hipfftExecC2R) \
|
||||||
|
__macro(hipfftExecZ2Z) \
|
||||||
|
__macro(hipfftExecR2C)
|
||||||
|
|
||||||
|
ROCFFT_ROUTINE_EACH(STREAM_EXECUTOR_ROCFFT_WRAP)
|
||||||
|
|
||||||
|
} // namespace wrap
|
||||||
|
|
||||||
|
namespace {
|
||||||
|
|
||||||
|
// A helper function transforming gpu_fft arguments into rocFFT arguments.
|
||||||
|
hipfftType ROCMFftType(fft::Type type) {
|
||||||
|
switch (type) {
|
||||||
|
case fft::Type::kC2CForward:
|
||||||
|
case fft::Type::kC2CInverse:
|
||||||
|
return HIPFFT_C2C;
|
||||||
|
case fft::Type::kC2R:
|
||||||
|
return HIPFFT_C2R;
|
||||||
|
case fft::Type::kR2C:
|
||||||
|
return HIPFFT_R2C;
|
||||||
|
case fft::Type::kZ2ZForward:
|
||||||
|
case fft::Type::kZ2ZInverse:
|
||||||
|
return HIPFFT_Z2Z;
|
||||||
|
case fft::Type::kZ2D:
|
||||||
|
return HIPFFT_Z2D;
|
||||||
|
case fft::Type::kD2Z:
|
||||||
|
return HIPFFT_D2Z;
|
||||||
|
default:
|
||||||
|
LOG(FATAL) << "Invalid value of fft::Type.";
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Associates the given stream with the given rocFFT plan.
|
||||||
|
bool SetStream(GpuExecutor *parent, hipfftHandle plan, Stream *stream) {
|
||||||
|
auto ret = wrap::hipfftSetStream(parent, plan, AsGpuStreamValue(stream));
|
||||||
|
if (ret != HIPFFT_SUCCESS) {
|
||||||
|
LOG(ERROR) << "failed to run rocFFT routine hipfftSetStream: " << ret;
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace
|
||||||
|
|
||||||
|
port::Status ROCMFftPlan::Initialize(
|
||||||
|
GpuExecutor *parent, Stream *stream, int rank, uint64 *elem_count,
|
||||||
|
uint64 *input_embed, uint64 input_stride, uint64 input_distance,
|
||||||
|
uint64 *output_embed, uint64 output_stride, uint64 output_distance,
|
||||||
|
fft::Type type, int batch_count, ScratchAllocator *scratch_allocator) {
|
||||||
|
if (IsInitialized()) {
|
||||||
|
LOG(FATAL) << "Try to repeatedly initialize.";
|
||||||
|
}
|
||||||
|
is_initialized_ = true;
|
||||||
|
int elem_count_[3], input_embed_[3], output_embed_[3];
|
||||||
|
for (int i = 0; i < rank; ++i) {
|
||||||
|
elem_count_[i] = elem_count[i];
|
||||||
|
if (input_embed) {
|
||||||
|
input_embed_[i] = input_embed[i];
|
||||||
|
}
|
||||||
|
if (output_embed) {
|
||||||
|
output_embed_[i] = output_embed[i];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
parent_ = parent;
|
||||||
|
fft_type_ = type;
|
||||||
|
if (batch_count == 1 && input_embed == nullptr && output_embed == nullptr) {
|
||||||
|
hipfftResult_t ret;
|
||||||
|
if (scratch_allocator == nullptr) {
|
||||||
|
switch (rank) {
|
||||||
|
case 1:
|
||||||
|
// hipfftPlan1d
|
||||||
|
ret = wrap::hipfftPlan1d(parent, &plan_, elem_count_[0],
|
||||||
|
ROCMFftType(type), 1 /* = batch */);
|
||||||
|
if (ret != HIPFFT_SUCCESS) {
|
||||||
|
LOG(ERROR) << "failed to create rocFFT 1d plan:" << ret;
|
||||||
|
return port::Status{port::error::INTERNAL,
|
||||||
|
"Failed to create rocFFT 1d plan."};
|
||||||
|
}
|
||||||
|
return port::Status::OK();
|
||||||
|
case 2:
|
||||||
|
// hipfftPlan2d
|
||||||
|
ret = wrap::hipfftPlan2d(parent, &plan_, elem_count_[0],
|
||||||
|
elem_count_[1], ROCMFftType(type));
|
||||||
|
if (ret != HIPFFT_SUCCESS) {
|
||||||
|
LOG(ERROR) << "failed to create rocFFT 2d plan:" << ret;
|
||||||
|
return port::Status{port::error::INTERNAL,
|
||||||
|
"Failed to create rocFFT 2d plan."};
|
||||||
|
}
|
||||||
|
return port::Status::OK();
|
||||||
|
case 3:
|
||||||
|
// hipfftPlan3d
|
||||||
|
ret =
|
||||||
|
wrap::hipfftPlan3d(parent, &plan_, elem_count_[0], elem_count_[1],
|
||||||
|
elem_count_[2], ROCMFftType(type));
|
||||||
|
if (ret != HIPFFT_SUCCESS) {
|
||||||
|
LOG(ERROR) << "failed to create rocFFT 3d plan:" << ret;
|
||||||
|
return port::Status{port::error::INTERNAL,
|
||||||
|
"Failed to create rocFFT 3d plan."};
|
||||||
|
}
|
||||||
|
return port::Status::OK();
|
||||||
|
default:
|
||||||
|
LOG(ERROR) << "Invalid rank value for hipfftPlan. "
|
||||||
|
"Requested 1, 2, or 3, given: "
|
||||||
|
<< rank;
|
||||||
|
return port::Status{port::error::INVALID_ARGUMENT,
|
||||||
|
"hipfftPlan only takes rank 1, 2, or 3."};
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
ret = wrap::hipfftCreate(parent, &plan_);
|
||||||
|
if (ret != HIPFFT_SUCCESS) {
|
||||||
|
LOG(ERROR) << "failed to create rocFFT plan:" << ret;
|
||||||
|
return port::Status{port::error::INTERNAL,
|
||||||
|
"Failed to create rocFFT plan."};
|
||||||
|
}
|
||||||
|
ret = wrap::hipfftSetAutoAllocation(parent, plan_, 0);
|
||||||
|
if (ret != HIPFFT_SUCCESS) {
|
||||||
|
LOG(ERROR) << "failed to set auto allocation for rocFFT plan:" << ret;
|
||||||
|
return port::Status{port::error::INTERNAL,
|
||||||
|
"Failed to set auto allocation for rocFFT plan."};
|
||||||
|
}
|
||||||
|
size_t size_in_bytes;
|
||||||
|
switch (rank) {
|
||||||
|
case 1:
|
||||||
|
ret = wrap::hipfftMakePlan1d(parent, plan_, elem_count_[0],
|
||||||
|
ROCMFftType(type), /*batch=*/1,
|
||||||
|
&size_in_bytes);
|
||||||
|
if (ret != HIPFFT_SUCCESS) {
|
||||||
|
LOG(ERROR) << "failed to make rocFFT 1d plan:" << ret;
|
||||||
|
return port::Status{port::error::INTERNAL,
|
||||||
|
"Failed to make rocFFT 1d plan."};
|
||||||
|
}
|
||||||
|
break;
|
||||||
|
case 2:
|
||||||
|
ret = wrap::hipfftMakePlan2d(parent, plan_, elem_count_[0],
|
||||||
|
elem_count_[1], ROCMFftType(type),
|
||||||
|
&size_in_bytes);
|
||||||
|
if (ret != HIPFFT_SUCCESS) {
|
||||||
|
LOG(ERROR) << "failed to make rocFFT 2d plan:" << ret;
|
||||||
|
return port::Status{port::error::INTERNAL,
|
||||||
|
"Failed to make rocFFT 2d plan."};
|
||||||
|
}
|
||||||
|
break;
|
||||||
|
case 3:
|
||||||
|
ret = wrap::hipfftMakePlan3d(parent, plan_, elem_count_[0],
|
||||||
|
elem_count_[1], elem_count_[2],
|
||||||
|
ROCMFftType(type), &size_in_bytes);
|
||||||
|
if (ret != HIPFFT_SUCCESS) {
|
||||||
|
LOG(ERROR) << "failed to make rocFFT 3d plan:" << ret;
|
||||||
|
return port::Status{port::error::INTERNAL,
|
||||||
|
"Failed to make rocFFT 3d plan."};
|
||||||
|
}
|
||||||
|
break;
|
||||||
|
default:
|
||||||
|
LOG(ERROR) << "Invalid rank value for hipfftPlan. "
|
||||||
|
"Requested 1, 2, or 3, given: "
|
||||||
|
<< rank;
|
||||||
|
return port::Status{port::error::INVALID_ARGUMENT,
|
||||||
|
"hipfftPlan only takes rank 1, 2, or 3."};
|
||||||
|
}
|
||||||
|
// TODO(yangzihao): refactor this code and the one with the same function
|
||||||
|
// in the batch mode.
|
||||||
|
if (size_in_bytes != 0) {
|
||||||
|
auto allocated =
|
||||||
|
scratch_allocator->AllocateBytes(stream, size_in_bytes);
|
||||||
|
if (!allocated.ok() || (scratch_ = allocated.ValueOrDie()) == nullptr) {
|
||||||
|
LOG(ERROR) << "failed to allocate work area.";
|
||||||
|
return allocated.status();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// Connect work area with allocated space.
|
||||||
|
ret = wrap::hipfftSetWorkArea(parent, plan_, scratch_.opaque());
|
||||||
|
if (ret != HIPFFT_SUCCESS) {
|
||||||
|
LOG(ERROR) << "failed to set work area for rocFFT plan:" << ret;
|
||||||
|
return port::Status{port::error::INTERNAL,
|
||||||
|
"Failed to set work area for rocFFT plan."};
|
||||||
|
}
|
||||||
|
return port::Status::OK();
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
// For either multiple batches or rank higher than 3, use hipfftPlanMany().
|
||||||
|
if (scratch_allocator == nullptr) {
|
||||||
|
auto ret = wrap::hipfftPlanMany(
|
||||||
|
parent, &plan_, rank, elem_count_,
|
||||||
|
input_embed ? input_embed_ : nullptr, input_stride, input_distance,
|
||||||
|
output_embed ? output_embed_ : nullptr, output_stride,
|
||||||
|
output_distance, ROCMFftType(type), batch_count);
|
||||||
|
if (ret != HIPFFT_SUCCESS) {
|
||||||
|
LOG(ERROR) << "failed to create rocFFT batched plan:" << ret;
|
||||||
|
return port::Status{port::error::INTERNAL,
|
||||||
|
"Failed to create rocFFT bacthed plan."};
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
auto ret = wrap::hipfftCreate(parent, &plan_);
|
||||||
|
if (ret != HIPFFT_SUCCESS) {
|
||||||
|
LOG(ERROR) << "failed to create rocFFT batched plan:" << ret;
|
||||||
|
return port::Status{port::error::INTERNAL,
|
||||||
|
"Failed to create rocFFT bacthed plan."};
|
||||||
|
}
|
||||||
|
ret = wrap::hipfftSetAutoAllocation(parent, plan_, 0);
|
||||||
|
if (ret != HIPFFT_SUCCESS) {
|
||||||
|
LOG(ERROR) << "failed to set auto allocation for rocFFT batched plan:"
|
||||||
|
<< ret;
|
||||||
|
return port::Status{
|
||||||
|
port::error::INTERNAL,
|
||||||
|
"Failed to set auto allocation for rocFFT bacthed plan."};
|
||||||
|
}
|
||||||
|
size_t size_in_bytes;
|
||||||
|
ret = wrap::hipfftMakePlanMany(
|
||||||
|
parent, plan_, rank, elem_count_,
|
||||||
|
input_embed ? input_embed_ : nullptr, input_stride, input_distance,
|
||||||
|
output_embed ? output_embed_ : nullptr, output_stride,
|
||||||
|
output_distance, ROCMFftType(type), batch_count, &size_in_bytes);
|
||||||
|
if (ret != HIPFFT_SUCCESS) {
|
||||||
|
LOG(ERROR) << "failed to make rocFFT batched plan:" << ret;
|
||||||
|
return port::Status{port::error::INTERNAL,
|
||||||
|
"Failed to make rocFFT bacthed plan."};
|
||||||
|
}
|
||||||
|
if (size_in_bytes != 0) {
|
||||||
|
auto allocated =
|
||||||
|
scratch_allocator->AllocateBytes(stream, size_in_bytes);
|
||||||
|
if (!allocated.ok() || (scratch_ = allocated.ValueOrDie()) == nullptr) {
|
||||||
|
LOG(ERROR) << "failed to allocate work area.";
|
||||||
|
return allocated.status();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// Connect work area with allocated space.
|
||||||
|
ret = wrap::hipfftSetWorkArea(parent, plan_, scratch_.opaque());
|
||||||
|
if (ret != HIPFFT_SUCCESS) {
|
||||||
|
LOG(ERROR) << "failed to set work area for rocFFT batched plan:" << ret;
|
||||||
|
return port::Status{port::error::INTERNAL,
|
||||||
|
"Failed to set work area for rocFFT bacthed plan."};
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return port::Status::OK();
|
||||||
|
}
|
||||||
|
|
||||||
|
port::Status ROCMFftPlan::Initialize(GpuExecutor *parent, Stream *stream,
|
||||||
|
int rank, uint64 *elem_count,
|
||||||
|
fft::Type type,
|
||||||
|
ScratchAllocator *scratch_allocator) {
|
||||||
|
return Initialize(parent_, stream, rank, elem_count,
|
||||||
|
/*input_embed=*/nullptr, /*input_stride=*/0,
|
||||||
|
/*input_distance=*/0,
|
||||||
|
/*output_embed=*/nullptr, /*output_stride=*/0,
|
||||||
|
/*output_distance=*/0, type, 1, scratch_allocator);
|
||||||
|
}
|
||||||
|
|
||||||
|
ROCMFftPlan::~ROCMFftPlan() { wrap::hipfftDestroy(parent_, plan_); }
|
||||||
|
|
||||||
|
int ROCMFftPlan::GetFftDirection() const {
|
||||||
|
if (!IsInitialized()) {
|
||||||
|
LOG(FATAL) << "Try to get fft direction before initialization.";
|
||||||
|
} else {
|
||||||
|
switch (fft_type_) {
|
||||||
|
case fft::Type::kC2CForward:
|
||||||
|
case fft::Type::kZ2ZForward:
|
||||||
|
case fft::Type::kR2C:
|
||||||
|
case fft::Type::kD2Z:
|
||||||
|
return HIPFFT_FORWARD;
|
||||||
|
case fft::Type::kC2CInverse:
|
||||||
|
case fft::Type::kZ2ZInverse:
|
||||||
|
case fft::Type::kC2R:
|
||||||
|
case fft::Type::kZ2D:
|
||||||
|
return HIPFFT_BACKWARD;
|
||||||
|
default:
|
||||||
|
LOG(FATAL) << "Invalid value of fft::Type.";
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
std::unique_ptr<fft::Plan> ROCMFft::Create1dPlan(Stream *stream, uint64 num_x,
|
||||||
|
fft::Type type,
|
||||||
|
bool in_place_fft) {
|
||||||
|
std::unique_ptr<ROCMFftPlan> fft_plan_ptr{new ROCMFftPlan()};
|
||||||
|
uint64 elem_count[1] = {num_x};
|
||||||
|
port::Status status = fft_plan_ptr->Initialize(
|
||||||
|
parent_, stream, 1, elem_count, type, /*scratch_allocator=*/nullptr);
|
||||||
|
// TODO(yangzihao): In the future, send error msg back to TensorFlow
|
||||||
|
// so it can fail gracefully,
|
||||||
|
if (!status.ok()) {
|
||||||
|
LOG(FATAL) << "failed to initialize hipfft 1d plan: "
|
||||||
|
<< status.error_message();
|
||||||
|
}
|
||||||
|
return std::move(fft_plan_ptr);
|
||||||
|
}
|
||||||
|
|
||||||
|
std::unique_ptr<fft::Plan> ROCMFft::Create1dPlanWithScratchAllocator(
|
||||||
|
Stream *stream, uint64 num_x, fft::Type type, bool in_place_fft,
|
||||||
|
ScratchAllocator *scratch_allocator) {
|
||||||
|
std::unique_ptr<ROCMFftPlan> fft_plan_ptr{new ROCMFftPlan()};
|
||||||
|
uint64 elem_count[1] = {num_x};
|
||||||
|
port::Status status = fft_plan_ptr->Initialize(parent_, stream, 1, elem_count,
|
||||||
|
type, scratch_allocator);
|
||||||
|
if (!status.ok()) {
|
||||||
|
LOG(FATAL)
|
||||||
|
<< "failed to initialize hipfft 1d plan with customized allocator: "
|
||||||
|
<< status.error_message();
|
||||||
|
}
|
||||||
|
return std::move(fft_plan_ptr);
|
||||||
|
}
|
||||||
|
|
||||||
|
std::unique_ptr<fft::Plan> ROCMFft::Create2dPlan(Stream *stream, uint64 num_x,
|
||||||
|
uint64 num_y, fft::Type type,
|
||||||
|
bool in_place_fft) {
|
||||||
|
std::unique_ptr<ROCMFftPlan> fft_plan_ptr{new ROCMFftPlan()};
|
||||||
|
uint64 elem_count[2] = {num_x, num_y};
|
||||||
|
port::Status status = fft_plan_ptr->Initialize(
|
||||||
|
parent_, stream, 1, elem_count, type, /*scratch_allocator=*/nullptr);
|
||||||
|
if (!status.ok()) {
|
||||||
|
LOG(FATAL) << "failed to initialize hipfft 2d plan: "
|
||||||
|
<< status.error_message();
|
||||||
|
}
|
||||||
|
return std::move(fft_plan_ptr);
|
||||||
|
}
|
||||||
|
|
||||||
|
std::unique_ptr<fft::Plan> ROCMFft::Create2dPlanWithScratchAllocator(
|
||||||
|
Stream *stream, uint64 num_x, uint64 num_y, fft::Type type,
|
||||||
|
bool in_place_fft, ScratchAllocator *scratch_allocator) {
|
||||||
|
std::unique_ptr<ROCMFftPlan> fft_plan_ptr{new ROCMFftPlan()};
|
||||||
|
uint64 elem_count[2] = {num_x, num_y};
|
||||||
|
port::Status status = fft_plan_ptr->Initialize(parent_, stream, 2, elem_count,
|
||||||
|
type, scratch_allocator);
|
||||||
|
if (!status.ok()) {
|
||||||
|
LOG(FATAL)
|
||||||
|
<< "failed to initialize hipfft 2d plan with customized allocator: "
|
||||||
|
<< status.error_message();
|
||||||
|
}
|
||||||
|
return std::move(fft_plan_ptr);
|
||||||
|
}
|
||||||
|
|
||||||
|
std::unique_ptr<fft::Plan> ROCMFft::Create3dPlan(Stream *stream, uint64 num_x,
|
||||||
|
uint64 num_y, uint64 num_z,
|
||||||
|
fft::Type type,
|
||||||
|
bool in_place_fft) {
|
||||||
|
std::unique_ptr<ROCMFftPlan> fft_plan_ptr{new ROCMFftPlan()};
|
||||||
|
uint64 elem_count[3] = {num_x, num_y, num_z};
|
||||||
|
port::Status status = fft_plan_ptr->Initialize(
|
||||||
|
parent_, stream, 3, elem_count, type, /*scratch_allocator=*/nullptr);
|
||||||
|
if (!status.ok()) {
|
||||||
|
LOG(FATAL) << "failed to initialize hipfft 3d plan: "
|
||||||
|
<< status.error_message();
|
||||||
|
}
|
||||||
|
return std::move(fft_plan_ptr);
|
||||||
|
}
|
||||||
|
|
||||||
|
std::unique_ptr<fft::Plan> ROCMFft::Create3dPlanWithScratchAllocator(
|
||||||
|
Stream *stream, uint64 num_x, uint64 num_y, uint64 num_z, fft::Type type,
|
||||||
|
bool in_place_fft, ScratchAllocator *scratch_allocator) {
|
||||||
|
std::unique_ptr<ROCMFftPlan> fft_plan_ptr{new ROCMFftPlan()};
|
||||||
|
uint64 elem_count[3] = {num_x, num_y, num_z};
|
||||||
|
port::Status status = fft_plan_ptr->Initialize(parent_, stream, 3, elem_count,
|
||||||
|
type, scratch_allocator);
|
||||||
|
if (!status.ok()) {
|
||||||
|
LOG(FATAL)
|
||||||
|
<< "failed to initialize hipfft 3d plan with customized allocator: "
|
||||||
|
<< status.error_message();
|
||||||
|
}
|
||||||
|
return std::move(fft_plan_ptr);
|
||||||
|
}
|
||||||
|
|
||||||
|
std::unique_ptr<fft::Plan> ROCMFft::CreateBatchedPlan(
|
||||||
|
Stream *stream, int rank, uint64 *elem_count, uint64 *input_embed,
|
||||||
|
uint64 input_stride, uint64 input_distance, uint64 *output_embed,
|
||||||
|
uint64 output_stride, uint64 output_distance, fft::Type type,
|
||||||
|
bool in_place_fft, int batch_count) {
|
||||||
|
std::unique_ptr<ROCMFftPlan> fft_plan_ptr{new ROCMFftPlan()};
|
||||||
|
port::Status status = fft_plan_ptr->Initialize(
|
||||||
|
parent_, stream, rank, elem_count, input_embed, input_stride,
|
||||||
|
input_distance, output_embed, output_stride, output_distance, type,
|
||||||
|
batch_count, /*scratch_allocator=*/nullptr);
|
||||||
|
if (!status.ok()) {
|
||||||
|
LOG(FATAL) << "failed to initialize batched hipfft plan: "
|
||||||
|
<< status.error_message();
|
||||||
|
}
|
||||||
|
|
||||||
|
return std::move(fft_plan_ptr);
|
||||||
|
}
|
||||||
|
|
||||||
|
std::unique_ptr<fft::Plan> ROCMFft::CreateBatchedPlanWithScratchAllocator(
|
||||||
|
Stream *stream, int rank, uint64 *elem_count, uint64 *input_embed,
|
||||||
|
uint64 input_stride, uint64 input_distance, uint64 *output_embed,
|
||||||
|
uint64 output_stride, uint64 output_distance, fft::Type type,
|
||||||
|
bool in_place_fft, int batch_count, ScratchAllocator *scratch_allocator) {
|
||||||
|
std::unique_ptr<ROCMFftPlan> fft_plan_ptr{new ROCMFftPlan()};
|
||||||
|
port::Status status = fft_plan_ptr->Initialize(
|
||||||
|
parent_, stream, rank, elem_count, input_embed, input_stride,
|
||||||
|
input_distance, output_embed, output_stride, output_distance, type,
|
||||||
|
batch_count, scratch_allocator);
|
||||||
|
if (!status.ok()) {
|
||||||
|
LOG(FATAL) << "failed to initialize batched hipfft plan with customized "
|
||||||
|
"allocator: "
|
||||||
|
<< status.error_message();
|
||||||
|
}
|
||||||
|
return std::move(fft_plan_ptr);
|
||||||
|
}
|
||||||
|
|
||||||
|
void ROCMFft::UpdatePlanWithScratchAllocator(
|
||||||
|
Stream *stream, fft::Plan *plan, ScratchAllocator *scratch_allocator) {
|
||||||
|
LOG(ERROR) << "update plan with scratch allocator not implemented";
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename FuncT, typename InputT, typename OutputT>
|
||||||
|
bool ROCMFft::DoFftInternal(Stream *stream, fft::Plan *plan, FuncT hipfftExec,
|
||||||
|
const DeviceMemory<InputT> &input,
|
||||||
|
DeviceMemory<OutputT> *output) {
|
||||||
|
ROCMFftPlan *rocm_fft_plan = dynamic_cast<ROCMFftPlan *>(plan);
|
||||||
|
if (rocm_fft_plan == nullptr) {
|
||||||
|
LOG(ERROR) << "the passed-in plan is not a ROCMFftPlan object.";
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (!SetStream(parent_, rocm_fft_plan->GetPlan(), stream)) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
auto ret = hipfftExec(parent_, rocm_fft_plan->GetPlan(),
|
||||||
|
GpuComplex(const_cast<InputT *>(GpuMemory(input))),
|
||||||
|
GpuComplex(GpuMemoryMutable(output)));
|
||||||
|
|
||||||
|
if (ret != HIPFFT_SUCCESS) {
|
||||||
|
LOG(ERROR) << "failed to run rocFFT routine: " << ret;
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename FuncT, typename InputT, typename OutputT>
|
||||||
|
bool ROCMFft::DoFftWithDirectionInternal(Stream *stream, fft::Plan *plan,
|
||||||
|
FuncT hipfftExec,
|
||||||
|
const DeviceMemory<InputT> &input,
|
||||||
|
DeviceMemory<OutputT> *output) {
|
||||||
|
ROCMFftPlan *rocm_fft_plan = dynamic_cast<ROCMFftPlan *>(plan);
|
||||||
|
if (rocm_fft_plan == nullptr) {
|
||||||
|
LOG(ERROR) << "the passed-in plan is not a ROCMFftPlan object.";
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (!SetStream(parent_, rocm_fft_plan->GetPlan(), stream)) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
auto ret = hipfftExec(parent_, rocm_fft_plan->GetPlan(),
|
||||||
|
GpuComplex(const_cast<InputT *>(GpuMemory(input))),
|
||||||
|
GpuComplex(GpuMemoryMutable(output)),
|
||||||
|
rocm_fft_plan->GetFftDirection());
|
||||||
|
|
||||||
|
if (ret != HIPFFT_SUCCESS) {
|
||||||
|
LOG(ERROR) << "failed to run rocFFT routine: " << ret;
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
#define STREAM_EXECUTOR_ROCM_DEFINE_FFT(__type, __fft_type1, __fft_type2, \
|
||||||
|
__fft_type3) \
|
||||||
|
bool ROCMFft::DoFft(Stream *stream, fft::Plan *plan, \
|
||||||
|
const DeviceMemory<std::complex<__type>> &input, \
|
||||||
|
DeviceMemory<std::complex<__type>> *output) { \
|
||||||
|
return DoFftWithDirectionInternal( \
|
||||||
|
stream, plan, wrap::hipfftExec##__fft_type1, input, output); \
|
||||||
|
} \
|
||||||
|
bool ROCMFft::DoFft(Stream *stream, fft::Plan *plan, \
|
||||||
|
const DeviceMemory<__type> &input, \
|
||||||
|
DeviceMemory<std::complex<__type>> *output) { \
|
||||||
|
return DoFftInternal(stream, plan, wrap::hipfftExec##__fft_type2, input, \
|
||||||
|
output); \
|
||||||
|
} \
|
||||||
|
bool ROCMFft::DoFft(Stream *stream, fft::Plan *plan, \
|
||||||
|
const DeviceMemory<std::complex<__type>> &input, \
|
||||||
|
DeviceMemory<__type> *output) { \
|
||||||
|
return DoFftInternal(stream, plan, wrap::hipfftExec##__fft_type3, input, \
|
||||||
|
output); \
|
||||||
|
}
|
||||||
|
|
||||||
|
STREAM_EXECUTOR_ROCM_DEFINE_FFT(float, C2C, R2C, C2R)
|
||||||
|
STREAM_EXECUTOR_ROCM_DEFINE_FFT(double, Z2Z, D2Z, Z2D)
|
||||||
|
|
||||||
|
#undef STREAM_EXECUTOR_ROCM_DEFINE_FFT
|
||||||
|
|
||||||
|
} // namespace gpu
|
||||||
|
|
||||||
|
void initialize_rocfft() {
|
||||||
|
auto rocFftAlreadyRegistered = PluginRegistry::Instance()->HasFactory(
|
||||||
|
rocm::kROCmPlatformId, PluginKind::kFft, gpu::kRocFftPlugin);
|
||||||
|
|
||||||
|
if (!rocFftAlreadyRegistered) {
|
||||||
|
port::Status status =
|
||||||
|
PluginRegistry::Instance()->RegisterFactory<PluginRegistry::FftFactory>(
|
||||||
|
rocm::kROCmPlatformId, gpu::kRocFftPlugin, "rocFFT",
|
||||||
|
[](internal::StreamExecutorInterface *parent) -> fft::FftSupport * {
|
||||||
|
gpu::GpuExecutor *rocm_executor =
|
||||||
|
dynamic_cast<gpu::GpuExecutor *>(parent);
|
||||||
|
if (rocm_executor == nullptr) {
|
||||||
|
LOG(ERROR)
|
||||||
|
<< "Attempting to initialize an instance of the rocFFT "
|
||||||
|
<< "support library with a non-ROCM StreamExecutor";
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
|
||||||
|
return new gpu::ROCMFft(rocm_executor);
|
||||||
|
});
|
||||||
|
if (!status.ok()) {
|
||||||
|
LOG(ERROR) << "Unable to register rocFFT factory: "
|
||||||
|
<< status.error_message();
|
||||||
|
}
|
||||||
|
|
||||||
|
PluginRegistry::Instance()->SetDefaultFactory(
|
||||||
|
rocm::kROCmPlatformId, PluginKind::kFft, gpu::kRocFftPlugin);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace stream_executor
|
||||||
|
|
||||||
|
REGISTER_MODULE_INITIALIZER(register_rocfft,
|
||||||
|
{ stream_executor::initialize_rocfft(); });
|
132
tensorflow/stream_executor/rocm/rocm_fft.h
Normal file
132
tensorflow/stream_executor/rocm/rocm_fft.h
Normal file
@ -0,0 +1,132 @@
|
|||||||
|
/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
// ROCM-specific support for FFT functionality -- this wraps the rocFFT library
|
||||||
|
// capabilities, and is only included into ROCM implementation code -- it will
|
||||||
|
// not introduce rocm headers into other code.
|
||||||
|
|
||||||
|
#ifndef TENSORFLOW_STREAM_EXECUTOR_ROCM_ROCM_FFT_H_
|
||||||
|
#define TENSORFLOW_STREAM_EXECUTOR_ROCM_ROCM_FFT_H_
|
||||||
|
|
||||||
|
#include "rocm/include/rocfft/hipfft.h"
|
||||||
|
#include "tensorflow/stream_executor/fft.h"
|
||||||
|
#include "tensorflow/stream_executor/platform/port.h"
|
||||||
|
#include "tensorflow/stream_executor/plugin_registry.h"
|
||||||
|
#include "tensorflow/stream_executor/scratch_allocator.h"
|
||||||
|
|
||||||
|
namespace stream_executor {
|
||||||
|
|
||||||
|
class Stream;
|
||||||
|
|
||||||
|
namespace gpu {
|
||||||
|
|
||||||
|
class GpuExecutor;
|
||||||
|
|
||||||
|
// Opaque and unique indentifier for the rocFFT plugin.
|
||||||
|
extern const PluginId kRocFftPlugin;
|
||||||
|
|
||||||
|
// ROCMFftPlan uses deferred initialization. Only a single call of
|
||||||
|
// Initialize() is allowed to properly create hipfft plan and set member
|
||||||
|
// variable is_initialized_ to true. Newly added interface that uses member
|
||||||
|
// variables should first check is_initialized_ to make sure that the values of
|
||||||
|
// member variables are valid.
|
||||||
|
class ROCMFftPlan : public fft::Plan {
|
||||||
|
public:
|
||||||
|
ROCMFftPlan()
|
||||||
|
: parent_(nullptr),
|
||||||
|
plan_(),
|
||||||
|
fft_type_(fft::Type::kInvalid),
|
||||||
|
scratch_(nullptr),
|
||||||
|
is_initialized_(false) {}
|
||||||
|
~ROCMFftPlan() override;
|
||||||
|
|
||||||
|
// Get FFT direction in hipFFT based on FFT type.
|
||||||
|
int GetFftDirection() const;
|
||||||
|
hipfftHandle GetPlan() const {
|
||||||
|
if (IsInitialized()) {
|
||||||
|
return plan_;
|
||||||
|
} else {
|
||||||
|
LOG(FATAL) << "Try to get hipfftHandle value before initialization.";
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Initialize function for batched plan
|
||||||
|
port::Status Initialize(GpuExecutor *parent, Stream *stream, int rank,
|
||||||
|
uint64 *elem_count, uint64 *input_embed,
|
||||||
|
uint64 input_stride, uint64 input_distance,
|
||||||
|
uint64 *output_embed, uint64 output_stride,
|
||||||
|
uint64 output_distance, fft::Type type,
|
||||||
|
int batch_count, ScratchAllocator *scratch_allocator);
|
||||||
|
|
||||||
|
// Initialize function for 1d,2d, and 3d plan
|
||||||
|
port::Status Initialize(GpuExecutor *parent, Stream *stream, int rank,
|
||||||
|
uint64 *elem_count, fft::Type type,
|
||||||
|
ScratchAllocator *scratch_allocator);
|
||||||
|
|
||||||
|
protected:
|
||||||
|
bool IsInitialized() const { return is_initialized_; }
|
||||||
|
|
||||||
|
private:
|
||||||
|
GpuExecutor *parent_;
|
||||||
|
hipfftHandle plan_;
|
||||||
|
fft::Type fft_type_;
|
||||||
|
DeviceMemory<uint8> scratch_;
|
||||||
|
bool is_initialized_;
|
||||||
|
};
|
||||||
|
|
||||||
|
// FFT support for ROCM platform via rocFFT library.
|
||||||
|
//
|
||||||
|
// This satisfies the platform-agnostic FftSupport interface.
|
||||||
|
//
|
||||||
|
// Note that the hipFFT handle that this encapsulates is implicitly tied to the
|
||||||
|
// context (and, as a result, the device) that the parent GpuExecutor is tied
|
||||||
|
// to. This simply happens as an artifact of creating the hipFFT handle when a
|
||||||
|
// ROCM context is active.
|
||||||
|
//
|
||||||
|
// Thread-safe. The ROCM context associated with all operations is the ROCM
|
||||||
|
// context of parent_, so all context is explicit.
|
||||||
|
class ROCMFft : public fft::FftSupport {
|
||||||
|
public:
|
||||||
|
explicit ROCMFft(GpuExecutor *parent) : parent_(parent) {}
|
||||||
|
~ROCMFft() override {}
|
||||||
|
|
||||||
|
TENSORFLOW_STREAM_EXECUTOR_GPU_FFT_SUPPORT_OVERRIDES
|
||||||
|
|
||||||
|
private:
|
||||||
|
GpuExecutor *parent_;
|
||||||
|
|
||||||
|
// Two helper functions that execute dynload::hipfftExec?2?.
|
||||||
|
|
||||||
|
// This is for complex to complex FFT, when the direction is required.
|
||||||
|
template <typename FuncT, typename InputT, typename OutputT>
|
||||||
|
bool DoFftWithDirectionInternal(Stream *stream, fft::Plan *plan,
|
||||||
|
FuncT hipfft_exec,
|
||||||
|
const DeviceMemory<InputT> &input,
|
||||||
|
DeviceMemory<OutputT> *output);
|
||||||
|
|
||||||
|
// This is for complex to real or real to complex FFT, when the direction
|
||||||
|
// is implied.
|
||||||
|
template <typename FuncT, typename InputT, typename OutputT>
|
||||||
|
bool DoFftInternal(Stream *stream, fft::Plan *plan, FuncT hipfft_exec,
|
||||||
|
const DeviceMemory<InputT> &input,
|
||||||
|
DeviceMemory<OutputT> *output);
|
||||||
|
|
||||||
|
SE_DISALLOW_COPY_AND_ASSIGN(ROCMFft);
|
||||||
|
};
|
||||||
|
|
||||||
|
} // namespace gpu
|
||||||
|
} // namespace stream_executor
|
||||||
|
|
||||||
|
#endif // TENSORFLOW_STREAM_EXECUTOR_ROCM_ROCM_FFT_H_
|
@ -18,7 +18,6 @@ limitations under the License.
|
|||||||
#include "absl/base/casts.h"
|
#include "absl/base/casts.h"
|
||||||
#include "absl/strings/str_cat.h"
|
#include "absl/strings/str_cat.h"
|
||||||
#include "absl/strings/str_format.h"
|
#include "absl/strings/str_format.h"
|
||||||
#include "tensorflow/stream_executor/gpu/gpu_diagnostics.h"
|
|
||||||
#include "tensorflow/stream_executor/gpu/gpu_driver.h"
|
#include "tensorflow/stream_executor/gpu/gpu_driver.h"
|
||||||
#include "tensorflow/stream_executor/gpu/gpu_event.h"
|
#include "tensorflow/stream_executor/gpu/gpu_event.h"
|
||||||
#include "tensorflow/stream_executor/gpu/gpu_executor.h"
|
#include "tensorflow/stream_executor/gpu/gpu_executor.h"
|
||||||
@ -41,6 +40,7 @@ limitations under the License.
|
|||||||
#include "tensorflow/stream_executor/platform/logging.h"
|
#include "tensorflow/stream_executor/platform/logging.h"
|
||||||
#include "tensorflow/stream_executor/platform/port.h"
|
#include "tensorflow/stream_executor/platform/port.h"
|
||||||
#include "tensorflow/stream_executor/plugin_registry.h"
|
#include "tensorflow/stream_executor/plugin_registry.h"
|
||||||
|
#include "tensorflow/stream_executor/rocm/rocm_diagnostics.h"
|
||||||
#include "tensorflow/stream_executor/rocm/rocm_platform_id.h"
|
#include "tensorflow/stream_executor/rocm/rocm_platform_id.h"
|
||||||
#include "tensorflow/stream_executor/stream.h"
|
#include "tensorflow/stream_executor/stream.h"
|
||||||
#include "tensorflow/stream_executor/stream_executor_internal.h"
|
#include "tensorflow/stream_executor/stream_executor_internal.h"
|
||||||
@ -655,7 +655,7 @@ port::Status GpuExecutor::BlockHostUntilDone(Stream* stream) {
|
|||||||
blas::BlasSupport* GpuExecutor::CreateBlas() {
|
blas::BlasSupport* GpuExecutor::CreateBlas() {
|
||||||
PluginRegistry* registry = PluginRegistry::Instance();
|
PluginRegistry* registry = PluginRegistry::Instance();
|
||||||
port::StatusOr<PluginRegistry::BlasFactory> status =
|
port::StatusOr<PluginRegistry::BlasFactory> status =
|
||||||
registry->GetFactory<PluginRegistry::BlasFactory>(kROCmPlatformId,
|
registry->GetFactory<PluginRegistry::BlasFactory>(rocm::kROCmPlatformId,
|
||||||
plugin_config_.blas());
|
plugin_config_.blas());
|
||||||
if (!status.ok()) {
|
if (!status.ok()) {
|
||||||
LOG(ERROR) << "Unable to retrieve BLAS factory: "
|
LOG(ERROR) << "Unable to retrieve BLAS factory: "
|
||||||
@ -669,7 +669,7 @@ blas::BlasSupport* GpuExecutor::CreateBlas() {
|
|||||||
dnn::DnnSupport* GpuExecutor::CreateDnn() {
|
dnn::DnnSupport* GpuExecutor::CreateDnn() {
|
||||||
PluginRegistry* registry = PluginRegistry::Instance();
|
PluginRegistry* registry = PluginRegistry::Instance();
|
||||||
port::StatusOr<PluginRegistry::DnnFactory> status =
|
port::StatusOr<PluginRegistry::DnnFactory> status =
|
||||||
registry->GetFactory<PluginRegistry::DnnFactory>(kROCmPlatformId,
|
registry->GetFactory<PluginRegistry::DnnFactory>(rocm::kROCmPlatformId,
|
||||||
plugin_config_.dnn());
|
plugin_config_.dnn());
|
||||||
if (!status.ok()) {
|
if (!status.ok()) {
|
||||||
LOG(ERROR) << "Unable to retrieve DNN factory: "
|
LOG(ERROR) << "Unable to retrieve DNN factory: "
|
||||||
@ -683,7 +683,7 @@ dnn::DnnSupport* GpuExecutor::CreateDnn() {
|
|||||||
fft::FftSupport* GpuExecutor::CreateFft() {
|
fft::FftSupport* GpuExecutor::CreateFft() {
|
||||||
PluginRegistry* registry = PluginRegistry::Instance();
|
PluginRegistry* registry = PluginRegistry::Instance();
|
||||||
port::StatusOr<PluginRegistry::FftFactory> status =
|
port::StatusOr<PluginRegistry::FftFactory> status =
|
||||||
registry->GetFactory<PluginRegistry::FftFactory>(kROCmPlatformId,
|
registry->GetFactory<PluginRegistry::FftFactory>(rocm::kROCmPlatformId,
|
||||||
plugin_config_.fft());
|
plugin_config_.fft());
|
||||||
if (!status.ok()) {
|
if (!status.ok()) {
|
||||||
LOG(ERROR) << "Unable to retrieve FFT factory: "
|
LOG(ERROR) << "Unable to retrieve FFT factory: "
|
||||||
@ -697,7 +697,7 @@ fft::FftSupport* GpuExecutor::CreateFft() {
|
|||||||
rng::RngSupport* GpuExecutor::CreateRng() {
|
rng::RngSupport* GpuExecutor::CreateRng() {
|
||||||
PluginRegistry* registry = PluginRegistry::Instance();
|
PluginRegistry* registry = PluginRegistry::Instance();
|
||||||
port::StatusOr<PluginRegistry::RngFactory> status =
|
port::StatusOr<PluginRegistry::RngFactory> status =
|
||||||
registry->GetFactory<PluginRegistry::RngFactory>(kROCmPlatformId,
|
registry->GetFactory<PluginRegistry::RngFactory>(rocm::kROCmPlatformId,
|
||||||
plugin_config_.rng());
|
plugin_config_.rng());
|
||||||
if (!status.ok()) {
|
if (!status.ok()) {
|
||||||
LOG(ERROR) << "Unable to retrieve RNG factory: "
|
LOG(ERROR) << "Unable to retrieve RNG factory: "
|
||||||
@ -878,12 +878,10 @@ DeviceDescription* GpuExecutor::PopulateDeviceDescription() const {
|
|||||||
{
|
{
|
||||||
int driver_version = 0;
|
int driver_version = 0;
|
||||||
(void)GpuDriver::GetDriverVersion(&driver_version);
|
(void)GpuDriver::GetDriverVersion(&driver_version);
|
||||||
string augmented_driver_version =
|
string augmented_driver_version = absl::StrFormat(
|
||||||
absl::StrFormat("%d (%s)", driver_version, "__FIXME__");
|
"%d (%s)", driver_version,
|
||||||
// FIXME:
|
rocm::DriverVersionStatusToString(Diagnostician::FindDsoVersion())
|
||||||
// uncomment the line below once the "DriverVersionStatusToString"
|
.c_str());
|
||||||
// routine is moved from the "cuda" namespace to the "gpu" naemspace
|
|
||||||
// DriverVersionStatusToString(Diagnostician::FindDsoVersion()).c_str());
|
|
||||||
builder.set_driver_version(augmented_driver_version);
|
builder.set_driver_version(augmented_driver_version);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -94,7 +94,7 @@ port::StatusOr<StreamExecutor*> ROCmPlatform::FirstExecutorForBus(
|
|||||||
absl::StrFormat("Executor for bus %d not found.", bus_ordinal)};
|
absl::StrFormat("Executor for bus %d not found.", bus_ordinal)};
|
||||||
}
|
}
|
||||||
|
|
||||||
Platform::Id ROCmPlatform::id() const { return kROCmPlatformId; }
|
Platform::Id ROCmPlatform::id() const { return rocm::kROCmPlatformId; }
|
||||||
|
|
||||||
int ROCmPlatform::VisibleDeviceCount() const {
|
int ROCmPlatform::VisibleDeviceCount() const {
|
||||||
// Throw away the result - it logs internally, and this [containing] function
|
// Throw away the result - it logs internally, and this [containing] function
|
||||||
|
@ -16,9 +16,9 @@ limitations under the License.
|
|||||||
#include "tensorflow/stream_executor/rocm/rocm_platform_id.h"
|
#include "tensorflow/stream_executor/rocm/rocm_platform_id.h"
|
||||||
|
|
||||||
namespace stream_executor {
|
namespace stream_executor {
|
||||||
namespace gpu {
|
namespace rocm {
|
||||||
|
|
||||||
PLATFORM_DEFINE_ID(kROCmPlatformId);
|
PLATFORM_DEFINE_ID(kROCmPlatformId);
|
||||||
|
|
||||||
} // namespace gpu
|
} // namespace rocm
|
||||||
} // namespace stream_executor
|
} // namespace stream_executor
|
||||||
|
@ -19,16 +19,16 @@ limitations under the License.
|
|||||||
#include "tensorflow/stream_executor/platform.h"
|
#include "tensorflow/stream_executor/platform.h"
|
||||||
|
|
||||||
namespace stream_executor {
|
namespace stream_executor {
|
||||||
namespace gpu {
|
namespace rocm {
|
||||||
|
|
||||||
// Opaque and unique identifier for the ROCm platform.
|
// Opaque and unique identifier for the ROCm platform.
|
||||||
// This is needed so that plugins can refer to/identify this platform without
|
// This is needed so that plugins can refer to/identify this platform without
|
||||||
// instantiating a ROCmPlatform object.
|
// instantiating a ROCmPlatform object.
|
||||||
// This is broken out here to avoid a circular dependency between ROCmPlatform
|
// This is broken out here to avoid a circular dependency between ROCmPlatform
|
||||||
// and GpuExecutor.
|
// and ROCmExecutor.
|
||||||
extern const Platform::Id kROCmPlatformId;
|
extern const Platform::Id kROCmPlatformId;
|
||||||
|
|
||||||
} // namespace gpu
|
} // namespace rocm
|
||||||
} // namespace stream_executor
|
} // namespace stream_executor
|
||||||
|
|
||||||
#endif // TENSORFLOW_STREAM_EXECUTOR_ROCM_ROCM_PLATFORM_ID_H_
|
#endif // TENSORFLOW_STREAM_EXECUTOR_ROCM_ROCM_PLATFORM_ID_H_
|
||||||
|
@ -14,15 +14,17 @@ limitations under the License.
|
|||||||
==============================================================================*/
|
==============================================================================*/
|
||||||
|
|
||||||
#include "rocm/include/hiprand/hiprand.h"
|
#include "rocm/include/hiprand/hiprand.h"
|
||||||
|
#include "tensorflow/stream_executor/gpu/gpu_rng.h"
|
||||||
|
|
||||||
#include "tensorflow/stream_executor/device_memory.h"
|
#include "tensorflow/stream_executor/device_memory.h"
|
||||||
#include "tensorflow/stream_executor/gpu/gpu_activation.h"
|
#include "tensorflow/stream_executor/gpu/gpu_activation.h"
|
||||||
#include "tensorflow/stream_executor/gpu/gpu_executor.h"
|
#include "tensorflow/stream_executor/gpu/gpu_executor.h"
|
||||||
#include "tensorflow/stream_executor/gpu/gpu_helpers.h"
|
#include "tensorflow/stream_executor/gpu/gpu_helpers.h"
|
||||||
#include "tensorflow/stream_executor/gpu/gpu_rng.h"
|
|
||||||
#include "tensorflow/stream_executor/gpu/gpu_stream.h"
|
#include "tensorflow/stream_executor/gpu/gpu_stream.h"
|
||||||
#include "tensorflow/stream_executor/lib/env.h"
|
#include "tensorflow/stream_executor/lib/env.h"
|
||||||
#include "tensorflow/stream_executor/lib/initialize.h"
|
#include "tensorflow/stream_executor/lib/initialize.h"
|
||||||
#include "tensorflow/stream_executor/lib/status.h"
|
#include "tensorflow/stream_executor/lib/status.h"
|
||||||
|
#include "tensorflow/stream_executor/platform/dso_loader.h"
|
||||||
#include "tensorflow/stream_executor/platform/logging.h"
|
#include "tensorflow/stream_executor/platform/logging.h"
|
||||||
#include "tensorflow/stream_executor/rng.h"
|
#include "tensorflow/stream_executor/rng.h"
|
||||||
#include "tensorflow/stream_executor/rocm/rocm_platform_id.h"
|
#include "tensorflow/stream_executor/rocm/rocm_platform_id.h"
|
||||||
@ -60,7 +62,9 @@ PLUGIN_REGISTRY_DEFINE_PLUGIN_ID(kGpuRandPlugin);
|
|||||||
|
|
||||||
namespace wrap {
|
namespace wrap {
|
||||||
|
|
||||||
#define PERFTOOLS_GPUTOOLS_HIPRAND_WRAP(__name) \
|
#ifdef PLATFORM_GOOGLE
|
||||||
|
|
||||||
|
#define STREAM_EXECUTOR_HIPRAND_WRAP(__name) \
|
||||||
struct WrapperShim__##__name { \
|
struct WrapperShim__##__name { \
|
||||||
template <typename... Args> \
|
template <typename... Args> \
|
||||||
hiprandStatus_t operator()(GpuExecutor* parent, Args... args) { \
|
hiprandStatus_t operator()(GpuExecutor* parent, Args... args) { \
|
||||||
@ -69,15 +73,47 @@ namespace wrap {
|
|||||||
} \
|
} \
|
||||||
} __name;
|
} __name;
|
||||||
|
|
||||||
PERFTOOLS_GPUTOOLS_HIPRAND_WRAP(hiprandCreateGenerator);
|
#else
|
||||||
PERFTOOLS_GPUTOOLS_HIPRAND_WRAP(hiprandDestroyGenerator);
|
|
||||||
PERFTOOLS_GPUTOOLS_HIPRAND_WRAP(hiprandSetStream);
|
#define STREAM_EXECUTOR_HIPRAND_WRAP(__name) \
|
||||||
PERFTOOLS_GPUTOOLS_HIPRAND_WRAP(hiprandGenerateUniform);
|
struct DynLoadShim__##__name { \
|
||||||
PERFTOOLS_GPUTOOLS_HIPRAND_WRAP(hiprandGenerateUniformDouble);
|
static const char* kName; \
|
||||||
PERFTOOLS_GPUTOOLS_HIPRAND_WRAP(hiprandSetPseudoRandomGeneratorSeed);
|
using FuncPtrT = std::add_pointer<decltype(::__name)>::type; \
|
||||||
PERFTOOLS_GPUTOOLS_HIPRAND_WRAP(hiprandSetGeneratorOffset);
|
static void* GetDsoHandle() { \
|
||||||
PERFTOOLS_GPUTOOLS_HIPRAND_WRAP(hiprandGenerateNormal);
|
auto s = internal::CachedDsoLoader::GetRocrandDsoHandle(); \
|
||||||
PERFTOOLS_GPUTOOLS_HIPRAND_WRAP(hiprandGenerateNormalDouble);
|
return s.ValueOrDie(); \
|
||||||
|
} \
|
||||||
|
static FuncPtrT LoadOrDie() { \
|
||||||
|
void* f; \
|
||||||
|
auto s = port::Env::Default()->GetSymbolFromLibrary(GetDsoHandle(), \
|
||||||
|
kName, &f); \
|
||||||
|
CHECK(s.ok()) << "could not find " << kName \
|
||||||
|
<< " in rocrand DSO; dlerror: " << s.error_message(); \
|
||||||
|
return reinterpret_cast<FuncPtrT>(f); \
|
||||||
|
} \
|
||||||
|
static FuncPtrT DynLoad() { \
|
||||||
|
static FuncPtrT f = LoadOrDie(); \
|
||||||
|
return f; \
|
||||||
|
} \
|
||||||
|
template <typename... Args> \
|
||||||
|
hiprandStatus operator()(GpuExecutor* parent, Args... args) { \
|
||||||
|
gpu::ScopedActivateExecutorContext sac{parent}; \
|
||||||
|
return DynLoad()(args...); \
|
||||||
|
} \
|
||||||
|
} __name; \
|
||||||
|
const char* DynLoadShim__##__name::kName = #__name;
|
||||||
|
|
||||||
|
#endif
|
||||||
|
|
||||||
|
STREAM_EXECUTOR_HIPRAND_WRAP(hiprandCreateGenerator);
|
||||||
|
STREAM_EXECUTOR_HIPRAND_WRAP(hiprandDestroyGenerator);
|
||||||
|
STREAM_EXECUTOR_HIPRAND_WRAP(hiprandSetStream);
|
||||||
|
STREAM_EXECUTOR_HIPRAND_WRAP(hiprandGenerateUniform);
|
||||||
|
STREAM_EXECUTOR_HIPRAND_WRAP(hiprandGenerateUniformDouble);
|
||||||
|
STREAM_EXECUTOR_HIPRAND_WRAP(hiprandSetPseudoRandomGeneratorSeed);
|
||||||
|
STREAM_EXECUTOR_HIPRAND_WRAP(hiprandSetGeneratorOffset);
|
||||||
|
STREAM_EXECUTOR_HIPRAND_WRAP(hiprandGenerateNormal);
|
||||||
|
STREAM_EXECUTOR_HIPRAND_WRAP(hiprandGenerateNormalDouble);
|
||||||
|
|
||||||
} // namespace wrap
|
} // namespace wrap
|
||||||
|
|
||||||
@ -245,40 +281,45 @@ bool GpuRng::SetSeed(Stream* stream, const uint8* seed, uint64 seed_bytes) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
} // namespace gpu
|
} // namespace gpu
|
||||||
|
|
||||||
|
void initialize_rocrand() {
|
||||||
|
auto rocRandAlreadyRegistered = PluginRegistry::Instance()->HasFactory(
|
||||||
|
rocm::kROCmPlatformId, PluginKind::kRng, gpu::kGpuRandPlugin);
|
||||||
|
|
||||||
|
if (!rocRandAlreadyRegistered) {
|
||||||
|
port::Status status =
|
||||||
|
PluginRegistry::Instance()->RegisterFactory<PluginRegistry::RngFactory>(
|
||||||
|
rocm::kROCmPlatformId, gpu::kGpuRandPlugin, "rocRAND",
|
||||||
|
[](internal::StreamExecutorInterface* parent) -> rng::RngSupport* {
|
||||||
|
gpu::GpuExecutor* rocm_executor =
|
||||||
|
dynamic_cast<gpu::GpuExecutor*>(parent);
|
||||||
|
if (rocm_executor == nullptr) {
|
||||||
|
LOG(ERROR)
|
||||||
|
<< "Attempting to initialize an instance of the hipRAND "
|
||||||
|
<< "support library with a non-ROCM StreamExecutor";
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
|
||||||
|
gpu::GpuRng* rng = new gpu::GpuRng(rocm_executor);
|
||||||
|
if (!rng->Init()) {
|
||||||
|
// Note: Init() will log a more specific error.
|
||||||
|
delete rng;
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
return rng;
|
||||||
|
});
|
||||||
|
|
||||||
|
if (!status.ok()) {
|
||||||
|
LOG(ERROR) << "Unable to register rocRAND factory: "
|
||||||
|
<< status.error_message();
|
||||||
|
}
|
||||||
|
|
||||||
|
PluginRegistry::Instance()->SetDefaultFactory(
|
||||||
|
rocm::kROCmPlatformId, PluginKind::kRng, gpu::kGpuRandPlugin);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
} // namespace stream_executor
|
} // namespace stream_executor
|
||||||
|
|
||||||
namespace se = ::stream_executor;
|
REGISTER_MODULE_INITIALIZER(register_rocrand,
|
||||||
|
{ stream_executor::initialize_rocrand(); });
|
||||||
REGISTER_MODULE_INITIALIZER(register_hiprand, {
|
|
||||||
se::port::Status status =
|
|
||||||
se::PluginRegistry::Instance()
|
|
||||||
->RegisterFactory<se::PluginRegistry::RngFactory>(
|
|
||||||
se::gpu::kROCmPlatformId, se::gpu::kGpuRandPlugin, "hipRAND",
|
|
||||||
[](se::internal::StreamExecutorInterface* parent)
|
|
||||||
-> se::rng::RngSupport* {
|
|
||||||
se::gpu::GpuExecutor* rocm_executor =
|
|
||||||
dynamic_cast<se::gpu::GpuExecutor*>(parent);
|
|
||||||
if (rocm_executor == nullptr) {
|
|
||||||
LOG(ERROR)
|
|
||||||
<< "Attempting to initialize an instance of the hipRAND "
|
|
||||||
<< "support library with a non-ROCM StreamExecutor";
|
|
||||||
return nullptr;
|
|
||||||
}
|
|
||||||
|
|
||||||
se::gpu::GpuRng* rng = new se::gpu::GpuRng(rocm_executor);
|
|
||||||
if (!rng->Init()) {
|
|
||||||
// Note: Init() will log a more specific error.
|
|
||||||
delete rng;
|
|
||||||
return nullptr;
|
|
||||||
}
|
|
||||||
return rng;
|
|
||||||
});
|
|
||||||
|
|
||||||
if (!status.ok()) {
|
|
||||||
LOG(ERROR) << "Unable to register hipRAND factory: "
|
|
||||||
<< status.error_message();
|
|
||||||
}
|
|
||||||
|
|
||||||
se::PluginRegistry::Instance()->SetDefaultFactory(
|
|
||||||
se::gpu::kROCmPlatformId, se::PluginKind::kRng, se::gpu::kGpuRandPlugin);
|
|
||||||
});
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user