Optimize col-major packing for PADDING_VALID convolutions
PiperOrigin-RevId: 288937975 Change-Id: I17f01d80077882e254bed6f92cfdca6b4dbdecb4
This commit is contained in:
parent
20a00d1102
commit
10b22bf930
@ -129,6 +129,7 @@ class TensorContractionInputMapper<
|
|||||||
m_colStride = patch_rows;
|
m_colStride = patch_rows;
|
||||||
|
|
||||||
m_outputRows = tensor.impl().outputRows();
|
m_outputRows = tensor.impl().outputRows();
|
||||||
|
m_outputCols = tensor.impl().outputCols();
|
||||||
m_row_strides = tensor.impl().userRowStride();
|
m_row_strides = tensor.impl().userRowStride();
|
||||||
m_col_strides = tensor.impl().userColStride();
|
m_col_strides = tensor.impl().userColStride();
|
||||||
|
|
||||||
@ -187,6 +188,7 @@ class TensorContractionInputMapper<
|
|||||||
m_inputCols = base_mapper.m_inputCols;
|
m_inputCols = base_mapper.m_inputCols;
|
||||||
|
|
||||||
m_outputRows = base_mapper.m_outputRows;
|
m_outputRows = base_mapper.m_outputRows;
|
||||||
|
m_outputCols = base_mapper.m_outputCols;
|
||||||
m_row_strides = base_mapper.m_row_strides;
|
m_row_strides = base_mapper.m_row_strides;
|
||||||
m_col_strides = base_mapper.m_col_strides;
|
m_col_strides = base_mapper.m_col_strides;
|
||||||
|
|
||||||
@ -652,7 +654,8 @@ class TensorContractionInputMapper<
|
|||||||
Index m_inputRows; // Number of rows in the input tensor
|
Index m_inputRows; // Number of rows in the input tensor
|
||||||
Index m_inputCols; // Number of cols in the input tensor
|
Index m_inputCols; // Number of cols in the input tensor
|
||||||
|
|
||||||
Index m_outputRows; // Number of patch rows
|
Index m_outputRows; // Number of convolution output rows
|
||||||
|
Index m_outputCols; // Number of convolution output column
|
||||||
|
|
||||||
Index m_row_strides; // User specified row stride
|
Index m_row_strides; // User specified row stride
|
||||||
Index m_col_strides; // User specified col stride
|
Index m_col_strides; // User specified col stride
|
||||||
@ -872,6 +875,23 @@ class TensorContractionSubMapper<
|
|||||||
inputIndex, mask<PacketT>(0, num_coeffs));
|
inputIndex, mask<PacketT>(0, num_coeffs));
|
||||||
}
|
}
|
||||||
EIGEN_DEVICE_FUNC
|
EIGEN_DEVICE_FUNC
|
||||||
|
EIGEN_ALWAYS_INLINE bool hasPadding() const {
|
||||||
|
// TODO(ezhulenev): It does seems that for inflated filter it's still
|
||||||
|
// possible to guarantee "no padding or skipping" for non-standard packing.
|
||||||
|
if (nonStandardPatches()) return true;
|
||||||
|
|
||||||
|
// Check if output rows and columns matches the PADDING_VALID case. If they
|
||||||
|
// are it means that there is no padding for the input tensor.
|
||||||
|
const bool match_rows = m_base_mapper.m_outputRows ==
|
||||||
|
divup(m_base_mapper.m_inputRows - patchRows() + 1,
|
||||||
|
m_base_mapper.m_row_strides);
|
||||||
|
const bool match_cols = m_base_mapper.m_outputCols ==
|
||||||
|
divup(m_base_mapper.m_inputCols - patchCols() + 1,
|
||||||
|
m_base_mapper.m_col_strides);
|
||||||
|
|
||||||
|
return !match_rows || !match_cols;
|
||||||
|
}
|
||||||
|
EIGEN_DEVICE_FUNC
|
||||||
EIGEN_ALWAYS_INLINE bool padRow(const Index row) const {
|
EIGEN_ALWAYS_INLINE bool padRow(const Index row) const {
|
||||||
const Index r = m_rowIndex + row;
|
const Index r = m_rowIndex + row;
|
||||||
return r < 0 || r >= m_base_mapper.m_inputRows;
|
return r < 0 || r >= m_base_mapper.m_inputRows;
|
||||||
@ -1629,16 +1649,14 @@ EIGEN_DEVICE_FUNC
|
|||||||
case PADDING_VALID: {
|
case PADDING_VALID: {
|
||||||
const TensorIndex InputRowsEff = InputRows + padding_top + padding_bottom;
|
const TensorIndex InputRowsEff = InputRows + padding_top + padding_bottom;
|
||||||
const TensorIndex InputColsEff = InputCols + padding_left + padding_right;
|
const TensorIndex InputColsEff = InputCols + padding_left + padding_right;
|
||||||
out_height = numext::ceil((InputRowsEff - kernelRowsEff + 1.f) /
|
out_height = divup(InputRowsEff - kernelRowsEff + 1, row_stride);
|
||||||
static_cast<float>(row_stride));
|
out_width = divup(InputColsEff - kernelColsEff + 1, col_stride);
|
||||||
out_width = numext::ceil((InputColsEff - kernelColsEff + 1.f) /
|
|
||||||
static_cast<float>(col_stride));
|
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
case PADDING_SAME: {
|
case PADDING_SAME: {
|
||||||
eigen_assert(!padding_explicit);
|
eigen_assert(!padding_explicit);
|
||||||
out_height = numext::ceil(InputRows / static_cast<float>(row_stride));
|
out_height = divup(InputRows, row_stride);
|
||||||
out_width = numext::ceil(InputCols / static_cast<float>(col_stride));
|
out_width = divup(InputCols, col_stride);
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
default: {
|
default: {
|
||||||
|
|||||||
@ -115,13 +115,23 @@ struct gemm_pack_colmajor_block<
|
|||||||
|
|
||||||
if (standard_patches && (rhs.patchDepth() % packet_size == 0)) {
|
if (standard_patches && (rhs.patchDepth() % packet_size == 0)) {
|
||||||
// Single packet always belong to single patch (row, col).
|
// Single packet always belong to single patch (row, col).
|
||||||
packStandardPatches</*patch_depth_is_multiple_of_packet_size*/ true>(
|
if (rhs.hasPadding()) {
|
||||||
block, rhs, rows, cols);
|
packStandardPatches</*patch_depth_is_multiple_of_packet_size=*/true,
|
||||||
|
/*has_padding=*/true>(block, rhs, rows, cols);
|
||||||
|
} else {
|
||||||
|
packStandardPatches</*patch_depth_is_multiple_of_packet_size=*/true,
|
||||||
|
/*has_padding=*/false>(block, rhs, rows, cols);
|
||||||
|
}
|
||||||
|
|
||||||
} else if (standard_patches) {
|
} else if (standard_patches) {
|
||||||
// Single packet can span across multiple patch rows or columns.
|
// Single packet can span across multiple patch rows or columns.
|
||||||
packStandardPatches</*patch_depth_is_multiple_of_packet_size*/ false>(
|
if (rhs.hasPadding()) {
|
||||||
block, rhs, rows, cols);
|
packStandardPatches</*patch_depth_is_multiple_of_packet_size=*/false,
|
||||||
|
/*has_padding=*/true>(block, rhs, rows, cols);
|
||||||
|
} else {
|
||||||
|
packStandardPatches</*patch_depth_is_multiple_of_packet_size=*/false,
|
||||||
|
/*has_padding=*/false>(block, rhs, rows, cols);
|
||||||
|
}
|
||||||
|
|
||||||
} else if (rhs.patchDepth() % packet_size == 0) {
|
} else if (rhs.patchDepth() % packet_size == 0) {
|
||||||
// Single packet always belong to single patch (row, col).
|
// Single packet always belong to single patch (row, col).
|
||||||
@ -138,8 +148,8 @@ struct gemm_pack_colmajor_block<
|
|||||||
private:
|
private:
|
||||||
// (A) Standard image patches:
|
// (A) Standard image patches:
|
||||||
//
|
//
|
||||||
// (1) in_row_stride = 1 && in_col_stide == 1
|
// (1) patch_row_inflate_strides == 1 AND
|
||||||
// (2) patch_row_inflate_strides == 1 && patch_col_inflate_strides == 1
|
// (2) patch_col_inflate_strides == 1
|
||||||
//
|
//
|
||||||
// Standard patches guarantee that two inner most dimensions (depth and rows)
|
// Standard patches guarantee that two inner most dimensions (depth and rows)
|
||||||
// are contiguous in memory and we can try to squeeze reads from them.
|
// are contiguous in memory and we can try to squeeze reads from them.
|
||||||
@ -154,8 +164,11 @@ struct gemm_pack_colmajor_block<
|
|||||||
// depth dimension size to be a multiple of packet size, so we can skip all
|
// depth dimension size to be a multiple of packet size, so we can skip all
|
||||||
// non vectorized loads and checks, because it's guaranteed that block size
|
// non vectorized loads and checks, because it's guaranteed that block size
|
||||||
// will be a multiple of a packet size (see TensorContractionBlocking).
|
// will be a multiple of a packet size (see TensorContractionBlocking).
|
||||||
|
//
|
||||||
template <bool patch_depth_is_multiple_of_packet_size>
|
// - has_padding: Input tensor has non-zero padding. In this case for each
|
||||||
|
// patch col and row we need to check that it doesn't correspond to the
|
||||||
|
// padded region of original input.
|
||||||
|
template <bool patch_depth_is_multiple_of_packet_size, bool has_padding>
|
||||||
EIGEN_ALWAYS_INLINE void packStandardPatches(Scalar* block,
|
EIGEN_ALWAYS_INLINE void packStandardPatches(Scalar* block,
|
||||||
const DataMapper rhs,
|
const DataMapper rhs,
|
||||||
StorageIndex rows,
|
StorageIndex rows,
|
||||||
@ -177,10 +190,14 @@ struct gemm_pack_colmajor_block<
|
|||||||
|
|
||||||
const StorageIndex start_row = (c == start_col) ? rhs.rowOffset() : 0;
|
const StorageIndex start_row = (c == start_col) ? rhs.rowOffset() : 0;
|
||||||
const StorageIndex max_row = rhs.maxRow(peeled_k, c);
|
const StorageIndex max_row = rhs.maxRow(peeled_k, c);
|
||||||
const bool pad_col = lm.padCol(c);
|
const bool pad_col = has_padding && lm.padCol(c);
|
||||||
|
|
||||||
|
eigen_assert(has_padding || !lm.padCol(c));
|
||||||
|
eigen_assert(has_padding || !lm.padAnyRow(start_row, max_row - 1));
|
||||||
|
|
||||||
// We can squeeze reads for all rows in [start_row, max_row) range.
|
// We can squeeze reads for all rows in [start_row, max_row) range.
|
||||||
if (!pad_col && !lm.padAnyRow(start_row, max_row - 1)) {
|
if (!has_padding ||
|
||||||
|
(!pad_col && !lm.padAnyRow(start_row, max_row - 1))) {
|
||||||
const StorageIndex start_depth =
|
const StorageIndex start_depth =
|
||||||
(c == start_col) ? rhs.depthOffset() : 0;
|
(c == start_col) ? rhs.depthOffset() : 0;
|
||||||
|
|
||||||
@ -196,6 +213,24 @@ struct gemm_pack_colmajor_block<
|
|||||||
eigen_assert((max_depth - start_depth) % packet_size == 0);
|
eigen_assert((max_depth - start_depth) % packet_size == 0);
|
||||||
StorageIndex d = start_depth;
|
StorageIndex d = start_depth;
|
||||||
|
|
||||||
|
const StorageIndex unrolled_depth = max_depth - 4 * packet_size;
|
||||||
|
for (; d <= unrolled_depth; d += 4 * packet_size) {
|
||||||
|
eigen_assert(k < peeled_k);
|
||||||
|
|
||||||
|
Packet p0 = rhs.packetNoPadding(d + 0 * packet_size, base_idx);
|
||||||
|
Packet p1 = rhs.packetNoPadding(d + 1 * packet_size, base_idx);
|
||||||
|
Packet p2 = rhs.packetNoPadding(d + 2 * packet_size, base_idx);
|
||||||
|
Packet p3 = rhs.packetNoPadding(d + 3 * packet_size, base_idx);
|
||||||
|
|
||||||
|
internal::pstoreu(block + 0 * packet_size, p0);
|
||||||
|
internal::pstoreu(block + 1 * packet_size, p1);
|
||||||
|
internal::pstoreu(block + 2 * packet_size, p2);
|
||||||
|
internal::pstoreu(block + 3 * packet_size, p3);
|
||||||
|
|
||||||
|
block += 4 * packet_size;
|
||||||
|
k += 4 * packet_size;
|
||||||
|
}
|
||||||
|
|
||||||
for (; d < max_depth; d += packet_size) {
|
for (; d < max_depth; d += packet_size) {
|
||||||
eigen_assert(k < peeled_k);
|
eigen_assert(k < peeled_k);
|
||||||
internal::pstoreu(block, rhs.packetNoPadding(d, base_idx));
|
internal::pstoreu(block, rhs.packetNoPadding(d, base_idx));
|
||||||
@ -205,8 +240,26 @@ struct gemm_pack_colmajor_block<
|
|||||||
|
|
||||||
} else {
|
} else {
|
||||||
StorageIndex d = start_depth;
|
StorageIndex d = start_depth;
|
||||||
const StorageIndex vectorized_depth = max_depth - packet_size;
|
|
||||||
|
|
||||||
|
const StorageIndex unrolled_depth = max_depth - 4 * packet_size;
|
||||||
|
for (; d <= unrolled_depth; d += 4 * packet_size) {
|
||||||
|
eigen_assert(k < peeled_k);
|
||||||
|
|
||||||
|
Packet p0 = rhs.packetNoPadding(d + 0 * packet_size, base_idx);
|
||||||
|
Packet p1 = rhs.packetNoPadding(d + 1 * packet_size, base_idx);
|
||||||
|
Packet p2 = rhs.packetNoPadding(d + 2 * packet_size, base_idx);
|
||||||
|
Packet p3 = rhs.packetNoPadding(d + 3 * packet_size, base_idx);
|
||||||
|
|
||||||
|
internal::pstoreu(block + 0 * packet_size, p0);
|
||||||
|
internal::pstoreu(block + 1 * packet_size, p1);
|
||||||
|
internal::pstoreu(block + 2 * packet_size, p2);
|
||||||
|
internal::pstoreu(block + 3 * packet_size, p3);
|
||||||
|
|
||||||
|
block += 4 * packet_size;
|
||||||
|
k += 4 * packet_size;
|
||||||
|
}
|
||||||
|
|
||||||
|
const StorageIndex vectorized_depth = max_depth - packet_size;
|
||||||
for (; d <= vectorized_depth; d += packet_size) {
|
for (; d <= vectorized_depth; d += packet_size) {
|
||||||
eigen_assert(k < peeled_k);
|
eigen_assert(k < peeled_k);
|
||||||
internal::pstoreu(block, rhs.packetNoPadding(d, base_idx));
|
internal::pstoreu(block, rhs.packetNoPadding(d, base_idx));
|
||||||
@ -237,7 +290,9 @@ struct gemm_pack_colmajor_block<
|
|||||||
const StorageIndex max_depth =
|
const StorageIndex max_depth =
|
||||||
rhs.maxDepth(peeled_k - k, start_depth);
|
rhs.maxDepth(peeled_k - k, start_depth);
|
||||||
|
|
||||||
const bool pad = pad_col || lm.padRow(r);
|
const bool pad = has_padding && (pad_col || lm.padRow(r));
|
||||||
|
eigen_assert(has_padding || !lm.padRow(r));
|
||||||
|
|
||||||
const StorageIndex base_idx = lm.baseIndex(r, c);
|
const StorageIndex base_idx = lm.baseIndex(r, c);
|
||||||
|
|
||||||
if (patch_depth_is_multiple_of_packet_size) {
|
if (patch_depth_is_multiple_of_packet_size) {
|
||||||
@ -248,7 +303,8 @@ struct gemm_pack_colmajor_block<
|
|||||||
|
|
||||||
for (; d < max_depth; d += packet_size) {
|
for (; d < max_depth; d += packet_size) {
|
||||||
eigen_assert(k < peeled_k);
|
eigen_assert(k < peeled_k);
|
||||||
const Packet p = pad ? pset1<Packet>(Scalar(0))
|
const Packet p = (has_padding && pad)
|
||||||
|
? pset1<Packet>(Scalar(0))
|
||||||
: rhs.packetNoPadding(d, base_idx);
|
: rhs.packetNoPadding(d, base_idx);
|
||||||
internal::pstoreu(block, p);
|
internal::pstoreu(block, p);
|
||||||
block += packet_size;
|
block += packet_size;
|
||||||
@ -256,11 +312,13 @@ struct gemm_pack_colmajor_block<
|
|||||||
}
|
}
|
||||||
|
|
||||||
} else {
|
} else {
|
||||||
const StorageIndex vectorized_depth = max_depth - packet_size;
|
|
||||||
StorageIndex d = start_depth;
|
StorageIndex d = start_depth;
|
||||||
|
|
||||||
|
const StorageIndex vectorized_depth = max_depth - packet_size;
|
||||||
for (; d <= vectorized_depth; d += packet_size) {
|
for (; d <= vectorized_depth; d += packet_size) {
|
||||||
eigen_assert(k < peeled_k);
|
eigen_assert(k < peeled_k);
|
||||||
const Packet p = pad ? pset1<Packet>(Scalar(0))
|
const Packet p = (has_padding && pad)
|
||||||
|
? pset1<Packet>(Scalar(0))
|
||||||
: rhs.packetNoPadding(d, base_idx);
|
: rhs.packetNoPadding(d, base_idx);
|
||||||
internal::pstoreu(block, p);
|
internal::pstoreu(block, p);
|
||||||
block += packet_size;
|
block += packet_size;
|
||||||
@ -269,7 +327,7 @@ struct gemm_pack_colmajor_block<
|
|||||||
|
|
||||||
eigen_assert(k <= peeled_k);
|
eigen_assert(k <= peeled_k);
|
||||||
const Index num_coeffs = CoeffFinalizer::finalize(
|
const Index num_coeffs = CoeffFinalizer::finalize(
|
||||||
block, rhs, base_idx, d, max_depth, pad);
|
block, rhs, base_idx, d, max_depth, has_padding && pad);
|
||||||
|
|
||||||
k += num_coeffs;
|
k += num_coeffs;
|
||||||
block += num_coeffs;
|
block += num_coeffs;
|
||||||
|
|||||||
@ -1382,6 +1382,7 @@ static void PackRhsHelper(int iters,
|
|||||||
int input_depth,
|
int input_depth,
|
||||||
/* Filter (kernel) dimensions: */
|
/* Filter (kernel) dimensions: */
|
||||||
int filter_count, int filter_cols, int filter_rows,
|
int filter_count, int filter_cols, int filter_rows,
|
||||||
|
Eigen::PaddingType padding,
|
||||||
/* Input strides: */
|
/* Input strides: */
|
||||||
int col_strides, int row_strides,
|
int col_strides, int row_strides,
|
||||||
/* Patch inflate strides: */
|
/* Patch inflate strides: */
|
||||||
@ -1489,14 +1490,27 @@ static void PackRhsHelper(int iters,
|
|||||||
row_strides, col_strides, //
|
row_strides, col_strides, //
|
||||||
/*in_row_strides=*/1, /*in_col_strides=*/1, //
|
/*in_row_strides=*/1, /*in_col_strides=*/1, //
|
||||||
patch_row_inflate_stride, patch_col_inflate_stride, //
|
patch_row_inflate_stride, patch_col_inflate_stride, //
|
||||||
Eigen::PADDING_SAME, /*padding_value=*/0.0);
|
padding, /*padding_value=*/0.0);
|
||||||
|
|
||||||
// 2. Reshape extracted patches into "virtual" 2d tensor.
|
// 2. Reshape extracted patches into "virtual" 2d tensor.
|
||||||
// NOTE: This is valid for PADDING_SAME only.
|
|
||||||
Index input_rows_eff = (input_rows - 1) * patch_row_inflate_stride + 1;
|
Index input_rows_eff = (input_rows - 1) * patch_row_inflate_stride + 1;
|
||||||
Index input_cols_eff = (input_cols - 1) * patch_col_inflate_stride + 1;
|
Index input_cols_eff = (input_cols - 1) * patch_col_inflate_stride + 1;
|
||||||
Index output_rows = input_rows_eff / row_strides;
|
|
||||||
Index output_cols = input_cols_eff / col_strides;
|
Index output_rows = 0;
|
||||||
|
Index output_cols = 0;
|
||||||
|
|
||||||
|
if (padding == Eigen::PADDING_SAME) {
|
||||||
|
output_rows = input_rows_eff / row_strides;
|
||||||
|
output_cols = input_cols_eff / col_strides;
|
||||||
|
} else if (padding == Eigen::PADDING_VALID) {
|
||||||
|
output_rows =
|
||||||
|
numext::ceil((input_rows_eff - filter_rows + 1.f) / row_strides);
|
||||||
|
output_cols =
|
||||||
|
numext::ceil((input_cols_eff - filter_cols + 1.f) / row_strides);
|
||||||
|
} else {
|
||||||
|
eigen_assert(false && "not supported");
|
||||||
|
}
|
||||||
|
|
||||||
NewDimension reshape_dims;
|
NewDimension reshape_dims;
|
||||||
reshape_dims[0] = input_depth * filter_rows * filter_cols; // patch size
|
reshape_dims[0] = input_depth * filter_rows * filter_cols; // patch size
|
||||||
reshape_dims[1] = output_rows * output_cols * input_batches; // num_patches
|
reshape_dims[1] = output_rows * output_cols * input_batches; // num_patches
|
||||||
@ -1561,7 +1575,7 @@ static void PackRhsHelper(int iters,
|
|||||||
tensorflow::testing::SetLabel(
|
tensorflow::testing::SetLabel(
|
||||||
absl::StrCat("patch: ", patch_rows, "x", patch_cols, " D", patch_depth,
|
absl::StrCat("patch: ", patch_rows, "x", patch_cols, " D", patch_depth,
|
||||||
"; num_patches=", num_patches, " patch_size=", patch_size,
|
"; num_patches=", num_patches, " patch_size=", patch_size,
|
||||||
" num_inputs=", num_inputs));
|
" num_inputs=", num_inputs, " padding=", padding));
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename T>
|
template <typename T>
|
||||||
@ -1755,24 +1769,24 @@ static void PackLhsHelper(int iters,
|
|||||||
|
|
||||||
#define BM_CONCAT(a, b) a##b
|
#define BM_CONCAT(a, b) a##b
|
||||||
|
|
||||||
#define BM_RHS_NAME(prefix, T, N, H, W, C, FC, FH, FW, SH, SW, ISH, ISW, BR, \
|
#define BM_RHS_NAME(prefix, T, N, H, W, C, FC, FH, FW, PAD, SH, SW, ISH, ISW, \
|
||||||
BC) \
|
BR, BC) \
|
||||||
BM_CONCAT( \
|
BM_CONCAT( \
|
||||||
BM_##prefix##_##T##_##N##_##H##x##W##_IC##C##_FC##FC##_##FH##x##FW, \
|
BM_##prefix##_##T##_##N##_##H##x##W##_IC##C##_FC##FC##_##FH##x##FW, \
|
||||||
_s##SH##x##SW##_is##ISH##x##ISW##_B##BR##x##BC)
|
_##PAD##_s##SH##x##SW##_is##ISH##x##ISW##_B##BR##x##BC)
|
||||||
|
|
||||||
#define BM_PackRhs(T, N, H, W, C, FC, FH, FW, SH, SW, ISH, ISW, BR, BC) \
|
#define BM_PackRhs(T, N, H, W, C, FC, FH, FW, PAD, SH, SW, ISH, ISW, BR, BC) \
|
||||||
static void BM_RHS_NAME(PackRhs, T, N, H, W, C, FC, FH, FW, SH, SW, ISH, \
|
static void BM_RHS_NAME(PackRhs, T, N, H, W, C, FC, FH, FW, PAD, SH, SW, \
|
||||||
ISW, BR, BC)(int iters) { \
|
ISH, ISW, BR, BC)(int iters) { \
|
||||||
PackRhsHelper<T>(iters, N, H, W, C, FC, FH, FW, SH, SW, ISH, ISW, BR, BC); \
|
PackRhsHelper<T>(iters, N, H, W, C, FC, FH, FW, PADDING_##PAD, SH, SW, \
|
||||||
|
ISH, ISW, BR, BC); \
|
||||||
} \
|
} \
|
||||||
BENCHMARK(BM_RHS_NAME(PackRhs, T, N, H, W, C, FC, FH, FW, SH, SW, ISH, ISW, \
|
BENCHMARK(BM_RHS_NAME(PackRhs, T, N, H, W, C, FC, FH, FW, PAD, SH, SW, ISH, \
|
||||||
BR, BC))
|
ISW, BR, BC))
|
||||||
|
|
||||||
// Number of input channel (input depth) it equal to the number of patch
|
// Number of input channel (input depth) it equal to the number of patch
|
||||||
// channels (patch depth).
|
// channels (patch depth).
|
||||||
|
|
||||||
// NOTE: This is the most common case in Tensorflow models.
|
|
||||||
// Fast path: input channel dimension is the multiple of the packet size.
|
// Fast path: input channel dimension is the multiple of the packet size.
|
||||||
BM_PackRhs(/*type*/ float, //
|
BM_PackRhs(/*type*/ float, //
|
||||||
/*batch*/ 32, //
|
/*batch*/ 32, //
|
||||||
@ -1780,6 +1794,7 @@ BM_PackRhs(/*type*/ float, //
|
|||||||
/*channels*/ 32, //
|
/*channels*/ 32, //
|
||||||
/*num_filters*/ 64, //
|
/*num_filters*/ 64, //
|
||||||
/*filter*/ 5, 5, //
|
/*filter*/ 5, 5, //
|
||||||
|
/*padding*/ VALID, //
|
||||||
/*stride*/ 1, 1, //
|
/*stride*/ 1, 1, //
|
||||||
/*patch inflate stride*/ 1, 1, //
|
/*patch inflate stride*/ 1, 1, //
|
||||||
/*block*/ 256, 56);
|
/*block*/ 256, 56);
|
||||||
@ -1790,6 +1805,29 @@ BM_PackRhs(/*type*/ float, //
|
|||||||
/*channels*/ 32, //
|
/*channels*/ 32, //
|
||||||
/*num_filters*/ 64, //
|
/*num_filters*/ 64, //
|
||||||
/*filter*/ 5, 5, //
|
/*filter*/ 5, 5, //
|
||||||
|
/*padding*/ SAME, //
|
||||||
|
/*stride*/ 1, 1, //
|
||||||
|
/*patch inflate stride*/ 1, 1, //
|
||||||
|
/*block*/ 256, 56);
|
||||||
|
|
||||||
|
BM_PackRhs(/*type*/ float, //
|
||||||
|
/*batch*/ 32, //
|
||||||
|
/*image*/ 64, 64, //
|
||||||
|
/*channels*/ 32, //
|
||||||
|
/*num_filters*/ 64, //
|
||||||
|
/*filter*/ 5, 5, //
|
||||||
|
/*padding*/ VALID, //
|
||||||
|
/*stride*/ 2, 2, //
|
||||||
|
/*patch inflate stride*/ 1, 1, //
|
||||||
|
/*block*/ 256, 56);
|
||||||
|
|
||||||
|
BM_PackRhs(/*type*/ float, //
|
||||||
|
/*batch*/ 32, //
|
||||||
|
/*image*/ 64, 64, //
|
||||||
|
/*channels*/ 32, //
|
||||||
|
/*num_filters*/ 64, //
|
||||||
|
/*filter*/ 5, 5, //
|
||||||
|
/*padding*/ SAME, //
|
||||||
/*stride*/ 2, 2, //
|
/*stride*/ 2, 2, //
|
||||||
/*patch inflate stride*/ 1, 1, //
|
/*patch inflate stride*/ 1, 1, //
|
||||||
/*block*/ 256, 56);
|
/*block*/ 256, 56);
|
||||||
@ -1801,6 +1839,7 @@ BM_PackRhs(/*type*/ float, //
|
|||||||
/*channels*/ 30, //
|
/*channels*/ 30, //
|
||||||
/*num_filters*/ 64, //
|
/*num_filters*/ 64, //
|
||||||
/*filter*/ 5, 5, //
|
/*filter*/ 5, 5, //
|
||||||
|
/*padding*/ SAME, //
|
||||||
/*stride*/ 1, 1, //
|
/*stride*/ 1, 1, //
|
||||||
/*patch inflate stride*/ 1, 1, //
|
/*patch inflate stride*/ 1, 1, //
|
||||||
/*block*/ 256, 56);
|
/*block*/ 256, 56);
|
||||||
@ -1811,6 +1850,29 @@ BM_PackRhs(/*type*/ float, //
|
|||||||
/*channels*/ 30, //
|
/*channels*/ 30, //
|
||||||
/*num_filters*/ 64, //
|
/*num_filters*/ 64, //
|
||||||
/*filter*/ 5, 5, //
|
/*filter*/ 5, 5, //
|
||||||
|
/*padding*/ VALID, //
|
||||||
|
/*stride*/ 1, 1, //
|
||||||
|
/*patch inflate stride*/ 1, 1, //
|
||||||
|
/*block*/ 256, 56);
|
||||||
|
|
||||||
|
BM_PackRhs(/*type*/ float, //
|
||||||
|
/*batch*/ 32, //
|
||||||
|
/*image*/ 64, 64, //
|
||||||
|
/*channels*/ 30, //
|
||||||
|
/*num_filters*/ 64, //
|
||||||
|
/*filter*/ 5, 5, //
|
||||||
|
/*padding*/ SAME, //
|
||||||
|
/*stride*/ 2, 2, //
|
||||||
|
/*patch inflate stride*/ 1, 1, //
|
||||||
|
/*block*/ 256, 56);
|
||||||
|
|
||||||
|
BM_PackRhs(/*type*/ float, //
|
||||||
|
/*batch*/ 32, //
|
||||||
|
/*image*/ 64, 64, //
|
||||||
|
/*channels*/ 30, //
|
||||||
|
/*num_filters*/ 64, //
|
||||||
|
/*filter*/ 5, 5, //
|
||||||
|
/*padding*/ VALID, //
|
||||||
/*stride*/ 2, 2, //
|
/*stride*/ 2, 2, //
|
||||||
/*patch inflate stride*/ 1, 1, //
|
/*patch inflate stride*/ 1, 1, //
|
||||||
/*block*/ 256, 56);
|
/*block*/ 256, 56);
|
||||||
@ -1822,6 +1884,7 @@ BM_PackRhs(/*type*/ float, //
|
|||||||
/*channels*/ 4, //
|
/*channels*/ 4, //
|
||||||
/*num_filters*/ 16, //
|
/*num_filters*/ 16, //
|
||||||
/*filter*/ 8, 8, //
|
/*filter*/ 8, 8, //
|
||||||
|
/*padding*/ SAME, //
|
||||||
/*stride*/ 1, 1, //
|
/*stride*/ 1, 1, //
|
||||||
/*patch inflate stride*/ 1, 1, //
|
/*patch inflate stride*/ 1, 1, //
|
||||||
/*block*/ 256, 56);
|
/*block*/ 256, 56);
|
||||||
@ -1832,6 +1895,29 @@ BM_PackRhs(/*type*/ float, //
|
|||||||
/*channels*/ 4, //
|
/*channels*/ 4, //
|
||||||
/*num_filters*/ 16, //
|
/*num_filters*/ 16, //
|
||||||
/*filter*/ 8, 8, //
|
/*filter*/ 8, 8, //
|
||||||
|
/*padding*/ VALID, //
|
||||||
|
/*stride*/ 1, 1, //
|
||||||
|
/*patch inflate stride*/ 1, 1, //
|
||||||
|
/*block*/ 256, 56);
|
||||||
|
|
||||||
|
BM_PackRhs(/*type*/ float, //
|
||||||
|
/*batch*/ 32, //
|
||||||
|
/*image*/ 256, 256, //
|
||||||
|
/*channels*/ 4, //
|
||||||
|
/*num_filters*/ 16, //
|
||||||
|
/*filter*/ 8, 8, //
|
||||||
|
/*padding*/ SAME, //
|
||||||
|
/*stride*/ 2, 4, //
|
||||||
|
/*patch inflate stride*/ 1, 1, //
|
||||||
|
/*block*/ 256, 56);
|
||||||
|
|
||||||
|
BM_PackRhs(/*type*/ float, //
|
||||||
|
/*batch*/ 32, //
|
||||||
|
/*image*/ 256, 256, //
|
||||||
|
/*channels*/ 4, //
|
||||||
|
/*num_filters*/ 16, //
|
||||||
|
/*filter*/ 8, 8, //
|
||||||
|
/*padding*/ VALID, //
|
||||||
/*stride*/ 2, 4, //
|
/*stride*/ 2, 4, //
|
||||||
/*patch inflate stride*/ 1, 1, //
|
/*patch inflate stride*/ 1, 1, //
|
||||||
/*block*/ 256, 56);
|
/*block*/ 256, 56);
|
||||||
@ -1843,6 +1929,19 @@ BM_PackRhs(/*type*/ float, //
|
|||||||
/*channels*/ 4, //
|
/*channels*/ 4, //
|
||||||
/*num_filters*/ 16, //
|
/*num_filters*/ 16, //
|
||||||
/*filter*/ 3, 3, //
|
/*filter*/ 3, 3, //
|
||||||
|
/*padding*/ SAME, //
|
||||||
|
/*stride*/ 1, 1, //
|
||||||
|
/*patch inflate stride*/ 1, 1, //
|
||||||
|
/*block*/ 36, 432);
|
||||||
|
|
||||||
|
// Short and wide block with small input channel dimension.
|
||||||
|
BM_PackRhs(/*type*/ float, //
|
||||||
|
/*batch*/ 32, //
|
||||||
|
/*image*/ 64, 64, //
|
||||||
|
/*channels*/ 4, //
|
||||||
|
/*num_filters*/ 16, //
|
||||||
|
/*filter*/ 3, 3, //
|
||||||
|
/*padding*/ VALID, //
|
||||||
/*stride*/ 1, 1, //
|
/*stride*/ 1, 1, //
|
||||||
/*patch inflate stride*/ 1, 1, //
|
/*patch inflate stride*/ 1, 1, //
|
||||||
/*block*/ 36, 432);
|
/*block*/ 36, 432);
|
||||||
@ -1853,16 +1952,41 @@ BM_PackRhs(/*type*/ float, //
|
|||||||
/*channels*/ 4, //
|
/*channels*/ 4, //
|
||||||
/*num_filters*/ 16, //
|
/*num_filters*/ 16, //
|
||||||
/*filter*/ 3, 3, //
|
/*filter*/ 3, 3, //
|
||||||
|
/*padding*/ SAME, //
|
||||||
/*stride*/ 2, 2, //
|
/*stride*/ 2, 2, //
|
||||||
/*patch inflate stride*/ 1, 1, //
|
/*patch inflate stride*/ 1, 1, //
|
||||||
/*block*/ 36, 432);
|
/*block*/ 36, 432);
|
||||||
|
|
||||||
|
BM_PackRhs(/*type*/ float, //
|
||||||
|
/*batch*/ 32, //
|
||||||
|
/*image*/ 64, 64, //
|
||||||
|
/*channels*/ 4, //
|
||||||
|
/*num_filters*/ 16, //
|
||||||
|
/*filter*/ 3, 3, //
|
||||||
|
/*padding*/ VALID, //
|
||||||
|
/*stride*/ 2, 2, //
|
||||||
|
/*patch inflate stride*/ 1, 1, //
|
||||||
|
/*block*/ 36, 432);
|
||||||
|
|
||||||
|
// Non standard patches with inflated strides.
|
||||||
|
BM_PackRhs(/*type*/ float, //
|
||||||
|
/*batch*/ 32, //
|
||||||
|
/*image*/ 32, 32, //
|
||||||
|
/*channels*/ 96, //
|
||||||
|
/*num_filters*/ 96, //
|
||||||
|
/*filter*/ 5, 5, //
|
||||||
|
/*padding*/ SAME, //
|
||||||
|
/*stride*/ 1, 1, //
|
||||||
|
/*patch inflate stride*/ 2, 2, //
|
||||||
|
/*block*/ 272, 240);
|
||||||
|
|
||||||
BM_PackRhs(/*type*/ float, //
|
BM_PackRhs(/*type*/ float, //
|
||||||
/*batch*/ 32, //
|
/*batch*/ 32, //
|
||||||
/*image*/ 32, 32, //
|
/*image*/ 32, 32, //
|
||||||
/*channels*/ 96, //
|
/*channels*/ 96, //
|
||||||
/*num_filters*/ 96, //
|
/*num_filters*/ 96, //
|
||||||
/*filter*/ 5, 5, //
|
/*filter*/ 5, 5, //
|
||||||
|
/*padding*/ VALID, //
|
||||||
/*stride*/ 1, 1, //
|
/*stride*/ 1, 1, //
|
||||||
/*patch inflate stride*/ 2, 2, //
|
/*patch inflate stride*/ 2, 2, //
|
||||||
/*block*/ 272, 240);
|
/*block*/ 272, 240);
|
||||||
@ -1875,6 +1999,7 @@ BM_PackRhs(/*type*/ qint8, //
|
|||||||
/*channels*/ 32, //
|
/*channels*/ 32, //
|
||||||
/*num_filters*/ 64, //
|
/*num_filters*/ 64, //
|
||||||
/*filter*/ 5, 5, //
|
/*filter*/ 5, 5, //
|
||||||
|
/*padding*/ SAME, //
|
||||||
/*stride*/ 1, 1, //
|
/*stride*/ 1, 1, //
|
||||||
/*patch inflate stride*/ 1, 1, //
|
/*patch inflate stride*/ 1, 1, //
|
||||||
/*block*/ 256, 56);
|
/*block*/ 256, 56);
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user