[XLA:CPU] Implement RngBernoulli for F32 and F64

PiperOrigin-RevId: 180283205
This commit is contained in:
Sanjoy Das 2017-12-28 11:26:03 -08:00 committed by TensorFlower Gardener
parent e4b27cb2ea
commit ade8058c51
2 changed files with 57 additions and 16 deletions

View File

@ -1267,10 +1267,24 @@ llvm_ir::ElementGenerator ElementalIrEmitter::MakeRngElementGenerator(
case RNG_BERNOULLI: {
TF_ASSIGN_OR_RETURN(llvm::Value * p,
operand_to_generator.at(hlo->operand(0))(index));
return ir_builder_->CreateZExt(
ir_builder_->CreateFCmpOLT(get_next_uniform_float(), p),
llvm_ir::PrimitiveTypeToIrType(hlo->shape().element_type(),
module_));
PrimitiveType element_type = hlo->shape().element_type();
llvm::Value* zero;
llvm::Value* one;
llvm::Type* result_ir_type = llvm_ir::PrimitiveTypeToIrType(
hlo->shape().element_type(), module_);
if (primitive_util::IsFloatingPointType(element_type)) {
zero = llvm::ConstantFP::get(result_ir_type, 0.0);
one = llvm::ConstantFP::get(result_ir_type, 1.0);
} else if (primitive_util::IsIntegralType(element_type)) {
zero = llvm::ConstantInt::get(result_ir_type, 0);
one = llvm::ConstantInt::get(result_ir_type, 1);
} else {
return Unimplemented(
"Rng Bernoulli unimplemented for requested type!");
}
return ir_builder_->CreateSelect(
ir_builder_->CreateFCmpOLT(get_next_uniform_float(), p), one, zero);
}
default:
return InvalidArgument(

View File

@ -37,6 +37,8 @@ class PrngTest : public ClientLibraryTestBase {
protected:
template <typename T>
void UniformTest(T a, T b, tensorflow::gtl::ArraySlice<int64> dims);
template <typename T>
void BernoulliTest(float p, tensorflow::gtl::ArraySlice<int64> dims);
// Computes the χ² statistic of a sample of the discrete uniform distribution
@ -62,9 +64,11 @@ void PrngTest::UniformTest(T a, T b, tensorflow::gtl::ArraySlice<int64> dims) {
});
}
template <typename T>
void PrngTest::BernoulliTest(float p, tensorflow::gtl::ArraySlice<int64> dims) {
ComputationBuilder builder(client_, TestName());
auto shape = ShapeUtil::MakeShape(U32, dims);
auto shape =
ShapeUtil::MakeShape(primitive_util::NativeToPrimitiveType<T>(), dims);
builder.RngBernoulli(builder.ConstantR0<float>(p), shape);
TF_ASSERT_OK_AND_ASSIGN(auto computation, builder.Build());
@ -74,21 +78,22 @@ void PrngTest::BernoulliTest(float p, tensorflow::gtl::ArraySlice<int64> dims) {
auto actual, client_->ExecuteAndTransfer(computation, /*arguments=*/{},
&execution_options));
EXPECT_THAT(dims, ::testing::ElementsAreArray(actual->shape().dimensions()));
int32 sum = 0;
actual->EachCell<uint32>(
[&sum](tensorflow::gtl::ArraySlice<int64>, uint32 value) {
EXPECT_TRUE(value == 0 || value == 1);
sum += value;
});
int32 total = ShapeUtil::ElementsIn(shape);
float p_tilde = sum / static_cast<float>(total);
T sum = 0;
actual->EachCell<T>([&sum](tensorflow::gtl::ArraySlice<int64>, T value) {
EXPECT_TRUE(value == static_cast<T>(0) || value == static_cast<T>(1));
sum += value;
});
int32 elements_in_output = ShapeUtil::ElementsIn(shape);
float p_tilde = sum / static_cast<float>(elements_in_output);
// Test within expected range using normal approximation. The test uses a
// fixed seed and has a fixed output per p and backend. Using the normal
// approximation as this test is invoked for different `p` and the different
// backends could use different random number generators and produce different
// values. Choose 95% confidence level, so that z_{1-\alpha/2} = 1.96.
float normal_approximation_term = 1.96 * sqrt(p * (1 - p) / total);
float normal_approximation_term =
1.96 * sqrt(p * (1 - p) / elements_in_output);
EXPECT_GE(p_tilde, p - normal_approximation_term);
EXPECT_LE(p_tilde, p + normal_approximation_term);
}
@ -251,8 +256,30 @@ XLA_TEST_F(PrngTest, PassInGlobalRngSeed) {
}
// Bernoulli random number generation tests
XLA_TEST_F(PrngTest, HundredValuesB10p5) { BernoulliTest(0.5, {100}); }
XLA_TEST_F(PrngTest, HundredValuesB10p1) { BernoulliTest(0.1, {100}); }
XLA_TEST_F(PrngTest, HundredValuesB10p5U32) {
BernoulliTest<uint32>(0.5, {100});
}
XLA_TEST_F(PrngTest, HundredValuesB10p1U32) {
BernoulliTest<uint32>(0.1, {100});
}
XLA_TEST_F(PrngTest, HundredValuesB10p5S32) {
BernoulliTest<int32>(0.5, {100});
}
XLA_TEST_F(PrngTest, HundredValuesB10p1S32) {
BernoulliTest<int32>(0.1, {100});
}
XLA_TEST_F(PrngTest, HundredValuesB10p5F32) {
BernoulliTest<float>(0.5, {100});
}
XLA_TEST_F(PrngTest, HundredValuesB10p1F32) {
BernoulliTest<float>(0.1, {100});
}
XLA_TEST_F(PrngTest, HundredValuesB10p5F64) {
BernoulliTest<double>(0.5, {100});
}
XLA_TEST_F(PrngTest, HundredValuesB10p1F64) {
BernoulliTest<double>(0.1, {100});
}
XLA_TEST_F(PrngTest, TenValuesN01) {
ComputationBuilder builder(client_, TestName());