Rewrite output state to match the input.
PiperOrigin-RevId: 309113452 Change-Id: I2e9f1fbeb6c874fe653b39cd9a284bc83f6d395d
This commit is contained in:
parent
6940d9171c
commit
06473adb2d
@ -32,15 +32,20 @@ limitations under the License.
|
||||
namespace xla {
|
||||
namespace {
|
||||
|
||||
StatusOr<XlaOp> GetPhiloxStateOp(XlaOp input_state) {
|
||||
TF_ASSIGN_OR_RETURN(const Shape* shape,
|
||||
input_state.builder()->GetShapePtr(input_state));
|
||||
if (shape->dimensions(0) >= 3) {
|
||||
XlaOp GetPhiloxStateOp(XlaOp input_state, const Shape& state_shape) {
|
||||
if (state_shape.dimensions(0) >= 3) {
|
||||
return Slice(input_state, {1}, {3}, {1});
|
||||
}
|
||||
return Rev(input_state, {0});
|
||||
}
|
||||
|
||||
XlaOp GetPhiloxOutputStateOp(XlaOp output_state, const Shape& state_shape) {
|
||||
if (state_shape.dimensions(0) < 3) {
|
||||
output_state = Slice(output_state, {0}, {1}, {1});
|
||||
}
|
||||
return output_state;
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
bool RngBitGeneratorExpander::InstructionMatchesPattern(
|
||||
@ -60,25 +65,22 @@ StatusOr<HloComputation*> RngBitGeneratorExpander::GetGeneratorComputation(
|
||||
XlaBuilder builder("rng");
|
||||
XlaOp state_param = Parameter(&builder, 0, state_shape, "state");
|
||||
XlaOp key_op = Reshape(Slice(state_param, {0}, {1}, {1}), {});
|
||||
XlaOp state_op;
|
||||
|
||||
BitGeneratorTy generator = nullptr;
|
||||
RngOutput output;
|
||||
switch (algorithm) {
|
||||
case RandomAlgorithm::RNG_THREE_FRY:
|
||||
generator = ThreeFryBitGenerator;
|
||||
state_op = Slice(state_param, {1}, {2}, {1});
|
||||
output = ThreeFryBitGenerator(key_op, Slice(state_param, {1}, {2}, {1}),
|
||||
data_shape);
|
||||
break;
|
||||
case RandomAlgorithm::RNG_PHILOX: {
|
||||
generator = PhiloxBitGenerator;
|
||||
TF_ASSIGN_OR_RETURN(state_op, GetPhiloxStateOp(state_param));
|
||||
case RandomAlgorithm::RNG_PHILOX:
|
||||
output = PhiloxBitGenerator(
|
||||
key_op, GetPhiloxStateOp(state_param, state_shape), data_shape);
|
||||
output.state = GetPhiloxOutputStateOp(output.state, state_shape);
|
||||
break;
|
||||
}
|
||||
default:
|
||||
return Unimplemented("Unsupported random algorthm: %s",
|
||||
RandomAlgorithm_Name(algorithm));
|
||||
}
|
||||
|
||||
RngOutput output = generator(key_op, state_op, data_shape);
|
||||
XlaOp final_state =
|
||||
ConcatInDim(&builder, {Reshape(key_op, {1}), output.state}, 0);
|
||||
Tuple(&builder, {final_state, output.value});
|
||||
|
Loading…
Reference in New Issue
Block a user