From 1a11d01c1fdd6683e9aa210dccde81de127dbf3e Mon Sep 17 00:00:00 2001
From: Kaixi Hou <kaixih@nvidia.com>
Date: Mon, 14 Sep 2020 15:52:22 -0700
Subject: [PATCH 1/7] support reduce ops for 5d tensors in layout optimizer

---
 .../generic_layout_optimizer_transposer.cc    | 27 +++++++++-
 tensorflow/core/kernels/data_format_ops.cc    | 10 ++--
 tensorflow/core/kernels/data_format_ops.h     | 53 ++++++++++++++-----
 .../python/grappler/layout_optimizer_test.py  | 39 ++++++++++++++
 tensorflow/python/ops/nn_test.py              | 27 ++++++++++
 5 files changed, 136 insertions(+), 20 deletions(-)

diff --git a/tensorflow/core/grappler/optimizers/generic_layout_optimizer_transposer.cc b/tensorflow/core/grappler/optimizers/generic_layout_optimizer_transposer.cc
index 2f806ba6b6a..265a6ae7cac 100644
--- a/tensorflow/core/grappler/optimizers/generic_layout_optimizer_transposer.cc
+++ b/tensorflow/core/grappler/optimizers/generic_layout_optimizer_transposer.cc
@@ -1371,11 +1371,31 @@ bool ReduceTransposer::IsReduceAxisSupported(
 Status ReduceTransposer::TransposeNode(TransposeContext* context,
                                        utils::MutableNodeView* node) {
   DCHECK(IsReduceOp(*node->node()));
-  if (!ShouldProcess(*context, *node) || !IsFaninPortRankN(*node, 0, 4) ||
+  const auto* output_shape_attr = node->GetAttr(kAttrOutputShape);
+  const auto& shape = output_shape_attr->list().shape(0);
+  const int rank = shape.dim_size();
+  std::string src_format = context->src_format;
+  std::string dst_format = context->dst_format;
+  // Update the format from 4D to 5D layout if necessary.
+  if (rank == 5) {
+    std::string src_format_3d = src_format == "NHWC" ? "NDHWC" : "NCDHW";
+    std::string dst_format_3d = dst_format == "NHWC" ? "NDHWC" : "NCDHW";
+    context->AssignDeviceAndDataFormats(context->target_device, src_format_3d,
+                                        dst_format_3d);
+  }
+  if (!ShouldProcess(*context, *node) || !IsFaninPortRankN(*node, 0, rank) ||
       !IsReduceAxisSupported(*context, *node) ||
       !IsAfterDstToSrcTransform(*context, *node)) {
+    // Change back to the original layout due to early exit.
+    if (rank == 5) {
+      context->AssignDeviceAndDataFormats(context->target_device, src_format,
+                                          dst_format);
+    }
     return Status::OK();
   }
+  VLOG(3) << "GenericLayoutOptimizer: transforming node '" << node->GetName()
+          << "' with op '" << node->GetOp() << "' from data format '"
+          << context->src_format << "' to '" << context->dst_format << "'";
   TF_RETURN_IF_ERROR(UpdateFaninEdgesWithOp(context, {0}, node, kOpTranspose));
   TF_RETURN_IF_ERROR(
       UpdateFaninEdgesWithOp(context, {1}, node, kOpDataFormatDimMap));
@@ -1383,6 +1403,11 @@ Status ReduceTransposer::TransposeNode(TransposeContext* context,
     TF_RETURN_IF_ERROR(
         UpdateFanoutEdgesWithOp(context, {0}, node, kOpTranspose));
   }
+  // Change back the format from 5D to 4D layout.
+  if (rank == 5) {
+    context->AssignDeviceAndDataFormats(context->target_device, src_format,
+                                        dst_format);
+  }
   return context->graph_view->GetMutationBuilder()->Apply();
 }
 
diff --git a/tensorflow/core/kernels/data_format_ops.cc b/tensorflow/core/kernels/data_format_ops.cc
index c62c710faf1..2f08f52c6dc 100644
--- a/tensorflow/core/kernels/data_format_ops.cc
+++ b/tensorflow/core/kernels/data_format_ops.cc
@@ -37,14 +37,14 @@ class DataFormatDimMapOp : public OpKernel {
     OP_REQUIRES_OK(context, context->GetAttr("src_format", &src_format));
     string dst_format;
     OP_REQUIRES_OK(context, context->GetAttr("dst_format", &dst_format));
-    OP_REQUIRES(context, src_format.size() == 4,
+    OP_REQUIRES(context, src_format.size() == 4 || src_format.size() == 5,
                 errors::InvalidArgument(strings::StrCat(
-                    "Source format must of length 4, received src_format = ",
-                    src_format)));
+                    "Source format must of length 4 or 5, received "
+                    "src_format = ", src_format)));
     OP_REQUIRES(
-        context, dst_format.size() == 4,
+        context, dst_format.size() == 4 || dst_format.size() == 5,
         errors::InvalidArgument(strings::StrCat(
-            "Destination format must of length 4, received dst_format = ",
+            "Destination format must of length 4 or 5, received dst_format = ",
             dst_format)));
     dst_idx_ = Tensor(DT_INT32, {static_cast<int64>(src_format.size())});
     for (int i = 0; i < src_format.size(); ++i) {
diff --git a/tensorflow/core/kernels/data_format_ops.h b/tensorflow/core/kernels/data_format_ops.h
index bc416fa78bc..89b54901223 100644
--- a/tensorflow/core/kernels/data_format_ops.h
+++ b/tensorflow/core/kernels/data_format_ops.h
@@ -28,24 +28,49 @@ template <typename Device, typename T>
 struct DataFormatDimMap {
   void operator()(const Device& d, typename TTypes<T>::ConstFlat x,
                   typename TTypes<T>::Flat y, const TTypes<int>::Vec dst) {
-    auto zero = x.constant(0);
-    auto one = x.constant(1);
-    auto two = x.constant(2);
+    if (dst.size() == 4) {
+      auto zero = x.constant(0);
+      auto one = x.constant(1);
+      auto two = x.constant(2);
 
-    auto f_zero = x.constant(dst(0));
-    auto f_one = x.constant(dst(1));
-    auto f_two = x.constant(dst(2));
-    auto f_three = x.constant(dst(3));
+      auto f_zero = x.constant(dst(0));
+      auto f_one = x.constant(dst(1));
+      auto f_two = x.constant(dst(2));
+      auto f_three = x.constant(dst(3));
 
-    auto four = x.constant(4);
-    auto x_mod = (x + four) % 4;
+      auto four = x.constant(4);
+      auto x_mod = (x + four) % 4;
 
-    auto is_zero = (x_mod == zero);
-    auto is_one = (x_mod == one);
-    auto is_two = (x_mod == two);
+      auto is_zero = (x_mod == zero);
+      auto is_one = (x_mod == one);
+      auto is_two = (x_mod == two);
 
-    y.device(d) = is_zero.select(
-        f_zero, is_one.select(f_one, is_two.select(f_two, f_three)));
+      y.device(d) = is_zero.select(
+          f_zero, is_one.select(f_one, is_two.select(f_two, f_three)));
+    } else {
+      auto zero = x.constant(0);
+      auto one = x.constant(1);
+      auto two = x.constant(2);
+      auto three = x.constant(3);
+
+      auto f_zero = x.constant(dst(0));
+      auto f_one = x.constant(dst(1));
+      auto f_two = x.constant(dst(2));
+      auto f_three = x.constant(dst(3));
+      auto f_four = x.constant(dst(4));
+
+      auto five = x.constant(5);
+      auto x_mod = (x + five) % 5;
+
+      auto is_zero = (x_mod == zero);
+      auto is_one = (x_mod == one);
+      auto is_two = (x_mod == two);
+      auto is_three = (x_mod == three);
+
+      y.device(d) = is_zero.select(
+          f_zero, is_one.select(f_one, is_two.select(f_two,
+              is_three.select(f_three, f_four))));
+    }
   }
 };
 
diff --git a/tensorflow/python/grappler/layout_optimizer_test.py b/tensorflow/python/grappler/layout_optimizer_test.py
index 198f5a7d83a..db85bf5bccd 100644
--- a/tensorflow/python/grappler/layout_optimizer_test.py
+++ b/tensorflow/python/grappler/layout_optimizer_test.py
@@ -221,6 +221,9 @@ class LayoutOptimizerTest(test.TestCase):
   def _assert_map_nhwc_to_nchw(self, name, nodes):
     self.assertIn(name + '-DimMapNHWCToNCHW-LayoutOptimizer', nodes)
 
+  def _assert_map_ndhwc_to_ncdhw(self, name, nodes):
+    self.assertIn(name + '-DataFormatDimMapNDHWCToNCDHW-LayoutOptimizer', nodes)
+
   def _assert_vec_nchw_to_nhwc(self, name, nodes):
     self.assertIn(name + '-VecPermuteNCHWToNHWC-LayoutOptimizer', nodes)
 
@@ -1194,6 +1197,42 @@ class LayoutOptimizerTest(test.TestCase):
       self._assert_trans_nchw_to_nhwc('LeakyReluGrad-0-0', nodes)
       self.assertAllClose(output_val_ref, output_val, atol=1e-3)
 
+  @test_util.deprecated_graph_mode_only
+  def testReduceOpsFor5DTensors(self):
+    if test.is_gpu_available(cuda_only=True):
+      random_seed.set_random_seed(0)
+      x = random_ops.truncated_normal([1, 4, 2, 3, 3], seed=0)
+      w = random_ops.truncated_normal([2, 2, 2, 3, 3], seed=0)
+      gamma = random_ops.truncated_normal([1, 1, 1, 1, 3], seed=0)
+      beta = random_ops.truncated_normal([1, 1, 1, 1, 3], seed=0)
+      conv3d = gen_nn_ops.conv3d(x, w, [1, 1, 1, 1, 1], 'SAME')
+      y = math_ops.reduce_mean(conv3d, [0, 1, 2, 3], keepdims=True)
+      output = array_ops.identity(y)
+
+      with session.Session(config=_get_config(False)) as sess:
+        output_val_ref = sess.run(output)
+
+      with session.Session(config=_get_config()) as sess:
+        metadata = config_pb2.RunMetadata()
+        output_val = sess.run(output, run_metadata=metadata)
+
+      nodes = []
+      num_transposes = 0
+      for node in metadata.cost_graph.node:
+        if _is_transpose(node.name):
+          num_transposes += 1
+        nodes.append(node.name)
+        print(node.name)
+
+      # The reduce op Mean needs to dim map the input reduce index to NCDHW.
+      # Then, the output needs to be tranposed back to NDHWC.
+      expected_num_transposes = 2
+      self.assertEqual(expected_num_transposes, num_transposes)
+      self._assert_trans_ndhwc_to_ncdhw('Conv3D-0', nodes)
+      self._assert_map_ndhwc_to_ncdhw('Mean-1', nodes)
+      self._assert_trans_ncdhw_to_ndhwc('Mean-0-0', nodes)
+      self.assertAllClose(output_val_ref, output_val, atol=1e-3)
+
   @test_util.deprecated_graph_mode_only
   def testConv3D(self):
     if test.is_gpu_available(cuda_only=True):
diff --git a/tensorflow/python/ops/nn_test.py b/tensorflow/python/ops/nn_test.py
index 7f3d9f6e286..851bfcb66de 100644
--- a/tensorflow/python/ops/nn_test.py
+++ b/tensorflow/python/ops/nn_test.py
@@ -1233,6 +1233,33 @@ class DataFormatDimMapTest(test_lib.TestCase):
       y_val = self.evaluate(y)
       self.assertAllEqual(y_val, y_val_expected)
 
+  def testNDHWCtoNCDHW(self):
+    x_val = [1, -4, -3, -2]
+    y_val_expected = [2, 2, 3, 4]
+    x = constant_op.constant(x_val)
+    y = nn_ops.data_format_dim_map(x, src_format="NDHWC", dst_format="NCDHW")
+    with test_util.use_gpu():
+      y_val = self.evaluate(y)
+      self.assertAllEqual(y_val, y_val_expected)
+
+  def testNDHWCtoDHWNC(self):
+    x_val = [-5, -4, -3, -2, -1, 0, 1, 2, 3, 4]
+    y_val_expected = [3, 0, 1, 2, 4, 3, 0, 1, 2, 4]
+    x = constant_op.constant(x_val)
+    y = nn_ops.data_format_dim_map(x, src_format="NDHWC", dst_format="DHWNC")
+    with test_util.use_gpu():
+      y_val = self.evaluate(y)
+      self.assertAllEqual(y_val, y_val_expected)
+
+  def testDNHWCtoWHDCN(self):
+    x_val = [-5, -4, -3, -2, -1, 0, 1, 2, 3, 4]
+    y_val_expected = [4, 2, 1, 0, 3, 4, 2, 1, 0, 3]
+    x = constant_op.constant(x_val)
+    y = nn_ops.data_format_dim_map(x, src_format="NDHWC", dst_format="WHDCN")
+    with test_util.use_gpu():
+      y_val = self.evaluate(y)
+      self.assertAllEqual(y_val, y_val_expected)
+
   def testArbitraryASCII(self):
     x_val = [-4, -3, -2, -1, 0, 1, 2, 3]
     y_val_expected = [3, 2, 1, 0, 3, 2, 1, 0]

From a4444ef184802ba812090ece129c85b068c4f390 Mon Sep 17 00:00:00 2001
From: Kaixi Hou <kaixih@nvidia.com>
Date: Mon, 14 Sep 2020 18:09:21 -0700
Subject: [PATCH 2/7] Fix a failed test

---
 .../optimizers/generic_layout_optimizer_transposer.cc         | 4 +++-
 1 file changed, 3 insertions(+), 1 deletion(-)

diff --git a/tensorflow/core/grappler/optimizers/generic_layout_optimizer_transposer.cc b/tensorflow/core/grappler/optimizers/generic_layout_optimizer_transposer.cc
index 265a6ae7cac..f43297d8fac 100644
--- a/tensorflow/core/grappler/optimizers/generic_layout_optimizer_transposer.cc
+++ b/tensorflow/core/grappler/optimizers/generic_layout_optimizer_transposer.cc
@@ -1371,7 +1371,9 @@ bool ReduceTransposer::IsReduceAxisSupported(
 Status ReduceTransposer::TransposeNode(TransposeContext* context,
                                        utils::MutableNodeView* node) {
   DCHECK(IsReduceOp(*node->node()));
-  const auto* output_shape_attr = node->GetAttr(kAttrOutputShape);
+  const auto& regular_fanin = node->GetRegularFanin(0);
+  const auto* output_shape_attr =
+      regular_fanin.node_view()->GetAttr(kAttrOutputShape);
   const auto& shape = output_shape_attr->list().shape(0);
   const int rank = shape.dim_size();
   std::string src_format = context->src_format;

From edbfd773932fc43a4c7b4378325803510727d0d1 Mon Sep 17 00:00:00 2001
From: Kaixi Hou <kaixih@nvidia.com>
Date: Mon, 21 Sep 2020 10:16:09 -0700
Subject: [PATCH 3/7] restrict 5d cases to nhwc/nchw

---
 .../optimizers/generic_layout_optimizer_transposer.cc      | 7 ++++---
 1 file changed, 4 insertions(+), 3 deletions(-)

diff --git a/tensorflow/core/grappler/optimizers/generic_layout_optimizer_transposer.cc b/tensorflow/core/grappler/optimizers/generic_layout_optimizer_transposer.cc
index f43297d8fac..2db39995132 100644
--- a/tensorflow/core/grappler/optimizers/generic_layout_optimizer_transposer.cc
+++ b/tensorflow/core/grappler/optimizers/generic_layout_optimizer_transposer.cc
@@ -1379,7 +1379,8 @@ Status ReduceTransposer::TransposeNode(TransposeContext* context,
   std::string src_format = context->src_format;
   std::string dst_format = context->dst_format;
   // Update the format from 4D to 5D layout if necessary.
-  if (rank == 5) {
+  bool allow_5d = rank == 5 && (src_format == "NHWC" || src_format == "NCHW");
+  if (allow_5d) {
     std::string src_format_3d = src_format == "NHWC" ? "NDHWC" : "NCDHW";
     std::string dst_format_3d = dst_format == "NHWC" ? "NDHWC" : "NCDHW";
     context->AssignDeviceAndDataFormats(context->target_device, src_format_3d,
@@ -1389,7 +1390,7 @@ Status ReduceTransposer::TransposeNode(TransposeContext* context,
       !IsReduceAxisSupported(*context, *node) ||
       !IsAfterDstToSrcTransform(*context, *node)) {
     // Change back to the original layout due to early exit.
-    if (rank == 5) {
+    if (allow_5d) {
       context->AssignDeviceAndDataFormats(context->target_device, src_format,
                                           dst_format);
     }
@@ -1406,7 +1407,7 @@ Status ReduceTransposer::TransposeNode(TransposeContext* context,
         UpdateFanoutEdgesWithOp(context, {0}, node, kOpTranspose));
   }
   // Change back the format from 5D to 4D layout.
-  if (rank == 5) {
+  if (allow_5d) {
     context->AssignDeviceAndDataFormats(context->target_device, src_format,
                                         dst_format);
   }

From 25567bd8418c01cc23a169a369d7e9b156d2bcc5 Mon Sep 17 00:00:00 2001
From: Kaixi Hou <kaixih@nvidia.com>
Date: Tue, 22 Sep 2020 12:25:38 -0700
Subject: [PATCH 4/7] Remove debugging prints

---
 tensorflow/python/grappler/layout_optimizer_test.py | 1 -
 1 file changed, 1 deletion(-)

diff --git a/tensorflow/python/grappler/layout_optimizer_test.py b/tensorflow/python/grappler/layout_optimizer_test.py
index db85bf5bccd..74416c44ef7 100644
--- a/tensorflow/python/grappler/layout_optimizer_test.py
+++ b/tensorflow/python/grappler/layout_optimizer_test.py
@@ -1222,7 +1222,6 @@ class LayoutOptimizerTest(test.TestCase):
         if _is_transpose(node.name):
           num_transposes += 1
         nodes.append(node.name)
-        print(node.name)
 
       # The reduce op Mean needs to dim map the input reduce index to NCDHW.
       # Then, the output needs to be tranposed back to NDHWC.

From 7a38d3fd961c35cc4e113f684e4c4035d1904764 Mon Sep 17 00:00:00 2001
From: Kaixi Hou <kaixih@nvidia.com>
Date: Wed, 23 Sep 2020 11:49:20 -0700
Subject: [PATCH 5/7] update xla data format map ops

---
 .../compiler/tests/data_format_ops_test.py    | 16 +++++++++++++++
 .../tf2xla/kernels/data_format_ops.cc         | 20 +++++++++++--------
 2 files changed, 28 insertions(+), 8 deletions(-)

diff --git a/tensorflow/compiler/tests/data_format_ops_test.py b/tensorflow/compiler/tests/data_format_ops_test.py
index 08d44256b50..ca833326a50 100644
--- a/tensorflow/compiler/tests/data_format_ops_test.py
+++ b/tensorflow/compiler/tests/data_format_ops_test.py
@@ -63,6 +63,22 @@ class XlaDataFormatDimMapTest(xla_test.XLATestCase):
     self._test([-4, -3, -2, -1, 0, 1, 2, 3], "qwer", "rewq",
                [3, 2, 1, 0, 3, 2, 1, 0])
 
+    self._test(0, "NDHWC", "NCDHW", 0)
+    self._test(1, "NDHWC", "NCDHW", 2)
+    self._test(2, "NDHWC", "NCDHW", 3)
+    self._test(3, "NDHWC", "NCDHW", 4)
+    self._test(4, "NDHWC", "NCDHW", 1)
+    self._test([1, 4], "NDHWC", "NCDHW", [2, 1])
+    self._test([1, 4, -2], "NDHWC", "NCDHW", [2, 1, 4])
+    self._test([1, -3, -2], "NDHWC", "NCDHW", [2, 3, 4])
+    self._test([[1, -4], [1, -1]], "NDHWC", "NCDHW", [[2, 2], [2, 1]])
+
+    self._test([1, -3, -2], "NDHWC", "NCDHW", [2, 3, 4])
+    self._test([-5, -4, -3, -2, -1, 0, 1, 2, 3, 4], "NDHWC", "DHWNC",
+               [3, 0, 1, 2, 4, 3, 0, 1, 2, 4])
+    self._test([-5, -4, -3, -2, -1, 0, 1, 2, 3, 4], "NDHWC", "WHDCN",
+               [4, 2, 1, 0, 3, 4, 2, 1, 0, 3])
+
 
 class XlaPermuteOpTest(xla_test.XLATestCase):
 
diff --git a/tensorflow/compiler/tf2xla/kernels/data_format_ops.cc b/tensorflow/compiler/tf2xla/kernels/data_format_ops.cc
index c1f60abc0d6..687d394972b 100644
--- a/tensorflow/compiler/tf2xla/kernels/data_format_ops.cc
+++ b/tensorflow/compiler/tf2xla/kernels/data_format_ops.cc
@@ -35,15 +35,18 @@ class DataFormatDimMapOp : public XlaOpKernel {
     OP_REQUIRES_OK(context, context->GetAttr("src_format", &src_format));
     string dst_format;
     OP_REQUIRES_OK(context, context->GetAttr("dst_format", &dst_format));
-    OP_REQUIRES(context, src_format.size() == 4,
+    OP_REQUIRES(context, src_format.size() == 4 or src_format.size() == 5,
                 errors::InvalidArgument(absl::StrCat(
-                    "Source format must of length 4, received src_format = ",
-                    src_format)));
+                    "Source format must of length 4 or 5, "
+                    "received src_format = ", src_format)));
     OP_REQUIRES(
-        context, dst_format.size() == 4,
+        context, dst_format.size() == 4 or dst_format.size() == 5,
         errors::InvalidArgument(absl::StrCat(
-            "Destination format must of length 4, received dst_format = ",
+            "Destination format must of length 4 or 5, received dst_format = ",
             dst_format)));
+    for (int i = 0; i < src_format.size(); ++i) {
+      dst_idx_.push_back(-1);
+    }
     for (int i = 0; i < src_format.size(); ++i) {
       for (int j = 0; j < dst_format.size(); ++j) {
         if (dst_format[j] == src_format[i]) {
@@ -61,9 +64,10 @@ class DataFormatDimMapOp : public XlaOpKernel {
     auto builder = context->builder();
     xla::XlaOp dst_indices =
         xla::ConstantR1(builder, absl::Span<const int32>(dst_idx_));
-    xla::XlaOp four = xla::ConstantR0<int32>(builder, 4);
+    const int dims = dst_idx_.size();
+    xla::XlaOp rank = xla::ConstantR0<int32>(builder, dims);
     xla::XlaOp src_indices =
-        (xla::ConvertElementType(context->Input(0), xla::S32) + four) % four;
+        (xla::ConvertElementType(context->Input(0), xla::S32) + rank) % rank;
     xla::XlaOp output =
         xla::TorchIndexSelect(dst_indices, src_indices, /*dim=*/0);
     context->SetOutput(
@@ -71,7 +75,7 @@ class DataFormatDimMapOp : public XlaOpKernel {
   }
 
  private:
-  std::array<int32, 4> dst_idx_ = {{-1, -1, -1, -1}};
+  std::vector<int32> dst_idx_;
 
   TF_DISALLOW_COPY_AND_ASSIGN(DataFormatDimMapOp);
 };

From d602375436790d93675af2b2ccb726756476e233 Mon Sep 17 00:00:00 2001
From: Kaixi Hou <kaixih@nvidia.com>
Date: Wed, 23 Sep 2020 11:52:28 -0700
Subject: [PATCH 6/7] remove two unused lines

---
 tensorflow/python/grappler/layout_optimizer_test.py | 2 --
 1 file changed, 2 deletions(-)

diff --git a/tensorflow/python/grappler/layout_optimizer_test.py b/tensorflow/python/grappler/layout_optimizer_test.py
index 74416c44ef7..a69ed72db87 100644
--- a/tensorflow/python/grappler/layout_optimizer_test.py
+++ b/tensorflow/python/grappler/layout_optimizer_test.py
@@ -1203,8 +1203,6 @@ class LayoutOptimizerTest(test.TestCase):
       random_seed.set_random_seed(0)
       x = random_ops.truncated_normal([1, 4, 2, 3, 3], seed=0)
       w = random_ops.truncated_normal([2, 2, 2, 3, 3], seed=0)
-      gamma = random_ops.truncated_normal([1, 1, 1, 1, 3], seed=0)
-      beta = random_ops.truncated_normal([1, 1, 1, 1, 3], seed=0)
       conv3d = gen_nn_ops.conv3d(x, w, [1, 1, 1, 1, 1], 'SAME')
       y = math_ops.reduce_mean(conv3d, [0, 1, 2, 3], keepdims=True)
       output = array_ops.identity(y)

From 9b947dd6377c022091c8aa005cdcff52c53ff5f0 Mon Sep 17 00:00:00 2001
From: Kaixi Hou <kaixih@nvidia.com>
Date: Wed, 23 Sep 2020 12:04:10 -0700
Subject: [PATCH 7/7] Also check dst_format

---
 .../grappler/optimizers/generic_layout_optimizer_transposer.cc | 3 ++-
 1 file changed, 2 insertions(+), 1 deletion(-)

diff --git a/tensorflow/core/grappler/optimizers/generic_layout_optimizer_transposer.cc b/tensorflow/core/grappler/optimizers/generic_layout_optimizer_transposer.cc
index 2db39995132..c425ef51c2f 100644
--- a/tensorflow/core/grappler/optimizers/generic_layout_optimizer_transposer.cc
+++ b/tensorflow/core/grappler/optimizers/generic_layout_optimizer_transposer.cc
@@ -1379,7 +1379,8 @@ Status ReduceTransposer::TransposeNode(TransposeContext* context,
   std::string src_format = context->src_format;
   std::string dst_format = context->dst_format;
   // Update the format from 4D to 5D layout if necessary.
-  bool allow_5d = rank == 5 && (src_format == "NHWC" || src_format == "NCHW");
+  bool allow_5d = rank == 5 && (src_format == "NHWC" || src_format == "NCHW") &&
+                  (dst_format == "NHWC" || dst_format == "NCHW");
   if (allow_5d) {
     std::string src_format_3d = src_format == "NHWC" ? "NDHWC" : "NCDHW";
     std::string dst_format_3d = dst_format == "NHWC" ? "NDHWC" : "NCDHW";