Wrap output kernel function into struct to reduce the binary size
PiperOrigin-RevId: 329417755 Change-Id: I408dd8e7209d59376684399faed9122a092606c0
This commit is contained in:
parent
d780cdcbe1
commit
dbb4fe3fe1
@ -106,20 +106,15 @@ class LaunchFusedConv2DWithOutputKernel {
|
||||
template <typename OutputKernel>
|
||||
void operator()(const OutputKernel& output_kernel, OpKernelContext* ctx,
|
||||
const Tensor& input, const Tensor& filter, Tensor* output) {
|
||||
// Wrap output_kernel into type erased function to reduce the number of
|
||||
// Wrap output_kernel into type erased wrapper to reduce the number of
|
||||
// unique template instantiations for Eigen Tensor contraction expressions.
|
||||
using OutputKernelFn =
|
||||
std::function<void(const ContractionOutputMapper<T, Eigen::Index>&,
|
||||
const Eigen::TensorContractionParams&, Eigen::Index,
|
||||
Eigen::Index, Eigen::Index, Eigen::Index)>;
|
||||
|
||||
OutputKernelFn output_kernel_fn =
|
||||
OutputKernelWrapper output_kernel_wrapper(
|
||||
[&output_kernel](
|
||||
const ContractionOutputMapper<T, Eigen::Index>& output_mapper,
|
||||
const Eigen::TensorContractionParams& params, Eigen::Index i,
|
||||
Eigen::Index j, Eigen::Index num_rows, Eigen::Index num_cols) {
|
||||
output_kernel(output_mapper, params, i, j, num_rows, num_cols);
|
||||
};
|
||||
});
|
||||
|
||||
if (filter.dim_size(0) == 1 && filter.dim_size(1) == 1 &&
|
||||
row_stride_ == 1 && col_stride_ == 1 && padding_ != EXPLICIT) {
|
||||
@ -130,12 +125,12 @@ class LaunchFusedConv2DWithOutputKernel {
|
||||
|
||||
Eigen::array<Eigen::IndexPair<Eigen::DenseIndex>, 1> dim_pair;
|
||||
dim_pair[0] = Eigen::IndexPair<Eigen::DenseIndex>(1, 0);
|
||||
functor::MatMulConvFunctor<CPUDevice, T, OutputKernelFn>()(
|
||||
functor::MatMulConvFunctor<CPUDevice, T, OutputKernelWrapper>()(
|
||||
ctx->eigen_device<CPUDevice>(),
|
||||
output->shaped<T, 2>({conv_width, filter.dim_size(3)}),
|
||||
input.shaped<T, 2>({conv_width, filter.dim_size(2)}),
|
||||
filter.shaped<T, 2>({filter.dim_size(2), filter.dim_size(3)}),
|
||||
dim_pair, std::move(output_kernel_fn));
|
||||
dim_pair, std::move(output_kernel_wrapper));
|
||||
|
||||
} else if (filter.dim_size(0) == input.dim_size(1) &&
|
||||
filter.dim_size(1) == input.dim_size(2) && row_dilation_ == 1 &&
|
||||
@ -147,16 +142,16 @@ class LaunchFusedConv2DWithOutputKernel {
|
||||
|
||||
Eigen::array<Eigen::IndexPair<Eigen::DenseIndex>, 1> dim_pair;
|
||||
dim_pair[0] = Eigen::IndexPair<Eigen::DenseIndex>(1, 0);
|
||||
functor::MatMulConvFunctor<CPUDevice, T, OutputKernelFn>()(
|
||||
functor::MatMulConvFunctor<CPUDevice, T, OutputKernelWrapper>()(
|
||||
ctx->eigen_device<CPUDevice>(),
|
||||
output->shaped<T, 2>({input.dim_size(0), filter.dim_size(3)}),
|
||||
input.shaped<T, 2>({input.dim_size(0), k}),
|
||||
filter.shaped<T, 2>({k, filter.dim_size(3)}), dim_pair,
|
||||
std::move(output_kernel_fn));
|
||||
std::move(output_kernel_wrapper));
|
||||
|
||||
} else {
|
||||
if (padding_ == EXPLICIT) {
|
||||
functor::SpatialConvolution<CPUDevice, T, OutputKernelFn>()(
|
||||
functor::SpatialConvolution<CPUDevice, T, OutputKernelWrapper>()(
|
||||
ctx->eigen_device<CPUDevice>(), output->tensor<T, 4>(),
|
||||
input.tensor<T, 4>(), filter.tensor<T, 4>(), row_stride_,
|
||||
col_stride_, row_dilation_, col_dilation_,
|
||||
@ -164,18 +159,43 @@ class LaunchFusedConv2DWithOutputKernel {
|
||||
static_cast<int>(explicit_paddings_[3]),
|
||||
static_cast<int>(explicit_paddings_[4]),
|
||||
static_cast<int>(explicit_paddings_[5]),
|
||||
std::move(output_kernel_fn));
|
||||
std::move(output_kernel_wrapper));
|
||||
} else {
|
||||
functor::SpatialConvolution<CPUDevice, T, OutputKernelFn>()(
|
||||
functor::SpatialConvolution<CPUDevice, T, OutputKernelWrapper>()(
|
||||
ctx->eigen_device<CPUDevice>(), output->tensor<T, 4>(),
|
||||
input.tensor<T, 4>(), filter.tensor<T, 4>(), row_stride_,
|
||||
col_stride_, row_dilation_, col_dilation_,
|
||||
BrainPadding2EigenPadding(padding_), std::move(output_kernel_fn));
|
||||
BrainPadding2EigenPadding(padding_),
|
||||
std::move(output_kernel_wrapper));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
private:
|
||||
// Wrap output_kernel into type erased struct to reduce the number of unique
|
||||
// template instantiations for Eigen Tensor contraction expressions.
|
||||
//
|
||||
// We do not pass std::function directly as an output kernel because it blows
|
||||
// up the binary size in debug mode with super long symbol names.
|
||||
struct OutputKernelWrapper {
|
||||
using OutputKernelFn =
|
||||
std::function<void(const ContractionOutputMapper<T, Eigen::Index>&,
|
||||
const Eigen::TensorContractionParams&, Eigen::Index,
|
||||
Eigen::Index, Eigen::Index, Eigen::Index)>;
|
||||
|
||||
explicit OutputKernelWrapper(OutputKernelFn fn)
|
||||
: output_kernel_fn(std::move(fn)) {}
|
||||
|
||||
void operator()(
|
||||
const ContractionOutputMapper<T, Eigen::Index>& output_mapper,
|
||||
const Eigen::TensorContractionParams& params, Eigen::Index i,
|
||||
Eigen::Index j, Eigen::Index num_rows, Eigen::Index num_cols) const {
|
||||
output_kernel_fn(output_mapper, params, i, j, num_rows, num_cols);
|
||||
}
|
||||
|
||||
OutputKernelFn output_kernel_fn;
|
||||
};
|
||||
|
||||
int row_stride_;
|
||||
int col_stride_;
|
||||
int row_dilation_;
|
||||
|
Loading…
x
Reference in New Issue
Block a user