Introduce new RngBitGenerator HLO
The new instruction has the same signature as xla::BitGeneratorTy what takes a key and a state and returns uniformly distributed random bits and a new value for the state. Its aim is to enable backend specific lowering for the various random bit generator algorithms what should unlock optimization opportunities. PiperOrigin-RevId: 293569472 Change-Id: I4f69d4f9858378fb1241435032ef75657933c056
This commit is contained in:
parent
fc1961d9c1
commit
a1c59c73da
@ -28,6 +28,7 @@ limitations under the License.
|
||||
#include "tensorflow/compiler/xla/client/lib/math.h"
|
||||
#include "tensorflow/compiler/xla/client/lib/prng.h"
|
||||
#include "tensorflow/compiler/xla/client/xla_builder.h"
|
||||
#include "tensorflow/compiler/xla/xla_data.pb.h"
|
||||
#include "tensorflow/core/framework/op_kernel.h"
|
||||
#include "tensorflow/core/framework/tensor.h"
|
||||
#include "tensorflow/core/framework/tensor_shape.h"
|
||||
@ -38,11 +39,27 @@ namespace {
|
||||
|
||||
xla::BitGeneratorTy BitGen(Algorithm alg) {
|
||||
if (alg == RNG_ALG_PHILOX) {
|
||||
return [](xla::XlaOp key, xla::XlaOp state, const xla::Shape& shape) {
|
||||
return xla::PhiloxBitGenerator(key, state, shape, /*scramble=*/false);
|
||||
return [=](xla::XlaOp key, xla::XlaOp state, const xla::Shape& shape) {
|
||||
state =
|
||||
xla::ConcatInDim(key.builder(), {xla::Reshape(key, {1}), state}, 0);
|
||||
xla::XlaOp result =
|
||||
xla::RngBitGenerator(xla::RandomAlgorithm::RNG_PHILOX, state, shape);
|
||||
xla::XlaOp data = xla::GetTupleElement(result, 1);
|
||||
xla::XlaOp new_state =
|
||||
xla::Slice(xla::GetTupleElement(result, 0), {1}, {3}, {1});
|
||||
return xla::RngOutput{data, new_state};
|
||||
};
|
||||
} else {
|
||||
return [=](xla::XlaOp key, xla::XlaOp state, const xla::Shape& shape) {
|
||||
state = xla::ConcatScalars(key.builder(), {key, state});
|
||||
xla::XlaOp result = xla::RngBitGenerator(
|
||||
xla::RandomAlgorithm::RNG_THREE_FRY, state, shape);
|
||||
xla::XlaOp data = xla::GetTupleElement(result, 1);
|
||||
xla::XlaOp new_state = xla::Reshape(
|
||||
xla::Slice(xla::GetTupleElement(result, 0), {1}, {2}, {1}), {});
|
||||
return xla::RngOutput{data, new_state};
|
||||
};
|
||||
}
|
||||
return xla::ThreeFryBitGenerator;
|
||||
}
|
||||
|
||||
xla::RngOutput StatefulRngUniform(Algorithm alg, xla::XlaOp key,
|
||||
|
@ -26,6 +26,7 @@ limitations under the License.
|
||||
#include "tensorflow/compiler/xla/client/lib/math.h"
|
||||
#include "tensorflow/compiler/xla/client/lib/prng.h"
|
||||
#include "tensorflow/compiler/xla/client/xla_builder.h"
|
||||
#include "tensorflow/compiler/xla/xla_data.pb.h"
|
||||
#include "tensorflow/core/framework/op_kernel.h"
|
||||
#include "tensorflow/core/framework/tensor.h"
|
||||
#include "tensorflow/core/framework/tensor_shape.h"
|
||||
@ -41,11 +42,23 @@ xla::BitGeneratorTy GetBitGeneratorForDevice(
|
||||
// Turn on the Philox algorithm for the CPU and GPU backends only.
|
||||
if (device_type_string == DEVICE_GPU_XLA_JIT ||
|
||||
device_type_string == DEVICE_CPU_XLA_JIT) {
|
||||
return [](xla::XlaOp key, xla::XlaOp state, const xla::Shape& shape) {
|
||||
return xla::PhiloxBitGenerator(key, state, shape, /*scramble=*/true);
|
||||
return [=](xla::XlaOp key, xla::XlaOp state, const xla::Shape& shape) {
|
||||
std::tie(state, key) = xla::ScramblePhiloxKey(key);
|
||||
xla::XlaOp philox_state =
|
||||
xla::ConcatInDim(key.builder(), {xla::Reshape(key, {1}), state}, 0);
|
||||
xla::XlaOp result = xla::RngBitGenerator(xla::RandomAlgorithm::RNG_PHILOX,
|
||||
philox_state, shape);
|
||||
return xla::RngOutput{/*value=*/xla::GetTupleElement(result, 1),
|
||||
/*state=*/xla::GetTupleElement(result, 0)};
|
||||
};
|
||||
}
|
||||
return xla::ThreeFryBitGenerator;
|
||||
return [=](xla::XlaOp key, xla::XlaOp state, const xla::Shape& shape) {
|
||||
state = xla::ConcatScalars(key.builder(), {key, state});
|
||||
xla::XlaOp result =
|
||||
xla::RngBitGenerator(xla::RandomAlgorithm::RNG_DEFAULT, state, shape);
|
||||
return xla::RngOutput{/*value=*/xla::GetTupleElement(result, 1),
|
||||
/*state=*/xla::GetTupleElement(result, 0)};
|
||||
};
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
@ -305,17 +305,9 @@ std::pair<Philox4x32State, XlaOp> GetPhiloxInputsAndUpdatedState(
|
||||
// numbers are generated in the unit of 128bits.
|
||||
std::pair<Philox4x32State, XlaOp> GeneratePhiloxBits(int64 num_elems,
|
||||
XlaOp initial_state,
|
||||
Philox4x32Key key,
|
||||
bool scramble) {
|
||||
Philox4x32Key key) {
|
||||
Philox4x32State state;
|
||||
if (scramble) {
|
||||
// When `scramble` is true, `initial_state` is not used. This is because
|
||||
// scramble is true only when this function is called by stateless random
|
||||
// ops, for which `initial_state` is always zero.
|
||||
std::tie(state, key) = ScramblePhiloxKey(key);
|
||||
} else {
|
||||
state = Uint128ToUint32s(Uint128FromOp(initial_state));
|
||||
}
|
||||
state = Uint128ToUint32s(Uint128FromOp(initial_state));
|
||||
const int64 num_vector4 = CeilOfRatio<int64>(num_elems, 4);
|
||||
Philox4x32State inputs;
|
||||
XlaOp new_state;
|
||||
@ -328,16 +320,15 @@ std::pair<Philox4x32State, XlaOp> GeneratePhiloxBits(int64 num_elems,
|
||||
// Generates an array of primitive type U32 with the given shape containing
|
||||
// random bits generated by the Philox algorithm. Returns the array and the new
|
||||
// state of the random number generator.
|
||||
RngOutput PhiloxRngBit32(XlaOp op_key, XlaOp initial_state, const Shape& shape,
|
||||
bool scramble) {
|
||||
RngOutput PhiloxRngBit32(XlaOp op_key, XlaOp initial_state,
|
||||
const Shape& shape) {
|
||||
XlaBuilder* builder = op_key.builder();
|
||||
const int64 num_elems = ShapeUtil::ElementsIn(shape);
|
||||
|
||||
Philox4x32Key key = Uint64ToUint32s(op_key);
|
||||
Philox4x32State bits;
|
||||
XlaOp new_state;
|
||||
std::tie(bits, new_state) =
|
||||
GeneratePhiloxBits(num_elems, initial_state, key, scramble);
|
||||
std::tie(bits, new_state) = GeneratePhiloxBits(num_elems, initial_state, key);
|
||||
// Combining bits[i] in a round-robin fashion, to align with non-XLA
|
||||
// implementations
|
||||
int64 bits_len = (num_elems + 3) / 4;
|
||||
@ -356,8 +347,8 @@ RngOutput PhiloxRngBit32(XlaOp op_key, XlaOp initial_state, const Shape& shape,
|
||||
// Generates an array of primitive type U64 with the given shape containing
|
||||
// random bits generated by the Philox algorithm. Returns the array and the new
|
||||
// state of the random number generator.
|
||||
RngOutput PhiloxRngBit64(XlaOp op_key, XlaOp initial_state, const Shape& shape,
|
||||
bool scramble) {
|
||||
RngOutput PhiloxRngBit64(XlaOp op_key, XlaOp initial_state,
|
||||
const Shape& shape) {
|
||||
XlaBuilder* builder = op_key.builder();
|
||||
const int64 num_elems = ShapeUtil::ElementsIn(shape);
|
||||
|
||||
@ -365,7 +356,7 @@ RngOutput PhiloxRngBit64(XlaOp op_key, XlaOp initial_state, const Shape& shape,
|
||||
Philox4x32State bits32;
|
||||
XlaOp new_state;
|
||||
std::tie(bits32, new_state) =
|
||||
GeneratePhiloxBits(num_elems * 2, initial_state, key, scramble);
|
||||
GeneratePhiloxBits(num_elems * 2, initial_state, key);
|
||||
|
||||
std::array<XlaOp, 2> bits64;
|
||||
bits64[0] = Uint32sToUint64({bits32[0], bits32[1]});
|
||||
@ -463,18 +454,18 @@ RngOutput ThreeFryBitGenerator(XlaOp key, XlaOp initial_state,
|
||||
}
|
||||
}
|
||||
|
||||
RngOutput PhiloxBitGenerator(XlaOp key, XlaOp initial_state, const Shape& shape,
|
||||
bool scramble) {
|
||||
RngOutput PhiloxBitGenerator(XlaOp key, XlaOp initial_state,
|
||||
const Shape& shape) {
|
||||
PrimitiveType type = shape.element_type();
|
||||
switch (type) {
|
||||
case F32:
|
||||
case U32:
|
||||
case S32:
|
||||
return PhiloxRngBit32(key, initial_state, shape, scramble);
|
||||
return PhiloxRngBit32(key, initial_state, shape);
|
||||
case F64:
|
||||
case U64:
|
||||
case S64:
|
||||
return PhiloxRngBit64(key, initial_state, shape, scramble);
|
||||
return PhiloxRngBit64(key, initial_state, shape);
|
||||
default:
|
||||
return {key.builder()->ReportError(Unimplemented(
|
||||
"Types other than F32, F64, U32, S32, U64 and S64 "
|
||||
@ -484,6 +475,13 @@ RngOutput PhiloxBitGenerator(XlaOp key, XlaOp initial_state, const Shape& shape,
|
||||
}
|
||||
}
|
||||
|
||||
std::pair<XlaOp, XlaOp> ScramblePhiloxKey(XlaOp key) {
|
||||
Philox4x32Key pkey = Uint64ToUint32s(key);
|
||||
auto state_key = ScramblePhiloxKey(pkey);
|
||||
return std::make_pair(Uint128ToOp(Uint32sToUint128(state_key.first)),
|
||||
Uint32sToUint64(state_key.second));
|
||||
}
|
||||
|
||||
RngOutput UniformFloatingPointDistribution(XlaOp key, XlaOp initial_state,
|
||||
BitGeneratorTy bit_generator,
|
||||
XlaOp minval, XlaOp maxval,
|
||||
|
@ -58,10 +58,10 @@ RngOutput ThreeFryBitGenerator(XlaOp key, XlaOp initial_state,
|
||||
// 4x32_10 version of the algorithm for the following reasons:
|
||||
// . 4x32 uses 32-bit multiplication which is fast on GPUs.
|
||||
// . The authors recommend the 10-round variant, and TensorFlow also uses it.
|
||||
// 'scramble` controls whether to scramble 'key' and 'initial_state' to form
|
||||
// the actual key and state fed to the Philox algorithm.
|
||||
RngOutput PhiloxBitGenerator(XlaOp key, XlaOp initial_state, const Shape& shape,
|
||||
bool scramble);
|
||||
RngOutput PhiloxBitGenerator(XlaOp key, XlaOp initial_state,
|
||||
const Shape& shape);
|
||||
// Returns a scrambled pair of (state, key) from a single key.
|
||||
std::pair<XlaOp, XlaOp> ScramblePhiloxKey(XlaOp key);
|
||||
|
||||
// Uses the given bit generator to generate random bits and then converts the
|
||||
// random bits to random numbers of uniform distribution in the given range.
|
||||
|
@ -37,6 +37,7 @@ limitations under the License.
|
||||
#include "tensorflow/compiler/xla/service/shape_inference.h"
|
||||
#include "tensorflow/compiler/xla/status_macros.h"
|
||||
#include "tensorflow/compiler/xla/util.h"
|
||||
#include "tensorflow/compiler/xla/xla_data.pb.h"
|
||||
|
||||
namespace xla {
|
||||
|
||||
@ -1752,6 +1753,36 @@ XlaOp XlaBuilder::RngUniform(XlaOp a, XlaOp b, const Shape& shape) {
|
||||
return RngOp(RandomDistribution::RNG_UNIFORM, {a, b}, shape);
|
||||
}
|
||||
|
||||
XlaOp XlaBuilder::RngBitGenerator(RandomAlgorithm algorithm,
|
||||
XlaOp initial_state, const Shape& shape) {
|
||||
return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
|
||||
HloInstructionProto instr;
|
||||
TF_RETURN_IF_ERROR(ShapeUtil::ValidateShapeWithOptionalLayout(shape));
|
||||
TF_ASSIGN_OR_RETURN(Shape state_shape, GetShape(initial_state));
|
||||
Shape output_shape = shape;
|
||||
switch (output_shape.element_type()) {
|
||||
case PrimitiveType::F32:
|
||||
case PrimitiveType::S32:
|
||||
case PrimitiveType::U32:
|
||||
output_shape.set_element_type(PrimitiveType::U32);
|
||||
break;
|
||||
case PrimitiveType::F64:
|
||||
case PrimitiveType::S64:
|
||||
case PrimitiveType::U64:
|
||||
output_shape.set_element_type(PrimitiveType::U64);
|
||||
break;
|
||||
default:
|
||||
return InvalidArgument("Unsupported shape for RngBitGenerator: %s",
|
||||
PrimitiveType_Name(output_shape.element_type()));
|
||||
}
|
||||
*instr.mutable_shape() =
|
||||
ShapeUtil::MakeTupleShape({state_shape, output_shape}).ToProto();
|
||||
instr.set_rng_algorithm(algorithm);
|
||||
return AddInstruction(std::move(instr), HloOpcode::kRngBitGenerator,
|
||||
{initial_state});
|
||||
});
|
||||
}
|
||||
|
||||
XlaOp XlaBuilder::While(const XlaComputation& condition,
|
||||
const XlaComputation& body, XlaOp init) {
|
||||
return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
|
||||
@ -3460,6 +3491,12 @@ XlaOp RngUniform(const XlaOp a, const XlaOp b, const Shape& shape) {
|
||||
return a.builder()->RngUniform(a, b, shape);
|
||||
}
|
||||
|
||||
XlaOp RngBitGenerator(RandomAlgorithm algorithm, const XlaOp initial_state,
|
||||
const Shape& shape) {
|
||||
return initial_state.builder()->RngBitGenerator(algorithm, initial_state,
|
||||
shape);
|
||||
}
|
||||
|
||||
XlaOp While(const XlaComputation& condition, const XlaComputation& body,
|
||||
const XlaOp init) {
|
||||
return init.builder()->While(condition, body, init);
|
||||
|
@ -565,6 +565,9 @@ class XlaBuilder {
|
||||
|
||||
XlaOp RngUniform(XlaOp a, XlaOp b, const Shape& shape);
|
||||
|
||||
XlaOp RngBitGenerator(RandomAlgorithm algorithm, XlaOp initial_state,
|
||||
const Shape& shape);
|
||||
|
||||
XlaOp While(const XlaComputation& condition, const XlaComputation& body,
|
||||
XlaOp init);
|
||||
|
||||
@ -985,6 +988,8 @@ class XlaBuilder {
|
||||
absl::Span<const XlaOp> static_operands);
|
||||
friend XlaOp RngNormal(XlaOp mu, XlaOp sigma, const Shape& shape);
|
||||
friend XlaOp RngUniform(XlaOp a, XlaOp b, const Shape& shape);
|
||||
friend XlaOp RngBitGenerator(RandomAlgorithm algorithm, XlaOp initial_state,
|
||||
const Shape& shape);
|
||||
friend XlaOp While(const XlaComputation& condition,
|
||||
const XlaComputation& body, XlaOp init);
|
||||
friend XlaOp Conditional(XlaOp predicate, XlaOp true_operand,
|
||||
@ -1856,6 +1861,11 @@ XlaOp RngNormal(XlaOp mu, XlaOp sigma, const Shape& shape);
|
||||
// computation. Returns values in the semi-open interval [a, b).
|
||||
XlaOp RngUniform(XlaOp a, XlaOp b, const Shape& shape);
|
||||
|
||||
// Enqueues a B(initial_state) random bit generation instruction onto the
|
||||
// computation. Resturns the new key and random bits with the specified shape.
|
||||
XlaOp RngBitGenerator(RandomAlgorithm algorithm, XlaOp initial_state,
|
||||
const Shape& shape);
|
||||
|
||||
// Enqueues a while node onto the computation.
|
||||
XlaOp While(const XlaComputation& condition, const XlaComputation& body,
|
||||
XlaOp init);
|
||||
|
@ -2271,6 +2271,34 @@ implementation-defined.
|
||||
: : : limit of interval :
|
||||
| `shape` | `Shape` | Output shape of type T |
|
||||
|
||||
## RngBitGenerator
|
||||
|
||||
Generates an output with a given shape filled with uniform random bits using the
|
||||
specified algorithm (or backend default) and returns an updated state (with the
|
||||
same shape as initial state) and the generated random data.
|
||||
|
||||
Initial state is the initial state of the current random number generation. It
|
||||
and the required shape and valid values are dependent on the algorithm used.
|
||||
|
||||
The output is guaranteed to be a deterministic function of the initial state but
|
||||
it is *not* guaranteed to be deterministic between backends and different
|
||||
compiler versions.
|
||||
|
||||
<b>`RngBitGenerator(algorithm, key, shape)`</b> | Arguments | Type | Semantics |
|
||||
|---------------- | ----------------- | ------------------------------------- |
|
||||
| `algorithm` | `RandomAlgorithm` | PRNG algorithm to be used. | |
|
||||
`initial_state` | `XlaOp` | Initial state for the PRNG algorithm. | | `shape` |
|
||||
`Shape` | Output shape for generated data. |
|
||||
|
||||
Available values for `algorithm`: * `rng_default`: Backend specific algorithm
|
||||
with backend specific shape requirements. * `rng_three_fry`: ThreeFry
|
||||
counter-based PRNG algorithm. The `initial_state` shape is `u64[2]` with
|
||||
arbitrary values.
|
||||
[Salmon et al. SC 2011. Parallel random numbers: as easy as 1, 2, 3.](http://www.thesalmons.org/john/random123/papers/random123sc11.pdf)
|
||||
* `rng_philox`: Philox algorithm to generate random numbers in parallel. The
|
||||
`initial_state` shape is `u64[3]` with arbitrary values.
|
||||
[Salmon et al. SC 2011. Parallel random numbers: as easy as 1, 2, 3.](http://www.thesalmons.org/john/random123/papers/random123sc11.pdf)
|
||||
|
||||
## Scatter
|
||||
|
||||
The XLA scatter operation generates a result which is the value of the input
|
||||
|
@ -4492,6 +4492,25 @@ cc_library(
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "rng_bit_generator_expander",
|
||||
srcs = ["rng_bit_generator_expander.cc"],
|
||||
hdrs = ["rng_bit_generator_expander.h"],
|
||||
deps = [
|
||||
":hlo",
|
||||
":hlo_casting_utils",
|
||||
":op_expander_pass",
|
||||
"//tensorflow/compiler/xla:shape_util",
|
||||
"//tensorflow/compiler/xla:statusor",
|
||||
"//tensorflow/compiler/xla:util",
|
||||
"//tensorflow/compiler/xla:xla_data_proto_cc",
|
||||
"//tensorflow/compiler/xla/client:xla_builder",
|
||||
"//tensorflow/compiler/xla/client/lib:prng",
|
||||
"//tensorflow/stream_executor/lib",
|
||||
"@com_google_absl//absl/container:flat_hash_map",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "slow_operation_alarm",
|
||||
srcs = ["slow_operation_alarm.cc"],
|
||||
|
@ -94,6 +94,7 @@ cc_library(
|
||||
"//tensorflow/compiler/xla/service:hlo_casting_utils",
|
||||
"//tensorflow/compiler/xla/service:dump",
|
||||
"//tensorflow/compiler/xla/service:map_inliner",
|
||||
"//tensorflow/compiler/xla/service:rng_bit_generator_expander",
|
||||
"//tensorflow/compiler/xla/service:tree_reduction_rewriter",
|
||||
"//tensorflow/compiler/xla/service:hlo_get_dimension_size_rewriter",
|
||||
"//tensorflow/compiler/xla/service:conditional_to_select",
|
||||
|
@ -93,6 +93,7 @@ limitations under the License.
|
||||
#include "tensorflow/compiler/xla/service/llvm_ir/llvm_util.h"
|
||||
#include "tensorflow/compiler/xla/service/map_inliner.h"
|
||||
#include "tensorflow/compiler/xla/service/reshape_mover.h"
|
||||
#include "tensorflow/compiler/xla/service/rng_bit_generator_expander.h"
|
||||
#include "tensorflow/compiler/xla/service/rng_expander.h"
|
||||
#include "tensorflow/compiler/xla/service/scatter_expander.h"
|
||||
#include "tensorflow/compiler/xla/service/slice_sinker.h"
|
||||
@ -241,6 +242,7 @@ Status CpuCompiler::RunHloPassesThroughLayoutAssn(
|
||||
|
||||
// Expand random number generation.
|
||||
pipeline.AddPass<RngExpander>();
|
||||
pipeline.AddPass<RngBitGeneratorExpander>(RandomAlgorithm::RNG_PHILOX);
|
||||
|
||||
// Remove zero-sized HLO from the input so that other passes don't have to
|
||||
// handle it.
|
||||
|
@ -225,6 +225,7 @@ class DfsHloVisitorBase {
|
||||
virtual Status HandleInfeed(HloInstructionPtr hlo) = 0;
|
||||
virtual Status HandleOutfeed(HloInstructionPtr hlo) = 0;
|
||||
virtual Status HandleRng(HloInstructionPtr hlo) = 0;
|
||||
virtual Status HandleRngBitGenerator(HloInstructionPtr hlo) = 0;
|
||||
virtual Status HandleRngGetAndUpdateState(HloInstructionPtr hlo) = 0;
|
||||
virtual Status HandleReverse(HloInstructionPtr hlo) = 0;
|
||||
virtual Status HandleSort(HloInstructionPtr hlo) = 0;
|
||||
|
@ -116,6 +116,9 @@ class DfsHloVisitorWithDefaultBase
|
||||
Status HandleRng(HloInstructionPtr random) override {
|
||||
return DefaultAction(random);
|
||||
}
|
||||
Status HandleRngBitGenerator(HloInstructionPtr random) override {
|
||||
return DefaultAction(random);
|
||||
}
|
||||
Status HandleRngGetAndUpdateState(HloInstructionPtr random) override {
|
||||
return DefaultAction(random);
|
||||
}
|
||||
|
@ -1155,6 +1155,7 @@ cc_library(
|
||||
"//tensorflow/compiler/xla/service:hlo_verifier",
|
||||
"//tensorflow/compiler/xla/service:llvm_compiler",
|
||||
"//tensorflow/compiler/xla/service:reshape_mover",
|
||||
"//tensorflow/compiler/xla/service:rng_bit_generator_expander",
|
||||
"//tensorflow/compiler/xla/service:rng_expander",
|
||||
"//tensorflow/compiler/xla/service:slice_sinker",
|
||||
"//tensorflow/compiler/xla/service:slow_operation_alarm",
|
||||
|
@ -80,6 +80,7 @@ limitations under the License.
|
||||
#include "tensorflow/compiler/xla/service/hlo_verifier.h"
|
||||
#include "tensorflow/compiler/xla/service/llvm_ir/llvm_util.h"
|
||||
#include "tensorflow/compiler/xla/service/reshape_mover.h"
|
||||
#include "tensorflow/compiler/xla/service/rng_bit_generator_expander.h"
|
||||
#include "tensorflow/compiler/xla/service/rng_expander.h"
|
||||
#include "tensorflow/compiler/xla/service/slice_sinker.h"
|
||||
#include "tensorflow/compiler/xla/service/slow_operation_alarm.h"
|
||||
@ -127,6 +128,7 @@ Status GpuCompiler::OptimizeHloModule(
|
||||
|
||||
// Expand random number generation.
|
||||
pipeline.AddPass<RngExpander>();
|
||||
pipeline.AddPass<RngBitGeneratorExpander>(RandomAlgorithm::RNG_PHILOX);
|
||||
|
||||
// Remove zero-sized HLO from the input so that other passes don't have to
|
||||
// handle it.
|
||||
|
@ -35,7 +35,7 @@ import "tensorflow/compiler/xla/xla_data.proto";
|
||||
option cc_enable_arenas = true;
|
||||
|
||||
// Serialization of HloInstruction.
|
||||
// Next ID: 70
|
||||
// Next ID: 71
|
||||
message HloInstructionProto {
|
||||
reserved 10;
|
||||
reserved "parameter_name";
|
||||
@ -241,6 +241,9 @@ message HloInstructionProto {
|
||||
// Specifies if all elements updated are guaranteed to be unique by
|
||||
// the caller.
|
||||
bool unique_indices = 69;
|
||||
|
||||
// RNG algorithm used by kRngBitGenerator.
|
||||
xla.RandomAlgorithm rng_algorithm = 70;
|
||||
}
|
||||
|
||||
// Serialization of HloComputation.
|
||||
|
@ -746,6 +746,15 @@ Status HloCostAnalysis::HandleRng(const HloInstruction* random) {
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status HloCostAnalysis::HandleRngBitGenerator(const HloInstruction* random) {
|
||||
// TODO(b/26346211): Implement better estimates for the RNG cost, since the
|
||||
// cost changes with the implementation and the distribution. For now, assume
|
||||
// the cost of each RNG is same as a transcendental operation.
|
||||
current_properties_[kTranscendentalsKey] =
|
||||
ShapeUtil::ElementsIn(random->shape());
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status HloCostAnalysis::HandleRngGetAndUpdateState(
|
||||
const HloInstruction* random) {
|
||||
return Status::OK();
|
||||
|
@ -84,6 +84,7 @@ class HloCostAnalysis : public ConstDfsHloVisitor {
|
||||
Status HandleInfeed(const HloInstruction* infeed) override;
|
||||
Status HandleOutfeed(const HloInstruction* outfeed) override;
|
||||
Status HandleRng(const HloInstruction* random) override;
|
||||
Status HandleRngBitGenerator(const HloInstruction* random) override;
|
||||
Status HandleRngGetAndUpdateState(const HloInstruction* random) override;
|
||||
Status HandleReverse(const HloInstruction* reverse) override;
|
||||
Status HandleSort(const HloInstruction* sort) override;
|
||||
|
@ -39,6 +39,7 @@ limitations under the License.
|
||||
#include "tensorflow/compiler/xla/service/hlo_casting_utils.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_instructions.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_module.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_opcode.h"
|
||||
#include "tensorflow/compiler/xla/service/pattern_matcher.h"
|
||||
#include "tensorflow/compiler/xla/shape_util.h"
|
||||
#include "tensorflow/compiler/xla/types.h"
|
||||
@ -967,6 +968,7 @@ ColorScheme HloDotDumper::GetInstructionColor(const HloInstruction* instr) {
|
||||
case HloOpcode::kRemainder:
|
||||
case HloOpcode::kRng:
|
||||
case HloOpcode::kRngGetAndUpdateState:
|
||||
case HloOpcode::kRngBitGenerator:
|
||||
case HloOpcode::kRoundNearestAfz:
|
||||
case HloOpcode::kRsqrt:
|
||||
case HloOpcode::kSelect:
|
||||
|
@ -31,6 +31,7 @@ limitations under the License.
|
||||
#include "absl/strings/numbers.h"
|
||||
#include "absl/strings/str_cat.h"
|
||||
#include "absl/strings/str_join.h"
|
||||
#include "absl/strings/string_view.h"
|
||||
#include "absl/types/span.h"
|
||||
#include "tensorflow/compiler/xla/layout_util.h"
|
||||
#include "tensorflow/compiler/xla/literal.h"
|
||||
@ -46,6 +47,7 @@ limitations under the License.
|
||||
#include "tensorflow/compiler/xla/status_macros.h"
|
||||
#include "tensorflow/compiler/xla/types.h"
|
||||
#include "tensorflow/compiler/xla/util.h"
|
||||
#include "tensorflow/compiler/xla/xla_data.pb.h"
|
||||
#include "tensorflow/core/lib/core/errors.h"
|
||||
#include "tensorflow/core/lib/gtl/map_util.h"
|
||||
#include "tensorflow/core/platform/human_readable_json.h"
|
||||
@ -346,6 +348,10 @@ StatusOr<std::unique_ptr<HloInstruction>> HloInstruction::CreateFromProto(
|
||||
case HloOpcode::kRng:
|
||||
instruction = CreateRng(shape, proto.distribution(), all_operands());
|
||||
break;
|
||||
case HloOpcode::kRngBitGenerator:
|
||||
instruction =
|
||||
CreateRngBitGenerator(shape, operands(0), proto.rng_algorithm());
|
||||
break;
|
||||
case HloOpcode::kRngGetAndUpdateState:
|
||||
instruction = CreateRngGetAndUpdateState(shape, proto.delta());
|
||||
break;
|
||||
@ -743,6 +749,13 @@ HloInstruction::CreateRngGetAndUpdateState(const Shape& shape, int64 delta) {
|
||||
return absl::make_unique<HloRngGetAndUpdateStateInstruction>(shape, delta);
|
||||
}
|
||||
|
||||
/* static */ std::unique_ptr<HloInstruction>
|
||||
HloInstruction::CreateRngBitGenerator(const Shape& shape, HloInstruction* state,
|
||||
RandomAlgorithm algorithm) {
|
||||
return absl::make_unique<HloRngBitGeneratorInstruction>(shape, state,
|
||||
algorithm);
|
||||
}
|
||||
|
||||
/* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateNary(
|
||||
const Shape& shape, HloOpcode opcode,
|
||||
absl::Span<HloInstruction* const> operands) {
|
||||
@ -1480,6 +1493,7 @@ std::unique_ptr<HloInstruction> HloInstruction::CloneWithNewOperands(
|
||||
case HloOpcode::kTrace:
|
||||
case HloOpcode::kFusion:
|
||||
case HloOpcode::kRng:
|
||||
case HloOpcode::kRngBitGenerator:
|
||||
case HloOpcode::kRngGetAndUpdateState:
|
||||
case HloOpcode::kParameter:
|
||||
case HloOpcode::kGetTupleElement:
|
||||
@ -1950,6 +1964,7 @@ bool HloInstruction::IdenticalSlowPath(
|
||||
case HloOpcode::kTrace:
|
||||
case HloOpcode::kFusion:
|
||||
case HloOpcode::kRng:
|
||||
case HloOpcode::kRngBitGenerator:
|
||||
case HloOpcode::kRngGetAndUpdateState:
|
||||
case HloOpcode::kParameter:
|
||||
case HloOpcode::kGetTupleElement:
|
||||
@ -2882,6 +2897,8 @@ Status HloInstruction::Visit(DfsHloVisitorBase<HloInstructionPtr>* visitor) {
|
||||
return visitor->HandleOutfeed(this);
|
||||
case HloOpcode::kRng:
|
||||
return visitor->HandleRng(this);
|
||||
case HloOpcode::kRngBitGenerator:
|
||||
return visitor->HandleRngBitGenerator(this);
|
||||
case HloOpcode::kRngGetAndUpdateState:
|
||||
return visitor->HandleRngGetAndUpdateState(this);
|
||||
case HloOpcode::kWhile:
|
||||
@ -3357,6 +3374,9 @@ string OpMetadataToString(const OpMetadata& metadata) {
|
||||
string RandomDistributionToString(const RandomDistribution& distribution) {
|
||||
return absl::AsciiStrToLower(RandomDistribution_Name(distribution));
|
||||
}
|
||||
string RandomAlgorithmToString(const RandomAlgorithm& algorithm) {
|
||||
return absl::AsciiStrToLower(RandomAlgorithm_Name(algorithm));
|
||||
}
|
||||
|
||||
string PrecisionToString(const PrecisionConfig::Precision& precision) {
|
||||
return absl::AsciiStrToLower(PrecisionConfig::Precision_Name(precision));
|
||||
@ -3401,6 +3421,24 @@ string ReplicaGroupsToString(const std::vector<ReplicaGroup>& replica_groups) {
|
||||
return StrCat("{", StrJoin(replica_group_str, ","), "}");
|
||||
}
|
||||
|
||||
StatusOr<RandomAlgorithm> StringToRandomAlgorithm(const string& name) {
|
||||
static std::unordered_map<string, RandomAlgorithm>* map = [] {
|
||||
static auto* map = new std::unordered_map<string, RandomAlgorithm>;
|
||||
for (int i = 0; i < RandomAlgorithm_ARRAYSIZE; i++) {
|
||||
if (RandomAlgorithm_IsValid(i)) {
|
||||
auto value = static_cast<RandomAlgorithm>(i);
|
||||
(*map)[RandomAlgorithmToString(value)] = value;
|
||||
}
|
||||
}
|
||||
return map;
|
||||
}();
|
||||
auto found = map->find(absl::AsciiStrToLower(name));
|
||||
if (found == map->end()) {
|
||||
return InvalidArgument("Unknown algorithm");
|
||||
}
|
||||
return found->second;
|
||||
}
|
||||
|
||||
StatusOr<RandomDistribution> StringToRandomDistribution(const string& name) {
|
||||
static std::unordered_map<string, RandomDistribution>* map = [] {
|
||||
static auto* map = new std::unordered_map<string, RandomDistribution>;
|
||||
|
@ -512,6 +512,11 @@ class HloInstruction {
|
||||
const Shape& shape, RandomDistribution distribution,
|
||||
absl::Span<HloInstruction* const> parameters);
|
||||
|
||||
// Creates a stateless random bit generator instruction that fills a shape
|
||||
// with random bits.
|
||||
static std::unique_ptr<HloInstruction> CreateRngBitGenerator(
|
||||
const Shape& shape, HloInstruction* state, RandomAlgorithm algorithm);
|
||||
|
||||
// Creates an instruction to update the random number generator state to
|
||||
// reflect the new state after `delta` units of 32 random bits are generated
|
||||
// and returns the old state.
|
||||
@ -2025,12 +2030,14 @@ string PaddingConfigToString(const PaddingConfig& padding);
|
||||
string FrontendAttributesToString(
|
||||
const FrontendAttributes& frontend_attributes);
|
||||
string OpMetadataToString(const OpMetadata& metadata);
|
||||
string RandomAlgorithmToString(const RandomAlgorithm& algorithm);
|
||||
string RandomDistributionToString(const RandomDistribution& distribution);
|
||||
string PrecisionToString(const PrecisionConfig::Precision& precision);
|
||||
string ConvolutionDimensionNumbersToString(
|
||||
const ConvolutionDimensionNumbers& dnums);
|
||||
string ReplicaGroupsToString(const std::vector<ReplicaGroup>& replica_groups);
|
||||
|
||||
StatusOr<RandomAlgorithm> StringToRandomAlgorithm(const string& name);
|
||||
StatusOr<RandomDistribution> StringToRandomDistribution(const string& name);
|
||||
StatusOr<PrecisionConfig::Precision> StringToPrecision(const string& name);
|
||||
|
||||
|
@ -31,6 +31,7 @@ limitations under the License.
|
||||
#include "tensorflow/compiler/xla/service/hlo_module.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_sharding_metadata.h"
|
||||
#include "tensorflow/compiler/xla/window_util.h"
|
||||
#include "tensorflow/compiler/xla/xla_data.pb.h"
|
||||
#include "tensorflow/core/platform/protobuf.h"
|
||||
|
||||
namespace xla {
|
||||
@ -2902,4 +2903,40 @@ HloRngGetAndUpdateStateInstruction::CloneWithNewOperandsImpl(
|
||||
return absl::make_unique<HloRngGetAndUpdateStateInstruction>(shape, delta());
|
||||
}
|
||||
|
||||
HloRngBitGeneratorInstruction::HloRngBitGeneratorInstruction(
|
||||
const Shape& shape, HloInstruction* state, RandomAlgorithm algorithm)
|
||||
: HloInstruction(HloOpcode::kRngBitGenerator, shape),
|
||||
algorithm_(algorithm) {
|
||||
AppendOperand(state);
|
||||
}
|
||||
|
||||
HloInstructionProto HloRngBitGeneratorInstruction::ToProto() const {
|
||||
HloInstructionProto proto = HloInstruction::ToProto();
|
||||
proto.set_rng_algorithm(algorithm_);
|
||||
return proto;
|
||||
}
|
||||
|
||||
std::vector<string> HloRngBitGeneratorInstruction::ExtraAttributesToStringImpl(
|
||||
const HloPrintOptions& options) const {
|
||||
return {StrCat("algorithm=", RandomAlgorithmToString(algorithm_))};
|
||||
}
|
||||
|
||||
bool HloRngBitGeneratorInstruction::IdenticalSlowPath(
|
||||
const HloInstruction& other,
|
||||
const std::function<bool(const HloComputation*, const HloComputation*)>&
|
||||
eq_computations) const {
|
||||
const auto& casted_other =
|
||||
static_cast<const HloRngBitGeneratorInstruction&>(other);
|
||||
return algorithm() == casted_other.algorithm();
|
||||
}
|
||||
|
||||
std::unique_ptr<HloInstruction>
|
||||
HloRngBitGeneratorInstruction::CloneWithNewOperandsImpl(
|
||||
const Shape& shape, absl::Span<HloInstruction* const> new_operands,
|
||||
HloCloneContext* /*context*/) const {
|
||||
CHECK_EQ(new_operands.size(), 1);
|
||||
return absl::make_unique<HloRngBitGeneratorInstruction>(
|
||||
shape, new_operands[0], algorithm());
|
||||
}
|
||||
|
||||
} // namespace xla
|
||||
|
@ -20,6 +20,7 @@ limitations under the License.
|
||||
|
||||
#include "absl/memory/memory.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
|
||||
#include "tensorflow/compiler/xla/xla_data.pb.h"
|
||||
|
||||
namespace xla {
|
||||
|
||||
@ -1721,6 +1722,28 @@ class HloRngGetAndUpdateStateInstruction : public HloInstruction {
|
||||
int64 delta_;
|
||||
};
|
||||
|
||||
class HloRngBitGeneratorInstruction : public HloInstruction {
|
||||
public:
|
||||
HloRngBitGeneratorInstruction(const Shape& shape, HloInstruction* state,
|
||||
RandomAlgorithm algorithm);
|
||||
|
||||
RandomAlgorithm algorithm() const { return algorithm_; }
|
||||
HloInstructionProto ToProto() const override;
|
||||
|
||||
private:
|
||||
std::vector<string> ExtraAttributesToStringImpl(
|
||||
const HloPrintOptions& options) const override;
|
||||
bool IdenticalSlowPath(
|
||||
const HloInstruction& other,
|
||||
const std::function<bool(const HloComputation*, const HloComputation*)>&
|
||||
eq_computations) const override;
|
||||
std::unique_ptr<HloInstruction> CloneWithNewOperandsImpl(
|
||||
const Shape& shape, absl::Span<HloInstruction* const> new_operands,
|
||||
HloCloneContext* context) const override;
|
||||
|
||||
RandomAlgorithm algorithm_;
|
||||
};
|
||||
|
||||
} // namespace xla
|
||||
|
||||
#endif // TENSORFLOW_COMPILER_XLA_SERVICE_HLO_INSTRUCTIONS_H_
|
||||
|
@ -122,6 +122,7 @@ namespace xla {
|
||||
V(kReverse, "reverse", 1) \
|
||||
V(kRng, "rng", kHloOpcodeIsVariadic) \
|
||||
V(kRngGetAndUpdateState, "rng-get-and-update-state", 0) \
|
||||
V(kRngBitGenerator, "rng-bit-generator", 1) \
|
||||
V(kRoundNearestAfz, "round-nearest-afz", 1) \
|
||||
V(kRsqrt, "rsqrt", 1) \
|
||||
V(kScatter, "scatter", 3) \
|
||||
|
@ -36,6 +36,7 @@ limitations under the License.
|
||||
#include "tensorflow/compiler/xla/primitive_util.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_casting_utils.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_domain_metadata.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_instructions.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_lexer.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_opcode.h"
|
||||
@ -203,7 +204,8 @@ class HloParserImpl : public HloParser {
|
||||
kDomain,
|
||||
kPrecisionList,
|
||||
kShapeList,
|
||||
kEnum
|
||||
kEnum,
|
||||
kRandomAlgorithm,
|
||||
};
|
||||
|
||||
struct AttrConfig {
|
||||
@ -322,6 +324,7 @@ class HloParserImpl : public HloParser {
|
||||
bool ParseComparisonDirection(ComparisonDirection* result);
|
||||
bool ParseFusionKind(HloInstruction::FusionKind* result);
|
||||
bool ParseRandomDistribution(RandomDistribution* result);
|
||||
bool ParseRandomAlgorithm(RandomAlgorithm* result);
|
||||
bool ParsePrecision(PrecisionConfig::Precision* result);
|
||||
bool ParseInt64(int64* result);
|
||||
bool ParseDouble(double* result);
|
||||
@ -1501,6 +1504,18 @@ bool HloParserImpl::ParseInstructionRhs(HloComputation::Builder* builder,
|
||||
HloInstruction::CreateRngGetAndUpdateState(shape, *delta));
|
||||
break;
|
||||
}
|
||||
case HloOpcode::kRngBitGenerator: {
|
||||
optional<RandomAlgorithm> algorithm;
|
||||
attrs["algorithm"] = {/*required=*/true, AttrTy::kRandomAlgorithm,
|
||||
&algorithm};
|
||||
if (!ParseOperands(&operands) || !ParseAttributes(attrs)) {
|
||||
return false;
|
||||
}
|
||||
instruction =
|
||||
builder->AddInstruction(HloInstruction::CreateRngBitGenerator(
|
||||
shape, operands[0], *algorithm));
|
||||
break;
|
||||
}
|
||||
case HloOpcode::kReducePrecision: {
|
||||
optional<int64> exponent_bits;
|
||||
optional<int64> mantissa_bits;
|
||||
@ -2972,6 +2987,14 @@ bool HloParserImpl::ParseAttributeHelper(
|
||||
->emplace(result);
|
||||
return true;
|
||||
}
|
||||
case AttrTy::kRandomAlgorithm: {
|
||||
RandomAlgorithm result;
|
||||
if (!ParseRandomAlgorithm(&result)) {
|
||||
return false;
|
||||
}
|
||||
static_cast<optional<RandomAlgorithm>*>(attr_out_ptr)->emplace(result);
|
||||
return true;
|
||||
}
|
||||
}
|
||||
}();
|
||||
if (!success) {
|
||||
@ -3963,6 +3986,23 @@ bool HloParserImpl::ParseRandomDistribution(RandomDistribution* result) {
|
||||
return true;
|
||||
}
|
||||
|
||||
bool HloParserImpl::ParseRandomAlgorithm(RandomAlgorithm* result) {
|
||||
VLOG(3) << "ParseRandomAlgorithm";
|
||||
if (lexer_.GetKind() != TokKind::kIdent) {
|
||||
return TokenError("expects random algorithm");
|
||||
}
|
||||
std::string val = lexer_.GetStrVal();
|
||||
auto status_or_result = StringToRandomAlgorithm(val);
|
||||
if (!status_or_result.ok()) {
|
||||
return TokenError(
|
||||
StrFormat("expects random algorithm but sees: %s, error: %s", val,
|
||||
status_or_result.status().error_message()));
|
||||
}
|
||||
*result = status_or_result.ValueOrDie();
|
||||
lexer_.Lex();
|
||||
return true;
|
||||
}
|
||||
|
||||
bool HloParserImpl::ParsePrecision(PrecisionConfig::Precision* result) {
|
||||
VLOG(3) << "ParsePrecision";
|
||||
if (lexer_.GetKind() != TokKind::kIdent) {
|
||||
|
@ -1055,6 +1055,17 @@ ENTRY %RngGetAndUpdateState () -> u64[2] {
|
||||
|
||||
)"
|
||||
},
|
||||
{
|
||||
"RngBitGenerator",
|
||||
R"(HloModule gng_bit_generator
|
||||
|
||||
ENTRY %RngBitGenerator (p0: u64[2]) -> (u64[2], u32[11,17]) {
|
||||
%p0 = u64[2]{0} parameter(0)
|
||||
ROOT %rand = (u64[2]{0}, u32[11,17]{1,0}) rng-bit-generator(u64[2]{0} %p0), algorithm=rng_three_fry
|
||||
}
|
||||
|
||||
)"
|
||||
}
|
||||
});
|
||||
// clang-format on
|
||||
}
|
||||
|
@ -415,6 +415,23 @@ Status ShapeVerifier::HandleRng(HloInstruction* instruction) {
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status ShapeVerifier::HandleRngBitGenerator(HloInstruction* hlo) {
|
||||
if (!hlo->shape().IsTuple() || hlo->shape().tuple_shapes_size() != 2) {
|
||||
return InternalError(
|
||||
"Expected tuple shape with 2 elements for RngBitGenerator. Got: %s",
|
||||
hlo->shape().ToString());
|
||||
}
|
||||
if (!ShapeUtil::Compatible(hlo->operand(0)->shape(),
|
||||
hlo->shape().tuple_shapes(0))) {
|
||||
return InternalError(
|
||||
"Expected state shape to match between input and output for "
|
||||
"RngBitGenerator. Got %s vs. %s",
|
||||
hlo->operand(0)->shape().ToString(),
|
||||
hlo->shape().tuple_shapes(0).ToString());
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status ShapeVerifier::HandleRngGetAndUpdateState(HloInstruction* instruction) {
|
||||
TF_RETURN_IF_ERROR(CheckOperandCount(instruction, 0));
|
||||
const Shape& result_shape = instruction->shape();
|
||||
|
@ -65,6 +65,7 @@ class ShapeVerifier : public DfsHloVisitor {
|
||||
Status HandleInfeed(HloInstruction*) override;
|
||||
Status HandleOutfeed(HloInstruction*) override;
|
||||
Status HandleRng(HloInstruction*) override;
|
||||
Status HandleRngBitGenerator(HloInstruction*) override;
|
||||
Status HandleRngGetAndUpdateState(HloInstruction*) override;
|
||||
Status HandleReverse(HloInstruction* reverse) override;
|
||||
Status HandleSort(HloInstruction* sort) override;
|
||||
|
@ -167,6 +167,7 @@ bool IsAlwaysDuplicable(const HloInstruction& instruction) {
|
||||
case HloOpcode::kReduceWindow:
|
||||
case HloOpcode::kRng:
|
||||
case HloOpcode::kRngGetAndUpdateState:
|
||||
case HloOpcode::kRngBitGenerator:
|
||||
case HloOpcode::kRsqrt:
|
||||
case HloOpcode::kScatter:
|
||||
case HloOpcode::kSelectAndScatter:
|
||||
|
@ -2260,6 +2260,7 @@ bool LayoutAssignment::InstructionCanChangeLayout(
|
||||
case HloOpcode::kReplicaId:
|
||||
case HloOpcode::kReshape:
|
||||
case HloOpcode::kRng:
|
||||
case HloOpcode::kRngBitGenerator:
|
||||
case HloOpcode::kRngGetAndUpdateState:
|
||||
case HloOpcode::kSend:
|
||||
case HloOpcode::kSendDone:
|
||||
|
105
tensorflow/compiler/xla/service/rng_bit_generator_expander.cc
Normal file
105
tensorflow/compiler/xla/service/rng_bit_generator_expander.cc
Normal file
@ -0,0 +1,105 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/compiler/xla/service/rng_bit_generator_expander.h"
|
||||
|
||||
#include "tensorflow/compiler/xla/client/lib/prng.h"
|
||||
#include "tensorflow/compiler/xla/client/xla_builder.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_casting_utils.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_computation.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_instructions.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_module.h"
|
||||
#include "tensorflow/compiler/xla/shape.h"
|
||||
#include "tensorflow/compiler/xla/shape_util.h"
|
||||
#include "tensorflow/compiler/xla/statusor.h"
|
||||
#include "tensorflow/compiler/xla/util.h"
|
||||
#include "tensorflow/compiler/xla/xla_data.pb.h"
|
||||
#include "tensorflow/stream_executor/lib/statusor.h"
|
||||
|
||||
namespace xla {
|
||||
|
||||
bool RngBitGeneratorExpander::InstructionMatchesPattern(
|
||||
HloInstruction* instruction) {
|
||||
return instruction->opcode() == HloOpcode::kRngBitGenerator;
|
||||
}
|
||||
|
||||
StatusOr<HloComputation*> RngBitGeneratorExpander::GetGeneratorComputation(
|
||||
const Shape& data_shape, const Shape& state_shape,
|
||||
RandomAlgorithm algorithm, HloModule* module) {
|
||||
RngGeneratorKey cache_key{data_shape, state_shape, algorithm, module};
|
||||
auto it = computation_cache_.find(cache_key);
|
||||
if (it != computation_cache_.end()) {
|
||||
return it->second;
|
||||
}
|
||||
|
||||
XlaBuilder builder("rng");
|
||||
XlaOp state_param = Parameter(&builder, 0, state_shape, "state");
|
||||
XlaOp key_op = Reshape(Slice(state_param, {0}, {1}, {1}), {});
|
||||
XlaOp state_op;
|
||||
|
||||
BitGeneratorTy generator = nullptr;
|
||||
switch (algorithm) {
|
||||
case RandomAlgorithm::RNG_THREE_FRY:
|
||||
generator = ThreeFryBitGenerator;
|
||||
state_op = Slice(state_param, {1}, {2}, {1});
|
||||
break;
|
||||
case RandomAlgorithm::RNG_PHILOX:
|
||||
generator = PhiloxBitGenerator;
|
||||
state_op = Slice(state_param, {1}, {3}, {1});
|
||||
break;
|
||||
default:
|
||||
return Unimplemented("Unsupported random algorthm: %s",
|
||||
RandomAlgorithm_Name(algorithm));
|
||||
}
|
||||
|
||||
RngOutput output = generator(key_op, state_op, data_shape);
|
||||
XlaOp final_state =
|
||||
ConcatInDim(&builder, {Reshape(key_op, {1}), output.state}, 0);
|
||||
Tuple(&builder, {final_state, output.value});
|
||||
TF_ASSIGN_OR_RETURN(XlaComputation xla_computation, builder.Build());
|
||||
|
||||
TF_ASSIGN_OR_RETURN(ProgramShape program_shape,
|
||||
xla_computation.GetProgramShape());
|
||||
HloModuleConfig config(program_shape);
|
||||
TF_ASSIGN_OR_RETURN(auto new_module, HloModule::CreateFromProto(
|
||||
xla_computation.proto(), config));
|
||||
HloCloneContext context(module);
|
||||
HloComputation* new_computation =
|
||||
module->DeepCloneComputation(new_module->entry_computation(), &context);
|
||||
computation_cache_.emplace(cache_key, new_computation);
|
||||
return new_computation;
|
||||
}
|
||||
|
||||
StatusOr<HloInstruction*> RngBitGeneratorExpander::ExpandInstruction(
|
||||
HloInstruction* hlo) {
|
||||
HloRngBitGeneratorInstruction* rng = Cast<HloRngBitGeneratorInstruction>(hlo);
|
||||
RandomAlgorithm algorithm = rng->algorithm();
|
||||
if (algorithm == RandomAlgorithm::RNG_DEFAULT) {
|
||||
algorithm = default_algorithm_;
|
||||
}
|
||||
|
||||
HloModule* module = hlo->parent()->parent();
|
||||
const Shape& data_shape = rng->shape().tuple_shapes(1);
|
||||
const Shape& state_shape = rng->operand(0)->shape();
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
HloComputation * generator_computation,
|
||||
GetGeneratorComputation(data_shape, state_shape, algorithm, module));
|
||||
return hlo->parent()->AddInstruction(HloInstruction::CreateCall(
|
||||
ShapeUtil::MakeTupleShape({state_shape, data_shape}),
|
||||
{hlo->mutable_operand(0)}, generator_computation));
|
||||
}
|
||||
|
||||
} // namespace xla
|
72
tensorflow/compiler/xla/service/rng_bit_generator_expander.h
Normal file
72
tensorflow/compiler/xla/service/rng_bit_generator_expander.h
Normal file
@ -0,0 +1,72 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_RNG_BIT_GENERATOR_EXPANDER_H_
|
||||
#define TENSORFLOW_COMPILER_XLA_SERVICE_RNG_BIT_GENERATOR_EXPANDER_H_
|
||||
|
||||
#include "absl/container/flat_hash_map.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_computation.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_module.h"
|
||||
#include "tensorflow/compiler/xla/service/op_expander_pass.h"
|
||||
#include "tensorflow/compiler/xla/shape_util.h"
|
||||
#include "tensorflow/compiler/xla/statusor.h"
|
||||
#include "tensorflow/compiler/xla/xla_data.pb.h"
|
||||
|
||||
namespace xla {
|
||||
|
||||
class RngBitGeneratorExpander : public OpExpanderPass {
|
||||
public:
|
||||
explicit RngBitGeneratorExpander(RandomAlgorithm default_algorithm)
|
||||
: default_algorithm_(default_algorithm) {
|
||||
CHECK_NE(default_algorithm_, RandomAlgorithm::RNG_DEFAULT);
|
||||
}
|
||||
|
||||
absl::string_view name() const override {
|
||||
return "rng-bit-generator-expander";
|
||||
}
|
||||
|
||||
protected:
|
||||
struct RngGeneratorKey {
|
||||
Shape data_shape;
|
||||
Shape state_shape;
|
||||
RandomAlgorithm algorithm;
|
||||
HloModule* module;
|
||||
|
||||
template <typename H>
|
||||
friend H AbslHashValue(H h, const RngGeneratorKey& c) {
|
||||
return H::combine(std::move(h), c.state_shape, c.data_shape, c.algorithm,
|
||||
c.module);
|
||||
}
|
||||
|
||||
bool operator==(const RngGeneratorKey& o) const {
|
||||
return data_shape == o.data_shape && state_shape == o.state_shape &&
|
||||
algorithm == o.algorithm && module == o.module;
|
||||
}
|
||||
};
|
||||
|
||||
bool InstructionMatchesPattern(HloInstruction* instruction) override;
|
||||
StatusOr<HloInstruction*> ExpandInstruction(HloInstruction* hlo) override;
|
||||
StatusOr<HloComputation*> GetGeneratorComputation(const Shape& data_shape,
|
||||
const Shape& state_shape,
|
||||
RandomAlgorithm algorithm,
|
||||
HloModule* module);
|
||||
|
||||
const RandomAlgorithm default_algorithm_;
|
||||
absl::flat_hash_map<RngGeneratorKey, HloComputation*> computation_cache_;
|
||||
};
|
||||
|
||||
} // namespace xla
|
||||
|
||||
#endif // TENSORFLOW_COMPILER_XLA_SERVICE_RNG_BIT_GENERATOR_EXPANDER_H_
|
@ -84,7 +84,7 @@ StatusOr<HloComputation*> GetComputationForRng(HloInstruction* rng) {
|
||||
|
||||
auto generator = [](xla::XlaOp key, xla::XlaOp state,
|
||||
const xla::Shape& shape) {
|
||||
return PhiloxBitGenerator(key, state, shape, /*scramble=*/false);
|
||||
return PhiloxBitGenerator(key, state, shape);
|
||||
};
|
||||
|
||||
XlaOp result;
|
||||
|
@ -559,6 +559,13 @@ enum RandomDistribution {
|
||||
// Next: 4
|
||||
}
|
||||
|
||||
enum RandomAlgorithm {
|
||||
RNG_DEFAULT = 0; // Backend dependent default algorithm.
|
||||
RNG_THREE_FRY = 1;
|
||||
RNG_PHILOX = 2;
|
||||
// Next: 2
|
||||
}
|
||||
|
||||
message TriangularSolveOptions {
|
||||
// If true, solves ax = b. If false, solves xa = b.
|
||||
bool left_side = 1;
|
||||
|
Loading…
Reference in New Issue
Block a user