Yuanzhong Xu 7c096232b2 [XLA] Avoid reshape to R1 in NormalFloatingPointDistribution
This refactors existing shape split code and reuses it for NormalFloatingPointDistribution

PiperOrigin-RevId: 347643318
Change-Id: Ie3d3c0145cf6a275a5f678ee4258f88c7de18a5a
2020-12-15 11:13:11 -08:00

643 lines
25 KiB
C++

/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/compiler/xla/client/lib/prng.h"
#include <cmath>
#include <vector>
#include "absl/base/casts.h"
#include "tensorflow/compiler/xla/client/lib/constants.h"
#include "tensorflow/compiler/xla/client/lib/math.h"
#include "tensorflow/compiler/xla/client/xla_builder.h"
#include "tensorflow/compiler/xla/util.h"
namespace xla {
xla::XlaOp ConcatScalars(xla::XlaBuilder* builder,
absl::Span<const xla::XlaOp> scalars) {
std::vector<xla::XlaOp> vectors;
absl::c_transform(scalars, std::back_inserter(vectors),
[](xla::XlaOp x) { return xla::Reshape(x, {1}); });
return ConcatInDim(builder, vectors, 0);
}
namespace {
// Rotates a 32-bit integer 'v' left by 'distance' bits.
XlaOp RotateLeftU32(XlaOp v, int distance) {
return (v << ConstantR0<uint32>(v.builder(), distance)) |
ShiftRightLogical(v, ConstantR0<uint32>(v.builder(), 32 - distance));
}
// The internal state of the Three Fry implementation.
using ThreeFry2x32State = std::array<XlaOp, 2>;
// Implements the ThreeFry counter-based PRNG algorithm.
// Salmon et al. SC 2011. Parallel random numbers: as easy as 1, 2, 3.
// http://www.thesalmons.org/john/random123/papers/random123sc11.pdf
ThreeFry2x32State ThreeFry2x32(ThreeFry2x32State input, ThreeFry2x32State key) {
XlaBuilder* builder = input[0].builder();
key[0] = BitcastConvertType(key[0], U32);
key[1] = BitcastConvertType(key[1], U32);
// Rotation distances specified by the Threefry2x32 algorithm.
constexpr std::array<int, 8> rotations = {13, 15, 26, 6, 17, 29, 16, 24};
ThreeFry2x32State x;
std::array<XlaOp, 3> ks;
// 0x1BD11BDA is a parity constant specified by the ThreeFry2x32 algorithm.
ks[2] = ConstantR0<uint32>(builder, 0x1BD11BDA);
for (int i = 0; i < 2; ++i) {
ks[i] = key[i];
x[i] = input[i];
ks[2] = ks[2] ^ key[i];
}
x[0] = x[0] + ks[0];
x[1] = x[1] + ks[1];
// Performs a single round of the Threefry2x32 algorithm, with a rotation
// amount 'rotation'.
auto round = [](ThreeFry2x32State v, int rotation) {
v[0] = v[0] + v[1];
v[1] = RotateLeftU32(v[1], rotation);
v[1] = v[0] ^ v[1];
return v;
};
// There are no known statistical flaws with 13 rounds of Threefry2x32.
// We are conservative and use 20 rounds.
x = round(x, rotations[0]);
x = round(x, rotations[1]);
x = round(x, rotations[2]);
x = round(x, rotations[3]);
x[0] = x[0] + ks[1];
x[1] = x[1] + ks[2] + ConstantR0<uint32>(builder, 1);
x = round(x, rotations[4]);
x = round(x, rotations[5]);
x = round(x, rotations[6]);
x = round(x, rotations[7]);
x[0] = x[0] + ks[2];
x[1] = x[1] + ks[0] + ConstantR0<uint32>(builder, 2);
x = round(x, rotations[0]);
x = round(x, rotations[1]);
x = round(x, rotations[2]);
x = round(x, rotations[3]);
x[0] = x[0] + ks[0];
x[1] = x[1] + ks[1] + ConstantR0<uint32>(builder, 3);
x = round(x, rotations[4]);
x = round(x, rotations[5]);
x = round(x, rotations[6]);
x = round(x, rotations[7]);
x[0] = x[0] + ks[1];
x[1] = x[1] + ks[2] + ConstantR0<uint32>(builder, 4);
x = round(x, rotations[0]);
x = round(x, rotations[1]);
x = round(x, rotations[2]);
x = round(x, rotations[3]);
x[0] = x[0] + ks[2];
x[1] = x[1] + ks[0] + ConstantR0<uint32>(builder, 5);
return x;
}
// Converts a uint64 to two uint32s.
std::array<XlaOp, 2> Uint64ToUint32s(XlaOp u64) {
XlaBuilder* builder = u64.builder();
XlaOp const32 = ConstantR0WithType(builder, U64, 32);
XlaOp fst = ConvertElementType(u64, U32);
XlaOp snd = ConvertElementType(ShiftRightLogical(u64, const32), U32);
return {fst, snd};
}
// Converts two uint32s to a uint64.
XlaOp Uint32sToUint64(std::array<XlaOp, 2> u32s) {
XlaBuilder* builder = u32s[0].builder();
return ConvertElementType(u32s[0], U64) |
ShiftLeft(ConvertElementType(u32s[1], U64),
ConstantR0WithType(builder, U64, 32));
}
// Given the initial state and the request shape of random numbers to be
// generated, returns the input for the random number generator and a new state.
std::pair<ThreeFry2x32State, XlaOp> GetThreeFryInputsAndUpdatedState(
XlaOp initial_state, const Shape& shape) {
XlaBuilder* builder = initial_state.builder();
auto u64_shape = ShapeUtil::MakeShape(U64, shape.dimensions());
// initial_state is an R1, so reshape it to a scalar.
auto input_u64 = Broadcast(Reshape(initial_state, {}), shape.dimensions());
int64 trailing_dims_product = 1;
for (int64 i = shape.rank() - 1; i >= 0; --i) {
if (shape.dimensions(i) < 2) {
continue;
}
input_u64 =
input_u64 + (Iota(builder, u64_shape, i) *
ConstantR0<uint64>(builder, trailing_dims_product));
trailing_dims_product *= shape.dimensions(i);
}
XlaOp new_state =
initial_state + ConstantR0<uint64>(builder, ShapeUtil::ElementsIn(shape));
return std::make_pair(Uint64ToUint32s(input_u64), new_state);
}
// Result for SplitShapeIntoHalves().
struct SplitShapePair {
Shape half_shape;
Shape concat_shape;
int64 split_dim;
int64 new_concat_dim;
};
// Split the shape on a dimension > 1 into two halves.
SplitShapePair SplitShapeIntoHalves(const Shape& shape) {
SplitShapePair pair;
if (shape.rank() == 0) {
pair.half_shape = ShapeUtil::MakeShape(shape.element_type(), {1});
pair.concat_shape = ShapeUtil::MakeShape(shape.element_type(), {2});
pair.split_dim = 0;
pair.new_concat_dim = 0;
return pair;
}
pair.split_dim = -1;
for (int64 i = 0; i < shape.rank(); ++i) {
if (shape.dimensions(i) % 2 == 0) {
pair.split_dim = i;
break;
}
}
if (pair.split_dim == -1) {
// No even dims. Find a dimension with maximum size.
for (int64 i = 0; i < shape.rank(); ++i) {
if (pair.split_dim == -1 ||
shape.dimensions(i) > shape.dimensions(pair.split_dim)) {
pair.split_dim = i;
}
}
}
CHECK_GE(pair.split_dim, 0);
std::vector<int64> half_shape_dims;
std::vector<int64> concat_shape_dims;
for (int64 i = 0; i < shape.rank(); ++i) {
if (i == pair.split_dim) {
// Create a new trivial dim for the later concat, which is more friendly
// to sharding propagation.
half_shape_dims.push_back(CeilOfRatio<int64>(shape.dimensions(i), 2));
half_shape_dims.push_back(1);
concat_shape_dims.push_back(half_shape_dims[i]);
concat_shape_dims.push_back(2);
} else {
half_shape_dims.push_back(shape.dimensions(i));
concat_shape_dims.push_back(shape.dimensions(i));
}
}
pair.new_concat_dim = pair.split_dim + 1;
pair.half_shape = ShapeUtil::MakeShape(shape.element_type(), half_shape_dims);
pair.concat_shape =
ShapeUtil::MakeShape(shape.element_type(), concat_shape_dims);
return pair;
}
// Combines a pair of split shapes. It works with scalar and non-scalar shapes.
XlaOp CombineShapePair(absl::Span<const XlaOp> pair,
const SplitShapePair& shape_pair,
const Shape& original_shape) {
if (original_shape.rank() == 0) {
return Reshape(pair[0], {});
}
XlaBuilder* builder = pair[0].builder();
XlaOp result = ConcatInDim(builder, pair, shape_pair.new_concat_dim);
const int64 pre_split_size = original_shape.dimensions(shape_pair.split_dim);
std::vector<int64> reshape_dims(original_shape.dimensions().begin(),
original_shape.dimensions().end());
reshape_dims[shape_pair.split_dim] =
RoundUpToNearest<int64>(pre_split_size, 2);
result = Reshape(result, reshape_dims);
if (reshape_dims[shape_pair.split_dim] != pre_split_size) {
result = Slice(result, std::vector<int64>(original_shape.rank(), 0),
original_shape.dimensions(),
std::vector<int64>(original_shape.rank(), 1));
}
return result;
}
// Generates random 32bits with the given shape using the Three Fry
// implementation. Returns the random bits and the new state.
RngOutput ThreeFryRngBit32(XlaOp key, XlaOp initial_state, const Shape& shape) {
auto shape_pair = SplitShapeIntoHalves(shape);
std::pair<ThreeFry2x32State, XlaOp> inputs_state =
GetThreeFryInputsAndUpdatedState(initial_state, shape_pair.half_shape);
ThreeFry2x32State inputs = inputs_state.first;
ThreeFry2x32State outputs = ThreeFry2x32(inputs, Uint64ToUint32s(key));
XlaOp result = CombineShapePair(outputs, shape_pair, shape);
return {result, inputs_state.second};
}
// Generates random 64bits with the given shape using the Three Fry
// implementation. Returns the random bits and the new state.
RngOutput ThreeFryRngBit64(XlaOp key, XlaOp initial_state, const Shape& shape) {
std::pair<ThreeFry2x32State, XlaOp> inputs_state =
GetThreeFryInputsAndUpdatedState(initial_state, shape);
ThreeFry2x32State inputs = inputs_state.first;
ThreeFry2x32State outputs = ThreeFry2x32(inputs, Uint64ToUint32s(key));
XlaOp result = Uint32sToUint64(outputs);
return {result, inputs_state.second};
}
// The key of the Philox random number generator.
using Philox4x32Key = std::array<XlaOp, 2>;
// The internal state of the Philox random number generator.
using Philox4x32State = std::array<XlaOp, 4>;
// Computes the Philox4x32 algorithm using 10 rounds.
Philox4x32State Philox4x32(Philox4x32State state, Philox4x32Key key) {
// Constants specified by the Philox algorithm.
static const uint32 kPhiloxW32A = 0x9E3779B9;
static const uint32 kPhiloxW32B = 0xBB67AE85;
static const uint32 kPhiloxM4x32A = 0xD2511F53;
static const uint32 kPhiloxM4x32B = 0xCD9E8D57;
struct HighLowPair {
XlaOp high;
XlaOp low;
};
// Compute the high and low words from multiplying two 32-bit integers.
auto mul_hi_low = [](XlaOp x, uint32 k) {
auto product =
ConvertElementType(x, U64) * ConstantR0<uint64>(x.builder(), k);
auto low = ConvertElementType(product, U32);
auto high =
ConvertElementType(product >> ConstantR0<uint64>(x.builder(), 32), U32);
return HighLowPair{high, low};
};
// Perform a single round of the Philox algorithm.
auto philox_round = [&](Philox4x32State x, Philox4x32Key key) {
auto product0 = mul_hi_low(x[0], kPhiloxM4x32A);
auto product1 = mul_hi_low(x[2], kPhiloxM4x32B);
return Philox4x32State{product1.high ^ x[1] ^ key[0], product1.low,
product0.high ^ x[3] ^ key[1], product0.low};
};
// Update the key after a round of Philox algorithm.
auto raise_key = [](Philox4x32Key key) {
XlaBuilder* builder = key[0].builder();
return Philox4x32Key{key[0] + ConstantR0<uint32>(builder, kPhiloxW32A),
key[1] + ConstantR0<uint32>(builder, kPhiloxW32B)};
};
static const int kNumRounds = 10;
for (int round = 0; round < kNumRounds; ++round, key = raise_key(key)) {
state = philox_round(state, key);
}
return state;
}
// Scrambles the input key so that users don't need to worry about which part
// of the key needs to be strong.
std::pair<Philox4x32State, Philox4x32Key> ScramblePhiloxKey(Philox4x32Key key) {
XlaBuilder* builder = key[0].builder();
XlaOp key0 = ConvertElementType(key[0], U64);
XlaOp key1 = ConvertElementType(key[1], U64);
Philox4x32State state = {
ConvertElementType(key0, U32),
ConvertElementType(key0 >> ScalarLike(key0, 32), U32),
ConvertElementType(key1, U32),
ConvertElementType(key1 >> ScalarLike(key1, 32), U32),
};
key = {ConstantR0<uint32>(builder, 0x3ec8f720),
ConstantR0<uint32>(builder, 0x02461e29)};
state = Philox4x32(state, key);
XlaOp zero = ConstantR0<uint32>(builder, 0);
return {Philox4x32State{zero, zero, state[2], state[3]},
Philox4x32Key{state[0], state[1]}};
}
// Adds an U128 tensor with an U64 tensor. The U128 tensor is represented as two
// U64s with the low 64bits in the front. This routine supports explicit
// broadcasting of the U128 tensor, with `broadcast_sizes` representing the
// dimensions prepended to its shape.
std::array<XlaOp, 2> Uint128AddUint64(
const std::array<XlaOp, 2>& u128, XlaOp u64,
absl::Span<const int64> broadcast_sizes = {}) {
auto u128_low = u128[0];
auto u128_high = u128[1];
XlaOp new_u128_low = u128_low + u64;
XlaOp one = ConstantR0<uint64>(u128[0].builder(), 1);
XlaOp new_u128_high = Select(Lt(new_u128_low, u128_low),
Broadcast(u128_high + one, broadcast_sizes),
Broadcast(u128_high, broadcast_sizes));
return {new_u128_low, new_u128_high};
}
std::array<XlaOp, 2> Uint32sToUint128(const std::array<XlaOp, 4>& u32s) {
return {Uint32sToUint64({u32s[0], u32s[1]}),
Uint32sToUint64({u32s[2], u32s[3]})};
}
std::array<XlaOp, 4> Uint128ToUint32s(const std::array<XlaOp, 2>& u128) {
std::array<XlaOp, 2> u128_low_32s = Uint64ToUint32s(u128[0]);
std::array<XlaOp, 2> u128_high_32s = Uint64ToUint32s(u128[1]);
return {u128_low_32s[0], u128_low_32s[1], u128_high_32s[0], u128_high_32s[1]};
}
std::array<XlaOp, 2> Uint128FromOp(XlaOp op) {
auto u128_low = xla::Reshape(xla::Slice(op, {0}, {1}, {1}), {});
auto u128_high = xla::Reshape(xla::Slice(op, {1}, {2}, {1}), {});
return {u128_low, u128_high};
}
XlaOp Uint128ToOp(std::array<XlaOp, 2> u128) {
return ConcatScalars(u128[0].builder(), {u128[0], u128[1]});
}
// Returns the pair (state + [0, 1, ..., n-1], state + n), which should be used
// as the inputs fed to `Philox4x32` and the updated state. `state` is an U128
// represented as 4 U32s in the order from the least significant one to the most
// significant one.
std::pair<Philox4x32State, XlaOp> GetPhiloxInputsAndUpdatedState(
const Philox4x32State& state, int64 n) {
XlaBuilder* builder = state[0].builder();
XlaOp iota = Iota(builder, U64, n);
auto state_u128 = Uint32sToUint128(state);
auto inputs = Uint128ToUint32s(Uint128AddUint64(state_u128, iota, {n}));
XlaOp new_state =
Uint128ToOp(Uint128AddUint64(state_u128, ConstantR0<uint64>(builder, n)));
return std::make_pair(inputs, new_state);
}
// Generates CeilOfRatio(num_elems, 4)*4 32bit Philox random numbers, as Philox
// numbers are generated in the unit of 128bits.
std::pair<Philox4x32State, XlaOp> GeneratePhiloxBits(int64 num_elems,
XlaOp initial_state,
Philox4x32Key key) {
Philox4x32State state;
state = Uint128ToUint32s(Uint128FromOp(initial_state));
const int64 num_vector4 = CeilOfRatio<int64>(num_elems, 4);
Philox4x32State inputs;
XlaOp new_state;
std::tie(inputs, new_state) =
GetPhiloxInputsAndUpdatedState(state, num_vector4);
auto outputs = Philox4x32(inputs, key);
return std::make_pair(outputs, new_state);
}
// Generates an array of primitive type U32 with the given shape containing
// random bits generated by the Philox algorithm. Returns the array and the new
// state of the random number generator.
RngOutput PhiloxRngBit32(XlaOp op_key, XlaOp initial_state,
const Shape& shape) {
XlaBuilder* builder = op_key.builder();
const int64 num_elems = ShapeUtil::ElementsIn(shape);
Philox4x32Key key = Uint64ToUint32s(op_key);
Philox4x32State bits;
XlaOp new_state;
std::tie(bits, new_state) = GeneratePhiloxBits(num_elems, initial_state, key);
// Combining bits[i] in a round-robin fashion, to align with non-XLA
// implementations
int64 bits_len = (num_elems + 3) / 4;
for (auto i = 0; i < 4; ++i) {
bits[i] = Reshape(bits[i], {bits_len, 1});
}
XlaOp numbers = ConcatInDim(builder, {bits[0], bits[1], bits[2], bits[3]},
/*dimension=*/1);
numbers = Reshape(numbers, {bits_len * 4});
numbers = Slice(numbers, /*start_indices=*/{0},
/*limit_indices=*/{num_elems},
/*strides=*/{1});
return {Reshape(numbers, AsInt64Slice(shape.dimensions())), new_state};
}
// Generates an array of primitive type U64 with the given shape containing
// random bits generated by the Philox algorithm. Returns the array and the new
// state of the random number generator.
RngOutput PhiloxRngBit64(XlaOp op_key, XlaOp initial_state,
const Shape& shape) {
XlaBuilder* builder = op_key.builder();
const int64 num_elems = ShapeUtil::ElementsIn(shape);
Philox4x32Key key = Uint64ToUint32s(op_key);
Philox4x32State bits32;
XlaOp new_state;
std::tie(bits32, new_state) =
GeneratePhiloxBits(num_elems * 2, initial_state, key);
std::array<XlaOp, 2> bits64;
bits64[0] = Uint32sToUint64({bits32[0], bits32[1]});
bits64[1] = Uint32sToUint64({bits32[2], bits32[3]});
// Combining bits64[i] in a round-robin fashion, to align with non-XLA
// implementations
int64 bits64_len = (num_elems + 1) / 2;
for (auto i = 0; i < 2; ++i) {
bits64[i] = Reshape(bits64[i], {bits64_len, 1});
}
XlaOp numbers = ConcatInDim(builder, {bits64[0], bits64[1]},
/*dimension=*/1);
numbers = Reshape(numbers, {bits64_len * 2});
numbers = Slice(numbers, /*start_indices=*/{0},
/*limit_indices=*/{num_elems},
/*strides=*/{1});
return {Reshape(numbers, AsInt64Slice(shape.dimensions())), new_state};
}
XlaOp ConvertRandomBitsToUniformFloatingPoint(XlaOp bits, XlaOp minval,
XlaOp maxval) {
XlaBuilder* builder = bits.builder();
return builder->ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
TF_ASSIGN_OR_RETURN(const Shape* minval_shape,
builder->GetShapePtr(minval));
TF_ASSIGN_OR_RETURN(const Shape* bits_shape, builder->GetShapePtr(bits));
PrimitiveType value_type = minval_shape->element_type();
PrimitiveType bit_type = bits_shape->element_type();
CHECK((value_type == F32 && bit_type == U32) ||
(value_type == F64 && bit_type == U64));
// Form random mantissa bits for float/double, with a leading 1 bit.
int num_float_bits = primitive_util::BitWidth(value_type);
// Subtract one as SignificandWidth includes the leading 1 bit.
int num_mantissa_bits = primitive_util::SignificandWidth(value_type) - 1;
// Ignore the exponent bits and convert the mantissa bits to the floating
// point type.
bits = ShiftRightLogical(
bits, ScalarLike(bits, num_float_bits - num_mantissa_bits));
// We have an integer-valued floating point number in the range
// [0, 2**{num_mantissa_bits}).
XlaOp values = ConvertElementType(bits, value_type);
// Divide by 2**{-num_mantissa_bits} to get a number in the range
// [0.0, 1.0).
values = values * ScalarLike(values, std::ldexp(1., -num_mantissa_bits));
// Multiply and add to shift to the range [minval, maxval).
return values * (maxval - minval) + minval;
});
}
XlaOp ConvertRandomBitsToUniformInt(XlaOp bits, XlaOp minval, XlaOp maxval,
PrimitiveType type,
PrimitiveType unsigned_type) {
XlaBuilder* builder = bits.builder();
XlaOp range = BitcastConvertType(maxval, unsigned_type) -
BitcastConvertType(minval, unsigned_type);
XlaOp dist = Rem(bits, range);
XlaOp dist_div_2 =
ShiftRightLogical(dist, ConstantR0WithType(builder, unsigned_type, 1));
return minval + BitcastConvertType(dist_div_2, type) +
BitcastConvertType(dist - dist_div_2, type);
}
// Implements the Box-Muller transform, which converts random floats in the
// range of [0, 1] from uniform distribution to normal distribution with mean 0
// and variance 1. For more detail on the Box-Muller transform, see
// http://en.wikipedia.org/wiki/Box%E2%80%93Muller_transform#Basic_form
std::pair<XlaOp, XlaOp> BoxMullerTransform(XlaOp x0, XlaOp x1) {
// Do not send a really small number to log().
XlaOp u1 = Max(x0, ScalarLike(x0, 1.0e-7f));
XlaOp v1 = ScalarLike(x1, 2.0f * M_PI) * x1;
XlaOp u2 = Sqrt(ScalarLike(u1, -2.0f) * Log(u1));
return {Sin(v1) * u2, Cos(v1) * u2};
}
} // namespace
XlaOp PhiloxIncreaseCounter(XlaOp counter, XlaOp delta) {
return Uint128ToOp(Uint128AddUint64(Uint128FromOp(counter), delta));
}
RngOutput ThreeFryBitGenerator(XlaOp key, XlaOp initial_state,
const Shape& shape) {
PrimitiveType type = shape.element_type();
switch (type) {
case F32:
case U32:
case S32:
return ThreeFryRngBit32(key, initial_state, shape);
case F64:
case U64:
case S64:
return ThreeFryRngBit64(key, initial_state, shape);
default:
return {key.builder()->ReportError(Unimplemented(
"Types other than F32, F64, U32, S32, U64 and S64 "
"are not implemented by ThreeFryBitGenerator; got %s",
primitive_util::LowercasePrimitiveTypeName(type))),
initial_state};
}
}
RngOutput PhiloxBitGenerator(XlaOp key, XlaOp initial_state,
const Shape& shape) {
PrimitiveType type = shape.element_type();
switch (type) {
case F32:
case U32:
case S32:
return PhiloxRngBit32(key, initial_state, shape);
case F64:
case U64:
case S64:
return PhiloxRngBit64(key, initial_state, shape);
default:
return {key.builder()->ReportError(Unimplemented(
"Types other than F32, F64, U32, S32, U64 and S64 "
"are not implemented by PhiloxFryBitGenerator; got %s",
primitive_util::LowercasePrimitiveTypeName(type))),
initial_state};
}
}
std::pair<XlaOp, XlaOp> ScramblePhiloxKey(XlaOp key) {
Philox4x32Key pkey = Uint64ToUint32s(key);
auto state_key = ScramblePhiloxKey(pkey);
return std::make_pair(Uint128ToOp(Uint32sToUint128(state_key.first)),
Uint32sToUint64(state_key.second));
}
RngOutput UniformFloatingPointDistribution(XlaOp key, XlaOp initial_state,
BitGeneratorTy bit_generator,
XlaOp minval, XlaOp maxval,
const Shape& shape) {
RngOutput bits_state = bit_generator(key, initial_state, shape);
XlaOp bits = bits_state.value;
XlaOp new_state = bits_state.state;
return {ConvertRandomBitsToUniformFloatingPoint(bits, minval, maxval),
new_state};
}
RngOutput UniformIntDistribution(XlaOp key, XlaOp initial_state,
BitGeneratorTy bit_generator, XlaOp minval,
XlaOp maxval, const Shape& shape) {
RngOutput bits_state = bit_generator(key, initial_state, shape);
XlaOp bits = bits_state.value;
XlaOp new_state = bits_state.state;
PrimitiveType type = shape.element_type();
PrimitiveType unsigned_type;
if (type == U32 || type == S32) {
unsigned_type = U32;
} else {
DCHECK(type == U64 || type == S64);
unsigned_type = U64;
}
return {
ConvertRandomBitsToUniformInt(bits, minval, maxval, type, unsigned_type),
new_state};
}
RngOutput NormalFloatingPointDistribution(XlaOp key, XlaOp initial_state,
BitGeneratorTy bit_generator,
const Shape& shape) {
PrimitiveType primitive_type = shape.element_type();
DCHECK(primitive_type == F32 || primitive_type == F64);
XlaBuilder* builder = key.builder();
auto shape_pair = SplitShapeIntoHalves(shape);
RngOutput bits_state = UniformFloatingPointDistribution(
key, initial_state, bit_generator,
xla::ConstantR0WithType(builder, primitive_type, 0.0),
xla::ConstantR0WithType(builder, primitive_type, 1.0),
shape_pair.concat_shape);
// Separate the bits into two groups to perform the Box-Muller transform.
XlaOp bits_0 = Slice(bits_state.value,
std::vector<int64>(shape_pair.half_shape.rank(), 0),
shape_pair.half_shape.dimensions(),
std::vector<int64>(shape_pair.half_shape.rank(), 1));
std::vector<int64> bits_1_starts(shape_pair.half_shape.rank(), 0);
bits_1_starts[shape_pair.new_concat_dim] = 1;
XlaOp bits_1 = Slice(bits_state.value, bits_1_starts,
shape_pair.concat_shape.dimensions(),
std::vector<int64>(shape_pair.half_shape.rank(), 1));
std::tie(bits_0, bits_1) = BoxMullerTransform(bits_0, bits_1);
// Put the numbers in the two groups back to form the requested shape.
XlaOp normal = CombineShapePair({bits_0, bits_1}, shape_pair, shape);
return {normal, bits_state.state};
}
} // namespace xla