diff --git a/tensorflow/compiler/tests/unary_ops_test.py b/tensorflow/compiler/tests/unary_ops_test.py index f2e0eac2d99..159fa6685b5 100644 --- a/tensorflow/compiler/tests/unary_ops_test.py +++ b/tensorflow/compiler/tests/unary_ops_test.py @@ -74,7 +74,7 @@ class UnaryOpsTest(xla_test.XLATestCase): if equality_test is None: self.assertEqual(output.dtype, expected.dtype) self.assertAllCloseAccordingToType( - result, expected, rtol=rtol, atol=atol, bfloat16_rtol=0.03) + expected, result, rtol=rtol, atol=atol, bfloat16_rtol=0.03) else: equality_test(result, expected, rtol=rtol, atol=atol) @@ -956,6 +956,15 @@ class UnaryOpsTest(xla_test.XLATestCase): [[9], [10], [13], [14]], [[11], [12], [15], [16]]]], dtype=dtype), data_format)) + self._assertOpOutputMatchesExpected( + make_op("NCHW_VECT_C"), + np.arange(32, dtype=dtype).reshape((1, 8, 1, 1, 4)), + expected=np.array([[[[[0, 1], [8, 9]], [[16, 17], [24, 25]]], + [[[2, 3], [10, 11]], [[18, 19], [26, 27]]], + [[[4, 5], [12, 13]], [[20, 21], [28, 29]]], + [[[6, 7], [14, 15]], [[22, 23], [30, 31]]]]], + dtype=dtype)) + def testSpaceToDepth(self): def make_op(data_format): @@ -999,6 +1008,15 @@ class UnaryOpsTest(xla_test.XLATestCase): [13, 14, 15, 16]]]], dtype=dtype), data_format)) + self._assertOpOutputMatchesExpected( + make_op("NCHW_VECT_C"), + np.arange(32, dtype=dtype).reshape((1, 2, 2, 2, 4)), + expected=np.array([[[[[0, 1, 2, 3, 16, 17, 18, 19]]], + [[[4, 5, 6, 7, 20, 21, 22, 23]]], + [[[8, 9, 10, 11, 24, 25, 26, 27]]], + [[[12, 13, 14, 15, 28, 29, 30, 31]]]]], + dtype=dtype)) + def _assertSoftplusMatchesExpected(self, features, dtype): features = np.array(features, dtype=dtype) zero = np.asarray(0).astype(dtype) diff --git a/tensorflow/compiler/tf2xla/kernels/BUILD b/tensorflow/compiler/tf2xla/kernels/BUILD index ef7492590b1..cf297786888 100644 --- a/tensorflow/compiler/tf2xla/kernels/BUILD +++ b/tensorflow/compiler/tf2xla/kernels/BUILD @@ -127,6 +127,7 @@ tf_kernel_library( "//tensorflow/compiler/tf2xla:common", "//tensorflow/compiler/tf2xla:xla_compiler", "//tensorflow/compiler/tf2xla/lib:broadcast", + "//tensorflow/compiler/tf2xla/lib:data_format", "//tensorflow/compiler/tf2xla/lib:random", "//tensorflow/compiler/tf2xla/lib:scatter", "//tensorflow/compiler/tf2xla/lib:util", diff --git a/tensorflow/compiler/tf2xla/kernels/depthtospace_op.cc b/tensorflow/compiler/tf2xla/kernels/depthtospace_op.cc index e96a1adce43..9fe91d16d77 100644 --- a/tensorflow/compiler/tf2xla/kernels/depthtospace_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/depthtospace_op.cc @@ -13,6 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +#include "tensorflow/compiler/tf2xla/lib/data_format.h" #include "tensorflow/compiler/tf2xla/xla_helpers.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" @@ -30,11 +31,6 @@ class DepthToSpaceOp : public XlaOpKernel { OP_REQUIRES(ctx, FormatFromString(data_format_str, &data_format_), errors::InvalidArgument("Invalid data format")); - OP_REQUIRES(ctx, data_format_ == FORMAT_NCHW || data_format_ == FORMAT_NHWC, - errors::InvalidArgument("Unsupported data format ", - ToString(data_format_), - "; expected formats NHWC or NCHW")); - OP_REQUIRES_OK(ctx, ctx->GetAttr("block_size", &block_size_)); OP_REQUIRES( ctx, block_size_ > 1, @@ -42,19 +38,36 @@ class DepthToSpaceOp : public XlaOpKernel { } void Compile(XlaOpKernelContext* ctx) override { - const TensorShape input_tensor_shape = ctx->InputShape(0); - int input_rank = input_tensor_shape.dims(); + xla::XlaOp input = ctx->Input(0); + + TensorFormat data_format = data_format_; + // If the data is in a vectorized format, reformat it into a non-vectorized + // version first. We'll undo the transformation later. + if (data_format == FORMAT_NCHW_VECT_C) { + data_format = FORMAT_NCHW; + auto input_reshaped = NCHW_VECT_CToNCHW(input); + OP_REQUIRES_OK(ctx, input_reshaped.status()); + input = input_reshaped.ValueOrDie(); + } + + OP_REQUIRES(ctx, data_format == FORMAT_NCHW || data_format == FORMAT_NHWC, + errors::InvalidArgument("Unsupported data format ", + ToString(data_format_))); + + xla::XlaBuilder* builder = input.builder(); + auto input_xla_shape = builder->GetShape(input); + OP_REQUIRES_OK(ctx, input_xla_shape.status()); + const std::vector& input_shape = + input_xla_shape.ValueOrDie().dimensions(); + int input_rank = input_shape.size(); + static const int kRequiredDims = 4; OP_REQUIRES(ctx, kRequiredDims == input_rank, errors::InvalidArgument("Input rank should be ", kRequiredDims, "; got: ", input_rank)); - const absl::InlinedVector input_shape = - input_tensor_shape.dim_sizes(); - xla::XlaOp input = ctx->Input(0); - - int feature_dim = GetTensorFeatureDimIndex(input_rank, data_format_); - int num_spatial_dims = GetTensorSpatialDims(input_rank, data_format_); + int feature_dim = GetTensorFeatureDimIndex(input_rank, data_format); + int num_spatial_dims = GetTensorSpatialDims(input_rank, data_format); std::vector reshaped_shape; std::vector transpose_order; @@ -62,7 +75,7 @@ class DepthToSpaceOp : public XlaOpKernel { reshaped_shape.reserve(input_rank); transpose_order.reserve(input_rank); output_shape.reserve(input_rank); - if (data_format_ == FORMAT_NHWC) { + if (data_format == FORMAT_NHWC) { reshaped_shape.push_back(input_shape[0]); for (int i = 0; i < num_spatial_dims; ++i) { reshaped_shape.push_back(input_shape[1 + i]); @@ -153,6 +166,14 @@ class DepthToSpaceOp : public XlaOpKernel { // xla::XlaOp output = xla::Reshape(permuted_reshaped, output_shape); + // If this used to be a vectorized format turn it back now. + if (data_format != data_format_) { + DCHECK(data_format == FORMAT_NCHW && data_format_ == FORMAT_NCHW_VECT_C); + auto output_reshaped = NCHWToNCHW_VECT_C(output); + OP_REQUIRES_OK(ctx, output_reshaped.status()); + output = output_reshaped.ValueOrDie(); + } + ctx->SetOutput(0, output); } diff --git a/tensorflow/compiler/tf2xla/kernels/spacetodepth_op.cc b/tensorflow/compiler/tf2xla/kernels/spacetodepth_op.cc index 3293c13b21b..96863d6d1ba 100644 --- a/tensorflow/compiler/tf2xla/kernels/spacetodepth_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/spacetodepth_op.cc @@ -13,6 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +#include "tensorflow/compiler/tf2xla/lib/data_format.h" #include "tensorflow/compiler/tf2xla/xla_helpers.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" @@ -30,11 +31,6 @@ class SpaceToDepthOp : public XlaOpKernel { OP_REQUIRES(ctx, FormatFromString(data_format_str, &data_format_), errors::InvalidArgument("Invalid data format")); - OP_REQUIRES(ctx, data_format_ == FORMAT_NCHW || data_format_ == FORMAT_NHWC, - errors::InvalidArgument("Unsupported data format ", - ToString(data_format_), - "; expected formats NHWC or NCHW")); - OP_REQUIRES_OK(ctx, ctx->GetAttr("block_size", &block_size_)); OP_REQUIRES( ctx, block_size_ > 1, @@ -42,19 +38,36 @@ class SpaceToDepthOp : public XlaOpKernel { } void Compile(XlaOpKernelContext* ctx) override { - const TensorShape input_tensor_shape = ctx->InputShape(0); - int input_rank = input_tensor_shape.dims(); + xla::XlaOp input = ctx->Input(0); + + TensorFormat data_format = data_format_; + // If the data is in a vectorized format, reformat it into a non-vectorized + // version first. We'll undo the transformation later. + if (data_format == FORMAT_NCHW_VECT_C) { + data_format = FORMAT_NCHW; + auto input_reshaped = NCHW_VECT_CToNCHW(input); + OP_REQUIRES_OK(ctx, input_reshaped.status()); + input = input_reshaped.ValueOrDie(); + } + + OP_REQUIRES(ctx, data_format == FORMAT_NCHW || data_format == FORMAT_NHWC, + errors::InvalidArgument("Unsupported data format ", + ToString(data_format_))); + + xla::XlaBuilder* builder = input.builder(); + auto input_xla_shape = builder->GetShape(input); + OP_REQUIRES_OK(ctx, input_xla_shape.status()); + const std::vector& input_shape = + input_xla_shape.ValueOrDie().dimensions(); + int input_rank = input_shape.size(); + static const int kRequiredDims = 4; OP_REQUIRES(ctx, kRequiredDims == input_rank, errors::InvalidArgument("Input rank should be ", kRequiredDims, "; got ", input_rank)); - const absl::InlinedVector input_shape = - input_tensor_shape.dim_sizes(); - xla::XlaOp input = ctx->Input(0); - - int feature_dim = GetTensorFeatureDimIndex(input_rank, data_format_); - int num_spatial_dims = GetTensorSpatialDims(input_rank, data_format_); + int feature_dim = GetTensorFeatureDimIndex(input_rank, data_format); + int num_spatial_dims = GetTensorSpatialDims(input_rank, data_format); std::vector reshaped_shape; std::vector transpose_order; @@ -62,7 +75,7 @@ class SpaceToDepthOp : public XlaOpKernel { reshaped_shape.reserve(input_rank); transpose_order.reserve(input_rank); output_shape.reserve(input_rank); - if (data_format_ == FORMAT_NHWC) { + if (data_format == FORMAT_NHWC) { int64 block_elems = 1; for (int i = 0; i < num_spatial_dims; ++i) { OP_REQUIRES(ctx, input_shape[1 + i] % block_size_ == 0, @@ -157,6 +170,14 @@ class SpaceToDepthOp : public XlaOpKernel { // xla::XlaOp output = xla::Reshape(permuted_reshaped, output_shape); + // If this used to be a vectorized format turn it back now. + if (data_format != data_format_) { + DCHECK(data_format == FORMAT_NCHW && data_format_ == FORMAT_NCHW_VECT_C); + auto output_reshaped = NCHWToNCHW_VECT_C(output); + OP_REQUIRES_OK(ctx, output_reshaped.status()); + output = output_reshaped.ValueOrDie(); + } + ctx->SetOutput(0, output); } diff --git a/tensorflow/compiler/tf2xla/lib/BUILD b/tensorflow/compiler/tf2xla/lib/BUILD index 3d7b0bc959f..f9ce50be6e3 100644 --- a/tensorflow/compiler/tf2xla/lib/BUILD +++ b/tensorflow/compiler/tf2xla/lib/BUILD @@ -82,3 +82,15 @@ cc_library( "@com_google_absl//absl/types:span", ], ) + +cc_library( + name = "data_format", + srcs = ["data_format.cc"], + hdrs = ["data_format.h"], + deps = [ + "//tensorflow/compiler/xla:statusor", + "//tensorflow/compiler/xla/client:xla_builder", + "//tensorflow/core:framework", + "//tensorflow/core:lib", + ], +) diff --git a/tensorflow/compiler/tf2xla/lib/data_format.cc b/tensorflow/compiler/tf2xla/lib/data_format.cc new file mode 100644 index 00000000000..0253bcdc5f9 --- /dev/null +++ b/tensorflow/compiler/tf2xla/lib/data_format.cc @@ -0,0 +1,87 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/tf2xla/lib/data_format.h" +#include "tensorflow/core/lib/core/errors.h" + +namespace tensorflow { +namespace { + +xla::StatusOr Contract(xla::XlaOp input, int64 dim) { + xla::XlaBuilder* builder = input.builder(); + TF_ASSIGN_OR_RETURN(xla::Shape input_shape, builder->GetShape(input)); + + if (input_shape.dimensions().back() != 4) { + return errors::InvalidArgument("Expected last dimension to be 4; got ", + input_shape.dimensions().back()); + } + + // Transpose the input so C is directly followed by VECT_C. + std::vector permutation; + for (int64 i = 0; i != input_shape.rank() - 1; ++i) { + permutation.push_back(i); + if (i == dim) { + permutation.push_back(input_shape.rank() - 1); + } + } + + // Now merge the adjacent dimensions with a reshape. + std::vector contracted_shape(input_shape.dimensions().begin(), + input_shape.dimensions().end() - 1); + contracted_shape[dim] *= 4; + + return xla::Reshape(xla::Transpose(input, permutation), contracted_shape); +} + +xla::StatusOr Expand(xla::XlaOp input, int64 dim) { + xla::XlaBuilder* builder = input.builder(); + TF_ASSIGN_OR_RETURN(xla::Shape input_shape, builder->GetShape(input)); + + if (input_shape.dimensions(dim) % 4 != 0) { + return errors::InvalidArgument( + "Expected vectorized dimension to be evenly divisible by 4; got ", + input_shape.dimensions(dim)); + } + + // Split the `dim` into two dimensions with a reshape. The size of the new + // dimension is always 4. + std::vector expanded_shape(input_shape.dimensions()); + expanded_shape[dim] /= 4; + expanded_shape.insert(expanded_shape.begin() + dim, 4); + + // Move the newly created dimension to the end with a transpose. + std::vector permutation; + for (int64 i = 0; i != expanded_shape.size(); ++i) { + permutation.push_back(i); + if (i == dim) { + ++i; + } + } + permutation.push_back(dim + 1); + + return xla::Transpose(xla::Reshape(input, expanded_shape), permutation); +} + +} // namespace + +xla::StatusOr NCHW_VECT_CToNCHW(xla::XlaOp input) { + return Contract(input, 1); +} + +xla::StatusOr NCHWToNCHW_VECT_C(xla::XlaOp input) { + return Expand(input, 1); +} + +} // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/lib/data_format.h b/tensorflow/compiler/tf2xla/lib/data_format.h new file mode 100644 index 00000000000..839723b0ea8 --- /dev/null +++ b/tensorflow/compiler/tf2xla/lib/data_format.h @@ -0,0 +1,37 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_TF2XLA_LIB_DATA_FORMAT_H_ +#define TENSORFLOW_COMPILER_TF2XLA_LIB_DATA_FORMAT_H_ + +#include "tensorflow/compiler/xla/client/xla_builder.h" +#include "tensorflow/compiler/xla/statusor.h" +#include "tensorflow/core/util/tensor_format.h" + +namespace tensorflow { + +// Reformat from NCHW_VECT_C to NCHW. +// +// Prerequisites: the last dimension of the input must be of size 4. +xla::StatusOr NCHW_VECT_CToNCHW(xla::XlaOp input); + +// Reformat from NCHW to NCHW_VECT_C. +// +// Prerequisites: the vectorized dimension `C` must be a multiple of 4. +xla::StatusOr NCHWToNCHW_VECT_C(xla::XlaOp input); + +} // namespace tensorflow + +#endif // TENSORFLOW_COMPILER_TF2XLA_LIB_DATA_FORMAT_H_ diff --git a/tensorflow/python/kernel_tests/depthtospace_op_test.py b/tensorflow/python/kernel_tests/depthtospace_op_test.py index b7a865cf13e..96c9b5258e2 100644 --- a/tensorflow/python/kernel_tests/depthtospace_op_test.py +++ b/tensorflow/python/kernel_tests/depthtospace_op_test.py @@ -295,7 +295,6 @@ class DepthToSpaceTest(test.TestCase): actual_vals, expected_vals = self.evaluate([actual, expected]) self.assertTrue(np.array_equal(actual_vals, expected_vals)) - @test_util.disable_xla("b/123553551") # Unsupported data format def testAgainstTranspose(self): self.compareToTranspose(3, 2, 3, 1, 2, "NHWC", False) self.compareToTranspose(3, 2, 3, 2, 2, "NHWC", False) diff --git a/tensorflow/python/kernel_tests/spacetodepth_op_test.py b/tensorflow/python/kernel_tests/spacetodepth_op_test.py index 69243afb69c..e96bc09f365 100644 --- a/tensorflow/python/kernel_tests/spacetodepth_op_test.py +++ b/tensorflow/python/kernel_tests/spacetodepth_op_test.py @@ -285,7 +285,6 @@ class SpaceToDepthTest(test.TestCase): actual_vals, expected_vals = self.evaluate([actual, expected]) self.assertTrue(np.array_equal(actual_vals, expected_vals)) - @test_util.disable_xla("b/123553551") # Unsupported data format def testAgainstTranspose(self): self.compareToTranspose(3, 2, 3, 1, 2, "NHWC", False) self.compareToTranspose(1, 2, 3, 2, 2, "NHWC", False)