138 lines
4.9 KiB
C++
138 lines
4.9 KiB
C++
/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
|
|
|
|
Licensed under the Apache License, Version 2.0 (the "License");
|
|
you may not use this file except in compliance with the License.
|
|
You may obtain a copy of the License at
|
|
|
|
http://www.apache.org/licenses/LICENSE-2.0
|
|
|
|
Unless required by applicable law or agreed to in writing, software
|
|
distributed under the License is distributed on an "AS IS" BASIS,
|
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
See the License for the specific language governing permissions and
|
|
limitations under the License.
|
|
==============================================================================*/
|
|
|
|
// ROCM-specific support for FFT functionality -- this wraps the rocFFT library
|
|
// capabilities, and is only included into ROCM implementation code -- it will
|
|
// not introduce rocm headers into other code.
|
|
|
|
#ifndef TENSORFLOW_STREAM_EXECUTOR_ROCM_ROCM_FFT_H_
|
|
#define TENSORFLOW_STREAM_EXECUTOR_ROCM_ROCM_FFT_H_
|
|
|
|
#include "rocm/include/rocfft/hipfft.h"
|
|
#include "tensorflow/stream_executor/fft.h"
|
|
#include "tensorflow/stream_executor/platform/port.h"
|
|
#include "tensorflow/stream_executor/plugin_registry.h"
|
|
#include "tensorflow/stream_executor/scratch_allocator.h"
|
|
|
|
namespace stream_executor {
|
|
|
|
class Stream;
|
|
|
|
namespace gpu {
|
|
|
|
class GpuExecutor;
|
|
|
|
// Opaque and unique indentifier for the rocFFT plugin.
|
|
extern const PluginId kRocFftPlugin;
|
|
|
|
// ROCMFftPlan uses deferred initialization. Only a single call of
|
|
// Initialize() is allowed to properly create hipfft plan and set member
|
|
// variable is_initialized_ to true. Newly added interface that uses member
|
|
// variables should first check is_initialized_ to make sure that the values of
|
|
// member variables are valid.
|
|
class ROCMFftPlan : public fft::Plan {
|
|
public:
|
|
ROCMFftPlan()
|
|
: parent_(nullptr),
|
|
plan_(),
|
|
fft_type_(fft::Type::kInvalid),
|
|
scratch_(nullptr),
|
|
scratch_size_bytes_(0),
|
|
is_initialized_(false) {}
|
|
~ROCMFftPlan() override;
|
|
|
|
// Get FFT direction in hipFFT based on FFT type.
|
|
int GetFftDirection() const;
|
|
hipfftHandle GetPlan() const {
|
|
if (IsInitialized()) {
|
|
return plan_;
|
|
} else {
|
|
LOG(FATAL) << "Try to get hipfftHandle value before initialization.";
|
|
}
|
|
}
|
|
|
|
// Initialize function for batched plan
|
|
port::Status Initialize(GpuExecutor *parent, Stream *stream, int rank,
|
|
uint64 *elem_count, uint64 *input_embed,
|
|
uint64 input_stride, uint64 input_distance,
|
|
uint64 *output_embed, uint64 output_stride,
|
|
uint64 output_distance, fft::Type type,
|
|
int batch_count, ScratchAllocator *scratch_allocator);
|
|
|
|
// Initialize function for 1d,2d, and 3d plan
|
|
port::Status Initialize(GpuExecutor *parent, Stream *stream, int rank,
|
|
uint64 *elem_count, fft::Type type,
|
|
ScratchAllocator *scratch_allocator);
|
|
|
|
port::Status UpdateScratchAllocator(Stream *stream,
|
|
ScratchAllocator *scratch_allocator);
|
|
|
|
protected:
|
|
bool IsInitialized() const { return is_initialized_; }
|
|
|
|
private:
|
|
GpuExecutor *parent_;
|
|
hipfftHandle plan_;
|
|
fft::Type fft_type_;
|
|
DeviceMemory<uint8> scratch_;
|
|
size_t scratch_size_bytes_;
|
|
bool is_initialized_;
|
|
};
|
|
|
|
// FFT support for ROCM platform via rocFFT library.
|
|
//
|
|
// This satisfies the platform-agnostic FftSupport interface.
|
|
//
|
|
// Note that the hipFFT handle that this encapsulates is implicitly tied to the
|
|
// context (and, as a result, the device) that the parent GpuExecutor is tied
|
|
// to. This simply happens as an artifact of creating the hipFFT handle when a
|
|
// ROCM context is active.
|
|
//
|
|
// Thread-safe. The ROCM context associated with all operations is the ROCM
|
|
// context of parent_, so all context is explicit.
|
|
class ROCMFft : public fft::FftSupport {
|
|
public:
|
|
explicit ROCMFft(GpuExecutor *parent) : parent_(parent) {}
|
|
~ROCMFft() override {}
|
|
|
|
TENSORFLOW_STREAM_EXECUTOR_GPU_FFT_SUPPORT_OVERRIDES
|
|
|
|
private:
|
|
GpuExecutor *parent_;
|
|
|
|
// Two helper functions that execute dynload::hipfftExec?2?.
|
|
|
|
// This is for complex to complex FFT, when the direction is required.
|
|
template <typename FuncT, typename InputT, typename OutputT>
|
|
bool DoFftWithDirectionInternal(Stream *stream, fft::Plan *plan,
|
|
FuncT hipfft_exec,
|
|
const DeviceMemory<InputT> &input,
|
|
DeviceMemory<OutputT> *output);
|
|
|
|
// This is for complex to real or real to complex FFT, when the direction
|
|
// is implied.
|
|
template <typename FuncT, typename InputT, typename OutputT>
|
|
bool DoFftInternal(Stream *stream, fft::Plan *plan, FuncT hipfft_exec,
|
|
const DeviceMemory<InputT> &input,
|
|
DeviceMemory<OutputT> *output);
|
|
|
|
SE_DISALLOW_COPY_AND_ASSIGN(ROCMFft);
|
|
};
|
|
|
|
} // namespace gpu
|
|
} // namespace stream_executor
|
|
|
|
#endif // TENSORFLOW_STREAM_EXECUTOR_ROCM_ROCM_FFT_H_
|