diff --git a/tensorflow/compiler/mlir/tensorflow/ir/tf_ops.cc b/tensorflow/compiler/mlir/tensorflow/ir/tf_ops.cc index 0cc6850b813..b206b281754 100644 --- a/tensorflow/compiler/mlir/tensorflow/ir/tf_ops.cc +++ b/tensorflow/compiler/mlir/tensorflow/ir/tf_ops.cc @@ -1536,6 +1536,38 @@ static LogicalResult Verify(MaxPoolGradOp op) { return success(); } +//===----------------------------------------------------------------------===// +// MeanOp +//===----------------------------------------------------------------------===// + +LogicalResult MeanOp::FoldOperandsPermutation(ArrayRef permutation) { + // Reduction indices must be defined by a constant operation. + auto reduction_op = + dyn_cast_or_null(reduction_indices().getDefiningOp()); + if (!reduction_op) return failure(); + + auto reductions_value = reduction_op.value().dyn_cast(); + if (!reductions_value) return failure(); + + // Prepare new reduction indices according to operand permutation. + SmallVector shuffled_reduction; + llvm::transform(reductions_value.getIntValues(), + std::back_inserter(shuffled_reduction), + [&](APInt idx) { return permutation[idx.getSExtValue()]; }); + + // Add constant operation with a new reduction indices. + OpBuilder builder(getOperation()); + auto type = mlir::RankedTensorType::get(shuffled_reduction.size(), + builder.getIntegerType(64)); + auto values = mlir::DenseIntElementsAttr::get(type, shuffled_reduction); + auto shuffled_reduction_op = builder.create(getLoc(), values); + + // Use new reduction indices. + setOperand(1, shuffled_reduction_op); + + return success(); +} + //===----------------------------------------------------------------------===// // NegOp //===----------------------------------------------------------------------===// diff --git a/tensorflow/compiler/mlir/tensorflow/ir/tf_ops.td b/tensorflow/compiler/mlir/tensorflow/ir/tf_ops.td index b391d5284a5..e95fcbbdad3 100644 --- a/tensorflow/compiler/mlir/tensorflow/ir/tf_ops.td +++ b/tensorflow/compiler/mlir/tensorflow/ir/tf_ops.td @@ -172,7 +172,7 @@ else_branch: A function that takes 'inputs' and returns a list of }]; } -def TF_MeanOp : TF_Op<"Mean", [NoSideEffect]> { +def TF_MeanOp : TF_Op<"Mean", [NoSideEffect, TF_FoldOperandsTransposeInterface]> { let summary = "Computes the mean of elements across dimensions of a tensor."; let description = [{ @@ -195,6 +195,13 @@ retained with length 1. TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>; TF_DerivedOperandTypeAttr Tidx = TF_DerivedOperandTypeAttr<1>; + + let extraClassDeclaration = [{ + // TF_FoldOperandsTransposeInterface: + SmallVector GetLayoutDependentArgs() { return {0}; } + SmallVector GetLayoutDependentResults() { return {}; } + LogicalResult FoldOperandsPermutation(ArrayRef permutation); + }]; } def TF_LegacyCallOp : TF_Op<"LegacyCall", diff --git a/tensorflow/compiler/mlir/tensorflow/tests/layout_optimization_move_transposes_end.mlir b/tensorflow/compiler/mlir/tensorflow/tests/layout_optimization_move_transposes_end.mlir index d89f5cbdf98..4e5a29dcfbe 100644 --- a/tensorflow/compiler/mlir/tensorflow/tests/layout_optimization_move_transposes_end.mlir +++ b/tensorflow/compiler/mlir/tensorflow/tests/layout_optimization_move_transposes_end.mlir @@ -73,6 +73,25 @@ func @fold_into_max_pool(%arg0: tensor<1x64x112x112xf32>) -> tensor<1x56x56x64xf return %2 : tensor<1x56x56x64xf32> } +// CHECK-LABEL: func @fold_into_mean +func @fold_into_mean(%arg0: tensor<1x64x112x112xf32>) -> tensor<1x64xf32> { + + // CHECK: %[[RED_IDX:[0-9]*]] = "tf.Const"() {value = dense<[2, 3]> : tensor<2xi64>} + // CHECK: %[[MEAN:[0-9]*]] = "tf.Mean"(%arg0, %[[RED_IDX]]) + // CHECK-SAME: (tensor<1x64x112x112xf32>, tensor<2xi64>) -> tensor<1x64xf32> + // CHECK: return %[[MEAN]] + + // Transpose NCHW -> NHWC + %0 = "tf.Const"() {value = dense<[0, 2, 3, 1]> : tensor<4xi64>} : () -> tensor<4xi64> + %1 = "tf.Transpose"(%arg0, %0) : (tensor<1x64x112x112xf32>, tensor<4xi64>) -> tensor<1x112x112x64xf32> + + // Compute Mean over spatial dimensions in NHWC format. + %2 = "tf.Const"() {value = dense<[1, 2]> : tensor<2xi64>} : () -> tensor<2xi64> + %3 = "tf.Mean"(%1, %2) : (tensor<1x112x112x64xf32>, tensor<2xi64>) -> tensor<1x64xf32> + + return %3 : tensor<1x64xf32> +} + // CHECK-LABEL: func @fold_into_fused_batch_norm func @fold_into_fused_batch_norm(%arg0: tensor<1x64x112x112xf32>, %arg1: tensor<64xf32>) -> tensor<1x112x112x64xf32> {