diff --git a/tensorflow/core/kernels/conv_ops_fused_impl.h b/tensorflow/core/kernels/conv_ops_fused_impl.h index ce60afecc55..5011e8ba7a1 100644 --- a/tensorflow/core/kernels/conv_ops_fused_impl.h +++ b/tensorflow/core/kernels/conv_ops_fused_impl.h @@ -106,20 +106,15 @@ class LaunchFusedConv2DWithOutputKernel { template void operator()(const OutputKernel& output_kernel, OpKernelContext* ctx, const Tensor& input, const Tensor& filter, Tensor* output) { - // Wrap output_kernel into type erased function to reduce the number of + // Wrap output_kernel into type erased wrapper to reduce the number of // unique template instantiations for Eigen Tensor contraction expressions. - using OutputKernelFn = - std::function&, - const Eigen::TensorContractionParams&, Eigen::Index, - Eigen::Index, Eigen::Index, Eigen::Index)>; - - OutputKernelFn output_kernel_fn = + OutputKernelWrapper output_kernel_wrapper( [&output_kernel]( const ContractionOutputMapper& output_mapper, const Eigen::TensorContractionParams& params, Eigen::Index i, Eigen::Index j, Eigen::Index num_rows, Eigen::Index num_cols) { output_kernel(output_mapper, params, i, j, num_rows, num_cols); - }; + }); if (filter.dim_size(0) == 1 && filter.dim_size(1) == 1 && row_stride_ == 1 && col_stride_ == 1 && padding_ != EXPLICIT) { @@ -130,12 +125,12 @@ class LaunchFusedConv2DWithOutputKernel { Eigen::array, 1> dim_pair; dim_pair[0] = Eigen::IndexPair(1, 0); - functor::MatMulConvFunctor()( + functor::MatMulConvFunctor()( ctx->eigen_device(), output->shaped({conv_width, filter.dim_size(3)}), input.shaped({conv_width, filter.dim_size(2)}), filter.shaped({filter.dim_size(2), filter.dim_size(3)}), - dim_pair, std::move(output_kernel_fn)); + dim_pair, std::move(output_kernel_wrapper)); } else if (filter.dim_size(0) == input.dim_size(1) && filter.dim_size(1) == input.dim_size(2) && row_dilation_ == 1 && @@ -147,16 +142,16 @@ class LaunchFusedConv2DWithOutputKernel { Eigen::array, 1> dim_pair; dim_pair[0] = Eigen::IndexPair(1, 0); - functor::MatMulConvFunctor()( + functor::MatMulConvFunctor()( ctx->eigen_device(), output->shaped({input.dim_size(0), filter.dim_size(3)}), input.shaped({input.dim_size(0), k}), filter.shaped({k, filter.dim_size(3)}), dim_pair, - std::move(output_kernel_fn)); + std::move(output_kernel_wrapper)); } else { if (padding_ == EXPLICIT) { - functor::SpatialConvolution()( + functor::SpatialConvolution()( ctx->eigen_device(), output->tensor(), input.tensor(), filter.tensor(), row_stride_, col_stride_, row_dilation_, col_dilation_, @@ -164,18 +159,43 @@ class LaunchFusedConv2DWithOutputKernel { static_cast(explicit_paddings_[3]), static_cast(explicit_paddings_[4]), static_cast(explicit_paddings_[5]), - std::move(output_kernel_fn)); + std::move(output_kernel_wrapper)); } else { - functor::SpatialConvolution()( + functor::SpatialConvolution()( ctx->eigen_device(), output->tensor(), input.tensor(), filter.tensor(), row_stride_, col_stride_, row_dilation_, col_dilation_, - BrainPadding2EigenPadding(padding_), std::move(output_kernel_fn)); + BrainPadding2EigenPadding(padding_), + std::move(output_kernel_wrapper)); } } } private: + // Wrap output_kernel into type erased struct to reduce the number of unique + // template instantiations for Eigen Tensor contraction expressions. + // + // We do not pass std::function directly as an output kernel because it blows + // up the binary size in debug mode with super long symbol names. + struct OutputKernelWrapper { + using OutputKernelFn = + std::function&, + const Eigen::TensorContractionParams&, Eigen::Index, + Eigen::Index, Eigen::Index, Eigen::Index)>; + + explicit OutputKernelWrapper(OutputKernelFn fn) + : output_kernel_fn(std::move(fn)) {} + + void operator()( + const ContractionOutputMapper& output_mapper, + const Eigen::TensorContractionParams& params, Eigen::Index i, + Eigen::Index j, Eigen::Index num_rows, Eigen::Index num_cols) const { + output_kernel_fn(output_mapper, params, i, j, num_rows, num_cols); + } + + OutputKernelFn output_kernel_fn; + }; + int row_stride_; int col_stride_; int row_dilation_;