[XLA] Miscellaneous code clean up to prepare for adding a Philox random bit
generator. Modify prng.h to export the ThreeFry random bit generators. Add three functions to support both stateful and stateless random number generation. These functions take a random bit generator to generate random bits, convert the random bits to values in uniform or normal distributions, and return the random numbers as well as the updated state of the random bit generator. This allows tensorflow to use the same XLA client APIs to implement stateful random ops and stateless random ops. There was a subtle difference between the old stateless and stateful random number generations. After we merge the two implementations into one, we need to slightly increase the relative error tolerance for a test. Move StatelessRngUniform from prng.h to random_ops_util.h in the tf2xla bridge. PiperOrigin-RevId: 242557394
This commit is contained in:
parent
91cbed99b6
commit
4b50df1569
@ -15,6 +15,7 @@ limitations under the License.
|
||||
|
||||
// XLA implementations of Categorical op.
|
||||
|
||||
#include "tensorflow/compiler/tf2xla/kernels/random_ops_util.h"
|
||||
#include "tensorflow/compiler/tf2xla/shape_util.h"
|
||||
#include "tensorflow/compiler/tf2xla/type_util.h"
|
||||
#include "tensorflow/compiler/tf2xla/xla_helpers.h"
|
||||
@ -140,8 +141,6 @@ class StatelessCategoricalOp : public CategoricalOp {
|
||||
xla::XlaOp GetLogUniforms(xla::Shape uniform_shape, xla::PrimitiveType type,
|
||||
XlaOpKernelContext* ctx) override {
|
||||
xla::XlaOp seed = ctx->Input(2);
|
||||
auto seed0 = xla::Reshape(xla::Slice(seed, {0}, {1}, {1}), {});
|
||||
auto seed1 = xla::Reshape(xla::Slice(seed, {1}, {2}, {1}), {});
|
||||
|
||||
xla::XlaBuilder* builder = ctx->builder();
|
||||
if (uniform_shape.element_type() == xla::BF16) {
|
||||
@ -150,8 +149,8 @@ class StatelessCategoricalOp : public CategoricalOp {
|
||||
// We want a number in (0, 1) rather than [0, 1) or (0, 1]:
|
||||
// * log(-log(0)) is ∞.
|
||||
// * log(-log(1)) is -∞.
|
||||
auto uniforms = xla::StatelessRngUniform(
|
||||
{seed0, seed1}, uniform_shape,
|
||||
xla::XlaOp uniforms = StatelessRngUniform(
|
||||
seed, uniform_shape,
|
||||
xla::MinPositiveNormalValue(builder, uniform_shape.element_type()),
|
||||
xla::One(builder, uniform_shape.element_type()));
|
||||
return xla::ConvertElementType(xla::Log(-xla::Log(uniforms)), type);
|
||||
|
@ -22,6 +22,13 @@ limitations under the License.
|
||||
#include "tensorflow/core/framework/tensor.h"
|
||||
|
||||
namespace tensorflow {
|
||||
// Returns a tensor containing 'shape' random values uniformly distributed in
|
||||
// the range [minval, maxval). The raw random bits are generated by the given
|
||||
// `bit_generator` and converted to the requested data type and range. This
|
||||
// routine requires 2 32-bit integer seeds and currently only supports 'shape's
|
||||
// of type F32, S32 and S64.
|
||||
xla::XlaOp StatelessRngUniform(xla::XlaOp seeds, const xla::Shape& shape,
|
||||
xla::XlaOp minval, xla::XlaOp maxval);
|
||||
|
||||
// Converts to bfloat16 if `dtype` equals DT_BFLOAT16, no-op otherwise.
|
||||
// It masks the last 16 bit. With normal rounding, values near "maxval" would be
|
||||
|
@ -35,127 +35,50 @@ limitations under the License.
|
||||
namespace tensorflow {
|
||||
namespace {
|
||||
|
||||
std::pair<xla::ThreeFry2x32State, xla::XlaOp> GetInputsFromCounter(
|
||||
xla::XlaOp counter, const int64 size) {
|
||||
auto builder = counter.builder();
|
||||
auto input_u64 = Iota(builder, xla::U64, size);
|
||||
input_u64 = input_u64 + counter;
|
||||
counter = counter + xla::ConstantR0<uint64>(builder, size);
|
||||
return std::make_pair(xla::Uint64ToUint32s(input_u64), counter);
|
||||
}
|
||||
|
||||
// `StatelessRngUniformU32` uses ThreeFry2x32’s counter space too
|
||||
// wastefully, only able to generate 2^32*2 int32 numbers for each key, while
|
||||
// the real capacity is 2^64*2. Counter-space efficiency is important for
|
||||
// stateful ops, hence the following 2 new functions.
|
||||
std::pair<xla::XlaOp, xla::XlaOp> StatefulRngUniformU32(
|
||||
xla::XlaOp key, xla::XlaOp counter, const xla::Shape& shape) {
|
||||
auto builder = key.builder();
|
||||
const int64 size = xla::ShapeUtil::ElementsIn(shape);
|
||||
const int64 half_size = xla::CeilOfRatio<int64>(size, 2);
|
||||
const bool size_is_odd = (half_size * 2 != size);
|
||||
auto inputs_counter = GetInputsFromCounter(counter, half_size);
|
||||
auto inputs = inputs_counter.first;
|
||||
counter = inputs_counter.second;
|
||||
auto outputs = xla::ThreeFry2x32(inputs, xla::Uint64ToUint32s(key));
|
||||
if (size_is_odd) {
|
||||
outputs[1] = Slice(outputs[1], {0}, {half_size - 1}, {1});
|
||||
}
|
||||
auto result = ConcatInDim(builder, outputs, 0);
|
||||
return std::make_pair(Reshape(result, xla::AsInt64Slice(shape.dimensions())),
|
||||
counter);
|
||||
}
|
||||
|
||||
std::pair<xla::XlaOp, xla::XlaOp> StatefulRngUniformU64(
|
||||
xla::XlaOp key, xla::XlaOp counter, const xla::Shape& shape) {
|
||||
const int64 size = xla::ShapeUtil::ElementsIn(shape);
|
||||
auto inputs_counter = GetInputsFromCounter(counter, size);
|
||||
auto inputs = inputs_counter.first;
|
||||
counter = inputs_counter.second;
|
||||
auto outputs = ThreeFry2x32(inputs, Uint64ToUint32s(key));
|
||||
auto result = Uint32sToUint64(outputs);
|
||||
return std::make_pair(Reshape(result, xla::AsInt64Slice(shape.dimensions())),
|
||||
counter);
|
||||
}
|
||||
|
||||
std::pair<xla::XlaOp, xla::XlaOp> StatefulRngUniform(xla::XlaOp key,
|
||||
xla::XlaOp counter,
|
||||
const xla::Shape& shape,
|
||||
xla::XlaOp minval,
|
||||
xla::XlaOp maxval) {
|
||||
auto builder = key.builder();
|
||||
xla::RngOutput StatefulRngUniform(xla::XlaOp key, xla::XlaOp initial_state,
|
||||
const xla::Shape& shape, xla::XlaOp minval,
|
||||
xla::XlaOp maxval) {
|
||||
xla::PrimitiveType type = shape.element_type();
|
||||
switch (type) {
|
||||
case xla::F32: {
|
||||
auto bits_counter = StatefulRngUniformU32(key, counter, shape);
|
||||
auto bits = bits_counter.first;
|
||||
counter = bits_counter.second;
|
||||
return std::make_pair(xla::StatelessRngUniformF32(bits, minval, maxval),
|
||||
counter);
|
||||
}
|
||||
case xla::U32: // fall through
|
||||
case xla::S32: {
|
||||
auto bits_counter = StatefulRngUniformU32(key, counter, shape);
|
||||
auto bits = bits_counter.first;
|
||||
counter = bits_counter.second;
|
||||
return std::make_pair(
|
||||
xla::StatelessRngUniformInt(bits, minval, maxval, type, xla::U32),
|
||||
counter);
|
||||
}
|
||||
case xla::U64: // fall through
|
||||
case xla::S64: {
|
||||
auto bits_counter = StatefulRngUniformU64(key, counter, shape);
|
||||
auto bits = bits_counter.first;
|
||||
counter = bits_counter.second;
|
||||
return std::make_pair(
|
||||
xla::StatelessRngUniformInt(bits, minval, maxval, type, xla::U64),
|
||||
counter);
|
||||
}
|
||||
case xla::F32:
|
||||
return xla::UniformF32Distribution(
|
||||
key, initial_state, xla::ThreeFryBitGenerator, minval, maxval, shape);
|
||||
case xla::U32:
|
||||
case xla::S32:
|
||||
case xla::U64:
|
||||
case xla::S64:
|
||||
return UniformIntDistribution(
|
||||
key, initial_state, xla::ThreeFryBitGenerator, minval, maxval, shape);
|
||||
default:
|
||||
return std::make_pair(
|
||||
builder->ReportError(xla::Unimplemented(
|
||||
"Types other than F32, U32, S32, U64 and S64 "
|
||||
"are not implemented by "
|
||||
"StatefulRngUniform; got: %s",
|
||||
xla::primitive_util::LowercasePrimitiveTypeName(type))),
|
||||
counter);
|
||||
return {key.builder()->ReportError(xla::Unimplemented(
|
||||
"Types other than F32, U32, S32, U64 and S64 "
|
||||
"are not implemented by "
|
||||
"StatefulRngUniform; got %s",
|
||||
xla::primitive_util::LowercasePrimitiveTypeName(type))),
|
||||
initial_state};
|
||||
}
|
||||
}
|
||||
|
||||
template <typename A, typename B, typename A2>
|
||||
std::pair<A2, B> map_first(std::function<A2(A)> f, std::pair<A, B> p) {
|
||||
return std::make_pair(f(p.first), p.second);
|
||||
}
|
||||
|
||||
std::pair<xla::XlaOp, xla::XlaOp> StatefulRngUniformFullInt(
|
||||
xla::XlaOp key, xla::XlaOp counter, const xla::Shape& shape) {
|
||||
xla::RngOutput StatefulRngUniformFullInt(xla::XlaOp key,
|
||||
xla::XlaOp initial_state,
|
||||
const xla::Shape& shape) {
|
||||
xla::PrimitiveType type = shape.element_type();
|
||||
xla::RngOutput output = xla::ThreeFryBitGenerator(key, initial_state, shape);
|
||||
switch (type) {
|
||||
case xla::U32:
|
||||
return StatefulRngUniformU32(key, counter, shape);
|
||||
case xla::S32: {
|
||||
// Needs explicit function type because of type-inference failure.
|
||||
std::function<xla::XlaOp(xla::XlaOp)> f = [](xla::XlaOp x) {
|
||||
return BitcastConvertType(x, xla::S32);
|
||||
};
|
||||
return map_first(f, StatefulRngUniformU32(key, counter, shape));
|
||||
}
|
||||
case xla::U64:
|
||||
return StatefulRngUniformU64(key, counter, shape);
|
||||
case xla::S64: {
|
||||
std::function<xla::XlaOp(xla::XlaOp)> f = [](xla::XlaOp x) {
|
||||
return BitcastConvertType(x, xla::S64);
|
||||
};
|
||||
return map_first(f, StatefulRngUniformU64(key, counter, shape));
|
||||
}
|
||||
return output;
|
||||
case xla::S32:
|
||||
case xla::S64:
|
||||
output.value = BitcastConvertType(output.value, type);
|
||||
return output;
|
||||
default:
|
||||
auto builder = key.builder();
|
||||
return std::make_pair(
|
||||
builder->ReportError(xla::Unimplemented(
|
||||
return {
|
||||
key.builder()->ReportError(xla::Unimplemented(
|
||||
"Types other than U32, S32, U64 and S64 are not implemented by "
|
||||
"StatefulRngUniformFullInt; got: %s",
|
||||
xla::primitive_util::LowercasePrimitiveTypeName(type))),
|
||||
counter);
|
||||
initial_state};
|
||||
}
|
||||
}
|
||||
|
||||
@ -177,15 +100,15 @@ xla::XlaOp ConcatScalars(xla::XlaBuilder* builder,
|
||||
0);
|
||||
}
|
||||
|
||||
using sampler_return_type = xla::StatusOr<std::pair<xla::XlaOp, xla::XlaOp>>;
|
||||
using SamplerReturnType = xla::StatusOr<xla::RngOutput>;
|
||||
|
||||
// A helper function containing the common part of several kernels below.
|
||||
// Precondition: 'algorithm' and 'shape' are compile-time constants.
|
||||
Status CompileImpl(XlaOpKernelContext* ctx, int state_input_idx,
|
||||
int alg_input_idx, int shape_input_idx,
|
||||
std::function<sampler_return_type(xla::XlaOp, xla::XlaOp,
|
||||
TensorShape)> const&
|
||||
sample_with_threefry) {
|
||||
Status CompileImpl(
|
||||
XlaOpKernelContext* ctx, int state_input_idx, int alg_input_idx,
|
||||
int shape_input_idx,
|
||||
std::function<SamplerReturnType(xla::XlaOp, xla::XlaOp, TensorShape)> const&
|
||||
sampler) {
|
||||
auto alg_shape = ctx->InputShape(alg_input_idx);
|
||||
if (alg_shape.dims() != 0) {
|
||||
return errors::InvalidArgument("algorithm must be of shape [], not ",
|
||||
@ -215,24 +138,22 @@ Status CompileImpl(XlaOpKernelContext* ctx, int state_input_idx,
|
||||
TensorShape shape;
|
||||
TF_RETURN_IF_ERROR(ctx->ConstantInputAsShape(shape_input_idx, &shape));
|
||||
|
||||
static constexpr int COUNTER_SIZE = 1;
|
||||
auto counter = BitcastConvertType(
|
||||
xla::Reshape(xla::Slice(var, {0}, {COUNTER_SIZE}, {1}), {}), xla::U64);
|
||||
static constexpr int kStateSize = 1;
|
||||
auto state = BitcastConvertType(
|
||||
xla::Reshape(xla::Slice(var, {0}, {kStateSize}, {1}), {}), xla::U64);
|
||||
auto key = BitcastConvertType(
|
||||
xla::Reshape(xla::Slice(var, {COUNTER_SIZE}, {COUNTER_SIZE + 1}, {1}),
|
||||
{}),
|
||||
xla::Reshape(xla::Slice(var, {kStateSize}, {kStateSize + 1}, {1}), {}),
|
||||
xla::U64);
|
||||
|
||||
auto status_or_value = sample_with_threefry(counter, key, shape);
|
||||
auto status_or_value = sampler(state, key, shape);
|
||||
if (!status_or_value.ok()) {
|
||||
return status_or_value.status();
|
||||
}
|
||||
auto output_counter = status_or_value.ConsumeValueOrDie();
|
||||
auto output = output_counter.first;
|
||||
counter = output_counter.second;
|
||||
ctx->SetOutput(0, output);
|
||||
auto builder = ctx->builder();
|
||||
var = ConcatScalars(builder, {counter, key});
|
||||
xla::RngOutput value_state = status_or_value.ConsumeValueOrDie();
|
||||
state = value_state.state;
|
||||
ctx->SetOutput(0, value_state.value);
|
||||
xla::XlaBuilder* builder = ctx->builder();
|
||||
var = ConcatScalars(builder, {state, key});
|
||||
xla::PrimitiveType state_element_type;
|
||||
TF_RETURN_IF_ERROR(
|
||||
DataTypeToPrimitiveType(STATE_ELEMENT_DTYPE, &state_element_type));
|
||||
@ -252,23 +173,22 @@ class StatefulUniformOp : public XlaOpKernel {
|
||||
}
|
||||
|
||||
void Compile(XlaOpKernelContext* ctx) override {
|
||||
auto builder = ctx->builder();
|
||||
auto sample_with_threefry = [builder, this](
|
||||
xla::XlaOp counter, xla::XlaOp key,
|
||||
TensorShape shape) -> sampler_return_type {
|
||||
xla::XlaBuilder* builder = ctx->builder();
|
||||
auto sampler = [builder, this](xla::XlaOp state, xla::XlaOp key,
|
||||
TensorShape shape) -> SamplerReturnType {
|
||||
xla::Shape xla_shape;
|
||||
TF_RETURN_IF_ERROR(TensorShapeToXLAShape(DT_FLOAT, shape, &xla_shape));
|
||||
auto uniform_counter = StatefulRngUniform(
|
||||
key, counter, xla_shape, xla::ConstantR0<float>(builder, 0.0),
|
||||
xla::RngOutput uniform_state = StatefulRngUniform(
|
||||
key, state, xla_shape, xla::ConstantR0<float>(builder, 0.0),
|
||||
xla::ConstantR0<float>(builder, 1.0));
|
||||
auto uniform = uniform_counter.first;
|
||||
counter = uniform_counter.second;
|
||||
xla::XlaOp uniform = uniform_state.value;
|
||||
state = uniform_state.state;
|
||||
uniform = MaybeConvertF32ToBF16(uniform, dtype_);
|
||||
return {{uniform, counter}};
|
||||
return {{uniform, state}};
|
||||
};
|
||||
OP_REQUIRES_OK(ctx,
|
||||
CompileImpl(ctx, /*state_input_idx=*/0, /*alg_input_idx=*/1,
|
||||
/*shape_input_idx=*/2, sample_with_threefry));
|
||||
/*shape_input_idx=*/2, sampler));
|
||||
}
|
||||
|
||||
private:
|
||||
@ -293,30 +213,20 @@ class StatefulStandardNormalOp : public XlaOpKernel {
|
||||
}
|
||||
|
||||
void Compile(XlaOpKernelContext* ctx) override {
|
||||
auto builder = ctx->builder();
|
||||
auto sample_with_threefry =
|
||||
auto sampler =
|
||||
// Needs explicit lambda return type because it fails to be inferred.
|
||||
[builder, this](xla::XlaOp counter, xla::XlaOp key,
|
||||
TensorShape shape) -> sampler_return_type {
|
||||
[this](xla::XlaOp state, xla::XlaOp key,
|
||||
TensorShape shape) -> SamplerReturnType {
|
||||
xla::Shape xla_shape;
|
||||
TF_RETURN_IF_ERROR(TensorShapeToXLAShape(DT_FLOAT, shape, &xla_shape));
|
||||
|
||||
auto uniform_counter = StatefulRngUniform(
|
||||
key, counter, xla_shape,
|
||||
xla::ConstantR0<float>(builder, std::nextafter(-1.0f, 0.0f)),
|
||||
xla::ConstantR0<float>(builder, 1.0));
|
||||
auto uniform = uniform_counter.first;
|
||||
counter = uniform_counter.second;
|
||||
// Convert uniform distribution to normal distribution by computing
|
||||
// sqrt(2) * erfinv(x)
|
||||
auto normal =
|
||||
xla::ScalarLike(uniform, std::sqrt(2.0)) * xla::ErfInv(uniform);
|
||||
normal = MaybeConvertF32ToBF16(normal, dtype_);
|
||||
return {{normal, counter}};
|
||||
xla::RngOutput value_state = xla::NormalF32Distribution(
|
||||
key, state, xla::ThreeFryBitGenerator, xla_shape);
|
||||
xla::XlaOp normal = MaybeConvertF32ToBF16(value_state.value, dtype_);
|
||||
return {{normal, value_state.state}};
|
||||
};
|
||||
OP_REQUIRES_OK(ctx,
|
||||
CompileImpl(ctx, /*state_input_idx=*/0, /*alg_input_idx=*/1,
|
||||
/*shape_input_idx=*/2, sample_with_threefry));
|
||||
/*shape_input_idx=*/2, sampler));
|
||||
}
|
||||
|
||||
private:
|
||||
@ -341,27 +251,27 @@ class StatefulTruncatedNormalOp : public XlaOpKernel {
|
||||
}
|
||||
|
||||
void Compile(XlaOpKernelContext* ctx) override {
|
||||
auto builder = ctx->builder();
|
||||
auto sample_with_threefry =
|
||||
xla::XlaBuilder* builder = ctx->builder();
|
||||
auto sampler =
|
||||
// Needs explicit lambda return type because it fails to be inferred.
|
||||
[builder, this](xla::XlaOp counter, xla::XlaOp key,
|
||||
TensorShape shape) -> sampler_return_type {
|
||||
[builder, this](xla::XlaOp state, xla::XlaOp key,
|
||||
TensorShape shape) -> SamplerReturnType {
|
||||
xla::Shape xla_shape;
|
||||
TF_RETURN_IF_ERROR(TensorShapeToXLAShape(DT_FLOAT, shape, &xla_shape));
|
||||
|
||||
auto uniform_counter = StatefulRngUniform(
|
||||
key, counter, xla_shape,
|
||||
xla::RngOutput uniform_result = StatefulRngUniform(
|
||||
key, state, xla_shape,
|
||||
xla::MinPositiveNormalValue(builder, xla_shape.element_type()),
|
||||
xla::One(builder, xla_shape.element_type()));
|
||||
auto uniform = uniform_counter.first;
|
||||
counter = uniform_counter.second;
|
||||
xla::XlaOp uniform = uniform_result.value;
|
||||
state = uniform_result.state;
|
||||
xla::XlaOp truncated_normal = TruncatedNormal(uniform);
|
||||
truncated_normal = MaybeConvertF32ToBF16(truncated_normal, dtype_);
|
||||
return {{truncated_normal, counter}};
|
||||
return {{truncated_normal, state}};
|
||||
};
|
||||
OP_REQUIRES_OK(ctx,
|
||||
CompileImpl(ctx, /*state_input_idx=*/0, /*alg_input_idx=*/1,
|
||||
/*shape_input_idx=*/2, sample_with_threefry));
|
||||
/*shape_input_idx=*/2, sampler));
|
||||
}
|
||||
|
||||
private:
|
||||
@ -388,11 +298,11 @@ class StatefulUniformIntOp : public XlaOpKernel {
|
||||
xla::XlaOp minval = ctx->Input(3);
|
||||
xla::XlaOp maxval = ctx->Input(4);
|
||||
auto sample_with_threefry = [minval, maxval, this](
|
||||
xla::XlaOp counter, xla::XlaOp key,
|
||||
TensorShape shape) -> sampler_return_type {
|
||||
xla::XlaOp state, xla::XlaOp key,
|
||||
TensorShape shape) -> SamplerReturnType {
|
||||
xla::Shape xla_shape;
|
||||
TF_RETURN_IF_ERROR(TensorShapeToXLAShape(dtype_, shape, &xla_shape));
|
||||
return StatefulRngUniform(key, counter, xla_shape, minval, maxval);
|
||||
return StatefulRngUniform(key, state, xla_shape, minval, maxval);
|
||||
};
|
||||
OP_REQUIRES_OK(ctx,
|
||||
CompileImpl(ctx, /*state_input_idx=*/0, /*alg_input_idx=*/1,
|
||||
@ -420,12 +330,11 @@ class StatefulUniformFullIntOp : public XlaOpKernel {
|
||||
}
|
||||
|
||||
void Compile(XlaOpKernelContext* ctx) override {
|
||||
auto sample_with_threefry = [this](
|
||||
xla::XlaOp counter, xla::XlaOp key,
|
||||
TensorShape shape) -> sampler_return_type {
|
||||
auto sample_with_threefry = [this](xla::XlaOp state, xla::XlaOp key,
|
||||
TensorShape shape) -> SamplerReturnType {
|
||||
xla::Shape xla_shape;
|
||||
TF_RETURN_IF_ERROR(TensorShapeToXLAShape(dtype_, shape, &xla_shape));
|
||||
return StatefulRngUniformFullInt(key, counter, xla_shape);
|
||||
return StatefulRngUniformFullInt(key, state, xla_shape);
|
||||
};
|
||||
OP_REQUIRES_OK(ctx,
|
||||
CompileImpl(ctx, /*state_input_idx=*/0, /*alg_input_idx=*/1,
|
||||
|
@ -36,8 +36,8 @@ namespace tensorflow {
|
||||
xla::XlaOp MaybeConvertF32ToBF16(xla::XlaOp input, DataType dtype) {
|
||||
if (dtype == DT_BFLOAT16) {
|
||||
xla::XlaBuilder* builder = input.builder();
|
||||
auto output = xla::BitcastConvertType(input, xla::U32) &
|
||||
xla::ConstantR0<uint32>(builder, 0xFFFF0000);
|
||||
xla::XlaOp output = xla::BitcastConvertType(input, xla::U32) &
|
||||
xla::ConstantR0<uint32>(builder, 0xFFFF0000);
|
||||
return xla::ConvertElementType(xla::BitcastConvertType(output, xla::F32),
|
||||
xla::BF16);
|
||||
} else {
|
||||
@ -45,22 +45,36 @@ xla::XlaOp MaybeConvertF32ToBF16(xla::XlaOp input, DataType dtype) {
|
||||
}
|
||||
}
|
||||
|
||||
xla::XlaOp Uniform2NormalUsingSqrtErfinv(xla::XlaOp uniform) {
|
||||
// Convert uniform distribution to normal distribution by computing
|
||||
// sqrt(2) * erfinv(x)
|
||||
return xla::ScalarLike(uniform, std::sqrt(2.0)) * xla::ErfInv(uniform);
|
||||
}
|
||||
xla::XlaOp StatelessRngUniform(xla::XlaOp seeds, const xla::Shape& shape,
|
||||
xla::XlaOp minval, xla::XlaOp maxval) {
|
||||
xla::XlaBuilder* builder = seeds.builder();
|
||||
|
||||
// A wrapper of xla::StatelessRngUniform. Returns an op that produces random
|
||||
// values with uniform distribution in the range [minval, maxval) for the given
|
||||
// shape and given two 32-bit seeds. Currently only shapes of type F32, S32 and
|
||||
// S64 are implemented.
|
||||
xla::XlaOp StatelessRandomUniformImpl(const xla::Shape& shape, DataType unused,
|
||||
xla::XlaOp seed, xla::XlaOp minval,
|
||||
xla::XlaOp maxval) {
|
||||
xla::XlaOp seed0 = xla::Reshape(xla::Slice(seed, {0}, {1}, {1}), {});
|
||||
xla::XlaOp seed1 = xla::Reshape(xla::Slice(seed, {1}, {2}, {1}), {});
|
||||
return xla::StatelessRngUniform({seed0, seed1}, shape, minval, maxval);
|
||||
xla::XlaOp seed0 = xla::Reshape(xla::Slice(seeds, {0}, {1}, {1}), {});
|
||||
xla::XlaOp seed1 = xla::Reshape(xla::Slice(seeds, {1}, {2}, {1}), {});
|
||||
xla::XlaOp key = ConvertElementType(seed0, xla::U64) |
|
||||
ShiftLeft(ConvertElementType(seed1, xla::U64),
|
||||
ConstantR0WithType(builder, xla::U64, 32));
|
||||
xla::XlaOp initial_state = xla::ConstantR0WithType(builder, xla::U64, 0);
|
||||
xla::PrimitiveType type = shape.element_type();
|
||||
switch (type) {
|
||||
case xla::F32:
|
||||
return xla::UniformF32Distribution(key, initial_state,
|
||||
xla::ThreeFryBitGenerator, minval,
|
||||
maxval, shape)
|
||||
.value;
|
||||
case xla::S32: // fall through
|
||||
case xla::S64:
|
||||
return UniformIntDistribution(key, initial_state,
|
||||
xla::ThreeFryBitGenerator, minval, maxval,
|
||||
shape)
|
||||
.value;
|
||||
break;
|
||||
default:
|
||||
return builder->ReportError(xla::Unimplemented(
|
||||
"Types other than F32, S32 and S64 are not implemented by "
|
||||
"StatelessRngUniform; got %s",
|
||||
xla::primitive_util::LowercasePrimitiveTypeName(type)));
|
||||
}
|
||||
}
|
||||
|
||||
namespace {
|
||||
@ -86,8 +100,8 @@ class StatelessRandomUniformOp : public XlaOpKernel {
|
||||
|
||||
xla::Shape xla_shape;
|
||||
OP_REQUIRES_OK(ctx, TensorShapeToXLAShape(DT_FLOAT, shape, &xla_shape));
|
||||
xla::XlaOp uniform = StatelessRandomUniformImpl(
|
||||
xla_shape, dtype_, seed, xla::ConstantR0<float>(builder, 0.0),
|
||||
xla::XlaOp uniform = StatelessRngUniform(
|
||||
seed, xla_shape, xla::ConstantR0<float>(builder, 0.0),
|
||||
xla::ConstantR0<float>(builder, 1.0));
|
||||
uniform = MaybeConvertF32ToBF16(uniform, dtype_);
|
||||
ctx->SetOutput(0, uniform);
|
||||
@ -136,8 +150,8 @@ class StatelessRandomUniformIntOp : public XlaOpKernel {
|
||||
|
||||
xla::Shape xla_shape;
|
||||
OP_REQUIRES_OK(ctx, TensorShapeToXLAShape(dtype_, shape, &xla_shape));
|
||||
xla::XlaOp uniform =
|
||||
StatelessRandomUniformImpl(xla_shape, dtype_, seed, minval, maxval);
|
||||
xla::XlaOp uniform = StatelessRngUniform(seed, xla_shape, minval, maxval);
|
||||
|
||||
ctx->SetOutput(0, uniform);
|
||||
}
|
||||
|
||||
@ -170,14 +184,20 @@ class StatelessRandomNormalOp : public XlaOpKernel {
|
||||
errors::InvalidArgument("seed must have shape [2], not ",
|
||||
seed_shape.DebugString()));
|
||||
xla::XlaOp seed = ctx->Input(1);
|
||||
xla::XlaBuilder* builder = ctx->builder();
|
||||
xla::Shape xla_shape;
|
||||
OP_REQUIRES_OK(ctx, TensorShapeToXLAShape(DT_FLOAT, shape, &xla_shape));
|
||||
xla::XlaOp uniform = StatelessRandomUniformImpl(
|
||||
xla_shape, dtype_, seed,
|
||||
xla::ConstantR0<float>(builder, std::nextafter(-1.0f, 0.0f)),
|
||||
xla::ConstantR0<float>(builder, 1.0));
|
||||
xla::XlaOp normal = Uniform2NormalUsingSqrtErfinv(uniform);
|
||||
|
||||
xla::XlaBuilder* builder = seed.builder();
|
||||
xla::XlaOp seed0 = xla::Reshape(xla::Slice(seed, {0}, {1}, {1}), {});
|
||||
xla::XlaOp seed1 = xla::Reshape(xla::Slice(seed, {1}, {2}, {1}), {});
|
||||
xla::XlaOp initial_state = xla::ConstantR0WithType(builder, xla::U64, 0);
|
||||
xla::XlaOp key = ConvertElementType(seed0, xla::U64) |
|
||||
ShiftLeft(ConvertElementType(seed1, xla::U64),
|
||||
ConstantR0WithType(builder, xla::U64, 32));
|
||||
xla::XlaOp normal =
|
||||
xla::NormalF32Distribution(key, initial_state,
|
||||
xla::ThreeFryBitGenerator, xla_shape)
|
||||
.value;
|
||||
normal = MaybeConvertF32ToBF16(normal, dtype_);
|
||||
ctx->SetOutput(0, normal);
|
||||
}
|
||||
@ -215,8 +235,8 @@ class StatelessTruncatedNormalOp : public XlaOpKernel {
|
||||
|
||||
xla::Shape xla_shape;
|
||||
OP_REQUIRES_OK(ctx, TensorShapeToXLAShape(DT_FLOAT, shape, &xla_shape));
|
||||
xla::XlaOp uniform = StatelessRandomUniformImpl(
|
||||
xla_shape, dtype_, seed,
|
||||
xla::XlaOp uniform = StatelessRngUniform(
|
||||
seed, xla_shape,
|
||||
xla::MinPositiveNormalValue(builder, xla_shape.element_type()),
|
||||
xla::One(builder, xla_shape.element_type()));
|
||||
xla::XlaOp truncated_normal = TruncatedNormal(uniform);
|
||||
|
@ -32,8 +32,12 @@ XlaOp RotateLeftU32(XlaOp v, int distance) {
|
||||
ShiftRightLogical(v, ConstantR0<uint32>(v.builder(), 32 - distance));
|
||||
}
|
||||
|
||||
} // namespace
|
||||
// The internal state of the Three Fry implementation.
|
||||
using ThreeFry2x32State = std::array<XlaOp, 2>;
|
||||
|
||||
// Implements the ThreeFry counter-based PRNG algorithm.
|
||||
// Salmon et al. SC 2011. Parallel random numbers: as easy as 1, 2, 3.
|
||||
// http://www.thesalmons.org/john/random123/papers/random123sc11.pdf
|
||||
ThreeFry2x32State ThreeFry2x32(ThreeFry2x32State input, ThreeFry2x32State key) {
|
||||
XlaBuilder* builder = input[0].builder();
|
||||
key[0] = BitcastConvertType(key[0], U32);
|
||||
@ -104,56 +108,68 @@ ThreeFry2x32State ThreeFry2x32(ThreeFry2x32State input, ThreeFry2x32State key) {
|
||||
return x;
|
||||
}
|
||||
|
||||
// Returns the inputs with unique counter values for ThreeFry2x32.
|
||||
ThreeFry2x32State GetInputs(const int64 size, XlaBuilder* builder) {
|
||||
ThreeFry2x32State inputs;
|
||||
inputs[0] = Iota(builder, U32, size);
|
||||
inputs[1] = inputs[0] + ConstantR0<uint32>(builder, size);
|
||||
return inputs;
|
||||
}
|
||||
|
||||
XlaOp StatelessRngUniformU32(std::array<XlaOp, 2> key, const Shape& shape) {
|
||||
XlaBuilder* builder = key[0].builder();
|
||||
const int64 size = ShapeUtil::ElementsIn(shape);
|
||||
const int64 half_size = CeilOfRatio<int64>(size, 2);
|
||||
const bool size_is_odd = (half_size * 2 != size);
|
||||
ThreeFry2x32State inputs = GetInputs(half_size, builder);
|
||||
ThreeFry2x32State outputs = ThreeFry2x32(inputs, key);
|
||||
if (size_is_odd) {
|
||||
outputs[1] = Slice(outputs[1], {0}, {half_size - 1}, {1});
|
||||
}
|
||||
auto result = ConcatInDim(builder, outputs, 0);
|
||||
return Reshape(result, AsInt64Slice(shape.dimensions()));
|
||||
}
|
||||
|
||||
// Converts a uint64 to two uint32s.
|
||||
ThreeFry2x32State Uint64ToUint32s(XlaOp u64) {
|
||||
auto builder = u64.builder();
|
||||
auto const32 = ConstantR0WithType(builder, U64, 32);
|
||||
auto fst = ConvertElementType(u64, U32);
|
||||
auto snd = ConvertElementType(ShiftRightLogical(u64, const32), U32);
|
||||
XlaBuilder* builder = u64.builder();
|
||||
XlaOp const32 = ConstantR0WithType(builder, U64, 32);
|
||||
XlaOp fst = ConvertElementType(u64, U32);
|
||||
XlaOp snd = ConvertElementType(ShiftRightLogical(u64, const32), U32);
|
||||
return {fst, snd};
|
||||
}
|
||||
|
||||
// Converts two uint32s to a uint64.
|
||||
XlaOp Uint32sToUint64(ThreeFry2x32State u32s) {
|
||||
auto builder = u32s[0].builder();
|
||||
XlaBuilder* builder = u32s[0].builder();
|
||||
return ConvertElementType(u32s[0], U64) |
|
||||
ShiftLeft(ConvertElementType(u32s[1], U64),
|
||||
ConstantR0WithType(builder, U64, 32));
|
||||
}
|
||||
|
||||
XlaOp StatelessRngUniformU64(std::array<XlaOp, 2> key, const Shape& shape) {
|
||||
XlaBuilder* builder = key[0].builder();
|
||||
const int64 size = ShapeUtil::ElementsIn(shape);
|
||||
ThreeFry2x32State inputs = GetInputs(size, builder);
|
||||
ThreeFry2x32State outputs = ThreeFry2x32(inputs, key);
|
||||
// low 32 bit: outputs[0], high 32 bit: outputs[1]
|
||||
auto result = Uint32sToUint64(outputs);
|
||||
return Reshape(result, AsInt64Slice(shape.dimensions()));
|
||||
// Given the initial state and the request number of random numbers to be
|
||||
// generated, returns the input for the random number generator and a new state.
|
||||
std::pair<ThreeFry2x32State, XlaOp> GetThreeFryInputsAndUpdatedState(
|
||||
XlaOp initial_state, const int64 size) {
|
||||
XlaBuilder* builder = initial_state.builder();
|
||||
XlaOp input_u64 = Iota(builder, U64, size);
|
||||
input_u64 = input_u64 + initial_state;
|
||||
XlaOp new_state = initial_state + ConstantR0<uint64>(builder, size);
|
||||
return std::make_pair(Uint64ToUint32s(input_u64), new_state);
|
||||
}
|
||||
|
||||
XlaOp StatelessRngUniformF32(XlaOp bits, XlaOp minval, XlaOp maxval) {
|
||||
XlaBuilder* builder = bits.builder();
|
||||
// Generates random 32bits with the given shape using the Three Fry
|
||||
// implementation. Returns the random bits and the new state.
|
||||
RngOutput ThreeFryRngBit32(XlaOp key, XlaOp initial_state, const Shape& shape) {
|
||||
XlaBuilder* builder = key.builder();
|
||||
const int64 size = ShapeUtil::ElementsIn(shape);
|
||||
const int64 half_size = CeilOfRatio<int64>(size, 2);
|
||||
const bool size_is_odd = (half_size * 2 != size);
|
||||
std::pair<ThreeFry2x32State, XlaOp> inputs_state =
|
||||
GetThreeFryInputsAndUpdatedState(initial_state, half_size);
|
||||
ThreeFry2x32State inputs = inputs_state.first;
|
||||
ThreeFry2x32State outputs = ThreeFry2x32(inputs, Uint64ToUint32s(key));
|
||||
if (size_is_odd) {
|
||||
outputs[1] = Slice(outputs[1], {0}, {half_size - 1}, {1});
|
||||
}
|
||||
XlaOp result = ConcatInDim(builder, outputs, 0);
|
||||
return {Reshape(result, AsInt64Slice(shape.dimensions())),
|
||||
inputs_state.second};
|
||||
}
|
||||
|
||||
// Generates random 64bits with the given shape using the Three Fry
|
||||
// implementation. Returns the random bits and the new state.
|
||||
RngOutput ThreeFryRngBit64(XlaOp key, XlaOp initial_state, const Shape& shape) {
|
||||
const int64 size = ShapeUtil::ElementsIn(shape);
|
||||
std::pair<ThreeFry2x32State, XlaOp> inputs_state =
|
||||
GetThreeFryInputsAndUpdatedState(initial_state, size);
|
||||
ThreeFry2x32State inputs = inputs_state.first;
|
||||
ThreeFry2x32State outputs = ThreeFry2x32(inputs, Uint64ToUint32s(key));
|
||||
XlaOp result = Uint32sToUint64(outputs);
|
||||
return {Reshape(result, AsInt64Slice(shape.dimensions())),
|
||||
inputs_state.second};
|
||||
}
|
||||
|
||||
XlaOp ConvertRandomBitsToUniformF32(XlaOp bits, XlaOp minval, XlaOp maxval) {
|
||||
XlaBuilder* builder = bits.builder();
|
||||
// Form 23 random mantissa bits, with a leading 1 bit. The leading 1 bit
|
||||
// forces the random bits into the mantissa.
|
||||
constexpr int kFloatBits = 32;
|
||||
@ -161,50 +177,95 @@ XlaOp StatelessRngUniformF32(XlaOp bits, XlaOp minval, XlaOp maxval) {
|
||||
bits = ShiftRightLogical(
|
||||
bits, ConstantR0<uint32>(builder, kFloatBits - kMantissaBits)) |
|
||||
ConstantR0<uint32>(builder, absl::bit_cast<uint32>(1.0f));
|
||||
auto floats = BitcastConvertType(bits, F32);
|
||||
XlaOp values = BitcastConvertType(bits, F32);
|
||||
|
||||
// We have a floating point number in the range [1.0, 2.0).
|
||||
// Subtract 1.0f to shift to the range [0.0, 1.0)
|
||||
floats = floats - ConstantR0<float>(builder, 1.0f);
|
||||
values = values - ConstantR0<float>(builder, 1.0f);
|
||||
// Multiply and add to shift to the range [minval, maxval).
|
||||
return floats * (maxval - minval) + minval;
|
||||
return values * (maxval - minval) + minval;
|
||||
}
|
||||
|
||||
XlaOp StatelessRngUniformInt(XlaOp bits, XlaOp minval, XlaOp maxval,
|
||||
PrimitiveType type, PrimitiveType unsigned_type) {
|
||||
XlaOp ConvertRandomBitsToUniformInt(XlaOp bits, XlaOp minval, XlaOp maxval,
|
||||
PrimitiveType type,
|
||||
PrimitiveType unsigned_type) {
|
||||
XlaBuilder* builder = bits.builder();
|
||||
auto range = BitcastConvertType(maxval, unsigned_type) -
|
||||
BitcastConvertType(minval, unsigned_type);
|
||||
auto dist = Rem(bits, range);
|
||||
auto dist_div_2 =
|
||||
XlaOp range = BitcastConvertType(maxval, unsigned_type) -
|
||||
BitcastConvertType(minval, unsigned_type);
|
||||
XlaOp dist = Rem(bits, range);
|
||||
XlaOp dist_div_2 =
|
||||
ShiftRightLogical(dist, ConstantR0WithType(builder, unsigned_type, 1));
|
||||
|
||||
return minval + BitcastConvertType(dist_div_2, type) +
|
||||
BitcastConvertType(dist - dist_div_2, type);
|
||||
}
|
||||
|
||||
XlaOp StatelessRngUniform(std::array<XlaOp, 2> seeds, const Shape& shape,
|
||||
XlaOp minval, XlaOp maxval) {
|
||||
XlaBuilder* builder = seeds[0].builder();
|
||||
XlaOp UniformToNormalUsingSqrtErfInv(XlaOp uniform) {
|
||||
return ScalarLike(uniform, std::sqrt(2.0)) * ErfInv(uniform);
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
RngOutput ThreeFryBitGenerator(XlaOp key, XlaOp initial_state,
|
||||
const Shape& shape) {
|
||||
PrimitiveType type = shape.element_type();
|
||||
switch (type) {
|
||||
case F32: {
|
||||
auto bits = StatelessRngUniformU32(seeds, shape);
|
||||
return StatelessRngUniformF32(bits, minval, maxval);
|
||||
}
|
||||
case S32: {
|
||||
auto bits = StatelessRngUniformU32(seeds, shape);
|
||||
return StatelessRngUniformInt(bits, minval, maxval, type, U32);
|
||||
}
|
||||
case S64: {
|
||||
auto bits = StatelessRngUniformU64(seeds, shape);
|
||||
return StatelessRngUniformInt(bits, minval, maxval, type, U64);
|
||||
}
|
||||
case F32:
|
||||
case U32:
|
||||
case S32:
|
||||
return ThreeFryRngBit32(key, initial_state, shape);
|
||||
case U64:
|
||||
case S64:
|
||||
return ThreeFryRngBit64(key, initial_state, shape);
|
||||
default:
|
||||
return builder->ReportError(Unimplemented(
|
||||
"Types other than F32, S32 and S64 are not implemented by "
|
||||
"StatelessRngUniform."));
|
||||
return {key.builder()->ReportError(Unimplemented(
|
||||
"Types other than F32, U32, S32, U64 and S64 "
|
||||
"are not implemented by ThreeFryBitGenerator; got %s",
|
||||
xla::primitive_util::LowercasePrimitiveTypeName(type))),
|
||||
initial_state};
|
||||
}
|
||||
}
|
||||
|
||||
RngOutput UniformF32Distribution(XlaOp key, XlaOp initial_state,
|
||||
BitGeneratorTy bit_generator, XlaOp minval,
|
||||
XlaOp maxval, const Shape& shape) {
|
||||
DCHECK_EQ(shape.element_type(), F32);
|
||||
RngOutput bits_state = bit_generator(key, initial_state, shape);
|
||||
XlaOp bits = bits_state.value;
|
||||
XlaOp new_state = bits_state.state;
|
||||
return {ConvertRandomBitsToUniformF32(bits, minval, maxval), new_state};
|
||||
}
|
||||
|
||||
RngOutput UniformIntDistribution(XlaOp key, XlaOp initial_state,
|
||||
BitGeneratorTy bit_generator, XlaOp minval,
|
||||
XlaOp maxval, const Shape& shape) {
|
||||
RngOutput bits_state = bit_generator(key, initial_state, shape);
|
||||
XlaOp bits = bits_state.value;
|
||||
XlaOp new_state = bits_state.state;
|
||||
PrimitiveType type = shape.element_type();
|
||||
PrimitiveType unsigned_type;
|
||||
if (type == U32 || type == S32) {
|
||||
unsigned_type = U32;
|
||||
} else {
|
||||
DCHECK(type == U64 || type == S64);
|
||||
unsigned_type = U64;
|
||||
}
|
||||
return {
|
||||
ConvertRandomBitsToUniformInt(bits, minval, maxval, type, unsigned_type),
|
||||
new_state};
|
||||
}
|
||||
|
||||
RngOutput NormalF32Distribution(XlaOp key, XlaOp initial_state,
|
||||
BitGeneratorTy bit_generator,
|
||||
const Shape& shape) {
|
||||
DCHECK_EQ(shape.element_type(), F32);
|
||||
XlaBuilder* builder = key.builder();
|
||||
RngOutput bits_state = UniformF32Distribution(
|
||||
key, initial_state, bit_generator,
|
||||
ConstantR0<float>(builder, std::nextafter(-1.0f, 0.0f)),
|
||||
ConstantR0<float>(builder, 1.0), shape);
|
||||
XlaOp normal = UniformToNormalUsingSqrtErfInv(bits_state.value);
|
||||
return {normal, bits_state.state};
|
||||
}
|
||||
|
||||
} // namespace xla
|
||||
|
@ -23,37 +23,52 @@ limitations under the License.
|
||||
|
||||
namespace xla {
|
||||
|
||||
// Records the bits and state generated by a random number generator.
|
||||
struct RngOutput {
|
||||
XlaOp value;
|
||||
XlaOp state;
|
||||
};
|
||||
|
||||
// A BitGenerator returns random bits and updated random bit generator state.
|
||||
//
|
||||
// key: is a value input to a random number generator that can affect the
|
||||
// sequence of number it will generate. A random number generator constructs
|
||||
// its seed using the key and the initial state. The tf2xla bridge passes the
|
||||
// seed operand of a tensorflow random operation as a key to the random bit
|
||||
// generator, for example.
|
||||
// initial_state: initial_state is the initial state of the current random
|
||||
// number generation. It could be 0 for a stateless random operation, and
|
||||
// the returned state from a previous execution for a stateful random
|
||||
// operation.
|
||||
// shape: the shape of the random bits.
|
||||
using BitGeneratorTy = std::function<RngOutput(XlaOp key, XlaOp initial_state,
|
||||
const xla::Shape& shape)>;
|
||||
|
||||
// Implements the ThreeFry counter-based PRNG algorithm.
|
||||
// Salmon et al. SC 2011. Parallel random numbers: as easy as 1, 2, 3.
|
||||
// http://www.thesalmons.org/john/random123/papers/random123sc11.pdf
|
||||
using ThreeFry2x32State = std::array<XlaOp, 2>;
|
||||
ThreeFry2x32State ThreeFry2x32(ThreeFry2x32State input, ThreeFry2x32State key);
|
||||
RngOutput ThreeFryBitGenerator(XlaOp key, XlaOp initial_state,
|
||||
const xla::Shape& shape);
|
||||
|
||||
// Returns a tensor containing 'shape' random values uniformly distributed in
|
||||
// the range [minval, maxval). Requires 2 32-bit integer seeds.
|
||||
// Currently only 'shape's of type F32, S32 and S64 are implemented.
|
||||
XlaOp StatelessRngUniform(std::array<XlaOp, 2> seeds, const Shape& shape,
|
||||
XlaOp minval, XlaOp maxval);
|
||||
// Uses the given bit generator to generate random bits and then converts the
|
||||
// random bits to random numbers of uniform distribution in the given range.
|
||||
// Returns the random numbers and the state of the random number generator.
|
||||
// This function is for shape with float element type.
|
||||
RngOutput UniformF32Distribution(XlaOp key, XlaOp initial_state,
|
||||
BitGeneratorTy bit_generator, XlaOp minval,
|
||||
XlaOp maxval, const xla::Shape& shape);
|
||||
|
||||
// Converts a 32-bit (signed or unsigned) integer random number `bits` into a
|
||||
// float32 in the range [minval, maxval).
|
||||
XlaOp StatelessRngUniformF32(XlaOp bits, XlaOp minval, XlaOp maxval);
|
||||
// Similar to UniformF32Distribution but for shape with integer element types.
|
||||
RngOutput UniformIntDistribution(XlaOp key, XlaOp initial_state,
|
||||
BitGeneratorTy bit_generator, XlaOp minval,
|
||||
XlaOp maxval, const xla::Shape& shape);
|
||||
|
||||
// Converts an integer random number 'bits' of type 'type' to a random number
|
||||
// in the range [minval, maxval), of the same type. 'unsigned_type' is the
|
||||
// unsigned version of 'type' (could be the same) with the same bit width.
|
||||
// The algorithm is the same one that TF uses right now, but it's
|
||||
// uniform only when maxval - minval is a divisor of the range that bits is
|
||||
// generated from.
|
||||
// TODO(b/72573764): Generate real uniform integer distribution.
|
||||
XlaOp StatelessRngUniformInt(XlaOp bits, XlaOp minval, XlaOp maxval,
|
||||
PrimitiveType type, PrimitiveType unsigned_type);
|
||||
|
||||
// The following 2 functions, for converting between one uint64 and two uint32s,
|
||||
// use the contract "lower 32 bits for the first uint32, higher 32 bits for the
|
||||
// second".
|
||||
ThreeFry2x32State Uint64ToUint32s(XlaOp u64);
|
||||
XlaOp Uint32sToUint64(ThreeFry2x32State u32s);
|
||||
// Uses the given bit generator to generate random bits and then converts the
|
||||
// random bits to random numbers of normal distribution.
|
||||
// Returns the random numbers and the state of the random number generator.
|
||||
RngOutput NormalF32Distribution(XlaOp key, XlaOp initial_state,
|
||||
BitGeneratorTy bit_generator,
|
||||
const xla::Shape& shape);
|
||||
|
||||
} // namespace xla
|
||||
|
||||
|
@ -140,5 +140,7 @@ def test_truncated_normal(assert_equal, assert_all_close, dtype, n, y):
|
||||
(alpha * normal_pdf(alpha) - beta * normal_pdf(beta)) / z) - (
|
||||
(normal_pdf(alpha) - normal_pdf(beta)) / z)**2)
|
||||
actual_variance = np.var(y)
|
||||
assert_all_close(actual_variance, expected_variance,
|
||||
rtol=5e-3 if dtype == dtypes.bfloat16 else 1e-3)
|
||||
assert_all_close(
|
||||
actual_variance,
|
||||
expected_variance,
|
||||
rtol=6e-3 if dtype == dtypes.bfloat16 else 1e-3)
|
||||
|
Loading…
Reference in New Issue
Block a user