From e5213e91eb961c7d2d12a296dca47cfb1986e712 Mon Sep 17 00:00:00 2001 From: Tim Shen Date: Mon, 13 Jan 2020 14:47:57 -0800 Subject: [PATCH] Support empty window and 0D convolution. For non-GPUs it's already accidentally supported; for GPUs it's not hard to support anyway. PiperOrigin-RevId: 289523901 Change-Id: I44bd121145e5a5a6dd47cd4a63f5ceec87ef7729 --- .../compiler/xla/service/gpu/gpu_conv_runner.cc | 17 +++++++++-------- .../tests/gpu_convolution_regression_test.cc | 11 +++++++++++ .../compiler/xla/service/hlo_instructions.cc | 2 +- tensorflow/compiler/xla/service/hlo_parser.cc | 4 ---- .../compiler/xla/tests/convolution_test.cc | 12 ++++++++++++ tensorflow/compiler/xla/window_util.cc | 6 ++++-- 6 files changed, 37 insertions(+), 15 deletions(-) diff --git a/tensorflow/compiler/xla/service/gpu/gpu_conv_runner.cc b/tensorflow/compiler/xla/service/gpu/gpu_conv_runner.cc index 03da7cebec5..ea6d1666c56 100644 --- a/tensorflow/compiler/xla/service/gpu/gpu_conv_runner.cc +++ b/tensorflow/compiler/xla/service/gpu/gpu_conv_runner.cc @@ -337,7 +337,7 @@ StatusOr GetGpuConvParams( const int num_dimensions = window.dimensions_size(); CHECK_LE(num_dimensions, 3) << conv->ToString(); - CHECK_GE(num_dimensions, 1) << conv->ToString(); + // cuDNN does not support 1D convolutions. We therefore express 1D // convolutions as 2D convolutions where the first spatial dimension is 1. // This matches the behavior of TF (see definition of conv1d in @@ -346,7 +346,8 @@ StatusOr GetGpuConvParams( // If one dimension is reversed, we need to have all dimensions reversed (so // we're doing convolution not cross correlation). - const bool dims_reversed = window.dimensions()[0].window_reversal(); + const bool dims_reversed = + window.dimensions_size() > 0 && window.dimensions()[0].window_reversal(); CHECK_EQ(num_dimensions, dnums.input_spatial_dimensions_size()) << conv->ToString(); @@ -429,12 +430,12 @@ StatusOr GetGpuConvParams( } // Add a singleton dimension in the 1D convolution case. - if (num_dimensions == 1) { - input_descriptor.set_spatial_dim(static_cast(0), 1); - output_descriptor.set_spatial_dim(static_cast(0), 1); - filter_descriptor.set_spatial_dim(static_cast(0), 1); - params.conv_desc.set_zero_padding(static_cast(0), 0) - .set_filter_stride(static_cast(0), 1); + for (int dim = 0; dim < effective_num_dimensions - num_dimensions; dim++) { + input_descriptor.set_spatial_dim(static_cast(dim), 1); + output_descriptor.set_spatial_dim(static_cast(dim), 1); + filter_descriptor.set_spatial_dim(static_cast(dim), 1); + params.conv_desc.set_zero_padding(static_cast(dim), 0) + .set_filter_stride(static_cast(dim), 1); } return params; diff --git a/tensorflow/compiler/xla/service/gpu/tests/gpu_convolution_regression_test.cc b/tensorflow/compiler/xla/service/gpu/tests/gpu_convolution_regression_test.cc index 7433414c800..2a84b66d101 100644 --- a/tensorflow/compiler/xla/service/gpu/tests/gpu_convolution_regression_test.cc +++ b/tensorflow/compiler/xla/service/gpu/tests/gpu_convolution_regression_test.cc @@ -106,6 +106,17 @@ ENTRY %TestComputation { })"); } +TEST_F(GpuConvolutionRegressionTest, Conv0D) { + CheckForHloText(R"( +HloModule TestModule + +ENTRY TestComputation { + %parameter.1 = f32[10,5]{1,0} parameter(0) + %parameter.2 = f32[5,7]{0,1} parameter(1) + ROOT %custom-call.1 = (f32[10,7]{1,0}, u8[0]{0}) custom-call(f32[10,5]{1,0} %parameter.1, f32[5,7]{0,1} %parameter.2), window={}, dim_labels=bf_io->bf, custom_call_target="__cudnn$convForward", backend_config="{conv_result_scale:1}" +})"); +} + } // namespace } // namespace gpu } // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_instructions.cc b/tensorflow/compiler/xla/service/hlo_instructions.cc index 94b5926d876..0ed8d767953 100644 --- a/tensorflow/compiler/xla/service/hlo_instructions.cc +++ b/tensorflow/compiler/xla/service/hlo_instructions.cc @@ -2196,7 +2196,7 @@ HloInstructionProto HloCustomCallInstruction::ToProto() const { std::vector HloCustomCallInstruction::ExtraAttributesToStringImpl( const HloPrintOptions& options) const { std::vector extra; - if (window_ != nullptr && window_->dimensions_size() != 0) { + if (window_ != nullptr) { extra.push_back(StrCat("window={", window_util::ToString(*window_), "}")); } if (convolution_dimension_numbers_ != nullptr) { diff --git a/tensorflow/compiler/xla/service/hlo_parser.cc b/tensorflow/compiler/xla/service/hlo_parser.cc index ecb25298288..d6e8a8be893 100644 --- a/tensorflow/compiler/xla/service/hlo_parser.cc +++ b/tensorflow/compiler/xla/service/hlo_parser.cc @@ -3146,10 +3146,6 @@ bool HloParserImpl::ParseWindow(Window* window, bool expect_outer_curlies) { } } - if (size.empty()) { - return Error(loc, - "sub-attribute 'size=' is required in the window attribute"); - } if (!stride.empty() && stride.size() != size.size()) { return Error(loc, "expects 'stride=' has the same size as 'size='"); } diff --git a/tensorflow/compiler/xla/tests/convolution_test.cc b/tensorflow/compiler/xla/tests/convolution_test.cc index 097265f3bb1..6ff0f9d6b2a 100644 --- a/tensorflow/compiler/xla/tests/convolution_test.cc +++ b/tensorflow/compiler/xla/tests/convolution_test.cc @@ -2008,5 +2008,17 @@ ENTRY Test { EXPECT_TRUE(RunAndCompare(kHlo, ErrorSpec{0.01, 0.01})); } +XLA_TEST_F(ConvolutionHloTest, TestConv0D) { + constexpr char kHlo[] = R"( +HloModule TestModule + +ENTRY TestComputation { + %parameter.1 = f32[10,5]{1,0} parameter(0) + %parameter.2 = f32[5,7]{1,0} parameter(1) + ROOT %convolution.3 = f32[10,7]{1,0} convolution(f32[10,5]{1,0} %parameter.1, f32[5,7]{1,0} %parameter.2), dim_labels=bf_io->bf +})"; + EXPECT_TRUE(RunAndCompare(kHlo, ErrorSpec{0.01, 0.01})); +} + } // namespace } // namespace xla diff --git a/tensorflow/compiler/xla/window_util.cc b/tensorflow/compiler/xla/window_util.cc index f660116771b..a58179c3ee0 100644 --- a/tensorflow/compiler/xla/window_util.cc +++ b/tensorflow/compiler/xla/window_util.cc @@ -104,8 +104,10 @@ string ToString(const Window& window) { } }; - add_field("size", - [](const WindowDimension& dim) { return StrCat(dim.size()); }); + if (window.dimensions_size() > 0) { + add_field("size", + [](const WindowDimension& dim) { return StrCat(dim.size()); }); + } if (HasStride(window)) { add_field(" stride", [](const WindowDimension& dim) { return StrCat(dim.stride()); });