cosmetic / formating changes
This commit is contained in:
parent
e1f2191a31
commit
d75510457b
@ -1815,14 +1815,14 @@ bool ROCMBlas::DoBlasGemmWithAlgorithm(
|
||||
|
||||
template <typename T>
|
||||
port::Status ROCMBlas::AllocateStridedBuffer(
|
||||
const std::vector<typename RocBlasTypeConversionHelper<T>::mapped_type *>
|
||||
&raw_ptrs,
|
||||
const std::vector<typename RocBlasTypeConversionHelper<T>::mapped_type*>&
|
||||
raw_ptrs,
|
||||
int batch_count, uint64_t batch_stride, ScratchAllocator* scratch_allocator,
|
||||
Stream* stream,
|
||||
std::unique_ptr<TemporaryDeviceMemory<
|
||||
typename RocBlasTypeConversionHelper<T>::mapped_type>>* temp_memory,
|
||||
DeviceMemory<typename RocBlasTypeConversionHelper<T>::mapped_type>
|
||||
*device_memory) {
|
||||
DeviceMemory<typename RocBlasTypeConversionHelper<T>::mapped_type>*
|
||||
device_memory) {
|
||||
assert(device_memory != nullptr);
|
||||
|
||||
using MAPPED_T = typename RocBlasTypeConversionHelper<T>::mapped_type;
|
||||
|
@ -114,14 +114,14 @@ class ROCMBlas : public blas::BlasSupport {
|
||||
// strided flavor
|
||||
template <typename T>
|
||||
port::Status AllocateStridedBuffer(
|
||||
const std::vector<typename RocBlasTypeConversionHelper<T>::mapped_type *>
|
||||
&raw_ptrs,
|
||||
const std::vector<typename RocBlasTypeConversionHelper<T>::mapped_type*>&
|
||||
raw_ptrs,
|
||||
int batch_count, uint64_t batch_stride,
|
||||
ScratchAllocator* scratch_allocator, Stream* stream,
|
||||
std::unique_ptr<TemporaryDeviceMemory<
|
||||
typename RocBlasTypeConversionHelper<T>::mapped_type>>* temp_memory,
|
||||
DeviceMemory<typename RocBlasTypeConversionHelper<T>::mapped_type>
|
||||
*device_memory);
|
||||
DeviceMemory<typename RocBlasTypeConversionHelper<T>::mapped_type>*
|
||||
device_memory);
|
||||
|
||||
// A helper function to implement DoBlasGemmBatched interfaces for generic
|
||||
// types.
|
||||
|
@ -27,10 +27,6 @@ limitations under the License.
|
||||
#include "tensorflow/stream_executor/platform/dso_loader.h"
|
||||
#include "tensorflow/stream_executor/platform/port.h"
|
||||
|
||||
#if defined(TENSORFLOW_USE_ROCM)
|
||||
|
||||
#endif
|
||||
|
||||
namespace tensorflow {
|
||||
namespace wrap {
|
||||
#ifdef PLATFORM_GOOGLE
|
||||
@ -83,8 +79,8 @@ namespace wrap {
|
||||
__macro(hipDeviceTotalMem) \
|
||||
__macro(hipDriverGetVersion) \
|
||||
__macro(hipEventCreateWithFlags) \
|
||||
__macro(hipEventElapsedTime) \
|
||||
__macro(hipEventDestroy) \
|
||||
__macro(hipEventElapsedTime) \
|
||||
__macro(hipEventQuery) \
|
||||
__macro(hipEventRecord) \
|
||||
__macro(hipEventSynchronize) \
|
||||
|
@ -86,21 +86,33 @@ namespace wrap {
|
||||
|
||||
#endif
|
||||
|
||||
// clang-format off
|
||||
#define ROCFFT_ROUTINE_EACH(__macro) \
|
||||
__macro(hipfftDestroy) __macro(hipfftSetStream) __macro(hipfftPlan1d) \
|
||||
__macro(hipfftPlan2d) __macro(hipfftPlan3d) __macro(hipfftPlanMany) \
|
||||
__macro(hipfftCreate) __macro(hipfftSetAutoAllocation) \
|
||||
__macro(hipfftSetWorkArea) __macro(hipfftGetSize1d) \
|
||||
__macro(hipfftMakePlan1d) __macro(hipfftGetSize2d) \
|
||||
__macro(hipfftMakePlan2d) __macro(hipfftGetSize3d) \
|
||||
__macro(hipfftMakePlan3d) __macro(hipfftGetSizeMany) \
|
||||
__macro(hipfftDestroy) \
|
||||
__macro(hipfftSetStream) \
|
||||
__macro(hipfftPlan1d) \
|
||||
__macro(hipfftPlan2d) \
|
||||
__macro(hipfftPlan3d) \
|
||||
__macro(hipfftPlanMany) \
|
||||
__macro(hipfftCreate) \
|
||||
__macro(hipfftSetAutoAllocation) \
|
||||
__macro(hipfftSetWorkArea) \
|
||||
__macro(hipfftGetSize1d) \
|
||||
__macro(hipfftMakePlan1d) \
|
||||
__macro(hipfftGetSize2d) \
|
||||
__macro(hipfftMakePlan2d) \
|
||||
__macro(hipfftGetSize3d) \
|
||||
__macro(hipfftMakePlan3d) \
|
||||
__macro(hipfftGetSizeMany) \
|
||||
__macro(hipfftMakePlanMany) \
|
||||
__macro(hipfftExecD2Z) \
|
||||
__macro(hipfftExecZ2D) \
|
||||
__macro(hipfftExecC2C) \
|
||||
__macro(hipfftExecC2R) \
|
||||
__macro(hipfftExecZ2Z) \
|
||||
__macro(hipfftExecR2C)
|
||||
__macro(hipfftExecR2C) \
|
||||
|
||||
// clang-format on
|
||||
|
||||
ROCFFT_ROUTINE_EACH(STREAM_EXECUTOR_ROCFFT_WRAP)
|
||||
|
||||
|
@ -14,12 +14,11 @@ limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "rocm/include/hiprand/hiprand.h"
|
||||
#include "tensorflow/stream_executor/gpu/gpu_rng.h"
|
||||
|
||||
#include "tensorflow/stream_executor/device_memory.h"
|
||||
#include "tensorflow/stream_executor/gpu/gpu_activation.h"
|
||||
#include "tensorflow/stream_executor/gpu/gpu_executor.h"
|
||||
#include "tensorflow/stream_executor/gpu/gpu_helpers.h"
|
||||
#include "tensorflow/stream_executor/gpu/gpu_rng.h"
|
||||
#include "tensorflow/stream_executor/gpu/gpu_stream.h"
|
||||
#include "tensorflow/stream_executor/lib/env.h"
|
||||
#include "tensorflow/stream_executor/lib/initialize.h"
|
||||
|
Loading…
Reference in New Issue
Block a user