Add tests for convolution 1D

RELNOTES: n/a

PiperOrigin-RevId: 173060283
This commit is contained in:
Yunxing Dai 2017-10-22 16:48:19 -07:00 committed by TensorFlower Gardener
parent a699458107
commit fd8d517b97

View File

@ -508,21 +508,35 @@ XLA_TEST_F(ConvolutionTest, Convolve2D_1x3x3x5_3x3x5x5_Valid) {
error_spec_);
}
XLA_TEST_F(ConvolutionTest, Convolve1D_Valid) {
struct Convolve1DTestParam {
int64 input_feature;
int64 output_feature;
int64 batch;
int64 window_size;
int64 num_windows;
};
class Convolve1D1WindowTest
: public ConvolutionTest,
public ::testing::WithParamInterface<Convolve1DTestParam> {};
XLA_TEST_P(Convolve1D1WindowTest, Convolve1D1Window) {
ComputationBuilder builder(client_, TestName());
int64 output_feature = 1;
int64 input_feature = 64;
int64 batch = 1;
int64 length = 1;
std::vector<int64> input_dims = {batch, 4 + length - 1, input_feature};
std::vector<int64> filter_dims = {4, input_feature, output_feature};
int64 input_feature = GetParam().input_feature;
int64 output_feature = GetParam().output_feature;
int64 batch = GetParam().batch;
int64 num_windows = GetParam().num_windows;
int64 window_size = GetParam().window_size;
std::vector<int64> input_dims = {batch, window_size + num_windows - 1,
input_feature};
std::vector<int64> filter_dims = {window_size, input_feature, output_feature};
Shape input_shape = ShapeUtil::MakeShape(F32, input_dims);
Shape filter_shape = ShapeUtil::MakeShape(F32, filter_dims);
{
auto input = builder.Parameter(0, input_shape, "input");
auto filter = builder.Parameter(1, filter_shape, "filter");
// Tensorflow dimension numbers for 2D convolution.
// Tensorflow dimension numbers for 1D convolution.
ConvolutionDimensionNumbers dnums;
dnums.set_input_batch_dimension(0);
dnums.set_output_batch_dimension(0);
@ -538,28 +552,57 @@ XLA_TEST_F(ConvolutionTest, Convolve1D_Valid) {
}
std::vector<float> input_elems(ShapeUtil::ElementsIn(input_shape), 1.0);
// std::iota(input_elems.begin(), input_elems.end(), 1.0f);
auto input_r1 = Literal::CreateR1<float>(input_elems);
auto input_r4 = input_r1->Reshape(input_dims).ConsumeValueOrDie();
auto input_r3 = input_r1->Reshape(input_dims).ConsumeValueOrDie();
std::vector<float> filter_elems(ShapeUtil::ElementsIn(filter_shape), 1.0);
// std::iota(filter_elems.begin(), filter_elems.end(), 1.0f);
auto filter_r1 = Literal::CreateR1<float>(filter_elems);
auto filter_r4 = filter_r1->Reshape(filter_dims).ConsumeValueOrDie();
auto filter_r3 = filter_r1->Reshape(filter_dims).ConsumeValueOrDie();
std::vector<float> expect_elems(batch * output_feature * length, 256);
std::vector<float> expect_elems(batch * output_feature * num_windows,
window_size * input_feature);
auto expected_r1 = Literal::CreateR1<float>(expect_elems);
auto expected_r4 =
expected_r1->Reshape({batch, length, output_feature}).ConsumeValueOrDie();
auto expected_r3 = expected_r1->Reshape({batch, num_windows, output_feature})
.ConsumeValueOrDie();
auto input_literal = client_->TransferToServer(*input_r4).ConsumeValueOrDie();
auto input_literal = client_->TransferToServer(*input_r3).ConsumeValueOrDie();
auto filter_literal =
client_->TransferToServer(*filter_r4).ConsumeValueOrDie();
ComputeAndCompareLiteral(&builder, *expected_r4,
client_->TransferToServer(*filter_r3).ConsumeValueOrDie();
ComputeAndCompareLiteral(&builder, *expected_r3,
{input_literal.get(), filter_literal.get()},
error_spec_);
}
INSTANTIATE_TEST_CASE_P(
Convolve1D1WindowTest_Instantiation, Convolve1D1WindowTest,
::testing::Values(Convolve1DTestParam{1, 1, 1, 1, 2},
Convolve1DTestParam{160, 1, 1, 5, 1},
Convolve1DTestParam{24, 1, 1, 20, 1},
Convolve1DTestParam{30, 1, 1, 20, 1},
Convolve1DTestParam{23, 1, 1, 20, 20},
Convolve1DTestParam{25, 1, 1, 20, 1},
Convolve1DTestParam{24, 1, 1, 10, 5},
Convolve1DTestParam{160, 1, 1, 10, 1},
Convolve1DTestParam{255, 1, 1, 3, 1},
Convolve1DTestParam{130, 1, 1, 1, 3},
Convolve1DTestParam{64, 1, 1, 1, 1},
Convolve1DTestParam{128, 1, 1, 1, 1},
Convolve1DTestParam{139, 1, 1, 128, 1},
Convolve1DTestParam{1, 10, 10, 1, 10},
Convolve1DTestParam{1, 10, 130, 1, 2},
Convolve1DTestParam{1, 10, 130, 1, 1},
Convolve1DTestParam{1, 64, 64, 1, 10},
Convolve1DTestParam{1, 65, 65, 1, 1},
Convolve1DTestParam{1, 128, 128, 1, 1},
Convolve1DTestParam{128, 128, 128, 128, 1},
Convolve1DTestParam{1, 128, 128, 1, 1},
Convolve1DTestParam{2, 2, 2, 2, 1},
Convolve1DTestParam{161, 1, 1, 10, 1},
Convolve1DTestParam{900, 1, 1, 10, 1},
Convolve1DTestParam{640, 3, 3, 128, 1})
);
} // namespace
} // namespace xla