Introduce new RngBitGenerator HLO

The new instruction has the same signature as xla::BitGeneratorTy
what takes a key and a state and returns uniformly distributed
random bits and a new value for the state.

Its aim is to enable backend specific lowering for the various random
bit generator algorithms what should unlock optimization opportunities.

PiperOrigin-RevId: 293569472
Change-Id: I4f69d4f9858378fb1241435032ef75657933c056
This commit is contained in:
A. Unique TensorFlower 2020-02-06 05:03:54 -08:00 committed by TensorFlower Gardener
parent fc1961d9c1
commit a1c59c73da
33 changed files with 542 additions and 34 deletions

View File

@ -28,6 +28,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/client/lib/math.h"
#include "tensorflow/compiler/xla/client/lib/prng.h"
#include "tensorflow/compiler/xla/client/xla_builder.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/framework/tensor_shape.h"
@ -38,11 +39,27 @@ namespace {
xla::BitGeneratorTy BitGen(Algorithm alg) {
if (alg == RNG_ALG_PHILOX) {
return [](xla::XlaOp key, xla::XlaOp state, const xla::Shape& shape) {
return xla::PhiloxBitGenerator(key, state, shape, /*scramble=*/false);
return [=](xla::XlaOp key, xla::XlaOp state, const xla::Shape& shape) {
state =
xla::ConcatInDim(key.builder(), {xla::Reshape(key, {1}), state}, 0);
xla::XlaOp result =
xla::RngBitGenerator(xla::RandomAlgorithm::RNG_PHILOX, state, shape);
xla::XlaOp data = xla::GetTupleElement(result, 1);
xla::XlaOp new_state =
xla::Slice(xla::GetTupleElement(result, 0), {1}, {3}, {1});
return xla::RngOutput{data, new_state};
};
} else {
return [=](xla::XlaOp key, xla::XlaOp state, const xla::Shape& shape) {
state = xla::ConcatScalars(key.builder(), {key, state});
xla::XlaOp result = xla::RngBitGenerator(
xla::RandomAlgorithm::RNG_THREE_FRY, state, shape);
xla::XlaOp data = xla::GetTupleElement(result, 1);
xla::XlaOp new_state = xla::Reshape(
xla::Slice(xla::GetTupleElement(result, 0), {1}, {2}, {1}), {});
return xla::RngOutput{data, new_state};
};
}
return xla::ThreeFryBitGenerator;
}
xla::RngOutput StatefulRngUniform(Algorithm alg, xla::XlaOp key,

View File

@ -26,6 +26,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/client/lib/math.h"
#include "tensorflow/compiler/xla/client/lib/prng.h"
#include "tensorflow/compiler/xla/client/xla_builder.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/framework/tensor_shape.h"
@ -41,11 +42,23 @@ xla::BitGeneratorTy GetBitGeneratorForDevice(
// Turn on the Philox algorithm for the CPU and GPU backends only.
if (device_type_string == DEVICE_GPU_XLA_JIT ||
device_type_string == DEVICE_CPU_XLA_JIT) {
return [](xla::XlaOp key, xla::XlaOp state, const xla::Shape& shape) {
return xla::PhiloxBitGenerator(key, state, shape, /*scramble=*/true);
return [=](xla::XlaOp key, xla::XlaOp state, const xla::Shape& shape) {
std::tie(state, key) = xla::ScramblePhiloxKey(key);
xla::XlaOp philox_state =
xla::ConcatInDim(key.builder(), {xla::Reshape(key, {1}), state}, 0);
xla::XlaOp result = xla::RngBitGenerator(xla::RandomAlgorithm::RNG_PHILOX,
philox_state, shape);
return xla::RngOutput{/*value=*/xla::GetTupleElement(result, 1),
/*state=*/xla::GetTupleElement(result, 0)};
};
}
return xla::ThreeFryBitGenerator;
return [=](xla::XlaOp key, xla::XlaOp state, const xla::Shape& shape) {
state = xla::ConcatScalars(key.builder(), {key, state});
xla::XlaOp result =
xla::RngBitGenerator(xla::RandomAlgorithm::RNG_DEFAULT, state, shape);
return xla::RngOutput{/*value=*/xla::GetTupleElement(result, 1),
/*state=*/xla::GetTupleElement(result, 0)};
};
}
} // namespace

View File

@ -305,17 +305,9 @@ std::pair<Philox4x32State, XlaOp> GetPhiloxInputsAndUpdatedState(
// numbers are generated in the unit of 128bits.
std::pair<Philox4x32State, XlaOp> GeneratePhiloxBits(int64 num_elems,
XlaOp initial_state,
Philox4x32Key key,
bool scramble) {
Philox4x32Key key) {
Philox4x32State state;
if (scramble) {
// When `scramble` is true, `initial_state` is not used. This is because
// scramble is true only when this function is called by stateless random
// ops, for which `initial_state` is always zero.
std::tie(state, key) = ScramblePhiloxKey(key);
} else {
state = Uint128ToUint32s(Uint128FromOp(initial_state));
}
state = Uint128ToUint32s(Uint128FromOp(initial_state));
const int64 num_vector4 = CeilOfRatio<int64>(num_elems, 4);
Philox4x32State inputs;
XlaOp new_state;
@ -328,16 +320,15 @@ std::pair<Philox4x32State, XlaOp> GeneratePhiloxBits(int64 num_elems,
// Generates an array of primitive type U32 with the given shape containing
// random bits generated by the Philox algorithm. Returns the array and the new
// state of the random number generator.
RngOutput PhiloxRngBit32(XlaOp op_key, XlaOp initial_state, const Shape& shape,
bool scramble) {
RngOutput PhiloxRngBit32(XlaOp op_key, XlaOp initial_state,
const Shape& shape) {
XlaBuilder* builder = op_key.builder();
const int64 num_elems = ShapeUtil::ElementsIn(shape);
Philox4x32Key key = Uint64ToUint32s(op_key);
Philox4x32State bits;
XlaOp new_state;
std::tie(bits, new_state) =
GeneratePhiloxBits(num_elems, initial_state, key, scramble);
std::tie(bits, new_state) = GeneratePhiloxBits(num_elems, initial_state, key);
// Combining bits[i] in a round-robin fashion, to align with non-XLA
// implementations
int64 bits_len = (num_elems + 3) / 4;
@ -356,8 +347,8 @@ RngOutput PhiloxRngBit32(XlaOp op_key, XlaOp initial_state, const Shape& shape,
// Generates an array of primitive type U64 with the given shape containing
// random bits generated by the Philox algorithm. Returns the array and the new
// state of the random number generator.
RngOutput PhiloxRngBit64(XlaOp op_key, XlaOp initial_state, const Shape& shape,
bool scramble) {
RngOutput PhiloxRngBit64(XlaOp op_key, XlaOp initial_state,
const Shape& shape) {
XlaBuilder* builder = op_key.builder();
const int64 num_elems = ShapeUtil::ElementsIn(shape);
@ -365,7 +356,7 @@ RngOutput PhiloxRngBit64(XlaOp op_key, XlaOp initial_state, const Shape& shape,
Philox4x32State bits32;
XlaOp new_state;
std::tie(bits32, new_state) =
GeneratePhiloxBits(num_elems * 2, initial_state, key, scramble);
GeneratePhiloxBits(num_elems * 2, initial_state, key);
std::array<XlaOp, 2> bits64;
bits64[0] = Uint32sToUint64({bits32[0], bits32[1]});
@ -463,18 +454,18 @@ RngOutput ThreeFryBitGenerator(XlaOp key, XlaOp initial_state,
}
}
RngOutput PhiloxBitGenerator(XlaOp key, XlaOp initial_state, const Shape& shape,
bool scramble) {
RngOutput PhiloxBitGenerator(XlaOp key, XlaOp initial_state,
const Shape& shape) {
PrimitiveType type = shape.element_type();
switch (type) {
case F32:
case U32:
case S32:
return PhiloxRngBit32(key, initial_state, shape, scramble);
return PhiloxRngBit32(key, initial_state, shape);
case F64:
case U64:
case S64:
return PhiloxRngBit64(key, initial_state, shape, scramble);
return PhiloxRngBit64(key, initial_state, shape);
default:
return {key.builder()->ReportError(Unimplemented(
"Types other than F32, F64, U32, S32, U64 and S64 "
@ -484,6 +475,13 @@ RngOutput PhiloxBitGenerator(XlaOp key, XlaOp initial_state, const Shape& shape,
}
}
std::pair<XlaOp, XlaOp> ScramblePhiloxKey(XlaOp key) {
Philox4x32Key pkey = Uint64ToUint32s(key);
auto state_key = ScramblePhiloxKey(pkey);
return std::make_pair(Uint128ToOp(Uint32sToUint128(state_key.first)),
Uint32sToUint64(state_key.second));
}
RngOutput UniformFloatingPointDistribution(XlaOp key, XlaOp initial_state,
BitGeneratorTy bit_generator,
XlaOp minval, XlaOp maxval,

View File

@ -58,10 +58,10 @@ RngOutput ThreeFryBitGenerator(XlaOp key, XlaOp initial_state,
// 4x32_10 version of the algorithm for the following reasons:
// . 4x32 uses 32-bit multiplication which is fast on GPUs.
// . The authors recommend the 10-round variant, and TensorFlow also uses it.
// 'scramble` controls whether to scramble 'key' and 'initial_state' to form
// the actual key and state fed to the Philox algorithm.
RngOutput PhiloxBitGenerator(XlaOp key, XlaOp initial_state, const Shape& shape,
bool scramble);
RngOutput PhiloxBitGenerator(XlaOp key, XlaOp initial_state,
const Shape& shape);
// Returns a scrambled pair of (state, key) from a single key.
std::pair<XlaOp, XlaOp> ScramblePhiloxKey(XlaOp key);
// Uses the given bit generator to generate random bits and then converts the
// random bits to random numbers of uniform distribution in the given range.

View File

@ -37,6 +37,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/shape_inference.h"
#include "tensorflow/compiler/xla/status_macros.h"
#include "tensorflow/compiler/xla/util.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
namespace xla {
@ -1752,6 +1753,36 @@ XlaOp XlaBuilder::RngUniform(XlaOp a, XlaOp b, const Shape& shape) {
return RngOp(RandomDistribution::RNG_UNIFORM, {a, b}, shape);
}
XlaOp XlaBuilder::RngBitGenerator(RandomAlgorithm algorithm,
XlaOp initial_state, const Shape& shape) {
return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
HloInstructionProto instr;
TF_RETURN_IF_ERROR(ShapeUtil::ValidateShapeWithOptionalLayout(shape));
TF_ASSIGN_OR_RETURN(Shape state_shape, GetShape(initial_state));
Shape output_shape = shape;
switch (output_shape.element_type()) {
case PrimitiveType::F32:
case PrimitiveType::S32:
case PrimitiveType::U32:
output_shape.set_element_type(PrimitiveType::U32);
break;
case PrimitiveType::F64:
case PrimitiveType::S64:
case PrimitiveType::U64:
output_shape.set_element_type(PrimitiveType::U64);
break;
default:
return InvalidArgument("Unsupported shape for RngBitGenerator: %s",
PrimitiveType_Name(output_shape.element_type()));
}
*instr.mutable_shape() =
ShapeUtil::MakeTupleShape({state_shape, output_shape}).ToProto();
instr.set_rng_algorithm(algorithm);
return AddInstruction(std::move(instr), HloOpcode::kRngBitGenerator,
{initial_state});
});
}
XlaOp XlaBuilder::While(const XlaComputation& condition,
const XlaComputation& body, XlaOp init) {
return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
@ -3460,6 +3491,12 @@ XlaOp RngUniform(const XlaOp a, const XlaOp b, const Shape& shape) {
return a.builder()->RngUniform(a, b, shape);
}
XlaOp RngBitGenerator(RandomAlgorithm algorithm, const XlaOp initial_state,
const Shape& shape) {
return initial_state.builder()->RngBitGenerator(algorithm, initial_state,
shape);
}
XlaOp While(const XlaComputation& condition, const XlaComputation& body,
const XlaOp init) {
return init.builder()->While(condition, body, init);

View File

@ -565,6 +565,9 @@ class XlaBuilder {
XlaOp RngUniform(XlaOp a, XlaOp b, const Shape& shape);
XlaOp RngBitGenerator(RandomAlgorithm algorithm, XlaOp initial_state,
const Shape& shape);
XlaOp While(const XlaComputation& condition, const XlaComputation& body,
XlaOp init);
@ -985,6 +988,8 @@ class XlaBuilder {
absl::Span<const XlaOp> static_operands);
friend XlaOp RngNormal(XlaOp mu, XlaOp sigma, const Shape& shape);
friend XlaOp RngUniform(XlaOp a, XlaOp b, const Shape& shape);
friend XlaOp RngBitGenerator(RandomAlgorithm algorithm, XlaOp initial_state,
const Shape& shape);
friend XlaOp While(const XlaComputation& condition,
const XlaComputation& body, XlaOp init);
friend XlaOp Conditional(XlaOp predicate, XlaOp true_operand,
@ -1856,6 +1861,11 @@ XlaOp RngNormal(XlaOp mu, XlaOp sigma, const Shape& shape);
// computation. Returns values in the semi-open interval [a, b).
XlaOp RngUniform(XlaOp a, XlaOp b, const Shape& shape);
// Enqueues a B(initial_state) random bit generation instruction onto the
// computation. Resturns the new key and random bits with the specified shape.
XlaOp RngBitGenerator(RandomAlgorithm algorithm, XlaOp initial_state,
const Shape& shape);
// Enqueues a while node onto the computation.
XlaOp While(const XlaComputation& condition, const XlaComputation& body,
XlaOp init);

View File

@ -2271,6 +2271,34 @@ implementation-defined.
: : : limit of interval :
| `shape` | `Shape` | Output shape of type T |
## RngBitGenerator
Generates an output with a given shape filled with uniform random bits using the
specified algorithm (or backend default) and returns an updated state (with the
same shape as initial state) and the generated random data.
Initial state is the initial state of the current random number generation. It
and the required shape and valid values are dependent on the algorithm used.
The output is guaranteed to be a deterministic function of the initial state but
it is *not* guaranteed to be deterministic between backends and different
compiler versions.
<b>`RngBitGenerator(algorithm, key, shape)`</b> | Arguments | Type | Semantics |
|---------------- | ----------------- | ------------------------------------- |
| `algorithm` | `RandomAlgorithm` | PRNG algorithm to be used. | |
`initial_state` | `XlaOp` | Initial state for the PRNG algorithm. | | `shape` |
`Shape` | Output shape for generated data. |
Available values for `algorithm`: * `rng_default`: Backend specific algorithm
with backend specific shape requirements. * `rng_three_fry`: ThreeFry
counter-based PRNG algorithm. The `initial_state` shape is `u64[2]` with
arbitrary values.
[Salmon et al. SC 2011. Parallel random numbers: as easy as 1, 2, 3.](http://www.thesalmons.org/john/random123/papers/random123sc11.pdf)
* `rng_philox`: Philox algorithm to generate random numbers in parallel. The
`initial_state` shape is `u64[3]` with arbitrary values.
[Salmon et al. SC 2011. Parallel random numbers: as easy as 1, 2, 3.](http://www.thesalmons.org/john/random123/papers/random123sc11.pdf)
## Scatter
The XLA scatter operation generates a result which is the value of the input

View File

@ -4492,6 +4492,25 @@ cc_library(
],
)
cc_library(
name = "rng_bit_generator_expander",
srcs = ["rng_bit_generator_expander.cc"],
hdrs = ["rng_bit_generator_expander.h"],
deps = [
":hlo",
":hlo_casting_utils",
":op_expander_pass",
"//tensorflow/compiler/xla:shape_util",
"//tensorflow/compiler/xla:statusor",
"//tensorflow/compiler/xla:util",
"//tensorflow/compiler/xla:xla_data_proto_cc",
"//tensorflow/compiler/xla/client:xla_builder",
"//tensorflow/compiler/xla/client/lib:prng",
"//tensorflow/stream_executor/lib",
"@com_google_absl//absl/container:flat_hash_map",
],
)
cc_library(
name = "slow_operation_alarm",
srcs = ["slow_operation_alarm.cc"],

View File

@ -94,6 +94,7 @@ cc_library(
"//tensorflow/compiler/xla/service:hlo_casting_utils",
"//tensorflow/compiler/xla/service:dump",
"//tensorflow/compiler/xla/service:map_inliner",
"//tensorflow/compiler/xla/service:rng_bit_generator_expander",
"//tensorflow/compiler/xla/service:tree_reduction_rewriter",
"//tensorflow/compiler/xla/service:hlo_get_dimension_size_rewriter",
"//tensorflow/compiler/xla/service:conditional_to_select",

View File

@ -93,6 +93,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/llvm_ir/llvm_util.h"
#include "tensorflow/compiler/xla/service/map_inliner.h"
#include "tensorflow/compiler/xla/service/reshape_mover.h"
#include "tensorflow/compiler/xla/service/rng_bit_generator_expander.h"
#include "tensorflow/compiler/xla/service/rng_expander.h"
#include "tensorflow/compiler/xla/service/scatter_expander.h"
#include "tensorflow/compiler/xla/service/slice_sinker.h"
@ -241,6 +242,7 @@ Status CpuCompiler::RunHloPassesThroughLayoutAssn(
// Expand random number generation.
pipeline.AddPass<RngExpander>();
pipeline.AddPass<RngBitGeneratorExpander>(RandomAlgorithm::RNG_PHILOX);
// Remove zero-sized HLO from the input so that other passes don't have to
// handle it.

View File

@ -225,6 +225,7 @@ class DfsHloVisitorBase {
virtual Status HandleInfeed(HloInstructionPtr hlo) = 0;
virtual Status HandleOutfeed(HloInstructionPtr hlo) = 0;
virtual Status HandleRng(HloInstructionPtr hlo) = 0;
virtual Status HandleRngBitGenerator(HloInstructionPtr hlo) = 0;
virtual Status HandleRngGetAndUpdateState(HloInstructionPtr hlo) = 0;
virtual Status HandleReverse(HloInstructionPtr hlo) = 0;
virtual Status HandleSort(HloInstructionPtr hlo) = 0;

View File

@ -116,6 +116,9 @@ class DfsHloVisitorWithDefaultBase
Status HandleRng(HloInstructionPtr random) override {
return DefaultAction(random);
}
Status HandleRngBitGenerator(HloInstructionPtr random) override {
return DefaultAction(random);
}
Status HandleRngGetAndUpdateState(HloInstructionPtr random) override {
return DefaultAction(random);
}

View File

@ -1155,6 +1155,7 @@ cc_library(
"//tensorflow/compiler/xla/service:hlo_verifier",
"//tensorflow/compiler/xla/service:llvm_compiler",
"//tensorflow/compiler/xla/service:reshape_mover",
"//tensorflow/compiler/xla/service:rng_bit_generator_expander",
"//tensorflow/compiler/xla/service:rng_expander",
"//tensorflow/compiler/xla/service:slice_sinker",
"//tensorflow/compiler/xla/service:slow_operation_alarm",

View File

@ -80,6 +80,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/hlo_verifier.h"
#include "tensorflow/compiler/xla/service/llvm_ir/llvm_util.h"
#include "tensorflow/compiler/xla/service/reshape_mover.h"
#include "tensorflow/compiler/xla/service/rng_bit_generator_expander.h"
#include "tensorflow/compiler/xla/service/rng_expander.h"
#include "tensorflow/compiler/xla/service/slice_sinker.h"
#include "tensorflow/compiler/xla/service/slow_operation_alarm.h"
@ -127,6 +128,7 @@ Status GpuCompiler::OptimizeHloModule(
// Expand random number generation.
pipeline.AddPass<RngExpander>();
pipeline.AddPass<RngBitGeneratorExpander>(RandomAlgorithm::RNG_PHILOX);
// Remove zero-sized HLO from the input so that other passes don't have to
// handle it.

View File

@ -35,7 +35,7 @@ import "tensorflow/compiler/xla/xla_data.proto";
option cc_enable_arenas = true;
// Serialization of HloInstruction.
// Next ID: 70
// Next ID: 71
message HloInstructionProto {
reserved 10;
reserved "parameter_name";
@ -241,6 +241,9 @@ message HloInstructionProto {
// Specifies if all elements updated are guaranteed to be unique by
// the caller.
bool unique_indices = 69;
// RNG algorithm used by kRngBitGenerator.
xla.RandomAlgorithm rng_algorithm = 70;
}
// Serialization of HloComputation.

View File

@ -746,6 +746,15 @@ Status HloCostAnalysis::HandleRng(const HloInstruction* random) {
return Status::OK();
}
Status HloCostAnalysis::HandleRngBitGenerator(const HloInstruction* random) {
// TODO(b/26346211): Implement better estimates for the RNG cost, since the
// cost changes with the implementation and the distribution. For now, assume
// the cost of each RNG is same as a transcendental operation.
current_properties_[kTranscendentalsKey] =
ShapeUtil::ElementsIn(random->shape());
return Status::OK();
}
Status HloCostAnalysis::HandleRngGetAndUpdateState(
const HloInstruction* random) {
return Status::OK();

View File

@ -84,6 +84,7 @@ class HloCostAnalysis : public ConstDfsHloVisitor {
Status HandleInfeed(const HloInstruction* infeed) override;
Status HandleOutfeed(const HloInstruction* outfeed) override;
Status HandleRng(const HloInstruction* random) override;
Status HandleRngBitGenerator(const HloInstruction* random) override;
Status HandleRngGetAndUpdateState(const HloInstruction* random) override;
Status HandleReverse(const HloInstruction* reverse) override;
Status HandleSort(const HloInstruction* sort) override;

View File

@ -39,6 +39,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/hlo_casting_utils.h"
#include "tensorflow/compiler/xla/service/hlo_instructions.h"
#include "tensorflow/compiler/xla/service/hlo_module.h"
#include "tensorflow/compiler/xla/service/hlo_opcode.h"
#include "tensorflow/compiler/xla/service/pattern_matcher.h"
#include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/compiler/xla/types.h"
@ -967,6 +968,7 @@ ColorScheme HloDotDumper::GetInstructionColor(const HloInstruction* instr) {
case HloOpcode::kRemainder:
case HloOpcode::kRng:
case HloOpcode::kRngGetAndUpdateState:
case HloOpcode::kRngBitGenerator:
case HloOpcode::kRoundNearestAfz:
case HloOpcode::kRsqrt:
case HloOpcode::kSelect:

View File

@ -31,6 +31,7 @@ limitations under the License.
#include "absl/strings/numbers.h"
#include "absl/strings/str_cat.h"
#include "absl/strings/str_join.h"
#include "absl/strings/string_view.h"
#include "absl/types/span.h"
#include "tensorflow/compiler/xla/layout_util.h"
#include "tensorflow/compiler/xla/literal.h"
@ -46,6 +47,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/status_macros.h"
#include "tensorflow/compiler/xla/types.h"
#include "tensorflow/compiler/xla/util.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/lib/gtl/map_util.h"
#include "tensorflow/core/platform/human_readable_json.h"
@ -346,6 +348,10 @@ StatusOr<std::unique_ptr<HloInstruction>> HloInstruction::CreateFromProto(
case HloOpcode::kRng:
instruction = CreateRng(shape, proto.distribution(), all_operands());
break;
case HloOpcode::kRngBitGenerator:
instruction =
CreateRngBitGenerator(shape, operands(0), proto.rng_algorithm());
break;
case HloOpcode::kRngGetAndUpdateState:
instruction = CreateRngGetAndUpdateState(shape, proto.delta());
break;
@ -743,6 +749,13 @@ HloInstruction::CreateRngGetAndUpdateState(const Shape& shape, int64 delta) {
return absl::make_unique<HloRngGetAndUpdateStateInstruction>(shape, delta);
}
/* static */ std::unique_ptr<HloInstruction>
HloInstruction::CreateRngBitGenerator(const Shape& shape, HloInstruction* state,
RandomAlgorithm algorithm) {
return absl::make_unique<HloRngBitGeneratorInstruction>(shape, state,
algorithm);
}
/* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateNary(
const Shape& shape, HloOpcode opcode,
absl::Span<HloInstruction* const> operands) {
@ -1480,6 +1493,7 @@ std::unique_ptr<HloInstruction> HloInstruction::CloneWithNewOperands(
case HloOpcode::kTrace:
case HloOpcode::kFusion:
case HloOpcode::kRng:
case HloOpcode::kRngBitGenerator:
case HloOpcode::kRngGetAndUpdateState:
case HloOpcode::kParameter:
case HloOpcode::kGetTupleElement:
@ -1950,6 +1964,7 @@ bool HloInstruction::IdenticalSlowPath(
case HloOpcode::kTrace:
case HloOpcode::kFusion:
case HloOpcode::kRng:
case HloOpcode::kRngBitGenerator:
case HloOpcode::kRngGetAndUpdateState:
case HloOpcode::kParameter:
case HloOpcode::kGetTupleElement:
@ -2882,6 +2897,8 @@ Status HloInstruction::Visit(DfsHloVisitorBase<HloInstructionPtr>* visitor) {
return visitor->HandleOutfeed(this);
case HloOpcode::kRng:
return visitor->HandleRng(this);
case HloOpcode::kRngBitGenerator:
return visitor->HandleRngBitGenerator(this);
case HloOpcode::kRngGetAndUpdateState:
return visitor->HandleRngGetAndUpdateState(this);
case HloOpcode::kWhile:
@ -3357,6 +3374,9 @@ string OpMetadataToString(const OpMetadata& metadata) {
string RandomDistributionToString(const RandomDistribution& distribution) {
return absl::AsciiStrToLower(RandomDistribution_Name(distribution));
}
string RandomAlgorithmToString(const RandomAlgorithm& algorithm) {
return absl::AsciiStrToLower(RandomAlgorithm_Name(algorithm));
}
string PrecisionToString(const PrecisionConfig::Precision& precision) {
return absl::AsciiStrToLower(PrecisionConfig::Precision_Name(precision));
@ -3401,6 +3421,24 @@ string ReplicaGroupsToString(const std::vector<ReplicaGroup>& replica_groups) {
return StrCat("{", StrJoin(replica_group_str, ","), "}");
}
StatusOr<RandomAlgorithm> StringToRandomAlgorithm(const string& name) {
static std::unordered_map<string, RandomAlgorithm>* map = [] {
static auto* map = new std::unordered_map<string, RandomAlgorithm>;
for (int i = 0; i < RandomAlgorithm_ARRAYSIZE; i++) {
if (RandomAlgorithm_IsValid(i)) {
auto value = static_cast<RandomAlgorithm>(i);
(*map)[RandomAlgorithmToString(value)] = value;
}
}
return map;
}();
auto found = map->find(absl::AsciiStrToLower(name));
if (found == map->end()) {
return InvalidArgument("Unknown algorithm");
}
return found->second;
}
StatusOr<RandomDistribution> StringToRandomDistribution(const string& name) {
static std::unordered_map<string, RandomDistribution>* map = [] {
static auto* map = new std::unordered_map<string, RandomDistribution>;

View File

@ -512,6 +512,11 @@ class HloInstruction {
const Shape& shape, RandomDistribution distribution,
absl::Span<HloInstruction* const> parameters);
// Creates a stateless random bit generator instruction that fills a shape
// with random bits.
static std::unique_ptr<HloInstruction> CreateRngBitGenerator(
const Shape& shape, HloInstruction* state, RandomAlgorithm algorithm);
// Creates an instruction to update the random number generator state to
// reflect the new state after `delta` units of 32 random bits are generated
// and returns the old state.
@ -2025,12 +2030,14 @@ string PaddingConfigToString(const PaddingConfig& padding);
string FrontendAttributesToString(
const FrontendAttributes& frontend_attributes);
string OpMetadataToString(const OpMetadata& metadata);
string RandomAlgorithmToString(const RandomAlgorithm& algorithm);
string RandomDistributionToString(const RandomDistribution& distribution);
string PrecisionToString(const PrecisionConfig::Precision& precision);
string ConvolutionDimensionNumbersToString(
const ConvolutionDimensionNumbers& dnums);
string ReplicaGroupsToString(const std::vector<ReplicaGroup>& replica_groups);
StatusOr<RandomAlgorithm> StringToRandomAlgorithm(const string& name);
StatusOr<RandomDistribution> StringToRandomDistribution(const string& name);
StatusOr<PrecisionConfig::Precision> StringToPrecision(const string& name);

View File

@ -31,6 +31,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/hlo_module.h"
#include "tensorflow/compiler/xla/service/hlo_sharding_metadata.h"
#include "tensorflow/compiler/xla/window_util.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
#include "tensorflow/core/platform/protobuf.h"
namespace xla {
@ -2902,4 +2903,40 @@ HloRngGetAndUpdateStateInstruction::CloneWithNewOperandsImpl(
return absl::make_unique<HloRngGetAndUpdateStateInstruction>(shape, delta());
}
HloRngBitGeneratorInstruction::HloRngBitGeneratorInstruction(
const Shape& shape, HloInstruction* state, RandomAlgorithm algorithm)
: HloInstruction(HloOpcode::kRngBitGenerator, shape),
algorithm_(algorithm) {
AppendOperand(state);
}
HloInstructionProto HloRngBitGeneratorInstruction::ToProto() const {
HloInstructionProto proto = HloInstruction::ToProto();
proto.set_rng_algorithm(algorithm_);
return proto;
}
std::vector<string> HloRngBitGeneratorInstruction::ExtraAttributesToStringImpl(
const HloPrintOptions& options) const {
return {StrCat("algorithm=", RandomAlgorithmToString(algorithm_))};
}
bool HloRngBitGeneratorInstruction::IdenticalSlowPath(
const HloInstruction& other,
const std::function<bool(const HloComputation*, const HloComputation*)>&
eq_computations) const {
const auto& casted_other =
static_cast<const HloRngBitGeneratorInstruction&>(other);
return algorithm() == casted_other.algorithm();
}
std::unique_ptr<HloInstruction>
HloRngBitGeneratorInstruction::CloneWithNewOperandsImpl(
const Shape& shape, absl::Span<HloInstruction* const> new_operands,
HloCloneContext* /*context*/) const {
CHECK_EQ(new_operands.size(), 1);
return absl::make_unique<HloRngBitGeneratorInstruction>(
shape, new_operands[0], algorithm());
}
} // namespace xla

View File

@ -20,6 +20,7 @@ limitations under the License.
#include "absl/memory/memory.h"
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
namespace xla {
@ -1721,6 +1722,28 @@ class HloRngGetAndUpdateStateInstruction : public HloInstruction {
int64 delta_;
};
class HloRngBitGeneratorInstruction : public HloInstruction {
public:
HloRngBitGeneratorInstruction(const Shape& shape, HloInstruction* state,
RandomAlgorithm algorithm);
RandomAlgorithm algorithm() const { return algorithm_; }
HloInstructionProto ToProto() const override;
private:
std::vector<string> ExtraAttributesToStringImpl(
const HloPrintOptions& options) const override;
bool IdenticalSlowPath(
const HloInstruction& other,
const std::function<bool(const HloComputation*, const HloComputation*)>&
eq_computations) const override;
std::unique_ptr<HloInstruction> CloneWithNewOperandsImpl(
const Shape& shape, absl::Span<HloInstruction* const> new_operands,
HloCloneContext* context) const override;
RandomAlgorithm algorithm_;
};
} // namespace xla
#endif // TENSORFLOW_COMPILER_XLA_SERVICE_HLO_INSTRUCTIONS_H_

View File

@ -122,6 +122,7 @@ namespace xla {
V(kReverse, "reverse", 1) \
V(kRng, "rng", kHloOpcodeIsVariadic) \
V(kRngGetAndUpdateState, "rng-get-and-update-state", 0) \
V(kRngBitGenerator, "rng-bit-generator", 1) \
V(kRoundNearestAfz, "round-nearest-afz", 1) \
V(kRsqrt, "rsqrt", 1) \
V(kScatter, "scatter", 3) \

View File

@ -36,6 +36,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/primitive_util.h"
#include "tensorflow/compiler/xla/service/hlo_casting_utils.h"
#include "tensorflow/compiler/xla/service/hlo_domain_metadata.h"
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
#include "tensorflow/compiler/xla/service/hlo_instructions.h"
#include "tensorflow/compiler/xla/service/hlo_lexer.h"
#include "tensorflow/compiler/xla/service/hlo_opcode.h"
@ -203,7 +204,8 @@ class HloParserImpl : public HloParser {
kDomain,
kPrecisionList,
kShapeList,
kEnum
kEnum,
kRandomAlgorithm,
};
struct AttrConfig {
@ -322,6 +324,7 @@ class HloParserImpl : public HloParser {
bool ParseComparisonDirection(ComparisonDirection* result);
bool ParseFusionKind(HloInstruction::FusionKind* result);
bool ParseRandomDistribution(RandomDistribution* result);
bool ParseRandomAlgorithm(RandomAlgorithm* result);
bool ParsePrecision(PrecisionConfig::Precision* result);
bool ParseInt64(int64* result);
bool ParseDouble(double* result);
@ -1501,6 +1504,18 @@ bool HloParserImpl::ParseInstructionRhs(HloComputation::Builder* builder,
HloInstruction::CreateRngGetAndUpdateState(shape, *delta));
break;
}
case HloOpcode::kRngBitGenerator: {
optional<RandomAlgorithm> algorithm;
attrs["algorithm"] = {/*required=*/true, AttrTy::kRandomAlgorithm,
&algorithm};
if (!ParseOperands(&operands) || !ParseAttributes(attrs)) {
return false;
}
instruction =
builder->AddInstruction(HloInstruction::CreateRngBitGenerator(
shape, operands[0], *algorithm));
break;
}
case HloOpcode::kReducePrecision: {
optional<int64> exponent_bits;
optional<int64> mantissa_bits;
@ -2972,6 +2987,14 @@ bool HloParserImpl::ParseAttributeHelper(
->emplace(result);
return true;
}
case AttrTy::kRandomAlgorithm: {
RandomAlgorithm result;
if (!ParseRandomAlgorithm(&result)) {
return false;
}
static_cast<optional<RandomAlgorithm>*>(attr_out_ptr)->emplace(result);
return true;
}
}
}();
if (!success) {
@ -3963,6 +3986,23 @@ bool HloParserImpl::ParseRandomDistribution(RandomDistribution* result) {
return true;
}
bool HloParserImpl::ParseRandomAlgorithm(RandomAlgorithm* result) {
VLOG(3) << "ParseRandomAlgorithm";
if (lexer_.GetKind() != TokKind::kIdent) {
return TokenError("expects random algorithm");
}
std::string val = lexer_.GetStrVal();
auto status_or_result = StringToRandomAlgorithm(val);
if (!status_or_result.ok()) {
return TokenError(
StrFormat("expects random algorithm but sees: %s, error: %s", val,
status_or_result.status().error_message()));
}
*result = status_or_result.ValueOrDie();
lexer_.Lex();
return true;
}
bool HloParserImpl::ParsePrecision(PrecisionConfig::Precision* result) {
VLOG(3) << "ParsePrecision";
if (lexer_.GetKind() != TokKind::kIdent) {

View File

@ -1055,6 +1055,17 @@ ENTRY %RngGetAndUpdateState () -> u64[2] {
)"
},
{
"RngBitGenerator",
R"(HloModule gng_bit_generator
ENTRY %RngBitGenerator (p0: u64[2]) -> (u64[2], u32[11,17]) {
%p0 = u64[2]{0} parameter(0)
ROOT %rand = (u64[2]{0}, u32[11,17]{1,0}) rng-bit-generator(u64[2]{0} %p0), algorithm=rng_three_fry
}
)"
}
});
// clang-format on
}

View File

@ -415,6 +415,23 @@ Status ShapeVerifier::HandleRng(HloInstruction* instruction) {
return Status::OK();
}
Status ShapeVerifier::HandleRngBitGenerator(HloInstruction* hlo) {
if (!hlo->shape().IsTuple() || hlo->shape().tuple_shapes_size() != 2) {
return InternalError(
"Expected tuple shape with 2 elements for RngBitGenerator. Got: %s",
hlo->shape().ToString());
}
if (!ShapeUtil::Compatible(hlo->operand(0)->shape(),
hlo->shape().tuple_shapes(0))) {
return InternalError(
"Expected state shape to match between input and output for "
"RngBitGenerator. Got %s vs. %s",
hlo->operand(0)->shape().ToString(),
hlo->shape().tuple_shapes(0).ToString());
}
return Status::OK();
}
Status ShapeVerifier::HandleRngGetAndUpdateState(HloInstruction* instruction) {
TF_RETURN_IF_ERROR(CheckOperandCount(instruction, 0));
const Shape& result_shape = instruction->shape();

View File

@ -65,6 +65,7 @@ class ShapeVerifier : public DfsHloVisitor {
Status HandleInfeed(HloInstruction*) override;
Status HandleOutfeed(HloInstruction*) override;
Status HandleRng(HloInstruction*) override;
Status HandleRngBitGenerator(HloInstruction*) override;
Status HandleRngGetAndUpdateState(HloInstruction*) override;
Status HandleReverse(HloInstruction* reverse) override;
Status HandleSort(HloInstruction* sort) override;

View File

@ -167,6 +167,7 @@ bool IsAlwaysDuplicable(const HloInstruction& instruction) {
case HloOpcode::kReduceWindow:
case HloOpcode::kRng:
case HloOpcode::kRngGetAndUpdateState:
case HloOpcode::kRngBitGenerator:
case HloOpcode::kRsqrt:
case HloOpcode::kScatter:
case HloOpcode::kSelectAndScatter:

View File

@ -2260,6 +2260,7 @@ bool LayoutAssignment::InstructionCanChangeLayout(
case HloOpcode::kReplicaId:
case HloOpcode::kReshape:
case HloOpcode::kRng:
case HloOpcode::kRngBitGenerator:
case HloOpcode::kRngGetAndUpdateState:
case HloOpcode::kSend:
case HloOpcode::kSendDone:

View File

@ -0,0 +1,105 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/compiler/xla/service/rng_bit_generator_expander.h"
#include "tensorflow/compiler/xla/client/lib/prng.h"
#include "tensorflow/compiler/xla/client/xla_builder.h"
#include "tensorflow/compiler/xla/service/hlo_casting_utils.h"
#include "tensorflow/compiler/xla/service/hlo_computation.h"
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
#include "tensorflow/compiler/xla/service/hlo_instructions.h"
#include "tensorflow/compiler/xla/service/hlo_module.h"
#include "tensorflow/compiler/xla/shape.h"
#include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/compiler/xla/statusor.h"
#include "tensorflow/compiler/xla/util.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
#include "tensorflow/stream_executor/lib/statusor.h"
namespace xla {
bool RngBitGeneratorExpander::InstructionMatchesPattern(
HloInstruction* instruction) {
return instruction->opcode() == HloOpcode::kRngBitGenerator;
}
StatusOr<HloComputation*> RngBitGeneratorExpander::GetGeneratorComputation(
const Shape& data_shape, const Shape& state_shape,
RandomAlgorithm algorithm, HloModule* module) {
RngGeneratorKey cache_key{data_shape, state_shape, algorithm, module};
auto it = computation_cache_.find(cache_key);
if (it != computation_cache_.end()) {
return it->second;
}
XlaBuilder builder("rng");
XlaOp state_param = Parameter(&builder, 0, state_shape, "state");
XlaOp key_op = Reshape(Slice(state_param, {0}, {1}, {1}), {});
XlaOp state_op;
BitGeneratorTy generator = nullptr;
switch (algorithm) {
case RandomAlgorithm::RNG_THREE_FRY:
generator = ThreeFryBitGenerator;
state_op = Slice(state_param, {1}, {2}, {1});
break;
case RandomAlgorithm::RNG_PHILOX:
generator = PhiloxBitGenerator;
state_op = Slice(state_param, {1}, {3}, {1});
break;
default:
return Unimplemented("Unsupported random algorthm: %s",
RandomAlgorithm_Name(algorithm));
}
RngOutput output = generator(key_op, state_op, data_shape);
XlaOp final_state =
ConcatInDim(&builder, {Reshape(key_op, {1}), output.state}, 0);
Tuple(&builder, {final_state, output.value});
TF_ASSIGN_OR_RETURN(XlaComputation xla_computation, builder.Build());
TF_ASSIGN_OR_RETURN(ProgramShape program_shape,
xla_computation.GetProgramShape());
HloModuleConfig config(program_shape);
TF_ASSIGN_OR_RETURN(auto new_module, HloModule::CreateFromProto(
xla_computation.proto(), config));
HloCloneContext context(module);
HloComputation* new_computation =
module->DeepCloneComputation(new_module->entry_computation(), &context);
computation_cache_.emplace(cache_key, new_computation);
return new_computation;
}
StatusOr<HloInstruction*> RngBitGeneratorExpander::ExpandInstruction(
HloInstruction* hlo) {
HloRngBitGeneratorInstruction* rng = Cast<HloRngBitGeneratorInstruction>(hlo);
RandomAlgorithm algorithm = rng->algorithm();
if (algorithm == RandomAlgorithm::RNG_DEFAULT) {
algorithm = default_algorithm_;
}
HloModule* module = hlo->parent()->parent();
const Shape& data_shape = rng->shape().tuple_shapes(1);
const Shape& state_shape = rng->operand(0)->shape();
TF_ASSIGN_OR_RETURN(
HloComputation * generator_computation,
GetGeneratorComputation(data_shape, state_shape, algorithm, module));
return hlo->parent()->AddInstruction(HloInstruction::CreateCall(
ShapeUtil::MakeTupleShape({state_shape, data_shape}),
{hlo->mutable_operand(0)}, generator_computation));
}
} // namespace xla

View File

@ -0,0 +1,72 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_RNG_BIT_GENERATOR_EXPANDER_H_
#define TENSORFLOW_COMPILER_XLA_SERVICE_RNG_BIT_GENERATOR_EXPANDER_H_
#include "absl/container/flat_hash_map.h"
#include "tensorflow/compiler/xla/service/hlo_computation.h"
#include "tensorflow/compiler/xla/service/hlo_module.h"
#include "tensorflow/compiler/xla/service/op_expander_pass.h"
#include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/compiler/xla/statusor.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
namespace xla {
class RngBitGeneratorExpander : public OpExpanderPass {
public:
explicit RngBitGeneratorExpander(RandomAlgorithm default_algorithm)
: default_algorithm_(default_algorithm) {
CHECK_NE(default_algorithm_, RandomAlgorithm::RNG_DEFAULT);
}
absl::string_view name() const override {
return "rng-bit-generator-expander";
}
protected:
struct RngGeneratorKey {
Shape data_shape;
Shape state_shape;
RandomAlgorithm algorithm;
HloModule* module;
template <typename H>
friend H AbslHashValue(H h, const RngGeneratorKey& c) {
return H::combine(std::move(h), c.state_shape, c.data_shape, c.algorithm,
c.module);
}
bool operator==(const RngGeneratorKey& o) const {
return data_shape == o.data_shape && state_shape == o.state_shape &&
algorithm == o.algorithm && module == o.module;
}
};
bool InstructionMatchesPattern(HloInstruction* instruction) override;
StatusOr<HloInstruction*> ExpandInstruction(HloInstruction* hlo) override;
StatusOr<HloComputation*> GetGeneratorComputation(const Shape& data_shape,
const Shape& state_shape,
RandomAlgorithm algorithm,
HloModule* module);
const RandomAlgorithm default_algorithm_;
absl::flat_hash_map<RngGeneratorKey, HloComputation*> computation_cache_;
};
} // namespace xla
#endif // TENSORFLOW_COMPILER_XLA_SERVICE_RNG_BIT_GENERATOR_EXPANDER_H_

View File

@ -84,7 +84,7 @@ StatusOr<HloComputation*> GetComputationForRng(HloInstruction* rng) {
auto generator = [](xla::XlaOp key, xla::XlaOp state,
const xla::Shape& shape) {
return PhiloxBitGenerator(key, state, shape, /*scramble=*/false);
return PhiloxBitGenerator(key, state, shape);
};
XlaOp result;

View File

@ -559,6 +559,13 @@ enum RandomDistribution {
// Next: 4
}
enum RandomAlgorithm {
RNG_DEFAULT = 0; // Backend dependent default algorithm.
RNG_THREE_FRY = 1;
RNG_PHILOX = 2;
// Next: 2
}
message TriangularSolveOptions {
// If true, solves ax = b. If false, solves xa = b.
bool left_side = 1;