[XLA] Support RNG in HloEvaluator.

PiperOrigin-RevId: 227642508
This commit is contained in:
Kay Zhu 2019-01-03 00:30:28 -08:00 committed by TensorFlower Gardener
parent a8bf983c60
commit df0bd2904e
6 changed files with 153 additions and 9 deletions

View File

@ -52,7 +52,7 @@ StatusOr<bool> HloConstantFolding::Run(HloModule* module) {
computation->root_instruction() != instruction) {
continue;
}
// Skip Constant, Parameter, Tuple, AfterAll operation.
// Skip Constant, Parameter, Tuple, AfterAll, Rng operations.
// Tuple constants are not directly supported by any backends, hence
// folding Tuple is not useful and would in fact be expanded back into
// kTuple by Algebraic Simplifier.
@ -62,7 +62,8 @@ StatusOr<bool> HloConstantFolding::Run(HloModule* module) {
if (instruction->opcode() == HloOpcode::kParameter ||
instruction->opcode() == HloOpcode::kConstant ||
instruction->opcode() == HloOpcode::kTuple ||
instruction->opcode() == HloOpcode::kAfterAll) {
instruction->opcode() == HloOpcode::kAfterAll ||
instruction->opcode() == HloOpcode::kRng) {
continue;
}

View File

@ -233,6 +233,14 @@ StatusOr<Literal> HloEvaluator::Evaluate(
for (const auto& literal_ptr : arg_literals) {
arg_literals_.push_back(&*literal_ptr);
}
if (computation.parent()->config().seed()) {
seed_ = computation.parent()->config().seed();
} else {
std::random_device rd;
seed_ = rd();
}
engine_ = std::minstd_rand0(seed_);
TF_RETURN_IF_ERROR(computation.Accept(this));
return GetEvaluatedLiteralFor(computation.root_instruction()).Clone();

View File

@ -290,6 +290,11 @@ class HloEvaluator : public DfsHloVisitorWithDefault {
// Max loop iterations to execute with no maximum if negative.
int64 max_loop_iterations_;
// Module-level seed handle.
uint64 seed_;
// RNG engine.
std::minstd_rand0 engine_;
TF_DISALLOW_COPY_AND_ASSIGN(HloEvaluator);
};

View File

@ -2653,7 +2653,7 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault {
template <typename NativeT, typename std::enable_if<std::is_same<
double, NativeT>::value>::type* = nullptr>
Status HandleReducePrecision(HloInstruction* reduce_precision) {
return InvalidArgument("Double not supported for reduce precision");
return InvalidArgument("Double is not supported for reduce precision");
}
template <
@ -2719,6 +2719,103 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault {
return HandleIota<ReturnT>(iota);
}
template <typename NativeT,
typename std::enable_if<
!(std::is_integral<NativeT>::value ||
std::is_floating_point<NativeT>::value)>::type* = nullptr>
Status HandleRng(HloInstruction* random) {
return UnsupportedTypeError(random);
}
template <typename NativeT,
typename std::enable_if<
(std::is_floating_point<NativeT>::value)>::type* = nullptr>
Status HandleRng(HloInstruction* random) {
RandomDistribution distribution = random->random_distribution();
const auto result_shape = random->shape();
Literal result(result_shape);
switch (distribution) {
case RNG_UNIFORM: {
const Literal& low =
parent_->GetEvaluatedLiteralFor(random->operand(0));
const Literal& high =
parent_->GetEvaluatedLiteralFor(random->operand(1));
std::uniform_real_distribution<NativeT> generator(
low.Get<NativeT>({}), high.Get<NativeT>({}));
TF_RETURN_IF_ERROR(
result.Populate<NativeT>([&](absl::Span<const int64> /*indexes*/) {
return generator(parent_->engine_);
}));
break;
}
case RNG_NORMAL: {
const Literal& mean =
parent_->GetEvaluatedLiteralFor(random->operand(0));
const Literal& stddev =
parent_->GetEvaluatedLiteralFor(random->operand(1));
std::normal_distribution<NativeT> generator(mean.Get<NativeT>({}),
stddev.Get<NativeT>({}));
TF_RETURN_IF_ERROR(
result.Populate<NativeT>([&](absl::Span<const int64> /*indexes*/) {
return generator(parent_->engine_);
}));
break;
}
default:
return UnimplementedStrCat("The distribution ",
RandomDistribution_Name(distribution),
" is not implemented.");
}
parent_->evaluated_[random] = std::move(result);
return Status::OK();
}
template <typename NativeT,
typename std::enable_if<(std::is_integral<NativeT>::value)>::type* =
nullptr>
Status HandleRng(HloInstruction* random) {
RandomDistribution distribution = random->random_distribution();
const auto result_shape = random->shape();
Literal result(result_shape);
switch (distribution) {
case RNG_UNIFORM: {
const Literal& low =
parent_->GetEvaluatedLiteralFor(random->operand(0));
const Literal& high =
parent_->GetEvaluatedLiteralFor(random->operand(1));
// Note std::uniform_int_distribution assumes interval is closed, i.e.,
// [low, high], but we want [low, high) instead. Hence high-1 is used as
// the upper range.
std::uniform_int_distribution<int64> generator(
low.Get<NativeT>({}), high.Get<NativeT>({}) - 1);
TF_RETURN_IF_ERROR(
result.Populate<NativeT>([&](absl::Span<const int64> /*indexes*/) {
return static_cast<NativeT>(generator(parent_->engine_));
}));
break;
}
case RNG_NORMAL: {
return Unimplemented(
"Normal distribution is not supported for integral types.");
}
default:
return UnimplementedStrCat("The distribution ",
RandomDistribution_Name(distribution),
" is not implemented.");
}
parent_->evaluated_[random] = std::move(result);
return Status::OK();
}
Status HandleRng(HloInstruction* random) override {
return HandleRng<ReturnT>(random);
}
private:
// Creates a vector of multipliers which can be used to create a linear index
// into shape.

View File

@ -1422,10 +1422,6 @@ xla_test(
xla_test(
name = "prng_test",
srcs = ["prng_test.cc"],
blacklisted_backends = [
# TODO(b/122047800) support RNGs on the interpreter backend.
"interpreter",
],
deps = [
"//tensorflow/compiler/xla:literal",
"//tensorflow/compiler/xla:shape_util",

View File

@ -80,7 +80,9 @@ XLA_TEST_F(PrngTest, LargeU01) { UniformTest<float>(0, 1, {0x100, 0x100}); }
XLA_TEST_F(PrngTest, TwelveValuesU524) { UniformTest<int32>(5, 24, {12}); }
// TODO(b/71543667): Fix Rng ops on LLVM backends.
XLA_TEST_F(PrngTest, DISABLED_ON_GPU(DISABLED_ON_CPU(ScalarBF16Tests))) {
// TODO(b/122047800): Interpreter does not support BF16 for RNG ops.
XLA_TEST_F(PrngTest, DISABLED_ON_INTERPRETER(
DISABLED_ON_GPU(DISABLED_ON_CPU(ScalarBF16Tests)))) {
for (int64 seed = 0; seed < 100; ++seed) {
// The largest negative number smaller than zero in bf16 that's not
// denormalized.
@ -103,7 +105,9 @@ XLA_TEST_F(PrngTest, DISABLED_ON_GPU(DISABLED_ON_CPU(ScalarBF16Tests))) {
}
// TODO(b/71543667): Fix Rng ops on LLVM backends.
XLA_TEST_F(PrngTest, DISABLED_ON_GPU(DISABLED_ON_CPU(ScalarBF16CountTests))) {
// TODO(b/122047800): Interpreter does not support BF16 for RNG ops.
XLA_TEST_F(PrngTest, DISABLED_ON_INTERPRETER(DISABLED_ON_GPU(
DISABLED_ON_CPU(ScalarBF16CountTests)))) {
// There are 3 BF16 values in the range of [32.25, 33): 32.25, 32.5, 32.75,
// they should get similar counts.
bfloat16 low = static_cast<bfloat16>(32.25);
@ -276,6 +280,39 @@ XLA_TEST_F(PrngTest, PassInGlobalRngSeed) {
EXPECT_FALSE(LiteralTestUtil::Equal(result5, result6));
}
// This test verifies that the two RNG instructions with the same parameters in
// the same HloComputation produces different values.
XLA_TEST_F(PrngTest, DifferentValuesForIdenticalRngNodesInSameComputation) {
// Build a U[0,1) computation.
auto build_computation = [this]() {
XlaBuilder builder(TestName());
auto a = RngUniform(ConstantR0<int32>(&builder, 0),
ConstantR0<int32>(&builder, 100),
ShapeUtil::MakeShape(S32, {10}));
auto b = RngUniform(ConstantR0<int32>(&builder, 0),
ConstantR0<int32>(&builder, 100),
ShapeUtil::MakeShape(S32, {10}));
Tuple(&builder, {a, b});
return builder.Build();
};
ExecutionOptions execution_options = execution_options_;
execution_options.set_seed(42);
Literal result_tuple;
{
TF_ASSERT_OK_AND_ASSIGN(auto computation, build_computation());
TF_ASSERT_OK_AND_ASSIGN(
result_tuple, client_->ExecuteAndTransfer(computation, /*arguments=*/{},
&execution_options));
}
auto results = result_tuple.DecomposeTuple();
ASSERT_EQ(results.size(), 2);
EXPECT_FALSE(LiteralTestUtil::Equal(results[0], results[1]));
}
XLA_TEST_F(PrngTest, TenValuesN01) {
XlaBuilder builder(TestName());
RngNormal(ConstantR0<float>(&builder, 0), ConstantR0<float>(&builder, 1),