Previously we had a large number of ComputeAndCompare* methods to run a

computation and then compare the reuslt to a specified value (Array or
Literal). The new method takes adventage of the recently added
ComputeConstant method to calculate the expected value using the
HloEvaluator eliminating the need for doing the calculation manually.

As a usage example I converted the convolution tests to the new method
what simplified them by quite a bit. If there is interest then we can
migrate the other tests as well and then remove the old style
ComputeAndCompare* methods.

PiperOrigin-RevId: 175145596
This commit is contained in:
A. Unique TensorFlower 2017-11-09 05:36:43 -08:00 committed by TensorFlower Gardener
parent 71bd045af1
commit 18d5c3e4cf
3 changed files with 125 additions and 104 deletions

View File

@ -346,6 +346,60 @@ void ClientLibraryTestBase::ComputeAndCompareTuple(
LiteralTestUtil::ExpectNearTuple(expected, *actual, error);
}
void ClientLibraryTestBase::ComputeAndCompare(
ComputationBuilder* builder, const ComputationDataHandle& operand,
tensorflow::gtl::ArraySlice<Literal> arguments) {
auto status_or_data = ComputeValueAndReference(builder, operand, arguments);
EXPECT_IS_OK(status_or_data);
if (!status_or_data.ok()) {
return;
}
std::unique_ptr<Literal> reference, result;
std::tie(reference, result) = status_or_data.ConsumeValueOrDie();
LiteralTestUtil::ExpectEqual(*reference, *result);
}
void ClientLibraryTestBase::ComputeAndCompare(
ComputationBuilder* builder, const ComputationDataHandle& operand,
tensorflow::gtl::ArraySlice<Literal> arguments, ErrorSpec error) {
auto status_or_data = ComputeValueAndReference(builder, operand, arguments);
EXPECT_IS_OK(status_or_data);
if (!status_or_data.ok()) {
return;
}
std::unique_ptr<Literal> reference, result;
std::tie(reference, result) = status_or_data.ConsumeValueOrDie();
LiteralTestUtil::ExpectNear(*reference, *result, error);
}
StatusOr<std::pair<std::unique_ptr<Literal>, std::unique_ptr<Literal>>>
ClientLibraryTestBase::ComputeValueAndReference(
ComputationBuilder* builder, const ComputationDataHandle& operand,
tensorflow::gtl::ArraySlice<Literal> arguments) {
// Transfer the arguments to the executor service. We put the unique_ptr's
// into a vector to keep the data alive on the service until the end of this
// function.
std::vector<std::unique_ptr<GlobalData>> argument_data;
for (const auto& arg : arguments) {
TF_ASSIGN_OR_RETURN(auto data, client_->TransferToServer(arg));
argument_data.push_back(std::move(data));
}
// Create raw pointers to the GlobalData for the rest of the call stack.
std::vector<GlobalData*> argument_data_ptr;
std::transform(
argument_data.begin(), argument_data.end(),
std::back_inserter(argument_data_ptr),
[](const std::unique_ptr<GlobalData>& data) { return data.get(); });
TF_ASSIGN_OR_RETURN(
auto reference,
builder->ComputeConstant(operand, /*output_layout=*/nullptr, arguments));
TF_ASSIGN_OR_RETURN(auto result,
ExecuteAndTransfer(builder, argument_data_ptr));
return std::make_pair(std::move(reference), std::move(result));
}
Computation ClientLibraryTestBase::CreateScalarRelu() {
ComputationBuilder builder(client_, "relu");
auto z_value = builder.Parameter(0, ShapeUtil::MakeShape(F32, {}), "z_value");

View File

@ -196,6 +196,16 @@ class ClientLibraryTestBase : public ::testing::Test {
ComputationBuilder* builder, const Literal& expected,
tensorflow::gtl::ArraySlice<GlobalData*> arguments, ErrorSpec abs_error);
// Convenience method for running a built computation and comparing the result
// with the HloEvaluator.
void ComputeAndCompare(ComputationBuilder* builder,
const ComputationDataHandle& operand,
tensorflow::gtl::ArraySlice<Literal> arguments);
void ComputeAndCompare(ComputationBuilder* builder,
const ComputationDataHandle& operand,
tensorflow::gtl::ArraySlice<Literal> arguments,
ErrorSpec error);
// Create scalar operations for use in reductions.
Computation CreateScalarRelu();
Computation CreateScalarMax();
@ -298,6 +308,13 @@ class ClientLibraryTestBase : public ::testing::Test {
const std::function<void(const Literal& actual,
const string& error_message)>& verify_output,
const Shape* output_with_layout = nullptr);
// Executes the computation and calculates the expected reference value using
// the HloEvaluator. Returns two literal in the order of (expected, actual).
StatusOr<std::pair<std::unique_ptr<Literal>, std::unique_ptr<Literal>>>
ComputeValueAndReference(ComputationBuilder* builder,
const ComputationDataHandle& operand,
tensorflow::gtl::ArraySlice<Literal> arguments);
};
template <typename NativeT>

View File

@ -82,177 +82,127 @@ XLA_TEST_F(ConvolutionTest, ForwardPassConvolution_3x3x256_256_OutputZ_Iota) {
ComputationBuilder builder(client_, TestName());
auto lhs = builder.ConstantR4FromArray4D<float>(*alhs);
auto rhs = builder.ConstantR4FromArray4D<float>(*arhs);
builder.Conv(lhs, rhs, {1, 1}, Padding::kValid);
auto conv = builder.Conv(lhs, rhs, {1, 1}, Padding::kValid);
std::unique_ptr<Array4D<float>> aexpected =
ReferenceUtil::ConvArray4D(*alhs, *arhs, {1, 1}, Padding::kValid);
ComputeAndCompareR4<float>(&builder, *aexpected, {}, error_spec_);
ComputeAndCompare(&builder, conv, {}, error_spec_);
}
TEST_F(ConvolutionTest, Convolve_1x1x1x2_1x1x1x2_Valid) {
ComputationBuilder builder(client_, TestName());
{
Shape input_shape = ShapeUtil::MakeShape(F32, {1, 1, 1, 2});
Shape filter_shape = ShapeUtil::MakeShape(F32, {1, 1, 1, 2});
auto input = builder.Parameter(0, input_shape, "input");
auto filter = builder.Parameter(1, filter_shape, "filter");
builder.Conv(input, filter, {1, 1}, Padding::kValid);
}
Shape input_shape = ShapeUtil::MakeShape(F32, {1, 1, 1, 2});
Shape filter_shape = ShapeUtil::MakeShape(F32, {1, 1, 1, 2});
auto input = builder.Parameter(0, input_shape, "input");
auto filter = builder.Parameter(1, filter_shape, "filter");
auto conv = builder.Conv(input, filter, {1, 1}, Padding::kValid);
Array4D<float> input(1, 1, 1, 2);
input.FillWithYX(Array2D<float>({
Array4D<float> input_data(1, 1, 1, 2);
input_data.FillWithYX(Array2D<float>({
{1, 2},
}));
Array4D<float> filter(1, 1, 1, 2);
filter.FillWithYX(Array2D<float>({
Array4D<float> filter_data(1, 1, 1, 2);
filter_data.FillWithYX(Array2D<float>({
{5, 6},
}));
std::unique_ptr<Array4D<float>> aexpected =
ReferenceUtil::ConvArray4D(input, filter, {1, 1}, Padding::kValid);
auto input_literal =
client_->TransferToServer(*Literal::CreateR4FromArray4D(input))
.ConsumeValueOrDie();
auto filter_literal =
client_->TransferToServer(*Literal::CreateR4FromArray4D(filter))
.ConsumeValueOrDie();
ComputeAndCompareR4<float>(&builder, *aexpected,
{input_literal.get(), filter_literal.get()},
error_spec_);
ComputeAndCompare(&builder, conv,
{*Literal::CreateFromArray(input_data),
*Literal::CreateFromArray(filter_data)},
error_spec_);
}
// Tests valid padding for 2D convolution in raster space.
TEST_F(ConvolutionTest, Convolve_1x1x4x4_1x1x2x2_Valid) {
ComputationBuilder builder(client_, TestName());
{
Shape input_shape = ShapeUtil::MakeShape(F32, {1, 1, 4, 4});
Shape filter_shape = ShapeUtil::MakeShape(F32, {1, 1, 2, 2});
auto input = builder.Parameter(0, input_shape, "input");
auto filter = builder.Parameter(1, filter_shape, "filter");
builder.Conv(input, filter, {1, 1}, Padding::kValid);
}
Shape input_shape = ShapeUtil::MakeShape(F32, {1, 1, 4, 4});
Shape filter_shape = ShapeUtil::MakeShape(F32, {1, 1, 2, 2});
auto input = builder.Parameter(0, input_shape, "input");
auto filter = builder.Parameter(1, filter_shape, "filter");
auto conv = builder.Conv(input, filter, {1, 1}, Padding::kValid);
Array4D<float> input(1, 1, 4, 4);
Array4D<float> input_data(1, 1, 4, 4);
// clang-format off
input.FillWithYX(Array2D<float>({
input_data.FillWithYX(Array2D<float>({
{1, 2, 3, 4 },
{5, 6, 7, 8 },
{9, 10, 11, 12},
{13, 14, 15, 16},
}));
// clang-format on
Array4D<float> filter(1, 1, 2, 2);
Array4D<float> filter_data(1, 1, 2, 2);
// clang-format off
filter.FillWithYX(Array2D<float>({
filter_data.FillWithYX(Array2D<float>({
{5, 6},
{7, 8},
}));
// clang-format on
std::unique_ptr<Array4D<float>> aexpected =
ReferenceUtil::ConvArray4D(input, filter, {1, 1}, Padding::kValid);
auto input_literal =
client_->TransferToServer(*Literal::CreateR4FromArray4D(input))
.ConsumeValueOrDie();
auto filter_literal =
client_->TransferToServer(*Literal::CreateR4FromArray4D(filter))
.ConsumeValueOrDie();
ComputeAndCompareR4<float>(&builder, *aexpected,
{input_literal.get(), filter_literal.get()},
error_spec_);
ComputeAndCompare(&builder, conv,
{*Literal::CreateFromArray(input_data),
*Literal::CreateFromArray(filter_data)},
error_spec_);
}
// Tests same padding for 2D convolution in raster space.
TEST_F(ConvolutionTest, Convolve_1x1x4x4_1x1x2x2_Same) {
ComputationBuilder builder(client_, TestName());
{
Shape input_shape = ShapeUtil::MakeShape(F32, {1, 1, 4, 4});
Shape filter_shape = ShapeUtil::MakeShape(F32, {1, 1, 2, 2});
auto input = builder.Parameter(0, input_shape, "input");
auto filter = builder.Parameter(1, filter_shape, "filter");
builder.Conv(input, filter, {1, 1}, Padding::kSame);
}
Shape input_shape = ShapeUtil::MakeShape(F32, {1, 1, 4, 4});
Shape filter_shape = ShapeUtil::MakeShape(F32, {1, 1, 2, 2});
auto input = builder.Parameter(0, input_shape, "input");
auto filter = builder.Parameter(1, filter_shape, "filter");
auto conv = builder.Conv(input, filter, {1, 1}, Padding::kSame);
Array4D<float> input(1, 1, 4, 4);
Array4D<float> input_data(1, 1, 4, 4);
// clang-format off
input.FillWithYX(Array2D<float>({
input_data.FillWithYX(Array2D<float>({
{1, 2, 3, 4 },
{5, 6, 7, 8 },
{9, 10, 11, 12},
{13, 14, 15, 16},
}));
// clang-format on
Array4D<float> filter(1, 1, 2, 2);
Array4D<float> filter_data(1, 1, 2, 2);
// clang-format off
filter.FillWithYX(Array2D<float>({
filter_data.FillWithYX(Array2D<float>({
{5, 6},
{7, 8},
}));
// clang-format on
std::unique_ptr<Array4D<float>> aexpected =
ReferenceUtil::ConvArray4D(input, filter, {1, 1}, Padding::kSame);
auto input_literal =
client_->TransferToServer(*Literal::CreateR4FromArray4D(input))
.ConsumeValueOrDie();
auto filter_literal =
client_->TransferToServer(*Literal::CreateR4FromArray4D(filter))
.ConsumeValueOrDie();
ComputeAndCompareR4<float>(&builder, *aexpected,
{input_literal.get(), filter_literal.get()},
error_spec_);
ComputeAndCompare(&builder, conv,
{*Literal::CreateFromArray(input_data),
*Literal::CreateFromArray(filter_data)},
error_spec_);
}
// Tests same padding for 2D convolution in raster space with an odd sized
// kernel.
TEST_F(ConvolutionTest, Convolve_1x1x4x4_1x1x3x3_Same) {
ComputationBuilder builder(client_, TestName());
{
Shape input_shape = ShapeUtil::MakeShape(F32, {1, 1, 4, 4});
Shape filter_shape = ShapeUtil::MakeShape(F32, {1, 1, 3, 3});
auto input = builder.Parameter(0, input_shape, "input");
auto filter = builder.Parameter(1, filter_shape, "filter");
builder.Conv(input, filter, {1, 1}, Padding::kSame);
}
Shape input_shape = ShapeUtil::MakeShape(F32, {1, 1, 4, 4});
Shape filter_shape = ShapeUtil::MakeShape(F32, {1, 1, 3, 3});
auto input = builder.Parameter(0, input_shape, "input");
auto filter = builder.Parameter(1, filter_shape, "filter");
auto conv = builder.Conv(input, filter, {1, 1}, Padding::kSame);
Array4D<float> input(1, 1, 4, 4);
Array4D<float> input_data(1, 1, 4, 4);
// clang-format off
input.FillWithYX(Array2D<float>({
input_data.FillWithYX(Array2D<float>({
{1, 2, 3, 4 },
{5, 6, 7, 8 },
{9, 10, 11, 12},
{13, 14, 15, 16},
}));
// clang-format on
Array4D<float> filter(1, 1, 3, 3);
Array4D<float> filter_data(1, 1, 3, 3);
// clang-format off
filter.FillWithYX(Array2D<float>({
filter_data.FillWithYX(Array2D<float>({
{ 5, 6, 7},
{ 8, 9, 10},
{11, 12, 13},
}));
// clang-format on
std::unique_ptr<Array4D<float>> aexpected =
ReferenceUtil::ConvArray4D(input, filter, {1, 1}, Padding::kSame);
auto input_literal =
client_->TransferToServer(*Literal::CreateR4FromArray4D(input))
.ConsumeValueOrDie();
auto filter_literal =
client_->TransferToServer(*Literal::CreateR4FromArray4D(filter))
.ConsumeValueOrDie();
ComputeAndCompareR4<float>(&builder, *aexpected,
{input_literal.get(), filter_literal.get()},
error_spec_);
ComputeAndCompare(&builder, conv,
{*Literal::CreateFromArray(input_data),
*Literal::CreateFromArray(filter_data)},
error_spec_);
}
XLA_TEST_F(ConvolutionTest, Convolve1D_1x2x5_1x2x2_Valid) {