[XLA] Add tests for large numbers of parameter / return values and while loops.

PiperOrigin-RevId: 168487225
This commit is contained in:
Chris Leary 2017-09-12 20:10:34 -07:00 committed by TensorFlower Gardener
parent 3cd6bdef5f
commit 11d3ac29d5
6 changed files with 178 additions and 4 deletions

View File

@ -2419,7 +2419,8 @@ Status IrEmitter::HandleWhile(HloInstruction* xla_while) {
HloComputation* condition = xla_while->while_condition();
TF_RET_CHECK(ShapeUtil::IsScalar(condition->root_instruction()->shape()) &&
condition->root_instruction()->shape().element_type() == PRED)
<< "While condition computation must return bool";
<< "While condition computation must return bool; got: "
<< ShapeUtil::HumanString(condition->root_instruction()->shape());
// Check that all while-related buffers share an allocation slice.
TF_RETURN_IF_ERROR(ShapeUtil::ForEachSubshapeWithStatus(
xla_while->shape(),

View File

@ -36,8 +36,8 @@ void DumpModule(const HloModule& module,
const string& message) {
hlo_graph_dumper::MaybeDumpHloModule(module, message);
VLOG(2) << "HLO " << message << ":";
XLA_VLOG_LINES(2, module.ToString());
VLOG(3) << "HLO " << message << ":";
XLA_VLOG_LINES(3, module.ToString());
}
} // namespace

View File

@ -370,6 +370,7 @@ xla_test(
xla_test(
name = "params_test",
srcs = ["params_test.cc"],
shard_count = 30,
deps = [
"//tensorflow/compiler/xla:array2d",
"//tensorflow/compiler/xla:literal_util",

View File

@ -258,7 +258,8 @@ tensorflow::Status ClientLibraryTestBase::ComputeAndCompareLiteralWithStatus(
LOG(WARNING) << "performing exact comparison of floating point numbers";
} else {
TF_RET_CHECK(ShapeUtil::ElementIsIntegral(expected.shape()) ||
expected.shape().element_type() == PRED);
expected.shape().element_type() == PRED)
<< ShapeUtil::HumanString(expected.shape());
}
auto expect_equal = [&](const Literal& actual, const string& error_message) {
LiteralTestUtil::ExpectEqual(expected, actual, error_message);

View File

@ -251,6 +251,85 @@ XLA_TEST_F(ParamsTest, HundredLargeR1Parameters) {
ComputeAndCompareR1<float>(&builder, sum, param_data, ErrorSpec(0.0001f));
}
// TODO(b/65525254) Fails on GPU on 2017-09-10 because we try to reserve too
// much space in parameter memory for the kernel.
//
// TODO(b/65526061) Failed on CPU on 2017-09-10 due to timeout in LLVM
// compilation.
XLA_TEST_F(ParamsTest,
DISABLED_ON_CPU(DISABLED_ON_GPU(ThreeThousandParameters))) {
ComputationBuilder builder(client_, TestName());
std::vector<std::unique_ptr<GlobalData>> param_data_owner;
ComputationDataHandle sum_handle = builder.ConstantR0<float>(0.0f);
float target = 0.0;
constexpr int kParamCount = 3000;
for (int i = 0; i < kParamCount; ++i) {
target += i;
std::unique_ptr<Literal> literal = Literal::CreateR0<float>(i);
param_data_owner.push_back(
std::move(client_->TransferToServer(*literal)).ValueOrDie());
ComputationDataHandle param =
builder.Parameter(i, literal->shape(), "param");
sum_handle = builder.Add(sum_handle, param);
}
std::vector<GlobalData*> param_data;
param_data.reserve(param_data_owner.size());
for (const std::unique_ptr<GlobalData>& data : param_data_owner) {
param_data.push_back(data.get());
}
ComputeAndCompareR0<float>(&builder, target, param_data, ErrorSpec(0.0001f));
}
// TODO(b/65525254) Fails on GPU on 2017-09-10 because we try to reserve too
// much space in parameter memory for the kernel.
//
// TODO(b/65526061) Failed on CPU on 2017-09-10 due to timeout in LLVM
// compilation.
XLA_TEST_F(ParamsTest, DISABLED_ON_CPU(DISABLED_ON_GPU(
ThreeThousandParametersAndOutputElements))) {
ComputationBuilder builder(client_, TestName());
std::vector<std::unique_ptr<GlobalData>> param_data_owner;
ComputationDataHandle sum_handle = builder.ConstantR1<int32>({0, 0});
int32 target = 0;
constexpr int kParamCount = 3000;
std::vector<ComputationDataHandle> params;
for (int i = 0; i < kParamCount; ++i) {
target += i;
std::unique_ptr<Literal> literal = Literal::CreateR1<int32>({i, i});
param_data_owner.push_back(
std::move(client_->TransferToServer(*literal)).ValueOrDie());
ComputationDataHandle param =
builder.Parameter(i, literal->shape(), "param");
params.push_back(param);
sum_handle = builder.Add(sum_handle, param);
}
std::vector<ComputationDataHandle> outputs;
for (int i = 0; i < kParamCount; ++i) {
outputs.push_back(builder.Add(params[i], sum_handle));
}
builder.Tuple(outputs);
std::vector<GlobalData*> param_data;
param_data.reserve(param_data_owner.size());
for (const std::unique_ptr<GlobalData>& data : param_data_owner) {
param_data.push_back(data.get());
}
std::vector<std::unique_ptr<Literal>> elements;
std::vector<const Literal*> ptrs;
for (int i = 0; i < kParamCount; ++i) {
elements.push_back(Literal::CreateR1<int32>({target + i, target + i}));
ptrs.push_back(elements.back().get());
}
ComputeAndCompareTuple(&builder, *Literal::MakeTuple(ptrs), param_data);
}
XLA_TEST_F(ParamsTest,
DISABLED_ON_CPU_PARALLEL(TupleOfR1ParametersAddedTogether)) {
ComputationBuilder builder(client_, TestName());

View File

@ -31,6 +31,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/tests/literal_test_util.h"
#include "tensorflow/compiler/xla/tests/test_macros.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
#include "tensorflow/core/lib/core/status_test_util.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/test.h"
#include "tensorflow/core/platform/test_benchmark.h"
@ -770,6 +771,97 @@ TEST_F(WhileTest, WhileWithPrngScalarResult) {
}
}
// TODO(b/34969189) Fails with bad AtomicCmpSwap on GPU on 2017-09-11.
TEST_F(WhileTest, DISABLED_ON_GPU(WhileThatSwapsParameterWithTupleElement)) {
auto element_shape = ShapeUtil::MakeShape(F32, {2});
ComputationBuilder outer(client_, "outer");
auto p = outer.Parameter(0, element_shape, "param");
auto t = outer.Tuple({p, outer.ConstantR1<float>({1, 1})});
TF_ASSERT_OK_AND_ASSIGN(const std::unique_ptr<Shape> tuple_shape,
outer.GetShape(t));
ComputationBuilder cond(client_, "cond");
auto cond_t = cond.Parameter(0, *tuple_shape, "t");
TF_ASSERT_OK(Any(cond.Eq(cond.GetTupleElement(cond_t, 0),
cond.ConstantR1<float>({42, 42})),
&cond)
.status());
ComputationBuilder body(client_, "body");
auto body_t = body.Parameter(0, *tuple_shape, "t");
auto e = body.GetTupleElement(body_t, 1);
body.Tuple({e, e});
TF_ASSERT_OK_AND_ASSIGN(auto cond_computation, cond.Build());
TF_ASSERT_OK_AND_ASSIGN(auto body_computation, body.Build());
outer.While(cond_computation, body_computation, t);
auto expected_element = Literal::CreateR1<float>({1, 1});
auto expected =
Literal::MakeTuple({expected_element.get(), expected_element.get()});
TF_ASSERT_OK_AND_ASSIGN(
std::unique_ptr<GlobalData> parameter_data,
client_->TransferToServer(*Literal::CreateR1<float>({42, 42})));
ComputeAndCompareTuple(&outer, *expected, {parameter_data.get()},
ErrorSpec(1e-6));
}
// TODO(b/34969189) Fails with bad AtomicCmpSwap on GPU on 2017-09-11.
TEST_F(WhileTest, DISABLED_ON_GPU(WhileThatSwapsParameterWithBroadcast)) {
auto element_shape = ShapeUtil::MakeShape(F32, {2});
ComputationBuilder outer(client_, "outer");
auto p = outer.Parameter(0, element_shape, "param");
ComputationBuilder cond(client_, "cond");
auto cond_t = cond.Parameter(0, element_shape, "t");
TF_ASSERT_OK(
Any(cond.Eq(cond_t, cond.ConstantR1<float>({42, 42})), &cond).status());
ComputationBuilder body(client_, "body");
auto body_t = body.Parameter(0, element_shape, "t");
auto e = body.Broadcast(body.ConstantR0<float>(1.0), {2});
TF_ASSERT_OK_AND_ASSIGN(auto cond_computation, cond.Build());
TF_ASSERT_OK_AND_ASSIGN(auto body_computation, body.Build());
outer.While(cond_computation, body_computation, p);
TF_ASSERT_OK_AND_ASSIGN(
std::unique_ptr<GlobalData> parameter_data,
client_->TransferToServer(*Literal::CreateR1<float>({42, 42})));
ComputeAndCompareR1<float>(&outer, {1.0f, 1.0f}, {parameter_data.get()},
ErrorSpec(1e-6));
}
TEST_F(WhileTest, WhileThatTurnsScalarParameterToTupleElement) {
auto element_shape = ShapeUtil::MakeShape(F32, {});
ComputationBuilder outer(client_, "outer");
auto p = outer.Parameter(0, element_shape, "param");
ComputationBuilder cond(client_, "cond");
auto cond_t = cond.Parameter(0, element_shape, "t");
cond.Eq(cond_t, cond.ConstantR0<float>(42));
ComputationBuilder body(client_, "body");
auto body_t = body.Parameter(0, element_shape, "t");
auto tuple =
body.Tuple({body_t, body.Add(body_t, body.ConstantR0<float>(1))});
auto e = body.GetTupleElement(tuple, 1);
TF_ASSERT_OK_AND_ASSIGN(auto cond_computation, cond.Build());
TF_ASSERT_OK_AND_ASSIGN(auto body_computation, body.Build());
outer.While(cond_computation, body_computation, p);
TF_ASSERT_OK_AND_ASSIGN(
std::unique_ptr<GlobalData> parameter_data,
client_->TransferToServer(*Literal::CreateR0<float>(42)));
ComputeAndCompareR0<float>(&outer, 43.0f, {parameter_data.get()},
ErrorSpec(1e-6));
}
// Tests nested while loops.
//
// int32 result = 0;