[XLA] Miscellaneous code clean up to prepare for adding a Philox random bit
generator. Modify prng.h to export the ThreeFry random bit generators. Add three functions to support both stateful and stateless random number generation. These functions take a random bit generator to generate random bits, convert the random bits to values in uniform or normal distributions, and return the random numbers as well as the updated state of the random bit generator. This allows tensorflow to use the same XLA client APIs to implement stateful random ops and stateless random ops. There was a subtle difference between the old stateless and stateful random number generations. After we merge the two implementations into one, we need to slightly increase the relative error tolerance for a test. Move StatelessRngUniform from prng.h to random_ops_util.h in the tf2xla bridge. PiperOrigin-RevId: 242557394
This commit is contained in:
parent
91cbed99b6
commit
4b50df1569
@ -15,6 +15,7 @@ limitations under the License.
|
|||||||
|
|
||||||
// XLA implementations of Categorical op.
|
// XLA implementations of Categorical op.
|
||||||
|
|
||||||
|
#include "tensorflow/compiler/tf2xla/kernels/random_ops_util.h"
|
||||||
#include "tensorflow/compiler/tf2xla/shape_util.h"
|
#include "tensorflow/compiler/tf2xla/shape_util.h"
|
||||||
#include "tensorflow/compiler/tf2xla/type_util.h"
|
#include "tensorflow/compiler/tf2xla/type_util.h"
|
||||||
#include "tensorflow/compiler/tf2xla/xla_helpers.h"
|
#include "tensorflow/compiler/tf2xla/xla_helpers.h"
|
||||||
@ -140,8 +141,6 @@ class StatelessCategoricalOp : public CategoricalOp {
|
|||||||
xla::XlaOp GetLogUniforms(xla::Shape uniform_shape, xla::PrimitiveType type,
|
xla::XlaOp GetLogUniforms(xla::Shape uniform_shape, xla::PrimitiveType type,
|
||||||
XlaOpKernelContext* ctx) override {
|
XlaOpKernelContext* ctx) override {
|
||||||
xla::XlaOp seed = ctx->Input(2);
|
xla::XlaOp seed = ctx->Input(2);
|
||||||
auto seed0 = xla::Reshape(xla::Slice(seed, {0}, {1}, {1}), {});
|
|
||||||
auto seed1 = xla::Reshape(xla::Slice(seed, {1}, {2}, {1}), {});
|
|
||||||
|
|
||||||
xla::XlaBuilder* builder = ctx->builder();
|
xla::XlaBuilder* builder = ctx->builder();
|
||||||
if (uniform_shape.element_type() == xla::BF16) {
|
if (uniform_shape.element_type() == xla::BF16) {
|
||||||
@ -150,8 +149,8 @@ class StatelessCategoricalOp : public CategoricalOp {
|
|||||||
// We want a number in (0, 1) rather than [0, 1) or (0, 1]:
|
// We want a number in (0, 1) rather than [0, 1) or (0, 1]:
|
||||||
// * log(-log(0)) is ∞.
|
// * log(-log(0)) is ∞.
|
||||||
// * log(-log(1)) is -∞.
|
// * log(-log(1)) is -∞.
|
||||||
auto uniforms = xla::StatelessRngUniform(
|
xla::XlaOp uniforms = StatelessRngUniform(
|
||||||
{seed0, seed1}, uniform_shape,
|
seed, uniform_shape,
|
||||||
xla::MinPositiveNormalValue(builder, uniform_shape.element_type()),
|
xla::MinPositiveNormalValue(builder, uniform_shape.element_type()),
|
||||||
xla::One(builder, uniform_shape.element_type()));
|
xla::One(builder, uniform_shape.element_type()));
|
||||||
return xla::ConvertElementType(xla::Log(-xla::Log(uniforms)), type);
|
return xla::ConvertElementType(xla::Log(-xla::Log(uniforms)), type);
|
||||||
|
@ -22,6 +22,13 @@ limitations under the License.
|
|||||||
#include "tensorflow/core/framework/tensor.h"
|
#include "tensorflow/core/framework/tensor.h"
|
||||||
|
|
||||||
namespace tensorflow {
|
namespace tensorflow {
|
||||||
|
// Returns a tensor containing 'shape' random values uniformly distributed in
|
||||||
|
// the range [minval, maxval). The raw random bits are generated by the given
|
||||||
|
// `bit_generator` and converted to the requested data type and range. This
|
||||||
|
// routine requires 2 32-bit integer seeds and currently only supports 'shape's
|
||||||
|
// of type F32, S32 and S64.
|
||||||
|
xla::XlaOp StatelessRngUniform(xla::XlaOp seeds, const xla::Shape& shape,
|
||||||
|
xla::XlaOp minval, xla::XlaOp maxval);
|
||||||
|
|
||||||
// Converts to bfloat16 if `dtype` equals DT_BFLOAT16, no-op otherwise.
|
// Converts to bfloat16 if `dtype` equals DT_BFLOAT16, no-op otherwise.
|
||||||
// It masks the last 16 bit. With normal rounding, values near "maxval" would be
|
// It masks the last 16 bit. With normal rounding, values near "maxval" would be
|
||||||
|
@ -35,127 +35,50 @@ limitations under the License.
|
|||||||
namespace tensorflow {
|
namespace tensorflow {
|
||||||
namespace {
|
namespace {
|
||||||
|
|
||||||
std::pair<xla::ThreeFry2x32State, xla::XlaOp> GetInputsFromCounter(
|
xla::RngOutput StatefulRngUniform(xla::XlaOp key, xla::XlaOp initial_state,
|
||||||
xla::XlaOp counter, const int64 size) {
|
const xla::Shape& shape, xla::XlaOp minval,
|
||||||
auto builder = counter.builder();
|
xla::XlaOp maxval) {
|
||||||
auto input_u64 = Iota(builder, xla::U64, size);
|
|
||||||
input_u64 = input_u64 + counter;
|
|
||||||
counter = counter + xla::ConstantR0<uint64>(builder, size);
|
|
||||||
return std::make_pair(xla::Uint64ToUint32s(input_u64), counter);
|
|
||||||
}
|
|
||||||
|
|
||||||
// `StatelessRngUniformU32` uses ThreeFry2x32’s counter space too
|
|
||||||
// wastefully, only able to generate 2^32*2 int32 numbers for each key, while
|
|
||||||
// the real capacity is 2^64*2. Counter-space efficiency is important for
|
|
||||||
// stateful ops, hence the following 2 new functions.
|
|
||||||
std::pair<xla::XlaOp, xla::XlaOp> StatefulRngUniformU32(
|
|
||||||
xla::XlaOp key, xla::XlaOp counter, const xla::Shape& shape) {
|
|
||||||
auto builder = key.builder();
|
|
||||||
const int64 size = xla::ShapeUtil::ElementsIn(shape);
|
|
||||||
const int64 half_size = xla::CeilOfRatio<int64>(size, 2);
|
|
||||||
const bool size_is_odd = (half_size * 2 != size);
|
|
||||||
auto inputs_counter = GetInputsFromCounter(counter, half_size);
|
|
||||||
auto inputs = inputs_counter.first;
|
|
||||||
counter = inputs_counter.second;
|
|
||||||
auto outputs = xla::ThreeFry2x32(inputs, xla::Uint64ToUint32s(key));
|
|
||||||
if (size_is_odd) {
|
|
||||||
outputs[1] = Slice(outputs[1], {0}, {half_size - 1}, {1});
|
|
||||||
}
|
|
||||||
auto result = ConcatInDim(builder, outputs, 0);
|
|
||||||
return std::make_pair(Reshape(result, xla::AsInt64Slice(shape.dimensions())),
|
|
||||||
counter);
|
|
||||||
}
|
|
||||||
|
|
||||||
std::pair<xla::XlaOp, xla::XlaOp> StatefulRngUniformU64(
|
|
||||||
xla::XlaOp key, xla::XlaOp counter, const xla::Shape& shape) {
|
|
||||||
const int64 size = xla::ShapeUtil::ElementsIn(shape);
|
|
||||||
auto inputs_counter = GetInputsFromCounter(counter, size);
|
|
||||||
auto inputs = inputs_counter.first;
|
|
||||||
counter = inputs_counter.second;
|
|
||||||
auto outputs = ThreeFry2x32(inputs, Uint64ToUint32s(key));
|
|
||||||
auto result = Uint32sToUint64(outputs);
|
|
||||||
return std::make_pair(Reshape(result, xla::AsInt64Slice(shape.dimensions())),
|
|
||||||
counter);
|
|
||||||
}
|
|
||||||
|
|
||||||
std::pair<xla::XlaOp, xla::XlaOp> StatefulRngUniform(xla::XlaOp key,
|
|
||||||
xla::XlaOp counter,
|
|
||||||
const xla::Shape& shape,
|
|
||||||
xla::XlaOp minval,
|
|
||||||
xla::XlaOp maxval) {
|
|
||||||
auto builder = key.builder();
|
|
||||||
xla::PrimitiveType type = shape.element_type();
|
xla::PrimitiveType type = shape.element_type();
|
||||||
switch (type) {
|
switch (type) {
|
||||||
case xla::F32: {
|
case xla::F32:
|
||||||
auto bits_counter = StatefulRngUniformU32(key, counter, shape);
|
return xla::UniformF32Distribution(
|
||||||
auto bits = bits_counter.first;
|
key, initial_state, xla::ThreeFryBitGenerator, minval, maxval, shape);
|
||||||
counter = bits_counter.second;
|
case xla::U32:
|
||||||
return std::make_pair(xla::StatelessRngUniformF32(bits, minval, maxval),
|
case xla::S32:
|
||||||
counter);
|
case xla::U64:
|
||||||
}
|
case xla::S64:
|
||||||
case xla::U32: // fall through
|
return UniformIntDistribution(
|
||||||
case xla::S32: {
|
key, initial_state, xla::ThreeFryBitGenerator, minval, maxval, shape);
|
||||||
auto bits_counter = StatefulRngUniformU32(key, counter, shape);
|
|
||||||
auto bits = bits_counter.first;
|
|
||||||
counter = bits_counter.second;
|
|
||||||
return std::make_pair(
|
|
||||||
xla::StatelessRngUniformInt(bits, minval, maxval, type, xla::U32),
|
|
||||||
counter);
|
|
||||||
}
|
|
||||||
case xla::U64: // fall through
|
|
||||||
case xla::S64: {
|
|
||||||
auto bits_counter = StatefulRngUniformU64(key, counter, shape);
|
|
||||||
auto bits = bits_counter.first;
|
|
||||||
counter = bits_counter.second;
|
|
||||||
return std::make_pair(
|
|
||||||
xla::StatelessRngUniformInt(bits, minval, maxval, type, xla::U64),
|
|
||||||
counter);
|
|
||||||
}
|
|
||||||
default:
|
default:
|
||||||
return std::make_pair(
|
return {key.builder()->ReportError(xla::Unimplemented(
|
||||||
builder->ReportError(xla::Unimplemented(
|
"Types other than F32, U32, S32, U64 and S64 "
|
||||||
"Types other than F32, U32, S32, U64 and S64 "
|
"are not implemented by "
|
||||||
"are not implemented by "
|
"StatefulRngUniform; got %s",
|
||||||
"StatefulRngUniform; got: %s",
|
xla::primitive_util::LowercasePrimitiveTypeName(type))),
|
||||||
xla::primitive_util::LowercasePrimitiveTypeName(type))),
|
initial_state};
|
||||||
counter);
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename A, typename B, typename A2>
|
xla::RngOutput StatefulRngUniformFullInt(xla::XlaOp key,
|
||||||
std::pair<A2, B> map_first(std::function<A2(A)> f, std::pair<A, B> p) {
|
xla::XlaOp initial_state,
|
||||||
return std::make_pair(f(p.first), p.second);
|
const xla::Shape& shape) {
|
||||||
}
|
|
||||||
|
|
||||||
std::pair<xla::XlaOp, xla::XlaOp> StatefulRngUniformFullInt(
|
|
||||||
xla::XlaOp key, xla::XlaOp counter, const xla::Shape& shape) {
|
|
||||||
xla::PrimitiveType type = shape.element_type();
|
xla::PrimitiveType type = shape.element_type();
|
||||||
|
xla::RngOutput output = xla::ThreeFryBitGenerator(key, initial_state, shape);
|
||||||
switch (type) {
|
switch (type) {
|
||||||
case xla::U32:
|
case xla::U32:
|
||||||
return StatefulRngUniformU32(key, counter, shape);
|
|
||||||
case xla::S32: {
|
|
||||||
// Needs explicit function type because of type-inference failure.
|
|
||||||
std::function<xla::XlaOp(xla::XlaOp)> f = [](xla::XlaOp x) {
|
|
||||||
return BitcastConvertType(x, xla::S32);
|
|
||||||
};
|
|
||||||
return map_first(f, StatefulRngUniformU32(key, counter, shape));
|
|
||||||
}
|
|
||||||
case xla::U64:
|
case xla::U64:
|
||||||
return StatefulRngUniformU64(key, counter, shape);
|
return output;
|
||||||
case xla::S64: {
|
case xla::S32:
|
||||||
std::function<xla::XlaOp(xla::XlaOp)> f = [](xla::XlaOp x) {
|
case xla::S64:
|
||||||
return BitcastConvertType(x, xla::S64);
|
output.value = BitcastConvertType(output.value, type);
|
||||||
};
|
return output;
|
||||||
return map_first(f, StatefulRngUniformU64(key, counter, shape));
|
|
||||||
}
|
|
||||||
default:
|
default:
|
||||||
auto builder = key.builder();
|
return {
|
||||||
return std::make_pair(
|
key.builder()->ReportError(xla::Unimplemented(
|
||||||
builder->ReportError(xla::Unimplemented(
|
|
||||||
"Types other than U32, S32, U64 and S64 are not implemented by "
|
"Types other than U32, S32, U64 and S64 are not implemented by "
|
||||||
"StatefulRngUniformFullInt; got: %s",
|
"StatefulRngUniformFullInt; got: %s",
|
||||||
xla::primitive_util::LowercasePrimitiveTypeName(type))),
|
xla::primitive_util::LowercasePrimitiveTypeName(type))),
|
||||||
counter);
|
initial_state};
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -177,15 +100,15 @@ xla::XlaOp ConcatScalars(xla::XlaBuilder* builder,
|
|||||||
0);
|
0);
|
||||||
}
|
}
|
||||||
|
|
||||||
using sampler_return_type = xla::StatusOr<std::pair<xla::XlaOp, xla::XlaOp>>;
|
using SamplerReturnType = xla::StatusOr<xla::RngOutput>;
|
||||||
|
|
||||||
// A helper function containing the common part of several kernels below.
|
// A helper function containing the common part of several kernels below.
|
||||||
// Precondition: 'algorithm' and 'shape' are compile-time constants.
|
// Precondition: 'algorithm' and 'shape' are compile-time constants.
|
||||||
Status CompileImpl(XlaOpKernelContext* ctx, int state_input_idx,
|
Status CompileImpl(
|
||||||
int alg_input_idx, int shape_input_idx,
|
XlaOpKernelContext* ctx, int state_input_idx, int alg_input_idx,
|
||||||
std::function<sampler_return_type(xla::XlaOp, xla::XlaOp,
|
int shape_input_idx,
|
||||||
TensorShape)> const&
|
std::function<SamplerReturnType(xla::XlaOp, xla::XlaOp, TensorShape)> const&
|
||||||
sample_with_threefry) {
|
sampler) {
|
||||||
auto alg_shape = ctx->InputShape(alg_input_idx);
|
auto alg_shape = ctx->InputShape(alg_input_idx);
|
||||||
if (alg_shape.dims() != 0) {
|
if (alg_shape.dims() != 0) {
|
||||||
return errors::InvalidArgument("algorithm must be of shape [], not ",
|
return errors::InvalidArgument("algorithm must be of shape [], not ",
|
||||||
@ -215,24 +138,22 @@ Status CompileImpl(XlaOpKernelContext* ctx, int state_input_idx,
|
|||||||
TensorShape shape;
|
TensorShape shape;
|
||||||
TF_RETURN_IF_ERROR(ctx->ConstantInputAsShape(shape_input_idx, &shape));
|
TF_RETURN_IF_ERROR(ctx->ConstantInputAsShape(shape_input_idx, &shape));
|
||||||
|
|
||||||
static constexpr int COUNTER_SIZE = 1;
|
static constexpr int kStateSize = 1;
|
||||||
auto counter = BitcastConvertType(
|
auto state = BitcastConvertType(
|
||||||
xla::Reshape(xla::Slice(var, {0}, {COUNTER_SIZE}, {1}), {}), xla::U64);
|
xla::Reshape(xla::Slice(var, {0}, {kStateSize}, {1}), {}), xla::U64);
|
||||||
auto key = BitcastConvertType(
|
auto key = BitcastConvertType(
|
||||||
xla::Reshape(xla::Slice(var, {COUNTER_SIZE}, {COUNTER_SIZE + 1}, {1}),
|
xla::Reshape(xla::Slice(var, {kStateSize}, {kStateSize + 1}, {1}), {}),
|
||||||
{}),
|
|
||||||
xla::U64);
|
xla::U64);
|
||||||
|
|
||||||
auto status_or_value = sample_with_threefry(counter, key, shape);
|
auto status_or_value = sampler(state, key, shape);
|
||||||
if (!status_or_value.ok()) {
|
if (!status_or_value.ok()) {
|
||||||
return status_or_value.status();
|
return status_or_value.status();
|
||||||
}
|
}
|
||||||
auto output_counter = status_or_value.ConsumeValueOrDie();
|
xla::RngOutput value_state = status_or_value.ConsumeValueOrDie();
|
||||||
auto output = output_counter.first;
|
state = value_state.state;
|
||||||
counter = output_counter.second;
|
ctx->SetOutput(0, value_state.value);
|
||||||
ctx->SetOutput(0, output);
|
xla::XlaBuilder* builder = ctx->builder();
|
||||||
auto builder = ctx->builder();
|
var = ConcatScalars(builder, {state, key});
|
||||||
var = ConcatScalars(builder, {counter, key});
|
|
||||||
xla::PrimitiveType state_element_type;
|
xla::PrimitiveType state_element_type;
|
||||||
TF_RETURN_IF_ERROR(
|
TF_RETURN_IF_ERROR(
|
||||||
DataTypeToPrimitiveType(STATE_ELEMENT_DTYPE, &state_element_type));
|
DataTypeToPrimitiveType(STATE_ELEMENT_DTYPE, &state_element_type));
|
||||||
@ -252,23 +173,22 @@ class StatefulUniformOp : public XlaOpKernel {
|
|||||||
}
|
}
|
||||||
|
|
||||||
void Compile(XlaOpKernelContext* ctx) override {
|
void Compile(XlaOpKernelContext* ctx) override {
|
||||||
auto builder = ctx->builder();
|
xla::XlaBuilder* builder = ctx->builder();
|
||||||
auto sample_with_threefry = [builder, this](
|
auto sampler = [builder, this](xla::XlaOp state, xla::XlaOp key,
|
||||||
xla::XlaOp counter, xla::XlaOp key,
|
TensorShape shape) -> SamplerReturnType {
|
||||||
TensorShape shape) -> sampler_return_type {
|
|
||||||
xla::Shape xla_shape;
|
xla::Shape xla_shape;
|
||||||
TF_RETURN_IF_ERROR(TensorShapeToXLAShape(DT_FLOAT, shape, &xla_shape));
|
TF_RETURN_IF_ERROR(TensorShapeToXLAShape(DT_FLOAT, shape, &xla_shape));
|
||||||
auto uniform_counter = StatefulRngUniform(
|
xla::RngOutput uniform_state = StatefulRngUniform(
|
||||||
key, counter, xla_shape, xla::ConstantR0<float>(builder, 0.0),
|
key, state, xla_shape, xla::ConstantR0<float>(builder, 0.0),
|
||||||
xla::ConstantR0<float>(builder, 1.0));
|
xla::ConstantR0<float>(builder, 1.0));
|
||||||
auto uniform = uniform_counter.first;
|
xla::XlaOp uniform = uniform_state.value;
|
||||||
counter = uniform_counter.second;
|
state = uniform_state.state;
|
||||||
uniform = MaybeConvertF32ToBF16(uniform, dtype_);
|
uniform = MaybeConvertF32ToBF16(uniform, dtype_);
|
||||||
return {{uniform, counter}};
|
return {{uniform, state}};
|
||||||
};
|
};
|
||||||
OP_REQUIRES_OK(ctx,
|
OP_REQUIRES_OK(ctx,
|
||||||
CompileImpl(ctx, /*state_input_idx=*/0, /*alg_input_idx=*/1,
|
CompileImpl(ctx, /*state_input_idx=*/0, /*alg_input_idx=*/1,
|
||||||
/*shape_input_idx=*/2, sample_with_threefry));
|
/*shape_input_idx=*/2, sampler));
|
||||||
}
|
}
|
||||||
|
|
||||||
private:
|
private:
|
||||||
@ -293,30 +213,20 @@ class StatefulStandardNormalOp : public XlaOpKernel {
|
|||||||
}
|
}
|
||||||
|
|
||||||
void Compile(XlaOpKernelContext* ctx) override {
|
void Compile(XlaOpKernelContext* ctx) override {
|
||||||
auto builder = ctx->builder();
|
auto sampler =
|
||||||
auto sample_with_threefry =
|
|
||||||
// Needs explicit lambda return type because it fails to be inferred.
|
// Needs explicit lambda return type because it fails to be inferred.
|
||||||
[builder, this](xla::XlaOp counter, xla::XlaOp key,
|
[this](xla::XlaOp state, xla::XlaOp key,
|
||||||
TensorShape shape) -> sampler_return_type {
|
TensorShape shape) -> SamplerReturnType {
|
||||||
xla::Shape xla_shape;
|
xla::Shape xla_shape;
|
||||||
TF_RETURN_IF_ERROR(TensorShapeToXLAShape(DT_FLOAT, shape, &xla_shape));
|
TF_RETURN_IF_ERROR(TensorShapeToXLAShape(DT_FLOAT, shape, &xla_shape));
|
||||||
|
xla::RngOutput value_state = xla::NormalF32Distribution(
|
||||||
auto uniform_counter = StatefulRngUniform(
|
key, state, xla::ThreeFryBitGenerator, xla_shape);
|
||||||
key, counter, xla_shape,
|
xla::XlaOp normal = MaybeConvertF32ToBF16(value_state.value, dtype_);
|
||||||
xla::ConstantR0<float>(builder, std::nextafter(-1.0f, 0.0f)),
|
return {{normal, value_state.state}};
|
||||||
xla::ConstantR0<float>(builder, 1.0));
|
|
||||||
auto uniform = uniform_counter.first;
|
|
||||||
counter = uniform_counter.second;
|
|
||||||
// Convert uniform distribution to normal distribution by computing
|
|
||||||
// sqrt(2) * erfinv(x)
|
|
||||||
auto normal =
|
|
||||||
xla::ScalarLike(uniform, std::sqrt(2.0)) * xla::ErfInv(uniform);
|
|
||||||
normal = MaybeConvertF32ToBF16(normal, dtype_);
|
|
||||||
return {{normal, counter}};
|
|
||||||
};
|
};
|
||||||
OP_REQUIRES_OK(ctx,
|
OP_REQUIRES_OK(ctx,
|
||||||
CompileImpl(ctx, /*state_input_idx=*/0, /*alg_input_idx=*/1,
|
CompileImpl(ctx, /*state_input_idx=*/0, /*alg_input_idx=*/1,
|
||||||
/*shape_input_idx=*/2, sample_with_threefry));
|
/*shape_input_idx=*/2, sampler));
|
||||||
}
|
}
|
||||||
|
|
||||||
private:
|
private:
|
||||||
@ -341,27 +251,27 @@ class StatefulTruncatedNormalOp : public XlaOpKernel {
|
|||||||
}
|
}
|
||||||
|
|
||||||
void Compile(XlaOpKernelContext* ctx) override {
|
void Compile(XlaOpKernelContext* ctx) override {
|
||||||
auto builder = ctx->builder();
|
xla::XlaBuilder* builder = ctx->builder();
|
||||||
auto sample_with_threefry =
|
auto sampler =
|
||||||
// Needs explicit lambda return type because it fails to be inferred.
|
// Needs explicit lambda return type because it fails to be inferred.
|
||||||
[builder, this](xla::XlaOp counter, xla::XlaOp key,
|
[builder, this](xla::XlaOp state, xla::XlaOp key,
|
||||||
TensorShape shape) -> sampler_return_type {
|
TensorShape shape) -> SamplerReturnType {
|
||||||
xla::Shape xla_shape;
|
xla::Shape xla_shape;
|
||||||
TF_RETURN_IF_ERROR(TensorShapeToXLAShape(DT_FLOAT, shape, &xla_shape));
|
TF_RETURN_IF_ERROR(TensorShapeToXLAShape(DT_FLOAT, shape, &xla_shape));
|
||||||
|
|
||||||
auto uniform_counter = StatefulRngUniform(
|
xla::RngOutput uniform_result = StatefulRngUniform(
|
||||||
key, counter, xla_shape,
|
key, state, xla_shape,
|
||||||
xla::MinPositiveNormalValue(builder, xla_shape.element_type()),
|
xla::MinPositiveNormalValue(builder, xla_shape.element_type()),
|
||||||
xla::One(builder, xla_shape.element_type()));
|
xla::One(builder, xla_shape.element_type()));
|
||||||
auto uniform = uniform_counter.first;
|
xla::XlaOp uniform = uniform_result.value;
|
||||||
counter = uniform_counter.second;
|
state = uniform_result.state;
|
||||||
xla::XlaOp truncated_normal = TruncatedNormal(uniform);
|
xla::XlaOp truncated_normal = TruncatedNormal(uniform);
|
||||||
truncated_normal = MaybeConvertF32ToBF16(truncated_normal, dtype_);
|
truncated_normal = MaybeConvertF32ToBF16(truncated_normal, dtype_);
|
||||||
return {{truncated_normal, counter}};
|
return {{truncated_normal, state}};
|
||||||
};
|
};
|
||||||
OP_REQUIRES_OK(ctx,
|
OP_REQUIRES_OK(ctx,
|
||||||
CompileImpl(ctx, /*state_input_idx=*/0, /*alg_input_idx=*/1,
|
CompileImpl(ctx, /*state_input_idx=*/0, /*alg_input_idx=*/1,
|
||||||
/*shape_input_idx=*/2, sample_with_threefry));
|
/*shape_input_idx=*/2, sampler));
|
||||||
}
|
}
|
||||||
|
|
||||||
private:
|
private:
|
||||||
@ -388,11 +298,11 @@ class StatefulUniformIntOp : public XlaOpKernel {
|
|||||||
xla::XlaOp minval = ctx->Input(3);
|
xla::XlaOp minval = ctx->Input(3);
|
||||||
xla::XlaOp maxval = ctx->Input(4);
|
xla::XlaOp maxval = ctx->Input(4);
|
||||||
auto sample_with_threefry = [minval, maxval, this](
|
auto sample_with_threefry = [minval, maxval, this](
|
||||||
xla::XlaOp counter, xla::XlaOp key,
|
xla::XlaOp state, xla::XlaOp key,
|
||||||
TensorShape shape) -> sampler_return_type {
|
TensorShape shape) -> SamplerReturnType {
|
||||||
xla::Shape xla_shape;
|
xla::Shape xla_shape;
|
||||||
TF_RETURN_IF_ERROR(TensorShapeToXLAShape(dtype_, shape, &xla_shape));
|
TF_RETURN_IF_ERROR(TensorShapeToXLAShape(dtype_, shape, &xla_shape));
|
||||||
return StatefulRngUniform(key, counter, xla_shape, minval, maxval);
|
return StatefulRngUniform(key, state, xla_shape, minval, maxval);
|
||||||
};
|
};
|
||||||
OP_REQUIRES_OK(ctx,
|
OP_REQUIRES_OK(ctx,
|
||||||
CompileImpl(ctx, /*state_input_idx=*/0, /*alg_input_idx=*/1,
|
CompileImpl(ctx, /*state_input_idx=*/0, /*alg_input_idx=*/1,
|
||||||
@ -420,12 +330,11 @@ class StatefulUniformFullIntOp : public XlaOpKernel {
|
|||||||
}
|
}
|
||||||
|
|
||||||
void Compile(XlaOpKernelContext* ctx) override {
|
void Compile(XlaOpKernelContext* ctx) override {
|
||||||
auto sample_with_threefry = [this](
|
auto sample_with_threefry = [this](xla::XlaOp state, xla::XlaOp key,
|
||||||
xla::XlaOp counter, xla::XlaOp key,
|
TensorShape shape) -> SamplerReturnType {
|
||||||
TensorShape shape) -> sampler_return_type {
|
|
||||||
xla::Shape xla_shape;
|
xla::Shape xla_shape;
|
||||||
TF_RETURN_IF_ERROR(TensorShapeToXLAShape(dtype_, shape, &xla_shape));
|
TF_RETURN_IF_ERROR(TensorShapeToXLAShape(dtype_, shape, &xla_shape));
|
||||||
return StatefulRngUniformFullInt(key, counter, xla_shape);
|
return StatefulRngUniformFullInt(key, state, xla_shape);
|
||||||
};
|
};
|
||||||
OP_REQUIRES_OK(ctx,
|
OP_REQUIRES_OK(ctx,
|
||||||
CompileImpl(ctx, /*state_input_idx=*/0, /*alg_input_idx=*/1,
|
CompileImpl(ctx, /*state_input_idx=*/0, /*alg_input_idx=*/1,
|
||||||
|
@ -36,8 +36,8 @@ namespace tensorflow {
|
|||||||
xla::XlaOp MaybeConvertF32ToBF16(xla::XlaOp input, DataType dtype) {
|
xla::XlaOp MaybeConvertF32ToBF16(xla::XlaOp input, DataType dtype) {
|
||||||
if (dtype == DT_BFLOAT16) {
|
if (dtype == DT_BFLOAT16) {
|
||||||
xla::XlaBuilder* builder = input.builder();
|
xla::XlaBuilder* builder = input.builder();
|
||||||
auto output = xla::BitcastConvertType(input, xla::U32) &
|
xla::XlaOp output = xla::BitcastConvertType(input, xla::U32) &
|
||||||
xla::ConstantR0<uint32>(builder, 0xFFFF0000);
|
xla::ConstantR0<uint32>(builder, 0xFFFF0000);
|
||||||
return xla::ConvertElementType(xla::BitcastConvertType(output, xla::F32),
|
return xla::ConvertElementType(xla::BitcastConvertType(output, xla::F32),
|
||||||
xla::BF16);
|
xla::BF16);
|
||||||
} else {
|
} else {
|
||||||
@ -45,22 +45,36 @@ xla::XlaOp MaybeConvertF32ToBF16(xla::XlaOp input, DataType dtype) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
xla::XlaOp Uniform2NormalUsingSqrtErfinv(xla::XlaOp uniform) {
|
xla::XlaOp StatelessRngUniform(xla::XlaOp seeds, const xla::Shape& shape,
|
||||||
// Convert uniform distribution to normal distribution by computing
|
xla::XlaOp minval, xla::XlaOp maxval) {
|
||||||
// sqrt(2) * erfinv(x)
|
xla::XlaBuilder* builder = seeds.builder();
|
||||||
return xla::ScalarLike(uniform, std::sqrt(2.0)) * xla::ErfInv(uniform);
|
|
||||||
}
|
|
||||||
|
|
||||||
// A wrapper of xla::StatelessRngUniform. Returns an op that produces random
|
xla::XlaOp seed0 = xla::Reshape(xla::Slice(seeds, {0}, {1}, {1}), {});
|
||||||
// values with uniform distribution in the range [minval, maxval) for the given
|
xla::XlaOp seed1 = xla::Reshape(xla::Slice(seeds, {1}, {2}, {1}), {});
|
||||||
// shape and given two 32-bit seeds. Currently only shapes of type F32, S32 and
|
xla::XlaOp key = ConvertElementType(seed0, xla::U64) |
|
||||||
// S64 are implemented.
|
ShiftLeft(ConvertElementType(seed1, xla::U64),
|
||||||
xla::XlaOp StatelessRandomUniformImpl(const xla::Shape& shape, DataType unused,
|
ConstantR0WithType(builder, xla::U64, 32));
|
||||||
xla::XlaOp seed, xla::XlaOp minval,
|
xla::XlaOp initial_state = xla::ConstantR0WithType(builder, xla::U64, 0);
|
||||||
xla::XlaOp maxval) {
|
xla::PrimitiveType type = shape.element_type();
|
||||||
xla::XlaOp seed0 = xla::Reshape(xla::Slice(seed, {0}, {1}, {1}), {});
|
switch (type) {
|
||||||
xla::XlaOp seed1 = xla::Reshape(xla::Slice(seed, {1}, {2}, {1}), {});
|
case xla::F32:
|
||||||
return xla::StatelessRngUniform({seed0, seed1}, shape, minval, maxval);
|
return xla::UniformF32Distribution(key, initial_state,
|
||||||
|
xla::ThreeFryBitGenerator, minval,
|
||||||
|
maxval, shape)
|
||||||
|
.value;
|
||||||
|
case xla::S32: // fall through
|
||||||
|
case xla::S64:
|
||||||
|
return UniformIntDistribution(key, initial_state,
|
||||||
|
xla::ThreeFryBitGenerator, minval, maxval,
|
||||||
|
shape)
|
||||||
|
.value;
|
||||||
|
break;
|
||||||
|
default:
|
||||||
|
return builder->ReportError(xla::Unimplemented(
|
||||||
|
"Types other than F32, S32 and S64 are not implemented by "
|
||||||
|
"StatelessRngUniform; got %s",
|
||||||
|
xla::primitive_util::LowercasePrimitiveTypeName(type)));
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
namespace {
|
namespace {
|
||||||
@ -86,8 +100,8 @@ class StatelessRandomUniformOp : public XlaOpKernel {
|
|||||||
|
|
||||||
xla::Shape xla_shape;
|
xla::Shape xla_shape;
|
||||||
OP_REQUIRES_OK(ctx, TensorShapeToXLAShape(DT_FLOAT, shape, &xla_shape));
|
OP_REQUIRES_OK(ctx, TensorShapeToXLAShape(DT_FLOAT, shape, &xla_shape));
|
||||||
xla::XlaOp uniform = StatelessRandomUniformImpl(
|
xla::XlaOp uniform = StatelessRngUniform(
|
||||||
xla_shape, dtype_, seed, xla::ConstantR0<float>(builder, 0.0),
|
seed, xla_shape, xla::ConstantR0<float>(builder, 0.0),
|
||||||
xla::ConstantR0<float>(builder, 1.0));
|
xla::ConstantR0<float>(builder, 1.0));
|
||||||
uniform = MaybeConvertF32ToBF16(uniform, dtype_);
|
uniform = MaybeConvertF32ToBF16(uniform, dtype_);
|
||||||
ctx->SetOutput(0, uniform);
|
ctx->SetOutput(0, uniform);
|
||||||
@ -136,8 +150,8 @@ class StatelessRandomUniformIntOp : public XlaOpKernel {
|
|||||||
|
|
||||||
xla::Shape xla_shape;
|
xla::Shape xla_shape;
|
||||||
OP_REQUIRES_OK(ctx, TensorShapeToXLAShape(dtype_, shape, &xla_shape));
|
OP_REQUIRES_OK(ctx, TensorShapeToXLAShape(dtype_, shape, &xla_shape));
|
||||||
xla::XlaOp uniform =
|
xla::XlaOp uniform = StatelessRngUniform(seed, xla_shape, minval, maxval);
|
||||||
StatelessRandomUniformImpl(xla_shape, dtype_, seed, minval, maxval);
|
|
||||||
ctx->SetOutput(0, uniform);
|
ctx->SetOutput(0, uniform);
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -170,14 +184,20 @@ class StatelessRandomNormalOp : public XlaOpKernel {
|
|||||||
errors::InvalidArgument("seed must have shape [2], not ",
|
errors::InvalidArgument("seed must have shape [2], not ",
|
||||||
seed_shape.DebugString()));
|
seed_shape.DebugString()));
|
||||||
xla::XlaOp seed = ctx->Input(1);
|
xla::XlaOp seed = ctx->Input(1);
|
||||||
xla::XlaBuilder* builder = ctx->builder();
|
|
||||||
xla::Shape xla_shape;
|
xla::Shape xla_shape;
|
||||||
OP_REQUIRES_OK(ctx, TensorShapeToXLAShape(DT_FLOAT, shape, &xla_shape));
|
OP_REQUIRES_OK(ctx, TensorShapeToXLAShape(DT_FLOAT, shape, &xla_shape));
|
||||||
xla::XlaOp uniform = StatelessRandomUniformImpl(
|
|
||||||
xla_shape, dtype_, seed,
|
xla::XlaBuilder* builder = seed.builder();
|
||||||
xla::ConstantR0<float>(builder, std::nextafter(-1.0f, 0.0f)),
|
xla::XlaOp seed0 = xla::Reshape(xla::Slice(seed, {0}, {1}, {1}), {});
|
||||||
xla::ConstantR0<float>(builder, 1.0));
|
xla::XlaOp seed1 = xla::Reshape(xla::Slice(seed, {1}, {2}, {1}), {});
|
||||||
xla::XlaOp normal = Uniform2NormalUsingSqrtErfinv(uniform);
|
xla::XlaOp initial_state = xla::ConstantR0WithType(builder, xla::U64, 0);
|
||||||
|
xla::XlaOp key = ConvertElementType(seed0, xla::U64) |
|
||||||
|
ShiftLeft(ConvertElementType(seed1, xla::U64),
|
||||||
|
ConstantR0WithType(builder, xla::U64, 32));
|
||||||
|
xla::XlaOp normal =
|
||||||
|
xla::NormalF32Distribution(key, initial_state,
|
||||||
|
xla::ThreeFryBitGenerator, xla_shape)
|
||||||
|
.value;
|
||||||
normal = MaybeConvertF32ToBF16(normal, dtype_);
|
normal = MaybeConvertF32ToBF16(normal, dtype_);
|
||||||
ctx->SetOutput(0, normal);
|
ctx->SetOutput(0, normal);
|
||||||
}
|
}
|
||||||
@ -215,8 +235,8 @@ class StatelessTruncatedNormalOp : public XlaOpKernel {
|
|||||||
|
|
||||||
xla::Shape xla_shape;
|
xla::Shape xla_shape;
|
||||||
OP_REQUIRES_OK(ctx, TensorShapeToXLAShape(DT_FLOAT, shape, &xla_shape));
|
OP_REQUIRES_OK(ctx, TensorShapeToXLAShape(DT_FLOAT, shape, &xla_shape));
|
||||||
xla::XlaOp uniform = StatelessRandomUniformImpl(
|
xla::XlaOp uniform = StatelessRngUniform(
|
||||||
xla_shape, dtype_, seed,
|
seed, xla_shape,
|
||||||
xla::MinPositiveNormalValue(builder, xla_shape.element_type()),
|
xla::MinPositiveNormalValue(builder, xla_shape.element_type()),
|
||||||
xla::One(builder, xla_shape.element_type()));
|
xla::One(builder, xla_shape.element_type()));
|
||||||
xla::XlaOp truncated_normal = TruncatedNormal(uniform);
|
xla::XlaOp truncated_normal = TruncatedNormal(uniform);
|
||||||
|
@ -32,8 +32,12 @@ XlaOp RotateLeftU32(XlaOp v, int distance) {
|
|||||||
ShiftRightLogical(v, ConstantR0<uint32>(v.builder(), 32 - distance));
|
ShiftRightLogical(v, ConstantR0<uint32>(v.builder(), 32 - distance));
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace
|
// The internal state of the Three Fry implementation.
|
||||||
|
using ThreeFry2x32State = std::array<XlaOp, 2>;
|
||||||
|
|
||||||
|
// Implements the ThreeFry counter-based PRNG algorithm.
|
||||||
|
// Salmon et al. SC 2011. Parallel random numbers: as easy as 1, 2, 3.
|
||||||
|
// http://www.thesalmons.org/john/random123/papers/random123sc11.pdf
|
||||||
ThreeFry2x32State ThreeFry2x32(ThreeFry2x32State input, ThreeFry2x32State key) {
|
ThreeFry2x32State ThreeFry2x32(ThreeFry2x32State input, ThreeFry2x32State key) {
|
||||||
XlaBuilder* builder = input[0].builder();
|
XlaBuilder* builder = input[0].builder();
|
||||||
key[0] = BitcastConvertType(key[0], U32);
|
key[0] = BitcastConvertType(key[0], U32);
|
||||||
@ -104,56 +108,68 @@ ThreeFry2x32State ThreeFry2x32(ThreeFry2x32State input, ThreeFry2x32State key) {
|
|||||||
return x;
|
return x;
|
||||||
}
|
}
|
||||||
|
|
||||||
// Returns the inputs with unique counter values for ThreeFry2x32.
|
// Converts a uint64 to two uint32s.
|
||||||
ThreeFry2x32State GetInputs(const int64 size, XlaBuilder* builder) {
|
|
||||||
ThreeFry2x32State inputs;
|
|
||||||
inputs[0] = Iota(builder, U32, size);
|
|
||||||
inputs[1] = inputs[0] + ConstantR0<uint32>(builder, size);
|
|
||||||
return inputs;
|
|
||||||
}
|
|
||||||
|
|
||||||
XlaOp StatelessRngUniformU32(std::array<XlaOp, 2> key, const Shape& shape) {
|
|
||||||
XlaBuilder* builder = key[0].builder();
|
|
||||||
const int64 size = ShapeUtil::ElementsIn(shape);
|
|
||||||
const int64 half_size = CeilOfRatio<int64>(size, 2);
|
|
||||||
const bool size_is_odd = (half_size * 2 != size);
|
|
||||||
ThreeFry2x32State inputs = GetInputs(half_size, builder);
|
|
||||||
ThreeFry2x32State outputs = ThreeFry2x32(inputs, key);
|
|
||||||
if (size_is_odd) {
|
|
||||||
outputs[1] = Slice(outputs[1], {0}, {half_size - 1}, {1});
|
|
||||||
}
|
|
||||||
auto result = ConcatInDim(builder, outputs, 0);
|
|
||||||
return Reshape(result, AsInt64Slice(shape.dimensions()));
|
|
||||||
}
|
|
||||||
|
|
||||||
ThreeFry2x32State Uint64ToUint32s(XlaOp u64) {
|
ThreeFry2x32State Uint64ToUint32s(XlaOp u64) {
|
||||||
auto builder = u64.builder();
|
XlaBuilder* builder = u64.builder();
|
||||||
auto const32 = ConstantR0WithType(builder, U64, 32);
|
XlaOp const32 = ConstantR0WithType(builder, U64, 32);
|
||||||
auto fst = ConvertElementType(u64, U32);
|
XlaOp fst = ConvertElementType(u64, U32);
|
||||||
auto snd = ConvertElementType(ShiftRightLogical(u64, const32), U32);
|
XlaOp snd = ConvertElementType(ShiftRightLogical(u64, const32), U32);
|
||||||
return {fst, snd};
|
return {fst, snd};
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Converts two uint32s to a uint64.
|
||||||
XlaOp Uint32sToUint64(ThreeFry2x32State u32s) {
|
XlaOp Uint32sToUint64(ThreeFry2x32State u32s) {
|
||||||
auto builder = u32s[0].builder();
|
XlaBuilder* builder = u32s[0].builder();
|
||||||
return ConvertElementType(u32s[0], U64) |
|
return ConvertElementType(u32s[0], U64) |
|
||||||
ShiftLeft(ConvertElementType(u32s[1], U64),
|
ShiftLeft(ConvertElementType(u32s[1], U64),
|
||||||
ConstantR0WithType(builder, U64, 32));
|
ConstantR0WithType(builder, U64, 32));
|
||||||
}
|
}
|
||||||
|
|
||||||
XlaOp StatelessRngUniformU64(std::array<XlaOp, 2> key, const Shape& shape) {
|
// Given the initial state and the request number of random numbers to be
|
||||||
XlaBuilder* builder = key[0].builder();
|
// generated, returns the input for the random number generator and a new state.
|
||||||
const int64 size = ShapeUtil::ElementsIn(shape);
|
std::pair<ThreeFry2x32State, XlaOp> GetThreeFryInputsAndUpdatedState(
|
||||||
ThreeFry2x32State inputs = GetInputs(size, builder);
|
XlaOp initial_state, const int64 size) {
|
||||||
ThreeFry2x32State outputs = ThreeFry2x32(inputs, key);
|
XlaBuilder* builder = initial_state.builder();
|
||||||
// low 32 bit: outputs[0], high 32 bit: outputs[1]
|
XlaOp input_u64 = Iota(builder, U64, size);
|
||||||
auto result = Uint32sToUint64(outputs);
|
input_u64 = input_u64 + initial_state;
|
||||||
return Reshape(result, AsInt64Slice(shape.dimensions()));
|
XlaOp new_state = initial_state + ConstantR0<uint64>(builder, size);
|
||||||
|
return std::make_pair(Uint64ToUint32s(input_u64), new_state);
|
||||||
}
|
}
|
||||||
|
|
||||||
XlaOp StatelessRngUniformF32(XlaOp bits, XlaOp minval, XlaOp maxval) {
|
// Generates random 32bits with the given shape using the Three Fry
|
||||||
XlaBuilder* builder = bits.builder();
|
// implementation. Returns the random bits and the new state.
|
||||||
|
RngOutput ThreeFryRngBit32(XlaOp key, XlaOp initial_state, const Shape& shape) {
|
||||||
|
XlaBuilder* builder = key.builder();
|
||||||
|
const int64 size = ShapeUtil::ElementsIn(shape);
|
||||||
|
const int64 half_size = CeilOfRatio<int64>(size, 2);
|
||||||
|
const bool size_is_odd = (half_size * 2 != size);
|
||||||
|
std::pair<ThreeFry2x32State, XlaOp> inputs_state =
|
||||||
|
GetThreeFryInputsAndUpdatedState(initial_state, half_size);
|
||||||
|
ThreeFry2x32State inputs = inputs_state.first;
|
||||||
|
ThreeFry2x32State outputs = ThreeFry2x32(inputs, Uint64ToUint32s(key));
|
||||||
|
if (size_is_odd) {
|
||||||
|
outputs[1] = Slice(outputs[1], {0}, {half_size - 1}, {1});
|
||||||
|
}
|
||||||
|
XlaOp result = ConcatInDim(builder, outputs, 0);
|
||||||
|
return {Reshape(result, AsInt64Slice(shape.dimensions())),
|
||||||
|
inputs_state.second};
|
||||||
|
}
|
||||||
|
|
||||||
|
// Generates random 64bits with the given shape using the Three Fry
|
||||||
|
// implementation. Returns the random bits and the new state.
|
||||||
|
RngOutput ThreeFryRngBit64(XlaOp key, XlaOp initial_state, const Shape& shape) {
|
||||||
|
const int64 size = ShapeUtil::ElementsIn(shape);
|
||||||
|
std::pair<ThreeFry2x32State, XlaOp> inputs_state =
|
||||||
|
GetThreeFryInputsAndUpdatedState(initial_state, size);
|
||||||
|
ThreeFry2x32State inputs = inputs_state.first;
|
||||||
|
ThreeFry2x32State outputs = ThreeFry2x32(inputs, Uint64ToUint32s(key));
|
||||||
|
XlaOp result = Uint32sToUint64(outputs);
|
||||||
|
return {Reshape(result, AsInt64Slice(shape.dimensions())),
|
||||||
|
inputs_state.second};
|
||||||
|
}
|
||||||
|
|
||||||
|
XlaOp ConvertRandomBitsToUniformF32(XlaOp bits, XlaOp minval, XlaOp maxval) {
|
||||||
|
XlaBuilder* builder = bits.builder();
|
||||||
// Form 23 random mantissa bits, with a leading 1 bit. The leading 1 bit
|
// Form 23 random mantissa bits, with a leading 1 bit. The leading 1 bit
|
||||||
// forces the random bits into the mantissa.
|
// forces the random bits into the mantissa.
|
||||||
constexpr int kFloatBits = 32;
|
constexpr int kFloatBits = 32;
|
||||||
@ -161,50 +177,95 @@ XlaOp StatelessRngUniformF32(XlaOp bits, XlaOp minval, XlaOp maxval) {
|
|||||||
bits = ShiftRightLogical(
|
bits = ShiftRightLogical(
|
||||||
bits, ConstantR0<uint32>(builder, kFloatBits - kMantissaBits)) |
|
bits, ConstantR0<uint32>(builder, kFloatBits - kMantissaBits)) |
|
||||||
ConstantR0<uint32>(builder, absl::bit_cast<uint32>(1.0f));
|
ConstantR0<uint32>(builder, absl::bit_cast<uint32>(1.0f));
|
||||||
auto floats = BitcastConvertType(bits, F32);
|
XlaOp values = BitcastConvertType(bits, F32);
|
||||||
|
|
||||||
// We have a floating point number in the range [1.0, 2.0).
|
// We have a floating point number in the range [1.0, 2.0).
|
||||||
// Subtract 1.0f to shift to the range [0.0, 1.0)
|
// Subtract 1.0f to shift to the range [0.0, 1.0)
|
||||||
floats = floats - ConstantR0<float>(builder, 1.0f);
|
values = values - ConstantR0<float>(builder, 1.0f);
|
||||||
// Multiply and add to shift to the range [minval, maxval).
|
// Multiply and add to shift to the range [minval, maxval).
|
||||||
return floats * (maxval - minval) + minval;
|
return values * (maxval - minval) + minval;
|
||||||
}
|
}
|
||||||
|
|
||||||
XlaOp StatelessRngUniformInt(XlaOp bits, XlaOp minval, XlaOp maxval,
|
XlaOp ConvertRandomBitsToUniformInt(XlaOp bits, XlaOp minval, XlaOp maxval,
|
||||||
PrimitiveType type, PrimitiveType unsigned_type) {
|
PrimitiveType type,
|
||||||
|
PrimitiveType unsigned_type) {
|
||||||
XlaBuilder* builder = bits.builder();
|
XlaBuilder* builder = bits.builder();
|
||||||
auto range = BitcastConvertType(maxval, unsigned_type) -
|
XlaOp range = BitcastConvertType(maxval, unsigned_type) -
|
||||||
BitcastConvertType(minval, unsigned_type);
|
BitcastConvertType(minval, unsigned_type);
|
||||||
auto dist = Rem(bits, range);
|
XlaOp dist = Rem(bits, range);
|
||||||
auto dist_div_2 =
|
XlaOp dist_div_2 =
|
||||||
ShiftRightLogical(dist, ConstantR0WithType(builder, unsigned_type, 1));
|
ShiftRightLogical(dist, ConstantR0WithType(builder, unsigned_type, 1));
|
||||||
|
|
||||||
return minval + BitcastConvertType(dist_div_2, type) +
|
return minval + BitcastConvertType(dist_div_2, type) +
|
||||||
BitcastConvertType(dist - dist_div_2, type);
|
BitcastConvertType(dist - dist_div_2, type);
|
||||||
}
|
}
|
||||||
|
|
||||||
XlaOp StatelessRngUniform(std::array<XlaOp, 2> seeds, const Shape& shape,
|
XlaOp UniformToNormalUsingSqrtErfInv(XlaOp uniform) {
|
||||||
XlaOp minval, XlaOp maxval) {
|
return ScalarLike(uniform, std::sqrt(2.0)) * ErfInv(uniform);
|
||||||
XlaBuilder* builder = seeds[0].builder();
|
}
|
||||||
|
|
||||||
|
} // namespace
|
||||||
|
|
||||||
|
RngOutput ThreeFryBitGenerator(XlaOp key, XlaOp initial_state,
|
||||||
|
const Shape& shape) {
|
||||||
PrimitiveType type = shape.element_type();
|
PrimitiveType type = shape.element_type();
|
||||||
switch (type) {
|
switch (type) {
|
||||||
case F32: {
|
case F32:
|
||||||
auto bits = StatelessRngUniformU32(seeds, shape);
|
case U32:
|
||||||
return StatelessRngUniformF32(bits, minval, maxval);
|
case S32:
|
||||||
}
|
return ThreeFryRngBit32(key, initial_state, shape);
|
||||||
case S32: {
|
case U64:
|
||||||
auto bits = StatelessRngUniformU32(seeds, shape);
|
case S64:
|
||||||
return StatelessRngUniformInt(bits, minval, maxval, type, U32);
|
return ThreeFryRngBit64(key, initial_state, shape);
|
||||||
}
|
|
||||||
case S64: {
|
|
||||||
auto bits = StatelessRngUniformU64(seeds, shape);
|
|
||||||
return StatelessRngUniformInt(bits, minval, maxval, type, U64);
|
|
||||||
}
|
|
||||||
default:
|
default:
|
||||||
return builder->ReportError(Unimplemented(
|
return {key.builder()->ReportError(Unimplemented(
|
||||||
"Types other than F32, S32 and S64 are not implemented by "
|
"Types other than F32, U32, S32, U64 and S64 "
|
||||||
"StatelessRngUniform."));
|
"are not implemented by ThreeFryBitGenerator; got %s",
|
||||||
|
xla::primitive_util::LowercasePrimitiveTypeName(type))),
|
||||||
|
initial_state};
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
RngOutput UniformF32Distribution(XlaOp key, XlaOp initial_state,
|
||||||
|
BitGeneratorTy bit_generator, XlaOp minval,
|
||||||
|
XlaOp maxval, const Shape& shape) {
|
||||||
|
DCHECK_EQ(shape.element_type(), F32);
|
||||||
|
RngOutput bits_state = bit_generator(key, initial_state, shape);
|
||||||
|
XlaOp bits = bits_state.value;
|
||||||
|
XlaOp new_state = bits_state.state;
|
||||||
|
return {ConvertRandomBitsToUniformF32(bits, minval, maxval), new_state};
|
||||||
|
}
|
||||||
|
|
||||||
|
RngOutput UniformIntDistribution(XlaOp key, XlaOp initial_state,
|
||||||
|
BitGeneratorTy bit_generator, XlaOp minval,
|
||||||
|
XlaOp maxval, const Shape& shape) {
|
||||||
|
RngOutput bits_state = bit_generator(key, initial_state, shape);
|
||||||
|
XlaOp bits = bits_state.value;
|
||||||
|
XlaOp new_state = bits_state.state;
|
||||||
|
PrimitiveType type = shape.element_type();
|
||||||
|
PrimitiveType unsigned_type;
|
||||||
|
if (type == U32 || type == S32) {
|
||||||
|
unsigned_type = U32;
|
||||||
|
} else {
|
||||||
|
DCHECK(type == U64 || type == S64);
|
||||||
|
unsigned_type = U64;
|
||||||
|
}
|
||||||
|
return {
|
||||||
|
ConvertRandomBitsToUniformInt(bits, minval, maxval, type, unsigned_type),
|
||||||
|
new_state};
|
||||||
|
}
|
||||||
|
|
||||||
|
RngOutput NormalF32Distribution(XlaOp key, XlaOp initial_state,
|
||||||
|
BitGeneratorTy bit_generator,
|
||||||
|
const Shape& shape) {
|
||||||
|
DCHECK_EQ(shape.element_type(), F32);
|
||||||
|
XlaBuilder* builder = key.builder();
|
||||||
|
RngOutput bits_state = UniformF32Distribution(
|
||||||
|
key, initial_state, bit_generator,
|
||||||
|
ConstantR0<float>(builder, std::nextafter(-1.0f, 0.0f)),
|
||||||
|
ConstantR0<float>(builder, 1.0), shape);
|
||||||
|
XlaOp normal = UniformToNormalUsingSqrtErfInv(bits_state.value);
|
||||||
|
return {normal, bits_state.state};
|
||||||
|
}
|
||||||
|
|
||||||
} // namespace xla
|
} // namespace xla
|
||||||
|
@ -23,37 +23,52 @@ limitations under the License.
|
|||||||
|
|
||||||
namespace xla {
|
namespace xla {
|
||||||
|
|
||||||
|
// Records the bits and state generated by a random number generator.
|
||||||
|
struct RngOutput {
|
||||||
|
XlaOp value;
|
||||||
|
XlaOp state;
|
||||||
|
};
|
||||||
|
|
||||||
|
// A BitGenerator returns random bits and updated random bit generator state.
|
||||||
|
//
|
||||||
|
// key: is a value input to a random number generator that can affect the
|
||||||
|
// sequence of number it will generate. A random number generator constructs
|
||||||
|
// its seed using the key and the initial state. The tf2xla bridge passes the
|
||||||
|
// seed operand of a tensorflow random operation as a key to the random bit
|
||||||
|
// generator, for example.
|
||||||
|
// initial_state: initial_state is the initial state of the current random
|
||||||
|
// number generation. It could be 0 for a stateless random operation, and
|
||||||
|
// the returned state from a previous execution for a stateful random
|
||||||
|
// operation.
|
||||||
|
// shape: the shape of the random bits.
|
||||||
|
using BitGeneratorTy = std::function<RngOutput(XlaOp key, XlaOp initial_state,
|
||||||
|
const xla::Shape& shape)>;
|
||||||
|
|
||||||
// Implements the ThreeFry counter-based PRNG algorithm.
|
// Implements the ThreeFry counter-based PRNG algorithm.
|
||||||
// Salmon et al. SC 2011. Parallel random numbers: as easy as 1, 2, 3.
|
// Salmon et al. SC 2011. Parallel random numbers: as easy as 1, 2, 3.
|
||||||
// http://www.thesalmons.org/john/random123/papers/random123sc11.pdf
|
// http://www.thesalmons.org/john/random123/papers/random123sc11.pdf
|
||||||
using ThreeFry2x32State = std::array<XlaOp, 2>;
|
RngOutput ThreeFryBitGenerator(XlaOp key, XlaOp initial_state,
|
||||||
ThreeFry2x32State ThreeFry2x32(ThreeFry2x32State input, ThreeFry2x32State key);
|
const xla::Shape& shape);
|
||||||
|
|
||||||
// Returns a tensor containing 'shape' random values uniformly distributed in
|
// Uses the given bit generator to generate random bits and then converts the
|
||||||
// the range [minval, maxval). Requires 2 32-bit integer seeds.
|
// random bits to random numbers of uniform distribution in the given range.
|
||||||
// Currently only 'shape's of type F32, S32 and S64 are implemented.
|
// Returns the random numbers and the state of the random number generator.
|
||||||
XlaOp StatelessRngUniform(std::array<XlaOp, 2> seeds, const Shape& shape,
|
// This function is for shape with float element type.
|
||||||
XlaOp minval, XlaOp maxval);
|
RngOutput UniformF32Distribution(XlaOp key, XlaOp initial_state,
|
||||||
|
BitGeneratorTy bit_generator, XlaOp minval,
|
||||||
|
XlaOp maxval, const xla::Shape& shape);
|
||||||
|
|
||||||
// Converts a 32-bit (signed or unsigned) integer random number `bits` into a
|
// Similar to UniformF32Distribution but for shape with integer element types.
|
||||||
// float32 in the range [minval, maxval).
|
RngOutput UniformIntDistribution(XlaOp key, XlaOp initial_state,
|
||||||
XlaOp StatelessRngUniformF32(XlaOp bits, XlaOp minval, XlaOp maxval);
|
BitGeneratorTy bit_generator, XlaOp minval,
|
||||||
|
XlaOp maxval, const xla::Shape& shape);
|
||||||
|
|
||||||
// Converts an integer random number 'bits' of type 'type' to a random number
|
// Uses the given bit generator to generate random bits and then converts the
|
||||||
// in the range [minval, maxval), of the same type. 'unsigned_type' is the
|
// random bits to random numbers of normal distribution.
|
||||||
// unsigned version of 'type' (could be the same) with the same bit width.
|
// Returns the random numbers and the state of the random number generator.
|
||||||
// The algorithm is the same one that TF uses right now, but it's
|
RngOutput NormalF32Distribution(XlaOp key, XlaOp initial_state,
|
||||||
// uniform only when maxval - minval is a divisor of the range that bits is
|
BitGeneratorTy bit_generator,
|
||||||
// generated from.
|
const xla::Shape& shape);
|
||||||
// TODO(b/72573764): Generate real uniform integer distribution.
|
|
||||||
XlaOp StatelessRngUniformInt(XlaOp bits, XlaOp minval, XlaOp maxval,
|
|
||||||
PrimitiveType type, PrimitiveType unsigned_type);
|
|
||||||
|
|
||||||
// The following 2 functions, for converting between one uint64 and two uint32s,
|
|
||||||
// use the contract "lower 32 bits for the first uint32, higher 32 bits for the
|
|
||||||
// second".
|
|
||||||
ThreeFry2x32State Uint64ToUint32s(XlaOp u64);
|
|
||||||
XlaOp Uint32sToUint64(ThreeFry2x32State u32s);
|
|
||||||
|
|
||||||
} // namespace xla
|
} // namespace xla
|
||||||
|
|
||||||
|
@ -140,5 +140,7 @@ def test_truncated_normal(assert_equal, assert_all_close, dtype, n, y):
|
|||||||
(alpha * normal_pdf(alpha) - beta * normal_pdf(beta)) / z) - (
|
(alpha * normal_pdf(alpha) - beta * normal_pdf(beta)) / z) - (
|
||||||
(normal_pdf(alpha) - normal_pdf(beta)) / z)**2)
|
(normal_pdf(alpha) - normal_pdf(beta)) / z)**2)
|
||||||
actual_variance = np.var(y)
|
actual_variance = np.var(y)
|
||||||
assert_all_close(actual_variance, expected_variance,
|
assert_all_close(
|
||||||
rtol=5e-3 if dtype == dtypes.bfloat16 else 1e-3)
|
actual_variance,
|
||||||
|
expected_variance,
|
||||||
|
rtol=6e-3 if dtype == dtypes.bfloat16 else 1e-3)
|
||||||
|
Loading…
Reference in New Issue
Block a user