Rewrite output state to match the input.

PiperOrigin-RevId: 309113452
Change-Id: I2e9f1fbeb6c874fe653b39cd9a284bc83f6d395d
This commit is contained in:
Davide Libenzi 2020-04-29 15:47:11 -07:00 committed by TensorFlower Gardener
parent 6940d9171c
commit 06473adb2d

View File

@ -32,15 +32,20 @@ limitations under the License.
namespace xla {
namespace {
StatusOr<XlaOp> GetPhiloxStateOp(XlaOp input_state) {
TF_ASSIGN_OR_RETURN(const Shape* shape,
input_state.builder()->GetShapePtr(input_state));
if (shape->dimensions(0) >= 3) {
XlaOp GetPhiloxStateOp(XlaOp input_state, const Shape& state_shape) {
if (state_shape.dimensions(0) >= 3) {
return Slice(input_state, {1}, {3}, {1});
}
return Rev(input_state, {0});
}
XlaOp GetPhiloxOutputStateOp(XlaOp output_state, const Shape& state_shape) {
if (state_shape.dimensions(0) < 3) {
output_state = Slice(output_state, {0}, {1}, {1});
}
return output_state;
}
} // namespace
bool RngBitGeneratorExpander::InstructionMatchesPattern(
@ -60,25 +65,22 @@ StatusOr<HloComputation*> RngBitGeneratorExpander::GetGeneratorComputation(
XlaBuilder builder("rng");
XlaOp state_param = Parameter(&builder, 0, state_shape, "state");
XlaOp key_op = Reshape(Slice(state_param, {0}, {1}, {1}), {});
XlaOp state_op;
BitGeneratorTy generator = nullptr;
RngOutput output;
switch (algorithm) {
case RandomAlgorithm::RNG_THREE_FRY:
generator = ThreeFryBitGenerator;
state_op = Slice(state_param, {1}, {2}, {1});
output = ThreeFryBitGenerator(key_op, Slice(state_param, {1}, {2}, {1}),
data_shape);
break;
case RandomAlgorithm::RNG_PHILOX: {
generator = PhiloxBitGenerator;
TF_ASSIGN_OR_RETURN(state_op, GetPhiloxStateOp(state_param));
case RandomAlgorithm::RNG_PHILOX:
output = PhiloxBitGenerator(
key_op, GetPhiloxStateOp(state_param, state_shape), data_shape);
output.state = GetPhiloxOutputStateOp(output.state, state_shape);
break;
}
default:
return Unimplemented("Unsupported random algorthm: %s",
RandomAlgorithm_Name(algorithm));
}
RngOutput output = generator(key_op, state_op, data_shape);
XlaOp final_state =
ConcatInDim(&builder, {Reshape(key_op, {1}), output.state}, 0);
Tuple(&builder, {final_state, output.value});