[TF:XLA] Implement NCHW_VECT_C for DepthToSpace/SpaceToBatch by desugaring to NCHW
XLA is smart enough to simplify the extra steps away so this shouldn't be significantly more expensive than a "native" implementation. TF only uses NCHW_VECT_C for quantized int8 convolutions which XLA doesn't support, but the data formatting around it can be compiled by XLA. The actual formatting is factored into separate functions, it'll likely come in handy again for other ops. PiperOrigin-RevId: 238202809
This commit is contained in:
parent
444a8bc05f
commit
bd1bcb4f21
@ -74,7 +74,7 @@ class UnaryOpsTest(xla_test.XLATestCase):
|
||||
if equality_test is None:
|
||||
self.assertEqual(output.dtype, expected.dtype)
|
||||
self.assertAllCloseAccordingToType(
|
||||
result, expected, rtol=rtol, atol=atol, bfloat16_rtol=0.03)
|
||||
expected, result, rtol=rtol, atol=atol, bfloat16_rtol=0.03)
|
||||
else:
|
||||
equality_test(result, expected, rtol=rtol, atol=atol)
|
||||
|
||||
@ -956,6 +956,15 @@ class UnaryOpsTest(xla_test.XLATestCase):
|
||||
[[9], [10], [13], [14]], [[11], [12], [15], [16]]]],
|
||||
dtype=dtype), data_format))
|
||||
|
||||
self._assertOpOutputMatchesExpected(
|
||||
make_op("NCHW_VECT_C"),
|
||||
np.arange(32, dtype=dtype).reshape((1, 8, 1, 1, 4)),
|
||||
expected=np.array([[[[[0, 1], [8, 9]], [[16, 17], [24, 25]]],
|
||||
[[[2, 3], [10, 11]], [[18, 19], [26, 27]]],
|
||||
[[[4, 5], [12, 13]], [[20, 21], [28, 29]]],
|
||||
[[[6, 7], [14, 15]], [[22, 23], [30, 31]]]]],
|
||||
dtype=dtype))
|
||||
|
||||
def testSpaceToDepth(self):
|
||||
|
||||
def make_op(data_format):
|
||||
@ -999,6 +1008,15 @@ class UnaryOpsTest(xla_test.XLATestCase):
|
||||
[13, 14, 15, 16]]]],
|
||||
dtype=dtype), data_format))
|
||||
|
||||
self._assertOpOutputMatchesExpected(
|
||||
make_op("NCHW_VECT_C"),
|
||||
np.arange(32, dtype=dtype).reshape((1, 2, 2, 2, 4)),
|
||||
expected=np.array([[[[[0, 1, 2, 3, 16, 17, 18, 19]]],
|
||||
[[[4, 5, 6, 7, 20, 21, 22, 23]]],
|
||||
[[[8, 9, 10, 11, 24, 25, 26, 27]]],
|
||||
[[[12, 13, 14, 15, 28, 29, 30, 31]]]]],
|
||||
dtype=dtype))
|
||||
|
||||
def _assertSoftplusMatchesExpected(self, features, dtype):
|
||||
features = np.array(features, dtype=dtype)
|
||||
zero = np.asarray(0).astype(dtype)
|
||||
|
@ -127,6 +127,7 @@ tf_kernel_library(
|
||||
"//tensorflow/compiler/tf2xla:common",
|
||||
"//tensorflow/compiler/tf2xla:xla_compiler",
|
||||
"//tensorflow/compiler/tf2xla/lib:broadcast",
|
||||
"//tensorflow/compiler/tf2xla/lib:data_format",
|
||||
"//tensorflow/compiler/tf2xla/lib:random",
|
||||
"//tensorflow/compiler/tf2xla/lib:scatter",
|
||||
"//tensorflow/compiler/tf2xla/lib:util",
|
||||
|
@ -13,6 +13,7 @@ See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/compiler/tf2xla/lib/data_format.h"
|
||||
#include "tensorflow/compiler/tf2xla/xla_helpers.h"
|
||||
#include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
|
||||
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
|
||||
@ -30,11 +31,6 @@ class DepthToSpaceOp : public XlaOpKernel {
|
||||
OP_REQUIRES(ctx, FormatFromString(data_format_str, &data_format_),
|
||||
errors::InvalidArgument("Invalid data format"));
|
||||
|
||||
OP_REQUIRES(ctx, data_format_ == FORMAT_NCHW || data_format_ == FORMAT_NHWC,
|
||||
errors::InvalidArgument("Unsupported data format ",
|
||||
ToString(data_format_),
|
||||
"; expected formats NHWC or NCHW"));
|
||||
|
||||
OP_REQUIRES_OK(ctx, ctx->GetAttr("block_size", &block_size_));
|
||||
OP_REQUIRES(
|
||||
ctx, block_size_ > 1,
|
||||
@ -42,19 +38,36 @@ class DepthToSpaceOp : public XlaOpKernel {
|
||||
}
|
||||
|
||||
void Compile(XlaOpKernelContext* ctx) override {
|
||||
const TensorShape input_tensor_shape = ctx->InputShape(0);
|
||||
int input_rank = input_tensor_shape.dims();
|
||||
xla::XlaOp input = ctx->Input(0);
|
||||
|
||||
TensorFormat data_format = data_format_;
|
||||
// If the data is in a vectorized format, reformat it into a non-vectorized
|
||||
// version first. We'll undo the transformation later.
|
||||
if (data_format == FORMAT_NCHW_VECT_C) {
|
||||
data_format = FORMAT_NCHW;
|
||||
auto input_reshaped = NCHW_VECT_CToNCHW(input);
|
||||
OP_REQUIRES_OK(ctx, input_reshaped.status());
|
||||
input = input_reshaped.ValueOrDie();
|
||||
}
|
||||
|
||||
OP_REQUIRES(ctx, data_format == FORMAT_NCHW || data_format == FORMAT_NHWC,
|
||||
errors::InvalidArgument("Unsupported data format ",
|
||||
ToString(data_format_)));
|
||||
|
||||
xla::XlaBuilder* builder = input.builder();
|
||||
auto input_xla_shape = builder->GetShape(input);
|
||||
OP_REQUIRES_OK(ctx, input_xla_shape.status());
|
||||
const std::vector<int64>& input_shape =
|
||||
input_xla_shape.ValueOrDie().dimensions();
|
||||
int input_rank = input_shape.size();
|
||||
|
||||
static const int kRequiredDims = 4;
|
||||
OP_REQUIRES(ctx, kRequiredDims == input_rank,
|
||||
errors::InvalidArgument("Input rank should be ", kRequiredDims,
|
||||
"; got: ", input_rank));
|
||||
const absl::InlinedVector<int64, 4> input_shape =
|
||||
input_tensor_shape.dim_sizes();
|
||||
|
||||
xla::XlaOp input = ctx->Input(0);
|
||||
|
||||
int feature_dim = GetTensorFeatureDimIndex(input_rank, data_format_);
|
||||
int num_spatial_dims = GetTensorSpatialDims(input_rank, data_format_);
|
||||
int feature_dim = GetTensorFeatureDimIndex(input_rank, data_format);
|
||||
int num_spatial_dims = GetTensorSpatialDims(input_rank, data_format);
|
||||
|
||||
std::vector<int64> reshaped_shape;
|
||||
std::vector<int64> transpose_order;
|
||||
@ -62,7 +75,7 @@ class DepthToSpaceOp : public XlaOpKernel {
|
||||
reshaped_shape.reserve(input_rank);
|
||||
transpose_order.reserve(input_rank);
|
||||
output_shape.reserve(input_rank);
|
||||
if (data_format_ == FORMAT_NHWC) {
|
||||
if (data_format == FORMAT_NHWC) {
|
||||
reshaped_shape.push_back(input_shape[0]);
|
||||
for (int i = 0; i < num_spatial_dims; ++i) {
|
||||
reshaped_shape.push_back(input_shape[1 + i]);
|
||||
@ -153,6 +166,14 @@ class DepthToSpaceOp : public XlaOpKernel {
|
||||
//
|
||||
xla::XlaOp output = xla::Reshape(permuted_reshaped, output_shape);
|
||||
|
||||
// If this used to be a vectorized format turn it back now.
|
||||
if (data_format != data_format_) {
|
||||
DCHECK(data_format == FORMAT_NCHW && data_format_ == FORMAT_NCHW_VECT_C);
|
||||
auto output_reshaped = NCHWToNCHW_VECT_C(output);
|
||||
OP_REQUIRES_OK(ctx, output_reshaped.status());
|
||||
output = output_reshaped.ValueOrDie();
|
||||
}
|
||||
|
||||
ctx->SetOutput(0, output);
|
||||
}
|
||||
|
||||
|
@ -13,6 +13,7 @@ See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/compiler/tf2xla/lib/data_format.h"
|
||||
#include "tensorflow/compiler/tf2xla/xla_helpers.h"
|
||||
#include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
|
||||
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
|
||||
@ -30,11 +31,6 @@ class SpaceToDepthOp : public XlaOpKernel {
|
||||
OP_REQUIRES(ctx, FormatFromString(data_format_str, &data_format_),
|
||||
errors::InvalidArgument("Invalid data format"));
|
||||
|
||||
OP_REQUIRES(ctx, data_format_ == FORMAT_NCHW || data_format_ == FORMAT_NHWC,
|
||||
errors::InvalidArgument("Unsupported data format ",
|
||||
ToString(data_format_),
|
||||
"; expected formats NHWC or NCHW"));
|
||||
|
||||
OP_REQUIRES_OK(ctx, ctx->GetAttr("block_size", &block_size_));
|
||||
OP_REQUIRES(
|
||||
ctx, block_size_ > 1,
|
||||
@ -42,19 +38,36 @@ class SpaceToDepthOp : public XlaOpKernel {
|
||||
}
|
||||
|
||||
void Compile(XlaOpKernelContext* ctx) override {
|
||||
const TensorShape input_tensor_shape = ctx->InputShape(0);
|
||||
int input_rank = input_tensor_shape.dims();
|
||||
xla::XlaOp input = ctx->Input(0);
|
||||
|
||||
TensorFormat data_format = data_format_;
|
||||
// If the data is in a vectorized format, reformat it into a non-vectorized
|
||||
// version first. We'll undo the transformation later.
|
||||
if (data_format == FORMAT_NCHW_VECT_C) {
|
||||
data_format = FORMAT_NCHW;
|
||||
auto input_reshaped = NCHW_VECT_CToNCHW(input);
|
||||
OP_REQUIRES_OK(ctx, input_reshaped.status());
|
||||
input = input_reshaped.ValueOrDie();
|
||||
}
|
||||
|
||||
OP_REQUIRES(ctx, data_format == FORMAT_NCHW || data_format == FORMAT_NHWC,
|
||||
errors::InvalidArgument("Unsupported data format ",
|
||||
ToString(data_format_)));
|
||||
|
||||
xla::XlaBuilder* builder = input.builder();
|
||||
auto input_xla_shape = builder->GetShape(input);
|
||||
OP_REQUIRES_OK(ctx, input_xla_shape.status());
|
||||
const std::vector<int64>& input_shape =
|
||||
input_xla_shape.ValueOrDie().dimensions();
|
||||
int input_rank = input_shape.size();
|
||||
|
||||
static const int kRequiredDims = 4;
|
||||
OP_REQUIRES(ctx, kRequiredDims == input_rank,
|
||||
errors::InvalidArgument("Input rank should be ", kRequiredDims,
|
||||
"; got ", input_rank));
|
||||
const absl::InlinedVector<int64, 4> input_shape =
|
||||
input_tensor_shape.dim_sizes();
|
||||
|
||||
xla::XlaOp input = ctx->Input(0);
|
||||
|
||||
int feature_dim = GetTensorFeatureDimIndex(input_rank, data_format_);
|
||||
int num_spatial_dims = GetTensorSpatialDims(input_rank, data_format_);
|
||||
int feature_dim = GetTensorFeatureDimIndex(input_rank, data_format);
|
||||
int num_spatial_dims = GetTensorSpatialDims(input_rank, data_format);
|
||||
|
||||
std::vector<int64> reshaped_shape;
|
||||
std::vector<int64> transpose_order;
|
||||
@ -62,7 +75,7 @@ class SpaceToDepthOp : public XlaOpKernel {
|
||||
reshaped_shape.reserve(input_rank);
|
||||
transpose_order.reserve(input_rank);
|
||||
output_shape.reserve(input_rank);
|
||||
if (data_format_ == FORMAT_NHWC) {
|
||||
if (data_format == FORMAT_NHWC) {
|
||||
int64 block_elems = 1;
|
||||
for (int i = 0; i < num_spatial_dims; ++i) {
|
||||
OP_REQUIRES(ctx, input_shape[1 + i] % block_size_ == 0,
|
||||
@ -157,6 +170,14 @@ class SpaceToDepthOp : public XlaOpKernel {
|
||||
//
|
||||
xla::XlaOp output = xla::Reshape(permuted_reshaped, output_shape);
|
||||
|
||||
// If this used to be a vectorized format turn it back now.
|
||||
if (data_format != data_format_) {
|
||||
DCHECK(data_format == FORMAT_NCHW && data_format_ == FORMAT_NCHW_VECT_C);
|
||||
auto output_reshaped = NCHWToNCHW_VECT_C(output);
|
||||
OP_REQUIRES_OK(ctx, output_reshaped.status());
|
||||
output = output_reshaped.ValueOrDie();
|
||||
}
|
||||
|
||||
ctx->SetOutput(0, output);
|
||||
}
|
||||
|
||||
|
@ -82,3 +82,15 @@ cc_library(
|
||||
"@com_google_absl//absl/types:span",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "data_format",
|
||||
srcs = ["data_format.cc"],
|
||||
hdrs = ["data_format.h"],
|
||||
deps = [
|
||||
"//tensorflow/compiler/xla:statusor",
|
||||
"//tensorflow/compiler/xla/client:xla_builder",
|
||||
"//tensorflow/core:framework",
|
||||
"//tensorflow/core:lib",
|
||||
],
|
||||
)
|
||||
|
87
tensorflow/compiler/tf2xla/lib/data_format.cc
Normal file
87
tensorflow/compiler/tf2xla/lib/data_format.cc
Normal file
@ -0,0 +1,87 @@
|
||||
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/compiler/tf2xla/lib/data_format.h"
|
||||
#include "tensorflow/core/lib/core/errors.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace {
|
||||
|
||||
xla::StatusOr<xla::XlaOp> Contract(xla::XlaOp input, int64 dim) {
|
||||
xla::XlaBuilder* builder = input.builder();
|
||||
TF_ASSIGN_OR_RETURN(xla::Shape input_shape, builder->GetShape(input));
|
||||
|
||||
if (input_shape.dimensions().back() != 4) {
|
||||
return errors::InvalidArgument("Expected last dimension to be 4; got ",
|
||||
input_shape.dimensions().back());
|
||||
}
|
||||
|
||||
// Transpose the input so C is directly followed by VECT_C.
|
||||
std::vector<int64> permutation;
|
||||
for (int64 i = 0; i != input_shape.rank() - 1; ++i) {
|
||||
permutation.push_back(i);
|
||||
if (i == dim) {
|
||||
permutation.push_back(input_shape.rank() - 1);
|
||||
}
|
||||
}
|
||||
|
||||
// Now merge the adjacent dimensions with a reshape.
|
||||
std::vector<int64> contracted_shape(input_shape.dimensions().begin(),
|
||||
input_shape.dimensions().end() - 1);
|
||||
contracted_shape[dim] *= 4;
|
||||
|
||||
return xla::Reshape(xla::Transpose(input, permutation), contracted_shape);
|
||||
}
|
||||
|
||||
xla::StatusOr<xla::XlaOp> Expand(xla::XlaOp input, int64 dim) {
|
||||
xla::XlaBuilder* builder = input.builder();
|
||||
TF_ASSIGN_OR_RETURN(xla::Shape input_shape, builder->GetShape(input));
|
||||
|
||||
if (input_shape.dimensions(dim) % 4 != 0) {
|
||||
return errors::InvalidArgument(
|
||||
"Expected vectorized dimension to be evenly divisible by 4; got ",
|
||||
input_shape.dimensions(dim));
|
||||
}
|
||||
|
||||
// Split the `dim` into two dimensions with a reshape. The size of the new
|
||||
// dimension is always 4.
|
||||
std::vector<int64> expanded_shape(input_shape.dimensions());
|
||||
expanded_shape[dim] /= 4;
|
||||
expanded_shape.insert(expanded_shape.begin() + dim, 4);
|
||||
|
||||
// Move the newly created dimension to the end with a transpose.
|
||||
std::vector<int64> permutation;
|
||||
for (int64 i = 0; i != expanded_shape.size(); ++i) {
|
||||
permutation.push_back(i);
|
||||
if (i == dim) {
|
||||
++i;
|
||||
}
|
||||
}
|
||||
permutation.push_back(dim + 1);
|
||||
|
||||
return xla::Transpose(xla::Reshape(input, expanded_shape), permutation);
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
xla::StatusOr<xla::XlaOp> NCHW_VECT_CToNCHW(xla::XlaOp input) {
|
||||
return Contract(input, 1);
|
||||
}
|
||||
|
||||
xla::StatusOr<xla::XlaOp> NCHWToNCHW_VECT_C(xla::XlaOp input) {
|
||||
return Expand(input, 1);
|
||||
}
|
||||
|
||||
} // namespace tensorflow
|
37
tensorflow/compiler/tf2xla/lib/data_format.h
Normal file
37
tensorflow/compiler/tf2xla/lib/data_format.h
Normal file
@ -0,0 +1,37 @@
|
||||
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef TENSORFLOW_COMPILER_TF2XLA_LIB_DATA_FORMAT_H_
|
||||
#define TENSORFLOW_COMPILER_TF2XLA_LIB_DATA_FORMAT_H_
|
||||
|
||||
#include "tensorflow/compiler/xla/client/xla_builder.h"
|
||||
#include "tensorflow/compiler/xla/statusor.h"
|
||||
#include "tensorflow/core/util/tensor_format.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
// Reformat from NCHW_VECT_C to NCHW.
|
||||
//
|
||||
// Prerequisites: the last dimension of the input must be of size 4.
|
||||
xla::StatusOr<xla::XlaOp> NCHW_VECT_CToNCHW(xla::XlaOp input);
|
||||
|
||||
// Reformat from NCHW to NCHW_VECT_C.
|
||||
//
|
||||
// Prerequisites: the vectorized dimension `C` must be a multiple of 4.
|
||||
xla::StatusOr<xla::XlaOp> NCHWToNCHW_VECT_C(xla::XlaOp input);
|
||||
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // TENSORFLOW_COMPILER_TF2XLA_LIB_DATA_FORMAT_H_
|
@ -295,7 +295,6 @@ class DepthToSpaceTest(test.TestCase):
|
||||
actual_vals, expected_vals = self.evaluate([actual, expected])
|
||||
self.assertTrue(np.array_equal(actual_vals, expected_vals))
|
||||
|
||||
@test_util.disable_xla("b/123553551") # Unsupported data format
|
||||
def testAgainstTranspose(self):
|
||||
self.compareToTranspose(3, 2, 3, 1, 2, "NHWC", False)
|
||||
self.compareToTranspose(3, 2, 3, 2, 2, "NHWC", False)
|
||||
|
@ -285,7 +285,6 @@ class SpaceToDepthTest(test.TestCase):
|
||||
actual_vals, expected_vals = self.evaluate([actual, expected])
|
||||
self.assertTrue(np.array_equal(actual_vals, expected_vals))
|
||||
|
||||
@test_util.disable_xla("b/123553551") # Unsupported data format
|
||||
def testAgainstTranspose(self):
|
||||
self.compareToTranspose(3, 2, 3, 1, 2, "NHWC", False)
|
||||
self.compareToTranspose(1, 2, 3, 2, 2, "NHWC", False)
|
||||
|
Loading…
Reference in New Issue
Block a user