[XLA] Migrate a function from AlgebraicSimplifier to ShapeUtil
PiperOrigin-RevId: 243440029
This commit is contained in:
parent
6e4a2891c2
commit
fb695b89e0
@ -1992,40 +1992,11 @@ Status AlgebraicSimplifierVisitor::HandleGetTupleElement(
|
|||||||
|
|
||||||
namespace {
|
namespace {
|
||||||
|
|
||||||
// Return whether the given reshape instruction leaves the dimensions at the
|
|
||||||
// given input indices unmodified, and returns their output indices.
|
|
||||||
//
|
|
||||||
// Example:
|
|
||||||
// input_dim_indices = {2, 3}
|
|
||||||
// input shape = T[a, b, x, y, cd]
|
|
||||||
// output shape = T[ab, x, 1, y, c, d]
|
|
||||||
// return value = {1, 3}
|
|
||||||
//
|
|
||||||
// Precondition: input_dim_indices is sorted.
|
|
||||||
absl::optional<std::vector<int64>> ReshapeLeavesDimensionsUnmodified(
|
absl::optional<std::vector<int64>> ReshapeLeavesDimensionsUnmodified(
|
||||||
const HloInstruction* hlo, absl::Span<const int64> input_dim_indices) {
|
const HloInstruction* hlo, absl::Span<const int64> input_dim_indices) {
|
||||||
CHECK_EQ(HloOpcode::kReshape, hlo->opcode());
|
CHECK_EQ(hlo->opcode(), HloOpcode::kReshape);
|
||||||
CHECK(std::is_sorted(input_dim_indices.begin(), input_dim_indices.end()));
|
return ShapeUtil::ReshapeLeavesDimensionsUnmodified(
|
||||||
|
hlo->operand(0)->shape(), hlo->shape(), input_dim_indices);
|
||||||
std::vector<int64> output_dim_indices;
|
|
||||||
std::vector<std::pair<int64, int64>> unmodified_dims =
|
|
||||||
ShapeUtil::DimensionsUnmodifiedByReshape(hlo->operand(0)->shape(),
|
|
||||||
hlo->shape());
|
|
||||||
size_t i = 0; // index to unmodified_dims
|
|
||||||
for (int64 input_dim_index : input_dim_indices) {
|
|
||||||
// Search unmodified_dims for input_dim_index. We can search from the last
|
|
||||||
// matching position because input_dim_indices is guaranteed to be sorted.
|
|
||||||
while (i < unmodified_dims.size() &&
|
|
||||||
unmodified_dims[i].first < input_dim_index) {
|
|
||||||
++i;
|
|
||||||
}
|
|
||||||
if (i >= unmodified_dims.size() ||
|
|
||||||
unmodified_dims[i].first != input_dim_index) {
|
|
||||||
return absl::nullopt;
|
|
||||||
}
|
|
||||||
output_dim_indices.push_back(unmodified_dims[i].second);
|
|
||||||
}
|
|
||||||
return output_dim_indices;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Returns true if the output of "instruction" is a permutation of the
|
// Returns true if the output of "instruction" is a permutation of the
|
||||||
|
@ -1078,6 +1078,32 @@ ShapeUtil::DimensionsUnmodifiedByReshape(const Shape& input_shape,
|
|||||||
return common_factors;
|
return common_factors;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/* static */ absl::optional<std::vector<int64>>
|
||||||
|
ShapeUtil::ReshapeLeavesDimensionsUnmodified(
|
||||||
|
const Shape& from_shape, const Shape& to_shape,
|
||||||
|
absl::Span<const int64> input_dim_indices) {
|
||||||
|
CHECK(std::is_sorted(input_dim_indices.begin(), input_dim_indices.end()));
|
||||||
|
|
||||||
|
std::vector<int64> output_dim_indices;
|
||||||
|
std::vector<std::pair<int64, int64>> unmodified_dims =
|
||||||
|
ShapeUtil::DimensionsUnmodifiedByReshape(from_shape, to_shape);
|
||||||
|
size_t i = 0; // index to unmodified_dims
|
||||||
|
for (int64 input_dim_index : input_dim_indices) {
|
||||||
|
// Search unmodified_dims for input_dim_index. We can search from the last
|
||||||
|
// matching position because input_dim_indices is guaranteed to be sorted.
|
||||||
|
while (i < unmodified_dims.size() &&
|
||||||
|
unmodified_dims[i].first < input_dim_index) {
|
||||||
|
++i;
|
||||||
|
}
|
||||||
|
if (i >= unmodified_dims.size() ||
|
||||||
|
unmodified_dims[i].first != input_dim_index) {
|
||||||
|
return absl::nullopt;
|
||||||
|
}
|
||||||
|
output_dim_indices.push_back(unmodified_dims[i].second);
|
||||||
|
}
|
||||||
|
return output_dim_indices;
|
||||||
|
}
|
||||||
|
|
||||||
/* static */ bool ShapeUtil::TransposeIsBitcast(
|
/* static */ bool ShapeUtil::TransposeIsBitcast(
|
||||||
const Shape& input_shape, const Shape& output_shape,
|
const Shape& input_shape, const Shape& output_shape,
|
||||||
absl::Span<const int64> dimension_mapping) {
|
absl::Span<const int64> dimension_mapping) {
|
||||||
|
@ -578,6 +578,20 @@ class ShapeUtil {
|
|||||||
static std::vector<std::pair<int64, int64>> DimensionsUnmodifiedByReshape(
|
static std::vector<std::pair<int64, int64>> DimensionsUnmodifiedByReshape(
|
||||||
const Shape& input_shape, const Shape& output_shape);
|
const Shape& input_shape, const Shape& output_shape);
|
||||||
|
|
||||||
|
// Return whether the given reshape instruction leaves the dimensions at the
|
||||||
|
// given input indices unmodified, and returns their output indices.
|
||||||
|
//
|
||||||
|
// Example:
|
||||||
|
// input_dim_indices = {2, 3}
|
||||||
|
// input shape = T[a, b, x, y, cd]
|
||||||
|
// output shape = T[ab, x, 1, y, c, d]
|
||||||
|
// return value = {1, 3}
|
||||||
|
//
|
||||||
|
// Precondition: input_dim_indices is sorted.
|
||||||
|
static absl::optional<std::vector<int64>> ReshapeLeavesDimensionsUnmodified(
|
||||||
|
const Shape& from_shape, const Shape& to_shape,
|
||||||
|
absl::Span<const int64> input_dim_indices);
|
||||||
|
|
||||||
// Returns whether a transpose from input_shape to output_shape with dimension
|
// Returns whether a transpose from input_shape to output_shape with dimension
|
||||||
// mapping "dimension_mapping" produces a result which is bit-wise identical
|
// mapping "dimension_mapping" produces a result which is bit-wise identical
|
||||||
// to its input and thus may be replaced with a bitcast.
|
// to its input and thus may be replaced with a bitcast.
|
||||||
|
Loading…
Reference in New Issue
Block a user