[XLA] Support RNG in HloEvaluator.
PiperOrigin-RevId: 227642508
This commit is contained in:
parent
a8bf983c60
commit
df0bd2904e
@ -52,7 +52,7 @@ StatusOr<bool> HloConstantFolding::Run(HloModule* module) {
|
||||
computation->root_instruction() != instruction) {
|
||||
continue;
|
||||
}
|
||||
// Skip Constant, Parameter, Tuple, AfterAll operation.
|
||||
// Skip Constant, Parameter, Tuple, AfterAll, Rng operations.
|
||||
// Tuple constants are not directly supported by any backends, hence
|
||||
// folding Tuple is not useful and would in fact be expanded back into
|
||||
// kTuple by Algebraic Simplifier.
|
||||
@ -62,7 +62,8 @@ StatusOr<bool> HloConstantFolding::Run(HloModule* module) {
|
||||
if (instruction->opcode() == HloOpcode::kParameter ||
|
||||
instruction->opcode() == HloOpcode::kConstant ||
|
||||
instruction->opcode() == HloOpcode::kTuple ||
|
||||
instruction->opcode() == HloOpcode::kAfterAll) {
|
||||
instruction->opcode() == HloOpcode::kAfterAll ||
|
||||
instruction->opcode() == HloOpcode::kRng) {
|
||||
continue;
|
||||
}
|
||||
|
||||
|
@ -233,6 +233,14 @@ StatusOr<Literal> HloEvaluator::Evaluate(
|
||||
for (const auto& literal_ptr : arg_literals) {
|
||||
arg_literals_.push_back(&*literal_ptr);
|
||||
}
|
||||
if (computation.parent()->config().seed()) {
|
||||
seed_ = computation.parent()->config().seed();
|
||||
} else {
|
||||
std::random_device rd;
|
||||
seed_ = rd();
|
||||
}
|
||||
|
||||
engine_ = std::minstd_rand0(seed_);
|
||||
|
||||
TF_RETURN_IF_ERROR(computation.Accept(this));
|
||||
return GetEvaluatedLiteralFor(computation.root_instruction()).Clone();
|
||||
|
@ -290,6 +290,11 @@ class HloEvaluator : public DfsHloVisitorWithDefault {
|
||||
// Max loop iterations to execute with no maximum if negative.
|
||||
int64 max_loop_iterations_;
|
||||
|
||||
// Module-level seed handle.
|
||||
uint64 seed_;
|
||||
// RNG engine.
|
||||
std::minstd_rand0 engine_;
|
||||
|
||||
TF_DISALLOW_COPY_AND_ASSIGN(HloEvaluator);
|
||||
};
|
||||
|
||||
|
@ -2653,7 +2653,7 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault {
|
||||
template <typename NativeT, typename std::enable_if<std::is_same<
|
||||
double, NativeT>::value>::type* = nullptr>
|
||||
Status HandleReducePrecision(HloInstruction* reduce_precision) {
|
||||
return InvalidArgument("Double not supported for reduce precision");
|
||||
return InvalidArgument("Double is not supported for reduce precision");
|
||||
}
|
||||
|
||||
template <
|
||||
@ -2719,6 +2719,103 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault {
|
||||
return HandleIota<ReturnT>(iota);
|
||||
}
|
||||
|
||||
template <typename NativeT,
|
||||
typename std::enable_if<
|
||||
!(std::is_integral<NativeT>::value ||
|
||||
std::is_floating_point<NativeT>::value)>::type* = nullptr>
|
||||
Status HandleRng(HloInstruction* random) {
|
||||
return UnsupportedTypeError(random);
|
||||
}
|
||||
template <typename NativeT,
|
||||
typename std::enable_if<
|
||||
(std::is_floating_point<NativeT>::value)>::type* = nullptr>
|
||||
Status HandleRng(HloInstruction* random) {
|
||||
RandomDistribution distribution = random->random_distribution();
|
||||
const auto result_shape = random->shape();
|
||||
Literal result(result_shape);
|
||||
|
||||
switch (distribution) {
|
||||
case RNG_UNIFORM: {
|
||||
const Literal& low =
|
||||
parent_->GetEvaluatedLiteralFor(random->operand(0));
|
||||
const Literal& high =
|
||||
parent_->GetEvaluatedLiteralFor(random->operand(1));
|
||||
|
||||
std::uniform_real_distribution<NativeT> generator(
|
||||
low.Get<NativeT>({}), high.Get<NativeT>({}));
|
||||
|
||||
TF_RETURN_IF_ERROR(
|
||||
result.Populate<NativeT>([&](absl::Span<const int64> /*indexes*/) {
|
||||
return generator(parent_->engine_);
|
||||
}));
|
||||
break;
|
||||
}
|
||||
case RNG_NORMAL: {
|
||||
const Literal& mean =
|
||||
parent_->GetEvaluatedLiteralFor(random->operand(0));
|
||||
const Literal& stddev =
|
||||
parent_->GetEvaluatedLiteralFor(random->operand(1));
|
||||
|
||||
std::normal_distribution<NativeT> generator(mean.Get<NativeT>({}),
|
||||
stddev.Get<NativeT>({}));
|
||||
|
||||
TF_RETURN_IF_ERROR(
|
||||
result.Populate<NativeT>([&](absl::Span<const int64> /*indexes*/) {
|
||||
return generator(parent_->engine_);
|
||||
}));
|
||||
break;
|
||||
}
|
||||
default:
|
||||
return UnimplementedStrCat("The distribution ",
|
||||
RandomDistribution_Name(distribution),
|
||||
" is not implemented.");
|
||||
}
|
||||
parent_->evaluated_[random] = std::move(result);
|
||||
return Status::OK();
|
||||
}
|
||||
template <typename NativeT,
|
||||
typename std::enable_if<(std::is_integral<NativeT>::value)>::type* =
|
||||
nullptr>
|
||||
Status HandleRng(HloInstruction* random) {
|
||||
RandomDistribution distribution = random->random_distribution();
|
||||
const auto result_shape = random->shape();
|
||||
Literal result(result_shape);
|
||||
|
||||
switch (distribution) {
|
||||
case RNG_UNIFORM: {
|
||||
const Literal& low =
|
||||
parent_->GetEvaluatedLiteralFor(random->operand(0));
|
||||
const Literal& high =
|
||||
parent_->GetEvaluatedLiteralFor(random->operand(1));
|
||||
|
||||
// Note std::uniform_int_distribution assumes interval is closed, i.e.,
|
||||
// [low, high], but we want [low, high) instead. Hence high-1 is used as
|
||||
// the upper range.
|
||||
std::uniform_int_distribution<int64> generator(
|
||||
low.Get<NativeT>({}), high.Get<NativeT>({}) - 1);
|
||||
|
||||
TF_RETURN_IF_ERROR(
|
||||
result.Populate<NativeT>([&](absl::Span<const int64> /*indexes*/) {
|
||||
return static_cast<NativeT>(generator(parent_->engine_));
|
||||
}));
|
||||
break;
|
||||
}
|
||||
case RNG_NORMAL: {
|
||||
return Unimplemented(
|
||||
"Normal distribution is not supported for integral types.");
|
||||
}
|
||||
default:
|
||||
return UnimplementedStrCat("The distribution ",
|
||||
RandomDistribution_Name(distribution),
|
||||
" is not implemented.");
|
||||
}
|
||||
parent_->evaluated_[random] = std::move(result);
|
||||
return Status::OK();
|
||||
}
|
||||
Status HandleRng(HloInstruction* random) override {
|
||||
return HandleRng<ReturnT>(random);
|
||||
}
|
||||
|
||||
private:
|
||||
// Creates a vector of multipliers which can be used to create a linear index
|
||||
// into shape.
|
||||
|
@ -1422,10 +1422,6 @@ xla_test(
|
||||
xla_test(
|
||||
name = "prng_test",
|
||||
srcs = ["prng_test.cc"],
|
||||
blacklisted_backends = [
|
||||
# TODO(b/122047800) support RNGs on the interpreter backend.
|
||||
"interpreter",
|
||||
],
|
||||
deps = [
|
||||
"//tensorflow/compiler/xla:literal",
|
||||
"//tensorflow/compiler/xla:shape_util",
|
||||
|
@ -80,7 +80,9 @@ XLA_TEST_F(PrngTest, LargeU01) { UniformTest<float>(0, 1, {0x100, 0x100}); }
|
||||
XLA_TEST_F(PrngTest, TwelveValuesU524) { UniformTest<int32>(5, 24, {12}); }
|
||||
|
||||
// TODO(b/71543667): Fix Rng ops on LLVM backends.
|
||||
XLA_TEST_F(PrngTest, DISABLED_ON_GPU(DISABLED_ON_CPU(ScalarBF16Tests))) {
|
||||
// TODO(b/122047800): Interpreter does not support BF16 for RNG ops.
|
||||
XLA_TEST_F(PrngTest, DISABLED_ON_INTERPRETER(
|
||||
DISABLED_ON_GPU(DISABLED_ON_CPU(ScalarBF16Tests)))) {
|
||||
for (int64 seed = 0; seed < 100; ++seed) {
|
||||
// The largest negative number smaller than zero in bf16 that's not
|
||||
// denormalized.
|
||||
@ -103,7 +105,9 @@ XLA_TEST_F(PrngTest, DISABLED_ON_GPU(DISABLED_ON_CPU(ScalarBF16Tests))) {
|
||||
}
|
||||
|
||||
// TODO(b/71543667): Fix Rng ops on LLVM backends.
|
||||
XLA_TEST_F(PrngTest, DISABLED_ON_GPU(DISABLED_ON_CPU(ScalarBF16CountTests))) {
|
||||
// TODO(b/122047800): Interpreter does not support BF16 for RNG ops.
|
||||
XLA_TEST_F(PrngTest, DISABLED_ON_INTERPRETER(DISABLED_ON_GPU(
|
||||
DISABLED_ON_CPU(ScalarBF16CountTests)))) {
|
||||
// There are 3 BF16 values in the range of [32.25, 33): 32.25, 32.5, 32.75,
|
||||
// they should get similar counts.
|
||||
bfloat16 low = static_cast<bfloat16>(32.25);
|
||||
@ -276,6 +280,39 @@ XLA_TEST_F(PrngTest, PassInGlobalRngSeed) {
|
||||
EXPECT_FALSE(LiteralTestUtil::Equal(result5, result6));
|
||||
}
|
||||
|
||||
// This test verifies that the two RNG instructions with the same parameters in
|
||||
// the same HloComputation produces different values.
|
||||
XLA_TEST_F(PrngTest, DifferentValuesForIdenticalRngNodesInSameComputation) {
|
||||
// Build a U[0,1) computation.
|
||||
auto build_computation = [this]() {
|
||||
XlaBuilder builder(TestName());
|
||||
auto a = RngUniform(ConstantR0<int32>(&builder, 0),
|
||||
ConstantR0<int32>(&builder, 100),
|
||||
ShapeUtil::MakeShape(S32, {10}));
|
||||
auto b = RngUniform(ConstantR0<int32>(&builder, 0),
|
||||
ConstantR0<int32>(&builder, 100),
|
||||
ShapeUtil::MakeShape(S32, {10}));
|
||||
Tuple(&builder, {a, b});
|
||||
return builder.Build();
|
||||
};
|
||||
|
||||
ExecutionOptions execution_options = execution_options_;
|
||||
execution_options.set_seed(42);
|
||||
|
||||
Literal result_tuple;
|
||||
{
|
||||
TF_ASSERT_OK_AND_ASSIGN(auto computation, build_computation());
|
||||
TF_ASSERT_OK_AND_ASSIGN(
|
||||
result_tuple, client_->ExecuteAndTransfer(computation, /*arguments=*/{},
|
||||
&execution_options));
|
||||
}
|
||||
|
||||
auto results = result_tuple.DecomposeTuple();
|
||||
ASSERT_EQ(results.size(), 2);
|
||||
|
||||
EXPECT_FALSE(LiteralTestUtil::Equal(results[0], results[1]));
|
||||
}
|
||||
|
||||
XLA_TEST_F(PrngTest, TenValuesN01) {
|
||||
XlaBuilder builder(TestName());
|
||||
RngNormal(ConstantR0<float>(&builder, 0), ConstantR0<float>(&builder, 1),
|
||||
|
Loading…
x
Reference in New Issue
Block a user