[XLA] Avoid reshape to R1 in NormalFloatingPointDistribution

This refactors existing shape split code and reuses it for NormalFloatingPointDistribution

PiperOrigin-RevId: 347643318
Change-Id: Ie3d3c0145cf6a275a5f678ee4258f88c7de18a5a
This commit is contained in:
Yuanzhong Xu 2020-12-15 10:32:04 -08:00 committed by TensorFlower Gardener
parent c6228dd44a
commit 7c096232b2

View File

@ -158,58 +158,96 @@ std::pair<ThreeFry2x32State, XlaOp> GetThreeFryInputsAndUpdatedState(
return std::make_pair(Uint64ToUint32s(input_u64), new_state);
}
// Result for SplitShapeIntoHalves().
struct SplitShapePair {
Shape half_shape;
Shape concat_shape;
int64 split_dim;
int64 new_concat_dim;
};
// Split the shape on a dimension > 1 into two halves.
SplitShapePair SplitShapeIntoHalves(const Shape& shape) {
SplitShapePair pair;
if (shape.rank() == 0) {
pair.half_shape = ShapeUtil::MakeShape(shape.element_type(), {1});
pair.concat_shape = ShapeUtil::MakeShape(shape.element_type(), {2});
pair.split_dim = 0;
pair.new_concat_dim = 0;
return pair;
}
pair.split_dim = -1;
for (int64 i = 0; i < shape.rank(); ++i) {
if (shape.dimensions(i) % 2 == 0) {
pair.split_dim = i;
break;
}
}
if (pair.split_dim == -1) {
// No even dims. Find a dimension with maximum size.
for (int64 i = 0; i < shape.rank(); ++i) {
if (pair.split_dim == -1 ||
shape.dimensions(i) > shape.dimensions(pair.split_dim)) {
pair.split_dim = i;
}
}
}
CHECK_GE(pair.split_dim, 0);
std::vector<int64> half_shape_dims;
std::vector<int64> concat_shape_dims;
for (int64 i = 0; i < shape.rank(); ++i) {
if (i == pair.split_dim) {
// Create a new trivial dim for the later concat, which is more friendly
// to sharding propagation.
half_shape_dims.push_back(CeilOfRatio<int64>(shape.dimensions(i), 2));
half_shape_dims.push_back(1);
concat_shape_dims.push_back(half_shape_dims[i]);
concat_shape_dims.push_back(2);
} else {
half_shape_dims.push_back(shape.dimensions(i));
concat_shape_dims.push_back(shape.dimensions(i));
}
}
pair.new_concat_dim = pair.split_dim + 1;
pair.half_shape = ShapeUtil::MakeShape(shape.element_type(), half_shape_dims);
pair.concat_shape =
ShapeUtil::MakeShape(shape.element_type(), concat_shape_dims);
return pair;
}
// Combines a pair of split shapes. It works with scalar and non-scalar shapes.
XlaOp CombineShapePair(absl::Span<const XlaOp> pair,
const SplitShapePair& shape_pair,
const Shape& original_shape) {
if (original_shape.rank() == 0) {
return Reshape(pair[0], {});
}
XlaBuilder* builder = pair[0].builder();
XlaOp result = ConcatInDim(builder, pair, shape_pair.new_concat_dim);
const int64 pre_split_size = original_shape.dimensions(shape_pair.split_dim);
std::vector<int64> reshape_dims(original_shape.dimensions().begin(),
original_shape.dimensions().end());
reshape_dims[shape_pair.split_dim] =
RoundUpToNearest<int64>(pre_split_size, 2);
result = Reshape(result, reshape_dims);
if (reshape_dims[shape_pair.split_dim] != pre_split_size) {
result = Slice(result, std::vector<int64>(original_shape.rank(), 0),
original_shape.dimensions(),
std::vector<int64>(original_shape.rank(), 1));
}
return result;
}
// Generates random 32bits with the given shape using the Three Fry
// implementation. Returns the random bits and the new state.
RngOutput ThreeFryRngBit32(XlaOp key, XlaOp initial_state, const Shape& shape) {
XlaBuilder* builder = key.builder();
// Try to split the shape on a dimension > 1 into two halves, each
// representing a U32 value.
std::vector<int64> half_shape_dims;
std::vector<int64> padded_full_shape_dims;
int64 split_dim = -1;
for (int64 i = 0; i < shape.rank(); ++i) {
if (shape.dimensions(i) > 1 && split_dim < 0) {
half_shape_dims.push_back(CeilOfRatio<int64>(shape.dimensions(i), 2));
// Create a new trivial dim for the later concat, which is more friendly
// to sharding propagation.
half_shape_dims.push_back(1);
split_dim = i;
padded_full_shape_dims.push_back(half_shape_dims[i] * 2);
} else {
half_shape_dims.push_back(shape.dimensions(i));
padded_full_shape_dims.push_back(shape.dimensions(i));
}
}
auto half_shape = ShapeUtil::MakeShape(shape.element_type(), half_shape_dims);
if (split_dim >= 0) {
std::pair<ThreeFry2x32State, XlaOp> inputs_state =
GetThreeFryInputsAndUpdatedState(initial_state, half_shape);
ThreeFry2x32State inputs = inputs_state.first;
ThreeFry2x32State outputs = ThreeFry2x32(inputs, Uint64ToUint32s(key));
XlaOp result = ConcatInDim(builder, outputs, split_dim + 1);
result = Reshape(result, padded_full_shape_dims);
if (shape.dimensions(split_dim) % 2 != 0) {
result = Slice(result, std::vector<int64>(shape.rank(), 0),
shape.dimensions(), std::vector<int64>(shape.rank(), 1));
}
return {result, inputs_state.second};
}
// Use an R1 shape if the previous attempt failed.
const int64 size = ShapeUtil::ElementsIn(shape);
const int64 half_size = CeilOfRatio<int64>(size, 2);
const bool size_is_odd = (half_size * 2 != size);
auto shape_pair = SplitShapeIntoHalves(shape);
std::pair<ThreeFry2x32State, XlaOp> inputs_state =
GetThreeFryInputsAndUpdatedState(
initial_state,
ShapeUtil::MakeShape(shape.element_type(), {half_size}));
GetThreeFryInputsAndUpdatedState(initial_state, shape_pair.half_shape);
ThreeFry2x32State inputs = inputs_state.first;
ThreeFry2x32State outputs = ThreeFry2x32(inputs, Uint64ToUint32s(key));
if (size_is_odd) {
outputs[1] = Slice(outputs[1], {0}, {half_size - 1}, {1});
}
XlaOp result = ConcatInDim(builder, outputs, 0);
return {Reshape(result, AsInt64Slice(shape.dimensions())),
inputs_state.second};
XlaOp result = CombineShapePair(outputs, shape_pair, shape);
return {result, inputs_state.second};
}
// Generates random 64bits with the given shape using the Three Fry
@ -577,27 +615,27 @@ RngOutput NormalFloatingPointDistribution(XlaOp key, XlaOp initial_state,
DCHECK(primitive_type == F32 || primitive_type == F64);
XlaBuilder* builder = key.builder();
const int64 num_elems = ShapeUtil::ElementsIn(shape);
const int64 num_pairs = CeilOfRatio<int64>(num_elems, 2);
auto shape_pair = SplitShapeIntoHalves(shape);
RngOutput bits_state = UniformFloatingPointDistribution(
key, initial_state, bit_generator,
xla::ConstantR0WithType(builder, primitive_type, 0.0),
xla::ConstantR0WithType(builder, primitive_type, 1.0),
ShapeUtil::MakeShape(primitive_type, {num_pairs * 2}));
shape_pair.concat_shape);
// Separate the bits into two groups to perform the Box-Muller transform.
XlaOp bits_0 = Slice(bits_state.value, {0}, {num_pairs}, {1});
XlaOp bits_1 = Slice(bits_state.value, {num_pairs}, {2 * num_pairs}, {1});
XlaOp bits_0 = Slice(bits_state.value,
std::vector<int64>(shape_pair.half_shape.rank(), 0),
shape_pair.half_shape.dimensions(),
std::vector<int64>(shape_pair.half_shape.rank(), 1));
std::vector<int64> bits_1_starts(shape_pair.half_shape.rank(), 0);
bits_1_starts[shape_pair.new_concat_dim] = 1;
XlaOp bits_1 = Slice(bits_state.value, bits_1_starts,
shape_pair.concat_shape.dimensions(),
std::vector<int64>(shape_pair.half_shape.rank(), 1));
std::tie(bits_0, bits_1) = BoxMullerTransform(bits_0, bits_1);
// Put the numbers in the two groups back to form the requested shape.
XlaOp normal = ConcatInDim(builder, {bits_0, bits_1}, /*dimension=*/0);
if (num_elems != num_pairs * 2) {
normal = Slice(normal, /*start_indices=*/{0}, /*limit_indices=*/{num_elems},
/*strides=*/{1});
}
normal = Reshape(normal, shape.dimensions());
XlaOp normal = CombineShapePair({bits_0, bits_1}, shape_pair, shape);
return {normal, bits_state.state};
}