[XLA] Migrate a function from AlgebraicSimplifier to ShapeUtil

PiperOrigin-RevId: 243440029
This commit is contained in:
David Majnemer 2019-04-13 13:10:36 -07:00 committed by TensorFlower Gardener
parent 6e4a2891c2
commit fb695b89e0
3 changed files with 43 additions and 32 deletions

View File

@ -1992,40 +1992,11 @@ Status AlgebraicSimplifierVisitor::HandleGetTupleElement(
namespace { namespace {
// Return whether the given reshape instruction leaves the dimensions at the
// given input indices unmodified, and returns their output indices.
//
// Example:
// input_dim_indices = {2, 3}
// input shape = T[a, b, x, y, cd]
// output shape = T[ab, x, 1, y, c, d]
// return value = {1, 3}
//
// Precondition: input_dim_indices is sorted.
absl::optional<std::vector<int64>> ReshapeLeavesDimensionsUnmodified( absl::optional<std::vector<int64>> ReshapeLeavesDimensionsUnmodified(
const HloInstruction* hlo, absl::Span<const int64> input_dim_indices) { const HloInstruction* hlo, absl::Span<const int64> input_dim_indices) {
CHECK_EQ(HloOpcode::kReshape, hlo->opcode()); CHECK_EQ(hlo->opcode(), HloOpcode::kReshape);
CHECK(std::is_sorted(input_dim_indices.begin(), input_dim_indices.end())); return ShapeUtil::ReshapeLeavesDimensionsUnmodified(
hlo->operand(0)->shape(), hlo->shape(), input_dim_indices);
std::vector<int64> output_dim_indices;
std::vector<std::pair<int64, int64>> unmodified_dims =
ShapeUtil::DimensionsUnmodifiedByReshape(hlo->operand(0)->shape(),
hlo->shape());
size_t i = 0; // index to unmodified_dims
for (int64 input_dim_index : input_dim_indices) {
// Search unmodified_dims for input_dim_index. We can search from the last
// matching position because input_dim_indices is guaranteed to be sorted.
while (i < unmodified_dims.size() &&
unmodified_dims[i].first < input_dim_index) {
++i;
}
if (i >= unmodified_dims.size() ||
unmodified_dims[i].first != input_dim_index) {
return absl::nullopt;
}
output_dim_indices.push_back(unmodified_dims[i].second);
}
return output_dim_indices;
} }
// Returns true if the output of "instruction" is a permutation of the // Returns true if the output of "instruction" is a permutation of the

View File

@ -1078,6 +1078,32 @@ ShapeUtil::DimensionsUnmodifiedByReshape(const Shape& input_shape,
return common_factors; return common_factors;
} }
/* static */ absl::optional<std::vector<int64>>
ShapeUtil::ReshapeLeavesDimensionsUnmodified(
const Shape& from_shape, const Shape& to_shape,
absl::Span<const int64> input_dim_indices) {
CHECK(std::is_sorted(input_dim_indices.begin(), input_dim_indices.end()));
std::vector<int64> output_dim_indices;
std::vector<std::pair<int64, int64>> unmodified_dims =
ShapeUtil::DimensionsUnmodifiedByReshape(from_shape, to_shape);
size_t i = 0; // index to unmodified_dims
for (int64 input_dim_index : input_dim_indices) {
// Search unmodified_dims for input_dim_index. We can search from the last
// matching position because input_dim_indices is guaranteed to be sorted.
while (i < unmodified_dims.size() &&
unmodified_dims[i].first < input_dim_index) {
++i;
}
if (i >= unmodified_dims.size() ||
unmodified_dims[i].first != input_dim_index) {
return absl::nullopt;
}
output_dim_indices.push_back(unmodified_dims[i].second);
}
return output_dim_indices;
}
/* static */ bool ShapeUtil::TransposeIsBitcast( /* static */ bool ShapeUtil::TransposeIsBitcast(
const Shape& input_shape, const Shape& output_shape, const Shape& input_shape, const Shape& output_shape,
absl::Span<const int64> dimension_mapping) { absl::Span<const int64> dimension_mapping) {

View File

@ -578,6 +578,20 @@ class ShapeUtil {
static std::vector<std::pair<int64, int64>> DimensionsUnmodifiedByReshape( static std::vector<std::pair<int64, int64>> DimensionsUnmodifiedByReshape(
const Shape& input_shape, const Shape& output_shape); const Shape& input_shape, const Shape& output_shape);
// Return whether the given reshape instruction leaves the dimensions at the
// given input indices unmodified, and returns their output indices.
//
// Example:
// input_dim_indices = {2, 3}
// input shape = T[a, b, x, y, cd]
// output shape = T[ab, x, 1, y, c, d]
// return value = {1, 3}
//
// Precondition: input_dim_indices is sorted.
static absl::optional<std::vector<int64>> ReshapeLeavesDimensionsUnmodified(
const Shape& from_shape, const Shape& to_shape,
absl::Span<const int64> input_dim_indices);
// Returns whether a transpose from input_shape to output_shape with dimension // Returns whether a transpose from input_shape to output_shape with dimension
// mapping "dimension_mapping" produces a result which is bit-wise identical // mapping "dimension_mapping" produces a result which is bit-wise identical
// to its input and thus may be replaced with a bitcast. // to its input and thus may be replaced with a bitcast.