From bd1bcb4f21c93896afa57aec2f540cb0e0147d32 Mon Sep 17 00:00:00 2001
From: Benjamin Kramer <kramerb@google.com>
Date: Wed, 13 Mar 2019 04:36:06 -0700
Subject: [PATCH] [TF:XLA] Implement NCHW_VECT_C for DepthToSpace/SpaceToBatch
 by desugaring to NCHW

XLA is smart enough to simplify the extra steps away so this shouldn't be
significantly more expensive than a "native" implementation. TF only uses
NCHW_VECT_C for quantized int8 convolutions which XLA doesn't support, but the
data formatting around it can be compiled by XLA.

The actual formatting is factored into separate functions, it'll likely come in
handy again for other ops.

PiperOrigin-RevId: 238202809
---
 tensorflow/compiler/tests/unary_ops_test.py   | 20 ++++-
 tensorflow/compiler/tf2xla/kernels/BUILD      |  1 +
 .../tf2xla/kernels/depthtospace_op.cc         | 49 ++++++++---
 .../tf2xla/kernels/spacetodepth_op.cc         | 49 ++++++++---
 tensorflow/compiler/tf2xla/lib/BUILD          | 12 +++
 tensorflow/compiler/tf2xla/lib/data_format.cc | 87 +++++++++++++++++++
 tensorflow/compiler/tf2xla/lib/data_format.h  | 37 ++++++++
 .../kernel_tests/depthtospace_op_test.py      |  1 -
 .../kernel_tests/spacetodepth_op_test.py      |  1 -
 9 files changed, 226 insertions(+), 31 deletions(-)
 create mode 100644 tensorflow/compiler/tf2xla/lib/data_format.cc
 create mode 100644 tensorflow/compiler/tf2xla/lib/data_format.h

diff --git a/tensorflow/compiler/tests/unary_ops_test.py b/tensorflow/compiler/tests/unary_ops_test.py
index f2e0eac2d99..159fa6685b5 100644
--- a/tensorflow/compiler/tests/unary_ops_test.py
+++ b/tensorflow/compiler/tests/unary_ops_test.py
@@ -74,7 +74,7 @@ class UnaryOpsTest(xla_test.XLATestCase):
       if equality_test is None:
         self.assertEqual(output.dtype, expected.dtype)
         self.assertAllCloseAccordingToType(
-            result, expected, rtol=rtol, atol=atol, bfloat16_rtol=0.03)
+            expected, result, rtol=rtol, atol=atol, bfloat16_rtol=0.03)
       else:
         equality_test(result, expected, rtol=rtol, atol=atol)
 
@@ -956,6 +956,15 @@ class UnaryOpsTest(xla_test.XLATestCase):
                       [[9], [10], [13], [14]], [[11], [12], [15], [16]]]],
                     dtype=dtype), data_format))
 
+      self._assertOpOutputMatchesExpected(
+          make_op("NCHW_VECT_C"),
+          np.arange(32, dtype=dtype).reshape((1, 8, 1, 1, 4)),
+          expected=np.array([[[[[0, 1], [8, 9]], [[16, 17], [24, 25]]],
+                              [[[2, 3], [10, 11]], [[18, 19], [26, 27]]],
+                              [[[4, 5], [12, 13]], [[20, 21], [28, 29]]],
+                              [[[6, 7], [14, 15]], [[22, 23], [30, 31]]]]],
+                            dtype=dtype))
+
   def testSpaceToDepth(self):
 
     def make_op(data_format):
@@ -999,6 +1008,15 @@ class UnaryOpsTest(xla_test.XLATestCase):
                                                      [13, 14, 15, 16]]]],
                     dtype=dtype), data_format))
 
+      self._assertOpOutputMatchesExpected(
+          make_op("NCHW_VECT_C"),
+          np.arange(32, dtype=dtype).reshape((1, 2, 2, 2, 4)),
+          expected=np.array([[[[[0, 1, 2, 3, 16, 17, 18, 19]]],
+                              [[[4, 5, 6, 7, 20, 21, 22, 23]]],
+                              [[[8, 9, 10, 11, 24, 25, 26, 27]]],
+                              [[[12, 13, 14, 15, 28, 29, 30, 31]]]]],
+                            dtype=dtype))
+
   def _assertSoftplusMatchesExpected(self, features, dtype):
     features = np.array(features, dtype=dtype)
     zero = np.asarray(0).astype(dtype)
diff --git a/tensorflow/compiler/tf2xla/kernels/BUILD b/tensorflow/compiler/tf2xla/kernels/BUILD
index ef7492590b1..cf297786888 100644
--- a/tensorflow/compiler/tf2xla/kernels/BUILD
+++ b/tensorflow/compiler/tf2xla/kernels/BUILD
@@ -127,6 +127,7 @@ tf_kernel_library(
         "//tensorflow/compiler/tf2xla:common",
         "//tensorflow/compiler/tf2xla:xla_compiler",
         "//tensorflow/compiler/tf2xla/lib:broadcast",
+        "//tensorflow/compiler/tf2xla/lib:data_format",
         "//tensorflow/compiler/tf2xla/lib:random",
         "//tensorflow/compiler/tf2xla/lib:scatter",
         "//tensorflow/compiler/tf2xla/lib:util",
diff --git a/tensorflow/compiler/tf2xla/kernels/depthtospace_op.cc b/tensorflow/compiler/tf2xla/kernels/depthtospace_op.cc
index e96a1adce43..9fe91d16d77 100644
--- a/tensorflow/compiler/tf2xla/kernels/depthtospace_op.cc
+++ b/tensorflow/compiler/tf2xla/kernels/depthtospace_op.cc
@@ -13,6 +13,7 @@ See the License for the specific language governing permissions and
 limitations under the License.
 ==============================================================================*/
 
+#include "tensorflow/compiler/tf2xla/lib/data_format.h"
 #include "tensorflow/compiler/tf2xla/xla_helpers.h"
 #include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
 #include "tensorflow/compiler/tf2xla/xla_op_registry.h"
@@ -30,11 +31,6 @@ class DepthToSpaceOp : public XlaOpKernel {
     OP_REQUIRES(ctx, FormatFromString(data_format_str, &data_format_),
                 errors::InvalidArgument("Invalid data format"));
 
-    OP_REQUIRES(ctx, data_format_ == FORMAT_NCHW || data_format_ == FORMAT_NHWC,
-                errors::InvalidArgument("Unsupported data format ",
-                                        ToString(data_format_),
-                                        "; expected formats NHWC or NCHW"));
-
     OP_REQUIRES_OK(ctx, ctx->GetAttr("block_size", &block_size_));
     OP_REQUIRES(
         ctx, block_size_ > 1,
@@ -42,19 +38,36 @@ class DepthToSpaceOp : public XlaOpKernel {
   }
 
   void Compile(XlaOpKernelContext* ctx) override {
-    const TensorShape input_tensor_shape = ctx->InputShape(0);
-    int input_rank = input_tensor_shape.dims();
+    xla::XlaOp input = ctx->Input(0);
+
+    TensorFormat data_format = data_format_;
+    // If the data is in a vectorized format, reformat it into a non-vectorized
+    // version first. We'll undo the transformation later.
+    if (data_format == FORMAT_NCHW_VECT_C) {
+      data_format = FORMAT_NCHW;
+      auto input_reshaped = NCHW_VECT_CToNCHW(input);
+      OP_REQUIRES_OK(ctx, input_reshaped.status());
+      input = input_reshaped.ValueOrDie();
+    }
+
+    OP_REQUIRES(ctx, data_format == FORMAT_NCHW || data_format == FORMAT_NHWC,
+                errors::InvalidArgument("Unsupported data format ",
+                                        ToString(data_format_)));
+
+    xla::XlaBuilder* builder = input.builder();
+    auto input_xla_shape = builder->GetShape(input);
+    OP_REQUIRES_OK(ctx, input_xla_shape.status());
+    const std::vector<int64>& input_shape =
+        input_xla_shape.ValueOrDie().dimensions();
+    int input_rank = input_shape.size();
+
     static const int kRequiredDims = 4;
     OP_REQUIRES(ctx, kRequiredDims == input_rank,
                 errors::InvalidArgument("Input rank should be ", kRequiredDims,
                                         "; got: ", input_rank));
-    const absl::InlinedVector<int64, 4> input_shape =
-        input_tensor_shape.dim_sizes();
 
-    xla::XlaOp input = ctx->Input(0);
-
-    int feature_dim = GetTensorFeatureDimIndex(input_rank, data_format_);
-    int num_spatial_dims = GetTensorSpatialDims(input_rank, data_format_);
+    int feature_dim = GetTensorFeatureDimIndex(input_rank, data_format);
+    int num_spatial_dims = GetTensorSpatialDims(input_rank, data_format);
 
     std::vector<int64> reshaped_shape;
     std::vector<int64> transpose_order;
@@ -62,7 +75,7 @@ class DepthToSpaceOp : public XlaOpKernel {
     reshaped_shape.reserve(input_rank);
     transpose_order.reserve(input_rank);
     output_shape.reserve(input_rank);
-    if (data_format_ == FORMAT_NHWC) {
+    if (data_format == FORMAT_NHWC) {
       reshaped_shape.push_back(input_shape[0]);
       for (int i = 0; i < num_spatial_dims; ++i) {
         reshaped_shape.push_back(input_shape[1 + i]);
@@ -153,6 +166,14 @@ class DepthToSpaceOp : public XlaOpKernel {
     //
     xla::XlaOp output = xla::Reshape(permuted_reshaped, output_shape);
 
+    // If this used to be a vectorized format turn it back now.
+    if (data_format != data_format_) {
+      DCHECK(data_format == FORMAT_NCHW && data_format_ == FORMAT_NCHW_VECT_C);
+      auto output_reshaped = NCHWToNCHW_VECT_C(output);
+      OP_REQUIRES_OK(ctx, output_reshaped.status());
+      output = output_reshaped.ValueOrDie();
+    }
+
     ctx->SetOutput(0, output);
   }
 
diff --git a/tensorflow/compiler/tf2xla/kernels/spacetodepth_op.cc b/tensorflow/compiler/tf2xla/kernels/spacetodepth_op.cc
index 3293c13b21b..96863d6d1ba 100644
--- a/tensorflow/compiler/tf2xla/kernels/spacetodepth_op.cc
+++ b/tensorflow/compiler/tf2xla/kernels/spacetodepth_op.cc
@@ -13,6 +13,7 @@ See the License for the specific language governing permissions and
 limitations under the License.
 ==============================================================================*/
 
+#include "tensorflow/compiler/tf2xla/lib/data_format.h"
 #include "tensorflow/compiler/tf2xla/xla_helpers.h"
 #include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
 #include "tensorflow/compiler/tf2xla/xla_op_registry.h"
@@ -30,11 +31,6 @@ class SpaceToDepthOp : public XlaOpKernel {
     OP_REQUIRES(ctx, FormatFromString(data_format_str, &data_format_),
                 errors::InvalidArgument("Invalid data format"));
 
-    OP_REQUIRES(ctx, data_format_ == FORMAT_NCHW || data_format_ == FORMAT_NHWC,
-                errors::InvalidArgument("Unsupported data format ",
-                                        ToString(data_format_),
-                                        "; expected formats NHWC or NCHW"));
-
     OP_REQUIRES_OK(ctx, ctx->GetAttr("block_size", &block_size_));
     OP_REQUIRES(
         ctx, block_size_ > 1,
@@ -42,19 +38,36 @@ class SpaceToDepthOp : public XlaOpKernel {
   }
 
   void Compile(XlaOpKernelContext* ctx) override {
-    const TensorShape input_tensor_shape = ctx->InputShape(0);
-    int input_rank = input_tensor_shape.dims();
+    xla::XlaOp input = ctx->Input(0);
+
+    TensorFormat data_format = data_format_;
+    // If the data is in a vectorized format, reformat it into a non-vectorized
+    // version first. We'll undo the transformation later.
+    if (data_format == FORMAT_NCHW_VECT_C) {
+      data_format = FORMAT_NCHW;
+      auto input_reshaped = NCHW_VECT_CToNCHW(input);
+      OP_REQUIRES_OK(ctx, input_reshaped.status());
+      input = input_reshaped.ValueOrDie();
+    }
+
+    OP_REQUIRES(ctx, data_format == FORMAT_NCHW || data_format == FORMAT_NHWC,
+                errors::InvalidArgument("Unsupported data format ",
+                                        ToString(data_format_)));
+
+    xla::XlaBuilder* builder = input.builder();
+    auto input_xla_shape = builder->GetShape(input);
+    OP_REQUIRES_OK(ctx, input_xla_shape.status());
+    const std::vector<int64>& input_shape =
+        input_xla_shape.ValueOrDie().dimensions();
+    int input_rank = input_shape.size();
+
     static const int kRequiredDims = 4;
     OP_REQUIRES(ctx, kRequiredDims == input_rank,
                 errors::InvalidArgument("Input rank should be ", kRequiredDims,
                                         "; got ", input_rank));
-    const absl::InlinedVector<int64, 4> input_shape =
-        input_tensor_shape.dim_sizes();
 
-    xla::XlaOp input = ctx->Input(0);
-
-    int feature_dim = GetTensorFeatureDimIndex(input_rank, data_format_);
-    int num_spatial_dims = GetTensorSpatialDims(input_rank, data_format_);
+    int feature_dim = GetTensorFeatureDimIndex(input_rank, data_format);
+    int num_spatial_dims = GetTensorSpatialDims(input_rank, data_format);
 
     std::vector<int64> reshaped_shape;
     std::vector<int64> transpose_order;
@@ -62,7 +75,7 @@ class SpaceToDepthOp : public XlaOpKernel {
     reshaped_shape.reserve(input_rank);
     transpose_order.reserve(input_rank);
     output_shape.reserve(input_rank);
-    if (data_format_ == FORMAT_NHWC) {
+    if (data_format == FORMAT_NHWC) {
       int64 block_elems = 1;
       for (int i = 0; i < num_spatial_dims; ++i) {
         OP_REQUIRES(ctx, input_shape[1 + i] % block_size_ == 0,
@@ -157,6 +170,14 @@ class SpaceToDepthOp : public XlaOpKernel {
     //
     xla::XlaOp output = xla::Reshape(permuted_reshaped, output_shape);
 
+    // If this used to be a vectorized format turn it back now.
+    if (data_format != data_format_) {
+      DCHECK(data_format == FORMAT_NCHW && data_format_ == FORMAT_NCHW_VECT_C);
+      auto output_reshaped = NCHWToNCHW_VECT_C(output);
+      OP_REQUIRES_OK(ctx, output_reshaped.status());
+      output = output_reshaped.ValueOrDie();
+    }
+
     ctx->SetOutput(0, output);
   }
 
diff --git a/tensorflow/compiler/tf2xla/lib/BUILD b/tensorflow/compiler/tf2xla/lib/BUILD
index 3d7b0bc959f..f9ce50be6e3 100644
--- a/tensorflow/compiler/tf2xla/lib/BUILD
+++ b/tensorflow/compiler/tf2xla/lib/BUILD
@@ -82,3 +82,15 @@ cc_library(
         "@com_google_absl//absl/types:span",
     ],
 )
+
+cc_library(
+    name = "data_format",
+    srcs = ["data_format.cc"],
+    hdrs = ["data_format.h"],
+    deps = [
+        "//tensorflow/compiler/xla:statusor",
+        "//tensorflow/compiler/xla/client:xla_builder",
+        "//tensorflow/core:framework",
+        "//tensorflow/core:lib",
+    ],
+)
diff --git a/tensorflow/compiler/tf2xla/lib/data_format.cc b/tensorflow/compiler/tf2xla/lib/data_format.cc
new file mode 100644
index 00000000000..0253bcdc5f9
--- /dev/null
+++ b/tensorflow/compiler/tf2xla/lib/data_format.cc
@@ -0,0 +1,87 @@
+/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/compiler/tf2xla/lib/data_format.h"
+#include "tensorflow/core/lib/core/errors.h"
+
+namespace tensorflow {
+namespace {
+
+xla::StatusOr<xla::XlaOp> Contract(xla::XlaOp input, int64 dim) {
+  xla::XlaBuilder* builder = input.builder();
+  TF_ASSIGN_OR_RETURN(xla::Shape input_shape, builder->GetShape(input));
+
+  if (input_shape.dimensions().back() != 4) {
+    return errors::InvalidArgument("Expected last dimension to be 4; got ",
+                                   input_shape.dimensions().back());
+  }
+
+  // Transpose the input so C is directly followed by VECT_C.
+  std::vector<int64> permutation;
+  for (int64 i = 0; i != input_shape.rank() - 1; ++i) {
+    permutation.push_back(i);
+    if (i == dim) {
+      permutation.push_back(input_shape.rank() - 1);
+    }
+  }
+
+  // Now merge the adjacent dimensions with a reshape.
+  std::vector<int64> contracted_shape(input_shape.dimensions().begin(),
+                                      input_shape.dimensions().end() - 1);
+  contracted_shape[dim] *= 4;
+
+  return xla::Reshape(xla::Transpose(input, permutation), contracted_shape);
+}
+
+xla::StatusOr<xla::XlaOp> Expand(xla::XlaOp input, int64 dim) {
+  xla::XlaBuilder* builder = input.builder();
+  TF_ASSIGN_OR_RETURN(xla::Shape input_shape, builder->GetShape(input));
+
+  if (input_shape.dimensions(dim) % 4 != 0) {
+    return errors::InvalidArgument(
+        "Expected vectorized dimension to be evenly divisible by 4; got ",
+        input_shape.dimensions(dim));
+  }
+
+  // Split the `dim` into two dimensions with a reshape. The size of the new
+  // dimension is always 4.
+  std::vector<int64> expanded_shape(input_shape.dimensions());
+  expanded_shape[dim] /= 4;
+  expanded_shape.insert(expanded_shape.begin() + dim, 4);
+
+  // Move the newly created dimension to the end with a transpose.
+  std::vector<int64> permutation;
+  for (int64 i = 0; i != expanded_shape.size(); ++i) {
+    permutation.push_back(i);
+    if (i == dim) {
+      ++i;
+    }
+  }
+  permutation.push_back(dim + 1);
+
+  return xla::Transpose(xla::Reshape(input, expanded_shape), permutation);
+}
+
+}  // namespace
+
+xla::StatusOr<xla::XlaOp> NCHW_VECT_CToNCHW(xla::XlaOp input) {
+  return Contract(input, 1);
+}
+
+xla::StatusOr<xla::XlaOp> NCHWToNCHW_VECT_C(xla::XlaOp input) {
+  return Expand(input, 1);
+}
+
+}  // namespace tensorflow
diff --git a/tensorflow/compiler/tf2xla/lib/data_format.h b/tensorflow/compiler/tf2xla/lib/data_format.h
new file mode 100644
index 00000000000..839723b0ea8
--- /dev/null
+++ b/tensorflow/compiler/tf2xla/lib/data_format.h
@@ -0,0 +1,37 @@
+/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_COMPILER_TF2XLA_LIB_DATA_FORMAT_H_
+#define TENSORFLOW_COMPILER_TF2XLA_LIB_DATA_FORMAT_H_
+
+#include "tensorflow/compiler/xla/client/xla_builder.h"
+#include "tensorflow/compiler/xla/statusor.h"
+#include "tensorflow/core/util/tensor_format.h"
+
+namespace tensorflow {
+
+// Reformat from NCHW_VECT_C to NCHW.
+//
+// Prerequisites: the last dimension of the input must be of size 4.
+xla::StatusOr<xla::XlaOp> NCHW_VECT_CToNCHW(xla::XlaOp input);
+
+// Reformat from NCHW to NCHW_VECT_C.
+//
+// Prerequisites: the vectorized dimension `C` must be a multiple of 4.
+xla::StatusOr<xla::XlaOp> NCHWToNCHW_VECT_C(xla::XlaOp input);
+
+}  // namespace tensorflow
+
+#endif  // TENSORFLOW_COMPILER_TF2XLA_LIB_DATA_FORMAT_H_
diff --git a/tensorflow/python/kernel_tests/depthtospace_op_test.py b/tensorflow/python/kernel_tests/depthtospace_op_test.py
index b7a865cf13e..96c9b5258e2 100644
--- a/tensorflow/python/kernel_tests/depthtospace_op_test.py
+++ b/tensorflow/python/kernel_tests/depthtospace_op_test.py
@@ -295,7 +295,6 @@ class DepthToSpaceTest(test.TestCase):
       actual_vals, expected_vals = self.evaluate([actual, expected])
       self.assertTrue(np.array_equal(actual_vals, expected_vals))
 
-  @test_util.disable_xla("b/123553551")  # Unsupported data format
   def testAgainstTranspose(self):
     self.compareToTranspose(3, 2, 3, 1, 2, "NHWC", False)
     self.compareToTranspose(3, 2, 3, 2, 2, "NHWC", False)
diff --git a/tensorflow/python/kernel_tests/spacetodepth_op_test.py b/tensorflow/python/kernel_tests/spacetodepth_op_test.py
index 69243afb69c..e96bc09f365 100644
--- a/tensorflow/python/kernel_tests/spacetodepth_op_test.py
+++ b/tensorflow/python/kernel_tests/spacetodepth_op_test.py
@@ -285,7 +285,6 @@ class SpaceToDepthTest(test.TestCase):
       actual_vals, expected_vals = self.evaluate([actual, expected])
       self.assertTrue(np.array_equal(actual_vals, expected_vals))
 
-  @test_util.disable_xla("b/123553551")  # Unsupported data format
   def testAgainstTranspose(self):
     self.compareToTranspose(3, 2, 3, 1, 2, "NHWC", False)
     self.compareToTranspose(1, 2, 3, 2, 2, "NHWC", False)