Fix bug in padding detection in contraction packing
PiperOrigin-RevId: 289016140 Change-Id: Idb936e596dbeb4a2ec3b669dadcdb46d0f7a69d8
This commit is contained in:
parent
c156b61321
commit
6d40b8f587
@ -880,16 +880,21 @@ class TensorContractionSubMapper<
|
|||||||
// possible to guarantee "no padding or skipping" for non-standard packing.
|
// possible to guarantee "no padding or skipping" for non-standard packing.
|
||||||
if (nonStandardPatches()) return true;
|
if (nonStandardPatches()) return true;
|
||||||
|
|
||||||
// Check if output rows and columns matches the PADDING_VALID case. If they
|
// Non zero padding before.
|
||||||
// are it means that there is no padding for the input tensor.
|
if (m_base_mapper.m_rowPaddingTop > 0) return true;
|
||||||
const bool match_rows = m_base_mapper.m_outputRows ==
|
if (m_base_mapper.m_colPaddingLeft > 0) return true;
|
||||||
divup(m_base_mapper.m_inputRows - patchRows() + 1,
|
|
||||||
m_base_mapper.m_row_strides);
|
|
||||||
const bool match_cols = m_base_mapper.m_outputCols ==
|
|
||||||
divup(m_base_mapper.m_inputCols - patchCols() + 1,
|
|
||||||
m_base_mapper.m_col_strides);
|
|
||||||
|
|
||||||
return !match_rows || !match_cols;
|
// Non zero padding after in rows.
|
||||||
|
const Index last_row =
|
||||||
|
(m_base_mapper.m_outputRows - 1) * m_base_mapper.m_row_strides;
|
||||||
|
if (last_row + (patchRows() - 1) >= m_base_mapper.m_inputRows) return true;
|
||||||
|
|
||||||
|
// Non zero padding after in cols.
|
||||||
|
const Index last_col =
|
||||||
|
(m_base_mapper.m_outputCols - 1) * m_base_mapper.m_col_strides;
|
||||||
|
if (last_col + (patchCols() - 1) >= m_base_mapper.m_inputCols) return true;
|
||||||
|
|
||||||
|
return false;
|
||||||
}
|
}
|
||||||
EIGEN_DEVICE_FUNC
|
EIGEN_DEVICE_FUNC
|
||||||
EIGEN_ALWAYS_INLINE bool padRow(const Index row) const {
|
EIGEN_ALWAYS_INLINE bool padRow(const Index row) const {
|
||||||
|
@ -1506,7 +1506,7 @@ static void PackRhsHelper(int iters,
|
|||||||
output_rows =
|
output_rows =
|
||||||
numext::ceil((input_rows_eff - filter_rows + 1.f) / row_strides);
|
numext::ceil((input_rows_eff - filter_rows + 1.f) / row_strides);
|
||||||
output_cols =
|
output_cols =
|
||||||
numext::ceil((input_cols_eff - filter_cols + 1.f) / row_strides);
|
numext::ceil((input_cols_eff - filter_cols + 1.f) / col_strides);
|
||||||
} else {
|
} else {
|
||||||
eigen_assert(false && "not supported");
|
eigen_assert(false && "not supported");
|
||||||
}
|
}
|
||||||
|
Loading…
Reference in New Issue
Block a user