From ade8058c51e6837b506a451f21edb71f67750f60 Mon Sep 17 00:00:00 2001 From: Sanjoy Das Date: Thu, 28 Dec 2017 11:26:03 -0800 Subject: [PATCH] [XLA:CPU] Implement RngBernoulli for F32 and F64 PiperOrigin-RevId: 180283205 --- .../xla/service/elemental_ir_emitter.cc | 22 ++++++-- tensorflow/compiler/xla/tests/prng_test.cc | 51 ++++++++++++++----- 2 files changed, 57 insertions(+), 16 deletions(-) diff --git a/tensorflow/compiler/xla/service/elemental_ir_emitter.cc b/tensorflow/compiler/xla/service/elemental_ir_emitter.cc index 37929294327..e026dba4efa 100644 --- a/tensorflow/compiler/xla/service/elemental_ir_emitter.cc +++ b/tensorflow/compiler/xla/service/elemental_ir_emitter.cc @@ -1267,10 +1267,24 @@ llvm_ir::ElementGenerator ElementalIrEmitter::MakeRngElementGenerator( case RNG_BERNOULLI: { TF_ASSIGN_OR_RETURN(llvm::Value * p, operand_to_generator.at(hlo->operand(0))(index)); - return ir_builder_->CreateZExt( - ir_builder_->CreateFCmpOLT(get_next_uniform_float(), p), - llvm_ir::PrimitiveTypeToIrType(hlo->shape().element_type(), - module_)); + PrimitiveType element_type = hlo->shape().element_type(); + llvm::Value* zero; + llvm::Value* one; + llvm::Type* result_ir_type = llvm_ir::PrimitiveTypeToIrType( + hlo->shape().element_type(), module_); + if (primitive_util::IsFloatingPointType(element_type)) { + zero = llvm::ConstantFP::get(result_ir_type, 0.0); + one = llvm::ConstantFP::get(result_ir_type, 1.0); + } else if (primitive_util::IsIntegralType(element_type)) { + zero = llvm::ConstantInt::get(result_ir_type, 0); + one = llvm::ConstantInt::get(result_ir_type, 1); + } else { + return Unimplemented( + "Rng Bernoulli unimplemented for requested type!"); + } + + return ir_builder_->CreateSelect( + ir_builder_->CreateFCmpOLT(get_next_uniform_float(), p), one, zero); } default: return InvalidArgument( diff --git a/tensorflow/compiler/xla/tests/prng_test.cc b/tensorflow/compiler/xla/tests/prng_test.cc index 209f063cc5a..9c690ac8dc2 100644 --- a/tensorflow/compiler/xla/tests/prng_test.cc +++ b/tensorflow/compiler/xla/tests/prng_test.cc @@ -37,6 +37,8 @@ class PrngTest : public ClientLibraryTestBase { protected: template void UniformTest(T a, T b, tensorflow::gtl::ArraySlice dims); + + template void BernoulliTest(float p, tensorflow::gtl::ArraySlice dims); // Computes the χ² statistic of a sample of the discrete uniform distribution @@ -62,9 +64,11 @@ void PrngTest::UniformTest(T a, T b, tensorflow::gtl::ArraySlice dims) { }); } +template void PrngTest::BernoulliTest(float p, tensorflow::gtl::ArraySlice dims) { ComputationBuilder builder(client_, TestName()); - auto shape = ShapeUtil::MakeShape(U32, dims); + auto shape = + ShapeUtil::MakeShape(primitive_util::NativeToPrimitiveType(), dims); builder.RngBernoulli(builder.ConstantR0(p), shape); TF_ASSERT_OK_AND_ASSIGN(auto computation, builder.Build()); @@ -74,21 +78,22 @@ void PrngTest::BernoulliTest(float p, tensorflow::gtl::ArraySlice dims) { auto actual, client_->ExecuteAndTransfer(computation, /*arguments=*/{}, &execution_options)); EXPECT_THAT(dims, ::testing::ElementsAreArray(actual->shape().dimensions())); - int32 sum = 0; - actual->EachCell( - [&sum](tensorflow::gtl::ArraySlice, uint32 value) { - EXPECT_TRUE(value == 0 || value == 1); - sum += value; - }); - int32 total = ShapeUtil::ElementsIn(shape); - float p_tilde = sum / static_cast(total); + T sum = 0; + actual->EachCell([&sum](tensorflow::gtl::ArraySlice, T value) { + EXPECT_TRUE(value == static_cast(0) || value == static_cast(1)); + sum += value; + }); + + int32 elements_in_output = ShapeUtil::ElementsIn(shape); + float p_tilde = sum / static_cast(elements_in_output); // Test within expected range using normal approximation. The test uses a // fixed seed and has a fixed output per p and backend. Using the normal // approximation as this test is invoked for different `p` and the different // backends could use different random number generators and produce different // values. Choose 95% confidence level, so that z_{1-\alpha/2} = 1.96. - float normal_approximation_term = 1.96 * sqrt(p * (1 - p) / total); + float normal_approximation_term = + 1.96 * sqrt(p * (1 - p) / elements_in_output); EXPECT_GE(p_tilde, p - normal_approximation_term); EXPECT_LE(p_tilde, p + normal_approximation_term); } @@ -251,8 +256,30 @@ XLA_TEST_F(PrngTest, PassInGlobalRngSeed) { } // Bernoulli random number generation tests -XLA_TEST_F(PrngTest, HundredValuesB10p5) { BernoulliTest(0.5, {100}); } -XLA_TEST_F(PrngTest, HundredValuesB10p1) { BernoulliTest(0.1, {100}); } +XLA_TEST_F(PrngTest, HundredValuesB10p5U32) { + BernoulliTest(0.5, {100}); +} +XLA_TEST_F(PrngTest, HundredValuesB10p1U32) { + BernoulliTest(0.1, {100}); +} +XLA_TEST_F(PrngTest, HundredValuesB10p5S32) { + BernoulliTest(0.5, {100}); +} +XLA_TEST_F(PrngTest, HundredValuesB10p1S32) { + BernoulliTest(0.1, {100}); +} +XLA_TEST_F(PrngTest, HundredValuesB10p5F32) { + BernoulliTest(0.5, {100}); +} +XLA_TEST_F(PrngTest, HundredValuesB10p1F32) { + BernoulliTest(0.1, {100}); +} +XLA_TEST_F(PrngTest, HundredValuesB10p5F64) { + BernoulliTest(0.5, {100}); +} +XLA_TEST_F(PrngTest, HundredValuesB10p1F64) { + BernoulliTest(0.1, {100}); +} XLA_TEST_F(PrngTest, TenValuesN01) { ComputationBuilder builder(client_, TestName());