Remove vestigial code from cuDNN 4 in stream executor.

PiperOrigin-RevId: 293618779
Change-Id: I7e4ee38425bff93bf89b09b0e17c5cf79da98473
This commit is contained in:
A. Unique TensorFlower 2020-02-06 10:12:37 -08:00 committed by TensorFlower Gardener
parent a4064a389e
commit 38447b2d3b
11 changed files with 22 additions and 169 deletions

View File

@ -155,8 +155,6 @@ void RunCudnnBatchNormForwardInferenceImpl(
/*saved_mean=*/nullptr, //
/*saved_inv_var=*/nullptr, //
/*is_training=*/false, //
/*var_to_inv_var=*/nullptr, //
/*inv_var_to_var=*/nullptr, //
/*reserve_space_allocator=*/nullptr, //
/*workspace_allocator=*/nullptr);
}
@ -186,8 +184,6 @@ void RunCudnnBatchNormForwardTrainingImpl(
/*saved_mean=*/&params->output_mean, //
/*saved_inv_var=*/&params->output_inv_stddev, //
/*is_training=*/true, //
/*var_to_inv_var=*/nullptr, //
/*inv_var_to_var=*/nullptr, //
/*reserve_space_allocator=*/nullptr, //
/*workspace_allocator=*/nullptr);
}

View File

@ -827,33 +827,6 @@ struct FusedBatchNorm<GPUDevice, T, U> {
auto saved_inv_var_ptr =
StreamExecutorUtil::AsDeviceMemory<U>(*saved_inv_var);
GPUDevice d = context->eigen_device<GPUDevice>();
using se::DeviceMemory;
Tensor inv_var;
OP_REQUIRES_OK(
context, context->allocate_temp(DataTypeToEnum<U>::value,
estimated_variance.shape(), &inv_var));
auto inv_var_ptr = StreamExecutorUtil::AsDeviceMemory<U>(inv_var);
std::function<const DeviceMemory<U>&()> var_to_inv_var =
[d, epsilon, estimated_variance,
&inv_var_ptr]() -> const DeviceMemory<U>& {
auto estimated_variance_ptr =
StreamExecutorUtil::AsDeviceMemory<U>(estimated_variance);
const U* variance =
static_cast<const U*>(estimated_variance_ptr.opaque());
U* inv_variance = static_cast<U*>(inv_var_ptr.opaque());
int channels = inv_var_ptr.ElementCount();
VarianceToInvVariance<U>()(d, variance, epsilon, channels, inv_variance);
return inv_var_ptr;
};
const int64 sample_size = batch_size * height * width;
std::function<void()> inv_var_to_var = [d, &batch_var_ptr, epsilon,
sample_size]() {
U* variance = static_cast<U*>(batch_var_ptr.opaque());
int channels = batch_var_ptr.ElementCount();
InvVarianceToVariance<U>()(d, epsilon, sample_size, channels, variance);
};
std::unique_ptr<functor::CudnnBatchNormAllocatorInOutput<U>>
reserve_space_allocator;
std::unique_ptr<functor::CudnnBatchNormAllocatorInTemp<uint8>>
@ -875,8 +848,7 @@ struct FusedBatchNorm<GPUDevice, T, U> {
exponential_average_factor,
AsDnnActivationMode(activation_mode), &y_ptr, &batch_mean_ptr,
&batch_var_ptr, &saved_mean_ptr, &saved_inv_var_ptr,
is_training, std::move(var_to_inv_var),
std::move(inv_var_to_var), reserve_space_allocator.get(),
is_training, reserve_space_allocator.get(),
workspace_allocator.get())
.ok();

View File

@ -104,50 +104,6 @@ struct FusedBatchNormFreezeGrad<GPUDevice, T, U> {
template struct FusedBatchNormFreezeGrad<GPUDevice, float, float>;
template struct FusedBatchNormFreezeGrad<GPUDevice, Eigen::half, float>;
template <class T>
__global__ void VarianceToInvVarianceKernel(int nthreads,
const T* __restrict__ input,
double epsilon,
T* __restrict__ output) {
GPU_1D_KERNEL_LOOP(index, nthreads) {
output[index] = rsqrt(input[index] + T(epsilon));
}
}
template <class T>
void VarianceToInvVariance<T>::operator()(const Eigen::GpuDevice& d,
const T* variance, double epsilon,
int channels, T* inv_variance) {
GpuLaunchConfig config = GetGpuLaunchConfig(channels, d);
TF_CHECK_OK(GpuLaunchKernel(VarianceToInvVarianceKernel<T>,
config.block_count, config.thread_per_block, 0,
d.stream(), config.virtual_thread_count, variance,
epsilon, inv_variance));
}
template <class T>
__global__ void InvVarianceToVarianceKernel(int nthreads, double epsilon,
int sample_size, T* variance) {
GPU_1D_KERNEL_LOOP(index, nthreads) {
T inv_var = variance[index];
T var = __fdividef(1, inv_var * inv_var) - T(epsilon);
// This is for Bessel's correction
var *= T(sample_size) / T((sample_size > 1) ? sample_size - 1 : 1);
variance[index] = (var > 0) ? var : 0;
}
}
template <class T>
void InvVarianceToVariance<T>::operator()(const Eigen::GpuDevice& d,
double epsilon, int sample_size,
int channels, T* variance) {
GpuLaunchConfig config = GetGpuLaunchConfig(channels, d);
TF_CHECK_OK(GpuLaunchKernel(InvVarianceToVarianceKernel<T>,
config.block_count, config.thread_per_block, 0,
d.stream(), config.virtual_thread_count, epsilon,
sample_size, variance));
}
template <class T>
void SetNanFunctor<T>::operator()(const Eigen::GpuDevice& d,
typename TTypes<T>::Flat out) {
@ -155,8 +111,6 @@ void SetNanFunctor<T>::operator()(const Eigen::GpuDevice& d,
To32Bit(out).constant(Eigen::NumTraits<T>::quiet_NaN());
}
template class VarianceToInvVariance<float>;
template class InvVarianceToVariance<float>;
template class SetNanFunctor<float>;
// -------------------------------------------------------------------------- //

View File

@ -37,30 +37,6 @@ Status ParseActivationMode(OpKernelConstruction* context,
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
// There is a behavior difference between cuDNN v4 and v5 with regard to the
// scaling factor for function cudnnBatchNormalizationForwardInference.
// This function corrects the scaling factor if cuDNN v4 is used, so that
// this behavior inconsistency is hidden from TensorFlow users.
// Details: in cuDNN v4, y = bnScale * (x - mean) * variance + bnBias;
// in v5, y = bnScale * (x - mean) / sqrt(variance + epsilon) + bnBias
// The template is instantiated with T as float in batch_norm_ops.cu.cc; for
// other types, the instantiation needs to be added accordingly.
template <class T>
struct VarianceToInvVariance {
void operator()(const Eigen::GpuDevice& d, const T* variance, double epsilon,
int channels, T* inv_variance);
};
// This function converts the inverted variance of the cuDNN forward training
// output to variance for TensorFlow to calculate the running variance.
// The template is instantiated with T as float in batch_norm_ops.cu.cc; for
// other types, the instantiation needs to be added accordingly.
template <class T>
struct InvVarianceToVariance {
void operator()(const Eigen::GpuDevice& d, double epsilon, int sample_size,
int channels, T* variance);
};
// This function sets a GPU tensor to NaNs.
template <class T>
struct SetNanFunctor {

View File

@ -3468,17 +3468,14 @@ bool CudnnSupport::DoBatchNormalizationForward(
DeviceMemory<float>* batch_mean, DeviceMemory<float>* batch_var,
DeviceMemory<float>* saved_mean, DeviceMemory<float>* saved_inv_var,
bool is_training, ScratchAllocator* reserve_space_allocator,
ScratchAllocator* workspace_allocator,
std::function<const DeviceMemory<float>&()> var_to_inv_var,
std::function<void()> inv_var_to_var) {
ScratchAllocator* workspace_allocator) {
return IsStatusOk(
DoBatchNormalizationForwardImpl<float, float>(
stream, dnn::DataType::kFloat, dnn::DataType::kFloat, x, scale,
offset, estimated_mean, estimated_variance, side_input, x_desc,
scale_offset_desc, epsilon, exponential_average_factor,
activation_mode, y, batch_mean, batch_var, saved_mean, saved_inv_var,
is_training, reserve_space_allocator, workspace_allocator,
std::move(var_to_inv_var), std::move(inv_var_to_var)),
is_training, reserve_space_allocator, workspace_allocator),
/*report_error=*/true);
}
@ -3494,17 +3491,14 @@ bool CudnnSupport::DoBatchNormalizationForward(
DeviceMemory<float>* batch_mean, DeviceMemory<float>* batch_var,
DeviceMemory<float>* saved_mean, DeviceMemory<float>* saved_inv_var,
bool is_training, ScratchAllocator* reserve_space_allocator,
ScratchAllocator* workspace_allocator,
std::function<const DeviceMemory<float>&()> var_to_inv_var,
std::function<void()> inv_var_to_var) {
ScratchAllocator* workspace_allocator) {
return IsStatusOk(
DoBatchNormalizationForwardImpl<Eigen::half, float>(
stream, dnn::DataType::kHalf, dnn::DataType::kFloat, x, scale, offset,
estimated_mean, estimated_variance, side_input, x_desc,
scale_offset_desc, epsilon, exponential_average_factor,
activation_mode, y, batch_mean, batch_var, saved_mean, saved_inv_var,
is_training, reserve_space_allocator, workspace_allocator,
std::move(var_to_inv_var), std::move(inv_var_to_var)),
is_training, reserve_space_allocator, workspace_allocator),
/*report_error=*/true);
}
@ -3522,9 +3516,7 @@ port::Status CudnnSupport::DoBatchNormalizationForwardImpl(
DeviceMemory<U>* batch_mean, DeviceMemory<U>* batch_var,
DeviceMemory<U>* saved_mean, DeviceMemory<U>* saved_inv_var,
bool is_training, ScratchAllocator* reserve_space_allocator,
ScratchAllocator* workspace_allocator,
std::function<const DeviceMemory<U>&()> var_to_inv_var,
std::function<void()> inv_var_to_var) {
ScratchAllocator* workspace_allocator) {
CudnnTensorDescriptor x_descriptor(x_desc, ToCudnnDataType(input_data_type));
CudnnTensorDescriptor scale_offset_descriptor(
scale_offset_desc, ToCudnnDataType(scale_data_type));

View File

@ -219,7 +219,7 @@ class CudnnSupport : public dnn::DnnSupport {
Stream* stream, const DeviceMemory<float>& x,
const DeviceMemory<float>& scale, const DeviceMemory<float>& offset,
const DeviceMemory<float>& estimated_mean,
const DeviceMemory<float>& estimated_variance,
const DeviceMemory<float>& estimated_var_iance,
const DeviceMemory<float>& side_input, const dnn::BatchDescriptor& x_desc,
const dnn::BatchDescriptor& scale_offset_desc, const double epsilon,
const double exponential_average_factor,
@ -227,9 +227,7 @@ class CudnnSupport : public dnn::DnnSupport {
DeviceMemory<float>* batch_mean, DeviceMemory<float>* batch_var,
DeviceMemory<float>* saved_mean, DeviceMemory<float>* saved_inv_var,
bool is_training, ScratchAllocator* reserve_space_allocator,
ScratchAllocator* workspace_allocator,
std::function<const DeviceMemory<float>&()> var_to_inv_var,
std::function<void()> inv_var_to_var) override;
ScratchAllocator* workspace_allocator) override;
bool DoBatchNormalizationForward(
Stream* stream, const DeviceMemory<Eigen::half>& x,
@ -243,9 +241,7 @@ class CudnnSupport : public dnn::DnnSupport {
DeviceMemory<float>* batch_mean, DeviceMemory<float>* batch_var,
DeviceMemory<float>* saved_mean, DeviceMemory<float>* saved_inv_var,
bool is_training, ScratchAllocator* reserve_space_allocator,
ScratchAllocator* workspace_allocator,
std::function<const DeviceMemory<float>&()> var_to_inv_var,
std::function<void()> inv_var_to_var) override;
ScratchAllocator* workspace_allocator) override;
bool DoBatchNormalizationBackward(
Stream* stream, const DeviceMemory<float>& y_backprop,
@ -603,9 +599,7 @@ class CudnnSupport : public dnn::DnnSupport {
DeviceMemory<U>* batch_mean, DeviceMemory<U>* batch_var,
DeviceMemory<U>* saved_mean, DeviceMemory<U>* saved_inv_var,
bool is_training, ScratchAllocator* reserve_space_allocator,
ScratchAllocator* workspace_allocator,
std::function<const DeviceMemory<U>&()> var_to_inv_var,
std::function<void()> inv_var_to_var);
ScratchAllocator* workspace_allocator);
template <class T, class U>
port::Status DoBatchNormalizationBackwardImpl(

View File

@ -1031,11 +1031,6 @@ class DnnSupport {
// reserve_space_2: saved inv_var (1/sqrt(epsilon + variance), to be reused
// in the backward gradient computation.
// is_training: Set to true for training, false for inference.
// var_to_inv_var: a function to convert the variance to inverted variance
// for cuDNN v4 forward inference.
// inv_var_to_var: a function to convert the inverted variance to
// variance for cuDNN v4 forward training, to be used for TensorFlow
// to calculate the running variance.
virtual bool DoBatchNormalizationForward(
Stream* stream, const DeviceMemory<float>& x,
const DeviceMemory<float>& scale, const DeviceMemory<float>& offset,
@ -1049,9 +1044,7 @@ class DnnSupport {
DeviceMemory<float>* reserve_space_1,
DeviceMemory<float>* reserve_space_2, bool is_training,
ScratchAllocator* reserve_space_allocator,
ScratchAllocator* workspace_allocator,
std::function<const DeviceMemory<float>&()> var_to_inv_var,
std::function<void()> inv_var_to_var) {
ScratchAllocator* workspace_allocator) {
return false;
}
@ -1070,9 +1063,7 @@ class DnnSupport {
DeviceMemory<float>* reserve_space_1,
DeviceMemory<float>* reserve_space_2, bool is_training,
ScratchAllocator* reserve_space_allocator,
ScratchAllocator* workspace_allocator,
std::function<const DeviceMemory<float>&()> var_to_inv_var,
std::function<void()> inv_var_to_var) {
ScratchAllocator* workspace_allocator) {
return false;
}

View File

@ -3443,15 +3443,12 @@ bool MIOpenSupport::DoBatchNormalizationForward(
DeviceMemory<float>* batch_mean, DeviceMemory<float>* batch_var,
DeviceMemory<float>* saved_mean, DeviceMemory<float>* saved_inv_var,
bool is_training, ScratchAllocator* reserve_space_allocator,
ScratchAllocator* workspace_allocator,
std::function<const DeviceMemory<float>&()> var_to_inv_var,
std::function<void()> inv_var_to_var) {
ScratchAllocator* workspace_allocator) {
return DoBatchNormalizationForwardImpl<Eigen::half, float>(
stream, dnn::DataType::kHalf, dnn::DataType::kFloat, x, scale, offset,
estimated_mean, estimated_variance, side_input, x_desc, scale_offset_desc,
epsilon, exponential_average_factor, activation_mode, y, batch_mean,
batch_var, saved_mean, saved_inv_var, is_training,
std::move(var_to_inv_var), std::move(inv_var_to_var));
batch_var, saved_mean, saved_inv_var, is_training);
}
bool MIOpenSupport::DoBatchNormalizationForward(
@ -3466,15 +3463,12 @@ bool MIOpenSupport::DoBatchNormalizationForward(
DeviceMemory<float>* batch_mean, DeviceMemory<float>* batch_var,
DeviceMemory<float>* saved_mean, DeviceMemory<float>* saved_inv_var,
bool is_training, ScratchAllocator* reserve_space_allocator,
ScratchAllocator* workspace_allocator,
std::function<const DeviceMemory<float>&()> var_to_inv_var,
std::function<void()> inv_var_to_var) {
ScratchAllocator* workspace_allocator) {
return DoBatchNormalizationForwardImpl<float, float>(
stream, dnn::DataType::kFloat, dnn::DataType::kFloat, x, scale, offset,
estimated_mean, estimated_variance, side_input, x_desc, scale_offset_desc,
epsilon, exponential_average_factor, activation_mode, y, batch_mean,
batch_var, saved_mean, saved_inv_var, is_training,
std::move(var_to_inv_var), std::move(inv_var_to_var));
batch_var, saved_mean, saved_inv_var, is_training);
}
template <class T, class U>
@ -3490,8 +3484,7 @@ bool MIOpenSupport::DoBatchNormalizationForwardImpl(
dnn::ActivationMode activation_mode, DeviceMemory<T>* y,
DeviceMemory<U>* batch_mean, DeviceMemory<U>* batch_var,
DeviceMemory<U>* saved_mean, DeviceMemory<U>* saved_inv_var,
bool is_training, std::function<const DeviceMemory<U>&()> var_to_inv_var,
std::function<void()> inv_var_to_var) {
bool is_training) {
auto miopen = miopen_->GetHandle(parent_, stream);
ScopedTensorDescriptor x_descriptor{x_desc,

View File

@ -228,9 +228,7 @@ class MIOpenSupport : public dnn::DnnSupport {
DeviceMemory<float>* batch_mean, DeviceMemory<float>* batch_var,
DeviceMemory<float>* saved_mean, DeviceMemory<float>* saved_inv_var,
bool is_training, ScratchAllocator* reserve_space_allocator,
ScratchAllocator* workspace_allocator,
std::function<const DeviceMemory<float>&()> var_to_inv_var,
std::function<void()> inv_var_to_var) override;
ScratchAllocator* workspace_allocator) override;
bool DoBatchNormalizationForward(
Stream* stream, const DeviceMemory<Eigen::half>& x,
@ -244,9 +242,7 @@ class MIOpenSupport : public dnn::DnnSupport {
DeviceMemory<float>* batch_mean, DeviceMemory<float>* batch_var,
DeviceMemory<float>* saved_mean, DeviceMemory<float>* saved_inv_var,
bool is_training, ScratchAllocator* reserve_space_allocator,
ScratchAllocator* workspace_allocator,
std::function<const DeviceMemory<float>&()> var_to_inv_var,
std::function<void()> inv_var_to_var) override;
ScratchAllocator* workspace_allocator) override;
bool DoBatchNormalizationBackward(
Stream* stream, const DeviceMemory<float>& y_backprop,
@ -670,8 +666,7 @@ class MIOpenSupport : public dnn::DnnSupport {
dnn::ActivationMode activation_mode, DeviceMemory<T>* y,
DeviceMemory<U>* batch_mean, DeviceMemory<U>* batch_var,
DeviceMemory<U>* saved_mean, DeviceMemory<U>* saved_inv_var,
bool is_training, std::function<const DeviceMemory<U>&()> var_to_inv_var,
std::function<void()> inv_var_to_var);
bool is_training);
template <class T, class U>
bool DoBatchNormalizationBackwardImpl(

View File

@ -349,8 +349,6 @@ Stream &Stream::ThenBatchNormalizationForward(
DeviceMemory<float> *batch_mean, DeviceMemory<float> *batch_var,
DeviceMemory<float> *saved_mean, DeviceMemory<float> *saved_inv_var,
bool is_training,
std::function<const DeviceMemory<float> &()> var_to_inv_var,
std::function<void()> inv_var_to_var,
ScratchAllocator *reserve_space_allocator,
ScratchAllocator *workspace_allocator) {
VLOG_CALL(PARAM(x), PARAM(scale), PARAM(offset), PARAM(x_desc),
@ -362,8 +360,7 @@ Stream &Stream::ThenBatchNormalizationForward(
side_input, x_desc, scale_offset_desc, epsilon,
exponential_average_factor, activation_mode, y, batch_mean, batch_var,
saved_mean, saved_inv_var, is_training, reserve_space_allocator,
workspace_allocator, std::move(var_to_inv_var),
std::move(inv_var_to_var)));
workspace_allocator));
} else {
SetErrorAndLogNoDnnSupport();
}
@ -408,8 +405,6 @@ Stream &Stream::ThenBatchNormalizationForward(
DeviceMemory<float> *batch_mean, DeviceMemory<float> *batch_var,
DeviceMemory<float> *saved_mean, DeviceMemory<float> *saved_inv_var,
bool is_training,
std::function<const DeviceMemory<float> &()> var_to_inv_var,
std::function<void()> inv_var_to_var,
ScratchAllocator *reserve_space_allocator,
ScratchAllocator *workspace_allocator) {
VLOG_CALL(PARAM(x), PARAM(scale), PARAM(offset), PARAM(x_desc),
@ -421,8 +416,7 @@ Stream &Stream::ThenBatchNormalizationForward(
side_input, x_desc, scale_offset_desc, epsilon,
exponential_average_factor, activation_mode, y, batch_mean, batch_var,
saved_mean, saved_inv_var, is_training, reserve_space_allocator,
workspace_allocator, std::move(var_to_inv_var),
std::move(inv_var_to_var)));
workspace_allocator));
} else {
SetErrorAndLogNoDnnSupport();
}

View File

@ -241,8 +241,6 @@ class Stream {
DeviceMemory<float> *batch_mean, DeviceMemory<float> *batch_var,
DeviceMemory<float> *saved_mean, DeviceMemory<float> *saved_inv_var,
bool is_training,
std::function<const DeviceMemory<float> &()> var_to_inv_var,
std::function<void()> inv_var_to_var,
ScratchAllocator *reserve_space_allocator,
ScratchAllocator *workspace_allocator);
@ -268,8 +266,6 @@ class Stream {
DeviceMemory<float> *batch_mean, DeviceMemory<float> *batch_var,
DeviceMemory<float> *saved_mean, DeviceMemory<float> *saved_inv_var,
bool is_training,
std::function<const DeviceMemory<float> &()> var_to_inv_var,
std::function<void()> inv_var_to_var,
ScratchAllocator *reserve_space_allocator,
ScratchAllocator *workspace_allocator);