[XLA:CPU] Implement RngBernoulli for F32 and F64
PiperOrigin-RevId: 180283205
This commit is contained in:
parent
e4b27cb2ea
commit
ade8058c51
@ -1267,10 +1267,24 @@ llvm_ir::ElementGenerator ElementalIrEmitter::MakeRngElementGenerator(
|
||||
case RNG_BERNOULLI: {
|
||||
TF_ASSIGN_OR_RETURN(llvm::Value * p,
|
||||
operand_to_generator.at(hlo->operand(0))(index));
|
||||
return ir_builder_->CreateZExt(
|
||||
ir_builder_->CreateFCmpOLT(get_next_uniform_float(), p),
|
||||
llvm_ir::PrimitiveTypeToIrType(hlo->shape().element_type(),
|
||||
module_));
|
||||
PrimitiveType element_type = hlo->shape().element_type();
|
||||
llvm::Value* zero;
|
||||
llvm::Value* one;
|
||||
llvm::Type* result_ir_type = llvm_ir::PrimitiveTypeToIrType(
|
||||
hlo->shape().element_type(), module_);
|
||||
if (primitive_util::IsFloatingPointType(element_type)) {
|
||||
zero = llvm::ConstantFP::get(result_ir_type, 0.0);
|
||||
one = llvm::ConstantFP::get(result_ir_type, 1.0);
|
||||
} else if (primitive_util::IsIntegralType(element_type)) {
|
||||
zero = llvm::ConstantInt::get(result_ir_type, 0);
|
||||
one = llvm::ConstantInt::get(result_ir_type, 1);
|
||||
} else {
|
||||
return Unimplemented(
|
||||
"Rng Bernoulli unimplemented for requested type!");
|
||||
}
|
||||
|
||||
return ir_builder_->CreateSelect(
|
||||
ir_builder_->CreateFCmpOLT(get_next_uniform_float(), p), one, zero);
|
||||
}
|
||||
default:
|
||||
return InvalidArgument(
|
||||
|
@ -37,6 +37,8 @@ class PrngTest : public ClientLibraryTestBase {
|
||||
protected:
|
||||
template <typename T>
|
||||
void UniformTest(T a, T b, tensorflow::gtl::ArraySlice<int64> dims);
|
||||
|
||||
template <typename T>
|
||||
void BernoulliTest(float p, tensorflow::gtl::ArraySlice<int64> dims);
|
||||
|
||||
// Computes the χ² statistic of a sample of the discrete uniform distribution
|
||||
@ -62,9 +64,11 @@ void PrngTest::UniformTest(T a, T b, tensorflow::gtl::ArraySlice<int64> dims) {
|
||||
});
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void PrngTest::BernoulliTest(float p, tensorflow::gtl::ArraySlice<int64> dims) {
|
||||
ComputationBuilder builder(client_, TestName());
|
||||
auto shape = ShapeUtil::MakeShape(U32, dims);
|
||||
auto shape =
|
||||
ShapeUtil::MakeShape(primitive_util::NativeToPrimitiveType<T>(), dims);
|
||||
builder.RngBernoulli(builder.ConstantR0<float>(p), shape);
|
||||
|
||||
TF_ASSERT_OK_AND_ASSIGN(auto computation, builder.Build());
|
||||
@ -74,21 +78,22 @@ void PrngTest::BernoulliTest(float p, tensorflow::gtl::ArraySlice<int64> dims) {
|
||||
auto actual, client_->ExecuteAndTransfer(computation, /*arguments=*/{},
|
||||
&execution_options));
|
||||
EXPECT_THAT(dims, ::testing::ElementsAreArray(actual->shape().dimensions()));
|
||||
int32 sum = 0;
|
||||
actual->EachCell<uint32>(
|
||||
[&sum](tensorflow::gtl::ArraySlice<int64>, uint32 value) {
|
||||
EXPECT_TRUE(value == 0 || value == 1);
|
||||
sum += value;
|
||||
});
|
||||
int32 total = ShapeUtil::ElementsIn(shape);
|
||||
float p_tilde = sum / static_cast<float>(total);
|
||||
T sum = 0;
|
||||
actual->EachCell<T>([&sum](tensorflow::gtl::ArraySlice<int64>, T value) {
|
||||
EXPECT_TRUE(value == static_cast<T>(0) || value == static_cast<T>(1));
|
||||
sum += value;
|
||||
});
|
||||
|
||||
int32 elements_in_output = ShapeUtil::ElementsIn(shape);
|
||||
float p_tilde = sum / static_cast<float>(elements_in_output);
|
||||
|
||||
// Test within expected range using normal approximation. The test uses a
|
||||
// fixed seed and has a fixed output per p and backend. Using the normal
|
||||
// approximation as this test is invoked for different `p` and the different
|
||||
// backends could use different random number generators and produce different
|
||||
// values. Choose 95% confidence level, so that z_{1-\alpha/2} = 1.96.
|
||||
float normal_approximation_term = 1.96 * sqrt(p * (1 - p) / total);
|
||||
float normal_approximation_term =
|
||||
1.96 * sqrt(p * (1 - p) / elements_in_output);
|
||||
EXPECT_GE(p_tilde, p - normal_approximation_term);
|
||||
EXPECT_LE(p_tilde, p + normal_approximation_term);
|
||||
}
|
||||
@ -251,8 +256,30 @@ XLA_TEST_F(PrngTest, PassInGlobalRngSeed) {
|
||||
}
|
||||
|
||||
// Bernoulli random number generation tests
|
||||
XLA_TEST_F(PrngTest, HundredValuesB10p5) { BernoulliTest(0.5, {100}); }
|
||||
XLA_TEST_F(PrngTest, HundredValuesB10p1) { BernoulliTest(0.1, {100}); }
|
||||
XLA_TEST_F(PrngTest, HundredValuesB10p5U32) {
|
||||
BernoulliTest<uint32>(0.5, {100});
|
||||
}
|
||||
XLA_TEST_F(PrngTest, HundredValuesB10p1U32) {
|
||||
BernoulliTest<uint32>(0.1, {100});
|
||||
}
|
||||
XLA_TEST_F(PrngTest, HundredValuesB10p5S32) {
|
||||
BernoulliTest<int32>(0.5, {100});
|
||||
}
|
||||
XLA_TEST_F(PrngTest, HundredValuesB10p1S32) {
|
||||
BernoulliTest<int32>(0.1, {100});
|
||||
}
|
||||
XLA_TEST_F(PrngTest, HundredValuesB10p5F32) {
|
||||
BernoulliTest<float>(0.5, {100});
|
||||
}
|
||||
XLA_TEST_F(PrngTest, HundredValuesB10p1F32) {
|
||||
BernoulliTest<float>(0.1, {100});
|
||||
}
|
||||
XLA_TEST_F(PrngTest, HundredValuesB10p5F64) {
|
||||
BernoulliTest<double>(0.5, {100});
|
||||
}
|
||||
XLA_TEST_F(PrngTest, HundredValuesB10p1F64) {
|
||||
BernoulliTest<double>(0.1, {100});
|
||||
}
|
||||
|
||||
XLA_TEST_F(PrngTest, TenValuesN01) {
|
||||
ComputationBuilder builder(client_, TestName());
|
||||
|
Loading…
Reference in New Issue
Block a user