[TF:MLIR] Add support for folding Transpose into Mean

PiperOrigin-RevId: 296361326
Change-Id: I677bfd6aa17865514a8770b49bce6b7681d5c289
This commit is contained in:
Eugene Zhulenev 2020-02-20 21:11:35 -08:00 committed by TensorFlower Gardener
parent 120e5e6ea0
commit 2ca35b7a30
3 changed files with 59 additions and 1 deletions
tensorflow/compiler/mlir/tensorflow

View File

@ -1536,6 +1536,38 @@ static LogicalResult Verify(MaxPoolGradOp op) {
return success();
}
//===----------------------------------------------------------------------===//
// MeanOp
//===----------------------------------------------------------------------===//
LogicalResult MeanOp::FoldOperandsPermutation(ArrayRef<int64_t> permutation) {
// Reduction indices must be defined by a constant operation.
auto reduction_op =
dyn_cast_or_null<TF::ConstOp>(reduction_indices().getDefiningOp());
if (!reduction_op) return failure();
auto reductions_value = reduction_op.value().dyn_cast<DenseElementsAttr>();
if (!reductions_value) return failure();
// Prepare new reduction indices according to operand permutation.
SmallVector<int64_t, 4> shuffled_reduction;
llvm::transform(reductions_value.getIntValues(),
std::back_inserter(shuffled_reduction),
[&](APInt idx) { return permutation[idx.getSExtValue()]; });
// Add constant operation with a new reduction indices.
OpBuilder builder(getOperation());
auto type = mlir::RankedTensorType::get(shuffled_reduction.size(),
builder.getIntegerType(64));
auto values = mlir::DenseIntElementsAttr::get(type, shuffled_reduction);
auto shuffled_reduction_op = builder.create<TF::ConstOp>(getLoc(), values);
// Use new reduction indices.
setOperand(1, shuffled_reduction_op);
return success();
}
//===----------------------------------------------------------------------===//
// NegOp
//===----------------------------------------------------------------------===//

View File

@ -172,7 +172,7 @@ else_branch: A function that takes 'inputs' and returns a list of
}];
}
def TF_MeanOp : TF_Op<"Mean", [NoSideEffect]> {
def TF_MeanOp : TF_Op<"Mean", [NoSideEffect, TF_FoldOperandsTransposeInterface]> {
let summary = "Computes the mean of elements across dimensions of a tensor.";
let description = [{
@ -195,6 +195,13 @@ retained with length 1.
TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>;
TF_DerivedOperandTypeAttr Tidx = TF_DerivedOperandTypeAttr<1>;
let extraClassDeclaration = [{
// TF_FoldOperandsTransposeInterface:
SmallVector<unsigned, 4> GetLayoutDependentArgs() { return {0}; }
SmallVector<unsigned, 4> GetLayoutDependentResults() { return {}; }
LogicalResult FoldOperandsPermutation(ArrayRef<int64_t> permutation);
}];
}
def TF_LegacyCallOp : TF_Op<"LegacyCall",

View File

@ -73,6 +73,25 @@ func @fold_into_max_pool(%arg0: tensor<1x64x112x112xf32>) -> tensor<1x56x56x64xf
return %2 : tensor<1x56x56x64xf32>
}
// CHECK-LABEL: func @fold_into_mean
func @fold_into_mean(%arg0: tensor<1x64x112x112xf32>) -> tensor<1x64xf32> {
// CHECK: %[[RED_IDX:[0-9]*]] = "tf.Const"() {value = dense<[2, 3]> : tensor<2xi64>}
// CHECK: %[[MEAN:[0-9]*]] = "tf.Mean"(%arg0, %[[RED_IDX]])
// CHECK-SAME: (tensor<1x64x112x112xf32>, tensor<2xi64>) -> tensor<1x64xf32>
// CHECK: return %[[MEAN]]
// Transpose NCHW -> NHWC
%0 = "tf.Const"() {value = dense<[0, 2, 3, 1]> : tensor<4xi64>} : () -> tensor<4xi64>
%1 = "tf.Transpose"(%arg0, %0) : (tensor<1x64x112x112xf32>, tensor<4xi64>) -> tensor<1x112x112x64xf32>
// Compute Mean over spatial dimensions in NHWC format.
%2 = "tf.Const"() {value = dense<[1, 2]> : tensor<2xi64>} : () -> tensor<2xi64>
%3 = "tf.Mean"(%1, %2) : (tensor<1x112x112x64xf32>, tensor<2xi64>) -> tensor<1x64xf32>
return %3 : tensor<1x64xf32>
}
// CHECK-LABEL: func @fold_into_fused_batch_norm
func @fold_into_fused_batch_norm(%arg0: tensor<1x64x112x112xf32>, %arg1: tensor<64xf32>) -> tensor<1x112x112x64xf32> {