Support empty window and 0D convolution. For non-GPUs it's already accidentally supported; for GPUs it's not hard to support anyway.
PiperOrigin-RevId: 289523901 Change-Id: I44bd121145e5a5a6dd47cd4a63f5ceec87ef7729
This commit is contained in:
parent
b55f48a63b
commit
e5213e91eb
@ -337,7 +337,7 @@ StatusOr<GpuConvParams> GetGpuConvParams(
|
||||
|
||||
const int num_dimensions = window.dimensions_size();
|
||||
CHECK_LE(num_dimensions, 3) << conv->ToString();
|
||||
CHECK_GE(num_dimensions, 1) << conv->ToString();
|
||||
|
||||
// cuDNN does not support 1D convolutions. We therefore express 1D
|
||||
// convolutions as 2D convolutions where the first spatial dimension is 1.
|
||||
// This matches the behavior of TF (see definition of conv1d in
|
||||
@ -346,7 +346,8 @@ StatusOr<GpuConvParams> GetGpuConvParams(
|
||||
|
||||
// If one dimension is reversed, we need to have all dimensions reversed (so
|
||||
// we're doing convolution not cross correlation).
|
||||
const bool dims_reversed = window.dimensions()[0].window_reversal();
|
||||
const bool dims_reversed =
|
||||
window.dimensions_size() > 0 && window.dimensions()[0].window_reversal();
|
||||
|
||||
CHECK_EQ(num_dimensions, dnums.input_spatial_dimensions_size())
|
||||
<< conv->ToString();
|
||||
@ -429,12 +430,12 @@ StatusOr<GpuConvParams> GetGpuConvParams(
|
||||
}
|
||||
|
||||
// Add a singleton dimension in the 1D convolution case.
|
||||
if (num_dimensions == 1) {
|
||||
input_descriptor.set_spatial_dim(static_cast<DimIndex>(0), 1);
|
||||
output_descriptor.set_spatial_dim(static_cast<DimIndex>(0), 1);
|
||||
filter_descriptor.set_spatial_dim(static_cast<DimIndex>(0), 1);
|
||||
params.conv_desc.set_zero_padding(static_cast<DimIndex>(0), 0)
|
||||
.set_filter_stride(static_cast<DimIndex>(0), 1);
|
||||
for (int dim = 0; dim < effective_num_dimensions - num_dimensions; dim++) {
|
||||
input_descriptor.set_spatial_dim(static_cast<DimIndex>(dim), 1);
|
||||
output_descriptor.set_spatial_dim(static_cast<DimIndex>(dim), 1);
|
||||
filter_descriptor.set_spatial_dim(static_cast<DimIndex>(dim), 1);
|
||||
params.conv_desc.set_zero_padding(static_cast<DimIndex>(dim), 0)
|
||||
.set_filter_stride(static_cast<DimIndex>(dim), 1);
|
||||
}
|
||||
|
||||
return params;
|
||||
|
@ -106,6 +106,17 @@ ENTRY %TestComputation {
|
||||
})");
|
||||
}
|
||||
|
||||
TEST_F(GpuConvolutionRegressionTest, Conv0D) {
|
||||
CheckForHloText(R"(
|
||||
HloModule TestModule
|
||||
|
||||
ENTRY TestComputation {
|
||||
%parameter.1 = f32[10,5]{1,0} parameter(0)
|
||||
%parameter.2 = f32[5,7]{0,1} parameter(1)
|
||||
ROOT %custom-call.1 = (f32[10,7]{1,0}, u8[0]{0}) custom-call(f32[10,5]{1,0} %parameter.1, f32[5,7]{0,1} %parameter.2), window={}, dim_labels=bf_io->bf, custom_call_target="__cudnn$convForward", backend_config="{conv_result_scale:1}"
|
||||
})");
|
||||
}
|
||||
|
||||
} // namespace
|
||||
} // namespace gpu
|
||||
} // namespace xla
|
||||
|
@ -2196,7 +2196,7 @@ HloInstructionProto HloCustomCallInstruction::ToProto() const {
|
||||
std::vector<string> HloCustomCallInstruction::ExtraAttributesToStringImpl(
|
||||
const HloPrintOptions& options) const {
|
||||
std::vector<string> extra;
|
||||
if (window_ != nullptr && window_->dimensions_size() != 0) {
|
||||
if (window_ != nullptr) {
|
||||
extra.push_back(StrCat("window={", window_util::ToString(*window_), "}"));
|
||||
}
|
||||
if (convolution_dimension_numbers_ != nullptr) {
|
||||
|
@ -3146,10 +3146,6 @@ bool HloParserImpl::ParseWindow(Window* window, bool expect_outer_curlies) {
|
||||
}
|
||||
}
|
||||
|
||||
if (size.empty()) {
|
||||
return Error(loc,
|
||||
"sub-attribute 'size=' is required in the window attribute");
|
||||
}
|
||||
if (!stride.empty() && stride.size() != size.size()) {
|
||||
return Error(loc, "expects 'stride=' has the same size as 'size='");
|
||||
}
|
||||
|
@ -2008,5 +2008,17 @@ ENTRY Test {
|
||||
EXPECT_TRUE(RunAndCompare(kHlo, ErrorSpec{0.01, 0.01}));
|
||||
}
|
||||
|
||||
XLA_TEST_F(ConvolutionHloTest, TestConv0D) {
|
||||
constexpr char kHlo[] = R"(
|
||||
HloModule TestModule
|
||||
|
||||
ENTRY TestComputation {
|
||||
%parameter.1 = f32[10,5]{1,0} parameter(0)
|
||||
%parameter.2 = f32[5,7]{1,0} parameter(1)
|
||||
ROOT %convolution.3 = f32[10,7]{1,0} convolution(f32[10,5]{1,0} %parameter.1, f32[5,7]{1,0} %parameter.2), dim_labels=bf_io->bf
|
||||
})";
|
||||
EXPECT_TRUE(RunAndCompare(kHlo, ErrorSpec{0.01, 0.01}));
|
||||
}
|
||||
|
||||
} // namespace
|
||||
} // namespace xla
|
||||
|
@ -104,8 +104,10 @@ string ToString(const Window& window) {
|
||||
}
|
||||
};
|
||||
|
||||
add_field("size",
|
||||
[](const WindowDimension& dim) { return StrCat(dim.size()); });
|
||||
if (window.dimensions_size() > 0) {
|
||||
add_field("size",
|
||||
[](const WindowDimension& dim) { return StrCat(dim.size()); });
|
||||
}
|
||||
if (HasStride(window)) {
|
||||
add_field(" stride",
|
||||
[](const WindowDimension& dim) { return StrCat(dim.stride()); });
|
||||
|
Loading…
Reference in New Issue
Block a user