[XLA] Make liberal use of inlined vectors to avoid memory allocation inside

loops.

PiperOrigin-RevId: 277799911
Change-Id: I313ac54c836f1cb6aeaa7900f177783396940435
This commit is contained in:
Blake Hechtman 2019-10-31 14:19:57 -07:00 committed by TensorFlower Gardener
parent bf9c54ae1f
commit 093da5c248
10 changed files with 19 additions and 18 deletions

View File

@ -2285,7 +2285,7 @@ bool OutputIsPermutationOfOperandElements(HloInstruction* instruction,
// "operand". Precondition: "operand" is an operand of "instruction".
bool OutputIsSubsetOfOperandElements(HloInstruction* instruction,
HloInstruction* operand) {
std::vector<int64> operand_indices = instruction->OperandIndices(operand);
const auto operand_indices = instruction->OperandIndices(operand);
CHECK(!operand_indices.empty());
if (operand_indices.size() != 1) {
return false;

View File

@ -1155,7 +1155,7 @@ bool HloDataflowAnalysis::CanShareOperandBufferWithUser(
user->opcode() == HloOpcode::kWhile) {
// We eliminated other users in HloOrdering::LiveRangeStrictlyBefore
// so here we just need to check that the use is at the right operand index.
std::vector<int64> operand_indices = user->OperandIndices(operand);
const auto operand_indices = user->OperandIndices(operand);
int64 operand_no = user->opcode() == HloOpcode::kTriangularSolve ? 1 : 0;
return operand_indices.size() == 1 && operand_indices[0] == operand_no;
}
@ -1171,7 +1171,7 @@ bool HloDataflowAnalysis::CanShareOperandBufferWithUser(
}
CHECK(!user_index.empty());
// Only share with the right tuple element buffer.
std::vector<int64> operand_indices = user->OperandIndices(operand);
const auto operand_indices = user->OperandIndices(operand);
return operand_indices.size() == 1 && user_index[0] == operand_indices[0];
}
if (user->opcode() == HloOpcode::kCall) {

View File

@ -542,9 +542,8 @@ StatusOr<std::unique_ptr<HloInstruction>> HloInstruction::CreateFromProto(
case HloOpcode::kGather: {
TF_RET_CHECK(proto.has_gather_dimension_numbers())
<< "Gather instruction should have GatherDimensionNumbers set.";
std::unique_ptr<GatherDimensionNumbers> gather_dimension_numbers =
absl::make_unique<GatherDimensionNumbers>(
proto.gather_dimension_numbers());
auto gather_dimension_numbers = absl::make_unique<GatherDimensionNumbers>(
proto.gather_dimension_numbers());
std::vector<int64> gather_slice_sizes;
for (int64 bound : proto.gather_slice_sizes()) {
gather_slice_sizes.push_back(bound);
@ -3063,9 +3062,9 @@ Status HloInstruction::AcceptWithOperandOrder(
const Shape& HloInstruction::shape() const { return shape_; }
std::vector<int64> HloInstruction::OperandIndices(
absl::InlinedVector<int64, 4> HloInstruction::OperandIndices(
const HloInstruction* operand) const {
std::vector<int64> result;
absl::InlinedVector<int64, 4> result;
for (int64 i = 0; i < operand_count(); ++i) {
if (this->operand(i) == operand) {
result.push_back(i);

View File

@ -1410,7 +1410,8 @@ class HloInstruction {
// Returns the indices that the given operand appear in the operand list of
// this instruction. Note that an instruction can use the same operand
// multiple times.
std::vector<int64> OperandIndices(const HloInstruction* operand) const;
absl::InlinedVector<int64, 4> OperandIndices(
const HloInstruction* operand) const;
// Convenience helper for ShapeUtil::InsertedOrDeleted1SizedDimensions. If
// this reshape merely inserts or deletes 1-sized dimensions, return the input

View File

@ -43,7 +43,7 @@ using absl::StrJoin;
bool IsInstructionElementwiseOnOperand(const HloInstruction* instruction,
const HloInstruction* operand) {
std::vector<int64> operand_indices = instruction->OperandIndices(operand);
const auto operand_indices = instruction->OperandIndices(operand);
return absl::c_all_of(operand_indices, [instruction](int64 operand_index) {
return instruction->IsElementwiseOnOperand(operand_index);
});

View File

@ -137,7 +137,7 @@ IrArray::Index IrArray::Index::SourceIndexOfReshape(
const Shape& output_shape, const Shape& input_shape,
llvm::IRBuilder<>* builder) const {
CHECK_EQ(multidim_.size(), output_shape.rank());
std::vector<std::pair<int64, int64>> common_factors =
const auto common_factors =
CommonFactors(AsInt64Slice(input_shape.dimensions()),
AsInt64Slice(output_shape.dimensions()));
std::vector<llvm::Value*> source_multidim_index(

View File

@ -1099,7 +1099,8 @@ ShapeUtil::DimensionsUnmodifiedByReshape(const Shape& input_shape,
}
// `CommonFactors(a, b).back() == (a.rank, b.rank)` so we must pop it.
common_factors.pop_back();
return common_factors;
return std::vector<std::pair<int64, int64>>(common_factors.begin(),
common_factors.end());
}
/* static */ absl::optional<std::vector<int64>>

View File

@ -229,14 +229,14 @@ int64 Product(absl::Span<const int64> xs) {
std::multiplies<int64>());
}
std::vector<std::pair<int64, int64>> CommonFactors(absl::Span<const int64> a,
absl::Span<const int64> b) {
absl::InlinedVector<std::pair<int64, int64>, 8> CommonFactors(
absl::Span<const int64> a, absl::Span<const int64> b) {
CHECK_EQ(Product(a), Product(b));
if (0 == Product(a)) {
return {std::make_pair(0, 0), std::make_pair(a.size(), b.size())};
}
std::vector<std::pair<int64, int64>> bounds;
absl::InlinedVector<std::pair<int64, int64>, 8> bounds;
for (int64 i = 0, j = 0, prior_i = -1, prior_j = -1, partial_size_a = 1,
partial_size_b = 1;
;) {

View File

@ -484,8 +484,8 @@ int64 Product(absl::Span<const int64> xs);
//
// If the given shapes have non-zero size, returns the bounds of the shortest
// possible such subsequences; else, returns `{(0, 0), (a.size, b.size)}`.
std::vector<std::pair<int64, int64>> CommonFactors(absl::Span<const int64> a,
absl::Span<const int64> b);
absl::InlinedVector<std::pair<int64, int64>, 8> CommonFactors(
absl::Span<const int64> a, absl::Span<const int64> b);
// Removes illegal characters from filenames.
string SanitizeFileName(string file_name);

View File

@ -69,7 +69,7 @@ TEST(UtilTest, LogLines) {
TEST(UtilTest, CommonFactors) {
struct {
std::vector<int64> a, b;
std::vector<std::pair<int64, int64>> expected;
absl::InlinedVector<std::pair<int64, int64>, 8> expected;
} test_cases[] = {
{/*.a =*/{0}, /*.b =*/{0}, /*.expected =*/{{0, 0}, {1, 1}}},
{/*.a =*/{}, /*.b =*/{}, /*.expected =*/{{0, 0}}},