From 2ca35b7a30df39582c1c37cc06c1d13b9d0a2ecb Mon Sep 17 00:00:00 2001
From: Eugene Zhulenev <ezhulenev@google.com>
Date: Thu, 20 Feb 2020 21:11:35 -0800
Subject: [PATCH] [TF:MLIR] Add support for folding Transpose into Mean

PiperOrigin-RevId: 296361326
Change-Id: I677bfd6aa17865514a8770b49bce6b7681d5c289
---
 .../compiler/mlir/tensorflow/ir/tf_ops.cc     | 32 +++++++++++++++++++
 .../compiler/mlir/tensorflow/ir/tf_ops.td     |  9 +++++-
 ...yout_optimization_move_transposes_end.mlir | 19 +++++++++++
 3 files changed, 59 insertions(+), 1 deletion(-)

diff --git a/tensorflow/compiler/mlir/tensorflow/ir/tf_ops.cc b/tensorflow/compiler/mlir/tensorflow/ir/tf_ops.cc
index 0cc6850b813..b206b281754 100644
--- a/tensorflow/compiler/mlir/tensorflow/ir/tf_ops.cc
+++ b/tensorflow/compiler/mlir/tensorflow/ir/tf_ops.cc
@@ -1536,6 +1536,38 @@ static LogicalResult Verify(MaxPoolGradOp op) {
   return success();
 }
 
+//===----------------------------------------------------------------------===//
+// MeanOp
+//===----------------------------------------------------------------------===//
+
+LogicalResult MeanOp::FoldOperandsPermutation(ArrayRef<int64_t> permutation) {
+  // Reduction indices must be defined by a constant operation.
+  auto reduction_op =
+      dyn_cast_or_null<TF::ConstOp>(reduction_indices().getDefiningOp());
+  if (!reduction_op) return failure();
+
+  auto reductions_value = reduction_op.value().dyn_cast<DenseElementsAttr>();
+  if (!reductions_value) return failure();
+
+  // Prepare new reduction indices according to operand permutation.
+  SmallVector<int64_t, 4> shuffled_reduction;
+  llvm::transform(reductions_value.getIntValues(),
+                  std::back_inserter(shuffled_reduction),
+                  [&](APInt idx) { return permutation[idx.getSExtValue()]; });
+
+  // Add constant operation with a new reduction indices.
+  OpBuilder builder(getOperation());
+  auto type = mlir::RankedTensorType::get(shuffled_reduction.size(),
+                                          builder.getIntegerType(64));
+  auto values = mlir::DenseIntElementsAttr::get(type, shuffled_reduction);
+  auto shuffled_reduction_op = builder.create<TF::ConstOp>(getLoc(), values);
+
+  // Use new reduction indices.
+  setOperand(1, shuffled_reduction_op);
+
+  return success();
+}
+
 //===----------------------------------------------------------------------===//
 // NegOp
 //===----------------------------------------------------------------------===//
diff --git a/tensorflow/compiler/mlir/tensorflow/ir/tf_ops.td b/tensorflow/compiler/mlir/tensorflow/ir/tf_ops.td
index b391d5284a5..e95fcbbdad3 100644
--- a/tensorflow/compiler/mlir/tensorflow/ir/tf_ops.td
+++ b/tensorflow/compiler/mlir/tensorflow/ir/tf_ops.td
@@ -172,7 +172,7 @@ else_branch: A function that takes 'inputs' and returns a list of
   }];
 }
 
-def TF_MeanOp : TF_Op<"Mean", [NoSideEffect]> {
+def TF_MeanOp : TF_Op<"Mean", [NoSideEffect, TF_FoldOperandsTransposeInterface]> {
   let summary = "Computes the mean of elements across dimensions of a tensor.";
 
   let description = [{
@@ -195,6 +195,13 @@ retained with length 1.
 
   TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>;
   TF_DerivedOperandTypeAttr Tidx = TF_DerivedOperandTypeAttr<1>;
+
+  let extraClassDeclaration = [{
+    // TF_FoldOperandsTransposeInterface:
+    SmallVector<unsigned, 4> GetLayoutDependentArgs() { return {0}; }
+    SmallVector<unsigned, 4> GetLayoutDependentResults() { return {}; }
+    LogicalResult FoldOperandsPermutation(ArrayRef<int64_t> permutation);
+  }];
 }
 
 def TF_LegacyCallOp : TF_Op<"LegacyCall",
diff --git a/tensorflow/compiler/mlir/tensorflow/tests/layout_optimization_move_transposes_end.mlir b/tensorflow/compiler/mlir/tensorflow/tests/layout_optimization_move_transposes_end.mlir
index d89f5cbdf98..4e5a29dcfbe 100644
--- a/tensorflow/compiler/mlir/tensorflow/tests/layout_optimization_move_transposes_end.mlir
+++ b/tensorflow/compiler/mlir/tensorflow/tests/layout_optimization_move_transposes_end.mlir
@@ -73,6 +73,25 @@ func @fold_into_max_pool(%arg0: tensor<1x64x112x112xf32>) -> tensor<1x56x56x64xf
   return %2 : tensor<1x56x56x64xf32>
 }
 
+// CHECK-LABEL: func @fold_into_mean
+func @fold_into_mean(%arg0: tensor<1x64x112x112xf32>) -> tensor<1x64xf32> {
+
+  // CHECK: %[[RED_IDX:[0-9]*]] = "tf.Const"() {value = dense<[2, 3]> : tensor<2xi64>}
+  // CHECK: %[[MEAN:[0-9]*]] = "tf.Mean"(%arg0, %[[RED_IDX]])
+  // CHECK-SAME: (tensor<1x64x112x112xf32>, tensor<2xi64>) -> tensor<1x64xf32>
+  // CHECK: return %[[MEAN]]
+
+  // Transpose NCHW -> NHWC
+  %0 = "tf.Const"() {value = dense<[0, 2, 3, 1]> : tensor<4xi64>} : () -> tensor<4xi64>
+  %1 = "tf.Transpose"(%arg0, %0) : (tensor<1x64x112x112xf32>, tensor<4xi64>) -> tensor<1x112x112x64xf32>
+
+  // Compute Mean over spatial dimensions in NHWC format.
+  %2 = "tf.Const"() {value = dense<[1, 2]> : tensor<2xi64>} : () -> tensor<2xi64>
+  %3 = "tf.Mean"(%1, %2) : (tensor<1x112x112x64xf32>, tensor<2xi64>) -> tensor<1x64xf32>
+
+  return %3 : tensor<1x64xf32>
+}
+
 // CHECK-LABEL: func @fold_into_fused_batch_norm
 func @fold_into_fused_batch_norm(%arg0: tensor<1x64x112x112xf32>, %arg1: tensor<64xf32>) -> tensor<1x112x112x64xf32> {