From edfc5938ba99cbe81ac50796f6ff647a374daf82 Mon Sep 17 00:00:00 2001 From: Adrian Kuegel Date: Mon, 26 Oct 2020 07:03:26 -0700 Subject: [PATCH] Don't match to backward input convolution in unsupported case. For grouped convolutions, we assume that in the backward input convolution case, the input and output feature dimensions of the kernel are adjacent. If that is not the case, don't treat it as backward input convolution. PiperOrigin-RevId: 339029980 Change-Id: If0b4f8a64cd3ca73e9648358d8a579ce262b27c9 --- .../xla/service/gpu/gpu_conv_rewriter.cc | 9 +++--- tensorflow/compiler/xla/tests/BUILD | 8 ++++-- .../xla/tests/grouped_convolution_test.cc | 28 +++++++++++++++++++ 3 files changed, 38 insertions(+), 7 deletions(-) diff --git a/tensorflow/compiler/xla/service/gpu/gpu_conv_rewriter.cc b/tensorflow/compiler/xla/service/gpu/gpu_conv_rewriter.cc index fb8c05798d8..a6113273d8f 100644 --- a/tensorflow/compiler/xla/service/gpu/gpu_conv_rewriter.cc +++ b/tensorflow/compiler/xla/service/gpu/gpu_conv_rewriter.cc @@ -536,11 +536,12 @@ MatchBackwardInput(HloInstruction* conv) { // 'kernel_output_feature_dimension' by 'feature_group_count'. int64 input_feature_dimension = dnums.kernel_input_feature_dimension(); int64 output_feature_dimension = dnums.kernel_output_feature_dimension(); + // The following code assumes that input_feature_dimension and + // output_feature_dimension are adjacent. + if (std::abs(input_feature_dimension - output_feature_dimension) != 1) { + return no_match_result; + } - // In the backward convolution case, the spatial dimensions become the - // feature dimensions, and we are guaranteed that the spatial dimensions are - // adjacent. - CHECK_EQ(std::abs(input_feature_dimension - output_feature_dimension), 1LL); int64 input_features = rhs->shape().dimensions(input_feature_dimension); int64 output_features = rhs->shape().dimensions(output_feature_dimension); diff --git a/tensorflow/compiler/xla/tests/BUILD b/tensorflow/compiler/xla/tests/BUILD index 98ed49ad76a..a429bf7f2bc 100644 --- a/tensorflow/compiler/xla/tests/BUILD +++ b/tensorflow/compiler/xla/tests/BUILD @@ -413,16 +413,18 @@ xla_test( ], shard_count = 50, deps = [ + ":client_library_test_base", + ":hlo_test_base", ":test_macros_header", + ":test_utils", + ":xla_internal_test_main", "//tensorflow/compiler/xla:execution_options_util", "//tensorflow/compiler/xla:status_macros", "//tensorflow/compiler/xla:test", "//tensorflow/compiler/xla/client:xla_computation", "//tensorflow/compiler/xla/service:bfloat16_normalization", "//tensorflow/compiler/xla/service:despecializer", - "//tensorflow/compiler/xla/tests:client_library_test_base", - "//tensorflow/compiler/xla/tests:hlo_test_base", - "//tensorflow/compiler/xla/tests:xla_internal_test_main", + "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/types:optional", ], ) diff --git a/tensorflow/compiler/xla/tests/grouped_convolution_test.cc b/tensorflow/compiler/xla/tests/grouped_convolution_test.cc index 4b06fe2678f..36a1ee112d4 100644 --- a/tensorflow/compiler/xla/tests/grouped_convolution_test.cc +++ b/tensorflow/compiler/xla/tests/grouped_convolution_test.cc @@ -13,6 +13,10 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +#include +#include + +#include "absl/algorithm/container.h" #include "absl/types/optional.h" #include "tensorflow/compiler/xla/client/xla_computation.h" #include "tensorflow/compiler/xla/execution_options_util.h" @@ -23,6 +27,7 @@ limitations under the License. #include "tensorflow/compiler/xla/tests/client_library_test_base.h" #include "tensorflow/compiler/xla/tests/hlo_test_base.h" #include "tensorflow/compiler/xla/tests/test_macros.h" +#include "tensorflow/compiler/xla/tests/test_utils.h" namespace xla { namespace { @@ -248,5 +253,28 @@ INSTANTIATE_TEST_CASE_P( ::testing::Bool()), GroupedConvolution2DTestDataToString); +using GroupedConvolutionTest = HloTestBase; + +XLA_TEST_F(GroupedConvolutionTest, BackwardInputConvolution) { + auto module = ParseAndReturnVerifiedModule(R"( + HloModule convolution_module + +ENTRY convolution { + p1 = f32[2,1,1,1]{3,2,1,0} parameter(0) + p2 = f32[2,4,4,1]{3,2,1,0} parameter(1) + reverse = f32[2,4,4,1]{3,2,1,0} reverse(p2), dimensions={1,2} + ROOT convolution = f32[2,4,4,1]{3,2,1,0} convolution(p1, reverse), window={size=4x4 pad=3_3x3_3}, dim_labels=fb01_o01i->f01b, feature_group_count=2 +} +)") + .ValueOrDie(); + TF_ASSERT_OK_AND_ASSIGN(auto fake_arguments, MakeFakeArguments(module.get())); + std::vector fake_argument_ptrs; + absl::c_transform( + fake_arguments, std::back_inserter(fake_argument_ptrs), + [](const Literal& literal) { return &const_cast(literal); }); + EXPECT_TRUE(RunAndCompare(std::move(module), fake_argument_ptrs, + ErrorSpec{0.01, 0.01})); +} + } // namespace } // namespace xla