From e622f15b21adaaf0b707db7be6febf8a76b55e25 Mon Sep 17 00:00:00 2001 From: Jingyue Wu Date: Wed, 20 May 2020 15:54:53 -0700 Subject: [PATCH] DataFormatVecPermute accepts a vector of size 2. This partially rolls back cl/307496027. The code before cl/307496027 assumes the actual length of input_sizes is always 4 and always permutes the vector. However, this is unsafe because the length of input_sizes can also be 2. cl/307496027 made the code safe. But this way LayoutOptimizer misses some optimizations, which apparently cause more memory usage. This CL makes DataFormatVecPermute accepts a vector of size 2 as well as a vector of size 4. When the size is 2, the two dimensions are interpreted as spatial dimensions. This way LayoutOptimizer doesn't need to check the static shape of input_sizes. Instead, it applies DataFormatVecPermute regardless of the vector size. See b/156645925 for details. PiperOrigin-RevId: 312571735 Change-Id: I257e2bef328882dbbcd0fe6bf07ef1f8989daf36 --- .../compiler/mlir/tensorflow/ir/tf_ops.cc | 4 +- .../compiler/tests/data_format_ops_test.py | 10 +++ .../tf2xla/kernels/data_format_ops.cc | 26 +++++-- .../generic_layout_optimizer_test.cc | 76 +++++++------------ .../generic_layout_optimizer_transposer.cc | 19 +---- tensorflow/core/kernels/data_format_ops.cc | 43 +++++++---- tensorflow/python/ops/nn_test.py | 32 ++++++++ 7 files changed, 123 insertions(+), 87 deletions(-) diff --git a/tensorflow/compiler/mlir/tensorflow/ir/tf_ops.cc b/tensorflow/compiler/mlir/tensorflow/ir/tf_ops.cc index 95e888179e1..ea41c8224f0 100644 --- a/tensorflow/compiler/mlir/tensorflow/ir/tf_ops.cc +++ b/tensorflow/compiler/mlir/tensorflow/ir/tf_ops.cc @@ -1297,8 +1297,8 @@ static LogicalResult Verify(DataFormatVecPermuteOp op) { if (rank == 1) { int64_t dim0 = input_ty.getDimSize(0); - if (dim0 != ShapedType::kDynamicSize && dim0 != 4) - return op.emitOpError("requires 1D input of size 4"); + if (dim0 != ShapedType::kDynamicSize && dim0 != 4 && dim0 != 2) + return op.emitOpError("requires 1D input of size 4 or size 2"); } if (rank == 2) { diff --git a/tensorflow/compiler/tests/data_format_ops_test.py b/tensorflow/compiler/tests/data_format_ops_test.py index 681c1f3499e..08d44256b50 100644 --- a/tensorflow/compiler/tests/data_format_ops_test.py +++ b/tensorflow/compiler/tests/data_format_ops_test.py @@ -81,11 +81,21 @@ class XlaPermuteOpTest(xla_test.XLATestCase): x = np.array([7, 4, 9, 3], dtype=dtype) self._runPermuteAndCompare(x, "NHWC", "NCHW", [7, 3, 4, 9]) + def testNHWCToNCHW_Size2(self): + for dtype in {np.int32, np.int64}: + x = np.array([4, 9], dtype=dtype) + self._runPermuteAndCompare(x, "NHWC", "NCHW", [4, 9]) + def testNCHWToNHWC(self): for dtype in {np.int32, np.int64}: x = np.array([7, 4, 9, 3], dtype=dtype) self._runPermuteAndCompare(x, "NCHW", "NHWC", [7, 9, 3, 4]) + def testNCHWToNHWC_Size2(self): + for dtype in {np.int32, np.int64}: + x = np.array([9, 3], dtype=dtype) + self._runPermuteAndCompare(x, "NCHW", "NHWC", [9, 3]) + def testNHWCToHWNC(self): for dtype in {np.int32, np.int64}: x = np.array([7, 4, 9, 3], dtype=dtype) diff --git a/tensorflow/compiler/tf2xla/kernels/data_format_ops.cc b/tensorflow/compiler/tf2xla/kernels/data_format_ops.cc index fb89742b139..c1f60abc0d6 100644 --- a/tensorflow/compiler/tf2xla/kernels/data_format_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/data_format_ops.cc @@ -106,8 +106,9 @@ class DataFormatVecPermuteOp : public XlaOpKernel { errors::InvalidArgument( "Input must be a vector or matrix, but got shape ", input_tensor_shape.DebugString())); + const int dim0 = input_tensor_shape.dim_size(0); OP_REQUIRES( - ctx, input_tensor_shape.dim_size(0) == 4, + ctx, dim0 == 2 || dim0 == 4, errors::InvalidArgument( "First dimension of input must be of size 4, but got shape ", input_tensor_shape.DebugString())); @@ -118,10 +119,25 @@ class DataFormatVecPermuteOp : public XlaOpKernel { "Second dimension of 2D input must be of size 2, but got shape ", input_tensor_shape.DebugString())); } - int32 dst_indices[4]; - for (int i = 0; i < 4; ++i) { - for (int j = 0; j < 4; ++j) { - if (src_format_[i] == dst_format_[j]) { + + string src_format_str = src_format_; + string dst_format_str = dst_format_; + if (dim0 == 2) { + // If the input is a vector of size 2, treat the two elements as spatial + // dimensions. + auto keep_only_spatial_dimensions = [](string* format_str) -> void { + auto new_end = std::remove_if( + format_str->begin(), format_str->end(), + [](const char dim) { return dim != 'H' && dim != 'W'; }); + format_str->erase(new_end, format_str->end()); + }; + keep_only_spatial_dimensions(&src_format_str); + keep_only_spatial_dimensions(&dst_format_str); + } + std::vector dst_indices(dim0); + for (int i = 0; i < dim0; ++i) { + for (int j = 0; j < dim0; ++j) { + if (src_format_str[i] == dst_format_str[j]) { dst_indices[j] = i; break; } diff --git a/tensorflow/core/grappler/optimizers/generic_layout_optimizer_test.cc b/tensorflow/core/grappler/optimizers/generic_layout_optimizer_test.cc index c85d85e69ff..79bedf5f2e6 100644 --- a/tensorflow/core/grappler/optimizers/generic_layout_optimizer_test.cc +++ b/tensorflow/core/grappler/optimizers/generic_layout_optimizer_test.cc @@ -356,57 +356,35 @@ TEST_F(GenericLayoutOptimizerTest, Conv2DBackpropInputNonConstInputSizes) { #if !(GOOGLE_CUDA || TENSORFLOW_USE_ROCM) GTEST_SKIP() << "Neither CUDA nor ROCm is enabled"; #endif // !(GOOGLE_CUDA || TENSORFLOW_USE_ROCM) - Scope s = Scope::NewRootScope(); - auto conv = SimpleConv2DBackpropInput(&s, 7, 2, "SAME", /*dilated=*/false, - /*input_sizes_length=*/4); - Output fetch = ops::Identity(s.WithOpName("Fetch"), {conv}); - GrapplerItem item; - TF_ASSERT_OK(s.ToGraphDef(&item.graph)); + for (const int input_sizes_length : {2, 4}) { + Scope s = Scope::NewRootScope(); + auto conv = SimpleConv2DBackpropInput(&s, 7, 2, "SAME", /*dilated=*/false, + input_sizes_length); + Output fetch = ops::Identity(s.WithOpName("Fetch"), {conv}); + GrapplerItem item; + TF_ASSERT_OK(s.ToGraphDef(&item.graph)); - GenericLayoutOptimizer optimizer; - GraphDef output; - TF_ASSERT_OK(optimizer.Optimize(virtual_cluster_.get(), item, &output)); + GenericLayoutOptimizer optimizer; + GraphDef output; + TF_ASSERT_OK(optimizer.Optimize(virtual_cluster_.get(), item, &output)); - Status status; - utils::GraphView graph_view(&output, &status); - TF_ASSERT_OK(status); - auto* conv2d_backprop_node = graph_view.GetNode("Conv2DBackpropInput"); - ASSERT_NE(conv2d_backprop_node, nullptr); - ASSERT_EQ(conv2d_backprop_node->NumRegularFanins(), 3); - VerifyRegularFaninMatch( - conv2d_backprop_node, 0, - "Conv2DBackpropInput-0-DataFormatVecPermuteNHWCToNCHW-LayoutOptimizer", - 0); - auto* input_sizes_node = graph_view.GetNode( - "Conv2DBackpropInput-0-DataFormatVecPermuteNHWCToNCHW-LayoutOptimizer"); - ASSERT_NE(input_sizes_node, nullptr); - EXPECT_EQ(input_sizes_node->GetOp(), "DataFormatVecPermute"); - ASSERT_EQ(input_sizes_node->NumRegularFanins(), 1); - VerifyRegularFaninMatch(input_sizes_node, 0, "InputSizesIdentity", 0); -} - -TEST_F(GenericLayoutOptimizerTest, Conv2DBackpropInput2DInputSizes) { -#if !(GOOGLE_CUDA || TENSORFLOW_USE_ROCM) - GTEST_SKIP() << "Neither CUDA nor ROCm is enabled"; -#endif // !(GOOGLE_CUDA || TENSORFLOW_USE_ROCM) - Scope s = Scope::NewRootScope(); - auto conv = SimpleConv2DBackpropInput(&s, 7, 2, "SAME", /*dilated=*/false, - /*input_sizes_length=*/2); - Output fetch = ops::Identity(s.WithOpName("Fetch"), {conv}); - GrapplerItem item; - TF_ASSERT_OK(s.ToGraphDef(&item.graph)); - - GenericLayoutOptimizer optimizer; - GraphDef output; - TF_ASSERT_OK(optimizer.Optimize(virtual_cluster_.get(), item, &output)); - - Status status; - utils::GraphView graph_view(&output, &status); - TF_ASSERT_OK(status); - auto* conv2d_backprop_node = graph_view.GetNode("Conv2DBackpropInput"); - ASSERT_NE(conv2d_backprop_node, nullptr); - ASSERT_EQ(conv2d_backprop_node->NumRegularFanins(), 3); - VerifyRegularFaninMatch(conv2d_backprop_node, 0, "InputSizesIdentity", 0); + Status status; + utils::GraphView graph_view(&output, &status); + TF_ASSERT_OK(status); + auto* conv2d_backprop_node = graph_view.GetNode("Conv2DBackpropInput"); + ASSERT_NE(conv2d_backprop_node, nullptr); + ASSERT_EQ(conv2d_backprop_node->NumRegularFanins(), 3); + VerifyRegularFaninMatch( + conv2d_backprop_node, 0, + "Conv2DBackpropInput-0-DataFormatVecPermuteNHWCToNCHW-LayoutOptimizer", + 0); + auto* input_sizes_node = graph_view.GetNode( + "Conv2DBackpropInput-0-DataFormatVecPermuteNHWCToNCHW-LayoutOptimizer"); + ASSERT_NE(input_sizes_node, nullptr); + EXPECT_EQ(input_sizes_node->GetOp(), "DataFormatVecPermute"); + ASSERT_EQ(input_sizes_node->NumRegularFanins(), 1); + VerifyRegularFaninMatch(input_sizes_node, 0, "InputSizesIdentity", 0); + } } TEST_F(GenericLayoutOptimizerTest, Conv2DDataFormatVecPermuteCollapse) { diff --git a/tensorflow/core/grappler/optimizers/generic_layout_optimizer_transposer.cc b/tensorflow/core/grappler/optimizers/generic_layout_optimizer_transposer.cc index a5a5f7ae64a..ab7d8fcd6cf 100644 --- a/tensorflow/core/grappler/optimizers/generic_layout_optimizer_transposer.cc +++ b/tensorflow/core/grappler/optimizers/generic_layout_optimizer_transposer.cc @@ -739,28 +739,13 @@ Status Conv2DBackpropInputTransposer::TransposeNode( VLOG(3) << fanin_node->GetName() << " is not a vector."; return Status::OK(); } - int vector_size = fanin_shape.dim(0).size(); - if (vector_size == -1) { - VLOG(3) << "The number of elements in " << fanin_node->GetName() - << " is unknown."; - return Status::OK(); - } - if (vector_size != 2 && vector_size != 4) { - return errors::InvalidArgument( - fanin_node->GetName(), " must be a vector of size 2 or 4, but found ", - vector_size); - } VLOG(3) << "GenericLayoutOptimizer: transforming node '" << node->GetName() << "' with op '" << node->GetOp() << "' from data format '" << context->src_format << "' to '" << context->dst_format << "'"; TF_RETURN_IF_ERROR(UpdateNode(context, node)); - // Do not permute a input_sizes of size 2 because it represents HW regardless - // of whether NCHW or NHWC. - if (vector_size != 2) { - TF_RETURN_IF_ERROR( - UpdateFaninEdgesWithOp(context, {0}, node, kOpDataFormatVecPermute)); - } + TF_RETURN_IF_ERROR( + UpdateFaninEdgesWithOp(context, {0}, node, kOpDataFormatVecPermute)); TF_RETURN_IF_ERROR(UpdateFaninEdgesWithOp(context, {2}, node, kOpTranspose)); TF_RETURN_IF_ERROR(UpdateFanoutEdgesWithOp(context, {0}, node, kOpTranspose)); return context->graph_view->GetMutationBuilder()->Apply(); diff --git a/tensorflow/core/kernels/data_format_ops.cc b/tensorflow/core/kernels/data_format_ops.cc index 0b4241dbb93..181aa1b8a2c 100644 --- a/tensorflow/core/kernels/data_format_ops.cc +++ b/tensorflow/core/kernels/data_format_ops.cc @@ -90,16 +90,15 @@ class DataFormatVecPermuteOp : public OpKernel { "input must be a vector or 2D tensor, but got shape ", input.shape().DebugString())); if (input.dims() == 1) { - OP_REQUIRES( - context, input.NumElements() == 4, - errors::InvalidArgument("1D input must be of size 4, but got shape ", - input.shape().DebugString())); + OP_REQUIRES(context, input.NumElements() == 2 || input.NumElements() == 4, + errors::InvalidArgument( + "1D input must be of size 2 or 4, but got shape ", + input.shape().DebugString())); } else if (input.dims() == 2) { - OP_REQUIRES( - context, input.dim_size(0) == 4, - errors::InvalidArgument( - "First dimension of 2D input must be of size 4, but got shape ", - input.shape().DebugString())); + OP_REQUIRES(context, input.dim_size(0) == 2 || input.dim_size(0) == 4, + errors::InvalidArgument("First dimension of 2D input must be " + "of size 2 or 4, but got shape ", + input.shape().DebugString())); OP_REQUIRES( context, input.dim_size(1) == 2, errors::InvalidArgument( @@ -112,7 +111,21 @@ class DataFormatVecPermuteOp : public OpKernel { context->allocate_output(0, input.shape(), &output)); // Support 1D and 2D cases. Eigen::DSizes dst_idx; - ComputeDstIndex(input.dims(), &dst_idx); + string src_format_str = src_format_; + string dst_format_str = dst_format_; + if (input.dim_size(0) == 2) { + // If the input is a vector of size 2, treat the two elements as spatial + // dimensions. + auto keep_only_spatial_dimensions = [](string* format_str) -> void { + auto new_end = std::remove_if( + format_str->begin(), format_str->end(), + [](const char dim) { return dim != 'H' && dim != 'W'; }); + format_str->erase(new_end, format_str->end()); + }; + keep_only_spatial_dimensions(&src_format_str); + keep_only_spatial_dimensions(&dst_format_str); + } + ComputeDstIndex(src_format_str, dst_format_str, input.dims(), &dst_idx); functor::DataFormatVecPermute()(context->eigen_device(), input.flat(), @@ -124,10 +137,12 @@ class DataFormatVecPermuteOp : public OpKernel { // Example: HWNC --> NHWC // 1D: dst = [1, 2, 0, 3], // 2D: dst = [2, 3, 4, 5, 0, 1, 6, 7] - void ComputeDstIndex(int num_dim, Eigen::DSizes* dst) { - for (int i = 0; i < src_format_.size(); ++i) { - for (int j = 0; j < dst_format_.size(); ++j) { - if (dst_format_[j] != src_format_[i]) continue; + static void ComputeDstIndex(const string& src_format_str, + const string& dst_format_str, int num_dim, + Eigen::DSizes* dst) { + for (int i = 0; i < src_format_str.size(); ++i) { + for (int j = 0; j < dst_format_str.size(); ++j) { + if (dst_format_str[j] != src_format_str[i]) continue; // Found the dst index. Set output based on the number of dims. for (int k = 0; k < num_dim; ++k) { (*dst)[i * num_dim + k] = j * num_dim + k; diff --git a/tensorflow/python/ops/nn_test.py b/tensorflow/python/ops/nn_test.py index 860bdc60387..0088c04f909 100644 --- a/tensorflow/python/ops/nn_test.py +++ b/tensorflow/python/ops/nn_test.py @@ -1199,6 +1199,30 @@ class DataFormatVectorPermuteTest(test_lib.TestCase): y_val = self.evaluate(y) self.assertAllEqual(y_val, [7, 3, 4, 9]) + def testNHWCToNCHW_Size2(self): + x_val = [4, 9] + x = constant_op.constant(x_val) + y = nn_ops.data_format_vec_permute(x) + with test_util.use_gpu(): + y_val = self.evaluate(y) + self.assertAllEqual(y_val, [4, 9]) + + def testNHWCToWHCN(self): + x_val = [7, 4, 9, 3] + x = constant_op.constant(x_val) + y = nn_ops.data_format_vec_permute(x, src_format="NHWC", dst_format="WHCN") + with test_util.use_gpu(): + y_val = self.evaluate(y) + self.assertAllEqual(y_val, [9, 4, 3, 7]) + + def testNHWCToWHCN_Size2(self): + x_val = [4, 9] + x = constant_op.constant(x_val) + y = nn_ops.data_format_vec_permute(x, src_format="NHWC", dst_format="WHCN") + with test_util.use_gpu(): + y_val = self.evaluate(y) + self.assertAllEqual(y_val, [9, 4]) + def testNCHWToNHWC(self): x_val = [7, 4, 9, 3] x = constant_op.constant(x_val) @@ -1207,6 +1231,14 @@ class DataFormatVectorPermuteTest(test_lib.TestCase): y_val = self.evaluate(y) self.assertAllEqual(y_val, [7, 9, 3, 4]) + def testNCHWToNHWC_Size2(self): + x_val = [9, 3] + x = constant_op.constant(x_val) + y = nn_ops.data_format_vec_permute(x) + with test_util.use_gpu(): + y_val = self.evaluate(y) + self.assertAllEqual(y_val, [9, 3]) + def testNHWCToHWNC(self): x_val = [7, 4, 9, 3] x = constant_op.constant(x_val)