diff --git a/tensorflow/compiler/tf2xla/kernels/categorical_op.cc b/tensorflow/compiler/tf2xla/kernels/categorical_op.cc index a99c6ee4431..88c2ec0ed5c 100644 --- a/tensorflow/compiler/tf2xla/kernels/categorical_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/categorical_op.cc @@ -15,6 +15,7 @@ limitations under the License. // XLA implementations of Categorical op. +#include "tensorflow/compiler/tf2xla/kernels/random_ops_util.h" #include "tensorflow/compiler/tf2xla/shape_util.h" #include "tensorflow/compiler/tf2xla/type_util.h" #include "tensorflow/compiler/tf2xla/xla_helpers.h" @@ -140,8 +141,6 @@ class StatelessCategoricalOp : public CategoricalOp { xla::XlaOp GetLogUniforms(xla::Shape uniform_shape, xla::PrimitiveType type, XlaOpKernelContext* ctx) override { xla::XlaOp seed = ctx->Input(2); - auto seed0 = xla::Reshape(xla::Slice(seed, {0}, {1}, {1}), {}); - auto seed1 = xla::Reshape(xla::Slice(seed, {1}, {2}, {1}), {}); xla::XlaBuilder* builder = ctx->builder(); if (uniform_shape.element_type() == xla::BF16) { @@ -150,8 +149,8 @@ class StatelessCategoricalOp : public CategoricalOp { // We want a number in (0, 1) rather than [0, 1) or (0, 1]: // * log(-log(0)) is ∞. // * log(-log(1)) is -∞. - auto uniforms = xla::StatelessRngUniform( - {seed0, seed1}, uniform_shape, + xla::XlaOp uniforms = StatelessRngUniform( + seed, uniform_shape, xla::MinPositiveNormalValue(builder, uniform_shape.element_type()), xla::One(builder, uniform_shape.element_type())); return xla::ConvertElementType(xla::Log(-xla::Log(uniforms)), type); diff --git a/tensorflow/compiler/tf2xla/kernels/random_ops_util.h b/tensorflow/compiler/tf2xla/kernels/random_ops_util.h index d107be6f13c..c7b1fca78d4 100644 --- a/tensorflow/compiler/tf2xla/kernels/random_ops_util.h +++ b/tensorflow/compiler/tf2xla/kernels/random_ops_util.h @@ -22,6 +22,13 @@ limitations under the License. #include "tensorflow/core/framework/tensor.h" namespace tensorflow { +// Returns a tensor containing 'shape' random values uniformly distributed in +// the range [minval, maxval). The raw random bits are generated by the given +// `bit_generator` and converted to the requested data type and range. This +// routine requires 2 32-bit integer seeds and currently only supports 'shape's +// of type F32, S32 and S64. +xla::XlaOp StatelessRngUniform(xla::XlaOp seeds, const xla::Shape& shape, + xla::XlaOp minval, xla::XlaOp maxval); // Converts to bfloat16 if `dtype` equals DT_BFLOAT16, no-op otherwise. // It masks the last 16 bit. With normal rounding, values near "maxval" would be diff --git a/tensorflow/compiler/tf2xla/kernels/stateful_random_ops.cc b/tensorflow/compiler/tf2xla/kernels/stateful_random_ops.cc index ccdd1194916..cd9a6ee9f4b 100644 --- a/tensorflow/compiler/tf2xla/kernels/stateful_random_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/stateful_random_ops.cc @@ -35,127 +35,50 @@ limitations under the License. namespace tensorflow { namespace { -std::pair GetInputsFromCounter( - xla::XlaOp counter, const int64 size) { - auto builder = counter.builder(); - auto input_u64 = Iota(builder, xla::U64, size); - input_u64 = input_u64 + counter; - counter = counter + xla::ConstantR0(builder, size); - return std::make_pair(xla::Uint64ToUint32s(input_u64), counter); -} - -// `StatelessRngUniformU32` uses ThreeFry2x32’s counter space too -// wastefully, only able to generate 2^32*2 int32 numbers for each key, while -// the real capacity is 2^64*2. Counter-space efficiency is important for -// stateful ops, hence the following 2 new functions. -std::pair StatefulRngUniformU32( - xla::XlaOp key, xla::XlaOp counter, const xla::Shape& shape) { - auto builder = key.builder(); - const int64 size = xla::ShapeUtil::ElementsIn(shape); - const int64 half_size = xla::CeilOfRatio(size, 2); - const bool size_is_odd = (half_size * 2 != size); - auto inputs_counter = GetInputsFromCounter(counter, half_size); - auto inputs = inputs_counter.first; - counter = inputs_counter.second; - auto outputs = xla::ThreeFry2x32(inputs, xla::Uint64ToUint32s(key)); - if (size_is_odd) { - outputs[1] = Slice(outputs[1], {0}, {half_size - 1}, {1}); - } - auto result = ConcatInDim(builder, outputs, 0); - return std::make_pair(Reshape(result, xla::AsInt64Slice(shape.dimensions())), - counter); -} - -std::pair StatefulRngUniformU64( - xla::XlaOp key, xla::XlaOp counter, const xla::Shape& shape) { - const int64 size = xla::ShapeUtil::ElementsIn(shape); - auto inputs_counter = GetInputsFromCounter(counter, size); - auto inputs = inputs_counter.first; - counter = inputs_counter.second; - auto outputs = ThreeFry2x32(inputs, Uint64ToUint32s(key)); - auto result = Uint32sToUint64(outputs); - return std::make_pair(Reshape(result, xla::AsInt64Slice(shape.dimensions())), - counter); -} - -std::pair StatefulRngUniform(xla::XlaOp key, - xla::XlaOp counter, - const xla::Shape& shape, - xla::XlaOp minval, - xla::XlaOp maxval) { - auto builder = key.builder(); +xla::RngOutput StatefulRngUniform(xla::XlaOp key, xla::XlaOp initial_state, + const xla::Shape& shape, xla::XlaOp minval, + xla::XlaOp maxval) { xla::PrimitiveType type = shape.element_type(); switch (type) { - case xla::F32: { - auto bits_counter = StatefulRngUniformU32(key, counter, shape); - auto bits = bits_counter.first; - counter = bits_counter.second; - return std::make_pair(xla::StatelessRngUniformF32(bits, minval, maxval), - counter); - } - case xla::U32: // fall through - case xla::S32: { - auto bits_counter = StatefulRngUniformU32(key, counter, shape); - auto bits = bits_counter.first; - counter = bits_counter.second; - return std::make_pair( - xla::StatelessRngUniformInt(bits, minval, maxval, type, xla::U32), - counter); - } - case xla::U64: // fall through - case xla::S64: { - auto bits_counter = StatefulRngUniformU64(key, counter, shape); - auto bits = bits_counter.first; - counter = bits_counter.second; - return std::make_pair( - xla::StatelessRngUniformInt(bits, minval, maxval, type, xla::U64), - counter); - } + case xla::F32: + return xla::UniformF32Distribution( + key, initial_state, xla::ThreeFryBitGenerator, minval, maxval, shape); + case xla::U32: + case xla::S32: + case xla::U64: + case xla::S64: + return UniformIntDistribution( + key, initial_state, xla::ThreeFryBitGenerator, minval, maxval, shape); default: - return std::make_pair( - builder->ReportError(xla::Unimplemented( - "Types other than F32, U32, S32, U64 and S64 " - "are not implemented by " - "StatefulRngUniform; got: %s", - xla::primitive_util::LowercasePrimitiveTypeName(type))), - counter); + return {key.builder()->ReportError(xla::Unimplemented( + "Types other than F32, U32, S32, U64 and S64 " + "are not implemented by " + "StatefulRngUniform; got %s", + xla::primitive_util::LowercasePrimitiveTypeName(type))), + initial_state}; } } -template -std::pair map_first(std::function f, std::pair p) { - return std::make_pair(f(p.first), p.second); -} - -std::pair StatefulRngUniformFullInt( - xla::XlaOp key, xla::XlaOp counter, const xla::Shape& shape) { +xla::RngOutput StatefulRngUniformFullInt(xla::XlaOp key, + xla::XlaOp initial_state, + const xla::Shape& shape) { xla::PrimitiveType type = shape.element_type(); + xla::RngOutput output = xla::ThreeFryBitGenerator(key, initial_state, shape); switch (type) { case xla::U32: - return StatefulRngUniformU32(key, counter, shape); - case xla::S32: { - // Needs explicit function type because of type-inference failure. - std::function f = [](xla::XlaOp x) { - return BitcastConvertType(x, xla::S32); - }; - return map_first(f, StatefulRngUniformU32(key, counter, shape)); - } case xla::U64: - return StatefulRngUniformU64(key, counter, shape); - case xla::S64: { - std::function f = [](xla::XlaOp x) { - return BitcastConvertType(x, xla::S64); - }; - return map_first(f, StatefulRngUniformU64(key, counter, shape)); - } + return output; + case xla::S32: + case xla::S64: + output.value = BitcastConvertType(output.value, type); + return output; default: - auto builder = key.builder(); - return std::make_pair( - builder->ReportError(xla::Unimplemented( + return { + key.builder()->ReportError(xla::Unimplemented( "Types other than U32, S32, U64 and S64 are not implemented by " "StatefulRngUniformFullInt; got: %s", xla::primitive_util::LowercasePrimitiveTypeName(type))), - counter); + initial_state}; } } @@ -177,15 +100,15 @@ xla::XlaOp ConcatScalars(xla::XlaBuilder* builder, 0); } -using sampler_return_type = xla::StatusOr>; +using SamplerReturnType = xla::StatusOr; // A helper function containing the common part of several kernels below. // Precondition: 'algorithm' and 'shape' are compile-time constants. -Status CompileImpl(XlaOpKernelContext* ctx, int state_input_idx, - int alg_input_idx, int shape_input_idx, - std::function const& - sample_with_threefry) { +Status CompileImpl( + XlaOpKernelContext* ctx, int state_input_idx, int alg_input_idx, + int shape_input_idx, + std::function const& + sampler) { auto alg_shape = ctx->InputShape(alg_input_idx); if (alg_shape.dims() != 0) { return errors::InvalidArgument("algorithm must be of shape [], not ", @@ -215,24 +138,22 @@ Status CompileImpl(XlaOpKernelContext* ctx, int state_input_idx, TensorShape shape; TF_RETURN_IF_ERROR(ctx->ConstantInputAsShape(shape_input_idx, &shape)); - static constexpr int COUNTER_SIZE = 1; - auto counter = BitcastConvertType( - xla::Reshape(xla::Slice(var, {0}, {COUNTER_SIZE}, {1}), {}), xla::U64); + static constexpr int kStateSize = 1; + auto state = BitcastConvertType( + xla::Reshape(xla::Slice(var, {0}, {kStateSize}, {1}), {}), xla::U64); auto key = BitcastConvertType( - xla::Reshape(xla::Slice(var, {COUNTER_SIZE}, {COUNTER_SIZE + 1}, {1}), - {}), + xla::Reshape(xla::Slice(var, {kStateSize}, {kStateSize + 1}, {1}), {}), xla::U64); - auto status_or_value = sample_with_threefry(counter, key, shape); + auto status_or_value = sampler(state, key, shape); if (!status_or_value.ok()) { return status_or_value.status(); } - auto output_counter = status_or_value.ConsumeValueOrDie(); - auto output = output_counter.first; - counter = output_counter.second; - ctx->SetOutput(0, output); - auto builder = ctx->builder(); - var = ConcatScalars(builder, {counter, key}); + xla::RngOutput value_state = status_or_value.ConsumeValueOrDie(); + state = value_state.state; + ctx->SetOutput(0, value_state.value); + xla::XlaBuilder* builder = ctx->builder(); + var = ConcatScalars(builder, {state, key}); xla::PrimitiveType state_element_type; TF_RETURN_IF_ERROR( DataTypeToPrimitiveType(STATE_ELEMENT_DTYPE, &state_element_type)); @@ -252,23 +173,22 @@ class StatefulUniformOp : public XlaOpKernel { } void Compile(XlaOpKernelContext* ctx) override { - auto builder = ctx->builder(); - auto sample_with_threefry = [builder, this]( - xla::XlaOp counter, xla::XlaOp key, - TensorShape shape) -> sampler_return_type { + xla::XlaBuilder* builder = ctx->builder(); + auto sampler = [builder, this](xla::XlaOp state, xla::XlaOp key, + TensorShape shape) -> SamplerReturnType { xla::Shape xla_shape; TF_RETURN_IF_ERROR(TensorShapeToXLAShape(DT_FLOAT, shape, &xla_shape)); - auto uniform_counter = StatefulRngUniform( - key, counter, xla_shape, xla::ConstantR0(builder, 0.0), + xla::RngOutput uniform_state = StatefulRngUniform( + key, state, xla_shape, xla::ConstantR0(builder, 0.0), xla::ConstantR0(builder, 1.0)); - auto uniform = uniform_counter.first; - counter = uniform_counter.second; + xla::XlaOp uniform = uniform_state.value; + state = uniform_state.state; uniform = MaybeConvertF32ToBF16(uniform, dtype_); - return {{uniform, counter}}; + return {{uniform, state}}; }; OP_REQUIRES_OK(ctx, CompileImpl(ctx, /*state_input_idx=*/0, /*alg_input_idx=*/1, - /*shape_input_idx=*/2, sample_with_threefry)); + /*shape_input_idx=*/2, sampler)); } private: @@ -293,30 +213,20 @@ class StatefulStandardNormalOp : public XlaOpKernel { } void Compile(XlaOpKernelContext* ctx) override { - auto builder = ctx->builder(); - auto sample_with_threefry = + auto sampler = // Needs explicit lambda return type because it fails to be inferred. - [builder, this](xla::XlaOp counter, xla::XlaOp key, - TensorShape shape) -> sampler_return_type { + [this](xla::XlaOp state, xla::XlaOp key, + TensorShape shape) -> SamplerReturnType { xla::Shape xla_shape; TF_RETURN_IF_ERROR(TensorShapeToXLAShape(DT_FLOAT, shape, &xla_shape)); - - auto uniform_counter = StatefulRngUniform( - key, counter, xla_shape, - xla::ConstantR0(builder, std::nextafter(-1.0f, 0.0f)), - xla::ConstantR0(builder, 1.0)); - auto uniform = uniform_counter.first; - counter = uniform_counter.second; - // Convert uniform distribution to normal distribution by computing - // sqrt(2) * erfinv(x) - auto normal = - xla::ScalarLike(uniform, std::sqrt(2.0)) * xla::ErfInv(uniform); - normal = MaybeConvertF32ToBF16(normal, dtype_); - return {{normal, counter}}; + xla::RngOutput value_state = xla::NormalF32Distribution( + key, state, xla::ThreeFryBitGenerator, xla_shape); + xla::XlaOp normal = MaybeConvertF32ToBF16(value_state.value, dtype_); + return {{normal, value_state.state}}; }; OP_REQUIRES_OK(ctx, CompileImpl(ctx, /*state_input_idx=*/0, /*alg_input_idx=*/1, - /*shape_input_idx=*/2, sample_with_threefry)); + /*shape_input_idx=*/2, sampler)); } private: @@ -341,27 +251,27 @@ class StatefulTruncatedNormalOp : public XlaOpKernel { } void Compile(XlaOpKernelContext* ctx) override { - auto builder = ctx->builder(); - auto sample_with_threefry = + xla::XlaBuilder* builder = ctx->builder(); + auto sampler = // Needs explicit lambda return type because it fails to be inferred. - [builder, this](xla::XlaOp counter, xla::XlaOp key, - TensorShape shape) -> sampler_return_type { + [builder, this](xla::XlaOp state, xla::XlaOp key, + TensorShape shape) -> SamplerReturnType { xla::Shape xla_shape; TF_RETURN_IF_ERROR(TensorShapeToXLAShape(DT_FLOAT, shape, &xla_shape)); - auto uniform_counter = StatefulRngUniform( - key, counter, xla_shape, + xla::RngOutput uniform_result = StatefulRngUniform( + key, state, xla_shape, xla::MinPositiveNormalValue(builder, xla_shape.element_type()), xla::One(builder, xla_shape.element_type())); - auto uniform = uniform_counter.first; - counter = uniform_counter.second; + xla::XlaOp uniform = uniform_result.value; + state = uniform_result.state; xla::XlaOp truncated_normal = TruncatedNormal(uniform); truncated_normal = MaybeConvertF32ToBF16(truncated_normal, dtype_); - return {{truncated_normal, counter}}; + return {{truncated_normal, state}}; }; OP_REQUIRES_OK(ctx, CompileImpl(ctx, /*state_input_idx=*/0, /*alg_input_idx=*/1, - /*shape_input_idx=*/2, sample_with_threefry)); + /*shape_input_idx=*/2, sampler)); } private: @@ -388,11 +298,11 @@ class StatefulUniformIntOp : public XlaOpKernel { xla::XlaOp minval = ctx->Input(3); xla::XlaOp maxval = ctx->Input(4); auto sample_with_threefry = [minval, maxval, this]( - xla::XlaOp counter, xla::XlaOp key, - TensorShape shape) -> sampler_return_type { + xla::XlaOp state, xla::XlaOp key, + TensorShape shape) -> SamplerReturnType { xla::Shape xla_shape; TF_RETURN_IF_ERROR(TensorShapeToXLAShape(dtype_, shape, &xla_shape)); - return StatefulRngUniform(key, counter, xla_shape, minval, maxval); + return StatefulRngUniform(key, state, xla_shape, minval, maxval); }; OP_REQUIRES_OK(ctx, CompileImpl(ctx, /*state_input_idx=*/0, /*alg_input_idx=*/1, @@ -420,12 +330,11 @@ class StatefulUniformFullIntOp : public XlaOpKernel { } void Compile(XlaOpKernelContext* ctx) override { - auto sample_with_threefry = [this]( - xla::XlaOp counter, xla::XlaOp key, - TensorShape shape) -> sampler_return_type { + auto sample_with_threefry = [this](xla::XlaOp state, xla::XlaOp key, + TensorShape shape) -> SamplerReturnType { xla::Shape xla_shape; TF_RETURN_IF_ERROR(TensorShapeToXLAShape(dtype_, shape, &xla_shape)); - return StatefulRngUniformFullInt(key, counter, xla_shape); + return StatefulRngUniformFullInt(key, state, xla_shape); }; OP_REQUIRES_OK(ctx, CompileImpl(ctx, /*state_input_idx=*/0, /*alg_input_idx=*/1, diff --git a/tensorflow/compiler/tf2xla/kernels/stateless_random_ops.cc b/tensorflow/compiler/tf2xla/kernels/stateless_random_ops.cc index 91230de0029..ea6a260f022 100644 --- a/tensorflow/compiler/tf2xla/kernels/stateless_random_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/stateless_random_ops.cc @@ -36,8 +36,8 @@ namespace tensorflow { xla::XlaOp MaybeConvertF32ToBF16(xla::XlaOp input, DataType dtype) { if (dtype == DT_BFLOAT16) { xla::XlaBuilder* builder = input.builder(); - auto output = xla::BitcastConvertType(input, xla::U32) & - xla::ConstantR0(builder, 0xFFFF0000); + xla::XlaOp output = xla::BitcastConvertType(input, xla::U32) & + xla::ConstantR0(builder, 0xFFFF0000); return xla::ConvertElementType(xla::BitcastConvertType(output, xla::F32), xla::BF16); } else { @@ -45,22 +45,36 @@ xla::XlaOp MaybeConvertF32ToBF16(xla::XlaOp input, DataType dtype) { } } -xla::XlaOp Uniform2NormalUsingSqrtErfinv(xla::XlaOp uniform) { - // Convert uniform distribution to normal distribution by computing - // sqrt(2) * erfinv(x) - return xla::ScalarLike(uniform, std::sqrt(2.0)) * xla::ErfInv(uniform); -} +xla::XlaOp StatelessRngUniform(xla::XlaOp seeds, const xla::Shape& shape, + xla::XlaOp minval, xla::XlaOp maxval) { + xla::XlaBuilder* builder = seeds.builder(); -// A wrapper of xla::StatelessRngUniform. Returns an op that produces random -// values with uniform distribution in the range [minval, maxval) for the given -// shape and given two 32-bit seeds. Currently only shapes of type F32, S32 and -// S64 are implemented. -xla::XlaOp StatelessRandomUniformImpl(const xla::Shape& shape, DataType unused, - xla::XlaOp seed, xla::XlaOp minval, - xla::XlaOp maxval) { - xla::XlaOp seed0 = xla::Reshape(xla::Slice(seed, {0}, {1}, {1}), {}); - xla::XlaOp seed1 = xla::Reshape(xla::Slice(seed, {1}, {2}, {1}), {}); - return xla::StatelessRngUniform({seed0, seed1}, shape, minval, maxval); + xla::XlaOp seed0 = xla::Reshape(xla::Slice(seeds, {0}, {1}, {1}), {}); + xla::XlaOp seed1 = xla::Reshape(xla::Slice(seeds, {1}, {2}, {1}), {}); + xla::XlaOp key = ConvertElementType(seed0, xla::U64) | + ShiftLeft(ConvertElementType(seed1, xla::U64), + ConstantR0WithType(builder, xla::U64, 32)); + xla::XlaOp initial_state = xla::ConstantR0WithType(builder, xla::U64, 0); + xla::PrimitiveType type = shape.element_type(); + switch (type) { + case xla::F32: + return xla::UniformF32Distribution(key, initial_state, + xla::ThreeFryBitGenerator, minval, + maxval, shape) + .value; + case xla::S32: // fall through + case xla::S64: + return UniformIntDistribution(key, initial_state, + xla::ThreeFryBitGenerator, minval, maxval, + shape) + .value; + break; + default: + return builder->ReportError(xla::Unimplemented( + "Types other than F32, S32 and S64 are not implemented by " + "StatelessRngUniform; got %s", + xla::primitive_util::LowercasePrimitiveTypeName(type))); + } } namespace { @@ -86,8 +100,8 @@ class StatelessRandomUniformOp : public XlaOpKernel { xla::Shape xla_shape; OP_REQUIRES_OK(ctx, TensorShapeToXLAShape(DT_FLOAT, shape, &xla_shape)); - xla::XlaOp uniform = StatelessRandomUniformImpl( - xla_shape, dtype_, seed, xla::ConstantR0(builder, 0.0), + xla::XlaOp uniform = StatelessRngUniform( + seed, xla_shape, xla::ConstantR0(builder, 0.0), xla::ConstantR0(builder, 1.0)); uniform = MaybeConvertF32ToBF16(uniform, dtype_); ctx->SetOutput(0, uniform); @@ -136,8 +150,8 @@ class StatelessRandomUniformIntOp : public XlaOpKernel { xla::Shape xla_shape; OP_REQUIRES_OK(ctx, TensorShapeToXLAShape(dtype_, shape, &xla_shape)); - xla::XlaOp uniform = - StatelessRandomUniformImpl(xla_shape, dtype_, seed, minval, maxval); + xla::XlaOp uniform = StatelessRngUniform(seed, xla_shape, minval, maxval); + ctx->SetOutput(0, uniform); } @@ -170,14 +184,20 @@ class StatelessRandomNormalOp : public XlaOpKernel { errors::InvalidArgument("seed must have shape [2], not ", seed_shape.DebugString())); xla::XlaOp seed = ctx->Input(1); - xla::XlaBuilder* builder = ctx->builder(); xla::Shape xla_shape; OP_REQUIRES_OK(ctx, TensorShapeToXLAShape(DT_FLOAT, shape, &xla_shape)); - xla::XlaOp uniform = StatelessRandomUniformImpl( - xla_shape, dtype_, seed, - xla::ConstantR0(builder, std::nextafter(-1.0f, 0.0f)), - xla::ConstantR0(builder, 1.0)); - xla::XlaOp normal = Uniform2NormalUsingSqrtErfinv(uniform); + + xla::XlaBuilder* builder = seed.builder(); + xla::XlaOp seed0 = xla::Reshape(xla::Slice(seed, {0}, {1}, {1}), {}); + xla::XlaOp seed1 = xla::Reshape(xla::Slice(seed, {1}, {2}, {1}), {}); + xla::XlaOp initial_state = xla::ConstantR0WithType(builder, xla::U64, 0); + xla::XlaOp key = ConvertElementType(seed0, xla::U64) | + ShiftLeft(ConvertElementType(seed1, xla::U64), + ConstantR0WithType(builder, xla::U64, 32)); + xla::XlaOp normal = + xla::NormalF32Distribution(key, initial_state, + xla::ThreeFryBitGenerator, xla_shape) + .value; normal = MaybeConvertF32ToBF16(normal, dtype_); ctx->SetOutput(0, normal); } @@ -215,8 +235,8 @@ class StatelessTruncatedNormalOp : public XlaOpKernel { xla::Shape xla_shape; OP_REQUIRES_OK(ctx, TensorShapeToXLAShape(DT_FLOAT, shape, &xla_shape)); - xla::XlaOp uniform = StatelessRandomUniformImpl( - xla_shape, dtype_, seed, + xla::XlaOp uniform = StatelessRngUniform( + seed, xla_shape, xla::MinPositiveNormalValue(builder, xla_shape.element_type()), xla::One(builder, xla_shape.element_type())); xla::XlaOp truncated_normal = TruncatedNormal(uniform); diff --git a/tensorflow/compiler/xla/client/lib/prng.cc b/tensorflow/compiler/xla/client/lib/prng.cc index 63b3b07ddc2..2785e3176c7 100644 --- a/tensorflow/compiler/xla/client/lib/prng.cc +++ b/tensorflow/compiler/xla/client/lib/prng.cc @@ -32,8 +32,12 @@ XlaOp RotateLeftU32(XlaOp v, int distance) { ShiftRightLogical(v, ConstantR0(v.builder(), 32 - distance)); } -} // namespace +// The internal state of the Three Fry implementation. +using ThreeFry2x32State = std::array; +// Implements the ThreeFry counter-based PRNG algorithm. +// Salmon et al. SC 2011. Parallel random numbers: as easy as 1, 2, 3. +// http://www.thesalmons.org/john/random123/papers/random123sc11.pdf ThreeFry2x32State ThreeFry2x32(ThreeFry2x32State input, ThreeFry2x32State key) { XlaBuilder* builder = input[0].builder(); key[0] = BitcastConvertType(key[0], U32); @@ -104,56 +108,68 @@ ThreeFry2x32State ThreeFry2x32(ThreeFry2x32State input, ThreeFry2x32State key) { return x; } -// Returns the inputs with unique counter values for ThreeFry2x32. -ThreeFry2x32State GetInputs(const int64 size, XlaBuilder* builder) { - ThreeFry2x32State inputs; - inputs[0] = Iota(builder, U32, size); - inputs[1] = inputs[0] + ConstantR0(builder, size); - return inputs; -} - -XlaOp StatelessRngUniformU32(std::array key, const Shape& shape) { - XlaBuilder* builder = key[0].builder(); - const int64 size = ShapeUtil::ElementsIn(shape); - const int64 half_size = CeilOfRatio(size, 2); - const bool size_is_odd = (half_size * 2 != size); - ThreeFry2x32State inputs = GetInputs(half_size, builder); - ThreeFry2x32State outputs = ThreeFry2x32(inputs, key); - if (size_is_odd) { - outputs[1] = Slice(outputs[1], {0}, {half_size - 1}, {1}); - } - auto result = ConcatInDim(builder, outputs, 0); - return Reshape(result, AsInt64Slice(shape.dimensions())); -} - +// Converts a uint64 to two uint32s. ThreeFry2x32State Uint64ToUint32s(XlaOp u64) { - auto builder = u64.builder(); - auto const32 = ConstantR0WithType(builder, U64, 32); - auto fst = ConvertElementType(u64, U32); - auto snd = ConvertElementType(ShiftRightLogical(u64, const32), U32); + XlaBuilder* builder = u64.builder(); + XlaOp const32 = ConstantR0WithType(builder, U64, 32); + XlaOp fst = ConvertElementType(u64, U32); + XlaOp snd = ConvertElementType(ShiftRightLogical(u64, const32), U32); return {fst, snd}; } +// Converts two uint32s to a uint64. XlaOp Uint32sToUint64(ThreeFry2x32State u32s) { - auto builder = u32s[0].builder(); + XlaBuilder* builder = u32s[0].builder(); return ConvertElementType(u32s[0], U64) | ShiftLeft(ConvertElementType(u32s[1], U64), ConstantR0WithType(builder, U64, 32)); } -XlaOp StatelessRngUniformU64(std::array key, const Shape& shape) { - XlaBuilder* builder = key[0].builder(); - const int64 size = ShapeUtil::ElementsIn(shape); - ThreeFry2x32State inputs = GetInputs(size, builder); - ThreeFry2x32State outputs = ThreeFry2x32(inputs, key); - // low 32 bit: outputs[0], high 32 bit: outputs[1] - auto result = Uint32sToUint64(outputs); - return Reshape(result, AsInt64Slice(shape.dimensions())); +// Given the initial state and the request number of random numbers to be +// generated, returns the input for the random number generator and a new state. +std::pair GetThreeFryInputsAndUpdatedState( + XlaOp initial_state, const int64 size) { + XlaBuilder* builder = initial_state.builder(); + XlaOp input_u64 = Iota(builder, U64, size); + input_u64 = input_u64 + initial_state; + XlaOp new_state = initial_state + ConstantR0(builder, size); + return std::make_pair(Uint64ToUint32s(input_u64), new_state); } -XlaOp StatelessRngUniformF32(XlaOp bits, XlaOp minval, XlaOp maxval) { - XlaBuilder* builder = bits.builder(); +// Generates random 32bits with the given shape using the Three Fry +// implementation. Returns the random bits and the new state. +RngOutput ThreeFryRngBit32(XlaOp key, XlaOp initial_state, const Shape& shape) { + XlaBuilder* builder = key.builder(); + const int64 size = ShapeUtil::ElementsIn(shape); + const int64 half_size = CeilOfRatio(size, 2); + const bool size_is_odd = (half_size * 2 != size); + std::pair inputs_state = + GetThreeFryInputsAndUpdatedState(initial_state, half_size); + ThreeFry2x32State inputs = inputs_state.first; + ThreeFry2x32State outputs = ThreeFry2x32(inputs, Uint64ToUint32s(key)); + if (size_is_odd) { + outputs[1] = Slice(outputs[1], {0}, {half_size - 1}, {1}); + } + XlaOp result = ConcatInDim(builder, outputs, 0); + return {Reshape(result, AsInt64Slice(shape.dimensions())), + inputs_state.second}; +} +// Generates random 64bits with the given shape using the Three Fry +// implementation. Returns the random bits and the new state. +RngOutput ThreeFryRngBit64(XlaOp key, XlaOp initial_state, const Shape& shape) { + const int64 size = ShapeUtil::ElementsIn(shape); + std::pair inputs_state = + GetThreeFryInputsAndUpdatedState(initial_state, size); + ThreeFry2x32State inputs = inputs_state.first; + ThreeFry2x32State outputs = ThreeFry2x32(inputs, Uint64ToUint32s(key)); + XlaOp result = Uint32sToUint64(outputs); + return {Reshape(result, AsInt64Slice(shape.dimensions())), + inputs_state.second}; +} + +XlaOp ConvertRandomBitsToUniformF32(XlaOp bits, XlaOp minval, XlaOp maxval) { + XlaBuilder* builder = bits.builder(); // Form 23 random mantissa bits, with a leading 1 bit. The leading 1 bit // forces the random bits into the mantissa. constexpr int kFloatBits = 32; @@ -161,50 +177,95 @@ XlaOp StatelessRngUniformF32(XlaOp bits, XlaOp minval, XlaOp maxval) { bits = ShiftRightLogical( bits, ConstantR0(builder, kFloatBits - kMantissaBits)) | ConstantR0(builder, absl::bit_cast(1.0f)); - auto floats = BitcastConvertType(bits, F32); + XlaOp values = BitcastConvertType(bits, F32); // We have a floating point number in the range [1.0, 2.0). // Subtract 1.0f to shift to the range [0.0, 1.0) - floats = floats - ConstantR0(builder, 1.0f); + values = values - ConstantR0(builder, 1.0f); // Multiply and add to shift to the range [minval, maxval). - return floats * (maxval - minval) + minval; + return values * (maxval - minval) + minval; } -XlaOp StatelessRngUniformInt(XlaOp bits, XlaOp minval, XlaOp maxval, - PrimitiveType type, PrimitiveType unsigned_type) { +XlaOp ConvertRandomBitsToUniformInt(XlaOp bits, XlaOp minval, XlaOp maxval, + PrimitiveType type, + PrimitiveType unsigned_type) { XlaBuilder* builder = bits.builder(); - auto range = BitcastConvertType(maxval, unsigned_type) - - BitcastConvertType(minval, unsigned_type); - auto dist = Rem(bits, range); - auto dist_div_2 = + XlaOp range = BitcastConvertType(maxval, unsigned_type) - + BitcastConvertType(minval, unsigned_type); + XlaOp dist = Rem(bits, range); + XlaOp dist_div_2 = ShiftRightLogical(dist, ConstantR0WithType(builder, unsigned_type, 1)); return minval + BitcastConvertType(dist_div_2, type) + BitcastConvertType(dist - dist_div_2, type); } -XlaOp StatelessRngUniform(std::array seeds, const Shape& shape, - XlaOp minval, XlaOp maxval) { - XlaBuilder* builder = seeds[0].builder(); +XlaOp UniformToNormalUsingSqrtErfInv(XlaOp uniform) { + return ScalarLike(uniform, std::sqrt(2.0)) * ErfInv(uniform); +} + +} // namespace + +RngOutput ThreeFryBitGenerator(XlaOp key, XlaOp initial_state, + const Shape& shape) { PrimitiveType type = shape.element_type(); switch (type) { - case F32: { - auto bits = StatelessRngUniformU32(seeds, shape); - return StatelessRngUniformF32(bits, minval, maxval); - } - case S32: { - auto bits = StatelessRngUniformU32(seeds, shape); - return StatelessRngUniformInt(bits, minval, maxval, type, U32); - } - case S64: { - auto bits = StatelessRngUniformU64(seeds, shape); - return StatelessRngUniformInt(bits, minval, maxval, type, U64); - } + case F32: + case U32: + case S32: + return ThreeFryRngBit32(key, initial_state, shape); + case U64: + case S64: + return ThreeFryRngBit64(key, initial_state, shape); default: - return builder->ReportError(Unimplemented( - "Types other than F32, S32 and S64 are not implemented by " - "StatelessRngUniform.")); + return {key.builder()->ReportError(Unimplemented( + "Types other than F32, U32, S32, U64 and S64 " + "are not implemented by ThreeFryBitGenerator; got %s", + xla::primitive_util::LowercasePrimitiveTypeName(type))), + initial_state}; } } +RngOutput UniformF32Distribution(XlaOp key, XlaOp initial_state, + BitGeneratorTy bit_generator, XlaOp minval, + XlaOp maxval, const Shape& shape) { + DCHECK_EQ(shape.element_type(), F32); + RngOutput bits_state = bit_generator(key, initial_state, shape); + XlaOp bits = bits_state.value; + XlaOp new_state = bits_state.state; + return {ConvertRandomBitsToUniformF32(bits, minval, maxval), new_state}; +} + +RngOutput UniformIntDistribution(XlaOp key, XlaOp initial_state, + BitGeneratorTy bit_generator, XlaOp minval, + XlaOp maxval, const Shape& shape) { + RngOutput bits_state = bit_generator(key, initial_state, shape); + XlaOp bits = bits_state.value; + XlaOp new_state = bits_state.state; + PrimitiveType type = shape.element_type(); + PrimitiveType unsigned_type; + if (type == U32 || type == S32) { + unsigned_type = U32; + } else { + DCHECK(type == U64 || type == S64); + unsigned_type = U64; + } + return { + ConvertRandomBitsToUniformInt(bits, minval, maxval, type, unsigned_type), + new_state}; +} + +RngOutput NormalF32Distribution(XlaOp key, XlaOp initial_state, + BitGeneratorTy bit_generator, + const Shape& shape) { + DCHECK_EQ(shape.element_type(), F32); + XlaBuilder* builder = key.builder(); + RngOutput bits_state = UniformF32Distribution( + key, initial_state, bit_generator, + ConstantR0(builder, std::nextafter(-1.0f, 0.0f)), + ConstantR0(builder, 1.0), shape); + XlaOp normal = UniformToNormalUsingSqrtErfInv(bits_state.value); + return {normal, bits_state.state}; +} + } // namespace xla diff --git a/tensorflow/compiler/xla/client/lib/prng.h b/tensorflow/compiler/xla/client/lib/prng.h index 7b0b4c2439e..4cca47c0c4a 100644 --- a/tensorflow/compiler/xla/client/lib/prng.h +++ b/tensorflow/compiler/xla/client/lib/prng.h @@ -23,37 +23,52 @@ limitations under the License. namespace xla { +// Records the bits and state generated by a random number generator. +struct RngOutput { + XlaOp value; + XlaOp state; +}; + +// A BitGenerator returns random bits and updated random bit generator state. +// +// key: is a value input to a random number generator that can affect the +// sequence of number it will generate. A random number generator constructs +// its seed using the key and the initial state. The tf2xla bridge passes the +// seed operand of a tensorflow random operation as a key to the random bit +// generator, for example. +// initial_state: initial_state is the initial state of the current random +// number generation. It could be 0 for a stateless random operation, and +// the returned state from a previous execution for a stateful random +// operation. +// shape: the shape of the random bits. +using BitGeneratorTy = std::function; + // Implements the ThreeFry counter-based PRNG algorithm. // Salmon et al. SC 2011. Parallel random numbers: as easy as 1, 2, 3. // http://www.thesalmons.org/john/random123/papers/random123sc11.pdf -using ThreeFry2x32State = std::array; -ThreeFry2x32State ThreeFry2x32(ThreeFry2x32State input, ThreeFry2x32State key); +RngOutput ThreeFryBitGenerator(XlaOp key, XlaOp initial_state, + const xla::Shape& shape); -// Returns a tensor containing 'shape' random values uniformly distributed in -// the range [minval, maxval). Requires 2 32-bit integer seeds. -// Currently only 'shape's of type F32, S32 and S64 are implemented. -XlaOp StatelessRngUniform(std::array seeds, const Shape& shape, - XlaOp minval, XlaOp maxval); +// Uses the given bit generator to generate random bits and then converts the +// random bits to random numbers of uniform distribution in the given range. +// Returns the random numbers and the state of the random number generator. +// This function is for shape with float element type. +RngOutput UniformF32Distribution(XlaOp key, XlaOp initial_state, + BitGeneratorTy bit_generator, XlaOp minval, + XlaOp maxval, const xla::Shape& shape); -// Converts a 32-bit (signed or unsigned) integer random number `bits` into a -// float32 in the range [minval, maxval). -XlaOp StatelessRngUniformF32(XlaOp bits, XlaOp minval, XlaOp maxval); +// Similar to UniformF32Distribution but for shape with integer element types. +RngOutput UniformIntDistribution(XlaOp key, XlaOp initial_state, + BitGeneratorTy bit_generator, XlaOp minval, + XlaOp maxval, const xla::Shape& shape); -// Converts an integer random number 'bits' of type 'type' to a random number -// in the range [minval, maxval), of the same type. 'unsigned_type' is the -// unsigned version of 'type' (could be the same) with the same bit width. -// The algorithm is the same one that TF uses right now, but it's -// uniform only when maxval - minval is a divisor of the range that bits is -// generated from. -// TODO(b/72573764): Generate real uniform integer distribution. -XlaOp StatelessRngUniformInt(XlaOp bits, XlaOp minval, XlaOp maxval, - PrimitiveType type, PrimitiveType unsigned_type); - -// The following 2 functions, for converting between one uint64 and two uint32s, -// use the contract "lower 32 bits for the first uint32, higher 32 bits for the -// second". -ThreeFry2x32State Uint64ToUint32s(XlaOp u64); -XlaOp Uint32sToUint64(ThreeFry2x32State u32s); +// Uses the given bit generator to generate random bits and then converts the +// random bits to random numbers of normal distribution. +// Returns the random numbers and the state of the random number generator. +RngOutput NormalF32Distribution(XlaOp key, XlaOp initial_state, + BitGeneratorTy bit_generator, + const xla::Shape& shape); } // namespace xla diff --git a/tensorflow/python/kernel_tests/random/util.py b/tensorflow/python/kernel_tests/random/util.py index d8ece405cf5..6041da6bcfe 100644 --- a/tensorflow/python/kernel_tests/random/util.py +++ b/tensorflow/python/kernel_tests/random/util.py @@ -140,5 +140,7 @@ def test_truncated_normal(assert_equal, assert_all_close, dtype, n, y): (alpha * normal_pdf(alpha) - beta * normal_pdf(beta)) / z) - ( (normal_pdf(alpha) - normal_pdf(beta)) / z)**2) actual_variance = np.var(y) - assert_all_close(actual_variance, expected_variance, - rtol=5e-3 if dtype == dtypes.bfloat16 else 1e-3) + assert_all_close( + actual_variance, + expected_variance, + rtol=6e-3 if dtype == dtypes.bfloat16 else 1e-3)