Fix ReduceTransposer in GenericLayoutOptimizer to properly handle 5D tensors.

Reduction axes are not determined correctly when input is 5D.

PiperOrigin-RevId: 342977878
Change-Id: Ife73547e1bbcf4060906e15f585ccf21e4fb1ec7
This commit is contained in:
Andy Ly 2020-11-17 17:10:16 -08:00 committed by TensorFlower Gardener
parent cdae1a1ee0
commit 28c5e97db9
2 changed files with 18 additions and 9 deletions

View File

@ -1414,8 +1414,9 @@ bool ReduceTransposer::IsAlongAxis(const Tensor& tensor,
return true;
}
bool ReduceTransposer::IsReduceAxisSupported(
const TransposeContext& context, const utils::MutableNodeView& node) {
bool ReduceTransposer::IsReduceAxisSupported(const TransposeContext& context,
const utils::MutableNodeView& node,
int rank) {
if (KeepDims(node)) {
return true;
}
@ -1436,11 +1437,19 @@ bool ReduceTransposer::IsReduceAxisSupported(
auto indices = [&context](absl::Span<const char> labels) {
return GetDimensionIndicesFromLabel(context.src_dim_indices, labels);
};
return IsAlongAxis(tensor, indices({'N', 'H', 'W', 'C'}), kRank) ||
IsAlongAxis(tensor, indices({'H', 'W', 'C'}), kRank) ||
IsAlongAxis(tensor, indices({'N', 'H', 'W'}), kRank) ||
IsAlongAxis(tensor, indices({'H', 'W'}), kRank) ||
IsAlongAxis(tensor, indices({'C'}), kRank);
if (rank == 5) {
return IsAlongAxis(tensor, indices({'N', 'D', 'H', 'W', 'C'}), 5) ||
IsAlongAxis(tensor, indices({'D', 'H', 'W', 'C'}), 5) ||
IsAlongAxis(tensor, indices({'N', 'D', 'H', 'W'}), 5) ||
IsAlongAxis(tensor, indices({'D', 'H', 'W'}), 5) ||
IsAlongAxis(tensor, indices({'C'}), 5);
}
DCHECK_EQ(rank, 4);
return IsAlongAxis(tensor, indices({'N', 'H', 'W', 'C'}), 4) ||
IsAlongAxis(tensor, indices({'H', 'W', 'C'}), 4) ||
IsAlongAxis(tensor, indices({'N', 'H', 'W'}), 4) ||
IsAlongAxis(tensor, indices({'H', 'W'}), 4) ||
IsAlongAxis(tensor, indices({'C'}), 4);
}
Status ReduceTransposer::TransposeNode(TransposeContext* context,
@ -1452,7 +1461,7 @@ Status ReduceTransposer::TransposeNode(TransposeContext* context,
}
ScopedDataFormatUpgrader data_format_upgrader(context, rank);
if (!ShouldProcess(*context, *node) ||
!IsReduceAxisSupported(*context, *node) ||
!IsReduceAxisSupported(*context, *node, rank) ||
!IsAfterDstToSrcTransform(*context, *node)) {
return Status::OK();
}

View File

@ -422,7 +422,7 @@ class ReduceTransposer : public LayoutAgnosticOpTransposer {
bool KeepDims(const utils::MutableNodeView& node);
bool IsAlongAxis(const Tensor& tensor, absl::Span<const int> axis, int rank);
bool IsReduceAxisSupported(const TransposeContext& context,
const utils::MutableNodeView& node);
const utils::MutableNodeView& node, int rank);
};
class ReverseV2Transposer : public LayoutAgnosticOpTransposer {