Wrap output kernel function into struct to reduce the binary size

PiperOrigin-RevId: 329417755
Change-Id: I408dd8e7209d59376684399faed9122a092606c0
This commit is contained in:
Eugene Zhulenev 2020-08-31 18:16:52 -07:00 committed by TensorFlower Gardener
parent d780cdcbe1
commit dbb4fe3fe1

View File

@ -106,20 +106,15 @@ class LaunchFusedConv2DWithOutputKernel {
template <typename OutputKernel>
void operator()(const OutputKernel& output_kernel, OpKernelContext* ctx,
const Tensor& input, const Tensor& filter, Tensor* output) {
// Wrap output_kernel into type erased function to reduce the number of
// Wrap output_kernel into type erased wrapper to reduce the number of
// unique template instantiations for Eigen Tensor contraction expressions.
using OutputKernelFn =
std::function<void(const ContractionOutputMapper<T, Eigen::Index>&,
const Eigen::TensorContractionParams&, Eigen::Index,
Eigen::Index, Eigen::Index, Eigen::Index)>;
OutputKernelFn output_kernel_fn =
OutputKernelWrapper output_kernel_wrapper(
[&output_kernel](
const ContractionOutputMapper<T, Eigen::Index>& output_mapper,
const Eigen::TensorContractionParams& params, Eigen::Index i,
Eigen::Index j, Eigen::Index num_rows, Eigen::Index num_cols) {
output_kernel(output_mapper, params, i, j, num_rows, num_cols);
};
});
if (filter.dim_size(0) == 1 && filter.dim_size(1) == 1 &&
row_stride_ == 1 && col_stride_ == 1 && padding_ != EXPLICIT) {
@ -130,12 +125,12 @@ class LaunchFusedConv2DWithOutputKernel {
Eigen::array<Eigen::IndexPair<Eigen::DenseIndex>, 1> dim_pair;
dim_pair[0] = Eigen::IndexPair<Eigen::DenseIndex>(1, 0);
functor::MatMulConvFunctor<CPUDevice, T, OutputKernelFn>()(
functor::MatMulConvFunctor<CPUDevice, T, OutputKernelWrapper>()(
ctx->eigen_device<CPUDevice>(),
output->shaped<T, 2>({conv_width, filter.dim_size(3)}),
input.shaped<T, 2>({conv_width, filter.dim_size(2)}),
filter.shaped<T, 2>({filter.dim_size(2), filter.dim_size(3)}),
dim_pair, std::move(output_kernel_fn));
dim_pair, std::move(output_kernel_wrapper));
} else if (filter.dim_size(0) == input.dim_size(1) &&
filter.dim_size(1) == input.dim_size(2) && row_dilation_ == 1 &&
@ -147,16 +142,16 @@ class LaunchFusedConv2DWithOutputKernel {
Eigen::array<Eigen::IndexPair<Eigen::DenseIndex>, 1> dim_pair;
dim_pair[0] = Eigen::IndexPair<Eigen::DenseIndex>(1, 0);
functor::MatMulConvFunctor<CPUDevice, T, OutputKernelFn>()(
functor::MatMulConvFunctor<CPUDevice, T, OutputKernelWrapper>()(
ctx->eigen_device<CPUDevice>(),
output->shaped<T, 2>({input.dim_size(0), filter.dim_size(3)}),
input.shaped<T, 2>({input.dim_size(0), k}),
filter.shaped<T, 2>({k, filter.dim_size(3)}), dim_pair,
std::move(output_kernel_fn));
std::move(output_kernel_wrapper));
} else {
if (padding_ == EXPLICIT) {
functor::SpatialConvolution<CPUDevice, T, OutputKernelFn>()(
functor::SpatialConvolution<CPUDevice, T, OutputKernelWrapper>()(
ctx->eigen_device<CPUDevice>(), output->tensor<T, 4>(),
input.tensor<T, 4>(), filter.tensor<T, 4>(), row_stride_,
col_stride_, row_dilation_, col_dilation_,
@ -164,18 +159,43 @@ class LaunchFusedConv2DWithOutputKernel {
static_cast<int>(explicit_paddings_[3]),
static_cast<int>(explicit_paddings_[4]),
static_cast<int>(explicit_paddings_[5]),
std::move(output_kernel_fn));
std::move(output_kernel_wrapper));
} else {
functor::SpatialConvolution<CPUDevice, T, OutputKernelFn>()(
functor::SpatialConvolution<CPUDevice, T, OutputKernelWrapper>()(
ctx->eigen_device<CPUDevice>(), output->tensor<T, 4>(),
input.tensor<T, 4>(), filter.tensor<T, 4>(), row_stride_,
col_stride_, row_dilation_, col_dilation_,
BrainPadding2EigenPadding(padding_), std::move(output_kernel_fn));
BrainPadding2EigenPadding(padding_),
std::move(output_kernel_wrapper));
}
}
}
private:
// Wrap output_kernel into type erased struct to reduce the number of unique
// template instantiations for Eigen Tensor contraction expressions.
//
// We do not pass std::function directly as an output kernel because it blows
// up the binary size in debug mode with super long symbol names.
struct OutputKernelWrapper {
using OutputKernelFn =
std::function<void(const ContractionOutputMapper<T, Eigen::Index>&,
const Eigen::TensorContractionParams&, Eigen::Index,
Eigen::Index, Eigen::Index, Eigen::Index)>;
explicit OutputKernelWrapper(OutputKernelFn fn)
: output_kernel_fn(std::move(fn)) {}
void operator()(
const ContractionOutputMapper<T, Eigen::Index>& output_mapper,
const Eigen::TensorContractionParams& params, Eigen::Index i,
Eigen::Index j, Eigen::Index num_rows, Eigen::Index num_cols) const {
output_kernel_fn(output_mapper, params, i, j, num_rows, num_cols);
}
OutputKernelFn output_kernel_fn;
};
int row_stride_;
int col_stride_;
int row_dilation_;