Optimize col-major packing for PADDING_VALID convolutions

PiperOrigin-RevId: 288937975
Change-Id: I17f01d80077882e254bed6f92cfdca6b4dbdecb4
This commit is contained in:
Eugene Zhulenev 2020-01-09 11:33:49 -08:00 committed by TensorFlower Gardener
parent 20a00d1102
commit 10b22bf930
3 changed files with 242 additions and 41 deletions

View File

@ -129,6 +129,7 @@ class TensorContractionInputMapper<
m_colStride = patch_rows;
m_outputRows = tensor.impl().outputRows();
m_outputCols = tensor.impl().outputCols();
m_row_strides = tensor.impl().userRowStride();
m_col_strides = tensor.impl().userColStride();
@ -187,6 +188,7 @@ class TensorContractionInputMapper<
m_inputCols = base_mapper.m_inputCols;
m_outputRows = base_mapper.m_outputRows;
m_outputCols = base_mapper.m_outputCols;
m_row_strides = base_mapper.m_row_strides;
m_col_strides = base_mapper.m_col_strides;
@ -652,7 +654,8 @@ class TensorContractionInputMapper<
Index m_inputRows; // Number of rows in the input tensor
Index m_inputCols; // Number of cols in the input tensor
Index m_outputRows; // Number of patch rows
Index m_outputRows; // Number of convolution output rows
Index m_outputCols; // Number of convolution output column
Index m_row_strides; // User specified row stride
Index m_col_strides; // User specified col stride
@ -872,6 +875,23 @@ class TensorContractionSubMapper<
inputIndex, mask<PacketT>(0, num_coeffs));
}
EIGEN_DEVICE_FUNC
EIGEN_ALWAYS_INLINE bool hasPadding() const {
// TODO(ezhulenev): It does seems that for inflated filter it's still
// possible to guarantee "no padding or skipping" for non-standard packing.
if (nonStandardPatches()) return true;
// Check if output rows and columns matches the PADDING_VALID case. If they
// are it means that there is no padding for the input tensor.
const bool match_rows = m_base_mapper.m_outputRows ==
divup(m_base_mapper.m_inputRows - patchRows() + 1,
m_base_mapper.m_row_strides);
const bool match_cols = m_base_mapper.m_outputCols ==
divup(m_base_mapper.m_inputCols - patchCols() + 1,
m_base_mapper.m_col_strides);
return !match_rows || !match_cols;
}
EIGEN_DEVICE_FUNC
EIGEN_ALWAYS_INLINE bool padRow(const Index row) const {
const Index r = m_rowIndex + row;
return r < 0 || r >= m_base_mapper.m_inputRows;
@ -1629,16 +1649,14 @@ EIGEN_DEVICE_FUNC
case PADDING_VALID: {
const TensorIndex InputRowsEff = InputRows + padding_top + padding_bottom;
const TensorIndex InputColsEff = InputCols + padding_left + padding_right;
out_height = numext::ceil((InputRowsEff - kernelRowsEff + 1.f) /
static_cast<float>(row_stride));
out_width = numext::ceil((InputColsEff - kernelColsEff + 1.f) /
static_cast<float>(col_stride));
out_height = divup(InputRowsEff - kernelRowsEff + 1, row_stride);
out_width = divup(InputColsEff - kernelColsEff + 1, col_stride);
break;
}
case PADDING_SAME: {
eigen_assert(!padding_explicit);
out_height = numext::ceil(InputRows / static_cast<float>(row_stride));
out_width = numext::ceil(InputCols / static_cast<float>(col_stride));
out_height = divup(InputRows, row_stride);
out_width = divup(InputCols, col_stride);
break;
}
default: {

View File

@ -115,13 +115,23 @@ struct gemm_pack_colmajor_block<
if (standard_patches && (rhs.patchDepth() % packet_size == 0)) {
// Single packet always belong to single patch (row, col).
packStandardPatches</*patch_depth_is_multiple_of_packet_size*/ true>(
block, rhs, rows, cols);
if (rhs.hasPadding()) {
packStandardPatches</*patch_depth_is_multiple_of_packet_size=*/true,
/*has_padding=*/true>(block, rhs, rows, cols);
} else {
packStandardPatches</*patch_depth_is_multiple_of_packet_size=*/true,
/*has_padding=*/false>(block, rhs, rows, cols);
}
} else if (standard_patches) {
// Single packet can span across multiple patch rows or columns.
packStandardPatches</*patch_depth_is_multiple_of_packet_size*/ false>(
block, rhs, rows, cols);
if (rhs.hasPadding()) {
packStandardPatches</*patch_depth_is_multiple_of_packet_size=*/false,
/*has_padding=*/true>(block, rhs, rows, cols);
} else {
packStandardPatches</*patch_depth_is_multiple_of_packet_size=*/false,
/*has_padding=*/false>(block, rhs, rows, cols);
}
} else if (rhs.patchDepth() % packet_size == 0) {
// Single packet always belong to single patch (row, col).
@ -138,8 +148,8 @@ struct gemm_pack_colmajor_block<
private:
// (A) Standard image patches:
//
// (1) in_row_stride = 1 && in_col_stide == 1
// (2) patch_row_inflate_strides == 1 && patch_col_inflate_strides == 1
// (1) patch_row_inflate_strides == 1 AND
// (2) patch_col_inflate_strides == 1
//
// Standard patches guarantee that two inner most dimensions (depth and rows)
// are contiguous in memory and we can try to squeeze reads from them.
@ -154,8 +164,11 @@ struct gemm_pack_colmajor_block<
// depth dimension size to be a multiple of packet size, so we can skip all
// non vectorized loads and checks, because it's guaranteed that block size
// will be a multiple of a packet size (see TensorContractionBlocking).
template <bool patch_depth_is_multiple_of_packet_size>
//
// - has_padding: Input tensor has non-zero padding. In this case for each
// patch col and row we need to check that it doesn't correspond to the
// padded region of original input.
template <bool patch_depth_is_multiple_of_packet_size, bool has_padding>
EIGEN_ALWAYS_INLINE void packStandardPatches(Scalar* block,
const DataMapper rhs,
StorageIndex rows,
@ -177,10 +190,14 @@ struct gemm_pack_colmajor_block<
const StorageIndex start_row = (c == start_col) ? rhs.rowOffset() : 0;
const StorageIndex max_row = rhs.maxRow(peeled_k, c);
const bool pad_col = lm.padCol(c);
const bool pad_col = has_padding && lm.padCol(c);
eigen_assert(has_padding || !lm.padCol(c));
eigen_assert(has_padding || !lm.padAnyRow(start_row, max_row - 1));
// We can squeeze reads for all rows in [start_row, max_row) range.
if (!pad_col && !lm.padAnyRow(start_row, max_row - 1)) {
if (!has_padding ||
(!pad_col && !lm.padAnyRow(start_row, max_row - 1))) {
const StorageIndex start_depth =
(c == start_col) ? rhs.depthOffset() : 0;
@ -196,6 +213,24 @@ struct gemm_pack_colmajor_block<
eigen_assert((max_depth - start_depth) % packet_size == 0);
StorageIndex d = start_depth;
const StorageIndex unrolled_depth = max_depth - 4 * packet_size;
for (; d <= unrolled_depth; d += 4 * packet_size) {
eigen_assert(k < peeled_k);
Packet p0 = rhs.packetNoPadding(d + 0 * packet_size, base_idx);
Packet p1 = rhs.packetNoPadding(d + 1 * packet_size, base_idx);
Packet p2 = rhs.packetNoPadding(d + 2 * packet_size, base_idx);
Packet p3 = rhs.packetNoPadding(d + 3 * packet_size, base_idx);
internal::pstoreu(block + 0 * packet_size, p0);
internal::pstoreu(block + 1 * packet_size, p1);
internal::pstoreu(block + 2 * packet_size, p2);
internal::pstoreu(block + 3 * packet_size, p3);
block += 4 * packet_size;
k += 4 * packet_size;
}
for (; d < max_depth; d += packet_size) {
eigen_assert(k < peeled_k);
internal::pstoreu(block, rhs.packetNoPadding(d, base_idx));
@ -205,8 +240,26 @@ struct gemm_pack_colmajor_block<
} else {
StorageIndex d = start_depth;
const StorageIndex vectorized_depth = max_depth - packet_size;
const StorageIndex unrolled_depth = max_depth - 4 * packet_size;
for (; d <= unrolled_depth; d += 4 * packet_size) {
eigen_assert(k < peeled_k);
Packet p0 = rhs.packetNoPadding(d + 0 * packet_size, base_idx);
Packet p1 = rhs.packetNoPadding(d + 1 * packet_size, base_idx);
Packet p2 = rhs.packetNoPadding(d + 2 * packet_size, base_idx);
Packet p3 = rhs.packetNoPadding(d + 3 * packet_size, base_idx);
internal::pstoreu(block + 0 * packet_size, p0);
internal::pstoreu(block + 1 * packet_size, p1);
internal::pstoreu(block + 2 * packet_size, p2);
internal::pstoreu(block + 3 * packet_size, p3);
block += 4 * packet_size;
k += 4 * packet_size;
}
const StorageIndex vectorized_depth = max_depth - packet_size;
for (; d <= vectorized_depth; d += packet_size) {
eigen_assert(k < peeled_k);
internal::pstoreu(block, rhs.packetNoPadding(d, base_idx));
@ -237,7 +290,9 @@ struct gemm_pack_colmajor_block<
const StorageIndex max_depth =
rhs.maxDepth(peeled_k - k, start_depth);
const bool pad = pad_col || lm.padRow(r);
const bool pad = has_padding && (pad_col || lm.padRow(r));
eigen_assert(has_padding || !lm.padRow(r));
const StorageIndex base_idx = lm.baseIndex(r, c);
if (patch_depth_is_multiple_of_packet_size) {
@ -248,7 +303,8 @@ struct gemm_pack_colmajor_block<
for (; d < max_depth; d += packet_size) {
eigen_assert(k < peeled_k);
const Packet p = pad ? pset1<Packet>(Scalar(0))
const Packet p = (has_padding && pad)
? pset1<Packet>(Scalar(0))
: rhs.packetNoPadding(d, base_idx);
internal::pstoreu(block, p);
block += packet_size;
@ -256,11 +312,13 @@ struct gemm_pack_colmajor_block<
}
} else {
const StorageIndex vectorized_depth = max_depth - packet_size;
StorageIndex d = start_depth;
const StorageIndex vectorized_depth = max_depth - packet_size;
for (; d <= vectorized_depth; d += packet_size) {
eigen_assert(k < peeled_k);
const Packet p = pad ? pset1<Packet>(Scalar(0))
const Packet p = (has_padding && pad)
? pset1<Packet>(Scalar(0))
: rhs.packetNoPadding(d, base_idx);
internal::pstoreu(block, p);
block += packet_size;
@ -269,7 +327,7 @@ struct gemm_pack_colmajor_block<
eigen_assert(k <= peeled_k);
const Index num_coeffs = CoeffFinalizer::finalize(
block, rhs, base_idx, d, max_depth, pad);
block, rhs, base_idx, d, max_depth, has_padding && pad);
k += num_coeffs;
block += num_coeffs;

View File

@ -1382,6 +1382,7 @@ static void PackRhsHelper(int iters,
int input_depth,
/* Filter (kernel) dimensions: */
int filter_count, int filter_cols, int filter_rows,
Eigen::PaddingType padding,
/* Input strides: */
int col_strides, int row_strides,
/* Patch inflate strides: */
@ -1489,14 +1490,27 @@ static void PackRhsHelper(int iters,
row_strides, col_strides, //
/*in_row_strides=*/1, /*in_col_strides=*/1, //
patch_row_inflate_stride, patch_col_inflate_stride, //
Eigen::PADDING_SAME, /*padding_value=*/0.0);
padding, /*padding_value=*/0.0);
// 2. Reshape extracted patches into "virtual" 2d tensor.
// NOTE: This is valid for PADDING_SAME only.
Index input_rows_eff = (input_rows - 1) * patch_row_inflate_stride + 1;
Index input_cols_eff = (input_cols - 1) * patch_col_inflate_stride + 1;
Index output_rows = input_rows_eff / row_strides;
Index output_cols = input_cols_eff / col_strides;
Index output_rows = 0;
Index output_cols = 0;
if (padding == Eigen::PADDING_SAME) {
output_rows = input_rows_eff / row_strides;
output_cols = input_cols_eff / col_strides;
} else if (padding == Eigen::PADDING_VALID) {
output_rows =
numext::ceil((input_rows_eff - filter_rows + 1.f) / row_strides);
output_cols =
numext::ceil((input_cols_eff - filter_cols + 1.f) / row_strides);
} else {
eigen_assert(false && "not supported");
}
NewDimension reshape_dims;
reshape_dims[0] = input_depth * filter_rows * filter_cols; // patch size
reshape_dims[1] = output_rows * output_cols * input_batches; // num_patches
@ -1561,7 +1575,7 @@ static void PackRhsHelper(int iters,
tensorflow::testing::SetLabel(
absl::StrCat("patch: ", patch_rows, "x", patch_cols, " D", patch_depth,
"; num_patches=", num_patches, " patch_size=", patch_size,
" num_inputs=", num_inputs));
" num_inputs=", num_inputs, " padding=", padding));
}
template <typename T>
@ -1755,24 +1769,24 @@ static void PackLhsHelper(int iters,
#define BM_CONCAT(a, b) a##b
#define BM_RHS_NAME(prefix, T, N, H, W, C, FC, FH, FW, SH, SW, ISH, ISW, BR, \
BC) \
BM_CONCAT( \
BM_##prefix##_##T##_##N##_##H##x##W##_IC##C##_FC##FC##_##FH##x##FW, \
_s##SH##x##SW##_is##ISH##x##ISW##_B##BR##x##BC)
#define BM_RHS_NAME(prefix, T, N, H, W, C, FC, FH, FW, PAD, SH, SW, ISH, ISW, \
BR, BC) \
BM_CONCAT( \
BM_##prefix##_##T##_##N##_##H##x##W##_IC##C##_FC##FC##_##FH##x##FW, \
_##PAD##_s##SH##x##SW##_is##ISH##x##ISW##_B##BR##x##BC)
#define BM_PackRhs(T, N, H, W, C, FC, FH, FW, SH, SW, ISH, ISW, BR, BC) \
static void BM_RHS_NAME(PackRhs, T, N, H, W, C, FC, FH, FW, SH, SW, ISH, \
ISW, BR, BC)(int iters) { \
PackRhsHelper<T>(iters, N, H, W, C, FC, FH, FW, SH, SW, ISH, ISW, BR, BC); \
} \
BENCHMARK(BM_RHS_NAME(PackRhs, T, N, H, W, C, FC, FH, FW, SH, SW, ISH, ISW, \
BR, BC))
#define BM_PackRhs(T, N, H, W, C, FC, FH, FW, PAD, SH, SW, ISH, ISW, BR, BC) \
static void BM_RHS_NAME(PackRhs, T, N, H, W, C, FC, FH, FW, PAD, SH, SW, \
ISH, ISW, BR, BC)(int iters) { \
PackRhsHelper<T>(iters, N, H, W, C, FC, FH, FW, PADDING_##PAD, SH, SW, \
ISH, ISW, BR, BC); \
} \
BENCHMARK(BM_RHS_NAME(PackRhs, T, N, H, W, C, FC, FH, FW, PAD, SH, SW, ISH, \
ISW, BR, BC))
// Number of input channel (input depth) it equal to the number of patch
// channels (patch depth).
// NOTE: This is the most common case in Tensorflow models.
// Fast path: input channel dimension is the multiple of the packet size.
BM_PackRhs(/*type*/ float, //
/*batch*/ 32, //
@ -1780,6 +1794,7 @@ BM_PackRhs(/*type*/ float, //
/*channels*/ 32, //
/*num_filters*/ 64, //
/*filter*/ 5, 5, //
/*padding*/ VALID, //
/*stride*/ 1, 1, //
/*patch inflate stride*/ 1, 1, //
/*block*/ 256, 56);
@ -1790,6 +1805,29 @@ BM_PackRhs(/*type*/ float, //
/*channels*/ 32, //
/*num_filters*/ 64, //
/*filter*/ 5, 5, //
/*padding*/ SAME, //
/*stride*/ 1, 1, //
/*patch inflate stride*/ 1, 1, //
/*block*/ 256, 56);
BM_PackRhs(/*type*/ float, //
/*batch*/ 32, //
/*image*/ 64, 64, //
/*channels*/ 32, //
/*num_filters*/ 64, //
/*filter*/ 5, 5, //
/*padding*/ VALID, //
/*stride*/ 2, 2, //
/*patch inflate stride*/ 1, 1, //
/*block*/ 256, 56);
BM_PackRhs(/*type*/ float, //
/*batch*/ 32, //
/*image*/ 64, 64, //
/*channels*/ 32, //
/*num_filters*/ 64, //
/*filter*/ 5, 5, //
/*padding*/ SAME, //
/*stride*/ 2, 2, //
/*patch inflate stride*/ 1, 1, //
/*block*/ 256, 56);
@ -1801,6 +1839,7 @@ BM_PackRhs(/*type*/ float, //
/*channels*/ 30, //
/*num_filters*/ 64, //
/*filter*/ 5, 5, //
/*padding*/ SAME, //
/*stride*/ 1, 1, //
/*patch inflate stride*/ 1, 1, //
/*block*/ 256, 56);
@ -1811,6 +1850,29 @@ BM_PackRhs(/*type*/ float, //
/*channels*/ 30, //
/*num_filters*/ 64, //
/*filter*/ 5, 5, //
/*padding*/ VALID, //
/*stride*/ 1, 1, //
/*patch inflate stride*/ 1, 1, //
/*block*/ 256, 56);
BM_PackRhs(/*type*/ float, //
/*batch*/ 32, //
/*image*/ 64, 64, //
/*channels*/ 30, //
/*num_filters*/ 64, //
/*filter*/ 5, 5, //
/*padding*/ SAME, //
/*stride*/ 2, 2, //
/*patch inflate stride*/ 1, 1, //
/*block*/ 256, 56);
BM_PackRhs(/*type*/ float, //
/*batch*/ 32, //
/*image*/ 64, 64, //
/*channels*/ 30, //
/*num_filters*/ 64, //
/*filter*/ 5, 5, //
/*padding*/ VALID, //
/*stride*/ 2, 2, //
/*patch inflate stride*/ 1, 1, //
/*block*/ 256, 56);
@ -1822,6 +1884,7 @@ BM_PackRhs(/*type*/ float, //
/*channels*/ 4, //
/*num_filters*/ 16, //
/*filter*/ 8, 8, //
/*padding*/ SAME, //
/*stride*/ 1, 1, //
/*patch inflate stride*/ 1, 1, //
/*block*/ 256, 56);
@ -1832,6 +1895,29 @@ BM_PackRhs(/*type*/ float, //
/*channels*/ 4, //
/*num_filters*/ 16, //
/*filter*/ 8, 8, //
/*padding*/ VALID, //
/*stride*/ 1, 1, //
/*patch inflate stride*/ 1, 1, //
/*block*/ 256, 56);
BM_PackRhs(/*type*/ float, //
/*batch*/ 32, //
/*image*/ 256, 256, //
/*channels*/ 4, //
/*num_filters*/ 16, //
/*filter*/ 8, 8, //
/*padding*/ SAME, //
/*stride*/ 2, 4, //
/*patch inflate stride*/ 1, 1, //
/*block*/ 256, 56);
BM_PackRhs(/*type*/ float, //
/*batch*/ 32, //
/*image*/ 256, 256, //
/*channels*/ 4, //
/*num_filters*/ 16, //
/*filter*/ 8, 8, //
/*padding*/ VALID, //
/*stride*/ 2, 4, //
/*patch inflate stride*/ 1, 1, //
/*block*/ 256, 56);
@ -1843,6 +1929,19 @@ BM_PackRhs(/*type*/ float, //
/*channels*/ 4, //
/*num_filters*/ 16, //
/*filter*/ 3, 3, //
/*padding*/ SAME, //
/*stride*/ 1, 1, //
/*patch inflate stride*/ 1, 1, //
/*block*/ 36, 432);
// Short and wide block with small input channel dimension.
BM_PackRhs(/*type*/ float, //
/*batch*/ 32, //
/*image*/ 64, 64, //
/*channels*/ 4, //
/*num_filters*/ 16, //
/*filter*/ 3, 3, //
/*padding*/ VALID, //
/*stride*/ 1, 1, //
/*patch inflate stride*/ 1, 1, //
/*block*/ 36, 432);
@ -1853,16 +1952,41 @@ BM_PackRhs(/*type*/ float, //
/*channels*/ 4, //
/*num_filters*/ 16, //
/*filter*/ 3, 3, //
/*padding*/ SAME, //
/*stride*/ 2, 2, //
/*patch inflate stride*/ 1, 1, //
/*block*/ 36, 432);
BM_PackRhs(/*type*/ float, //
/*batch*/ 32, //
/*image*/ 64, 64, //
/*channels*/ 4, //
/*num_filters*/ 16, //
/*filter*/ 3, 3, //
/*padding*/ VALID, //
/*stride*/ 2, 2, //
/*patch inflate stride*/ 1, 1, //
/*block*/ 36, 432);
// Non standard patches with inflated strides.
BM_PackRhs(/*type*/ float, //
/*batch*/ 32, //
/*image*/ 32, 32, //
/*channels*/ 96, //
/*num_filters*/ 96, //
/*filter*/ 5, 5, //
/*padding*/ SAME, //
/*stride*/ 1, 1, //
/*patch inflate stride*/ 2, 2, //
/*block*/ 272, 240);
BM_PackRhs(/*type*/ float, //
/*batch*/ 32, //
/*image*/ 32, 32, //
/*channels*/ 96, //
/*num_filters*/ 96, //
/*filter*/ 5, 5, //
/*padding*/ VALID, //
/*stride*/ 1, 1, //
/*patch inflate stride*/ 2, 2, //
/*block*/ 272, 240);
@ -1875,6 +1999,7 @@ BM_PackRhs(/*type*/ qint8, //
/*channels*/ 32, //
/*num_filters*/ 64, //
/*filter*/ 5, 5, //
/*padding*/ SAME, //
/*stride*/ 1, 1, //
/*patch inflate stride*/ 1, 1, //
/*block*/ 256, 56);