diff --git a/tensorflow/core/kernels/BUILD b/tensorflow/core/kernels/BUILD index 364a05bd901..35d6a81751f 100644 --- a/tensorflow/core/kernels/BUILD +++ b/tensorflow/core/kernels/BUILD @@ -4348,6 +4348,7 @@ tf_kernel_library( prefix = "fused_batch_norm_op", deps = NN_DEPS + [ ":fill_functor", + ":redux_functor", ] + if_cuda([ "//tensorflow/core:stream_executor", ]), diff --git a/tensorflow/core/kernels/bias_op.cc b/tensorflow/core/kernels/bias_op.cc index 08979666d2f..6a407f5551a 100644 --- a/tensorflow/core/kernels/bias_op.cc +++ b/tensorflow/core/kernels/bias_op.cc @@ -18,6 +18,7 @@ limitations under the License. #define EIGEN_USE_THREADS #include "tensorflow/core/kernels/bias_op.h" + #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" #include "tensorflow/core/framework/bounds_check.h" #include "tensorflow/core/framework/numeric_op.h" @@ -273,7 +274,7 @@ class BiasGradOp : public OpKernel { using AccumT = typename AccumulatorType::type; if (data_format_ == FORMAT_NCHW) { const functor::ReduceMiddleDimensions< - T, AccumT, Eigen::internal::scalar_sum_op, + T, AccumT, T, Eigen::internal::scalar_sum_op, Eigen::internal::SumReducer> redux; Eigen::DSizes three_dims(batch, channel, @@ -282,7 +283,7 @@ class BiasGradOp : public OpKernel { output, 1); } else { const functor::ReduceOuterDimensions< - T, AccumT, Eigen::internal::scalar_sum_op> + T, AccumT, T, Eigen::internal::scalar_sum_op> redux; Eigen::DSizes two_dims(batch * height * width * depth, diff --git a/tensorflow/core/kernels/fused_batch_norm_op.cc b/tensorflow/core/kernels/fused_batch_norm_op.cc index 4179f17deee..70bd659be66 100644 --- a/tensorflow/core/kernels/fused_batch_norm_op.cc +++ b/tensorflow/core/kernels/fused_batch_norm_op.cc @@ -33,6 +33,7 @@ limitations under the License. #include "tensorflow/core/framework/tensor_types.h" #include "tensorflow/core/kernels/fill_functor.h" #include "tensorflow/core/kernels/fused_batch_norm_op.h" +#include "tensorflow/core/kernels/redux_functor.h" #include "tensorflow/core/util/env_var.h" #include "tensorflow/core/util/tensor_format.h" @@ -358,7 +359,6 @@ struct FusedBatchNormGrad { typename TTypes::ConstVec mean(mean_input.vec()); typename TTypes::ConstVec variance(variance_input.vec()); typename TTypes::Tensor x_backprop(x_backprop_output->tensor()); - typename TTypes::Vec scale_backprop(scale_backprop_output->vec()); typename TTypes::Vec offset_backprop(offset_backprop_output->vec()); // Note: the following formulas are used to compute the gradients for @@ -378,12 +378,10 @@ struct FusedBatchNormGrad { #if !defined(EIGEN_HAS_INDEX_LIST) Eigen::DSizes one_by_depth(1, depth); - Eigen::array reduce_dims({0}); Eigen::array bcast_spec({rest_size, 1}); #else Eigen::IndexList, Eigen::Index> one_by_depth; one_by_depth.set(1, depth); - Eigen::IndexList> reduce_dims; Eigen::IndexList> bcast_spec; bcast_spec.set(0, rest_size); #endif @@ -391,41 +389,182 @@ struct FusedBatchNormGrad { auto x_rest_by_depth = x.reshape(rest_by_depth).template cast(); U rest_size_inv = static_cast(1.0f / static_cast(rest_size)); + // Eigen is notoriously bad at reducing outer dimension, so we materialize + // all temporary tensors that require reduction, and then use Eigen redux + // functor, that is optimized for this particular task. + // + // All reductions are of this type: [rest_size, depth] -> [depth]. + using ScalarSum = Eigen::internal::scalar_sum_op; + const functor::ReduceOuterDimensions redux_sum_t; + const functor::ReduceOuterDimensions redux_sum_u; + + auto scratch_dtype = DataTypeToEnum::value; + + // Allocate a temporary workspace of [depth] shape. + Tensor scratch_one_by_depth; + OP_REQUIRES_OK(context, context->allocate_temp(scratch_dtype, {depth}, + &scratch_one_by_depth)); + + // Maybe allocate a temporary workspace of [rest_size, depth] shape. + Tensor scratch_rest_by_depth; + if (std::is_same::value) { + OP_REQUIRES(context, + scratch_rest_by_depth.CopyFrom(*x_backprop_output, + {rest_size, depth}), + errors::Internal("Failed to copy a tensor")); + } else { + OP_REQUIRES_OK(context, + context->allocate_temp(scratch_dtype, {rest_size, depth}, + &scratch_rest_by_depth)); + } + + typename TTypes::Tensor scratch_tensor( + scratch_rest_by_depth.tensor()); + typename TTypes::Vec scratch_vector(scratch_one_by_depth.vec()); + auto x_mean_rest_by_depth = mean.reshape(one_by_depth).broadcast(bcast_spec); - auto x_centered = (x_rest_by_depth - x_mean_rest_by_depth).eval(); + auto x_centered = (x_rest_by_depth - x_mean_rest_by_depth); auto coef0 = (variance + epsilon).rsqrt(); auto coef0_rest_by_depth = - coef0.eval().reshape(one_by_depth).broadcast(bcast_spec); + coef0.reshape(one_by_depth).broadcast(bcast_spec); auto x_scaled = x_centered * coef0_rest_by_depth; auto y_backprop_rest_by_depth = - y_backprop.eval().reshape(rest_by_depth).template cast(); - scale_backprop.device(d) = - (y_backprop_rest_by_depth * x_scaled).sum(reduce_dims); - auto y_backprop_sum = y_backprop_rest_by_depth.sum(reduce_dims); - offset_backprop.device(d) = y_backprop_sum; + y_backprop.reshape(rest_by_depth).template cast(); - auto y_backprop_sum_one_by_depth = - y_backprop_sum.eval().reshape(one_by_depth); + // Compute `scale_backprop_output`: + // scale_backprop = + // (y_backprop_rest_by_depth * x_scaled).sum(reduce_dims) + scratch_tensor.device(d) = y_backprop_rest_by_depth * x_scaled; + redux_sum_u(d, rest_by_depth, scratch_rest_by_depth, scale_backprop_output); + + // Compute 'offset_backprop_output': + // offset_backprop = + // y_backprop_rest_by_depth.sum(reduce_dims) + redux_sum_t(d, rest_by_depth, y_backprop_input, offset_backprop_output); + auto y_backprop_sum = offset_backprop; + + auto y_backprop_sum_one_by_depth = y_backprop_sum.reshape(one_by_depth); auto y_backprop_mean_one_by_depth = y_backprop_sum_one_by_depth * rest_size_inv; auto y_backprop_mean_rest_by_depth = y_backprop_mean_one_by_depth.broadcast(bcast_spec); auto y_backprop_centered = y_backprop_rest_by_depth - y_backprop_mean_rest_by_depth; - auto coef1 = - (scale * coef0).eval().reshape(one_by_depth).broadcast(bcast_spec); - auto coef2 = (coef0.square() * - (y_backprop_rest_by_depth * x_centered).mean(reduce_dims)) - .eval() + + // Compute expression: + // y_backprop_centered_mean = + // (y_backprop_rest_by_depth * x_centered).mean(reduce_dims) + scratch_tensor.device(d) = y_backprop_rest_by_depth * x_centered; + redux_sum_u(d, rest_by_depth, scratch_rest_by_depth, &scratch_one_by_depth); + auto y_backprop_centered_mean = scratch_vector / static_cast(rest_size); + + auto coef1 = (scale * coef0).reshape(one_by_depth).broadcast(bcast_spec); + auto coef2 = (coef0.square() * y_backprop_centered_mean) .reshape(one_by_depth) + .eval() .broadcast(bcast_spec); + x_backprop.reshape(rest_by_depth).device(d) = (coef1 * (y_backprop_centered - x_centered * coef2)).template cast(); } }; +template +struct FusedBatchNormFreezeGrad { + void operator()(OpKernelContext* context, const Tensor& y_backprop_input, + const Tensor& x_input, const Tensor& scale_input, + const Tensor& pop_mean_input, + const Tensor& pop_variance_input, U epsilon, + Tensor* x_backprop_output, Tensor* scale_backprop_output, + Tensor* offset_backprop_output) { + typename TTypes::ConstTensor y_backprop( + y_backprop_input.tensor()); + typename TTypes::ConstTensor input(x_input.tensor()); + typename TTypes::ConstVec scale(scale_input.vec()); + typename TTypes::ConstVec pop_mean(pop_mean_input.vec()); + typename TTypes::ConstVec pop_var(pop_variance_input.vec()); + typename TTypes::Tensor x_backprop(x_backprop_output->tensor()); + typename TTypes::Vec scale_backprop(scale_backprop_output->vec()); + + const int depth = pop_mean.dimension(0); + const int rest_size = input.size() / depth; + + const CPUDevice& d = context->eigen_device(); + + // Allocate two temporary workspaces of [depth] shape. + Tensor scratch1_vec, scratch2_vec; + OP_REQUIRES_OK(context, context->allocate_temp(DataTypeToEnum::value, + {depth}, &scratch1_vec)); + OP_REQUIRES_OK(context, context->allocate_temp(DataTypeToEnum::value, + {depth}, &scratch2_vec)); + + // Maybe allocate a temporary workspace of [rest_size, depth] shape. + Tensor scratch3_tensor; + if (std::is_same::value) { + OP_REQUIRES( + context, + scratch3_tensor.CopyFrom(*x_backprop_output, {rest_size, depth}), + errors::Internal("Failed to copy a tensor")); + } else { + OP_REQUIRES_OK(context, context->allocate_temp(DataTypeToEnum::value, + {rest_size, depth}, + &scratch3_tensor)); + } + + typename TTypes::Vec scratch1(scratch1_vec.vec()); + typename TTypes::Vec scratch2(scratch2_vec.vec()); + typename TTypes::Tensor scratch3(scratch3_tensor.tensor()); + + Eigen::DSizes rest_by_depth(rest_size, depth); +#if !defined(EIGEN_HAS_INDEX_LIST) + Eigen::DSizes one_by_depth(1, depth); + Eigen::array rest_by_one({rest_size, 1}); +#else + Eigen::IndexList, Eigen::Index> one_by_depth; + one_by_depth.set(1, depth); + Eigen::IndexList> rest_by_one; + rest_by_one.set(0, rest_size); +#endif + + // Sum reduction along the 0th dimension using custom CPU functor. + using ScalarSum = Eigen::internal::scalar_sum_op; + const functor::ReduceOuterDimensions redux_sum_t; + const functor::ReduceOuterDimensions redux_sum_u; + + // offset_backprop = sum(y_backprop) + // scale_backprop = y_backprop * ((x - pop_mean) * rsqrt(pop_var + epsilon)) + // x_backprop = y_backprop * (scale * rsqrt(pop_var + epsilon)) + + // NOTE: DEFAULT DEVICE comment is added to expression assignments that + // we don't want to be executed in a thread pool. + + auto y_backprop_rest_by_depth = + y_backprop.reshape(rest_by_depth).template cast(); + auto input_rest_by_depth = input.reshape(rest_by_depth).template cast(); + + // offset_backprop = sum(y_backprop) + redux_sum_t(d, rest_by_depth, y_backprop_input, offset_backprop_output); + + // scratch1 = rsqrt(pop_var + epsilon) + scratch1 = (pop_var + pop_var.constant(epsilon)).rsqrt(); // DEFAULT DEVICE + + // scratch2 = sum(y_backprop * (x - mean)) + scratch3.device(d) = + y_backprop_rest_by_depth * + (input_rest_by_depth - + pop_mean.reshape(one_by_depth).broadcast(rest_by_one)); + redux_sum_u(d, rest_by_depth, scratch3_tensor, &scratch2_vec); + + x_backprop.reshape(rest_by_depth).device(d) = + (y_backprop_rest_by_depth * + ((scratch1 * scale).reshape(one_by_depth).broadcast(rest_by_one))) + .template cast(); + scale_backprop = scratch2 * scratch1; // DEFAULT DEVICE + } +}; + #if !GOOGLE_CUDA && !TENSORFLOW_USE_ROCM namespace { // See implementation under GOOGLE_CUDA #ifdef below. @@ -827,12 +966,11 @@ struct FusedBatchNormGrad { #define DECLARE_GPU_SPEC(T, U) \ template <> \ void FusedBatchNormFreezeGrad::operator()( \ - const GPUDevice& d, const Tensor& y_backprop_input, \ + OpKernelContext* context, const Tensor& y_backprop_input, \ const Tensor& x_input, const Tensor& scale_input, \ const Tensor& mean_input, const Tensor& variance_input, U epsilon, \ Tensor* x_backprop_output, Tensor* scale_backprop_output, \ - Tensor* offset_backprop_output, typename TTypes::Vec scratch1, \ - typename TTypes::Vec scratch2); \ + Tensor* offset_backprop_output); \ extern template struct FusedBatchNormFreezeGrad; \ template <> \ void FusedBatchNormInferenceFunctor::operator()( \ @@ -1152,18 +1290,10 @@ class FusedBatchNormGradOpBase : public OpKernel { << "The implementation of FusedBatchNormGrad with is_training=False " "only support " << "NHWC tensor format for now."; - Tensor scratch1, scratch2; - OP_REQUIRES_OK(context, - context->allocate_temp(DataTypeToEnum::value, - scale_offset_shape, &scratch1)); - OP_REQUIRES_OK(context, - context->allocate_temp(DataTypeToEnum::value, - scale_offset_shape, &scratch2)); functor::FusedBatchNormFreezeGrad()( - context->eigen_device(), y_backprop, x, scale, - saved_mean_or_pop_mean, saved_maybe_inv_var_or_pop_var, epsilon_, - x_backprop, scale_backprop, offset_backprop, scratch1.vec(), - scratch2.vec()); + context, y_backprop, x, scale, saved_mean_or_pop_mean, + saved_maybe_inv_var_or_pop_var, epsilon_, x_backprop, scale_backprop, + offset_backprop); } } diff --git a/tensorflow/core/kernels/fused_batch_norm_op.cu.cc b/tensorflow/core/kernels/fused_batch_norm_op.cu.cc index ff088bd6f88..0d2c1c4015d 100644 --- a/tensorflow/core/kernels/fused_batch_norm_op.cu.cc +++ b/tensorflow/core/kernels/fused_batch_norm_op.cu.cc @@ -26,6 +26,83 @@ typedef Eigen::GpuDevice GPUDevice; namespace functor { +// TODO(ezhulenev): Use CUB reductions on GPU. +template +struct FusedBatchNormFreezeGrad { + void operator()(OpKernelContext* context, const Tensor& y_backprop_input, + const Tensor& x_input, const Tensor& scale_input, + const Tensor& pop_mean_input, + const Tensor& pop_variance_input, U epsilon, + Tensor* x_backprop_output, Tensor* scale_backprop_output, + Tensor* offset_backprop_output) { + typename TTypes::ConstTensor y_backprop( + y_backprop_input.tensor()); + typename TTypes::ConstTensor input(x_input.tensor()); + typename TTypes::ConstVec scale(scale_input.vec()); + typename TTypes::ConstVec pop_mean(pop_mean_input.vec()); + typename TTypes::ConstVec pop_var(pop_variance_input.vec()); + typename TTypes::Tensor x_backprop(x_backprop_output->tensor()); + typename TTypes::Vec scale_backprop(scale_backprop_output->vec()); + typename TTypes::Vec offset_backprop(offset_backprop_output->vec()); + + const int depth = pop_mean.dimension(0); + const int rest_size = input.size() / depth; + + // Allocate two temporary workspaces of [depth] shape. + Tensor scratch1_vec, scratch2_vec; + OP_REQUIRES_OK(context, context->allocate_temp(DataTypeToEnum::value, + {depth}, &scratch1_vec)); + OP_REQUIRES_OK(context, context->allocate_temp(DataTypeToEnum::value, + {depth}, &scratch2_vec)); + + typename TTypes::Vec scratch1(scratch1_vec.vec()); + typename TTypes::Vec scratch2(scratch2_vec.vec()); + + const GPUDevice& d = context->eigen_device(); + + Eigen::DSizes rest_by_depth(rest_size, depth); +#if !defined(EIGEN_HAS_INDEX_LIST) + Eigen::DSizes one_by_depth(1, depth); + Eigen::array reduction_axis{0}; + Eigen::array rest_by_one({rest_size, 1}); +#else + Eigen::IndexList, Eigen::Index> one_by_depth; + one_by_depth.set(1, depth); + Eigen::IndexList > reduction_axis; + Eigen::IndexList > rest_by_one; + rest_by_one.set(0, rest_size); +#endif + + // offset_backprop = sum(y_backprop) + // scale_backprop = y_backprop * ((x - pop_mean) * rsqrt(pop_var + epsilon)) + // x_backprop = y_backprop * (scale * rsqrt(pop_var + epsilon)) + + auto y_backprop_rest_by_depth = + y_backprop.reshape(rest_by_depth).template cast(); + auto input_rest_by_depth = input.reshape(rest_by_depth).template cast(); + + offset_backprop.device(d) = y_backprop_rest_by_depth.sum(reduction_axis); + + // scratch1 = rsqrt(pop_var + epsilon) + scratch1.device(d) = (pop_var + pop_var.constant(epsilon)).rsqrt(); + + // scratch2 = sum(y_backprop * (x - mean)) + scratch2.device(d) = + (y_backprop_rest_by_depth * + (input_rest_by_depth - + pop_mean.reshape(one_by_depth).broadcast(rest_by_one))) + .sum(reduction_axis); + + x_backprop.reshape(rest_by_depth).device(d) = + (y_backprop_rest_by_depth * ((scratch1 * scale) + .eval() + .reshape(one_by_depth) + .broadcast(rest_by_one))) + .template cast(); + scale_backprop.device(d) = scratch2 * scratch1; + } +}; + template struct FusedBatchNormFreezeGrad; template struct FusedBatchNormFreezeGrad; diff --git a/tensorflow/core/kernels/fused_batch_norm_op.h b/tensorflow/core/kernels/fused_batch_norm_op.h index 2cb19e15ddb..4936192377c 100644 --- a/tensorflow/core/kernels/fused_batch_norm_op.h +++ b/tensorflow/core/kernels/fused_batch_norm_op.h @@ -85,71 +85,15 @@ struct FusedBatchNormInferenceFunctor { #endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM // Functor used by FusedBatchNormGradOp to do the computations when -// is_training=False. Both CPU and GPU will use this functor. +// is_training=False. template struct FusedBatchNormFreezeGrad { - void operator()(const Device& d, const Tensor& y_backprop_input, + void operator()(OpKernelContext* context, const Tensor& y_backprop_input, const Tensor& x_input, const Tensor& scale_input, const Tensor& pop_mean_input, const Tensor& pop_variance_input, U epsilon, Tensor* x_backprop_output, Tensor* scale_backprop_output, - Tensor* offset_backprop_output, - typename TTypes::Vec scratch1, - typename TTypes::Vec scratch2) { - typename TTypes::ConstTensor y_backprop( - y_backprop_input.tensor()); - typename TTypes::ConstTensor input(x_input.tensor()); - typename TTypes::ConstVec scale(scale_input.vec()); - typename TTypes::ConstVec pop_mean(pop_mean_input.vec()); - typename TTypes::ConstVec pop_var(pop_variance_input.vec()); - typename TTypes::Tensor x_backprop(x_backprop_output->tensor()); - typename TTypes::Vec scale_backprop(scale_backprop_output->vec()); - typename TTypes::Vec offset_backprop(offset_backprop_output->vec()); - - const int depth = pop_mean.dimension(0); - const int rest_size = input.size() / depth; - - Eigen::DSizes rest_by_depth(rest_size, depth); -#if !defined(EIGEN_HAS_INDEX_LIST) - Eigen::DSizes one_by_depth(1, depth); - Eigen::array reduction_axis{0}; - Eigen::array rest_by_one({rest_size, 1}); -#else - Eigen::IndexList, Eigen::Index> one_by_depth; - one_by_depth.set(1, depth); - Eigen::IndexList > reduction_axis; - Eigen::IndexList > rest_by_one; - rest_by_one.set(0, rest_size); -#endif - - // offset_backprop = sum(y_backprop) - // scale_backprop = y_backprop * ((x - pop_mean) * rsqrt(pop_var + epsilon)) - // x_backprop = y_backprop * (scale * rsqrt(pop_var + epsilon)) - - auto y_backprop_rest_by_depth = - y_backprop.reshape(rest_by_depth).template cast(); - auto input_rest_by_depth = input.reshape(rest_by_depth).template cast(); - - offset_backprop.device(d) = y_backprop_rest_by_depth.sum(reduction_axis); - - // scratch1 = rsqrt(pop_var + epsilon) - scratch1.device(d) = (pop_var + pop_var.constant(epsilon)).rsqrt(); - - // scratch2 = sum(y_backprop * (x - mean)) - scratch2.device(d) = - (y_backprop_rest_by_depth * - (input_rest_by_depth - - pop_mean.reshape(one_by_depth).broadcast(rest_by_one))) - .sum(reduction_axis); - - x_backprop.reshape(rest_by_depth).device(d) = - (y_backprop_rest_by_depth * ((scratch1 * scale) - .eval() - .reshape(one_by_depth) - .broadcast(rest_by_one))) - .template cast(); - scale_backprop.device(d) = scratch2 * scratch1; - } + Tensor* offset_backprop_output) {} }; } // namespace functor diff --git a/tensorflow/core/kernels/fused_batch_norm_op_test.cc b/tensorflow/core/kernels/fused_batch_norm_op_test.cc index f765a3ee43d..5297d3ee138 100644 --- a/tensorflow/core/kernels/fused_batch_norm_op_test.cc +++ b/tensorflow/core/kernels/fused_batch_norm_op_test.cc @@ -269,6 +269,22 @@ BM_FusedBatchNorm(64, 14, 14, 256, fp16, true, NCHW, gpu); BENCHMARK(BM_NAME(FusedBatchNormGrad, N, H, W, C, T, IS_TRAINING, FORMAT, \ DEVICE)); +#define BM_FusedBatchNormGradResnetShapes(T, IS_TRAINING, FORMAT, DEVICE) \ + BM_FusedBatchNormGrad(64, 56, 56, 64, T, IS_TRAINING, FORMAT, DEVICE); \ + BM_FusedBatchNormGrad(64, 56, 56, 128, T, IS_TRAINING, FORMAT, DEVICE); \ + BM_FusedBatchNormGrad(64, 56, 56, 256, T, IS_TRAINING, FORMAT, DEVICE); \ + \ + BM_FusedBatchNormGrad(64, 28, 28, 128, T, IS_TRAINING, FORMAT, DEVICE); \ + BM_FusedBatchNormGrad(64, 28, 28, 256, T, IS_TRAINING, FORMAT, DEVICE); \ + BM_FusedBatchNormGrad(64, 28, 28, 512, T, IS_TRAINING, FORMAT, DEVICE); \ + \ + BM_FusedBatchNormGrad(64, 14, 14, 128, T, IS_TRAINING, FORMAT, DEVICE); \ + BM_FusedBatchNormGrad(64, 14, 14, 256, T, IS_TRAINING, FORMAT, DEVICE); \ + BM_FusedBatchNormGrad(64, 14, 14, 1024, T, IS_TRAINING, FORMAT, DEVICE) + +BM_FusedBatchNormGradResnetShapes(fp32, true, NHWC, cpu); +BM_FusedBatchNormGradResnetShapes(fp32, false, NHWC, cpu); + #ifdef GOOGLE_CUDA BM_FusedBatchNormGrad(64, 14, 14, 256, fp32, true, NHWC, gpu); BM_FusedBatchNormGrad(64, 14, 14, 256, fp16, true, NHWC, gpu); diff --git a/tensorflow/core/kernels/redux_functor.h b/tensorflow/core/kernels/redux_functor.h index 24dc876ef8e..30038c62dbd 100644 --- a/tensorflow/core/kernels/redux_functor.h +++ b/tensorflow/core/kernels/redux_functor.h @@ -35,16 +35,18 @@ namespace functor { // input: [D1, D2, ... , DN] // -> // output: [Di, ... , DN] where i belongs to set [1,N] -template +template struct ReduceOuterDimensions { - ReduceOuterDimensions(){}; + ReduceOuterDimensions() {} + template void operator()(const CPUDevice& device, const Eigen::DSizes& input_dims, const Tensor& input, Tensor* output) const { // Compute inner and outer dim after reshaping into 2d tensor. const int num_output_dims = output->dims(); - auto output_dims = output->template flat().dimensions(); + auto output_dims = output->template flat().dimensions(); Eigen::Index inner_dim = 1, outer_dim = 1; for (int i = 0; i < num_dims - num_output_dims; ++i) @@ -54,8 +56,8 @@ struct ReduceOuterDimensions { if (1 == outer_dim) { // Nothing to do but passing input to output. - output->template flat() = - input.template flat().reshape(output_dims); + output->template flat() = + input.template flat().reshape(output_dims); return; } @@ -63,13 +65,15 @@ struct ReduceOuterDimensions { const Eigen::Index num_threads = device.numThreads(); // If the inner dim parallelism is large enough - if (inner_dim > num_threads * 16) { + // TODO(ezhulenev): There seems to be no benefits in going this route. Check + // if this can be improved, or use better heuristic? + if (inner_dim > num_threads * 32) { // Do not create more blocks than there are threads in a pool. const Eigen::Index num_blocks = num_threads; // Block size along the outer dimension. const Eigen::Index inner_block_size = Eigen::divup(inner_dim, num_blocks); - const T* input_data = input.template flat().data(); + const InputT* input_data = input.template flat().data(); // Allocate temporary buffer for partial reductions. Eigen::Tensor buffer( @@ -82,7 +86,7 @@ struct ReduceOuterDimensions { Eigen::Unaligned>; using Input = Eigen::TensorMap< - Eigen::Tensor, + Eigen::Tensor, Eigen::Unaligned>; const auto compute = [inner_dim, outer_dim, num_blocks, inner_block_size, @@ -94,7 +98,7 @@ struct ReduceOuterDimensions { inner_dim_limit = std::min(inner_dim, inner_dim_limit); Eigen::Index my_job_len = inner_dim_limit - inner_dim_start; - const T* my_job_start = input_data + inner_dim_start; + const InputT* my_job_start = input_data + inner_dim_start; Buffer buf(buffer_data + inner_dim_start, my_job_len); for (Eigen::Index i = 0; i < outer_dim; ++i) { @@ -107,7 +111,7 @@ struct ReduceOuterDimensions { // Compute cost of reducing a single block. const Eigen::Index compute_size = outer_dim * inner_block_size; - const Eigen::Index compute_input_bytes = compute_size * sizeof(T); + const Eigen::Index compute_input_bytes = compute_size * sizeof(InputT); const Eigen::TensorOpCost cost( compute_input_bytes, 0, // We'll be mostly writing to L1, assume store cost is 0 @@ -116,8 +120,8 @@ struct ReduceOuterDimensions { device.parallelFor(num_blocks, cost, compute); // Write final result to the output. - output->template flat() = - buffer.template cast().reshape(output_dims); + output->template flat() = + buffer.template cast().reshape(output_dims); } else { // Compute block size along the outer dimension for efficiency. const Eigen::Index parallel_cell_size = inner_dim; @@ -136,7 +140,7 @@ struct ReduceOuterDimensions { // Block size along the outer dimension. const Eigen::Index outer_block_size = Eigen::divup(outer_dim, num_blocks); - const T* input_data = input.template flat().data(); + const InputT* input_data = input.template flat().data(); // Allocate temporary buffer for partial reductions. Tensor buffer(DataTypeToEnum::v(), {num_blocks, inner_dim}); @@ -148,7 +152,7 @@ struct ReduceOuterDimensions { Eigen::Unaligned>; using Input = Eigen::TensorMap< - Eigen::Tensor, + Eigen::Tensor, Eigen::Unaligned>; const auto compute = [inner_dim, num_blocks, outer_block_size, @@ -170,7 +174,7 @@ struct ReduceOuterDimensions { // Compute cost of reducing a single block. const Eigen::Index compute_size = outer_block_size * inner_dim; - const Eigen::Index compute_input_bytes = compute_size * sizeof(T); + const Eigen::Index compute_input_bytes = compute_size * sizeof(InputT); const Eigen::TensorOpCost cost( compute_input_bytes, 0, // We'll be mostly writing to L1, assume store cost is 0 @@ -187,7 +191,8 @@ struct ReduceOuterDimensions { const decltype(buf)>(buf0, buf); } // Write final result to the output. - output->template flat() = buf0.template cast().reshape(output_dims); + output->template flat() = + buf0.template cast().reshape(output_dims); } } }; @@ -197,9 +202,11 @@ struct ReduceOuterDimensions { // input: [D1, D2, ... , DN] // -> // output: [Di, ... , Dj] where i & j belongs to set [1,N]. -template +template struct ReduceMiddleDimensions { - ReduceMiddleDimensions(){}; + ReduceMiddleDimensions() {} + template void operator()(const CPUDevice& device, const Eigen::DSizes& input_dims, @@ -207,7 +214,7 @@ struct ReduceMiddleDimensions { const int axis_begin_dim) const { // Compute dims after reshaping into 3d tensor. const int num_output_dims = output->dims(); - auto output_dims = output->template flat().dimensions(); + auto output_dims = output->template flat().dimensions(); Eigen::Index inner_dim = 1, middle_dim = 1, outer_dim = 1; for (int i = 0; i < axis_begin_dim; ++i) outer_dim *= input_dims[i]; @@ -218,12 +225,12 @@ struct ReduceMiddleDimensions { if ((1 == inner_dim * outer_dim)) { // Nothing to do. - output->template flat() = - input.template flat().reshape(output_dims); + output->template flat() = + input.template flat().reshape(output_dims); return; } else if (1 == inner_dim) { // Equivalent to ReduceOuterDimensions. - const ReduceOuterDimensions redux; + const ReduceOuterDimensions redux; redux(device, input_dims, input, output); return; } @@ -247,7 +254,7 @@ struct ReduceMiddleDimensions { const Eigen::Index outer_block_size = Eigen::divup(total_workload, num_blocks); - const T* input_data = input.template flat().data(); + const InputT* input_data = input.template flat().data(); // Allocate temporary buffer for partial reductions. Eigen::Tensor buffer(num_blocks, middle_dim); @@ -255,7 +262,7 @@ struct ReduceMiddleDimensions { AccumT* buffer_data = buffer.data(); using Buffer = Eigen::TensorMap>; - using Input = Eigen::TensorMap>; + using Input = Eigen::TensorMap>; Eigen::array reduction_axis = {0}; Reducer reducer; @@ -301,7 +308,7 @@ struct ReduceMiddleDimensions { // Compute cost of reducing a single block. const Eigen::Index compute_size = outer_block_size * inner_dim; - const Eigen::Index compute_input_bytes = compute_size * sizeof(T); + const Eigen::Index compute_input_bytes = compute_size * sizeof(InputT); const Eigen::TensorOpCost cost( compute_input_bytes, 0, // We'll be mostly writing to L1, assume store cost is 0 @@ -322,7 +329,8 @@ struct ReduceMiddleDimensions { } // Write final result to the output. - output->template flat() = buf0.template cast().reshape(output_dims); + output->template flat() = + buf0.template cast().reshape(output_dims); } };