From dc9685322deda182a8ef5c585fe65befbfb079aa Mon Sep 17 00:00:00 2001
From: Gaurav Jain <gjn@google.com>
Date: Thu, 23 Jul 2020 11:37:54 -0700
Subject: [PATCH] Handle int64 axis in ReduceTransposer

PiperOrigin-RevId: 322830387
Change-Id: I4c5c7a536926fd032d5efc08cddf67e3844bca38
---
 .../generic_layout_optimizer_transposer.cc    |   7 +-
 ...eneric_layout_optimizer_transposer_test.cc | 256 +++++++++---------
 2 files changed, 140 insertions(+), 123 deletions(-)

diff --git a/tensorflow/core/grappler/optimizers/generic_layout_optimizer_transposer.cc b/tensorflow/core/grappler/optimizers/generic_layout_optimizer_transposer.cc
index a3449621405..0d836fda265 100644
--- a/tensorflow/core/grappler/optimizers/generic_layout_optimizer_transposer.cc
+++ b/tensorflow/core/grappler/optimizers/generic_layout_optimizer_transposer.cc
@@ -1236,7 +1236,12 @@ bool ReduceTransposer::IsAlongAxis(const Tensor& tensor,
     return false;
   }
   for (int i = 0; i < axis_size; ++i) {
-    int local_axis = tensor.flat<int>()(i);
+    int local_axis = 0;
+    if (tensor.dtype() == DT_INT32) {
+      local_axis = tensor.flat<int32>()(i);
+    } else {
+      local_axis = tensor.flat<int64>()(i);
+    }
     if (local_axis < 0) {
       local_axis += rank;
     }
diff --git a/tensorflow/core/grappler/optimizers/generic_layout_optimizer_transposer_test.cc b/tensorflow/core/grappler/optimizers/generic_layout_optimizer_transposer_test.cc
index bf938b650bf..ab0ccf57a4b 100644
--- a/tensorflow/core/grappler/optimizers/generic_layout_optimizer_transposer_test.cc
+++ b/tensorflow/core/grappler/optimizers/generic_layout_optimizer_transposer_test.cc
@@ -370,6 +370,136 @@ class TransposerTest : public ::testing::Test {
 
   void TearDown() override { TF_ASSERT_OK(virtual_cluster_->Shutdown()); }
 
+  template <typename T>
+  void ReduceTransposerKeepDims() {
+#if !(GOOGLE_CUDA || TENSORFLOW_USE_ROCM)
+    GTEST_SKIP() << "Neither CUDA nor ROCm is enabled";
+#endif  // !(GOOGLE_CUDA || TENSORFLOW_USE_ROCM)
+    GrapplerItem item;
+    Scope scope = Scope::NewRootScope();
+
+    auto input =
+        ops::RandomUniform(scope.WithOpName("input"),
+                           {kBatchSize, kHeight, kWidth, kDepthIn}, DT_FLOAT);
+    auto filter =
+        ops::RandomUniform(scope.WithOpName("filter"),
+                           {kHeight, kWidth, kDepthIn, kDepthOut}, DT_FLOAT);
+    Output conv2d = ops::Conv2D(
+        scope.WithOpName("conv2d").WithDevice("/device:GPU:0"), input, filter,
+        {1, 2, 4, 1}, "SAME", ops::Conv2D::DataFormat(kSrcFormat));
+
+    auto axis = ops::Const<T>(scope.WithOpName("axis"), {0, 1, 2}, {3});
+    auto attrs = ops::Sum::Attrs().KeepDims(true);
+    auto sum_op = ops::Sum(scope.WithOpName("sum").WithDevice("/device:GPU:0"),
+                           conv2d, axis, attrs);
+
+    auto z = ops::Identity(scope.WithOpName("z"), sum_op);
+    TF_ASSERT_OK(scope.ToGraphDef(&item.graph));
+
+    TransposeContext context;
+    TF_ASSERT_OK(TransposeContext::InitializeTransposeContext(
+        item, virtual_cluster_.get(), &context));
+    context.AssignDeviceAndDataFormats(kGPU, kSrcFormat, kDstFormat);
+
+    DefaultLayoutSensitiveOpTransposer conv2d_transposer;
+    auto* c2d = context.graph_view->GetNode("conv2d");
+    ASSERT_NE(c2d, nullptr);
+    TF_ASSERT_OK(conv2d_transposer.TransposeNode(&context, c2d));
+
+    ReduceTransposer reducer_transposer;
+    auto* sum = context.graph_view->GetNode("sum");
+    ASSERT_NE(sum, nullptr);
+    TF_ASSERT_OK(reducer_transposer.TransposeNode(&context, sum));
+
+    auto* input_transpose_node = context.graph_view->GetNode(
+        "sum-0-TransposeNHWCToNCHW-LayoutOptimizer");
+    ASSERT_NE(input_transpose_node, nullptr);
+
+    auto* updated_sum_node = context.graph_view->GetNode("sum");
+    ASSERT_NE(updated_sum_node, nullptr);
+    ASSERT_EQ(updated_sum_node->NumRegularFanins(), 2);
+    VerifyRegularFaninMatch(updated_sum_node, 0,
+                            input_transpose_node->GetName(), 0);
+
+    auto* axis_node = context.graph_view->GetNode(
+        "sum-1-DataFormatDimMapNHWCToNCHW-LayoutOptimizer");
+    ASSERT_NE(axis_node, nullptr);
+    ASSERT_EQ(axis_node->NumRegularFanins(), 1);
+    VerifyRegularFaninMatch(axis_node, 0, "axis", 0);
+
+    auto* output_transpose_node = context.graph_view->GetNode(
+        "sum-0-0-TransposeNCHWToNHWC-LayoutOptimizer");
+    ASSERT_NE(output_transpose_node, nullptr);
+
+    auto* z_output_node = context.graph_view->GetNode("z");
+    ASSERT_NE(z_output_node, nullptr);
+    ASSERT_EQ(z_output_node->NumRegularFanins(), 1);
+    VerifyRegularFaninMatch(z_output_node, 0, output_transpose_node->GetName(),
+                            0);
+  }
+
+  template <typename T>
+  void ReduceTransposerValidAxisNode() {
+#if !(GOOGLE_CUDA || TENSORFLOW_USE_ROCM)
+    GTEST_SKIP() << "Neither CUDA nor ROCm is enabled";
+#endif  // !(GOOGLE_CUDA || TENSORFLOW_USE_ROCM)
+    GrapplerItem item;
+    Scope scope = Scope::NewRootScope();
+
+    auto input =
+        ops::RandomUniform(scope.WithOpName("input"),
+                           {kBatchSize, kHeight, kWidth, kDepthIn}, DT_FLOAT);
+    auto filter =
+        ops::RandomUniform(scope.WithOpName("filter"),
+                           {kHeight, kWidth, kDepthIn, kDepthOut}, DT_FLOAT);
+    Output conv2d = ops::Conv2D(
+        scope.WithOpName("conv2d").WithDevice("/device:GPU:0"), input, filter,
+        {1, 2, 4, 1}, "SAME", ops::Conv2D::DataFormat(kSrcFormat));
+
+    auto axis = ops::Const<T>(scope.WithOpName("axis"), {0, 1, 2}, {3});
+    auto sum_op = ops::Max(scope.WithOpName("max").WithDevice("/device:GPU:0"),
+                           conv2d, axis);
+
+    auto z = ops::Identity(scope.WithOpName("z"), sum_op);
+    TF_ASSERT_OK(scope.ToGraphDef(&item.graph));
+
+    TransposeContext context;
+    TF_ASSERT_OK(TransposeContext::InitializeTransposeContext(
+        item, virtual_cluster_.get(), &context));
+    context.AssignDeviceAndDataFormats(kGPU, kSrcFormat, kDstFormat);
+
+    DefaultLayoutSensitiveOpTransposer conv2d_transposer;
+    auto* c2d = context.graph_view->GetNode("conv2d");
+    ASSERT_NE(c2d, nullptr);
+    TF_ASSERT_OK(conv2d_transposer.TransposeNode(&context, c2d));
+
+    ReduceTransposer reducer_transposer;
+    auto* max = context.graph_view->GetNode("max");
+    ASSERT_NE(max, nullptr);
+    TF_ASSERT_OK(reducer_transposer.TransposeNode(&context, max));
+
+    auto* input_transpose_node = context.graph_view->GetNode(
+        "max-0-TransposeNHWCToNCHW-LayoutOptimizer");
+    ASSERT_NE(input_transpose_node, nullptr);
+
+    auto* updated_max_node = context.graph_view->GetNode("max");
+    ASSERT_NE(updated_max_node, nullptr);
+    ASSERT_EQ(updated_max_node->NumRegularFanins(), 2);
+    VerifyRegularFaninMatch(updated_max_node, 0,
+                            input_transpose_node->GetName(), 0);
+
+    auto* axis_node = context.graph_view->GetNode(
+        "max-1-DataFormatDimMapNHWCToNCHW-LayoutOptimizer");
+    ASSERT_NE(axis_node, nullptr);
+    ASSERT_EQ(axis_node->NumRegularFanins(), 1);
+    VerifyRegularFaninMatch(axis_node, 0, "axis", 0);
+
+    auto* z_output_node = context.graph_view->GetNode("z");
+    ASSERT_NE(z_output_node, nullptr);
+    ASSERT_EQ(z_output_node->NumRegularFanins(), 1);
+    VerifyRegularFaninMatch(z_output_node, 0, updated_max_node->GetName(), 0);
+  }
+
   std::unique_ptr<Cluster> virtual_cluster_;
 };
 
@@ -3637,131 +3767,13 @@ TEST_F(TransposerTest, StridedSliceTransposerConstFaninBadRank) {
 }
 
 TEST_F(TransposerTest, ReduceTransposerKeepDims) {
-#if !(GOOGLE_CUDA || TENSORFLOW_USE_ROCM)
-  GTEST_SKIP() << "Neither CUDA nor ROCm is enabled";
-#endif  // !(GOOGLE_CUDA || TENSORFLOW_USE_ROCM)
-  GrapplerItem item;
-  Scope scope = Scope::NewRootScope();
-
-  auto input =
-      ops::RandomUniform(scope.WithOpName("input"),
-                         {kBatchSize, kHeight, kWidth, kDepthIn}, DT_FLOAT);
-  auto filter =
-      ops::RandomUniform(scope.WithOpName("filter"),
-                         {kHeight, kWidth, kDepthIn, kDepthOut}, DT_FLOAT);
-  Output conv2d = ops::Conv2D(
-      scope.WithOpName("conv2d").WithDevice("/device:GPU:0"), input, filter,
-      {1, 2, 4, 1}, "SAME", ops::Conv2D::DataFormat(kSrcFormat));
-
-  auto axis = ops::Const(scope.WithOpName("axis"), {0, 1, 2}, {3});
-  auto attrs = ops::Sum::Attrs().KeepDims(true);
-  auto sum_op = ops::Sum(scope.WithOpName("sum").WithDevice("/device:GPU:0"),
-                         conv2d, axis, attrs);
-
-  auto z = ops::Identity(scope.WithOpName("z"), sum_op);
-  TF_ASSERT_OK(scope.ToGraphDef(&item.graph));
-
-  TransposeContext context;
-  TF_ASSERT_OK(TransposeContext::InitializeTransposeContext(
-      item, virtual_cluster_.get(), &context));
-  context.AssignDeviceAndDataFormats(kGPU, kSrcFormat, kDstFormat);
-
-  DefaultLayoutSensitiveOpTransposer conv2d_transposer;
-  auto* c2d = context.graph_view->GetNode("conv2d");
-  ASSERT_NE(c2d, nullptr);
-  TF_ASSERT_OK(conv2d_transposer.TransposeNode(&context, c2d));
-
-  ReduceTransposer reducer_transposer;
-  auto* sum = context.graph_view->GetNode("sum");
-  ASSERT_NE(sum, nullptr);
-  TF_ASSERT_OK(reducer_transposer.TransposeNode(&context, sum));
-
-  auto* input_transpose_node =
-      context.graph_view->GetNode("sum-0-TransposeNHWCToNCHW-LayoutOptimizer");
-  ASSERT_NE(input_transpose_node, nullptr);
-
-  auto* updated_sum_node = context.graph_view->GetNode("sum");
-  ASSERT_NE(updated_sum_node, nullptr);
-  ASSERT_EQ(updated_sum_node->NumRegularFanins(), 2);
-  VerifyRegularFaninMatch(updated_sum_node, 0, input_transpose_node->GetName(),
-                          0);
-
-  auto* axis_node = context.graph_view->GetNode(
-      "sum-1-DataFormatDimMapNHWCToNCHW-LayoutOptimizer");
-  ASSERT_NE(axis_node, nullptr);
-  ASSERT_EQ(axis_node->NumRegularFanins(), 1);
-  VerifyRegularFaninMatch(axis_node, 0, "axis", 0);
-
-  auto* output_transpose_node = context.graph_view->GetNode(
-      "sum-0-0-TransposeNCHWToNHWC-LayoutOptimizer");
-  ASSERT_NE(output_transpose_node, nullptr);
-
-  auto* z_output_node = context.graph_view->GetNode("z");
-  ASSERT_NE(z_output_node, nullptr);
-  ASSERT_EQ(z_output_node->NumRegularFanins(), 1);
-  VerifyRegularFaninMatch(z_output_node, 0, output_transpose_node->GetName(),
-                          0);
+  ReduceTransposerKeepDims<int32>();
+  ReduceTransposerKeepDims<int64>();
 }
 
 TEST_F(TransposerTest, ReduceTransposerValidAxisNode) {
-#if !(GOOGLE_CUDA || TENSORFLOW_USE_ROCM)
-  GTEST_SKIP() << "Neither CUDA nor ROCm is enabled";
-#endif  // !(GOOGLE_CUDA || TENSORFLOW_USE_ROCM)
-  GrapplerItem item;
-  Scope scope = Scope::NewRootScope();
-
-  auto input =
-      ops::RandomUniform(scope.WithOpName("input"),
-                         {kBatchSize, kHeight, kWidth, kDepthIn}, DT_FLOAT);
-  auto filter =
-      ops::RandomUniform(scope.WithOpName("filter"),
-                         {kHeight, kWidth, kDepthIn, kDepthOut}, DT_FLOAT);
-  Output conv2d = ops::Conv2D(
-      scope.WithOpName("conv2d").WithDevice("/device:GPU:0"), input, filter,
-      {1, 2, 4, 1}, "SAME", ops::Conv2D::DataFormat(kSrcFormat));
-
-  auto axis = ops::Const(scope.WithOpName("axis"), {0, 1, 2}, {3});
-  auto sum_op = ops::Max(scope.WithOpName("max").WithDevice("/device:GPU:0"),
-                         conv2d, axis);
-
-  auto z = ops::Identity(scope.WithOpName("z"), sum_op);
-  TF_ASSERT_OK(scope.ToGraphDef(&item.graph));
-
-  TransposeContext context;
-  TF_ASSERT_OK(TransposeContext::InitializeTransposeContext(
-      item, virtual_cluster_.get(), &context));
-  context.AssignDeviceAndDataFormats(kGPU, kSrcFormat, kDstFormat);
-
-  DefaultLayoutSensitiveOpTransposer conv2d_transposer;
-  auto* c2d = context.graph_view->GetNode("conv2d");
-  ASSERT_NE(c2d, nullptr);
-  TF_ASSERT_OK(conv2d_transposer.TransposeNode(&context, c2d));
-
-  ReduceTransposer reducer_transposer;
-  auto* max = context.graph_view->GetNode("max");
-  ASSERT_NE(max, nullptr);
-  TF_ASSERT_OK(reducer_transposer.TransposeNode(&context, max));
-
-  auto* input_transpose_node =
-      context.graph_view->GetNode("max-0-TransposeNHWCToNCHW-LayoutOptimizer");
-  ASSERT_NE(input_transpose_node, nullptr);
-
-  auto* updated_max_node = context.graph_view->GetNode("max");
-  ASSERT_NE(updated_max_node, nullptr);
-  ASSERT_EQ(updated_max_node->NumRegularFanins(), 2);
-  VerifyRegularFaninMatch(updated_max_node, 0, input_transpose_node->GetName(),
-                          0);
-
-  auto* axis_node = context.graph_view->GetNode(
-      "max-1-DataFormatDimMapNHWCToNCHW-LayoutOptimizer");
-  ASSERT_NE(axis_node, nullptr);
-  ASSERT_EQ(axis_node->NumRegularFanins(), 1);
-  VerifyRegularFaninMatch(axis_node, 0, "axis", 0);
-
-  auto* z_output_node = context.graph_view->GetNode("z");
-  ASSERT_NE(z_output_node, nullptr);
-  ASSERT_EQ(z_output_node->NumRegularFanins(), 1);
-  VerifyRegularFaninMatch(z_output_node, 0, updated_max_node->GetName(), 0);
+  ReduceTransposerValidAxisNode<int32>();
+  ReduceTransposerValidAxisNode<int64>();
 }
 
 TEST(PermutationTest, PermutesVector) {