Add tests for convolution 1D
RELNOTES: n/a PiperOrigin-RevId: 173060283
This commit is contained in:
parent
a699458107
commit
fd8d517b97
@ -508,21 +508,35 @@ XLA_TEST_F(ConvolutionTest, Convolve2D_1x3x3x5_3x3x5x5_Valid) {
|
||||
error_spec_);
|
||||
}
|
||||
|
||||
XLA_TEST_F(ConvolutionTest, Convolve1D_Valid) {
|
||||
struct Convolve1DTestParam {
|
||||
int64 input_feature;
|
||||
int64 output_feature;
|
||||
int64 batch;
|
||||
int64 window_size;
|
||||
int64 num_windows;
|
||||
};
|
||||
|
||||
class Convolve1D1WindowTest
|
||||
: public ConvolutionTest,
|
||||
public ::testing::WithParamInterface<Convolve1DTestParam> {};
|
||||
|
||||
XLA_TEST_P(Convolve1D1WindowTest, Convolve1D1Window) {
|
||||
ComputationBuilder builder(client_, TestName());
|
||||
int64 output_feature = 1;
|
||||
int64 input_feature = 64;
|
||||
int64 batch = 1;
|
||||
int64 length = 1;
|
||||
std::vector<int64> input_dims = {batch, 4 + length - 1, input_feature};
|
||||
std::vector<int64> filter_dims = {4, input_feature, output_feature};
|
||||
int64 input_feature = GetParam().input_feature;
|
||||
int64 output_feature = GetParam().output_feature;
|
||||
int64 batch = GetParam().batch;
|
||||
int64 num_windows = GetParam().num_windows;
|
||||
int64 window_size = GetParam().window_size;
|
||||
std::vector<int64> input_dims = {batch, window_size + num_windows - 1,
|
||||
input_feature};
|
||||
std::vector<int64> filter_dims = {window_size, input_feature, output_feature};
|
||||
Shape input_shape = ShapeUtil::MakeShape(F32, input_dims);
|
||||
Shape filter_shape = ShapeUtil::MakeShape(F32, filter_dims);
|
||||
{
|
||||
auto input = builder.Parameter(0, input_shape, "input");
|
||||
auto filter = builder.Parameter(1, filter_shape, "filter");
|
||||
|
||||
// Tensorflow dimension numbers for 2D convolution.
|
||||
// Tensorflow dimension numbers for 1D convolution.
|
||||
ConvolutionDimensionNumbers dnums;
|
||||
dnums.set_input_batch_dimension(0);
|
||||
dnums.set_output_batch_dimension(0);
|
||||
@ -538,28 +552,57 @@ XLA_TEST_F(ConvolutionTest, Convolve1D_Valid) {
|
||||
}
|
||||
|
||||
std::vector<float> input_elems(ShapeUtil::ElementsIn(input_shape), 1.0);
|
||||
// std::iota(input_elems.begin(), input_elems.end(), 1.0f);
|
||||
auto input_r1 = Literal::CreateR1<float>(input_elems);
|
||||
auto input_r4 = input_r1->Reshape(input_dims).ConsumeValueOrDie();
|
||||
auto input_r3 = input_r1->Reshape(input_dims).ConsumeValueOrDie();
|
||||
|
||||
std::vector<float> filter_elems(ShapeUtil::ElementsIn(filter_shape), 1.0);
|
||||
// std::iota(filter_elems.begin(), filter_elems.end(), 1.0f);
|
||||
|
||||
auto filter_r1 = Literal::CreateR1<float>(filter_elems);
|
||||
auto filter_r4 = filter_r1->Reshape(filter_dims).ConsumeValueOrDie();
|
||||
auto filter_r3 = filter_r1->Reshape(filter_dims).ConsumeValueOrDie();
|
||||
|
||||
std::vector<float> expect_elems(batch * output_feature * length, 256);
|
||||
std::vector<float> expect_elems(batch * output_feature * num_windows,
|
||||
window_size * input_feature);
|
||||
auto expected_r1 = Literal::CreateR1<float>(expect_elems);
|
||||
auto expected_r4 =
|
||||
expected_r1->Reshape({batch, length, output_feature}).ConsumeValueOrDie();
|
||||
auto expected_r3 = expected_r1->Reshape({batch, num_windows, output_feature})
|
||||
.ConsumeValueOrDie();
|
||||
|
||||
auto input_literal = client_->TransferToServer(*input_r4).ConsumeValueOrDie();
|
||||
auto input_literal = client_->TransferToServer(*input_r3).ConsumeValueOrDie();
|
||||
auto filter_literal =
|
||||
client_->TransferToServer(*filter_r4).ConsumeValueOrDie();
|
||||
ComputeAndCompareLiteral(&builder, *expected_r4,
|
||||
client_->TransferToServer(*filter_r3).ConsumeValueOrDie();
|
||||
ComputeAndCompareLiteral(&builder, *expected_r3,
|
||||
{input_literal.get(), filter_literal.get()},
|
||||
error_spec_);
|
||||
}
|
||||
|
||||
INSTANTIATE_TEST_CASE_P(
|
||||
Convolve1D1WindowTest_Instantiation, Convolve1D1WindowTest,
|
||||
::testing::Values(Convolve1DTestParam{1, 1, 1, 1, 2},
|
||||
Convolve1DTestParam{160, 1, 1, 5, 1},
|
||||
Convolve1DTestParam{24, 1, 1, 20, 1},
|
||||
Convolve1DTestParam{30, 1, 1, 20, 1},
|
||||
Convolve1DTestParam{23, 1, 1, 20, 20},
|
||||
Convolve1DTestParam{25, 1, 1, 20, 1},
|
||||
Convolve1DTestParam{24, 1, 1, 10, 5},
|
||||
Convolve1DTestParam{160, 1, 1, 10, 1},
|
||||
Convolve1DTestParam{255, 1, 1, 3, 1},
|
||||
Convolve1DTestParam{130, 1, 1, 1, 3},
|
||||
Convolve1DTestParam{64, 1, 1, 1, 1},
|
||||
Convolve1DTestParam{128, 1, 1, 1, 1},
|
||||
Convolve1DTestParam{139, 1, 1, 128, 1},
|
||||
Convolve1DTestParam{1, 10, 10, 1, 10},
|
||||
Convolve1DTestParam{1, 10, 130, 1, 2},
|
||||
Convolve1DTestParam{1, 10, 130, 1, 1},
|
||||
Convolve1DTestParam{1, 64, 64, 1, 10},
|
||||
Convolve1DTestParam{1, 65, 65, 1, 1},
|
||||
Convolve1DTestParam{1, 128, 128, 1, 1},
|
||||
Convolve1DTestParam{128, 128, 128, 128, 1},
|
||||
Convolve1DTestParam{1, 128, 128, 1, 1},
|
||||
Convolve1DTestParam{2, 2, 2, 2, 1},
|
||||
Convolve1DTestParam{161, 1, 1, 10, 1},
|
||||
Convolve1DTestParam{900, 1, 1, 10, 1},
|
||||
Convolve1DTestParam{640, 3, 3, 128, 1})
|
||||
|
||||
);
|
||||
|
||||
} // namespace
|
||||
} // namespace xla
|
||||
|
Loading…
Reference in New Issue
Block a user