Support empty window and 0D convolution. For non-GPUs it's already accidentally supported; for GPUs it's not hard to support anyway.

PiperOrigin-RevId: 289523901
Change-Id: I44bd121145e5a5a6dd47cd4a63f5ceec87ef7729
This commit is contained in:
Tim Shen 2020-01-13 14:47:57 -08:00 committed by TensorFlower Gardener
parent b55f48a63b
commit e5213e91eb
6 changed files with 37 additions and 15 deletions

View File

@ -337,7 +337,7 @@ StatusOr<GpuConvParams> GetGpuConvParams(
const int num_dimensions = window.dimensions_size();
CHECK_LE(num_dimensions, 3) << conv->ToString();
CHECK_GE(num_dimensions, 1) << conv->ToString();
// cuDNN does not support 1D convolutions. We therefore express 1D
// convolutions as 2D convolutions where the first spatial dimension is 1.
// This matches the behavior of TF (see definition of conv1d in
@ -346,7 +346,8 @@ StatusOr<GpuConvParams> GetGpuConvParams(
// If one dimension is reversed, we need to have all dimensions reversed (so
// we're doing convolution not cross correlation).
const bool dims_reversed = window.dimensions()[0].window_reversal();
const bool dims_reversed =
window.dimensions_size() > 0 && window.dimensions()[0].window_reversal();
CHECK_EQ(num_dimensions, dnums.input_spatial_dimensions_size())
<< conv->ToString();
@ -429,12 +430,12 @@ StatusOr<GpuConvParams> GetGpuConvParams(
}
// Add a singleton dimension in the 1D convolution case.
if (num_dimensions == 1) {
input_descriptor.set_spatial_dim(static_cast<DimIndex>(0), 1);
output_descriptor.set_spatial_dim(static_cast<DimIndex>(0), 1);
filter_descriptor.set_spatial_dim(static_cast<DimIndex>(0), 1);
params.conv_desc.set_zero_padding(static_cast<DimIndex>(0), 0)
.set_filter_stride(static_cast<DimIndex>(0), 1);
for (int dim = 0; dim < effective_num_dimensions - num_dimensions; dim++) {
input_descriptor.set_spatial_dim(static_cast<DimIndex>(dim), 1);
output_descriptor.set_spatial_dim(static_cast<DimIndex>(dim), 1);
filter_descriptor.set_spatial_dim(static_cast<DimIndex>(dim), 1);
params.conv_desc.set_zero_padding(static_cast<DimIndex>(dim), 0)
.set_filter_stride(static_cast<DimIndex>(dim), 1);
}
return params;

View File

@ -106,6 +106,17 @@ ENTRY %TestComputation {
})");
}
TEST_F(GpuConvolutionRegressionTest, Conv0D) {
CheckForHloText(R"(
HloModule TestModule
ENTRY TestComputation {
%parameter.1 = f32[10,5]{1,0} parameter(0)
%parameter.2 = f32[5,7]{0,1} parameter(1)
ROOT %custom-call.1 = (f32[10,7]{1,0}, u8[0]{0}) custom-call(f32[10,5]{1,0} %parameter.1, f32[5,7]{0,1} %parameter.2), window={}, dim_labels=bf_io->bf, custom_call_target="__cudnn$convForward", backend_config="{conv_result_scale:1}"
})");
}
} // namespace
} // namespace gpu
} // namespace xla

View File

@ -2196,7 +2196,7 @@ HloInstructionProto HloCustomCallInstruction::ToProto() const {
std::vector<string> HloCustomCallInstruction::ExtraAttributesToStringImpl(
const HloPrintOptions& options) const {
std::vector<string> extra;
if (window_ != nullptr && window_->dimensions_size() != 0) {
if (window_ != nullptr) {
extra.push_back(StrCat("window={", window_util::ToString(*window_), "}"));
}
if (convolution_dimension_numbers_ != nullptr) {

View File

@ -3146,10 +3146,6 @@ bool HloParserImpl::ParseWindow(Window* window, bool expect_outer_curlies) {
}
}
if (size.empty()) {
return Error(loc,
"sub-attribute 'size=' is required in the window attribute");
}
if (!stride.empty() && stride.size() != size.size()) {
return Error(loc, "expects 'stride=' has the same size as 'size='");
}

View File

@ -2008,5 +2008,17 @@ ENTRY Test {
EXPECT_TRUE(RunAndCompare(kHlo, ErrorSpec{0.01, 0.01}));
}
XLA_TEST_F(ConvolutionHloTest, TestConv0D) {
constexpr char kHlo[] = R"(
HloModule TestModule
ENTRY TestComputation {
%parameter.1 = f32[10,5]{1,0} parameter(0)
%parameter.2 = f32[5,7]{1,0} parameter(1)
ROOT %convolution.3 = f32[10,7]{1,0} convolution(f32[10,5]{1,0} %parameter.1, f32[5,7]{1,0} %parameter.2), dim_labels=bf_io->bf
})";
EXPECT_TRUE(RunAndCompare(kHlo, ErrorSpec{0.01, 0.01}));
}
} // namespace
} // namespace xla

View File

@ -104,8 +104,10 @@ string ToString(const Window& window) {
}
};
add_field("size",
[](const WindowDimension& dim) { return StrCat(dim.size()); });
if (window.dimensions_size() > 0) {
add_field("size",
[](const WindowDimension& dim) { return StrCat(dim.size()); });
}
if (HasStride(window)) {
add_field(" stride",
[](const WindowDimension& dim) { return StrCat(dim.stride()); });