Split convolution tests into two.
PiperOrigin-RevId: 321634608 Change-Id: I1dfd1c5ab7010af10962ec021cc66f8fe9c6ce6e
This commit is contained in:
parent
0e9e9ea8c0
commit
57680ec9be
@ -1115,10 +1115,24 @@ xla_test(
|
||||
name = "convolution_test",
|
||||
timeout = "long",
|
||||
srcs = ["convolution_test.cc"],
|
||||
shard_count = 40,
|
||||
shard_count = 50,
|
||||
tags = [
|
||||
"no_rocm",
|
||||
"optonly",
|
||||
],
|
||||
deps = CONVOLUTION_TEST_DEPS + [
|
||||
"@com_google_absl//absl/memory",
|
||||
"@com_google_absl//absl/strings",
|
||||
],
|
||||
)
|
||||
|
||||
xla_test(
|
||||
name = "convolution_test_1d",
|
||||
timeout = "long",
|
||||
srcs = ["convolution_test_1d.cc"],
|
||||
shard_count = 50,
|
||||
tags = [
|
||||
"no_rocm",
|
||||
"nozapfhahn",
|
||||
"optonly",
|
||||
],
|
||||
deps = CONVOLUTION_TEST_DEPS + [
|
||||
@ -1147,6 +1161,23 @@ xla_test(
|
||||
],
|
||||
)
|
||||
|
||||
xla_test(
|
||||
name = "convolution_test_1d_autotune_disabled",
|
||||
timeout = "long",
|
||||
srcs = ["convolution_test_1d.cc"],
|
||||
args = ["--xla_gpu_autotune_level=0"],
|
||||
backends = ["gpu"],
|
||||
shard_count = 40,
|
||||
tags = [
|
||||
"no_rocm",
|
||||
"optonly",
|
||||
],
|
||||
deps = CONVOLUTION_TEST_DEPS + [
|
||||
"@com_google_absl//absl/memory",
|
||||
"@com_google_absl//absl/strings",
|
||||
],
|
||||
)
|
||||
|
||||
xla_test(
|
||||
name = "convolution_test_gpu_alternative_layout",
|
||||
timeout = "long",
|
||||
@ -1163,6 +1194,22 @@ xla_test(
|
||||
],
|
||||
)
|
||||
|
||||
xla_test(
|
||||
name = "convolution_test_1d_gpu_alternative_layout",
|
||||
timeout = "long",
|
||||
srcs = ["convolution_test_1d.cc"],
|
||||
backend_args = {"gpu": ["--xla_backend_extra_options=xla_gpu_experimental_conv_disable_layout_heuristic"]},
|
||||
backends = ["gpu"],
|
||||
shard_count = 25,
|
||||
tags = [
|
||||
"no_rocm",
|
||||
],
|
||||
deps = CONVOLUTION_TEST_DEPS + [
|
||||
"@com_google_absl//absl/memory",
|
||||
"@com_google_absl//absl/strings",
|
||||
],
|
||||
)
|
||||
|
||||
xla_test(
|
||||
name = "convolution_variants_test",
|
||||
timeout = "long",
|
||||
|
@ -13,7 +13,7 @@ See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
// Tests of convolution with trivial kernels and no special variations (like
|
||||
// Tests of 2+D convolution with trivial kernels and no special variations (like
|
||||
// strides and padding).
|
||||
|
||||
#include <memory>
|
||||
@ -240,174 +240,6 @@ class Convolve_1x1x4x4_1x1x3x3_Same : public ConvolutionTest {
|
||||
TYPED_TEST_CASE(Convolve_1x1x4x4_1x1x3x3_Same, TestTypes);
|
||||
TYPED_TEST(Convolve_1x1x4x4_1x1x3x3_Same, Types) { this->RunTest(); }
|
||||
|
||||
XLA_TEST_F(ConvolutionTest, Convolve1D_1x2x5_1x2x2_Valid) {
|
||||
XlaBuilder builder(TestName());
|
||||
{
|
||||
Shape input_shape = ShapeUtil::MakeShape(F32, {1, 2, 5});
|
||||
Shape filter_shape = ShapeUtil::MakeShape(F32, {1, 2, 2});
|
||||
auto input = Parameter(&builder, 0, input_shape, "input");
|
||||
auto filter = Parameter(&builder, 1, filter_shape, "filter");
|
||||
Conv(input, filter, {1}, Padding::kValid);
|
||||
}
|
||||
|
||||
Array3D<float> input({{{1, 2, 3, 4, 5}, {6, 7, 8, 9, 10}}});
|
||||
Array3D<float> filter({{{10, 20}, {30, 40}}});
|
||||
|
||||
Array3D<float> expected({{{510, 610, 710, 810}}});
|
||||
|
||||
auto input_literal =
|
||||
client_->TransferToServer(LiteralUtil::CreateR3FromArray3D(input))
|
||||
.ConsumeValueOrDie();
|
||||
auto filter_literal =
|
||||
client_->TransferToServer(LiteralUtil::CreateR3FromArray3D(filter))
|
||||
.ConsumeValueOrDie();
|
||||
|
||||
ComputeAndCompareR3<float>(&builder, expected,
|
||||
{input_literal.get(), filter_literal.get()},
|
||||
error_spec_);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
class Convolve1D_1x2x5_1x2x2_WithRHSDilation : public ConvolutionTest {
|
||||
public:
|
||||
void RunTest() {
|
||||
XlaBuilder builder(TestName());
|
||||
{
|
||||
Shape input_shape = ShapeUtil::MakeShapeWithType<T>({1, 2, 5});
|
||||
Shape filter_shape = ShapeUtil::MakeShapeWithType<T>({1, 2, 2});
|
||||
auto input = Parameter(&builder, 0, input_shape, "input");
|
||||
auto filter = Parameter(&builder, 1, filter_shape, "filter");
|
||||
// Convolution dimensions are bf0_oi0->bo0.
|
||||
ConvGeneralDilated(
|
||||
input, filter, /*window_strides=*/{1}, /*padding=*/{{0, 0}},
|
||||
/*lhs_dilation=*/{1}, /*rhs_dilation=*/{2},
|
||||
/*dimension_numbers=*/builder.CreateDefaultConvDimensionNumbers(1));
|
||||
}
|
||||
|
||||
Array3D<T> input(
|
||||
{{{1.0f, 2.0f, 3.0f, 4.0f, 5.0f}, {6.0f, 7.0f, 8.0f, 9.0f, 10.0f}}});
|
||||
Array3D<T> filter({{{10.0f, 20.0f}, {30.0f, 40.0f}}});
|
||||
|
||||
Array3D<T> expected({{{570.0f, 670.0f, 770.0f}}});
|
||||
|
||||
auto input_literal =
|
||||
client_->TransferToServer(LiteralUtil::CreateR3FromArray3D(input))
|
||||
.ConsumeValueOrDie();
|
||||
auto filter_literal =
|
||||
client_->TransferToServer(LiteralUtil::CreateR3FromArray3D(filter))
|
||||
.ConsumeValueOrDie();
|
||||
|
||||
ComputeAndCompareR3<T>(&builder, expected,
|
||||
{input_literal.get(), filter_literal.get()},
|
||||
error_spec_);
|
||||
}
|
||||
}; // namespace
|
||||
|
||||
TYPED_TEST_CASE(Convolve1D_1x2x5_1x2x2_WithRHSDilation, TestTypes);
|
||||
TYPED_TEST(Convolve1D_1x2x5_1x2x2_WithRHSDilation, Types) { this->RunTest(); }
|
||||
|
||||
XLA_TEST_F(ConvolutionTest, Convolve1D_1x2x5_1x2x2_WithLHSDilation) {
|
||||
XlaBuilder builder(TestName());
|
||||
{
|
||||
Shape input_shape = ShapeUtil::MakeShape(F32, {1, 2, 5});
|
||||
Shape filter_shape = ShapeUtil::MakeShape(F32, {1, 2, 2});
|
||||
auto input = Parameter(&builder, 0, input_shape, "input");
|
||||
auto filter = Parameter(&builder, 1, filter_shape, "filter");
|
||||
// Convolution dimensions are bf0_oi0->bo0.
|
||||
ConvGeneralDilated(
|
||||
input, filter, /*window_strides=*/{1}, /*padding=*/{{0, 0}},
|
||||
/*lhs_dilation=*/{2}, /*rhs_dilation=*/{1},
|
||||
/*dimension_numbers=*/builder.CreateDefaultConvDimensionNumbers(1));
|
||||
}
|
||||
|
||||
Array3D<float> input({{{1, 2, 3, 4, 5}, {6, 7, 8, 9, 10}}});
|
||||
Array3D<float> filter({{{10, 20}, {30, 40}}});
|
||||
|
||||
Array3D<float> expected({{{190, 320, 230, 380, 270, 440, 310, 500}}});
|
||||
|
||||
auto input_literal =
|
||||
client_->TransferToServer(LiteralUtil::CreateR3FromArray3D(input))
|
||||
.ConsumeValueOrDie();
|
||||
auto filter_literal =
|
||||
client_->TransferToServer(LiteralUtil::CreateR3FromArray3D(filter))
|
||||
.ConsumeValueOrDie();
|
||||
|
||||
ComputeAndCompareR3<float>(&builder, expected,
|
||||
{input_literal.get(), filter_literal.get()},
|
||||
error_spec_);
|
||||
}
|
||||
|
||||
XLA_TEST_F(ConvolutionTest, Convolve1D_1x2x5_1x2x2_WithLHSAndRHSDilation) {
|
||||
XlaBuilder builder(TestName());
|
||||
{
|
||||
Shape input_shape = ShapeUtil::MakeShape(F32, {1, 2, 5});
|
||||
Shape filter_shape = ShapeUtil::MakeShape(F32, {1, 2, 2});
|
||||
auto input = Parameter(&builder, 0, input_shape, "input");
|
||||
auto filter = Parameter(&builder, 1, filter_shape, "filter");
|
||||
// Convolution dimensions are bf0_oi0->bo0.
|
||||
ConvGeneralDilated(
|
||||
input, filter, /*window_strides=*/{1}, /*padding=*/{{0, 0}},
|
||||
/*lhs_dilation=*/{2}, /*rhs_dilation=*/{2},
|
||||
/*dimension_numbers=*/builder.CreateDefaultConvDimensionNumbers(1));
|
||||
}
|
||||
|
||||
Array3D<float> input({{{1, 2, 3, 4, 5}, {6, 7, 8, 9, 10}}});
|
||||
Array3D<float> filter({{{10, 20}, {30, 40}}});
|
||||
|
||||
Array3D<float> expected({{{510, 0, 610, 0, 710, 0, 810}}});
|
||||
|
||||
auto input_literal =
|
||||
client_->TransferToServer(LiteralUtil::CreateR3FromArray3D(input))
|
||||
.ConsumeValueOrDie();
|
||||
auto filter_literal =
|
||||
client_->TransferToServer(LiteralUtil::CreateR3FromArray3D(filter))
|
||||
.ConsumeValueOrDie();
|
||||
|
||||
ComputeAndCompareR3<float>(&builder, expected,
|
||||
{input_literal.get(), filter_literal.get()},
|
||||
error_spec_);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
class Convolve1D_1x2x5_1x2x2_WithPadding : public ConvolutionTest {
|
||||
public:
|
||||
void RunTest() {
|
||||
XlaBuilder builder(TestName());
|
||||
{
|
||||
Shape input_shape = ShapeUtil::MakeShapeWithType<T>({1, 2, 5});
|
||||
Shape filter_shape = ShapeUtil::MakeShapeWithType<T>({1, 2, 2});
|
||||
auto input = Parameter(&builder, 0, input_shape, "input");
|
||||
auto filter = Parameter(&builder, 1, filter_shape, "filter");
|
||||
// Convolution dimensions are bf0_oi0->bo0.
|
||||
ConvGeneralDilated(
|
||||
input, filter, /*window_strides=*/{1}, /*padding=*/{{2, 2}},
|
||||
/*lhs_dilation=*/{1}, /*rhs_dilation=*/{1},
|
||||
/*dimension_numbers=*/builder.CreateDefaultConvDimensionNumbers(1));
|
||||
}
|
||||
|
||||
Array3D<T> input(
|
||||
{{{1.0f, 2.0f, 3.0f, 4.0f, 5.0f}, {6.0f, 7.0f, 8.0f, 9.0f, 10.0f}}});
|
||||
Array3D<T> filter({{{10.0f, 20.0f}, {30.0f, 40.0f}}});
|
||||
|
||||
Array3D<T> expected(
|
||||
{{{0.0f, 260.0f, 510.0f, 610.0f, 710.0f, 810.0f, 350.0f, 0.0f}}});
|
||||
|
||||
auto input_literal =
|
||||
client_->TransferToServer(LiteralUtil::CreateR3FromArray3D(input))
|
||||
.ConsumeValueOrDie();
|
||||
auto filter_literal =
|
||||
client_->TransferToServer(LiteralUtil::CreateR3FromArray3D(filter))
|
||||
.ConsumeValueOrDie();
|
||||
|
||||
ComputeAndCompareR3<T>(&builder, expected,
|
||||
{input_literal.get(), filter_literal.get()},
|
||||
error_spec_);
|
||||
}
|
||||
};
|
||||
|
||||
TYPED_TEST_CASE(Convolve1D_1x2x5_1x2x2_WithPadding, TestTypes);
|
||||
TYPED_TEST(Convolve1D_1x2x5_1x2x2_WithPadding, Types) { this->RunTest(); }
|
||||
|
||||
XLA_TEST_F(ConvolutionTest, Convolve3D_1x4x2x3x3_2x2x2x3x3_Valid) {
|
||||
XlaBuilder builder(TestName());
|
||||
std::vector<int64> input_dims = {1, 4, 2, 3, 3};
|
||||
@ -1714,150 +1546,7 @@ INSTANTIATE_TEST_CASE_P(ConvolveWithAndWithoutCanonicalization_Instantiation,
|
||||
ConvolveWithAndWithoutCanonicalization,
|
||||
::testing::Values(true, false));
|
||||
|
||||
struct Convolve1DTestParam {
|
||||
int64 input_feature;
|
||||
int64 output_feature;
|
||||
int64 batch;
|
||||
int64 window_size;
|
||||
int64 num_windows;
|
||||
};
|
||||
|
||||
class Convolve1D1WindowTestBase
|
||||
: public ConvolutionTest,
|
||||
public ::testing::WithParamInterface<Convolve1DTestParam> {
|
||||
protected:
|
||||
template <typename T>
|
||||
void TestImpl() {
|
||||
XlaBuilder builder(TestName());
|
||||
int64 input_feature = GetParam().input_feature;
|
||||
int64 output_feature = GetParam().output_feature;
|
||||
int64 batch = GetParam().batch;
|
||||
int64 num_windows = GetParam().num_windows;
|
||||
int64 window_size = GetParam().window_size;
|
||||
std::vector<int64> input_dims = {batch, window_size + num_windows - 1,
|
||||
input_feature};
|
||||
std::vector<int64> filter_dims = {window_size, input_feature,
|
||||
output_feature};
|
||||
Shape input_shape = ShapeUtil::MakeShapeWithType<T>(input_dims);
|
||||
Shape filter_shape = ShapeUtil::MakeShapeWithType<T>(filter_dims);
|
||||
{
|
||||
auto input = Parameter(&builder, 0, input_shape, "input");
|
||||
auto filter = Parameter(&builder, 1, filter_shape, "filter");
|
||||
|
||||
// Tensorflow dimension numbers for 1D convolution.
|
||||
ConvolutionDimensionNumbers dnums;
|
||||
dnums.set_input_batch_dimension(0);
|
||||
dnums.set_output_batch_dimension(0);
|
||||
dnums.add_input_spatial_dimensions(1);
|
||||
dnums.add_output_spatial_dimensions(1);
|
||||
dnums.set_input_feature_dimension(2);
|
||||
dnums.set_output_feature_dimension(2);
|
||||
dnums.add_kernel_spatial_dimensions(0);
|
||||
dnums.set_kernel_input_feature_dimension(1);
|
||||
dnums.set_kernel_output_feature_dimension(2);
|
||||
|
||||
ConvWithGeneralDimensions(input, filter, {1}, Padding::kValid, dnums);
|
||||
}
|
||||
|
||||
std::vector<T> input_elems(ShapeUtil::ElementsIn(input_shape),
|
||||
static_cast<T>(1.0f));
|
||||
auto input_r1 = LiteralUtil::CreateR1<T>(input_elems);
|
||||
auto input_r3 = input_r1.Reshape(input_dims).ConsumeValueOrDie();
|
||||
|
||||
std::vector<T> filter_elems(ShapeUtil::ElementsIn(filter_shape),
|
||||
static_cast<T>(1.0f));
|
||||
|
||||
auto filter_r1 = LiteralUtil::CreateR1<T>(filter_elems);
|
||||
auto filter_r3 = filter_r1.Reshape(filter_dims).ConsumeValueOrDie();
|
||||
|
||||
std::vector<T> expect_elems(batch * output_feature * num_windows,
|
||||
static_cast<T>(window_size * input_feature));
|
||||
auto expected_r1 = LiteralUtil::CreateR1<T>(expect_elems);
|
||||
auto expected_r3 = expected_r1.Reshape({batch, num_windows, output_feature})
|
||||
.ConsumeValueOrDie();
|
||||
|
||||
auto input_literal =
|
||||
client_->TransferToServer(input_r3).ConsumeValueOrDie();
|
||||
auto filter_literal =
|
||||
client_->TransferToServer(filter_r3).ConsumeValueOrDie();
|
||||
ComputeAndCompareLiteral(&builder, expected_r3,
|
||||
{input_literal.get(), filter_literal.get()},
|
||||
error_spec_);
|
||||
}
|
||||
};
|
||||
|
||||
class Convolve1D1WindowTestFloat : public Convolve1D1WindowTestBase {};
|
||||
|
||||
XLA_TEST_P(Convolve1D1WindowTestFloat, Convolve1D1Window) { TestImpl<float>(); }
|
||||
|
||||
INSTANTIATE_TEST_CASE_P(
|
||||
Convolve1D1WindowTest_Instantiation, Convolve1D1WindowTestFloat,
|
||||
::testing::Values(Convolve1DTestParam{1, 1, 1, 1, 2},
|
||||
Convolve1DTestParam{160, 1, 1, 5, 1},
|
||||
Convolve1DTestParam{24, 1, 1, 20, 1},
|
||||
Convolve1DTestParam{30, 1, 1, 20, 1},
|
||||
Convolve1DTestParam{23, 1, 1, 20, 20},
|
||||
Convolve1DTestParam{25, 1, 1, 20, 1},
|
||||
Convolve1DTestParam{24, 1, 1, 10, 5},
|
||||
Convolve1DTestParam{160, 1, 1, 10, 1},
|
||||
Convolve1DTestParam{255, 1, 1, 3, 1},
|
||||
Convolve1DTestParam{130, 1, 1, 1, 2},
|
||||
Convolve1DTestParam{136, 1, 1, 1, 2},
|
||||
Convolve1DTestParam{64, 1, 1, 1, 1},
|
||||
Convolve1DTestParam{128, 1, 1, 1, 1},
|
||||
Convolve1DTestParam{139, 1, 1, 128, 1},
|
||||
Convolve1DTestParam{1, 10, 10, 1, 10},
|
||||
Convolve1DTestParam{1, 10, 130, 1, 2},
|
||||
Convolve1DTestParam{1, 10, 130, 1, 1},
|
||||
Convolve1DTestParam{1, 64, 64, 1, 10},
|
||||
Convolve1DTestParam{1, 65, 65, 1, 1},
|
||||
Convolve1DTestParam{1, 128, 128, 1, 1},
|
||||
Convolve1DTestParam{128, 128, 128, 128, 1},
|
||||
Convolve1DTestParam{1, 128, 128, 1, 1},
|
||||
Convolve1DTestParam{2, 2, 2, 2, 1},
|
||||
Convolve1DTestParam{161, 1, 1, 10, 1},
|
||||
Convolve1DTestParam{900, 1, 1, 10, 1},
|
||||
Convolve1DTestParam{640, 3, 3, 128, 1})
|
||||
|
||||
);
|
||||
|
||||
#if (XLA_TEST_BACKEND_GPU || XLA_TEST_BACKEND_CPU)
|
||||
class Convolve1D1WindowTestHalf : public Convolve1D1WindowTestBase {};
|
||||
|
||||
XLA_TEST_P(Convolve1D1WindowTestHalf, Convolve1D1Window) {
|
||||
TestImpl<Eigen::half>();
|
||||
}
|
||||
|
||||
INSTANTIATE_TEST_CASE_P(
|
||||
Convolve1D1WindowTest_Instantiation, Convolve1D1WindowTestHalf,
|
||||
::testing::Values(Convolve1DTestParam{1, 1, 1, 1, 2},
|
||||
Convolve1DTestParam{160, 1, 1, 5, 1},
|
||||
Convolve1DTestParam{24, 1, 1, 20, 1},
|
||||
Convolve1DTestParam{30, 1, 1, 20, 1},
|
||||
Convolve1DTestParam{23, 1, 1, 20, 20},
|
||||
Convolve1DTestParam{25, 1, 1, 20, 1},
|
||||
Convolve1DTestParam{24, 1, 1, 10, 5},
|
||||
Convolve1DTestParam{160, 1, 1, 10, 1},
|
||||
Convolve1DTestParam{255, 1, 1, 3, 1},
|
||||
Convolve1DTestParam{130, 1, 1, 1, 3},
|
||||
Convolve1DTestParam{64, 1, 1, 1, 1},
|
||||
Convolve1DTestParam{128, 1, 1, 1, 1},
|
||||
Convolve1DTestParam{139, 1, 1, 128, 1},
|
||||
Convolve1DTestParam{640, 3, 3, 128, 1},
|
||||
Convolve1DTestParam{900, 1, 1, 10, 1},
|
||||
Convolve1DTestParam{1, 10, 10, 1, 10},
|
||||
Convolve1DTestParam{1, 10, 130, 1, 1},
|
||||
Convolve1DTestParam{1, 10, 130, 1, 2},
|
||||
Convolve1DTestParam{1, 64, 64, 1, 10},
|
||||
Convolve1DTestParam{1, 65, 65, 1, 1},
|
||||
Convolve1DTestParam{1, 128, 128, 1, 1},
|
||||
Convolve1DTestParam{128, 128, 128, 128, 1},
|
||||
Convolve1DTestParam{1, 128, 128, 1, 1},
|
||||
Convolve1DTestParam{2, 2, 2, 2, 1},
|
||||
Convolve1DTestParam{161, 1, 1, 10, 1})
|
||||
|
||||
);
|
||||
#endif
|
||||
|
||||
XLA_TEST_F(ConvolutionTest, Convolve_bf16_1x1x1x2_1x1x1x2_Valid) {
|
||||
XlaBuilder builder(TestName());
|
||||
|
376
tensorflow/compiler/xla/tests/convolution_test_1d.cc
Normal file
376
tensorflow/compiler/xla/tests/convolution_test_1d.cc
Normal file
@ -0,0 +1,376 @@
|
||||
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
// Tests of 1D convolution with trivial kernels and no special variations (like
|
||||
// strides and padding).
|
||||
|
||||
#include <memory>
|
||||
|
||||
#include "absl/memory/memory.h"
|
||||
#include "absl/strings/str_cat.h"
|
||||
#include "tensorflow/compiler/xla/array2d.h"
|
||||
#include "tensorflow/compiler/xla/array4d.h"
|
||||
#include "tensorflow/compiler/xla/client/global_data.h"
|
||||
#include "tensorflow/compiler/xla/client/local_client.h"
|
||||
#include "tensorflow/compiler/xla/client/padding.h"
|
||||
#include "tensorflow/compiler/xla/client/xla_builder.h"
|
||||
#include "tensorflow/compiler/xla/layout_util.h"
|
||||
#include "tensorflow/compiler/xla/literal.h"
|
||||
#include "tensorflow/compiler/xla/reference_util.h"
|
||||
#include "tensorflow/compiler/xla/shape_util.h"
|
||||
#include "tensorflow/compiler/xla/statusor.h"
|
||||
#include "tensorflow/compiler/xla/tests/client_library_test_base.h"
|
||||
#include "tensorflow/compiler/xla/tests/hlo_test_base.h"
|
||||
#include "tensorflow/compiler/xla/tests/literal_test_util.h"
|
||||
#include "tensorflow/compiler/xla/tests/test_macros.h"
|
||||
#include "tensorflow/compiler/xla/xla_data.pb.h"
|
||||
#include "tensorflow/core/platform/test.h"
|
||||
#include "tensorflow/core/platform/types.h"
|
||||
|
||||
namespace xla {
|
||||
namespace {
|
||||
|
||||
class ConvolutionTest : public ClientLibraryTestBase {
|
||||
protected:
|
||||
#if XLA_TEST_BACKEND_GPU
|
||||
// XLA:GPU sometimes uses FFT convolution which isn't as precise as spatial
|
||||
// convolution. So relax the absolute error threshold.
|
||||
ErrorSpec error_spec_ = ErrorSpec(1e-2, 1e-3);
|
||||
#else
|
||||
ErrorSpec error_spec_ = ErrorSpec(1e-4, 1e-3);
|
||||
#endif
|
||||
};
|
||||
|
||||
#ifdef XLA_BACKEND_DOES_NOT_SUPPORT_FLOAT16
|
||||
using TestTypes = ::testing::Types<float>;
|
||||
#else
|
||||
using TestTypes = ::testing::Types<float, Eigen::half>;
|
||||
#endif
|
||||
|
||||
struct Convolve1DTestParam {
|
||||
int64 input_feature;
|
||||
int64 output_feature;
|
||||
int64 batch;
|
||||
int64 window_size;
|
||||
int64 num_windows;
|
||||
};
|
||||
|
||||
class Convolve1D1WindowTestBase
|
||||
: public ConvolutionTest,
|
||||
public ::testing::WithParamInterface<Convolve1DTestParam> {
|
||||
protected:
|
||||
template <typename T>
|
||||
void TestImpl() {
|
||||
XlaBuilder builder(TestName());
|
||||
int64 input_feature = GetParam().input_feature;
|
||||
int64 output_feature = GetParam().output_feature;
|
||||
int64 batch = GetParam().batch;
|
||||
int64 num_windows = GetParam().num_windows;
|
||||
int64 window_size = GetParam().window_size;
|
||||
std::vector<int64> input_dims = {batch, window_size + num_windows - 1,
|
||||
input_feature};
|
||||
std::vector<int64> filter_dims = {window_size, input_feature,
|
||||
output_feature};
|
||||
Shape input_shape = ShapeUtil::MakeShapeWithType<T>(input_dims);
|
||||
Shape filter_shape = ShapeUtil::MakeShapeWithType<T>(filter_dims);
|
||||
{
|
||||
auto input = Parameter(&builder, 0, input_shape, "input");
|
||||
auto filter = Parameter(&builder, 1, filter_shape, "filter");
|
||||
|
||||
// Tensorflow dimension numbers for 1D convolution.
|
||||
ConvolutionDimensionNumbers dnums;
|
||||
dnums.set_input_batch_dimension(0);
|
||||
dnums.set_output_batch_dimension(0);
|
||||
dnums.add_input_spatial_dimensions(1);
|
||||
dnums.add_output_spatial_dimensions(1);
|
||||
dnums.set_input_feature_dimension(2);
|
||||
dnums.set_output_feature_dimension(2);
|
||||
dnums.add_kernel_spatial_dimensions(0);
|
||||
dnums.set_kernel_input_feature_dimension(1);
|
||||
dnums.set_kernel_output_feature_dimension(2);
|
||||
|
||||
ConvWithGeneralDimensions(input, filter, {1}, Padding::kValid, dnums);
|
||||
}
|
||||
|
||||
std::vector<T> input_elems(ShapeUtil::ElementsIn(input_shape),
|
||||
static_cast<T>(1.0f));
|
||||
auto input_r1 = LiteralUtil::CreateR1<T>(input_elems);
|
||||
auto input_r3 = input_r1.Reshape(input_dims).ConsumeValueOrDie();
|
||||
|
||||
std::vector<T> filter_elems(ShapeUtil::ElementsIn(filter_shape),
|
||||
static_cast<T>(1.0f));
|
||||
|
||||
auto filter_r1 = LiteralUtil::CreateR1<T>(filter_elems);
|
||||
auto filter_r3 = filter_r1.Reshape(filter_dims).ConsumeValueOrDie();
|
||||
|
||||
std::vector<T> expect_elems(batch * output_feature * num_windows,
|
||||
static_cast<T>(window_size * input_feature));
|
||||
auto expected_r1 = LiteralUtil::CreateR1<T>(expect_elems);
|
||||
auto expected_r3 = expected_r1.Reshape({batch, num_windows, output_feature})
|
||||
.ConsumeValueOrDie();
|
||||
|
||||
auto input_literal =
|
||||
client_->TransferToServer(input_r3).ConsumeValueOrDie();
|
||||
auto filter_literal =
|
||||
client_->TransferToServer(filter_r3).ConsumeValueOrDie();
|
||||
ComputeAndCompareLiteral(&builder, expected_r3,
|
||||
{input_literal.get(), filter_literal.get()},
|
||||
error_spec_);
|
||||
}
|
||||
};
|
||||
|
||||
class Convolve1D1WindowTestFloat : public Convolve1D1WindowTestBase {};
|
||||
|
||||
XLA_TEST_P(Convolve1D1WindowTestFloat, Convolve1D1Window) { TestImpl<float>(); }
|
||||
|
||||
INSTANTIATE_TEST_CASE_P(
|
||||
Convolve1D1WindowTest_Instantiation, Convolve1D1WindowTestFloat,
|
||||
::testing::Values(Convolve1DTestParam{1, 1, 1, 1, 2},
|
||||
Convolve1DTestParam{160, 1, 1, 5, 1},
|
||||
Convolve1DTestParam{24, 1, 1, 20, 1},
|
||||
Convolve1DTestParam{30, 1, 1, 20, 1},
|
||||
Convolve1DTestParam{23, 1, 1, 20, 20},
|
||||
Convolve1DTestParam{25, 1, 1, 20, 1},
|
||||
Convolve1DTestParam{24, 1, 1, 10, 5},
|
||||
Convolve1DTestParam{160, 1, 1, 10, 1},
|
||||
Convolve1DTestParam{255, 1, 1, 3, 1},
|
||||
Convolve1DTestParam{130, 1, 1, 1, 2},
|
||||
Convolve1DTestParam{136, 1, 1, 1, 2},
|
||||
Convolve1DTestParam{64, 1, 1, 1, 1},
|
||||
Convolve1DTestParam{128, 1, 1, 1, 1},
|
||||
Convolve1DTestParam{139, 1, 1, 128, 1},
|
||||
Convolve1DTestParam{1, 10, 10, 1, 10},
|
||||
Convolve1DTestParam{1, 10, 130, 1, 2},
|
||||
Convolve1DTestParam{1, 10, 130, 1, 1},
|
||||
Convolve1DTestParam{1, 64, 64, 1, 10},
|
||||
Convolve1DTestParam{1, 65, 65, 1, 1},
|
||||
Convolve1DTestParam{1, 128, 128, 1, 1},
|
||||
Convolve1DTestParam{128, 128, 128, 128, 1},
|
||||
Convolve1DTestParam{1, 128, 128, 1, 1},
|
||||
Convolve1DTestParam{2, 2, 2, 2, 1},
|
||||
Convolve1DTestParam{161, 1, 1, 10, 1},
|
||||
Convolve1DTestParam{900, 1, 1, 10, 1},
|
||||
Convolve1DTestParam{640, 3, 3, 128, 1})
|
||||
|
||||
);
|
||||
|
||||
#if (XLA_TEST_BACKEND_GPU || XLA_TEST_BACKEND_CPU)
|
||||
class Convolve1D1WindowTestHalf : public Convolve1D1WindowTestBase {};
|
||||
|
||||
XLA_TEST_P(Convolve1D1WindowTestHalf, Convolve1D1Window) {
|
||||
TestImpl<Eigen::half>();
|
||||
}
|
||||
|
||||
INSTANTIATE_TEST_CASE_P(
|
||||
Convolve1D1WindowTest_Instantiation, Convolve1D1WindowTestHalf,
|
||||
::testing::Values(Convolve1DTestParam{1, 1, 1, 1, 2},
|
||||
Convolve1DTestParam{160, 1, 1, 5, 1},
|
||||
Convolve1DTestParam{24, 1, 1, 20, 1},
|
||||
Convolve1DTestParam{30, 1, 1, 20, 1},
|
||||
Convolve1DTestParam{23, 1, 1, 20, 20},
|
||||
Convolve1DTestParam{25, 1, 1, 20, 1},
|
||||
Convolve1DTestParam{24, 1, 1, 10, 5},
|
||||
Convolve1DTestParam{160, 1, 1, 10, 1},
|
||||
Convolve1DTestParam{255, 1, 1, 3, 1},
|
||||
Convolve1DTestParam{130, 1, 1, 1, 3},
|
||||
Convolve1DTestParam{64, 1, 1, 1, 1},
|
||||
Convolve1DTestParam{128, 1, 1, 1, 1},
|
||||
Convolve1DTestParam{139, 1, 1, 128, 1},
|
||||
Convolve1DTestParam{640, 3, 3, 128, 1},
|
||||
Convolve1DTestParam{900, 1, 1, 10, 1},
|
||||
Convolve1DTestParam{1, 10, 10, 1, 10},
|
||||
Convolve1DTestParam{1, 10, 130, 1, 1},
|
||||
Convolve1DTestParam{1, 10, 130, 1, 2},
|
||||
Convolve1DTestParam{1, 64, 64, 1, 10},
|
||||
Convolve1DTestParam{1, 65, 65, 1, 1},
|
||||
Convolve1DTestParam{1, 128, 128, 1, 1},
|
||||
Convolve1DTestParam{128, 128, 128, 128, 1},
|
||||
Convolve1DTestParam{1, 128, 128, 1, 1},
|
||||
Convolve1DTestParam{2, 2, 2, 2, 1},
|
||||
Convolve1DTestParam{161, 1, 1, 10, 1})
|
||||
|
||||
);
|
||||
#endif
|
||||
|
||||
XLA_TEST_F(ConvolutionTest, Convolve1D_1x2x5_1x2x2_Valid) {
|
||||
XlaBuilder builder(TestName());
|
||||
{
|
||||
Shape input_shape = ShapeUtil::MakeShape(F32, {1, 2, 5});
|
||||
Shape filter_shape = ShapeUtil::MakeShape(F32, {1, 2, 2});
|
||||
auto input = Parameter(&builder, 0, input_shape, "input");
|
||||
auto filter = Parameter(&builder, 1, filter_shape, "filter");
|
||||
Conv(input, filter, {1}, Padding::kValid);
|
||||
}
|
||||
|
||||
Array3D<float> input({{{1, 2, 3, 4, 5}, {6, 7, 8, 9, 10}}});
|
||||
Array3D<float> filter({{{10, 20}, {30, 40}}});
|
||||
|
||||
Array3D<float> expected({{{510, 610, 710, 810}}});
|
||||
|
||||
auto input_literal =
|
||||
client_->TransferToServer(LiteralUtil::CreateR3FromArray3D(input))
|
||||
.ConsumeValueOrDie();
|
||||
auto filter_literal =
|
||||
client_->TransferToServer(LiteralUtil::CreateR3FromArray3D(filter))
|
||||
.ConsumeValueOrDie();
|
||||
|
||||
ComputeAndCompareR3<float>(&builder, expected,
|
||||
{input_literal.get(), filter_literal.get()},
|
||||
error_spec_);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
class Convolve1D_1x2x5_1x2x2_WithRHSDilation : public ConvolutionTest {
|
||||
public:
|
||||
void RunTest() {
|
||||
XlaBuilder builder(TestName());
|
||||
{
|
||||
Shape input_shape = ShapeUtil::MakeShapeWithType<T>({1, 2, 5});
|
||||
Shape filter_shape = ShapeUtil::MakeShapeWithType<T>({1, 2, 2});
|
||||
auto input = Parameter(&builder, 0, input_shape, "input");
|
||||
auto filter = Parameter(&builder, 1, filter_shape, "filter");
|
||||
// Convolution dimensions are bf0_oi0->bo0.
|
||||
ConvGeneralDilated(
|
||||
input, filter, /*window_strides=*/{1}, /*padding=*/{{0, 0}},
|
||||
/*lhs_dilation=*/{1}, /*rhs_dilation=*/{2},
|
||||
/*dimension_numbers=*/builder.CreateDefaultConvDimensionNumbers(1));
|
||||
}
|
||||
|
||||
Array3D<T> input(
|
||||
{{{1.0f, 2.0f, 3.0f, 4.0f, 5.0f}, {6.0f, 7.0f, 8.0f, 9.0f, 10.0f}}});
|
||||
Array3D<T> filter({{{10.0f, 20.0f}, {30.0f, 40.0f}}});
|
||||
|
||||
Array3D<T> expected({{{570.0f, 670.0f, 770.0f}}});
|
||||
|
||||
auto input_literal =
|
||||
client_->TransferToServer(LiteralUtil::CreateR3FromArray3D(input))
|
||||
.ConsumeValueOrDie();
|
||||
auto filter_literal =
|
||||
client_->TransferToServer(LiteralUtil::CreateR3FromArray3D(filter))
|
||||
.ConsumeValueOrDie();
|
||||
|
||||
ComputeAndCompareR3<T>(&builder, expected,
|
||||
{input_literal.get(), filter_literal.get()},
|
||||
error_spec_);
|
||||
}
|
||||
}; // namespace
|
||||
|
||||
TYPED_TEST_CASE(Convolve1D_1x2x5_1x2x2_WithRHSDilation, TestTypes);
|
||||
TYPED_TEST(Convolve1D_1x2x5_1x2x2_WithRHSDilation, Types) { this->RunTest(); }
|
||||
|
||||
XLA_TEST_F(ConvolutionTest, Convolve1D_1x2x5_1x2x2_WithLHSDilation) {
|
||||
XlaBuilder builder(TestName());
|
||||
{
|
||||
Shape input_shape = ShapeUtil::MakeShape(F32, {1, 2, 5});
|
||||
Shape filter_shape = ShapeUtil::MakeShape(F32, {1, 2, 2});
|
||||
auto input = Parameter(&builder, 0, input_shape, "input");
|
||||
auto filter = Parameter(&builder, 1, filter_shape, "filter");
|
||||
// Convolution dimensions are bf0_oi0->bo0.
|
||||
ConvGeneralDilated(
|
||||
input, filter, /*window_strides=*/{1}, /*padding=*/{{0, 0}},
|
||||
/*lhs_dilation=*/{2}, /*rhs_dilation=*/{1},
|
||||
/*dimension_numbers=*/builder.CreateDefaultConvDimensionNumbers(1));
|
||||
}
|
||||
|
||||
Array3D<float> input({{{1, 2, 3, 4, 5}, {6, 7, 8, 9, 10}}});
|
||||
Array3D<float> filter({{{10, 20}, {30, 40}}});
|
||||
|
||||
Array3D<float> expected({{{190, 320, 230, 380, 270, 440, 310, 500}}});
|
||||
|
||||
auto input_literal =
|
||||
client_->TransferToServer(LiteralUtil::CreateR3FromArray3D(input))
|
||||
.ConsumeValueOrDie();
|
||||
auto filter_literal =
|
||||
client_->TransferToServer(LiteralUtil::CreateR3FromArray3D(filter))
|
||||
.ConsumeValueOrDie();
|
||||
|
||||
ComputeAndCompareR3<float>(&builder, expected,
|
||||
{input_literal.get(), filter_literal.get()},
|
||||
error_spec_);
|
||||
}
|
||||
|
||||
XLA_TEST_F(ConvolutionTest, Convolve1D_1x2x5_1x2x2_WithLHSAndRHSDilation) {
|
||||
XlaBuilder builder(TestName());
|
||||
{
|
||||
Shape input_shape = ShapeUtil::MakeShape(F32, {1, 2, 5});
|
||||
Shape filter_shape = ShapeUtil::MakeShape(F32, {1, 2, 2});
|
||||
auto input = Parameter(&builder, 0, input_shape, "input");
|
||||
auto filter = Parameter(&builder, 1, filter_shape, "filter");
|
||||
// Convolution dimensions are bf0_oi0->bo0.
|
||||
ConvGeneralDilated(
|
||||
input, filter, /*window_strides=*/{1}, /*padding=*/{{0, 0}},
|
||||
/*lhs_dilation=*/{2}, /*rhs_dilation=*/{2},
|
||||
/*dimension_numbers=*/builder.CreateDefaultConvDimensionNumbers(1));
|
||||
}
|
||||
|
||||
Array3D<float> input({{{1, 2, 3, 4, 5}, {6, 7, 8, 9, 10}}});
|
||||
Array3D<float> filter({{{10, 20}, {30, 40}}});
|
||||
|
||||
Array3D<float> expected({{{510, 0, 610, 0, 710, 0, 810}}});
|
||||
|
||||
auto input_literal =
|
||||
client_->TransferToServer(LiteralUtil::CreateR3FromArray3D(input))
|
||||
.ConsumeValueOrDie();
|
||||
auto filter_literal =
|
||||
client_->TransferToServer(LiteralUtil::CreateR3FromArray3D(filter))
|
||||
.ConsumeValueOrDie();
|
||||
|
||||
ComputeAndCompareR3<float>(&builder, expected,
|
||||
{input_literal.get(), filter_literal.get()},
|
||||
error_spec_);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
class Convolve1D_1x2x5_1x2x2_WithPadding : public ConvolutionTest {
|
||||
public:
|
||||
void RunTest() {
|
||||
XlaBuilder builder(TestName());
|
||||
{
|
||||
Shape input_shape = ShapeUtil::MakeShapeWithType<T>({1, 2, 5});
|
||||
Shape filter_shape = ShapeUtil::MakeShapeWithType<T>({1, 2, 2});
|
||||
auto input = Parameter(&builder, 0, input_shape, "input");
|
||||
auto filter = Parameter(&builder, 1, filter_shape, "filter");
|
||||
// Convolution dimensions are bf0_oi0->bo0.
|
||||
ConvGeneralDilated(
|
||||
input, filter, /*window_strides=*/{1}, /*padding=*/{{2, 2}},
|
||||
/*lhs_dilation=*/{1}, /*rhs_dilation=*/{1},
|
||||
/*dimension_numbers=*/builder.CreateDefaultConvDimensionNumbers(1));
|
||||
}
|
||||
|
||||
Array3D<T> input(
|
||||
{{{1.0f, 2.0f, 3.0f, 4.0f, 5.0f}, {6.0f, 7.0f, 8.0f, 9.0f, 10.0f}}});
|
||||
Array3D<T> filter({{{10.0f, 20.0f}, {30.0f, 40.0f}}});
|
||||
|
||||
Array3D<T> expected(
|
||||
{{{0.0f, 260.0f, 510.0f, 610.0f, 710.0f, 810.0f, 350.0f, 0.0f}}});
|
||||
|
||||
auto input_literal =
|
||||
client_->TransferToServer(LiteralUtil::CreateR3FromArray3D(input))
|
||||
.ConsumeValueOrDie();
|
||||
auto filter_literal =
|
||||
client_->TransferToServer(LiteralUtil::CreateR3FromArray3D(filter))
|
||||
.ConsumeValueOrDie();
|
||||
|
||||
ComputeAndCompareR3<T>(&builder, expected,
|
||||
{input_literal.get(), filter_literal.get()},
|
||||
error_spec_);
|
||||
}
|
||||
};
|
||||
|
||||
TYPED_TEST_CASE(Convolve1D_1x2x5_1x2x2_WithPadding, TestTypes);
|
||||
TYPED_TEST(Convolve1D_1x2x5_1x2x2_WithPadding, Types) { this->RunTest(); }
|
||||
|
||||
} // namespace
|
||||
} // namespace xla
|
Loading…
Reference in New Issue
Block a user