[TF:MLIR] Add support for folding Transpose into Mean
PiperOrigin-RevId: 296361326 Change-Id: I677bfd6aa17865514a8770b49bce6b7681d5c289
This commit is contained in:
parent
120e5e6ea0
commit
2ca35b7a30
tensorflow/compiler/mlir/tensorflow
@ -1536,6 +1536,38 @@ static LogicalResult Verify(MaxPoolGradOp op) {
|
||||
return success();
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// MeanOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
LogicalResult MeanOp::FoldOperandsPermutation(ArrayRef<int64_t> permutation) {
|
||||
// Reduction indices must be defined by a constant operation.
|
||||
auto reduction_op =
|
||||
dyn_cast_or_null<TF::ConstOp>(reduction_indices().getDefiningOp());
|
||||
if (!reduction_op) return failure();
|
||||
|
||||
auto reductions_value = reduction_op.value().dyn_cast<DenseElementsAttr>();
|
||||
if (!reductions_value) return failure();
|
||||
|
||||
// Prepare new reduction indices according to operand permutation.
|
||||
SmallVector<int64_t, 4> shuffled_reduction;
|
||||
llvm::transform(reductions_value.getIntValues(),
|
||||
std::back_inserter(shuffled_reduction),
|
||||
[&](APInt idx) { return permutation[idx.getSExtValue()]; });
|
||||
|
||||
// Add constant operation with a new reduction indices.
|
||||
OpBuilder builder(getOperation());
|
||||
auto type = mlir::RankedTensorType::get(shuffled_reduction.size(),
|
||||
builder.getIntegerType(64));
|
||||
auto values = mlir::DenseIntElementsAttr::get(type, shuffled_reduction);
|
||||
auto shuffled_reduction_op = builder.create<TF::ConstOp>(getLoc(), values);
|
||||
|
||||
// Use new reduction indices.
|
||||
setOperand(1, shuffled_reduction_op);
|
||||
|
||||
return success();
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// NegOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
@ -172,7 +172,7 @@ else_branch: A function that takes 'inputs' and returns a list of
|
||||
}];
|
||||
}
|
||||
|
||||
def TF_MeanOp : TF_Op<"Mean", [NoSideEffect]> {
|
||||
def TF_MeanOp : TF_Op<"Mean", [NoSideEffect, TF_FoldOperandsTransposeInterface]> {
|
||||
let summary = "Computes the mean of elements across dimensions of a tensor.";
|
||||
|
||||
let description = [{
|
||||
@ -195,6 +195,13 @@ retained with length 1.
|
||||
|
||||
TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>;
|
||||
TF_DerivedOperandTypeAttr Tidx = TF_DerivedOperandTypeAttr<1>;
|
||||
|
||||
let extraClassDeclaration = [{
|
||||
// TF_FoldOperandsTransposeInterface:
|
||||
SmallVector<unsigned, 4> GetLayoutDependentArgs() { return {0}; }
|
||||
SmallVector<unsigned, 4> GetLayoutDependentResults() { return {}; }
|
||||
LogicalResult FoldOperandsPermutation(ArrayRef<int64_t> permutation);
|
||||
}];
|
||||
}
|
||||
|
||||
def TF_LegacyCallOp : TF_Op<"LegacyCall",
|
||||
|
@ -73,6 +73,25 @@ func @fold_into_max_pool(%arg0: tensor<1x64x112x112xf32>) -> tensor<1x56x56x64xf
|
||||
return %2 : tensor<1x56x56x64xf32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @fold_into_mean
|
||||
func @fold_into_mean(%arg0: tensor<1x64x112x112xf32>) -> tensor<1x64xf32> {
|
||||
|
||||
// CHECK: %[[RED_IDX:[0-9]*]] = "tf.Const"() {value = dense<[2, 3]> : tensor<2xi64>}
|
||||
// CHECK: %[[MEAN:[0-9]*]] = "tf.Mean"(%arg0, %[[RED_IDX]])
|
||||
// CHECK-SAME: (tensor<1x64x112x112xf32>, tensor<2xi64>) -> tensor<1x64xf32>
|
||||
// CHECK: return %[[MEAN]]
|
||||
|
||||
// Transpose NCHW -> NHWC
|
||||
%0 = "tf.Const"() {value = dense<[0, 2, 3, 1]> : tensor<4xi64>} : () -> tensor<4xi64>
|
||||
%1 = "tf.Transpose"(%arg0, %0) : (tensor<1x64x112x112xf32>, tensor<4xi64>) -> tensor<1x112x112x64xf32>
|
||||
|
||||
// Compute Mean over spatial dimensions in NHWC format.
|
||||
%2 = "tf.Const"() {value = dense<[1, 2]> : tensor<2xi64>} : () -> tensor<2xi64>
|
||||
%3 = "tf.Mean"(%1, %2) : (tensor<1x112x112x64xf32>, tensor<2xi64>) -> tensor<1x64xf32>
|
||||
|
||||
return %3 : tensor<1x64xf32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @fold_into_fused_batch_norm
|
||||
func @fold_into_fused_batch_norm(%arg0: tensor<1x64x112x112xf32>, %arg1: tensor<64xf32>) -> tensor<1x112x112x64xf32> {
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user