[XLA] Avoid reshape to R1 in NormalFloatingPointDistribution
This refactors existing shape split code and reuses it for NormalFloatingPointDistribution PiperOrigin-RevId: 347643318 Change-Id: Ie3d3c0145cf6a275a5f678ee4258f88c7de18a5a
This commit is contained in:
parent
c6228dd44a
commit
7c096232b2
@ -158,58 +158,96 @@ std::pair<ThreeFry2x32State, XlaOp> GetThreeFryInputsAndUpdatedState(
|
||||
return std::make_pair(Uint64ToUint32s(input_u64), new_state);
|
||||
}
|
||||
|
||||
// Result for SplitShapeIntoHalves().
|
||||
struct SplitShapePair {
|
||||
Shape half_shape;
|
||||
Shape concat_shape;
|
||||
int64 split_dim;
|
||||
int64 new_concat_dim;
|
||||
};
|
||||
|
||||
// Split the shape on a dimension > 1 into two halves.
|
||||
SplitShapePair SplitShapeIntoHalves(const Shape& shape) {
|
||||
SplitShapePair pair;
|
||||
if (shape.rank() == 0) {
|
||||
pair.half_shape = ShapeUtil::MakeShape(shape.element_type(), {1});
|
||||
pair.concat_shape = ShapeUtil::MakeShape(shape.element_type(), {2});
|
||||
pair.split_dim = 0;
|
||||
pair.new_concat_dim = 0;
|
||||
return pair;
|
||||
}
|
||||
pair.split_dim = -1;
|
||||
for (int64 i = 0; i < shape.rank(); ++i) {
|
||||
if (shape.dimensions(i) % 2 == 0) {
|
||||
pair.split_dim = i;
|
||||
break;
|
||||
}
|
||||
}
|
||||
if (pair.split_dim == -1) {
|
||||
// No even dims. Find a dimension with maximum size.
|
||||
for (int64 i = 0; i < shape.rank(); ++i) {
|
||||
if (pair.split_dim == -1 ||
|
||||
shape.dimensions(i) > shape.dimensions(pair.split_dim)) {
|
||||
pair.split_dim = i;
|
||||
}
|
||||
}
|
||||
}
|
||||
CHECK_GE(pair.split_dim, 0);
|
||||
std::vector<int64> half_shape_dims;
|
||||
std::vector<int64> concat_shape_dims;
|
||||
for (int64 i = 0; i < shape.rank(); ++i) {
|
||||
if (i == pair.split_dim) {
|
||||
// Create a new trivial dim for the later concat, which is more friendly
|
||||
// to sharding propagation.
|
||||
half_shape_dims.push_back(CeilOfRatio<int64>(shape.dimensions(i), 2));
|
||||
half_shape_dims.push_back(1);
|
||||
concat_shape_dims.push_back(half_shape_dims[i]);
|
||||
concat_shape_dims.push_back(2);
|
||||
} else {
|
||||
half_shape_dims.push_back(shape.dimensions(i));
|
||||
concat_shape_dims.push_back(shape.dimensions(i));
|
||||
}
|
||||
}
|
||||
pair.new_concat_dim = pair.split_dim + 1;
|
||||
pair.half_shape = ShapeUtil::MakeShape(shape.element_type(), half_shape_dims);
|
||||
pair.concat_shape =
|
||||
ShapeUtil::MakeShape(shape.element_type(), concat_shape_dims);
|
||||
return pair;
|
||||
}
|
||||
|
||||
// Combines a pair of split shapes. It works with scalar and non-scalar shapes.
|
||||
XlaOp CombineShapePair(absl::Span<const XlaOp> pair,
|
||||
const SplitShapePair& shape_pair,
|
||||
const Shape& original_shape) {
|
||||
if (original_shape.rank() == 0) {
|
||||
return Reshape(pair[0], {});
|
||||
}
|
||||
XlaBuilder* builder = pair[0].builder();
|
||||
XlaOp result = ConcatInDim(builder, pair, shape_pair.new_concat_dim);
|
||||
const int64 pre_split_size = original_shape.dimensions(shape_pair.split_dim);
|
||||
std::vector<int64> reshape_dims(original_shape.dimensions().begin(),
|
||||
original_shape.dimensions().end());
|
||||
reshape_dims[shape_pair.split_dim] =
|
||||
RoundUpToNearest<int64>(pre_split_size, 2);
|
||||
result = Reshape(result, reshape_dims);
|
||||
if (reshape_dims[shape_pair.split_dim] != pre_split_size) {
|
||||
result = Slice(result, std::vector<int64>(original_shape.rank(), 0),
|
||||
original_shape.dimensions(),
|
||||
std::vector<int64>(original_shape.rank(), 1));
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
// Generates random 32bits with the given shape using the Three Fry
|
||||
// implementation. Returns the random bits and the new state.
|
||||
RngOutput ThreeFryRngBit32(XlaOp key, XlaOp initial_state, const Shape& shape) {
|
||||
XlaBuilder* builder = key.builder();
|
||||
// Try to split the shape on a dimension > 1 into two halves, each
|
||||
// representing a U32 value.
|
||||
std::vector<int64> half_shape_dims;
|
||||
std::vector<int64> padded_full_shape_dims;
|
||||
int64 split_dim = -1;
|
||||
for (int64 i = 0; i < shape.rank(); ++i) {
|
||||
if (shape.dimensions(i) > 1 && split_dim < 0) {
|
||||
half_shape_dims.push_back(CeilOfRatio<int64>(shape.dimensions(i), 2));
|
||||
// Create a new trivial dim for the later concat, which is more friendly
|
||||
// to sharding propagation.
|
||||
half_shape_dims.push_back(1);
|
||||
split_dim = i;
|
||||
padded_full_shape_dims.push_back(half_shape_dims[i] * 2);
|
||||
} else {
|
||||
half_shape_dims.push_back(shape.dimensions(i));
|
||||
padded_full_shape_dims.push_back(shape.dimensions(i));
|
||||
}
|
||||
}
|
||||
auto half_shape = ShapeUtil::MakeShape(shape.element_type(), half_shape_dims);
|
||||
if (split_dim >= 0) {
|
||||
std::pair<ThreeFry2x32State, XlaOp> inputs_state =
|
||||
GetThreeFryInputsAndUpdatedState(initial_state, half_shape);
|
||||
ThreeFry2x32State inputs = inputs_state.first;
|
||||
ThreeFry2x32State outputs = ThreeFry2x32(inputs, Uint64ToUint32s(key));
|
||||
XlaOp result = ConcatInDim(builder, outputs, split_dim + 1);
|
||||
result = Reshape(result, padded_full_shape_dims);
|
||||
if (shape.dimensions(split_dim) % 2 != 0) {
|
||||
result = Slice(result, std::vector<int64>(shape.rank(), 0),
|
||||
shape.dimensions(), std::vector<int64>(shape.rank(), 1));
|
||||
}
|
||||
return {result, inputs_state.second};
|
||||
}
|
||||
// Use an R1 shape if the previous attempt failed.
|
||||
const int64 size = ShapeUtil::ElementsIn(shape);
|
||||
const int64 half_size = CeilOfRatio<int64>(size, 2);
|
||||
const bool size_is_odd = (half_size * 2 != size);
|
||||
auto shape_pair = SplitShapeIntoHalves(shape);
|
||||
std::pair<ThreeFry2x32State, XlaOp> inputs_state =
|
||||
GetThreeFryInputsAndUpdatedState(
|
||||
initial_state,
|
||||
ShapeUtil::MakeShape(shape.element_type(), {half_size}));
|
||||
GetThreeFryInputsAndUpdatedState(initial_state, shape_pair.half_shape);
|
||||
ThreeFry2x32State inputs = inputs_state.first;
|
||||
ThreeFry2x32State outputs = ThreeFry2x32(inputs, Uint64ToUint32s(key));
|
||||
if (size_is_odd) {
|
||||
outputs[1] = Slice(outputs[1], {0}, {half_size - 1}, {1});
|
||||
}
|
||||
XlaOp result = ConcatInDim(builder, outputs, 0);
|
||||
return {Reshape(result, AsInt64Slice(shape.dimensions())),
|
||||
inputs_state.second};
|
||||
XlaOp result = CombineShapePair(outputs, shape_pair, shape);
|
||||
return {result, inputs_state.second};
|
||||
}
|
||||
|
||||
// Generates random 64bits with the given shape using the Three Fry
|
||||
@ -577,27 +615,27 @@ RngOutput NormalFloatingPointDistribution(XlaOp key, XlaOp initial_state,
|
||||
DCHECK(primitive_type == F32 || primitive_type == F64);
|
||||
|
||||
XlaBuilder* builder = key.builder();
|
||||
const int64 num_elems = ShapeUtil::ElementsIn(shape);
|
||||
const int64 num_pairs = CeilOfRatio<int64>(num_elems, 2);
|
||||
auto shape_pair = SplitShapeIntoHalves(shape);
|
||||
RngOutput bits_state = UniformFloatingPointDistribution(
|
||||
key, initial_state, bit_generator,
|
||||
xla::ConstantR0WithType(builder, primitive_type, 0.0),
|
||||
xla::ConstantR0WithType(builder, primitive_type, 1.0),
|
||||
ShapeUtil::MakeShape(primitive_type, {num_pairs * 2}));
|
||||
shape_pair.concat_shape);
|
||||
|
||||
// Separate the bits into two groups to perform the Box-Muller transform.
|
||||
XlaOp bits_0 = Slice(bits_state.value, {0}, {num_pairs}, {1});
|
||||
XlaOp bits_1 = Slice(bits_state.value, {num_pairs}, {2 * num_pairs}, {1});
|
||||
XlaOp bits_0 = Slice(bits_state.value,
|
||||
std::vector<int64>(shape_pair.half_shape.rank(), 0),
|
||||
shape_pair.half_shape.dimensions(),
|
||||
std::vector<int64>(shape_pair.half_shape.rank(), 1));
|
||||
std::vector<int64> bits_1_starts(shape_pair.half_shape.rank(), 0);
|
||||
bits_1_starts[shape_pair.new_concat_dim] = 1;
|
||||
XlaOp bits_1 = Slice(bits_state.value, bits_1_starts,
|
||||
shape_pair.concat_shape.dimensions(),
|
||||
std::vector<int64>(shape_pair.half_shape.rank(), 1));
|
||||
std::tie(bits_0, bits_1) = BoxMullerTransform(bits_0, bits_1);
|
||||
|
||||
// Put the numbers in the two groups back to form the requested shape.
|
||||
XlaOp normal = ConcatInDim(builder, {bits_0, bits_1}, /*dimension=*/0);
|
||||
if (num_elems != num_pairs * 2) {
|
||||
normal = Slice(normal, /*start_indices=*/{0}, /*limit_indices=*/{num_elems},
|
||||
/*strides=*/{1});
|
||||
}
|
||||
normal = Reshape(normal, shape.dimensions());
|
||||
|
||||
XlaOp normal = CombineShapePair({bits_0, bits_1}, shape_pair, shape);
|
||||
return {normal, bits_state.state};
|
||||
}
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user