377 lines
15 KiB
C++
377 lines
15 KiB
C++
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
|
|
|
Licensed under the Apache License, Version 2.0 (the "License");
|
|
you may not use this file except in compliance with the License.
|
|
You may obtain a copy of the License at
|
|
|
|
http://www.apache.org/licenses/LICENSE-2.0
|
|
|
|
Unless required by applicable law or agreed to in writing, software
|
|
distributed under the License is distributed on an "AS IS" BASIS,
|
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
See the License for the specific language governing permissions and
|
|
limitations under the License.
|
|
==============================================================================*/
|
|
|
|
// Tests of 1D convolution with trivial kernels and no special variations (like
|
|
// strides and padding).
|
|
|
|
#include <memory>
|
|
|
|
#include "absl/memory/memory.h"
|
|
#include "absl/strings/str_cat.h"
|
|
#include "tensorflow/compiler/xla/array2d.h"
|
|
#include "tensorflow/compiler/xla/array4d.h"
|
|
#include "tensorflow/compiler/xla/client/global_data.h"
|
|
#include "tensorflow/compiler/xla/client/local_client.h"
|
|
#include "tensorflow/compiler/xla/client/padding.h"
|
|
#include "tensorflow/compiler/xla/client/xla_builder.h"
|
|
#include "tensorflow/compiler/xla/layout_util.h"
|
|
#include "tensorflow/compiler/xla/literal.h"
|
|
#include "tensorflow/compiler/xla/reference_util.h"
|
|
#include "tensorflow/compiler/xla/shape_util.h"
|
|
#include "tensorflow/compiler/xla/statusor.h"
|
|
#include "tensorflow/compiler/xla/tests/client_library_test_base.h"
|
|
#include "tensorflow/compiler/xla/tests/hlo_test_base.h"
|
|
#include "tensorflow/compiler/xla/tests/literal_test_util.h"
|
|
#include "tensorflow/compiler/xla/tests/test_macros.h"
|
|
#include "tensorflow/compiler/xla/xla_data.pb.h"
|
|
#include "tensorflow/core/platform/test.h"
|
|
#include "tensorflow/core/platform/types.h"
|
|
|
|
namespace xla {
|
|
namespace {
|
|
|
|
class ConvolutionTest : public ClientLibraryTestBase {
|
|
protected:
|
|
#if XLA_TEST_BACKEND_GPU
|
|
// XLA:GPU sometimes uses FFT convolution which isn't as precise as spatial
|
|
// convolution. So relax the absolute error threshold.
|
|
ErrorSpec error_spec_ = ErrorSpec(1e-2, 1e-3);
|
|
#else
|
|
ErrorSpec error_spec_ = ErrorSpec(1e-4, 1e-3);
|
|
#endif
|
|
};
|
|
|
|
#ifdef XLA_BACKEND_DOES_NOT_SUPPORT_FLOAT16
|
|
using TestTypes = ::testing::Types<float>;
|
|
#else
|
|
using TestTypes = ::testing::Types<float, Eigen::half>;
|
|
#endif
|
|
|
|
struct Convolve1DTestParam {
|
|
int64 input_feature;
|
|
int64 output_feature;
|
|
int64 batch;
|
|
int64 window_size;
|
|
int64 num_windows;
|
|
};
|
|
|
|
class Convolve1D1WindowTestBase
|
|
: public ConvolutionTest,
|
|
public ::testing::WithParamInterface<Convolve1DTestParam> {
|
|
protected:
|
|
template <typename T>
|
|
void TestImpl() {
|
|
XlaBuilder builder(TestName());
|
|
int64 input_feature = GetParam().input_feature;
|
|
int64 output_feature = GetParam().output_feature;
|
|
int64 batch = GetParam().batch;
|
|
int64 num_windows = GetParam().num_windows;
|
|
int64 window_size = GetParam().window_size;
|
|
std::vector<int64> input_dims = {batch, window_size + num_windows - 1,
|
|
input_feature};
|
|
std::vector<int64> filter_dims = {window_size, input_feature,
|
|
output_feature};
|
|
Shape input_shape = ShapeUtil::MakeShapeWithType<T>(input_dims);
|
|
Shape filter_shape = ShapeUtil::MakeShapeWithType<T>(filter_dims);
|
|
{
|
|
auto input = Parameter(&builder, 0, input_shape, "input");
|
|
auto filter = Parameter(&builder, 1, filter_shape, "filter");
|
|
|
|
// Tensorflow dimension numbers for 1D convolution.
|
|
ConvolutionDimensionNumbers dnums;
|
|
dnums.set_input_batch_dimension(0);
|
|
dnums.set_output_batch_dimension(0);
|
|
dnums.add_input_spatial_dimensions(1);
|
|
dnums.add_output_spatial_dimensions(1);
|
|
dnums.set_input_feature_dimension(2);
|
|
dnums.set_output_feature_dimension(2);
|
|
dnums.add_kernel_spatial_dimensions(0);
|
|
dnums.set_kernel_input_feature_dimension(1);
|
|
dnums.set_kernel_output_feature_dimension(2);
|
|
|
|
ConvWithGeneralDimensions(input, filter, {1}, Padding::kValid, dnums);
|
|
}
|
|
|
|
std::vector<T> input_elems(ShapeUtil::ElementsIn(input_shape),
|
|
static_cast<T>(1.0f));
|
|
auto input_r1 = LiteralUtil::CreateR1<T>(input_elems);
|
|
auto input_r3 = input_r1.Reshape(input_dims).ConsumeValueOrDie();
|
|
|
|
std::vector<T> filter_elems(ShapeUtil::ElementsIn(filter_shape),
|
|
static_cast<T>(1.0f));
|
|
|
|
auto filter_r1 = LiteralUtil::CreateR1<T>(filter_elems);
|
|
auto filter_r3 = filter_r1.Reshape(filter_dims).ConsumeValueOrDie();
|
|
|
|
std::vector<T> expect_elems(batch * output_feature * num_windows,
|
|
static_cast<T>(window_size * input_feature));
|
|
auto expected_r1 = LiteralUtil::CreateR1<T>(expect_elems);
|
|
auto expected_r3 = expected_r1.Reshape({batch, num_windows, output_feature})
|
|
.ConsumeValueOrDie();
|
|
|
|
auto input_literal =
|
|
client_->TransferToServer(input_r3).ConsumeValueOrDie();
|
|
auto filter_literal =
|
|
client_->TransferToServer(filter_r3).ConsumeValueOrDie();
|
|
ComputeAndCompareLiteral(&builder, expected_r3,
|
|
{input_literal.get(), filter_literal.get()},
|
|
error_spec_);
|
|
}
|
|
};
|
|
|
|
class Convolve1D1WindowTestFloat : public Convolve1D1WindowTestBase {};
|
|
|
|
XLA_TEST_P(Convolve1D1WindowTestFloat, Convolve1D1Window) { TestImpl<float>(); }
|
|
|
|
INSTANTIATE_TEST_CASE_P(
|
|
Convolve1D1WindowTest_Instantiation, Convolve1D1WindowTestFloat,
|
|
::testing::Values(Convolve1DTestParam{1, 1, 1, 1, 2},
|
|
Convolve1DTestParam{160, 1, 1, 5, 1},
|
|
Convolve1DTestParam{24, 1, 1, 20, 1},
|
|
Convolve1DTestParam{30, 1, 1, 20, 1},
|
|
Convolve1DTestParam{23, 1, 1, 20, 20},
|
|
Convolve1DTestParam{25, 1, 1, 20, 1},
|
|
Convolve1DTestParam{24, 1, 1, 10, 5},
|
|
Convolve1DTestParam{160, 1, 1, 10, 1},
|
|
Convolve1DTestParam{255, 1, 1, 3, 1},
|
|
Convolve1DTestParam{130, 1, 1, 1, 2},
|
|
Convolve1DTestParam{136, 1, 1, 1, 2},
|
|
Convolve1DTestParam{64, 1, 1, 1, 1},
|
|
Convolve1DTestParam{128, 1, 1, 1, 1},
|
|
Convolve1DTestParam{139, 1, 1, 128, 1},
|
|
Convolve1DTestParam{1, 10, 10, 1, 10},
|
|
Convolve1DTestParam{1, 10, 130, 1, 2},
|
|
Convolve1DTestParam{1, 10, 130, 1, 1},
|
|
Convolve1DTestParam{1, 64, 64, 1, 10},
|
|
Convolve1DTestParam{1, 65, 65, 1, 1},
|
|
Convolve1DTestParam{1, 128, 128, 1, 1},
|
|
Convolve1DTestParam{128, 128, 128, 128, 1},
|
|
Convolve1DTestParam{1, 128, 128, 1, 1},
|
|
Convolve1DTestParam{2, 2, 2, 2, 1},
|
|
Convolve1DTestParam{161, 1, 1, 10, 1},
|
|
Convolve1DTestParam{900, 1, 1, 10, 1},
|
|
Convolve1DTestParam{640, 3, 3, 128, 1})
|
|
|
|
);
|
|
|
|
#if (XLA_TEST_BACKEND_GPU || XLA_TEST_BACKEND_CPU)
|
|
class Convolve1D1WindowTestHalf : public Convolve1D1WindowTestBase {};
|
|
|
|
XLA_TEST_P(Convolve1D1WindowTestHalf, Convolve1D1Window) {
|
|
TestImpl<Eigen::half>();
|
|
}
|
|
|
|
INSTANTIATE_TEST_CASE_P(
|
|
Convolve1D1WindowTest_Instantiation, Convolve1D1WindowTestHalf,
|
|
::testing::Values(Convolve1DTestParam{1, 1, 1, 1, 2},
|
|
Convolve1DTestParam{160, 1, 1, 5, 1},
|
|
Convolve1DTestParam{24, 1, 1, 20, 1},
|
|
Convolve1DTestParam{30, 1, 1, 20, 1},
|
|
Convolve1DTestParam{23, 1, 1, 20, 20},
|
|
Convolve1DTestParam{25, 1, 1, 20, 1},
|
|
Convolve1DTestParam{24, 1, 1, 10, 5},
|
|
Convolve1DTestParam{160, 1, 1, 10, 1},
|
|
Convolve1DTestParam{255, 1, 1, 3, 1},
|
|
Convolve1DTestParam{130, 1, 1, 1, 3},
|
|
Convolve1DTestParam{64, 1, 1, 1, 1},
|
|
Convolve1DTestParam{128, 1, 1, 1, 1},
|
|
Convolve1DTestParam{139, 1, 1, 128, 1},
|
|
Convolve1DTestParam{640, 3, 3, 128, 1},
|
|
Convolve1DTestParam{900, 1, 1, 10, 1},
|
|
Convolve1DTestParam{1, 10, 10, 1, 10},
|
|
Convolve1DTestParam{1, 10, 130, 1, 1},
|
|
Convolve1DTestParam{1, 10, 130, 1, 2},
|
|
Convolve1DTestParam{1, 64, 64, 1, 10},
|
|
Convolve1DTestParam{1, 65, 65, 1, 1},
|
|
Convolve1DTestParam{1, 128, 128, 1, 1},
|
|
Convolve1DTestParam{128, 128, 128, 128, 1},
|
|
Convolve1DTestParam{1, 128, 128, 1, 1},
|
|
Convolve1DTestParam{2, 2, 2, 2, 1},
|
|
Convolve1DTestParam{161, 1, 1, 10, 1})
|
|
|
|
);
|
|
#endif
|
|
|
|
XLA_TEST_F(ConvolutionTest, Convolve1D_1x2x5_1x2x2_Valid) {
|
|
XlaBuilder builder(TestName());
|
|
{
|
|
Shape input_shape = ShapeUtil::MakeShape(F32, {1, 2, 5});
|
|
Shape filter_shape = ShapeUtil::MakeShape(F32, {1, 2, 2});
|
|
auto input = Parameter(&builder, 0, input_shape, "input");
|
|
auto filter = Parameter(&builder, 1, filter_shape, "filter");
|
|
Conv(input, filter, {1}, Padding::kValid);
|
|
}
|
|
|
|
Array3D<float> input({{{1, 2, 3, 4, 5}, {6, 7, 8, 9, 10}}});
|
|
Array3D<float> filter({{{10, 20}, {30, 40}}});
|
|
|
|
Array3D<float> expected({{{510, 610, 710, 810}}});
|
|
|
|
auto input_literal =
|
|
client_->TransferToServer(LiteralUtil::CreateR3FromArray3D(input))
|
|
.ConsumeValueOrDie();
|
|
auto filter_literal =
|
|
client_->TransferToServer(LiteralUtil::CreateR3FromArray3D(filter))
|
|
.ConsumeValueOrDie();
|
|
|
|
ComputeAndCompareR3<float>(&builder, expected,
|
|
{input_literal.get(), filter_literal.get()},
|
|
error_spec_);
|
|
}
|
|
|
|
template <typename T>
|
|
class Convolve1D_1x2x5_1x2x2_WithRHSDilation : public ConvolutionTest {
|
|
public:
|
|
void RunTest() {
|
|
XlaBuilder builder(TestName());
|
|
{
|
|
Shape input_shape = ShapeUtil::MakeShapeWithType<T>({1, 2, 5});
|
|
Shape filter_shape = ShapeUtil::MakeShapeWithType<T>({1, 2, 2});
|
|
auto input = Parameter(&builder, 0, input_shape, "input");
|
|
auto filter = Parameter(&builder, 1, filter_shape, "filter");
|
|
// Convolution dimensions are bf0_oi0->bo0.
|
|
ConvGeneralDilated(
|
|
input, filter, /*window_strides=*/{1}, /*padding=*/{{0, 0}},
|
|
/*lhs_dilation=*/{1}, /*rhs_dilation=*/{2},
|
|
/*dimension_numbers=*/builder.CreateDefaultConvDimensionNumbers(1));
|
|
}
|
|
|
|
Array3D<T> input(
|
|
{{{1.0f, 2.0f, 3.0f, 4.0f, 5.0f}, {6.0f, 7.0f, 8.0f, 9.0f, 10.0f}}});
|
|
Array3D<T> filter({{{10.0f, 20.0f}, {30.0f, 40.0f}}});
|
|
|
|
Array3D<T> expected({{{570.0f, 670.0f, 770.0f}}});
|
|
|
|
auto input_literal =
|
|
client_->TransferToServer(LiteralUtil::CreateR3FromArray3D(input))
|
|
.ConsumeValueOrDie();
|
|
auto filter_literal =
|
|
client_->TransferToServer(LiteralUtil::CreateR3FromArray3D(filter))
|
|
.ConsumeValueOrDie();
|
|
|
|
ComputeAndCompareR3<T>(&builder, expected,
|
|
{input_literal.get(), filter_literal.get()},
|
|
error_spec_);
|
|
}
|
|
}; // namespace
|
|
|
|
TYPED_TEST_CASE(Convolve1D_1x2x5_1x2x2_WithRHSDilation, TestTypes);
|
|
TYPED_TEST(Convolve1D_1x2x5_1x2x2_WithRHSDilation, Types) { this->RunTest(); }
|
|
|
|
XLA_TEST_F(ConvolutionTest, Convolve1D_1x2x5_1x2x2_WithLHSDilation) {
|
|
XlaBuilder builder(TestName());
|
|
{
|
|
Shape input_shape = ShapeUtil::MakeShape(F32, {1, 2, 5});
|
|
Shape filter_shape = ShapeUtil::MakeShape(F32, {1, 2, 2});
|
|
auto input = Parameter(&builder, 0, input_shape, "input");
|
|
auto filter = Parameter(&builder, 1, filter_shape, "filter");
|
|
// Convolution dimensions are bf0_oi0->bo0.
|
|
ConvGeneralDilated(
|
|
input, filter, /*window_strides=*/{1}, /*padding=*/{{0, 0}},
|
|
/*lhs_dilation=*/{2}, /*rhs_dilation=*/{1},
|
|
/*dimension_numbers=*/builder.CreateDefaultConvDimensionNumbers(1));
|
|
}
|
|
|
|
Array3D<float> input({{{1, 2, 3, 4, 5}, {6, 7, 8, 9, 10}}});
|
|
Array3D<float> filter({{{10, 20}, {30, 40}}});
|
|
|
|
Array3D<float> expected({{{190, 320, 230, 380, 270, 440, 310, 500}}});
|
|
|
|
auto input_literal =
|
|
client_->TransferToServer(LiteralUtil::CreateR3FromArray3D(input))
|
|
.ConsumeValueOrDie();
|
|
auto filter_literal =
|
|
client_->TransferToServer(LiteralUtil::CreateR3FromArray3D(filter))
|
|
.ConsumeValueOrDie();
|
|
|
|
ComputeAndCompareR3<float>(&builder, expected,
|
|
{input_literal.get(), filter_literal.get()},
|
|
error_spec_);
|
|
}
|
|
|
|
XLA_TEST_F(ConvolutionTest, Convolve1D_1x2x5_1x2x2_WithLHSAndRHSDilation) {
|
|
XlaBuilder builder(TestName());
|
|
{
|
|
Shape input_shape = ShapeUtil::MakeShape(F32, {1, 2, 5});
|
|
Shape filter_shape = ShapeUtil::MakeShape(F32, {1, 2, 2});
|
|
auto input = Parameter(&builder, 0, input_shape, "input");
|
|
auto filter = Parameter(&builder, 1, filter_shape, "filter");
|
|
// Convolution dimensions are bf0_oi0->bo0.
|
|
ConvGeneralDilated(
|
|
input, filter, /*window_strides=*/{1}, /*padding=*/{{0, 0}},
|
|
/*lhs_dilation=*/{2}, /*rhs_dilation=*/{2},
|
|
/*dimension_numbers=*/builder.CreateDefaultConvDimensionNumbers(1));
|
|
}
|
|
|
|
Array3D<float> input({{{1, 2, 3, 4, 5}, {6, 7, 8, 9, 10}}});
|
|
Array3D<float> filter({{{10, 20}, {30, 40}}});
|
|
|
|
Array3D<float> expected({{{510, 0, 610, 0, 710, 0, 810}}});
|
|
|
|
auto input_literal =
|
|
client_->TransferToServer(LiteralUtil::CreateR3FromArray3D(input))
|
|
.ConsumeValueOrDie();
|
|
auto filter_literal =
|
|
client_->TransferToServer(LiteralUtil::CreateR3FromArray3D(filter))
|
|
.ConsumeValueOrDie();
|
|
|
|
ComputeAndCompareR3<float>(&builder, expected,
|
|
{input_literal.get(), filter_literal.get()},
|
|
error_spec_);
|
|
}
|
|
|
|
template <typename T>
|
|
class Convolve1D_1x2x5_1x2x2_WithPadding : public ConvolutionTest {
|
|
public:
|
|
void RunTest() {
|
|
XlaBuilder builder(TestName());
|
|
{
|
|
Shape input_shape = ShapeUtil::MakeShapeWithType<T>({1, 2, 5});
|
|
Shape filter_shape = ShapeUtil::MakeShapeWithType<T>({1, 2, 2});
|
|
auto input = Parameter(&builder, 0, input_shape, "input");
|
|
auto filter = Parameter(&builder, 1, filter_shape, "filter");
|
|
// Convolution dimensions are bf0_oi0->bo0.
|
|
ConvGeneralDilated(
|
|
input, filter, /*window_strides=*/{1}, /*padding=*/{{2, 2}},
|
|
/*lhs_dilation=*/{1}, /*rhs_dilation=*/{1},
|
|
/*dimension_numbers=*/builder.CreateDefaultConvDimensionNumbers(1));
|
|
}
|
|
|
|
Array3D<T> input(
|
|
{{{1.0f, 2.0f, 3.0f, 4.0f, 5.0f}, {6.0f, 7.0f, 8.0f, 9.0f, 10.0f}}});
|
|
Array3D<T> filter({{{10.0f, 20.0f}, {30.0f, 40.0f}}});
|
|
|
|
Array3D<T> expected(
|
|
{{{0.0f, 260.0f, 510.0f, 610.0f, 710.0f, 810.0f, 350.0f, 0.0f}}});
|
|
|
|
auto input_literal =
|
|
client_->TransferToServer(LiteralUtil::CreateR3FromArray3D(input))
|
|
.ConsumeValueOrDie();
|
|
auto filter_literal =
|
|
client_->TransferToServer(LiteralUtil::CreateR3FromArray3D(filter))
|
|
.ConsumeValueOrDie();
|
|
|
|
ComputeAndCompareR3<T>(&builder, expected,
|
|
{input_literal.get(), filter_literal.get()},
|
|
error_spec_);
|
|
}
|
|
};
|
|
|
|
TYPED_TEST_CASE(Convolve1D_1x2x5_1x2x2_WithPadding, TestTypes);
|
|
TYPED_TEST(Convolve1D_1x2x5_1x2x2_WithPadding, Types) { this->RunTest(); }
|
|
|
|
} // namespace
|
|
} // namespace xla
|