[XLA] Make liberal use of inlined vectors to avoid memory allocation inside
loops. PiperOrigin-RevId: 277799911 Change-Id: I313ac54c836f1cb6aeaa7900f177783396940435
This commit is contained in:
parent
bf9c54ae1f
commit
093da5c248
@ -2285,7 +2285,7 @@ bool OutputIsPermutationOfOperandElements(HloInstruction* instruction,
|
||||
// "operand". Precondition: "operand" is an operand of "instruction".
|
||||
bool OutputIsSubsetOfOperandElements(HloInstruction* instruction,
|
||||
HloInstruction* operand) {
|
||||
std::vector<int64> operand_indices = instruction->OperandIndices(operand);
|
||||
const auto operand_indices = instruction->OperandIndices(operand);
|
||||
CHECK(!operand_indices.empty());
|
||||
if (operand_indices.size() != 1) {
|
||||
return false;
|
||||
|
@ -1155,7 +1155,7 @@ bool HloDataflowAnalysis::CanShareOperandBufferWithUser(
|
||||
user->opcode() == HloOpcode::kWhile) {
|
||||
// We eliminated other users in HloOrdering::LiveRangeStrictlyBefore
|
||||
// so here we just need to check that the use is at the right operand index.
|
||||
std::vector<int64> operand_indices = user->OperandIndices(operand);
|
||||
const auto operand_indices = user->OperandIndices(operand);
|
||||
int64 operand_no = user->opcode() == HloOpcode::kTriangularSolve ? 1 : 0;
|
||||
return operand_indices.size() == 1 && operand_indices[0] == operand_no;
|
||||
}
|
||||
@ -1171,7 +1171,7 @@ bool HloDataflowAnalysis::CanShareOperandBufferWithUser(
|
||||
}
|
||||
CHECK(!user_index.empty());
|
||||
// Only share with the right tuple element buffer.
|
||||
std::vector<int64> operand_indices = user->OperandIndices(operand);
|
||||
const auto operand_indices = user->OperandIndices(operand);
|
||||
return operand_indices.size() == 1 && user_index[0] == operand_indices[0];
|
||||
}
|
||||
if (user->opcode() == HloOpcode::kCall) {
|
||||
|
@ -542,9 +542,8 @@ StatusOr<std::unique_ptr<HloInstruction>> HloInstruction::CreateFromProto(
|
||||
case HloOpcode::kGather: {
|
||||
TF_RET_CHECK(proto.has_gather_dimension_numbers())
|
||||
<< "Gather instruction should have GatherDimensionNumbers set.";
|
||||
std::unique_ptr<GatherDimensionNumbers> gather_dimension_numbers =
|
||||
absl::make_unique<GatherDimensionNumbers>(
|
||||
proto.gather_dimension_numbers());
|
||||
auto gather_dimension_numbers = absl::make_unique<GatherDimensionNumbers>(
|
||||
proto.gather_dimension_numbers());
|
||||
std::vector<int64> gather_slice_sizes;
|
||||
for (int64 bound : proto.gather_slice_sizes()) {
|
||||
gather_slice_sizes.push_back(bound);
|
||||
@ -3063,9 +3062,9 @@ Status HloInstruction::AcceptWithOperandOrder(
|
||||
|
||||
const Shape& HloInstruction::shape() const { return shape_; }
|
||||
|
||||
std::vector<int64> HloInstruction::OperandIndices(
|
||||
absl::InlinedVector<int64, 4> HloInstruction::OperandIndices(
|
||||
const HloInstruction* operand) const {
|
||||
std::vector<int64> result;
|
||||
absl::InlinedVector<int64, 4> result;
|
||||
for (int64 i = 0; i < operand_count(); ++i) {
|
||||
if (this->operand(i) == operand) {
|
||||
result.push_back(i);
|
||||
|
@ -1410,7 +1410,8 @@ class HloInstruction {
|
||||
// Returns the indices that the given operand appear in the operand list of
|
||||
// this instruction. Note that an instruction can use the same operand
|
||||
// multiple times.
|
||||
std::vector<int64> OperandIndices(const HloInstruction* operand) const;
|
||||
absl::InlinedVector<int64, 4> OperandIndices(
|
||||
const HloInstruction* operand) const;
|
||||
|
||||
// Convenience helper for ShapeUtil::InsertedOrDeleted1SizedDimensions. If
|
||||
// this reshape merely inserts or deletes 1-sized dimensions, return the input
|
||||
|
@ -43,7 +43,7 @@ using absl::StrJoin;
|
||||
|
||||
bool IsInstructionElementwiseOnOperand(const HloInstruction* instruction,
|
||||
const HloInstruction* operand) {
|
||||
std::vector<int64> operand_indices = instruction->OperandIndices(operand);
|
||||
const auto operand_indices = instruction->OperandIndices(operand);
|
||||
return absl::c_all_of(operand_indices, [instruction](int64 operand_index) {
|
||||
return instruction->IsElementwiseOnOperand(operand_index);
|
||||
});
|
||||
|
@ -137,7 +137,7 @@ IrArray::Index IrArray::Index::SourceIndexOfReshape(
|
||||
const Shape& output_shape, const Shape& input_shape,
|
||||
llvm::IRBuilder<>* builder) const {
|
||||
CHECK_EQ(multidim_.size(), output_shape.rank());
|
||||
std::vector<std::pair<int64, int64>> common_factors =
|
||||
const auto common_factors =
|
||||
CommonFactors(AsInt64Slice(input_shape.dimensions()),
|
||||
AsInt64Slice(output_shape.dimensions()));
|
||||
std::vector<llvm::Value*> source_multidim_index(
|
||||
|
@ -1099,7 +1099,8 @@ ShapeUtil::DimensionsUnmodifiedByReshape(const Shape& input_shape,
|
||||
}
|
||||
// `CommonFactors(a, b).back() == (a.rank, b.rank)` so we must pop it.
|
||||
common_factors.pop_back();
|
||||
return common_factors;
|
||||
return std::vector<std::pair<int64, int64>>(common_factors.begin(),
|
||||
common_factors.end());
|
||||
}
|
||||
|
||||
/* static */ absl::optional<std::vector<int64>>
|
||||
|
@ -229,14 +229,14 @@ int64 Product(absl::Span<const int64> xs) {
|
||||
std::multiplies<int64>());
|
||||
}
|
||||
|
||||
std::vector<std::pair<int64, int64>> CommonFactors(absl::Span<const int64> a,
|
||||
absl::Span<const int64> b) {
|
||||
absl::InlinedVector<std::pair<int64, int64>, 8> CommonFactors(
|
||||
absl::Span<const int64> a, absl::Span<const int64> b) {
|
||||
CHECK_EQ(Product(a), Product(b));
|
||||
if (0 == Product(a)) {
|
||||
return {std::make_pair(0, 0), std::make_pair(a.size(), b.size())};
|
||||
}
|
||||
|
||||
std::vector<std::pair<int64, int64>> bounds;
|
||||
absl::InlinedVector<std::pair<int64, int64>, 8> bounds;
|
||||
for (int64 i = 0, j = 0, prior_i = -1, prior_j = -1, partial_size_a = 1,
|
||||
partial_size_b = 1;
|
||||
;) {
|
||||
|
@ -484,8 +484,8 @@ int64 Product(absl::Span<const int64> xs);
|
||||
//
|
||||
// If the given shapes have non-zero size, returns the bounds of the shortest
|
||||
// possible such subsequences; else, returns `{(0, 0), (a.size, b.size)}`.
|
||||
std::vector<std::pair<int64, int64>> CommonFactors(absl::Span<const int64> a,
|
||||
absl::Span<const int64> b);
|
||||
absl::InlinedVector<std::pair<int64, int64>, 8> CommonFactors(
|
||||
absl::Span<const int64> a, absl::Span<const int64> b);
|
||||
|
||||
// Removes illegal characters from filenames.
|
||||
string SanitizeFileName(string file_name);
|
||||
|
@ -69,7 +69,7 @@ TEST(UtilTest, LogLines) {
|
||||
TEST(UtilTest, CommonFactors) {
|
||||
struct {
|
||||
std::vector<int64> a, b;
|
||||
std::vector<std::pair<int64, int64>> expected;
|
||||
absl::InlinedVector<std::pair<int64, int64>, 8> expected;
|
||||
} test_cases[] = {
|
||||
{/*.a =*/{0}, /*.b =*/{0}, /*.expected =*/{{0, 0}, {1, 1}}},
|
||||
{/*.a =*/{}, /*.b =*/{}, /*.expected =*/{{0, 0}}},
|
||||
|
Loading…
Reference in New Issue
Block a user