Fix bug in padding detection in contraction packing

PiperOrigin-RevId: 289016140
Change-Id: Idb936e596dbeb4a2ec3b669dadcdb46d0f7a69d8
This commit is contained in:
Eugene Zhulenev 2020-01-09 18:50:32 -08:00 committed by TensorFlower Gardener
parent c156b61321
commit 6d40b8f587
2 changed files with 15 additions and 10 deletions

View File

@ -880,16 +880,21 @@ class TensorContractionSubMapper<
// possible to guarantee "no padding or skipping" for non-standard packing.
if (nonStandardPatches()) return true;
// Check if output rows and columns matches the PADDING_VALID case. If they
// are it means that there is no padding for the input tensor.
const bool match_rows = m_base_mapper.m_outputRows ==
divup(m_base_mapper.m_inputRows - patchRows() + 1,
m_base_mapper.m_row_strides);
const bool match_cols = m_base_mapper.m_outputCols ==
divup(m_base_mapper.m_inputCols - patchCols() + 1,
m_base_mapper.m_col_strides);
// Non zero padding before.
if (m_base_mapper.m_rowPaddingTop > 0) return true;
if (m_base_mapper.m_colPaddingLeft > 0) return true;
return !match_rows || !match_cols;
// Non zero padding after in rows.
const Index last_row =
(m_base_mapper.m_outputRows - 1) * m_base_mapper.m_row_strides;
if (last_row + (patchRows() - 1) >= m_base_mapper.m_inputRows) return true;
// Non zero padding after in cols.
const Index last_col =
(m_base_mapper.m_outputCols - 1) * m_base_mapper.m_col_strides;
if (last_col + (patchCols() - 1) >= m_base_mapper.m_inputCols) return true;
return false;
}
EIGEN_DEVICE_FUNC
EIGEN_ALWAYS_INLINE bool padRow(const Index row) const {

View File

@ -1506,7 +1506,7 @@ static void PackRhsHelper(int iters,
output_rows =
numext::ceil((input_rows_eff - filter_rows + 1.f) / row_strides);
output_cols =
numext::ceil((input_cols_eff - filter_cols + 1.f) / row_strides);
numext::ceil((input_cols_eff - filter_cols + 1.f) / col_strides);
} else {
eigen_assert(false && "not supported");
}