Fix ReduceTransposer in GenericLayoutOptimizer to properly handle 5D tensors.
Reduction axes are not determined correctly when input is 5D. PiperOrigin-RevId: 342977878 Change-Id: Ife73547e1bbcf4060906e15f585ccf21e4fb1ec7
This commit is contained in:
parent
cdae1a1ee0
commit
28c5e97db9
@ -1414,8 +1414,9 @@ bool ReduceTransposer::IsAlongAxis(const Tensor& tensor,
|
||||
return true;
|
||||
}
|
||||
|
||||
bool ReduceTransposer::IsReduceAxisSupported(
|
||||
const TransposeContext& context, const utils::MutableNodeView& node) {
|
||||
bool ReduceTransposer::IsReduceAxisSupported(const TransposeContext& context,
|
||||
const utils::MutableNodeView& node,
|
||||
int rank) {
|
||||
if (KeepDims(node)) {
|
||||
return true;
|
||||
}
|
||||
@ -1436,11 +1437,19 @@ bool ReduceTransposer::IsReduceAxisSupported(
|
||||
auto indices = [&context](absl::Span<const char> labels) {
|
||||
return GetDimensionIndicesFromLabel(context.src_dim_indices, labels);
|
||||
};
|
||||
return IsAlongAxis(tensor, indices({'N', 'H', 'W', 'C'}), kRank) ||
|
||||
IsAlongAxis(tensor, indices({'H', 'W', 'C'}), kRank) ||
|
||||
IsAlongAxis(tensor, indices({'N', 'H', 'W'}), kRank) ||
|
||||
IsAlongAxis(tensor, indices({'H', 'W'}), kRank) ||
|
||||
IsAlongAxis(tensor, indices({'C'}), kRank);
|
||||
if (rank == 5) {
|
||||
return IsAlongAxis(tensor, indices({'N', 'D', 'H', 'W', 'C'}), 5) ||
|
||||
IsAlongAxis(tensor, indices({'D', 'H', 'W', 'C'}), 5) ||
|
||||
IsAlongAxis(tensor, indices({'N', 'D', 'H', 'W'}), 5) ||
|
||||
IsAlongAxis(tensor, indices({'D', 'H', 'W'}), 5) ||
|
||||
IsAlongAxis(tensor, indices({'C'}), 5);
|
||||
}
|
||||
DCHECK_EQ(rank, 4);
|
||||
return IsAlongAxis(tensor, indices({'N', 'H', 'W', 'C'}), 4) ||
|
||||
IsAlongAxis(tensor, indices({'H', 'W', 'C'}), 4) ||
|
||||
IsAlongAxis(tensor, indices({'N', 'H', 'W'}), 4) ||
|
||||
IsAlongAxis(tensor, indices({'H', 'W'}), 4) ||
|
||||
IsAlongAxis(tensor, indices({'C'}), 4);
|
||||
}
|
||||
|
||||
Status ReduceTransposer::TransposeNode(TransposeContext* context,
|
||||
@ -1452,7 +1461,7 @@ Status ReduceTransposer::TransposeNode(TransposeContext* context,
|
||||
}
|
||||
ScopedDataFormatUpgrader data_format_upgrader(context, rank);
|
||||
if (!ShouldProcess(*context, *node) ||
|
||||
!IsReduceAxisSupported(*context, *node) ||
|
||||
!IsReduceAxisSupported(*context, *node, rank) ||
|
||||
!IsAfterDstToSrcTransform(*context, *node)) {
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
@ -422,7 +422,7 @@ class ReduceTransposer : public LayoutAgnosticOpTransposer {
|
||||
bool KeepDims(const utils::MutableNodeView& node);
|
||||
bool IsAlongAxis(const Tensor& tensor, absl::Span<const int> axis, int rank);
|
||||
bool IsReduceAxisSupported(const TransposeContext& context,
|
||||
const utils::MutableNodeView& node);
|
||||
const utils::MutableNodeView& node, int rank);
|
||||
};
|
||||
|
||||
class ReverseV2Transposer : public LayoutAgnosticOpTransposer {
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user