Optimize FusedBatchNormGrad on CPU device.
20% speedup in tensorflow_models/official/resnet/keras:keras_cifar_main: BEFORE: {'num_batches':200, 'time_taken': 19.408517,'examples_per_second': 329.752141} {'num_batches':300, 'time_taken': 19.280430,'examples_per_second': 331.942807} {'num_batches':400, 'time_taken': 19.173295,'examples_per_second': 333.797607} AFTER: {'num_batches':200, 'time_taken': 16.136061,'examples_per_second': 396.627158} {'num_batches':300, 'time_taken': 15.969341,'examples_per_second': 400.767946} {'num_batches':400, 'time_taken': 15.745600,'examples_per_second': 406.462758} PiperOrigin-RevId: 257211709
This commit is contained in:
parent
45925ad3d0
commit
9c7ddffd97
@ -4348,6 +4348,7 @@ tf_kernel_library(
|
||||
prefix = "fused_batch_norm_op",
|
||||
deps = NN_DEPS + [
|
||||
":fill_functor",
|
||||
":redux_functor",
|
||||
] + if_cuda([
|
||||
"//tensorflow/core:stream_executor",
|
||||
]),
|
||||
|
@ -18,6 +18,7 @@ limitations under the License.
|
||||
#define EIGEN_USE_THREADS
|
||||
|
||||
#include "tensorflow/core/kernels/bias_op.h"
|
||||
|
||||
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
|
||||
#include "tensorflow/core/framework/bounds_check.h"
|
||||
#include "tensorflow/core/framework/numeric_op.h"
|
||||
@ -273,7 +274,7 @@ class BiasGradOp : public OpKernel {
|
||||
using AccumT = typename AccumulatorType<T>::type;
|
||||
if (data_format_ == FORMAT_NCHW) {
|
||||
const functor::ReduceMiddleDimensions<
|
||||
T, AccumT, Eigen::internal::scalar_sum_op<AccumT>,
|
||||
T, AccumT, T, Eigen::internal::scalar_sum_op<AccumT>,
|
||||
Eigen::internal::SumReducer<T>>
|
||||
redux;
|
||||
Eigen::DSizes<Eigen::Index, 3> three_dims(batch, channel,
|
||||
@ -282,7 +283,7 @@ class BiasGradOp : public OpKernel {
|
||||
output, 1);
|
||||
} else {
|
||||
const functor::ReduceOuterDimensions<
|
||||
T, AccumT, Eigen::internal::scalar_sum_op<AccumT>>
|
||||
T, AccumT, T, Eigen::internal::scalar_sum_op<AccumT>>
|
||||
redux;
|
||||
|
||||
Eigen::DSizes<Eigen::Index, 2> two_dims(batch * height * width * depth,
|
||||
|
@ -33,6 +33,7 @@ limitations under the License.
|
||||
#include "tensorflow/core/framework/tensor_types.h"
|
||||
#include "tensorflow/core/kernels/fill_functor.h"
|
||||
#include "tensorflow/core/kernels/fused_batch_norm_op.h"
|
||||
#include "tensorflow/core/kernels/redux_functor.h"
|
||||
#include "tensorflow/core/util/env_var.h"
|
||||
#include "tensorflow/core/util/tensor_format.h"
|
||||
|
||||
@ -358,7 +359,6 @@ struct FusedBatchNormGrad<CPUDevice, T, U> {
|
||||
typename TTypes<U>::ConstVec mean(mean_input.vec<U>());
|
||||
typename TTypes<U>::ConstVec variance(variance_input.vec<U>());
|
||||
typename TTypes<T, 4>::Tensor x_backprop(x_backprop_output->tensor<T, 4>());
|
||||
typename TTypes<U>::Vec scale_backprop(scale_backprop_output->vec<U>());
|
||||
typename TTypes<U>::Vec offset_backprop(offset_backprop_output->vec<U>());
|
||||
|
||||
// Note: the following formulas are used to compute the gradients for
|
||||
@ -378,12 +378,10 @@ struct FusedBatchNormGrad<CPUDevice, T, U> {
|
||||
|
||||
#if !defined(EIGEN_HAS_INDEX_LIST)
|
||||
Eigen::DSizes<Eigen::Index, 2> one_by_depth(1, depth);
|
||||
Eigen::array<int, 1> reduce_dims({0});
|
||||
Eigen::array<int, 2> bcast_spec({rest_size, 1});
|
||||
#else
|
||||
Eigen::IndexList<Eigen::type2index<1>, Eigen::Index> one_by_depth;
|
||||
one_by_depth.set(1, depth);
|
||||
Eigen::IndexList<Eigen::type2index<0>> reduce_dims;
|
||||
Eigen::IndexList<Eigen::Index, Eigen::type2index<1>> bcast_spec;
|
||||
bcast_spec.set(0, rest_size);
|
||||
#endif
|
||||
@ -391,41 +389,182 @@ struct FusedBatchNormGrad<CPUDevice, T, U> {
|
||||
auto x_rest_by_depth = x.reshape(rest_by_depth).template cast<U>();
|
||||
U rest_size_inv = static_cast<U>(1.0f / static_cast<U>(rest_size));
|
||||
|
||||
// Eigen is notoriously bad at reducing outer dimension, so we materialize
|
||||
// all temporary tensors that require reduction, and then use Eigen redux
|
||||
// functor, that is optimized for this particular task.
|
||||
//
|
||||
// All reductions are of this type: [rest_size, depth] -> [depth].
|
||||
using ScalarSum = Eigen::internal::scalar_sum_op<U>;
|
||||
const functor::ReduceOuterDimensions<T, U, U, ScalarSum> redux_sum_t;
|
||||
const functor::ReduceOuterDimensions<U, U, U, ScalarSum> redux_sum_u;
|
||||
|
||||
auto scratch_dtype = DataTypeToEnum<U>::value;
|
||||
|
||||
// Allocate a temporary workspace of [depth] shape.
|
||||
Tensor scratch_one_by_depth;
|
||||
OP_REQUIRES_OK(context, context->allocate_temp(scratch_dtype, {depth},
|
||||
&scratch_one_by_depth));
|
||||
|
||||
// Maybe allocate a temporary workspace of [rest_size, depth] shape.
|
||||
Tensor scratch_rest_by_depth;
|
||||
if (std::is_same<T, U>::value) {
|
||||
OP_REQUIRES(context,
|
||||
scratch_rest_by_depth.CopyFrom(*x_backprop_output,
|
||||
{rest_size, depth}),
|
||||
errors::Internal("Failed to copy a tensor"));
|
||||
} else {
|
||||
OP_REQUIRES_OK(context,
|
||||
context->allocate_temp(scratch_dtype, {rest_size, depth},
|
||||
&scratch_rest_by_depth));
|
||||
}
|
||||
|
||||
typename TTypes<U, 2>::Tensor scratch_tensor(
|
||||
scratch_rest_by_depth.tensor<U, 2>());
|
||||
typename TTypes<U>::Vec scratch_vector(scratch_one_by_depth.vec<U>());
|
||||
|
||||
auto x_mean_rest_by_depth =
|
||||
mean.reshape(one_by_depth).broadcast(bcast_spec);
|
||||
auto x_centered = (x_rest_by_depth - x_mean_rest_by_depth).eval();
|
||||
auto x_centered = (x_rest_by_depth - x_mean_rest_by_depth);
|
||||
auto coef0 = (variance + epsilon).rsqrt();
|
||||
auto coef0_rest_by_depth =
|
||||
coef0.eval().reshape(one_by_depth).broadcast(bcast_spec);
|
||||
coef0.reshape(one_by_depth).broadcast(bcast_spec);
|
||||
auto x_scaled = x_centered * coef0_rest_by_depth;
|
||||
|
||||
auto y_backprop_rest_by_depth =
|
||||
y_backprop.eval().reshape(rest_by_depth).template cast<U>();
|
||||
scale_backprop.device(d) =
|
||||
(y_backprop_rest_by_depth * x_scaled).sum(reduce_dims);
|
||||
auto y_backprop_sum = y_backprop_rest_by_depth.sum(reduce_dims);
|
||||
offset_backprop.device(d) = y_backprop_sum;
|
||||
y_backprop.reshape(rest_by_depth).template cast<U>();
|
||||
|
||||
auto y_backprop_sum_one_by_depth =
|
||||
y_backprop_sum.eval().reshape(one_by_depth);
|
||||
// Compute `scale_backprop_output`:
|
||||
// scale_backprop =
|
||||
// (y_backprop_rest_by_depth * x_scaled).sum(reduce_dims)
|
||||
scratch_tensor.device(d) = y_backprop_rest_by_depth * x_scaled;
|
||||
redux_sum_u(d, rest_by_depth, scratch_rest_by_depth, scale_backprop_output);
|
||||
|
||||
// Compute 'offset_backprop_output':
|
||||
// offset_backprop =
|
||||
// y_backprop_rest_by_depth.sum(reduce_dims)
|
||||
redux_sum_t(d, rest_by_depth, y_backprop_input, offset_backprop_output);
|
||||
auto y_backprop_sum = offset_backprop;
|
||||
|
||||
auto y_backprop_sum_one_by_depth = y_backprop_sum.reshape(one_by_depth);
|
||||
auto y_backprop_mean_one_by_depth =
|
||||
y_backprop_sum_one_by_depth * rest_size_inv;
|
||||
auto y_backprop_mean_rest_by_depth =
|
||||
y_backprop_mean_one_by_depth.broadcast(bcast_spec);
|
||||
auto y_backprop_centered =
|
||||
y_backprop_rest_by_depth - y_backprop_mean_rest_by_depth;
|
||||
auto coef1 =
|
||||
(scale * coef0).eval().reshape(one_by_depth).broadcast(bcast_spec);
|
||||
auto coef2 = (coef0.square() *
|
||||
(y_backprop_rest_by_depth * x_centered).mean(reduce_dims))
|
||||
.eval()
|
||||
|
||||
// Compute expression:
|
||||
// y_backprop_centered_mean =
|
||||
// (y_backprop_rest_by_depth * x_centered).mean(reduce_dims)
|
||||
scratch_tensor.device(d) = y_backprop_rest_by_depth * x_centered;
|
||||
redux_sum_u(d, rest_by_depth, scratch_rest_by_depth, &scratch_one_by_depth);
|
||||
auto y_backprop_centered_mean = scratch_vector / static_cast<U>(rest_size);
|
||||
|
||||
auto coef1 = (scale * coef0).reshape(one_by_depth).broadcast(bcast_spec);
|
||||
auto coef2 = (coef0.square() * y_backprop_centered_mean)
|
||||
.reshape(one_by_depth)
|
||||
.eval()
|
||||
.broadcast(bcast_spec);
|
||||
|
||||
x_backprop.reshape(rest_by_depth).device(d) =
|
||||
(coef1 * (y_backprop_centered - x_centered * coef2)).template cast<T>();
|
||||
}
|
||||
};
|
||||
|
||||
template <typename T, typename U>
|
||||
struct FusedBatchNormFreezeGrad<CPUDevice, T, U> {
|
||||
void operator()(OpKernelContext* context, const Tensor& y_backprop_input,
|
||||
const Tensor& x_input, const Tensor& scale_input,
|
||||
const Tensor& pop_mean_input,
|
||||
const Tensor& pop_variance_input, U epsilon,
|
||||
Tensor* x_backprop_output, Tensor* scale_backprop_output,
|
||||
Tensor* offset_backprop_output) {
|
||||
typename TTypes<T, 4>::ConstTensor y_backprop(
|
||||
y_backprop_input.tensor<T, 4>());
|
||||
typename TTypes<T, 4>::ConstTensor input(x_input.tensor<T, 4>());
|
||||
typename TTypes<U>::ConstVec scale(scale_input.vec<U>());
|
||||
typename TTypes<U>::ConstVec pop_mean(pop_mean_input.vec<U>());
|
||||
typename TTypes<U>::ConstVec pop_var(pop_variance_input.vec<U>());
|
||||
typename TTypes<T, 4>::Tensor x_backprop(x_backprop_output->tensor<T, 4>());
|
||||
typename TTypes<U>::Vec scale_backprop(scale_backprop_output->vec<U>());
|
||||
|
||||
const int depth = pop_mean.dimension(0);
|
||||
const int rest_size = input.size() / depth;
|
||||
|
||||
const CPUDevice& d = context->eigen_device<CPUDevice>();
|
||||
|
||||
// Allocate two temporary workspaces of [depth] shape.
|
||||
Tensor scratch1_vec, scratch2_vec;
|
||||
OP_REQUIRES_OK(context, context->allocate_temp(DataTypeToEnum<U>::value,
|
||||
{depth}, &scratch1_vec));
|
||||
OP_REQUIRES_OK(context, context->allocate_temp(DataTypeToEnum<U>::value,
|
||||
{depth}, &scratch2_vec));
|
||||
|
||||
// Maybe allocate a temporary workspace of [rest_size, depth] shape.
|
||||
Tensor scratch3_tensor;
|
||||
if (std::is_same<T, U>::value) {
|
||||
OP_REQUIRES(
|
||||
context,
|
||||
scratch3_tensor.CopyFrom(*x_backprop_output, {rest_size, depth}),
|
||||
errors::Internal("Failed to copy a tensor"));
|
||||
} else {
|
||||
OP_REQUIRES_OK(context, context->allocate_temp(DataTypeToEnum<U>::value,
|
||||
{rest_size, depth},
|
||||
&scratch3_tensor));
|
||||
}
|
||||
|
||||
typename TTypes<U>::Vec scratch1(scratch1_vec.vec<U>());
|
||||
typename TTypes<U>::Vec scratch2(scratch2_vec.vec<U>());
|
||||
typename TTypes<U, 2>::Tensor scratch3(scratch3_tensor.tensor<U, 2>());
|
||||
|
||||
Eigen::DSizes<Eigen::Index, 2> rest_by_depth(rest_size, depth);
|
||||
#if !defined(EIGEN_HAS_INDEX_LIST)
|
||||
Eigen::DSizes<Eigen::Index, 2> one_by_depth(1, depth);
|
||||
Eigen::array<int, 2> rest_by_one({rest_size, 1});
|
||||
#else
|
||||
Eigen::IndexList<Eigen::type2index<1>, Eigen::Index> one_by_depth;
|
||||
one_by_depth.set(1, depth);
|
||||
Eigen::IndexList<Eigen::Index, Eigen::type2index<1>> rest_by_one;
|
||||
rest_by_one.set(0, rest_size);
|
||||
#endif
|
||||
|
||||
// Sum reduction along the 0th dimension using custom CPU functor.
|
||||
using ScalarSum = Eigen::internal::scalar_sum_op<U>;
|
||||
const functor::ReduceOuterDimensions<T, U, U, ScalarSum> redux_sum_t;
|
||||
const functor::ReduceOuterDimensions<U, U, U, ScalarSum> redux_sum_u;
|
||||
|
||||
// offset_backprop = sum(y_backprop)
|
||||
// scale_backprop = y_backprop * ((x - pop_mean) * rsqrt(pop_var + epsilon))
|
||||
// x_backprop = y_backprop * (scale * rsqrt(pop_var + epsilon))
|
||||
|
||||
// NOTE: DEFAULT DEVICE comment is added to expression assignments that
|
||||
// we don't want to be executed in a thread pool.
|
||||
|
||||
auto y_backprop_rest_by_depth =
|
||||
y_backprop.reshape(rest_by_depth).template cast<U>();
|
||||
auto input_rest_by_depth = input.reshape(rest_by_depth).template cast<U>();
|
||||
|
||||
// offset_backprop = sum(y_backprop)
|
||||
redux_sum_t(d, rest_by_depth, y_backprop_input, offset_backprop_output);
|
||||
|
||||
// scratch1 = rsqrt(pop_var + epsilon)
|
||||
scratch1 = (pop_var + pop_var.constant(epsilon)).rsqrt(); // DEFAULT DEVICE
|
||||
|
||||
// scratch2 = sum(y_backprop * (x - mean))
|
||||
scratch3.device(d) =
|
||||
y_backprop_rest_by_depth *
|
||||
(input_rest_by_depth -
|
||||
pop_mean.reshape(one_by_depth).broadcast(rest_by_one));
|
||||
redux_sum_u(d, rest_by_depth, scratch3_tensor, &scratch2_vec);
|
||||
|
||||
x_backprop.reshape(rest_by_depth).device(d) =
|
||||
(y_backprop_rest_by_depth *
|
||||
((scratch1 * scale).reshape(one_by_depth).broadcast(rest_by_one)))
|
||||
.template cast<T>();
|
||||
scale_backprop = scratch2 * scratch1; // DEFAULT DEVICE
|
||||
}
|
||||
};
|
||||
|
||||
#if !GOOGLE_CUDA && !TENSORFLOW_USE_ROCM
|
||||
namespace {
|
||||
// See implementation under GOOGLE_CUDA #ifdef below.
|
||||
@ -827,12 +966,11 @@ struct FusedBatchNormGrad<GPUDevice, T, U> {
|
||||
#define DECLARE_GPU_SPEC(T, U) \
|
||||
template <> \
|
||||
void FusedBatchNormFreezeGrad<GPUDevice, T, U>::operator()( \
|
||||
const GPUDevice& d, const Tensor& y_backprop_input, \
|
||||
OpKernelContext* context, const Tensor& y_backprop_input, \
|
||||
const Tensor& x_input, const Tensor& scale_input, \
|
||||
const Tensor& mean_input, const Tensor& variance_input, U epsilon, \
|
||||
Tensor* x_backprop_output, Tensor* scale_backprop_output, \
|
||||
Tensor* offset_backprop_output, typename TTypes<U>::Vec scratch1, \
|
||||
typename TTypes<U>::Vec scratch2); \
|
||||
Tensor* offset_backprop_output); \
|
||||
extern template struct FusedBatchNormFreezeGrad<GPUDevice, T, U>; \
|
||||
template <> \
|
||||
void FusedBatchNormInferenceFunctor<GPUDevice, T, U>::operator()( \
|
||||
@ -1152,18 +1290,10 @@ class FusedBatchNormGradOpBase : public OpKernel {
|
||||
<< "The implementation of FusedBatchNormGrad with is_training=False "
|
||||
"only support "
|
||||
<< "NHWC tensor format for now.";
|
||||
Tensor scratch1, scratch2;
|
||||
OP_REQUIRES_OK(context,
|
||||
context->allocate_temp(DataTypeToEnum<U>::value,
|
||||
scale_offset_shape, &scratch1));
|
||||
OP_REQUIRES_OK(context,
|
||||
context->allocate_temp(DataTypeToEnum<U>::value,
|
||||
scale_offset_shape, &scratch2));
|
||||
functor::FusedBatchNormFreezeGrad<Device, T, U>()(
|
||||
context->eigen_device<Device>(), y_backprop, x, scale,
|
||||
saved_mean_or_pop_mean, saved_maybe_inv_var_or_pop_var, epsilon_,
|
||||
x_backprop, scale_backprop, offset_backprop, scratch1.vec<U>(),
|
||||
scratch2.vec<U>());
|
||||
context, y_backprop, x, scale, saved_mean_or_pop_mean,
|
||||
saved_maybe_inv_var_or_pop_var, epsilon_, x_backprop, scale_backprop,
|
||||
offset_backprop);
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -26,6 +26,83 @@ typedef Eigen::GpuDevice GPUDevice;
|
||||
|
||||
namespace functor {
|
||||
|
||||
// TODO(ezhulenev): Use CUB reductions on GPU.
|
||||
template <typename T, typename U>
|
||||
struct FusedBatchNormFreezeGrad<GPUDevice, T, U> {
|
||||
void operator()(OpKernelContext* context, const Tensor& y_backprop_input,
|
||||
const Tensor& x_input, const Tensor& scale_input,
|
||||
const Tensor& pop_mean_input,
|
||||
const Tensor& pop_variance_input, U epsilon,
|
||||
Tensor* x_backprop_output, Tensor* scale_backprop_output,
|
||||
Tensor* offset_backprop_output) {
|
||||
typename TTypes<T, 4>::ConstTensor y_backprop(
|
||||
y_backprop_input.tensor<T, 4>());
|
||||
typename TTypes<T, 4>::ConstTensor input(x_input.tensor<T, 4>());
|
||||
typename TTypes<U>::ConstVec scale(scale_input.vec<U>());
|
||||
typename TTypes<U>::ConstVec pop_mean(pop_mean_input.vec<U>());
|
||||
typename TTypes<U>::ConstVec pop_var(pop_variance_input.vec<U>());
|
||||
typename TTypes<T, 4>::Tensor x_backprop(x_backprop_output->tensor<T, 4>());
|
||||
typename TTypes<U>::Vec scale_backprop(scale_backprop_output->vec<U>());
|
||||
typename TTypes<U>::Vec offset_backprop(offset_backprop_output->vec<U>());
|
||||
|
||||
const int depth = pop_mean.dimension(0);
|
||||
const int rest_size = input.size() / depth;
|
||||
|
||||
// Allocate two temporary workspaces of [depth] shape.
|
||||
Tensor scratch1_vec, scratch2_vec;
|
||||
OP_REQUIRES_OK(context, context->allocate_temp(DataTypeToEnum<U>::value,
|
||||
{depth}, &scratch1_vec));
|
||||
OP_REQUIRES_OK(context, context->allocate_temp(DataTypeToEnum<U>::value,
|
||||
{depth}, &scratch2_vec));
|
||||
|
||||
typename TTypes<U>::Vec scratch1(scratch1_vec.vec<U>());
|
||||
typename TTypes<U>::Vec scratch2(scratch2_vec.vec<U>());
|
||||
|
||||
const GPUDevice& d = context->eigen_device<GPUDevice>();
|
||||
|
||||
Eigen::DSizes<Eigen::Index, 2> rest_by_depth(rest_size, depth);
|
||||
#if !defined(EIGEN_HAS_INDEX_LIST)
|
||||
Eigen::DSizes<Eigen::Index, 2> one_by_depth(1, depth);
|
||||
Eigen::array<int, 1> reduction_axis{0};
|
||||
Eigen::array<int, 2> rest_by_one({rest_size, 1});
|
||||
#else
|
||||
Eigen::IndexList<Eigen::type2index<1>, Eigen::Index> one_by_depth;
|
||||
one_by_depth.set(1, depth);
|
||||
Eigen::IndexList<Eigen::type2index<0> > reduction_axis;
|
||||
Eigen::IndexList<Eigen::Index, Eigen::type2index<1> > rest_by_one;
|
||||
rest_by_one.set(0, rest_size);
|
||||
#endif
|
||||
|
||||
// offset_backprop = sum(y_backprop)
|
||||
// scale_backprop = y_backprop * ((x - pop_mean) * rsqrt(pop_var + epsilon))
|
||||
// x_backprop = y_backprop * (scale * rsqrt(pop_var + epsilon))
|
||||
|
||||
auto y_backprop_rest_by_depth =
|
||||
y_backprop.reshape(rest_by_depth).template cast<U>();
|
||||
auto input_rest_by_depth = input.reshape(rest_by_depth).template cast<U>();
|
||||
|
||||
offset_backprop.device(d) = y_backprop_rest_by_depth.sum(reduction_axis);
|
||||
|
||||
// scratch1 = rsqrt(pop_var + epsilon)
|
||||
scratch1.device(d) = (pop_var + pop_var.constant(epsilon)).rsqrt();
|
||||
|
||||
// scratch2 = sum(y_backprop * (x - mean))
|
||||
scratch2.device(d) =
|
||||
(y_backprop_rest_by_depth *
|
||||
(input_rest_by_depth -
|
||||
pop_mean.reshape(one_by_depth).broadcast(rest_by_one)))
|
||||
.sum(reduction_axis);
|
||||
|
||||
x_backprop.reshape(rest_by_depth).device(d) =
|
||||
(y_backprop_rest_by_depth * ((scratch1 * scale)
|
||||
.eval()
|
||||
.reshape(one_by_depth)
|
||||
.broadcast(rest_by_one)))
|
||||
.template cast<T>();
|
||||
scale_backprop.device(d) = scratch2 * scratch1;
|
||||
}
|
||||
};
|
||||
|
||||
template struct FusedBatchNormFreezeGrad<GPUDevice, float, float>;
|
||||
template struct FusedBatchNormFreezeGrad<GPUDevice, Eigen::half, float>;
|
||||
|
||||
|
@ -85,71 +85,15 @@ struct FusedBatchNormInferenceFunctor {
|
||||
#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
|
||||
|
||||
// Functor used by FusedBatchNormGradOp to do the computations when
|
||||
// is_training=False. Both CPU and GPU will use this functor.
|
||||
// is_training=False.
|
||||
template <typename Device, typename T, typename U>
|
||||
struct FusedBatchNormFreezeGrad {
|
||||
void operator()(const Device& d, const Tensor& y_backprop_input,
|
||||
void operator()(OpKernelContext* context, const Tensor& y_backprop_input,
|
||||
const Tensor& x_input, const Tensor& scale_input,
|
||||
const Tensor& pop_mean_input,
|
||||
const Tensor& pop_variance_input, U epsilon,
|
||||
Tensor* x_backprop_output, Tensor* scale_backprop_output,
|
||||
Tensor* offset_backprop_output,
|
||||
typename TTypes<U>::Vec scratch1,
|
||||
typename TTypes<U>::Vec scratch2) {
|
||||
typename TTypes<T, 4>::ConstTensor y_backprop(
|
||||
y_backprop_input.tensor<T, 4>());
|
||||
typename TTypes<T, 4>::ConstTensor input(x_input.tensor<T, 4>());
|
||||
typename TTypes<U>::ConstVec scale(scale_input.vec<U>());
|
||||
typename TTypes<U>::ConstVec pop_mean(pop_mean_input.vec<U>());
|
||||
typename TTypes<U>::ConstVec pop_var(pop_variance_input.vec<U>());
|
||||
typename TTypes<T, 4>::Tensor x_backprop(x_backprop_output->tensor<T, 4>());
|
||||
typename TTypes<U>::Vec scale_backprop(scale_backprop_output->vec<U>());
|
||||
typename TTypes<U>::Vec offset_backprop(offset_backprop_output->vec<U>());
|
||||
|
||||
const int depth = pop_mean.dimension(0);
|
||||
const int rest_size = input.size() / depth;
|
||||
|
||||
Eigen::DSizes<Eigen::Index, 2> rest_by_depth(rest_size, depth);
|
||||
#if !defined(EIGEN_HAS_INDEX_LIST)
|
||||
Eigen::DSizes<Eigen::Index, 2> one_by_depth(1, depth);
|
||||
Eigen::array<int, 1> reduction_axis{0};
|
||||
Eigen::array<int, 2> rest_by_one({rest_size, 1});
|
||||
#else
|
||||
Eigen::IndexList<Eigen::type2index<1>, Eigen::Index> one_by_depth;
|
||||
one_by_depth.set(1, depth);
|
||||
Eigen::IndexList<Eigen::type2index<0> > reduction_axis;
|
||||
Eigen::IndexList<Eigen::Index, Eigen::type2index<1> > rest_by_one;
|
||||
rest_by_one.set(0, rest_size);
|
||||
#endif
|
||||
|
||||
// offset_backprop = sum(y_backprop)
|
||||
// scale_backprop = y_backprop * ((x - pop_mean) * rsqrt(pop_var + epsilon))
|
||||
// x_backprop = y_backprop * (scale * rsqrt(pop_var + epsilon))
|
||||
|
||||
auto y_backprop_rest_by_depth =
|
||||
y_backprop.reshape(rest_by_depth).template cast<U>();
|
||||
auto input_rest_by_depth = input.reshape(rest_by_depth).template cast<U>();
|
||||
|
||||
offset_backprop.device(d) = y_backprop_rest_by_depth.sum(reduction_axis);
|
||||
|
||||
// scratch1 = rsqrt(pop_var + epsilon)
|
||||
scratch1.device(d) = (pop_var + pop_var.constant(epsilon)).rsqrt();
|
||||
|
||||
// scratch2 = sum(y_backprop * (x - mean))
|
||||
scratch2.device(d) =
|
||||
(y_backprop_rest_by_depth *
|
||||
(input_rest_by_depth -
|
||||
pop_mean.reshape(one_by_depth).broadcast(rest_by_one)))
|
||||
.sum(reduction_axis);
|
||||
|
||||
x_backprop.reshape(rest_by_depth).device(d) =
|
||||
(y_backprop_rest_by_depth * ((scratch1 * scale)
|
||||
.eval()
|
||||
.reshape(one_by_depth)
|
||||
.broadcast(rest_by_one)))
|
||||
.template cast<T>();
|
||||
scale_backprop.device(d) = scratch2 * scratch1;
|
||||
}
|
||||
Tensor* offset_backprop_output) {}
|
||||
};
|
||||
|
||||
} // namespace functor
|
||||
|
@ -269,6 +269,22 @@ BM_FusedBatchNorm(64, 14, 14, 256, fp16, true, NCHW, gpu);
|
||||
BENCHMARK(BM_NAME(FusedBatchNormGrad, N, H, W, C, T, IS_TRAINING, FORMAT, \
|
||||
DEVICE));
|
||||
|
||||
#define BM_FusedBatchNormGradResnetShapes(T, IS_TRAINING, FORMAT, DEVICE) \
|
||||
BM_FusedBatchNormGrad(64, 56, 56, 64, T, IS_TRAINING, FORMAT, DEVICE); \
|
||||
BM_FusedBatchNormGrad(64, 56, 56, 128, T, IS_TRAINING, FORMAT, DEVICE); \
|
||||
BM_FusedBatchNormGrad(64, 56, 56, 256, T, IS_TRAINING, FORMAT, DEVICE); \
|
||||
\
|
||||
BM_FusedBatchNormGrad(64, 28, 28, 128, T, IS_TRAINING, FORMAT, DEVICE); \
|
||||
BM_FusedBatchNormGrad(64, 28, 28, 256, T, IS_TRAINING, FORMAT, DEVICE); \
|
||||
BM_FusedBatchNormGrad(64, 28, 28, 512, T, IS_TRAINING, FORMAT, DEVICE); \
|
||||
\
|
||||
BM_FusedBatchNormGrad(64, 14, 14, 128, T, IS_TRAINING, FORMAT, DEVICE); \
|
||||
BM_FusedBatchNormGrad(64, 14, 14, 256, T, IS_TRAINING, FORMAT, DEVICE); \
|
||||
BM_FusedBatchNormGrad(64, 14, 14, 1024, T, IS_TRAINING, FORMAT, DEVICE)
|
||||
|
||||
BM_FusedBatchNormGradResnetShapes(fp32, true, NHWC, cpu);
|
||||
BM_FusedBatchNormGradResnetShapes(fp32, false, NHWC, cpu);
|
||||
|
||||
#ifdef GOOGLE_CUDA
|
||||
BM_FusedBatchNormGrad(64, 14, 14, 256, fp32, true, NHWC, gpu);
|
||||
BM_FusedBatchNormGrad(64, 14, 14, 256, fp16, true, NHWC, gpu);
|
||||
|
@ -35,16 +35,18 @@ namespace functor {
|
||||
// input: [D1, D2, ... , DN]
|
||||
// ->
|
||||
// output: [Di, ... , DN] where i belongs to set [1,N]
|
||||
template <typename T, typename AccumT, typename BinaryFunctor>
|
||||
template <typename InputT, typename AccumT, typename OutputT,
|
||||
typename BinaryFunctor>
|
||||
struct ReduceOuterDimensions {
|
||||
ReduceOuterDimensions(){};
|
||||
ReduceOuterDimensions() {}
|
||||
|
||||
template <int num_dims>
|
||||
void operator()(const CPUDevice& device,
|
||||
const Eigen::DSizes<Eigen::Index, num_dims>& input_dims,
|
||||
const Tensor& input, Tensor* output) const {
|
||||
// Compute inner and outer dim after reshaping into 2d tensor.
|
||||
const int num_output_dims = output->dims();
|
||||
auto output_dims = output->template flat<T>().dimensions();
|
||||
auto output_dims = output->template flat<OutputT>().dimensions();
|
||||
|
||||
Eigen::Index inner_dim = 1, outer_dim = 1;
|
||||
for (int i = 0; i < num_dims - num_output_dims; ++i)
|
||||
@ -54,8 +56,8 @@ struct ReduceOuterDimensions {
|
||||
|
||||
if (1 == outer_dim) {
|
||||
// Nothing to do but passing input to output.
|
||||
output->template flat<T>() =
|
||||
input.template flat<T>().reshape(output_dims);
|
||||
output->template flat<OutputT>() =
|
||||
input.template flat<OutputT>().reshape(output_dims);
|
||||
return;
|
||||
}
|
||||
|
||||
@ -63,13 +65,15 @@ struct ReduceOuterDimensions {
|
||||
const Eigen::Index num_threads = device.numThreads();
|
||||
|
||||
// If the inner dim parallelism is large enough
|
||||
if (inner_dim > num_threads * 16) {
|
||||
// TODO(ezhulenev): There seems to be no benefits in going this route. Check
|
||||
// if this can be improved, or use better heuristic?
|
||||
if (inner_dim > num_threads * 32) {
|
||||
// Do not create more blocks than there are threads in a pool.
|
||||
const Eigen::Index num_blocks = num_threads;
|
||||
|
||||
// Block size along the outer dimension.
|
||||
const Eigen::Index inner_block_size = Eigen::divup(inner_dim, num_blocks);
|
||||
const T* input_data = input.template flat<T>().data();
|
||||
const InputT* input_data = input.template flat<InputT>().data();
|
||||
|
||||
// Allocate temporary buffer for partial reductions.
|
||||
Eigen::Tensor<AccumT, 1, Eigen::RowMajor, Eigen::Index> buffer(
|
||||
@ -82,7 +86,7 @@ struct ReduceOuterDimensions {
|
||||
Eigen::Unaligned>;
|
||||
|
||||
using Input = Eigen::TensorMap<
|
||||
Eigen::Tensor<const T, 1, Eigen::RowMajor, Eigen::Index>,
|
||||
Eigen::Tensor<const InputT, 1, Eigen::RowMajor, Eigen::Index>,
|
||||
Eigen::Unaligned>;
|
||||
|
||||
const auto compute = [inner_dim, outer_dim, num_blocks, inner_block_size,
|
||||
@ -94,7 +98,7 @@ struct ReduceOuterDimensions {
|
||||
inner_dim_limit = std::min(inner_dim, inner_dim_limit);
|
||||
Eigen::Index my_job_len = inner_dim_limit - inner_dim_start;
|
||||
|
||||
const T* my_job_start = input_data + inner_dim_start;
|
||||
const InputT* my_job_start = input_data + inner_dim_start;
|
||||
Buffer buf(buffer_data + inner_dim_start, my_job_len);
|
||||
|
||||
for (Eigen::Index i = 0; i < outer_dim; ++i) {
|
||||
@ -107,7 +111,7 @@ struct ReduceOuterDimensions {
|
||||
|
||||
// Compute cost of reducing a single block.
|
||||
const Eigen::Index compute_size = outer_dim * inner_block_size;
|
||||
const Eigen::Index compute_input_bytes = compute_size * sizeof(T);
|
||||
const Eigen::Index compute_input_bytes = compute_size * sizeof(InputT);
|
||||
const Eigen::TensorOpCost cost(
|
||||
compute_input_bytes,
|
||||
0, // We'll be mostly writing to L1, assume store cost is 0
|
||||
@ -116,8 +120,8 @@ struct ReduceOuterDimensions {
|
||||
device.parallelFor(num_blocks, cost, compute);
|
||||
|
||||
// Write final result to the output.
|
||||
output->template flat<T>() =
|
||||
buffer.template cast<T>().reshape(output_dims);
|
||||
output->template flat<OutputT>() =
|
||||
buffer.template cast<OutputT>().reshape(output_dims);
|
||||
} else {
|
||||
// Compute block size along the outer dimension for efficiency.
|
||||
const Eigen::Index parallel_cell_size = inner_dim;
|
||||
@ -136,7 +140,7 @@ struct ReduceOuterDimensions {
|
||||
// Block size along the outer dimension.
|
||||
const Eigen::Index outer_block_size = Eigen::divup(outer_dim, num_blocks);
|
||||
|
||||
const T* input_data = input.template flat<T>().data();
|
||||
const InputT* input_data = input.template flat<InputT>().data();
|
||||
|
||||
// Allocate temporary buffer for partial reductions.
|
||||
Tensor buffer(DataTypeToEnum<AccumT>::v(), {num_blocks, inner_dim});
|
||||
@ -148,7 +152,7 @@ struct ReduceOuterDimensions {
|
||||
Eigen::Unaligned>;
|
||||
|
||||
using Input = Eigen::TensorMap<
|
||||
Eigen::Tensor<const T, 1, Eigen::RowMajor, Eigen::Index>,
|
||||
Eigen::Tensor<const InputT, 1, Eigen::RowMajor, Eigen::Index>,
|
||||
Eigen::Unaligned>;
|
||||
|
||||
const auto compute = [inner_dim, num_blocks, outer_block_size,
|
||||
@ -170,7 +174,7 @@ struct ReduceOuterDimensions {
|
||||
|
||||
// Compute cost of reducing a single block.
|
||||
const Eigen::Index compute_size = outer_block_size * inner_dim;
|
||||
const Eigen::Index compute_input_bytes = compute_size * sizeof(T);
|
||||
const Eigen::Index compute_input_bytes = compute_size * sizeof(InputT);
|
||||
const Eigen::TensorOpCost cost(
|
||||
compute_input_bytes,
|
||||
0, // We'll be mostly writing to L1, assume store cost is 0
|
||||
@ -187,7 +191,8 @@ struct ReduceOuterDimensions {
|
||||
const decltype(buf)>(buf0, buf);
|
||||
}
|
||||
// Write final result to the output.
|
||||
output->template flat<T>() = buf0.template cast<T>().reshape(output_dims);
|
||||
output->template flat<OutputT>() =
|
||||
buf0.template cast<OutputT>().reshape(output_dims);
|
||||
}
|
||||
}
|
||||
};
|
||||
@ -197,9 +202,11 @@ struct ReduceOuterDimensions {
|
||||
// input: [D1, D2, ... , DN]
|
||||
// ->
|
||||
// output: [Di, ... , Dj] where i & j belongs to set [1,N].
|
||||
template <typename T, typename AccumT, typename BinaryFunctor, typename Reducer>
|
||||
template <typename InputT, typename AccumT, typename OutputT,
|
||||
typename BinaryFunctor, typename Reducer>
|
||||
struct ReduceMiddleDimensions {
|
||||
ReduceMiddleDimensions(){};
|
||||
ReduceMiddleDimensions() {}
|
||||
|
||||
template <int num_dims>
|
||||
void operator()(const CPUDevice& device,
|
||||
const Eigen::DSizes<Eigen::Index, num_dims>& input_dims,
|
||||
@ -207,7 +214,7 @@ struct ReduceMiddleDimensions {
|
||||
const int axis_begin_dim) const {
|
||||
// Compute dims after reshaping into 3d tensor.
|
||||
const int num_output_dims = output->dims();
|
||||
auto output_dims = output->template flat<T>().dimensions();
|
||||
auto output_dims = output->template flat<OutputT>().dimensions();
|
||||
|
||||
Eigen::Index inner_dim = 1, middle_dim = 1, outer_dim = 1;
|
||||
for (int i = 0; i < axis_begin_dim; ++i) outer_dim *= input_dims[i];
|
||||
@ -218,12 +225,12 @@ struct ReduceMiddleDimensions {
|
||||
|
||||
if ((1 == inner_dim * outer_dim)) {
|
||||
// Nothing to do.
|
||||
output->template flat<T>() =
|
||||
input.template flat<T>().reshape(output_dims);
|
||||
output->template flat<OutputT>() =
|
||||
input.template flat<OutputT>().reshape(output_dims);
|
||||
return;
|
||||
} else if (1 == inner_dim) {
|
||||
// Equivalent to ReduceOuterDimensions.
|
||||
const ReduceOuterDimensions<T, AccumT, BinaryFunctor> redux;
|
||||
const ReduceOuterDimensions<InputT, AccumT, OutputT, BinaryFunctor> redux;
|
||||
redux(device, input_dims, input, output);
|
||||
return;
|
||||
}
|
||||
@ -247,7 +254,7 @@ struct ReduceMiddleDimensions {
|
||||
const Eigen::Index outer_block_size =
|
||||
Eigen::divup(total_workload, num_blocks);
|
||||
|
||||
const T* input_data = input.template flat<T>().data();
|
||||
const InputT* input_data = input.template flat<InputT>().data();
|
||||
|
||||
// Allocate temporary buffer for partial reductions.
|
||||
Eigen::Tensor<AccumT, 2> buffer(num_blocks, middle_dim);
|
||||
@ -255,7 +262,7 @@ struct ReduceMiddleDimensions {
|
||||
AccumT* buffer_data = buffer.data();
|
||||
|
||||
using Buffer = Eigen::TensorMap<Eigen::Tensor<AccumT, 1>>;
|
||||
using Input = Eigen::TensorMap<Eigen::Tensor<const T, 1>>;
|
||||
using Input = Eigen::TensorMap<Eigen::Tensor<const InputT, 1>>;
|
||||
|
||||
Eigen::array<Eigen::Index, 1> reduction_axis = {0};
|
||||
Reducer reducer;
|
||||
@ -301,7 +308,7 @@ struct ReduceMiddleDimensions {
|
||||
|
||||
// Compute cost of reducing a single block.
|
||||
const Eigen::Index compute_size = outer_block_size * inner_dim;
|
||||
const Eigen::Index compute_input_bytes = compute_size * sizeof(T);
|
||||
const Eigen::Index compute_input_bytes = compute_size * sizeof(InputT);
|
||||
const Eigen::TensorOpCost cost(
|
||||
compute_input_bytes,
|
||||
0, // We'll be mostly writing to L1, assume store cost is 0
|
||||
@ -322,7 +329,8 @@ struct ReduceMiddleDimensions {
|
||||
}
|
||||
|
||||
// Write final result to the output.
|
||||
output->template flat<T>() = buf0.template cast<T>().reshape(output_dims);
|
||||
output->template flat<OutputT>() =
|
||||
buf0.template cast<OutputT>().reshape(output_dims);
|
||||
}
|
||||
};
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user