Support folding TF::TransposeOp when perm is a constant instead of TF::ConstOp
PiperOrigin-RevId: 328149666 Change-Id: I0c5561152383f12126ab9568c0facc4c3043c6a3
This commit is contained in:
parent
9578a394a0
commit
484f0e5fd9
@ -1939,11 +1939,9 @@ void TransposeOp::build(OpBuilder &builder, OperationState &result, Value x,
|
||||
namespace {
|
||||
|
||||
OpFoldResult FoldIdentityTranspose(TransposeOp op) {
|
||||
auto const_perm = dyn_cast_or_null<TF::ConstOp>(op.perm().getDefiningOp());
|
||||
if (!const_perm) return {};
|
||||
|
||||
auto const_value = const_perm.value();
|
||||
const auto elements = const_value.getValues<APInt>();
|
||||
DenseIntElementsAttr perm;
|
||||
if (!matchPattern(op.perm(), m_Constant(&perm))) return {};
|
||||
const auto elements = perm.getValues<APInt>();
|
||||
|
||||
for (auto it : llvm::enumerate(elements)) {
|
||||
if (it.index() != it.value()) return {};
|
||||
@ -1966,14 +1964,14 @@ OpFoldResult FoldCancellableTranspose(TransposeOp op) {
|
||||
if (!transpose) return {};
|
||||
|
||||
// Permutations defined by constant operations.
|
||||
auto perm0 = dyn_cast_or_null<TF::ConstOp>(op.perm().getDefiningOp());
|
||||
auto perm1 = dyn_cast_or_null<TF::ConstOp>(transpose.perm().getDefiningOp());
|
||||
if (!perm0 || !perm1) return {};
|
||||
DenseIntElementsAttr perm0;
|
||||
DenseIntElementsAttr perm1;
|
||||
if (!matchPattern(op.perm(), m_Constant(&perm0)) ||
|
||||
!matchPattern(transpose.perm(), m_Constant(&perm1)))
|
||||
return {};
|
||||
|
||||
// With permutation indices that cancel each other
|
||||
auto perm0_value = perm0.value().cast<DenseIntElementsAttr>();
|
||||
auto perm1_value = perm1.value().cast<DenseIntElementsAttr>();
|
||||
if (!AreCancellablePermutations(perm0_value, perm1_value)) return {};
|
||||
if (!AreCancellablePermutations(perm0, perm1)) return {};
|
||||
|
||||
return transpose.x();
|
||||
}
|
||||
|
@ -702,6 +702,15 @@ func @identityTranspose(%arg0: tensor<2x3x4x5x6xf32>) -> tensor<2x3x4x5x6xf32> {
|
||||
// CHECK: return %arg0
|
||||
}
|
||||
|
||||
// CHECK-LABEL: @identityTransposeConst
|
||||
func @identityTransposeConst(%arg0: tensor<2x3x4x5x6xf32>) -> tensor<2x3x4x5x6xf32> {
|
||||
%0 = constant dense<[0, 1, 2, 3, 4]> : tensor<5xi32>
|
||||
%1 = "tf.Transpose"(%arg0, %0) : (tensor<2x3x4x5x6xf32>, tensor<5xi32>) -> tensor<2x3x4x5x6xf32>
|
||||
|
||||
return %1 : tensor<2x3x4x5x6xf32>
|
||||
// CHECK: return %arg0
|
||||
}
|
||||
|
||||
// CHECK-LABEL: @nonIdentityTranspose
|
||||
func @nonIdentityTranspose(%arg0: tensor<2x3x4x5x6xf32>) -> tensor<2x3x4x6x5xf32> {
|
||||
%0 = "tf.Const"() {value = dense<[0, 1, 2, 4, 3]> : tensor<5xi32>} : () -> tensor<5xi32>
|
||||
@ -724,6 +733,17 @@ func @cancellableTranspose(%arg0: tensor<1x4x4x8xf32>) -> tensor<1x4x4x8xf32> {
|
||||
// CHECK: return %arg0
|
||||
}
|
||||
|
||||
// CHECK-LABEL: @cancellableTransposeConst
|
||||
func @cancellableTransposeConst(%arg0: tensor<1x4x4x8xf32>) -> tensor<1x4x4x8xf32> {
|
||||
%0 = constant dense<[0, 3, 1, 2]> : tensor<4xi32>
|
||||
%1 = constant dense<[0, 2, 3, 1]> : tensor<4xi32>
|
||||
%2 = "tf.Transpose"(%arg0, %0) : (tensor<1x4x4x8xf32>, tensor<4xi32>) -> tensor<1x8x4x4xf32>
|
||||
%3 = "tf.Transpose"(%2, %1) : (tensor<1x8x4x4xf32>, tensor<4xi32>) -> tensor<1x4x4x8xf32>
|
||||
|
||||
return %3 : tensor<1x4x4x8xf32>
|
||||
// CHECK: return %arg0
|
||||
}
|
||||
|
||||
// CHECK-LABEL: @nonCancellableTranspose
|
||||
func @nonCancellableTranspose(%arg0: tensor<1x4x4x8xf32>) -> tensor<4x1x4x8xf32> {
|
||||
%0 = "tf.Const"() {value = dense<[0, 3, 1, 2]> : tensor<4xi32>} : () -> tensor<4xi32>
|
||||
|
Loading…
x
Reference in New Issue
Block a user