[XLA] Migrate a function from AlgebraicSimplifier to ShapeUtil
PiperOrigin-RevId: 243440029
This commit is contained in:
parent
6e4a2891c2
commit
fb695b89e0
@ -1992,40 +1992,11 @@ Status AlgebraicSimplifierVisitor::HandleGetTupleElement(
|
||||
|
||||
namespace {
|
||||
|
||||
// Return whether the given reshape instruction leaves the dimensions at the
|
||||
// given input indices unmodified, and returns their output indices.
|
||||
//
|
||||
// Example:
|
||||
// input_dim_indices = {2, 3}
|
||||
// input shape = T[a, b, x, y, cd]
|
||||
// output shape = T[ab, x, 1, y, c, d]
|
||||
// return value = {1, 3}
|
||||
//
|
||||
// Precondition: input_dim_indices is sorted.
|
||||
absl::optional<std::vector<int64>> ReshapeLeavesDimensionsUnmodified(
|
||||
const HloInstruction* hlo, absl::Span<const int64> input_dim_indices) {
|
||||
CHECK_EQ(HloOpcode::kReshape, hlo->opcode());
|
||||
CHECK(std::is_sorted(input_dim_indices.begin(), input_dim_indices.end()));
|
||||
|
||||
std::vector<int64> output_dim_indices;
|
||||
std::vector<std::pair<int64, int64>> unmodified_dims =
|
||||
ShapeUtil::DimensionsUnmodifiedByReshape(hlo->operand(0)->shape(),
|
||||
hlo->shape());
|
||||
size_t i = 0; // index to unmodified_dims
|
||||
for (int64 input_dim_index : input_dim_indices) {
|
||||
// Search unmodified_dims for input_dim_index. We can search from the last
|
||||
// matching position because input_dim_indices is guaranteed to be sorted.
|
||||
while (i < unmodified_dims.size() &&
|
||||
unmodified_dims[i].first < input_dim_index) {
|
||||
++i;
|
||||
}
|
||||
if (i >= unmodified_dims.size() ||
|
||||
unmodified_dims[i].first != input_dim_index) {
|
||||
return absl::nullopt;
|
||||
}
|
||||
output_dim_indices.push_back(unmodified_dims[i].second);
|
||||
}
|
||||
return output_dim_indices;
|
||||
CHECK_EQ(hlo->opcode(), HloOpcode::kReshape);
|
||||
return ShapeUtil::ReshapeLeavesDimensionsUnmodified(
|
||||
hlo->operand(0)->shape(), hlo->shape(), input_dim_indices);
|
||||
}
|
||||
|
||||
// Returns true if the output of "instruction" is a permutation of the
|
||||
|
@ -1078,6 +1078,32 @@ ShapeUtil::DimensionsUnmodifiedByReshape(const Shape& input_shape,
|
||||
return common_factors;
|
||||
}
|
||||
|
||||
/* static */ absl::optional<std::vector<int64>>
|
||||
ShapeUtil::ReshapeLeavesDimensionsUnmodified(
|
||||
const Shape& from_shape, const Shape& to_shape,
|
||||
absl::Span<const int64> input_dim_indices) {
|
||||
CHECK(std::is_sorted(input_dim_indices.begin(), input_dim_indices.end()));
|
||||
|
||||
std::vector<int64> output_dim_indices;
|
||||
std::vector<std::pair<int64, int64>> unmodified_dims =
|
||||
ShapeUtil::DimensionsUnmodifiedByReshape(from_shape, to_shape);
|
||||
size_t i = 0; // index to unmodified_dims
|
||||
for (int64 input_dim_index : input_dim_indices) {
|
||||
// Search unmodified_dims for input_dim_index. We can search from the last
|
||||
// matching position because input_dim_indices is guaranteed to be sorted.
|
||||
while (i < unmodified_dims.size() &&
|
||||
unmodified_dims[i].first < input_dim_index) {
|
||||
++i;
|
||||
}
|
||||
if (i >= unmodified_dims.size() ||
|
||||
unmodified_dims[i].first != input_dim_index) {
|
||||
return absl::nullopt;
|
||||
}
|
||||
output_dim_indices.push_back(unmodified_dims[i].second);
|
||||
}
|
||||
return output_dim_indices;
|
||||
}
|
||||
|
||||
/* static */ bool ShapeUtil::TransposeIsBitcast(
|
||||
const Shape& input_shape, const Shape& output_shape,
|
||||
absl::Span<const int64> dimension_mapping) {
|
||||
|
@ -578,6 +578,20 @@ class ShapeUtil {
|
||||
static std::vector<std::pair<int64, int64>> DimensionsUnmodifiedByReshape(
|
||||
const Shape& input_shape, const Shape& output_shape);
|
||||
|
||||
// Return whether the given reshape instruction leaves the dimensions at the
|
||||
// given input indices unmodified, and returns their output indices.
|
||||
//
|
||||
// Example:
|
||||
// input_dim_indices = {2, 3}
|
||||
// input shape = T[a, b, x, y, cd]
|
||||
// output shape = T[ab, x, 1, y, c, d]
|
||||
// return value = {1, 3}
|
||||
//
|
||||
// Precondition: input_dim_indices is sorted.
|
||||
static absl::optional<std::vector<int64>> ReshapeLeavesDimensionsUnmodified(
|
||||
const Shape& from_shape, const Shape& to_shape,
|
||||
absl::Span<const int64> input_dim_indices);
|
||||
|
||||
// Returns whether a transpose from input_shape to output_shape with dimension
|
||||
// mapping "dimension_mapping" produces a result which is bit-wise identical
|
||||
// to its input and thus may be replaced with a bitcast.
|
||||
|
Loading…
Reference in New Issue
Block a user