Previously we had a large number of ComputeAndCompare* methods to run a
computation and then compare the reuslt to a specified value (Array or Literal). The new method takes adventage of the recently added ComputeConstant method to calculate the expected value using the HloEvaluator eliminating the need for doing the calculation manually. As a usage example I converted the convolution tests to the new method what simplified them by quite a bit. If there is interest then we can migrate the other tests as well and then remove the old style ComputeAndCompare* methods. PiperOrigin-RevId: 175145596
This commit is contained in:
parent
71bd045af1
commit
18d5c3e4cf
tensorflow/compiler/xla/tests
@ -346,6 +346,60 @@ void ClientLibraryTestBase::ComputeAndCompareTuple(
|
||||
LiteralTestUtil::ExpectNearTuple(expected, *actual, error);
|
||||
}
|
||||
|
||||
void ClientLibraryTestBase::ComputeAndCompare(
|
||||
ComputationBuilder* builder, const ComputationDataHandle& operand,
|
||||
tensorflow::gtl::ArraySlice<Literal> arguments) {
|
||||
auto status_or_data = ComputeValueAndReference(builder, operand, arguments);
|
||||
EXPECT_IS_OK(status_or_data);
|
||||
if (!status_or_data.ok()) {
|
||||
return;
|
||||
}
|
||||
std::unique_ptr<Literal> reference, result;
|
||||
std::tie(reference, result) = status_or_data.ConsumeValueOrDie();
|
||||
LiteralTestUtil::ExpectEqual(*reference, *result);
|
||||
}
|
||||
|
||||
void ClientLibraryTestBase::ComputeAndCompare(
|
||||
ComputationBuilder* builder, const ComputationDataHandle& operand,
|
||||
tensorflow::gtl::ArraySlice<Literal> arguments, ErrorSpec error) {
|
||||
auto status_or_data = ComputeValueAndReference(builder, operand, arguments);
|
||||
EXPECT_IS_OK(status_or_data);
|
||||
if (!status_or_data.ok()) {
|
||||
return;
|
||||
}
|
||||
std::unique_ptr<Literal> reference, result;
|
||||
std::tie(reference, result) = status_or_data.ConsumeValueOrDie();
|
||||
LiteralTestUtil::ExpectNear(*reference, *result, error);
|
||||
}
|
||||
|
||||
StatusOr<std::pair<std::unique_ptr<Literal>, std::unique_ptr<Literal>>>
|
||||
ClientLibraryTestBase::ComputeValueAndReference(
|
||||
ComputationBuilder* builder, const ComputationDataHandle& operand,
|
||||
tensorflow::gtl::ArraySlice<Literal> arguments) {
|
||||
// Transfer the arguments to the executor service. We put the unique_ptr's
|
||||
// into a vector to keep the data alive on the service until the end of this
|
||||
// function.
|
||||
std::vector<std::unique_ptr<GlobalData>> argument_data;
|
||||
for (const auto& arg : arguments) {
|
||||
TF_ASSIGN_OR_RETURN(auto data, client_->TransferToServer(arg));
|
||||
argument_data.push_back(std::move(data));
|
||||
}
|
||||
|
||||
// Create raw pointers to the GlobalData for the rest of the call stack.
|
||||
std::vector<GlobalData*> argument_data_ptr;
|
||||
std::transform(
|
||||
argument_data.begin(), argument_data.end(),
|
||||
std::back_inserter(argument_data_ptr),
|
||||
[](const std::unique_ptr<GlobalData>& data) { return data.get(); });
|
||||
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
auto reference,
|
||||
builder->ComputeConstant(operand, /*output_layout=*/nullptr, arguments));
|
||||
TF_ASSIGN_OR_RETURN(auto result,
|
||||
ExecuteAndTransfer(builder, argument_data_ptr));
|
||||
return std::make_pair(std::move(reference), std::move(result));
|
||||
}
|
||||
|
||||
Computation ClientLibraryTestBase::CreateScalarRelu() {
|
||||
ComputationBuilder builder(client_, "relu");
|
||||
auto z_value = builder.Parameter(0, ShapeUtil::MakeShape(F32, {}), "z_value");
|
||||
|
@ -196,6 +196,16 @@ class ClientLibraryTestBase : public ::testing::Test {
|
||||
ComputationBuilder* builder, const Literal& expected,
|
||||
tensorflow::gtl::ArraySlice<GlobalData*> arguments, ErrorSpec abs_error);
|
||||
|
||||
// Convenience method for running a built computation and comparing the result
|
||||
// with the HloEvaluator.
|
||||
void ComputeAndCompare(ComputationBuilder* builder,
|
||||
const ComputationDataHandle& operand,
|
||||
tensorflow::gtl::ArraySlice<Literal> arguments);
|
||||
void ComputeAndCompare(ComputationBuilder* builder,
|
||||
const ComputationDataHandle& operand,
|
||||
tensorflow::gtl::ArraySlice<Literal> arguments,
|
||||
ErrorSpec error);
|
||||
|
||||
// Create scalar operations for use in reductions.
|
||||
Computation CreateScalarRelu();
|
||||
Computation CreateScalarMax();
|
||||
@ -298,6 +308,13 @@ class ClientLibraryTestBase : public ::testing::Test {
|
||||
const std::function<void(const Literal& actual,
|
||||
const string& error_message)>& verify_output,
|
||||
const Shape* output_with_layout = nullptr);
|
||||
|
||||
// Executes the computation and calculates the expected reference value using
|
||||
// the HloEvaluator. Returns two literal in the order of (expected, actual).
|
||||
StatusOr<std::pair<std::unique_ptr<Literal>, std::unique_ptr<Literal>>>
|
||||
ComputeValueAndReference(ComputationBuilder* builder,
|
||||
const ComputationDataHandle& operand,
|
||||
tensorflow::gtl::ArraySlice<Literal> arguments);
|
||||
};
|
||||
|
||||
template <typename NativeT>
|
||||
|
@ -82,177 +82,127 @@ XLA_TEST_F(ConvolutionTest, ForwardPassConvolution_3x3x256_256_OutputZ_Iota) {
|
||||
ComputationBuilder builder(client_, TestName());
|
||||
auto lhs = builder.ConstantR4FromArray4D<float>(*alhs);
|
||||
auto rhs = builder.ConstantR4FromArray4D<float>(*arhs);
|
||||
builder.Conv(lhs, rhs, {1, 1}, Padding::kValid);
|
||||
auto conv = builder.Conv(lhs, rhs, {1, 1}, Padding::kValid);
|
||||
|
||||
std::unique_ptr<Array4D<float>> aexpected =
|
||||
ReferenceUtil::ConvArray4D(*alhs, *arhs, {1, 1}, Padding::kValid);
|
||||
|
||||
ComputeAndCompareR4<float>(&builder, *aexpected, {}, error_spec_);
|
||||
ComputeAndCompare(&builder, conv, {}, error_spec_);
|
||||
}
|
||||
|
||||
TEST_F(ConvolutionTest, Convolve_1x1x1x2_1x1x1x2_Valid) {
|
||||
ComputationBuilder builder(client_, TestName());
|
||||
{
|
||||
Shape input_shape = ShapeUtil::MakeShape(F32, {1, 1, 1, 2});
|
||||
Shape filter_shape = ShapeUtil::MakeShape(F32, {1, 1, 1, 2});
|
||||
auto input = builder.Parameter(0, input_shape, "input");
|
||||
auto filter = builder.Parameter(1, filter_shape, "filter");
|
||||
builder.Conv(input, filter, {1, 1}, Padding::kValid);
|
||||
}
|
||||
Shape input_shape = ShapeUtil::MakeShape(F32, {1, 1, 1, 2});
|
||||
Shape filter_shape = ShapeUtil::MakeShape(F32, {1, 1, 1, 2});
|
||||
auto input = builder.Parameter(0, input_shape, "input");
|
||||
auto filter = builder.Parameter(1, filter_shape, "filter");
|
||||
auto conv = builder.Conv(input, filter, {1, 1}, Padding::kValid);
|
||||
|
||||
Array4D<float> input(1, 1, 1, 2);
|
||||
input.FillWithYX(Array2D<float>({
|
||||
Array4D<float> input_data(1, 1, 1, 2);
|
||||
input_data.FillWithYX(Array2D<float>({
|
||||
{1, 2},
|
||||
}));
|
||||
Array4D<float> filter(1, 1, 1, 2);
|
||||
filter.FillWithYX(Array2D<float>({
|
||||
Array4D<float> filter_data(1, 1, 1, 2);
|
||||
filter_data.FillWithYX(Array2D<float>({
|
||||
{5, 6},
|
||||
}));
|
||||
|
||||
std::unique_ptr<Array4D<float>> aexpected =
|
||||
ReferenceUtil::ConvArray4D(input, filter, {1, 1}, Padding::kValid);
|
||||
|
||||
auto input_literal =
|
||||
client_->TransferToServer(*Literal::CreateR4FromArray4D(input))
|
||||
.ConsumeValueOrDie();
|
||||
auto filter_literal =
|
||||
client_->TransferToServer(*Literal::CreateR4FromArray4D(filter))
|
||||
.ConsumeValueOrDie();
|
||||
|
||||
ComputeAndCompareR4<float>(&builder, *aexpected,
|
||||
{input_literal.get(), filter_literal.get()},
|
||||
error_spec_);
|
||||
ComputeAndCompare(&builder, conv,
|
||||
{*Literal::CreateFromArray(input_data),
|
||||
*Literal::CreateFromArray(filter_data)},
|
||||
error_spec_);
|
||||
}
|
||||
|
||||
// Tests valid padding for 2D convolution in raster space.
|
||||
TEST_F(ConvolutionTest, Convolve_1x1x4x4_1x1x2x2_Valid) {
|
||||
ComputationBuilder builder(client_, TestName());
|
||||
{
|
||||
Shape input_shape = ShapeUtil::MakeShape(F32, {1, 1, 4, 4});
|
||||
Shape filter_shape = ShapeUtil::MakeShape(F32, {1, 1, 2, 2});
|
||||
auto input = builder.Parameter(0, input_shape, "input");
|
||||
auto filter = builder.Parameter(1, filter_shape, "filter");
|
||||
builder.Conv(input, filter, {1, 1}, Padding::kValid);
|
||||
}
|
||||
Shape input_shape = ShapeUtil::MakeShape(F32, {1, 1, 4, 4});
|
||||
Shape filter_shape = ShapeUtil::MakeShape(F32, {1, 1, 2, 2});
|
||||
auto input = builder.Parameter(0, input_shape, "input");
|
||||
auto filter = builder.Parameter(1, filter_shape, "filter");
|
||||
auto conv = builder.Conv(input, filter, {1, 1}, Padding::kValid);
|
||||
|
||||
Array4D<float> input(1, 1, 4, 4);
|
||||
Array4D<float> input_data(1, 1, 4, 4);
|
||||
// clang-format off
|
||||
input.FillWithYX(Array2D<float>({
|
||||
input_data.FillWithYX(Array2D<float>({
|
||||
{1, 2, 3, 4 },
|
||||
{5, 6, 7, 8 },
|
||||
{9, 10, 11, 12},
|
||||
{13, 14, 15, 16},
|
||||
}));
|
||||
// clang-format on
|
||||
Array4D<float> filter(1, 1, 2, 2);
|
||||
Array4D<float> filter_data(1, 1, 2, 2);
|
||||
// clang-format off
|
||||
filter.FillWithYX(Array2D<float>({
|
||||
filter_data.FillWithYX(Array2D<float>({
|
||||
{5, 6},
|
||||
{7, 8},
|
||||
}));
|
||||
// clang-format on
|
||||
|
||||
std::unique_ptr<Array4D<float>> aexpected =
|
||||
ReferenceUtil::ConvArray4D(input, filter, {1, 1}, Padding::kValid);
|
||||
|
||||
auto input_literal =
|
||||
client_->TransferToServer(*Literal::CreateR4FromArray4D(input))
|
||||
.ConsumeValueOrDie();
|
||||
auto filter_literal =
|
||||
client_->TransferToServer(*Literal::CreateR4FromArray4D(filter))
|
||||
.ConsumeValueOrDie();
|
||||
|
||||
ComputeAndCompareR4<float>(&builder, *aexpected,
|
||||
{input_literal.get(), filter_literal.get()},
|
||||
error_spec_);
|
||||
ComputeAndCompare(&builder, conv,
|
||||
{*Literal::CreateFromArray(input_data),
|
||||
*Literal::CreateFromArray(filter_data)},
|
||||
error_spec_);
|
||||
}
|
||||
|
||||
// Tests same padding for 2D convolution in raster space.
|
||||
TEST_F(ConvolutionTest, Convolve_1x1x4x4_1x1x2x2_Same) {
|
||||
ComputationBuilder builder(client_, TestName());
|
||||
{
|
||||
Shape input_shape = ShapeUtil::MakeShape(F32, {1, 1, 4, 4});
|
||||
Shape filter_shape = ShapeUtil::MakeShape(F32, {1, 1, 2, 2});
|
||||
auto input = builder.Parameter(0, input_shape, "input");
|
||||
auto filter = builder.Parameter(1, filter_shape, "filter");
|
||||
builder.Conv(input, filter, {1, 1}, Padding::kSame);
|
||||
}
|
||||
Shape input_shape = ShapeUtil::MakeShape(F32, {1, 1, 4, 4});
|
||||
Shape filter_shape = ShapeUtil::MakeShape(F32, {1, 1, 2, 2});
|
||||
auto input = builder.Parameter(0, input_shape, "input");
|
||||
auto filter = builder.Parameter(1, filter_shape, "filter");
|
||||
auto conv = builder.Conv(input, filter, {1, 1}, Padding::kSame);
|
||||
|
||||
Array4D<float> input(1, 1, 4, 4);
|
||||
Array4D<float> input_data(1, 1, 4, 4);
|
||||
// clang-format off
|
||||
input.FillWithYX(Array2D<float>({
|
||||
input_data.FillWithYX(Array2D<float>({
|
||||
{1, 2, 3, 4 },
|
||||
{5, 6, 7, 8 },
|
||||
{9, 10, 11, 12},
|
||||
{13, 14, 15, 16},
|
||||
}));
|
||||
// clang-format on
|
||||
Array4D<float> filter(1, 1, 2, 2);
|
||||
Array4D<float> filter_data(1, 1, 2, 2);
|
||||
// clang-format off
|
||||
filter.FillWithYX(Array2D<float>({
|
||||
filter_data.FillWithYX(Array2D<float>({
|
||||
{5, 6},
|
||||
{7, 8},
|
||||
}));
|
||||
// clang-format on
|
||||
|
||||
std::unique_ptr<Array4D<float>> aexpected =
|
||||
ReferenceUtil::ConvArray4D(input, filter, {1, 1}, Padding::kSame);
|
||||
|
||||
auto input_literal =
|
||||
client_->TransferToServer(*Literal::CreateR4FromArray4D(input))
|
||||
.ConsumeValueOrDie();
|
||||
auto filter_literal =
|
||||
client_->TransferToServer(*Literal::CreateR4FromArray4D(filter))
|
||||
.ConsumeValueOrDie();
|
||||
|
||||
ComputeAndCompareR4<float>(&builder, *aexpected,
|
||||
{input_literal.get(), filter_literal.get()},
|
||||
error_spec_);
|
||||
ComputeAndCompare(&builder, conv,
|
||||
{*Literal::CreateFromArray(input_data),
|
||||
*Literal::CreateFromArray(filter_data)},
|
||||
error_spec_);
|
||||
}
|
||||
|
||||
// Tests same padding for 2D convolution in raster space with an odd sized
|
||||
// kernel.
|
||||
TEST_F(ConvolutionTest, Convolve_1x1x4x4_1x1x3x3_Same) {
|
||||
ComputationBuilder builder(client_, TestName());
|
||||
{
|
||||
Shape input_shape = ShapeUtil::MakeShape(F32, {1, 1, 4, 4});
|
||||
Shape filter_shape = ShapeUtil::MakeShape(F32, {1, 1, 3, 3});
|
||||
auto input = builder.Parameter(0, input_shape, "input");
|
||||
auto filter = builder.Parameter(1, filter_shape, "filter");
|
||||
builder.Conv(input, filter, {1, 1}, Padding::kSame);
|
||||
}
|
||||
Shape input_shape = ShapeUtil::MakeShape(F32, {1, 1, 4, 4});
|
||||
Shape filter_shape = ShapeUtil::MakeShape(F32, {1, 1, 3, 3});
|
||||
auto input = builder.Parameter(0, input_shape, "input");
|
||||
auto filter = builder.Parameter(1, filter_shape, "filter");
|
||||
auto conv = builder.Conv(input, filter, {1, 1}, Padding::kSame);
|
||||
|
||||
Array4D<float> input(1, 1, 4, 4);
|
||||
Array4D<float> input_data(1, 1, 4, 4);
|
||||
// clang-format off
|
||||
input.FillWithYX(Array2D<float>({
|
||||
input_data.FillWithYX(Array2D<float>({
|
||||
{1, 2, 3, 4 },
|
||||
{5, 6, 7, 8 },
|
||||
{9, 10, 11, 12},
|
||||
{13, 14, 15, 16},
|
||||
}));
|
||||
// clang-format on
|
||||
Array4D<float> filter(1, 1, 3, 3);
|
||||
Array4D<float> filter_data(1, 1, 3, 3);
|
||||
// clang-format off
|
||||
filter.FillWithYX(Array2D<float>({
|
||||
filter_data.FillWithYX(Array2D<float>({
|
||||
{ 5, 6, 7},
|
||||
{ 8, 9, 10},
|
||||
{11, 12, 13},
|
||||
}));
|
||||
// clang-format on
|
||||
|
||||
std::unique_ptr<Array4D<float>> aexpected =
|
||||
ReferenceUtil::ConvArray4D(input, filter, {1, 1}, Padding::kSame);
|
||||
|
||||
auto input_literal =
|
||||
client_->TransferToServer(*Literal::CreateR4FromArray4D(input))
|
||||
.ConsumeValueOrDie();
|
||||
auto filter_literal =
|
||||
client_->TransferToServer(*Literal::CreateR4FromArray4D(filter))
|
||||
.ConsumeValueOrDie();
|
||||
|
||||
ComputeAndCompareR4<float>(&builder, *aexpected,
|
||||
{input_literal.get(), filter_literal.get()},
|
||||
error_spec_);
|
||||
ComputeAndCompare(&builder, conv,
|
||||
{*Literal::CreateFromArray(input_data),
|
||||
*Literal::CreateFromArray(filter_data)},
|
||||
error_spec_);
|
||||
}
|
||||
|
||||
XLA_TEST_F(ConvolutionTest, Convolve1D_1x2x5_1x2x2_Valid) {
|
||||
|
Loading…
Reference in New Issue
Block a user