[XLA] Use absl::c_foo rather than std::foo.

No functional change.

PiperOrigin-RevId: 227896034
This commit is contained in:
Justin Lebar 2019-01-04 12:29:13 -08:00 committed by TensorFlower Gardener
parent f9bd1568aa
commit b4813a0cff
55 changed files with 305 additions and 380 deletions

View File

@ -717,6 +717,7 @@ cc_library(
":types",
":xla_data_proto",
"//tensorflow/core:lib",
"@com_google_absl//absl/algorithm:container",
"@com_google_absl//absl/strings",
"@com_google_absl//absl/types:span",
],

View File

@ -77,7 +77,7 @@ XLA_TEST_F(SortingTest, TopKFullSort) {
auto x = ConstantR1<float>(&builder, inputs);
xla::GetTupleElement(xla::TopK(x, kSize), 0);
std::sort(inputs.begin(), inputs.end(), std::greater<float>());
absl::c_sort(inputs, std::greater<float>());
ComputeAndCompareR1<float>(&builder, inputs, {});
}

View File

@ -290,8 +290,8 @@ Layout CreateDefaultLayoutForRank(int64 rank) {
/* static */ bool LayoutUtil::HasLayout(const Shape& shape) {
if (shape.IsTuple()) {
// Tuple shape: all subshapes must have a layout.
return std::all_of(shape.tuple_shapes().begin(), shape.tuple_shapes().end(),
[](const Shape& s) { return HasLayout(s); });
return absl::c_all_of(shape.tuple_shapes(),
[](const Shape& s) { return HasLayout(s); });
} else if (!shape.IsArray()) {
// Opaque, token types etc. ignore layout.
return true;
@ -424,7 +424,7 @@ Status LayoutUtil::CopyLayoutBetweenShapes(const Shape& src, Shape* dst) {
positions_in_layout.push_back(
PositionInContainer(layout.minor_to_major(), dim));
}
std::sort(positions_in_layout.begin(), positions_in_layout.end());
absl::c_sort(positions_in_layout);
for (size_t i = 1; i < positions_in_layout.size(); ++i) {
if (1 != positions_in_layout[i] - positions_in_layout[i - 1]) {
return false;

View File

@ -55,7 +55,7 @@ string MetricTableReport::MakeReport(double expected_metric_sum) {
const auto metric_greater = [](const Entry& a, const Entry& b) {
return a.metric > b.metric;
};
std::sort(entries_.begin(), entries_.end(), metric_greater);
absl::c_sort(entries_, metric_greater);
// Create the report
AppendLine();
@ -117,7 +117,7 @@ std::vector<MetricTableReport::Category> MetricTableReport::MakeCategories(
auto metric_sum_greater = [](const Category& a, const Category& b) {
return a.metric_sum > b.metric_sum;
};
std::sort(categories.begin(), categories.end(), metric_sum_greater);
absl::c_sort(categories, metric_sum_greater);
return categories;
}

View File

@ -3414,6 +3414,7 @@ cc_library(
":hlo_profile_printer_data",
":human_readable_profile_builder",
"//tensorflow/compiler/xla:types",
"@com_google_absl//absl/algorithm:container",
"@com_google_absl//absl/strings",
],
)

View File

@ -2216,8 +2216,7 @@ Status AlgebraicSimplifierVisitor::HandleReverse(HloInstruction* reverse) {
auto dim_is_one = [&](int64 i) -> bool {
return reverse->shape().dimensions(i) == 1;
};
if (std::all_of(reverse->dimensions().begin(), reverse->dimensions().end(),
dim_is_one)) {
if (absl::c_all_of(reverse->dimensions(), dim_is_one)) {
return ReplaceInstruction(reverse, reverse->mutable_operand(0));
}
return Status::OK();
@ -2492,9 +2491,9 @@ Status AlgebraicSimplifierVisitor::HandleReduce(HloInstruction* reduce) {
// Create a new reduce with the combined reduction dimensions of both
// reduces.
std::vector<int64> arg_dims = arg->dimensions();
std::sort(arg_dims.begin(), arg_dims.end());
absl::c_sort(arg_dims);
std::vector<int64> reduce_dims = reduce->dimensions();
std::sort(reduce_dims.begin(), reduce_dims.end());
absl::c_sort(reduce_dims);
// Transform reduce_dims to the same rank as the operand of the operand.
for (int64 arg_dim : arg_dims) {
for (int64& dim : reduce_dims) {

View File

@ -86,10 +86,9 @@ std::vector<int64> ColorInterferenceGraph(
// first, but it would be good to investigate other ordering heuristics too.
std::vector<int64> nodes(node_count);
std::iota(nodes.begin(), nodes.end(), 0);
std::sort(nodes.begin(), nodes.end(),
[&interference_map](const int64 i, const int64 j) {
return interference_map[i].size() > interference_map[j].size();
});
absl::c_sort(nodes, [&interference_map](const int64 i, const int64 j) {
return interference_map[i].size() > interference_map[j].size();
});
const int64 kColorUnassigned = -1;
std::vector<int64> assigned_colors(node_count, kColorUnassigned);
@ -272,11 +271,12 @@ BufferAllocationProto BufferAllocation::ToProto() const {
proto_assigned->set_offset(buffer_offset_size.second.offset);
proto_assigned->set_size(buffer_offset_size.second.size);
}
std::sort(proto.mutable_assigned()->begin(), proto.mutable_assigned()->end(),
[](const BufferAllocationProto::Assigned& assign1,
const BufferAllocationProto::Assigned& assign2) {
return assign1.logical_buffer_id() < assign2.logical_buffer_id();
});
absl::c_sort(*proto.mutable_assigned(),
[](const BufferAllocationProto::Assigned& assign1,
const BufferAllocationProto::Assigned& assign2) {
return assign1.logical_buffer_id() <
assign2.logical_buffer_id();
});
return proto;
}
@ -308,10 +308,10 @@ string BufferAllocation::ToString() const {
for (const auto& buffer_offset_size : assigned_buffers_) {
sorted_buffers.push_back(buffer_offset_size.first);
}
std::sort(sorted_buffers.begin(), sorted_buffers.end(),
[](const LogicalBuffer* a, const LogicalBuffer* b) {
return a->id() < b->id();
});
absl::c_sort(sorted_buffers,
[](const LogicalBuffer* a, const LogicalBuffer* b) {
return a->id() < b->id();
});
for (const LogicalBuffer* buffer : sorted_buffers) {
const OffsetSize& offset_size = FindOrDie(assigned_buffers_, buffer);
StrAppend(&output, absl::StrFormat(
@ -479,10 +479,9 @@ bool BufferAssignment::HaveDisjointSlices(const HloInstruction* hlo_a,
// didn't return the empty set) for both HLOs, and the two resulting sets of
// slices are disjoint.
return !slices_a.empty() && !slices_b.empty() &&
std::none_of(slices_a.begin(), slices_a.end(),
[&](const BufferAllocation::Slice& slice) {
return slices_b.count(slice) > 0;
});
absl::c_none_of(slices_a, [&](const BufferAllocation::Slice& slice) {
return slices_b.contains(slice);
});
}
StatusOr<BufferAllocation::Slice>
@ -952,28 +951,28 @@ Status BufferAssigner::AssignBuffersForComputation(
// operands (assuming operands are the same/larger size) enabling the
// important reuse case where an elementwise instruction reuses one of its
// operand's buffer. This improves locality.
std::sort(sorted_buffers.begin(), sorted_buffers.end(),
[has_sequential_order, &liveness, &post_order_position, assignment](
const LogicalBuffer* a, const LogicalBuffer* b) {
// Primary sort is by decreasing buffer size.
const int64 a_size = assignment->buffer_size_(*a);
const int64 b_size = assignment->buffer_size_(*b);
if (a_size != b_size) {
return a_size > b_size; // use ">" for decreasing size.
}
// Otherwise live out buffers come before others, if the
// instructions are sequentially ordered.
if (has_sequential_order) {
const bool a_live_out = liveness.MaybeLiveOut(*a);
const bool b_live_out = liveness.MaybeLiveOut(*b);
if (a_live_out != b_live_out) {
return a_live_out;
}
}
// Final tiebreaker is in instruction post order.
return post_order_position.at(a->instruction()) <
post_order_position.at(b->instruction());
});
absl::c_sort(sorted_buffers,
[has_sequential_order, &liveness, &post_order_position,
assignment](const LogicalBuffer* a, const LogicalBuffer* b) {
// Primary sort is by decreasing buffer size.
const int64 a_size = assignment->buffer_size_(*a);
const int64 b_size = assignment->buffer_size_(*b);
if (a_size != b_size) {
return a_size > b_size; // use ">" for decreasing size.
}
// Otherwise live out buffers come before others, if the
// instructions are sequentially ordered.
if (has_sequential_order) {
const bool a_live_out = liveness.MaybeLiveOut(*a);
const bool b_live_out = liveness.MaybeLiveOut(*b);
if (a_live_out != b_live_out) {
return a_live_out;
}
}
// Final tiebreaker is in instruction post order.
return post_order_position.at(a->instruction()) <
post_order_position.at(b->instruction());
});
// BufferAllocations are necessarily created in decreasing size order. Keep
// indices of previously created BufferAllocations in allocation_indices.
@ -1305,10 +1304,10 @@ std::vector<const LogicalBuffer*> ComputePeakMemoryLogicalBuffers(
live_buffers.end());
// Stabily sort the live buffers.
std::sort(live_buffers_vector.begin(), live_buffers_vector.end(),
[](const LogicalBuffer* a, const LogicalBuffer* b) {
return a->id() < b->id();
});
absl::c_sort(live_buffers_vector,
[](const LogicalBuffer* a, const LogicalBuffer* b) {
return a->id() < b->id();
});
return live_buffers_vector;
}

View File

@ -384,7 +384,7 @@ TEST_F(CallGraphTest, ComplexGraph) {
// Verify visitation order of some computations in the graph.
auto index_of = [&visited](const HloComputation* comp) {
auto it = std::find(visited.begin(), visited.end(), comp);
auto it = absl::c_find(visited, comp);
EXPECT_NE(it, visited.end());
return std::distance(visited.begin(), it);
};

View File

@ -42,8 +42,8 @@ void ComputationLayout::SetToDefaultLayout() {
}
bool ComputationLayout::LayoutIsSet() const {
return std::all_of(parameter_layouts_.begin(), parameter_layouts_.end(),
[](const ShapeLayout& s) { return s.LayoutIsSet(); }) &&
return absl::c_all_of(parameter_layouts_,
[](const ShapeLayout& s) { return s.LayoutIsSet(); }) &&
result_layout_.LayoutIsSet();
}

View File

@ -539,10 +539,9 @@ class CopyRemover {
}
std::vector<const HloValue*> values = buffer.values();
std::sort(values.begin(), values.end(),
[this](const HloValue* a, const HloValue* b) {
return ordering_.IsDefinedBefore(*a, *b);
});
absl::c_sort(values, [this](const HloValue* a, const HloValue* b) {
return ordering_.IsDefinedBefore(*a, *b);
});
// Create a list containing all of the values in the buffer.
AddValueList(values, &value_to_node);
@ -842,12 +841,11 @@ class CopyRemover {
copy_value_node->next->prev = operand_node;
// Patch up uses. Remove use of copy from operand_node uses.
auto it =
std::find_if(operand_node->uses.begin(), operand_node->uses.end(),
[copy_value_node](const HloUse* use) {
return use->instruction ==
copy_value_node->value->defining_instruction();
});
auto it = absl::c_find_if(
operand_node->uses, [copy_value_node](const HloUse* use) {
return use->instruction ==
copy_value_node->value->defining_instruction();
});
CHECK(it != operand_node->uses.end());
operand_node->uses.erase(it);

View File

@ -77,17 +77,16 @@ StatusOr<DisassemblerResult> Disassembler::DisassembleObjectFile(
}
// Sort the symbols in increasing address order.
std::sort(
symbols.begin(), symbols.end(),
[](const llvm::object::SymbolRef& a, const llvm::object::SymbolRef& b) {
// getAddress returns a Expected object. Assert there is no error
// before extracting the address.
llvm::Expected<uint64_t> a_address_or_error = a.getAddress();
CHECK(a_address_or_error);
llvm::Expected<uint64_t> b_address_or_error = b.getAddress();
CHECK(b_address_or_error);
return a_address_or_error.get() < b_address_or_error.get();
});
absl::c_sort(symbols, [](const llvm::object::SymbolRef& a,
const llvm::object::SymbolRef& b) {
// getAddress returns a Expected object. Assert there is no error
// before extracting the address.
llvm::Expected<uint64_t> a_address_or_error = a.getAddress();
CHECK(a_address_or_error);
llvm::Expected<uint64_t> b_address_or_error = b.getAddress();
CHECK(b_address_or_error);
return a_address_or_error.get() < b_address_or_error.get();
});
// Construct ArrayRef pointing to section contents.
llvm::StringRef section_content_string;

View File

@ -1709,10 +1709,8 @@ StatusOr<bool> IrEmitter::EmitVectorizedReduce(
vectorization_factor_in_bytes /
ShapeUtil::ByteSizeOfPrimitiveType(reduce->shape().element_type());
bool is_reduction_over_minor_dimension =
std::find(dimensions.begin(), dimensions.end(),
LayoutUtil::Minor(arg->shape().layout(), 0)) !=
dimensions.end();
bool is_reduction_over_minor_dimension = absl::c_linear_search(
dimensions, LayoutUtil::Minor(arg->shape().layout(), 0));
unsigned element_alignment = tensorflow::MathUtil::GCD<unsigned>(
ShapeUtil::ByteSizeOfPrimitiveType(reduce->shape().element_type()),
@ -2401,8 +2399,7 @@ StatusOr<bool> IrEmitter::EmitFastConcatenate(
int64 concat_dim = concatenate->dimensions(0);
const Layout& output_layout = output_shape.layout();
auto output_min2maj = LayoutUtil::MinorToMajor(output_layout);
auto concat_dim_layout_itr =
std::find(output_min2maj.begin(), output_min2maj.end(), concat_dim);
auto concat_dim_layout_itr = absl::c_find(output_min2maj, concat_dim);
std::vector<int64> inner_dims(output_min2maj.begin(), concat_dim_layout_itr);
std::vector<int64> outer_dims(std::next(concat_dim_layout_itr),
@ -2956,8 +2953,7 @@ Status IrEmitter::ElementTypesSameAndSupported(
TF_RET_CHECK(!operands.empty());
PrimitiveType primitive_type = operands[0]->shape().element_type();
if (std::find(supported_types.begin(), supported_types.end(),
primitive_type) == supported_types.end()) {
if (!absl::c_linear_search(supported_types, primitive_type)) {
return Unimplemented("unsupported operand type %s in op %s",
PrimitiveType_Name(primitive_type),
HloOpcodeString(instruction.opcode()));

View File

@ -154,20 +154,17 @@ bool IsReductionToVector(const HloInstruction& reduce) {
const HloInstruction* input = reduce.operand(0);
std::vector<int64> dims_to_keep;
for (int64 dim = 0; dim < input->shape().dimensions().size(); ++dim) {
if (!std::count(reduce.dimensions().begin(), reduce.dimensions().end(),
dim)) {
if (!absl::c_linear_search(reduce.dimensions(), dim)) {
dims_to_keep.push_back(dim);
}
}
return LayoutUtil::AreDimensionsConsecutive(input->shape().layout(),
dims_to_keep) &&
ShapeUtil::Equal(reduce.shape(), ShapeUtil::FilterDimensions(
[&dims_to_keep](int64 dim) {
return std::count(
dims_to_keep.begin(),
dims_to_keep.end(), dim);
},
input->shape()));
ShapeUtil::Equal(
reduce.shape(),
ShapeUtil::FilterDimensions(
[&](int64 dim) { return absl::c_count(dims_to_keep, dim); },
input->shape()));
}
// This emits a device-side call to

View File

@ -1506,10 +1506,10 @@ std::unique_ptr<KernelThunk> IrEmitterUnnested::BuildKernelThunk(
return !allocation->is_constant();
});
std::sort(non_constant_buffers.begin(), non_constant_buffers.end(),
[](const BufferAllocation* a, const BufferAllocation* b) {
return a->index() < b->index();
});
absl::c_sort(non_constant_buffers,
[](const BufferAllocation* a, const BufferAllocation* b) {
return a->index() < b->index();
});
llvm::Function* kernel = BuildKernelPrototype(*inst, non_constant_buffers);

View File

@ -225,10 +225,10 @@ Status HeapSimulator::RunComputation(
}
}
// Sort to get a deterministic iteration order.
std::sort(operand_buffers_to_free.begin(), operand_buffers_to_free.end(),
[](const BufferValue* x, const BufferValue* y) {
return x->id() < y->id();
});
absl::c_sort(operand_buffers_to_free,
[](const BufferValue* x, const BufferValue* y) {
return x->id() < y->id();
});
// Allocate buffers defined by this instruction. This is the latest point
// that we can allocate; right before the buffer is first used. This must
@ -335,10 +335,9 @@ Status HeapSimulator::RunComputation(
to_free.push_back(buffer);
}
std::sort(to_free.begin(), to_free.end(),
[](const BufferValue* x, const BufferValue* y) {
return x->id() < y->id();
});
absl::c_sort(to_free, [](const BufferValue* x, const BufferValue* y) {
return x->id() < y->id();
});
for (const BufferValue* buffer : to_free) {
VLOG(3) << "Freeing pending: " << buffer->ToString();
Free(buffer, root);
@ -596,7 +595,7 @@ void DecreasingSizeRunsHeap::CallAndDrainRun() {
}
// Call ops in the run sorted by decreasing size, breaking ties by buffer id.
std::sort(run_.begin(), run_.end(), [](const Op& a, const Op& b) {
absl::c_sort(run_, [](const Op& a, const Op& b) {
if (a.size != b.size) {
return a.size > b.size;
}
@ -866,23 +865,23 @@ HeapSimulator::Result GlobalDecreasingSizeBestFitHeap::Finish() {
for (auto& entry : buffer_intervals_) {
sorted_buffer_intervals.push_back(entry.second);
}
std::sort(sorted_buffer_intervals.begin(), sorted_buffer_intervals.end(),
[](const BufferInterval& x, const BufferInterval& y) {
if (x.size != y.size) {
return x.size > y.size;
}
if (x.end - x.start != y.end - y.start) {
return x.end - x.start > y.end - y.start;
}
return x.buffer->id() < y.buffer->id();
});
absl::c_sort(sorted_buffer_intervals,
[](const BufferInterval& x, const BufferInterval& y) {
if (x.size != y.size) {
return x.size > y.size;
}
if (x.end - x.start != y.end - y.start) {
return x.end - x.start > y.end - y.start;
}
return x.buffer->id() < y.buffer->id();
});
BufferIntervalTree interval_tree(sorted_buffer_intervals.size());
for (auto& buffer_interval : sorted_buffer_intervals) {
auto chunks_overlapping_in_time = interval_tree.ChunksOverlappingInTime(
buffer_interval.start, buffer_interval.end);
std::sort(
chunks_overlapping_in_time.begin(), chunks_overlapping_in_time.end(),
absl::c_sort(
chunks_overlapping_in_time,
[](const Chunk& x, const Chunk& y) { return x.offset < y.offset; });
// Find the minimum free chunk that can hold this buffer.

View File

@ -117,7 +117,7 @@ class BufferValueMap {
for (const auto& pair : buffers_) {
buffer_numbers.push_back(pair.first);
}
std::sort(buffer_numbers.begin(), buffer_numbers.end());
absl::c_sort(buffer_numbers);
return buffer_numbers;
}
@ -319,7 +319,7 @@ class BufferValueMap {
ComputeWhileAliasedBuffers(value, &aliased_buffers);
ComputeConditionalAliasedBuffers(value, &aliased_buffers);
// Uniquify aliased buffers.
std::sort(aliased_buffers.begin(), aliased_buffers.end());
absl::c_sort(aliased_buffers);
aliased_buffers.erase(
std::unique(aliased_buffers.begin(), aliased_buffers.end()),
aliased_buffers.end());
@ -367,7 +367,7 @@ std::vector<const HloBuffer*> HloAliasAnalysis::ComputeBuffersAt(
}
// Sort and uniquify vector before returning.
std::sort(buffers.begin(), buffers.end(), HloBuffer::IdLessThan);
absl::c_sort(buffers, HloBuffer::IdLessThan);
buffers.erase(std::unique(buffers.begin(), buffers.end()), buffers.end());
return buffers;
@ -430,8 +430,7 @@ Status HloAliasAnalysis::Verify() const {
for (const auto& pair : value_to_buffer_) {
const HloValue* value = pair.first;
const HloBuffer& buffer = *pair.second;
TF_RET_CHECK(std::find(buffer.values().begin(), buffer.values().end(),
value) != buffer.values().end());
TF_RET_CHECK(absl::c_linear_search(buffer.values(), value));
}
for (HloBuffer::Id id = 0; id < buffers_.size(); ++id) {
@ -515,7 +514,7 @@ StatusOr<std::unique_ptr<HloAliasAnalysis>> HloAliasAnalysis::Run(
auto& value_set = buffer_map.GetValuesInBuffer(buffer_number);
std::vector<const HloValue*> sorted_values(value_set.begin(),
value_set.end());
std::sort(sorted_values.begin(), sorted_values.end(), HloValue::IdLessThan);
absl::c_sort(sorted_values, HloValue::IdLessThan);
alias_analysis->buffers_.emplace_back(next_id++, sorted_values);
for (const HloValue* value : sorted_values) {
alias_analysis->value_to_buffer_[value] =
@ -547,16 +546,15 @@ bool HloAliasAnalysis::HasLiveRangeInterference(
// tie-break using value ID. The tie-break is necessary because we need a
// strict weak order for std::sort.
std::vector<const HloValue*> values = buffer.values();
std::sort(values.begin(), values.end(),
[&ordering](const HloValue* a, const HloValue* b) {
if (ordering.IsDefinedBefore(*a, *b)) {
return true;
} else if (ordering.IsDefinedBefore(*b, *a)) {
return false;
} else {
return a->id() < b->id();
}
});
absl::c_sort(values, [&ordering](const HloValue* a, const HloValue* b) {
if (ordering.IsDefinedBefore(*a, *b)) {
return true;
} else if (ordering.IsDefinedBefore(*b, *a)) {
return false;
} else {
return a->id() < b->id();
}
});
// Walk through the ordered vector of values. First verify that the values
// are totally ordered with respect to 'ordering', then check that no

View File

@ -49,7 +49,7 @@ std::vector<HloPosition> HloBuffer::ComputePositions() const {
value->positions().end());
}
// Remove duplicates and sort positions.
std::sort(positions.begin(), positions.end());
absl::c_sort(positions);
positions.erase(std::unique(positions.begin(), positions.end()),
positions.end());
return positions;

View File

@ -531,11 +531,10 @@ HloComputation::CreateFromProto(
HloInstruction* root = instruction_map.at(proto.root_id());
// Sort the instructions in the proto id's order.
std::sort(instructions.begin(), instructions.end(),
[&](const std::unique_ptr<HloInstruction>& a,
const std::unique_ptr<HloInstruction>& b) {
return to_proto_id[a.get()] < to_proto_id[b.get()];
});
absl::c_sort(instructions, [&](const std::unique_ptr<HloInstruction>& a,
const std::unique_ptr<HloInstruction>& b) {
return to_proto_id[a.get()] < to_proto_id[b.get()];
});
TF_RETURN_IF_ERROR([&]() -> Status {
std::vector<bool> parameters_seen(parameter_count);
@ -800,8 +799,7 @@ Status HloComputation::AcceptOrdered(
absl::Span<HloInstruction* const> order) const {
VLOG(3) << "Accepting visitor with order.";
for (HloInstruction* root : CollectUnreachableRoots()) {
TF_RET_CHECK(std::find(order.begin(), order.end(), root) != order.end())
<< root->ToString();
TF_RET_CHECK(absl::c_linear_search(order, root)) << root->ToString();
}
TF_RET_CHECK(order.size() == instruction_count());
absl::flat_hash_set<const HloInstruction*> visited;

View File

@ -256,7 +256,7 @@ bool HloDataflowAnalysis::Phi(
input_value_ids.push_back(value->id());
}
}
std::sort(input_value_ids.begin(), input_value_ids.end());
absl::c_sort(input_value_ids);
input_value_ids.erase(
std::unique(input_value_ids.begin(), input_value_ids.end()),
input_value_ids.end());
@ -271,8 +271,7 @@ bool HloDataflowAnalysis::Phi(
if (current_value_defined_here) {
VLOG(5) << "current_value_defined_here: " << current_value->ToString();
CHECK(current_value->is_phi());
auto it = std::find(input_value_ids.begin(), input_value_ids.end(),
current_value->id());
auto it = absl::c_find(input_value_ids, current_value->id());
if (it != input_value_ids.end()) {
input_value_ids.erase(it);
}
@ -921,8 +920,7 @@ StatusOr<std::unique_ptr<HloDataflowAnalysis>> HloDataflowAnalysis::Run(
for (auto& pair : dataflow_analysis->values_) {
dataflow_analysis->values_vector_.push_back(&pair.second);
}
std::sort(dataflow_analysis->values_vector_.begin(),
dataflow_analysis->values_vector_.end(), HloValue::IdLessThan);
absl::c_sort(dataflow_analysis->values_vector_, HloValue::IdLessThan);
TF_DCHECK_OK(dataflow_analysis->Verify());
@ -937,9 +935,7 @@ Status HloDataflowAnalysis::Verify() const {
for (const HloValue* value : values()) {
for (const HloPosition& position : value->positions()) {
const HloValueSet& value_set = GetValueSet(position);
TF_RET_CHECK(std::find(value_set.values().begin(),
value_set.values().end(),
value) != value_set.values().end())
TF_RET_CHECK(absl::c_linear_search(value_set.values(), value))
<< "Value set at position " << position << " does not contain value "
<< value->ToShortString();
}
@ -954,9 +950,7 @@ Status HloDataflowAnalysis::Verify() const {
const HloValueSet& value_set = pair.second;
const HloPosition position{instruction, index};
for (const HloValue* value : value_set.values()) {
TF_RET_CHECK(std::find(value->positions().begin(),
value->positions().end(),
position) != value->positions().end())
TF_RET_CHECK(absl::c_linear_search(value->positions(), position))
<< "Value set at position " << position
<< " unexpectedly contains value " << value->ToShortString();
}
@ -1041,11 +1035,10 @@ bool HloDataflowAnalysis::CanShareOperandBufferWithUser(
// Check if one operand of kAdd fused root is kDot or kConvolution.
auto* add = user->fused_expression_root();
auto add_operand_it =
std::find_if(add->operands().begin(), add->operands().end(),
[&](HloInstruction* operand) {
return operand->opcode() == HloOpcode::kConvolution ||
operand->opcode() == HloOpcode::kDot;
});
absl::c_find_if(add->operands(), [&](HloInstruction* operand) {
return operand->opcode() == HloOpcode::kConvolution ||
operand->opcode() == HloOpcode::kDot;
});
if (add_operand_it == add->operands().end()) {
return false;
}
@ -1100,16 +1093,15 @@ bool HloDataflowAnalysis::CanShareOperandBufferWithUser(
// *) The root instruction of the called computation is element-wise on
// 'operand'.
const bool found_caller_use =
std::find_if(uses.begin(), uses.end(), [user](const HloUse& use) {
absl::c_find_if(uses, [user](const HloUse& use) {
return use.instruction == user;
}) != uses.end();
auto* callee_root = user->to_apply()->root_instruction();
const bool found_elementwise_callee_use =
std::find_if(
uses.begin(), uses.end(), [callee_root](const HloUse& use) {
return use.instruction == callee_root &&
callee_root->IsElementwiseOnOperand(use.operand_number);
}) != uses.end();
absl::c_find_if(uses, [callee_root](const HloUse& use) {
return use.instruction == callee_root &&
callee_root->IsElementwiseOnOperand(use.operand_number);
}) != uses.end();
return uses.size() == 2 && found_caller_use && found_elementwise_callee_use;
}

View File

@ -43,9 +43,7 @@ class HloDceTest : public HloTestBase {
// Returns whether the given instruction exists in the given computation.
bool HasInstruction(const HloComputation& computation,
const HloInstruction* instruction) {
return std::find(computation.instructions().begin(),
computation.instructions().end(),
instruction) != computation.instructions().end();
return absl::c_linear_search(computation.instructions(), instruction);
}
};

View File

@ -230,10 +230,10 @@ HloDomainMap::MakeNonDomainInstructions(
}
}
// sort instructions according to instructions_order
std::sort(instructions.begin(), instructions.end(),
[&instructions_order](HloInstruction* a, HloInstruction* b) {
return instructions_order.at(a) < instructions_order.at(b);
});
absl::c_sort(instructions,
[&instructions_order](HloInstruction* a, HloInstruction* b) {
return instructions_order.at(a) < instructions_order.at(b);
});
return instructions;
}

View File

@ -1248,8 +1248,7 @@ StatusOr<Literal> EvaluateSortInternal(HloInstruction* sort,
// Extract a slice from the keys and values literals that correspond to
// exactly the row in dimension 'sort_dim'.
std::vector<int64> limit_indices(indices.begin(), indices.end());
std::for_each(limit_indices.begin(), limit_indices.end(),
[](int64& index) { ++index; });
absl::c_for_each(limit_indices, [](int64& index) { ++index; });
limit_indices[sort_dim] = sort_dim_elements;
TF_ASSIGN_OR_RETURN(auto keys_to_sort,
keys_literal.Slice(indices, limit_indices)

View File

@ -1673,8 +1673,7 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault {
// Extract a slice from the literal that corresponds to exactly the
// row in dimension 'sort_dim'.
std::vector<int64> limit_indices(indices.begin(), indices.end());
std::for_each(limit_indices.begin(), limit_indices.end(),
[](int64& index) { ++index; });
absl::c_for_each(limit_indices, [](int64& index) { ++index; });
limit_indices[sort_dim] = sort_dim_elements;
TF_ASSIGN_OR_RETURN(auto row_to_sort,
keys_literal.Slice(indices, limit_indices)

View File

@ -561,8 +561,8 @@ bool HloDotDumper::ShouldShowSubcomputation(const HloComputation* subcomp) {
}
// Show the subcomputation if we're showing any of its members.
return std::any_of(
subcomp->instructions().begin(), subcomp->instructions().end(),
return absl::c_any_of(
subcomp->instructions(),
[&](const HloInstruction* instr) { return filter_.Show(instr); });
}
@ -735,15 +735,14 @@ bool HloDotDumper::ShouldMergeIntoUsers(const HloInstruction* instr) const {
const int kMinUsersToOmit = 3;
return instr->opcode() == HloOpcode::kParameter && instr->shape().IsTuple() &&
!instr->IsFused() &&
std::count_if(instr->users().begin(), instr->users().end(),
[&](const HloInstruction* user) {
return filter_.Show(user);
}) > kMinUsersToOmit &&
std::all_of(instr->users().begin(), instr->users().end(),
[&](const HloInstruction* user) {
return !filter_.Show(user) ||
user->opcode() == HloOpcode::kGetTupleElement;
});
absl::c_count_if(instr->users(),
[&](const HloInstruction* user) {
return filter_.Show(user);
}) > kMinUsersToOmit &&
absl::c_all_of(instr->users(), [&](const HloInstruction* user) {
return !filter_.Show(user) ||
user->opcode() == HloOpcode::kGetTupleElement;
});
}
string HloDotDumper::DumpInstruction(const HloInstruction* instr) {
@ -900,12 +899,11 @@ ColorScheme HloDotDumper::GetInstructionColor(const HloInstruction* instr) {
// the same color as a parameter. Unless the merged-in parameter is a
// parameter to a fusion node that is bound to a constant -- these aren't
// "real" parameters from the user's perspective.
if (std::any_of(instr->operands().begin(), instr->operands().end(),
[&](const HloInstruction* operand) {
return operand->opcode() == HloOpcode::kParameter &&
ShouldMergeIntoUsers(operand) &&
TryGetFusionParameterConstant(operand) == nullptr;
})) {
if (absl::c_any_of(instr->operands(), [&](const HloInstruction* operand) {
return operand->opcode() == HloOpcode::kParameter &&
ShouldMergeIntoUsers(operand) &&
TryGetFusionParameterConstant(operand) == nullptr;
})) {
return parameter_color;
}
@ -1355,12 +1353,11 @@ NodeFilter MakeNodeRadiusAroundFilter(const HloInstruction* root,
NodeFilterResult& filter_result = kv.second;
const auto& operands = instr->operands();
if (std::any_of(operands.begin(), operands.end(), is_displayed) &&
!std::all_of(operands.begin(), operands.end(), is_displayed)) {
if (absl::c_any_of(operands, is_displayed) &&
!absl::c_all_of(operands, is_displayed)) {
// Mark nodes with some operands omitted appropriately.
filter_result = kSomeOperandsOmitted;
} else if (!operands.empty() &&
std::none_of(operands.begin(), operands.end(), is_displayed)) {
} else if (!operands.empty() && absl::c_none_of(operands, is_displayed)) {
// Mark nodes with *all* operands omitted appropriately.
filter_result = kOmitNodeOperands;
}
@ -1368,8 +1365,7 @@ NodeFilter MakeNodeRadiusAroundFilter(const HloInstruction* root,
// Promote nodes with type kSomeUsersOmitted to kNormalNode if all of their
// users made it into the graph.
if (filter_result == kSomeUsersOmitted &&
std::all_of(instr->users().begin(), instr->users().end(),
is_displayed)) {
absl::c_all_of(instr->users(), is_displayed)) {
filter_result = kNormalNode;
}
}

View File

@ -83,15 +83,14 @@ StatusOr<std::unique_ptr<HloInstruction>> HloInstruction::CreateFromProto(
return computation_map.at(proto.called_computation_ids(index));
};
TF_RET_CHECK(std::all_of(
proto.operand_ids().begin(), proto.operand_ids().end(),
[&instruction_map](int64 id) { return instruction_map.contains(id); }))
TF_RET_CHECK(
absl::c_all_of(proto.operand_ids(),
[&](int64 id) { return instruction_map.contains(id); }))
<< proto.name() << " instruction contains invalid operand id(s)";
TF_RET_CHECK(std::all_of(
proto.called_computation_ids().begin(),
proto.called_computation_ids().end(),
[&computation_map](int64 id) { return computation_map.contains(id); }))
TF_RET_CHECK(
absl::c_all_of(proto.called_computation_ids(),
[&](int64 id) { return computation_map.contains(id); }))
<< proto.name() << " instruction references invalid computation id(s)";
Shape shape(proto.shape());
@ -1599,12 +1598,10 @@ HloInstruction::InstructionVector HloInstruction::unique_operands() const {
Status HloInstruction::AddControlDependencyTo(HloInstruction* instruction) {
TF_RET_CHECK(instruction->parent() == parent());
if (std::find(control_successors_.begin(), control_successors_.end(),
instruction) == control_successors_.end()) {
if (!absl::c_linear_search(control_successors_, instruction)) {
control_successors_.push_back(instruction);
TF_RET_CHECK(std::find(instruction->control_predecessors_.begin(),
instruction->control_predecessors_.end(),
this) == instruction->control_predecessors_.end());
TF_RET_CHECK(
!absl::c_linear_search(instruction->control_predecessors_, this));
instruction->control_predecessors_.push_back(this);
}
return Status::OK();
@ -1853,7 +1850,7 @@ void HloInstruction::RemoveUser(HloInstruction* user) {
user_set_.erase(set_it);
// This is linear in the number of the users, but a vector provides a stable
// iteration order and much faster traversal.
auto vec_it = std::find(users_.begin(), users_.end(), user);
auto vec_it = absl::c_find(users_, user);
CHECK(vec_it != users_.end());
users_.erase(vec_it);
}
@ -1871,8 +1868,7 @@ Status HloInstruction::ReplaceUseWith(HloInstruction* user,
RemoveUser(user);
TF_RET_CHECK(
std::count(user->operands_.begin(), user->operands_.end(), this) >= 0);
TF_RET_CHECK(absl::c_count(user->operands_, this) >= 0);
std::replace(user->operands_.begin(), user->operands_.end(), this,
new_producer);
new_producer->AddUser(user);
@ -1907,8 +1903,7 @@ Status HloInstruction::ReplaceOperandWithDifferentShape(
VLOG(3) << "Replacing operand " << operand_num << " of " << name() << " with "
<< new_operand->name() << ", was " << old_operand->name();
if (std::find(operands_.begin(), operands_.end(), old_operand) ==
operands_.end()) {
if (!absl::c_linear_search(operands_, old_operand)) {
old_operand->RemoveUser(this);
}
new_operand->AddUser(this);
@ -2945,10 +2940,10 @@ StatusOr<HloInstruction::FusionKind> StringToFusionKind(
string PaddingConfigToString(const PaddingConfig& padding) {
bool has_interior_padding =
std::any_of(padding.dimensions().begin(), padding.dimensions().end(),
[](const PaddingConfig::PaddingConfigDimension& dim) {
return dim.interior_padding() != 0;
});
absl::c_any_of(padding.dimensions(),
[](const PaddingConfig::PaddingConfigDimension& dim) {
return dim.interior_padding() != 0;
});
return StrJoin(
padding.dimensions(), "x",
[&](string* out, const PaddingConfig::PaddingConfigDimension& dim) {

View File

@ -42,11 +42,9 @@ using absl::StrJoin;
bool IsInstructionElementwiseOnOperand(const HloInstruction* instruction,
const HloInstruction* operand) {
std::vector<int64> operand_indices = instruction->OperandIndices(operand);
return std::all_of(
operand_indices.begin(), operand_indices.end(),
[instruction](int64 operand_index) {
return instruction->IsElementwiseOnOperand(operand_index);
});
return absl::c_all_of(operand_indices, [instruction](int64 operand_index) {
return instruction->IsElementwiseOnOperand(operand_index);
});
}
string PrecisionConfigToString(const PrecisionConfig& precision_config) {
@ -814,8 +812,7 @@ std::vector<string> HloSliceInstruction::ExtraAttributesToStringImpl(
std::vector<string> bounds;
bounds.reserve(slice_starts_.size());
const bool omit_stride =
std::all_of(slice_strides_.begin(), slice_strides_.end(),
[](int64 stride) { return stride == 1; });
absl::c_all_of(slice_strides_, [](int64 stride) { return stride == 1; });
for (int i = 0; i < slice_starts_.size(); ++i) {
string stride_str = omit_stride ? "" : StrCat(":", slice_strides_[i]);
bounds.push_back(
@ -1051,8 +1048,7 @@ HloInstruction* HloFusionInstruction::AddFusionOperand(
void HloFusionInstruction::MergeFusionInstruction(
HloFusionInstruction* instruction_to_merge) {
CHECK(std::find(operands().begin(), operands().end(), instruction_to_merge) !=
operands().end());
CHECK(absl::c_linear_search(operands(), instruction_to_merge));
// Clone the instruction from which to merge fused instructions.
std::unique_ptr<HloInstruction> cloned = instruction_to_merge->Clone();
HloFusionInstruction* cloned_fusion =
@ -1219,8 +1215,8 @@ HloInstruction* HloFusionInstruction::CloneAndFuseInternal(
// corresponding fused parameter instruction. Renumber parameters as
// necessary to make parameter numbers consistent with their index in the
// fused_parameter_ vector.
bool in_operand_list = std::find(operands().begin(), operands().end(),
instruction_to_fuse) != operands().end();
bool in_operand_list =
absl::c_linear_search(operands(), instruction_to_fuse);
CHECK(add_output || in_operand_list);
if (instruction_to_fuse->opcode() == HloOpcode::kTuple) {
// We assume all uses of a kTuple operation are GTE ops, not another

View File

@ -107,11 +107,10 @@ HloComputation* HloModule::AddEntryComputation(
}
Status HloModule::RemoveEmbeddedComputation(HloComputation* to_remove) {
auto it =
std::find_if(computations_.begin(), computations_.end(),
[&to_remove](const std::unique_ptr<HloComputation>& comp) {
return comp.get() == to_remove;
});
auto it = absl::c_find_if(
computations_, [&to_remove](const std::unique_ptr<HloComputation>& comp) {
return comp.get() == to_remove;
});
TF_RET_CHECK(it->get() == to_remove);
computations_.erase(it);
return Status::OK();
@ -304,11 +303,10 @@ StatusOr<std::unique_ptr<HloModule>> HloModule::CreateFromProto(
auto module = absl::make_unique<HloModule>(proto.name(), module_config);
// Sort the computations in the proto id's order.
std::sort(computations.begin(), computations.end(),
[&](const std::unique_ptr<HloComputation>& a,
const std::unique_ptr<HloComputation>& b) {
return to_proto_id[a.get()] < to_proto_id[b.get()];
});
absl::c_sort(computations, [&](const std::unique_ptr<HloComputation>& a,
const std::unique_ptr<HloComputation>& b) {
return to_proto_id[a.get()] < to_proto_id[b.get()];
});
// Add sorted computations to the module.
for (auto& computation : computations) {

View File

@ -38,9 +38,7 @@ class HloModuleDceTest : public HloTestBase {
// Returns whether the given instruction exists in the given computation.
bool HasInstruction(const HloComputation& computation,
const HloInstruction* instruction) {
return std::find(computation.instructions().begin(),
computation.instructions().end(),
instruction) != computation.instructions().end();
return absl::c_linear_search(computation.instructions(), instruction);
}
// Returns whether the while instruction with name 'while_name' in

View File

@ -2746,7 +2746,7 @@ bool HloParser::ParseConvolutionDimensionNumbers(
}
auto is_unique = [](string str) -> bool {
std::sort(str.begin(), str.end());
absl::c_sort(str);
return std::unique(str.begin(), str.end()) == str.end();
};

View File

@ -15,6 +15,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/hlo_profile_printer.h"
#include "absl/algorithm/container.h"
#include "absl/strings/str_cat.h"
#include "tensorflow/compiler/xla/service/human_readable_profile_builder.h"
@ -34,11 +35,10 @@ string PrintHloProfile(const HloProfilePrinterData& hlo_profile_printer_data,
for (const HloComputationInfo& computation_info :
hlo_profile_printer_data.computation_infos()) {
const auto& instruction_infos = computation_info.instruction_infos();
bool any_instruction_profiled =
std::any_of(instruction_infos.begin(), instruction_infos.end(),
[&](const HloInstructionInfo& instruction_info) {
return counters[instruction_info.profile_index()] != 0;
});
bool any_instruction_profiled = absl::c_any_of(
instruction_infos, [&](const HloInstructionInfo& instruction_info) {
return counters[instruction_info.profile_index()] != 0;
});
if (!any_instruction_profiled) {
continue;

View File

@ -49,7 +49,7 @@ void HloReachabilityMap::SetReachabilityToUnionHelper(
absl::Span<const HloInstruction* const> inputs,
const HloInstruction* instruction, BitVector* bit_vector) {
// If instruction is part of inputs, don't reset the bit_vector.
if (std::find(inputs.begin(), inputs.end(), instruction) == inputs.end()) {
if (!absl::c_linear_search(inputs, instruction)) {
bit_vector->SetToZero();
}
bit_vector->Set(GetIndex(instruction));

View File

@ -235,8 +235,7 @@ class InstructionList {
}
// Now scan forwards until we find one of the before_instructions.
while (std::find(before_instructions.begin(), before_instructions.end(),
min_position_item) == before_instructions.end()) {
while (!absl::c_linear_search(before_instructions, min_position_item)) {
min_position_item = min_position_item->next;
}
return InsertBefore(to_insert, min_position_item);
@ -302,7 +301,7 @@ ItemList GetUsers(const InstructionList& instruction_list,
// A buffer may be used by the instruction via more than one alias. For
// example, a buffer which appears in more than one element of a tuple.
Item* user_item = instruction_list.GetItem(user);
if (std::find(users.begin(), users.end(), user_item) == users.end()) {
if (!absl::c_linear_search(users, user_item)) {
users.push_back(user_item);
}
}
@ -456,8 +455,7 @@ class MemoryUsageTracker {
return false;
}
const BufferIdList& in_progress_uses = in_progress_item_->buffers_used;
return std::find(in_progress_uses.begin(), in_progress_uses.end(),
buffer_id) != in_progress_uses.end();
return absl::c_linear_search(in_progress_uses, buffer_id);
}
// Returns whether the given instruction is live at the current program
@ -535,8 +533,7 @@ MemoryUsageTracker::MemoryUsageTracker(
bool unused;
for (Item* user_item : GetUsers(instruction_list_, logical_buffer,
points_to_analysis, &unused)) {
if (std::find(buffer->users.begin(), buffer->users.end(),
user_item) == buffer->users.end()) {
if (!absl::c_linear_search(buffer->users, user_item)) {
buffer->users.push_back(user_item);
buffer->unfinished_user_count++;
user_item->buffers_used.push_back(buffer->id);
@ -784,8 +781,7 @@ bool MemoryUsageTracker::Check() const {
for (const Buffer& buffer : buffers_) {
if (buffer.defining_instruction->instruction == instruction) {
CHECK(std::find(defined_buffers.begin(), defined_buffers.end(),
buffer.id) != defined_buffers.end())
CHECK(absl::c_linear_search(defined_buffers, buffer.id))
<< "Instruction " << instruction->name()
<< " defined buffers is missing: " << buffer.ToString();
}
@ -808,8 +804,7 @@ bool MemoryUsageTracker::Check() const {
int64 unfinished_uses = 0;
for (Item* user : buffer.users) {
const BufferIdList& used_buffers = user->buffers_used;
CHECK(std::find(used_buffers.begin(), used_buffers.end(), buffer.id) !=
used_buffers.end())
CHECK(absl::c_linear_search(used_buffers, buffer.id))
<< "Instruction " << user->instruction->name()
<< " used buffers is missing " << buffer.ToString();
if (!IsFinished(user)) {
@ -836,10 +831,10 @@ int64 RematerializationCost(const HloInstruction* instruction,
// If none of the users of 'instruction' have been placed in the sequence (as
// tracked by memory_tracker), then rematerialization of 'instruction' is a
// zero-cost move of 'instruction' in the sequence.
if (!std::any_of(instruction->users().begin(), instruction->users().end(),
[&memory_tracker](const HloInstruction* inst) {
return memory_tracker.IsPlaced(inst);
})) {
if (!absl::c_any_of(instruction->users(),
[&memory_tracker](const HloInstruction* inst) {
return memory_tracker.IsPlaced(inst);
})) {
return 0;
}

View File

@ -107,13 +107,12 @@ string HloSharding::ToString() const {
bool HloSharding::UsesDevice(int64 device) const {
if (IsTuple()) {
return std::any_of(
tuple_elements_.begin(), tuple_elements_.end(),
[&](const HloSharding& s) { return s.UsesDevice(device); });
return absl::c_any_of(tuple_elements_, [&](const HloSharding& s) {
return s.UsesDevice(device);
});
}
const auto& devices = tile_assignment_;
return replicated_ ||
std::find(devices.begin(), devices.end(), device) != devices.end();
return replicated_ || absl::c_linear_search(devices, device);
}
std::map<int64, int64> HloSharding::UsedDevices(int64* count) const {

View File

@ -101,8 +101,8 @@ class HloSharding {
if (!IsTuple()) {
return replicated_;
}
return std::all_of(tuple_elements_.begin(), tuple_elements_.end(),
[](const HloSharding& s) { return s.IsReplicated(); });
return absl::c_all_of(
tuple_elements_, [](const HloSharding& s) { return s.IsReplicated(); });
}
// Returns true if the tile size is the same as the input size.
@ -110,8 +110,9 @@ class HloSharding {
if (!IsTuple()) {
return maximal_;
}
return std::all_of(tuple_elements_.begin(), tuple_elements_.end(),
[](const HloSharding& s) { return s.IsTileMaximal(); });
return absl::c_all_of(tuple_elements_, [](const HloSharding& s) {
return s.IsTileMaximal();
});
}
// Returns true if the sharding defines an operation on the given device.

View File

@ -61,8 +61,7 @@ void CleanNodeName(string* name) {
name->erase(std::remove(name->begin(), name->end(), '%'), name->end());
const string chars_to_replace = "<>[]";
auto pred = [&](char c) {
return std::find(chars_to_replace.begin(), chars_to_replace.end(), c) !=
chars_to_replace.end();
return absl::c_linear_search(chars_to_replace, c);
};
std::replace_if(name->begin(), name->end(), pred, '_');
}

View File

@ -209,7 +209,7 @@ std::ostream& operator<<(std::ostream& out, const HloValue& value) {
}
void HloValueSet::SortAndUniquifyValues() {
std::sort(values_.begin(), values_.end(), HloValue::IdLessThan);
absl::c_sort(values_, HloValue::IdLessThan);
values_.erase(std::unique(values_.begin(), values_.end(), HloValue::IdEqual),
values_.end());
}

View File

@ -128,9 +128,9 @@ string HumanReadableProfileBuilder::ToString() const {
// Sort ops in decreasing order of cycles, and print them.
std::vector<OpInfo> sorted_ops(op_infos_);
std::sort(
sorted_ops.begin(), sorted_ops.end(),
[](const OpInfo& a, const OpInfo& b) { return a.cycles > b.cycles; });
absl::c_sort(sorted_ops, [](const OpInfo& a, const OpInfo& b) {
return a.cycles > b.cycles;
});
for (const auto& op : sorted_ops) {
print_op(op);
}

View File

@ -178,19 +178,18 @@ bool InstructionFusion::EffectivelyAtMostUnary(HloInstruction* hlo) {
output_rank = std::max(output_rank, ShapeUtil::TrueRank(subshape));
}
});
return std::count_if(hlo->operands().begin(), hlo->operands().end(),
[output_rank](HloInstruction* operand) {
if (operand->opcode() == HloOpcode::kBroadcast ||
operand->opcode() == HloOpcode::kIota) {
return false;
}
if (operand->opcode() == HloOpcode::kConstant &&
ShapeUtil::IsEffectiveScalar(operand->shape())) {
return false;
}
return ShapeUtil::TrueRank(operand->shape()) >=
output_rank;
}) <= 1;
return absl::c_count_if(
hlo->operands(), [output_rank](HloInstruction* operand) {
if (operand->opcode() == HloOpcode::kBroadcast ||
operand->opcode() == HloOpcode::kIota) {
return false;
}
if (operand->opcode() == HloOpcode::kConstant &&
ShapeUtil::IsEffectiveScalar(operand->shape())) {
return false;
}
return ShapeUtil::TrueRank(operand->shape()) >= output_rank;
}) <= 1;
}
bool InstructionFusion::CanFuseOnAllPaths(
@ -409,9 +408,8 @@ class ReversePostOrderFusionQueue : public FusionQueue {
}
sorted_operand_numbers.push_back(i);
}
std::sort(
sorted_operand_numbers.begin(), sorted_operand_numbers.end(),
[&](int64 i, int64 j) {
absl::c_sort(
sorted_operand_numbers, [&](int64 i, int64 j) {
// Instructions with higher priority in the queue come first.
return (
FindOrDie(post_order_index_, instruction->mutable_operand(i)) >

View File

@ -147,12 +147,9 @@ bool LayoutConstraints::OperandBufferForwarded(
PointsToSet::BufferSet* output_buffers = GetBufferSet(instruction);
PointsToSet::BufferSet* operand_buffers =
GetBufferSet(instruction->operand(operand_no));
for (const LogicalBuffer* output_buffer : *output_buffers) {
if (operand_buffers->count(output_buffer) > 0) {
return true;
}
}
return false;
return absl::c_any_of(*output_buffers, [&](const LogicalBuffer* b) {
return operand_buffers->count(b) > 0;
});
}
Status LayoutConstraints::SetBufferLayout(const Layout& layout,

View File

@ -81,9 +81,7 @@ void AliasAnalysis::AddAliasingInformationToIrArray(const HloInstruction& hlo,
if (hlo.opcode() == HloOpcode::kParameter) {
const std::vector<HloInstruction*>& parameter_instructions =
module_.entry_computation()->parameter_instructions();
if (std::find(parameter_instructions.begin(),
parameter_instructions.end(),
&hlo) != parameter_instructions.end()) {
if (absl::c_linear_search(parameter_instructions, &hlo)) {
array->MarkInvariantOverWholeProgram(context_);
}
}

View File

@ -34,7 +34,7 @@ bool IsAllowed(char character) {
} // namespace
NameUniquer::NameUniquer(const string& separator) {
CHECK(std::all_of(separator.begin(), separator.end(), IsAllowed))
CHECK(absl::c_all_of(separator, IsAllowed))
<< "separator should comprises allowed characters only";
separator_ = separator;
}

View File

@ -260,8 +260,8 @@ PlatformUtil::GetStreamExecutors(
// Block here in thread_pool destructor until all devices are initialized.
}
VLOG(1) << "Device initialization complete";
if (std::all_of(stream_executors.begin(), stream_executors.end(),
[](se::StreamExecutor* s) { return s == nullptr; })) {
if (absl::c_all_of(stream_executors,
[](se::StreamExecutor* s) { return s == nullptr; })) {
return InternalError("no supported devices found for platform %s",
platform->Name());
}

View File

@ -534,9 +534,8 @@ Status ValidateDotDimensionNumbers(
absl::Span<const int64> contracting_dims,
absl::Span<const int64> batch_dims) -> bool {
auto in_range = [&rank](int64 i) -> bool { return 0 <= i && i < rank; };
return std::all_of(contracting_dims.begin(), contracting_dims.end(),
in_range) &&
std::all_of(batch_dims.begin(), batch_dims.end(), in_range);
return absl::c_all_of(contracting_dims, in_range) &&
absl::c_all_of(batch_dims, in_range);
};
absl::Span<const int64> lhs_contracting_dimensions =
@ -563,9 +562,8 @@ Status ValidateDotDimensionNumbers(
auto is_unique = [&dim_set](int64 i) -> bool {
return dim_set.insert(i).second;
};
return std::all_of(contracting_dims.begin(), contracting_dims.end(),
is_unique) &&
std::all_of(batch_dims.begin(), batch_dims.end(), is_unique);
return absl::c_all_of(contracting_dims, is_unique) &&
absl::c_all_of(batch_dims, is_unique);
};
if (!dims_unique(lhs_contracting_dimensions, lhs_batch_dimensions) ||
@ -1589,29 +1587,29 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation,
input_dnums[1] = dnums.input_feature_dimension();
std::copy(dnums.input_spatial_dimensions().begin(),
dnums.input_spatial_dimensions().end(), input_dnums.begin() + 2);
std::sort(input_dnums.begin(), input_dnums.end());
absl::c_sort(input_dnums);
std::vector<int64> window_dnums(num_dims);
window_dnums[0] = dnums.kernel_input_feature_dimension();
window_dnums[1] = dnums.kernel_output_feature_dimension();
std::copy(dnums.kernel_spatial_dimensions().begin(),
dnums.kernel_spatial_dimensions().end(), window_dnums.begin() + 2);
std::sort(window_dnums.begin(), window_dnums.end());
absl::c_sort(window_dnums);
std::vector<int64> output_dnums(num_dims);
output_dnums[0] = dnums.output_batch_dimension();
output_dnums[1] = dnums.output_feature_dimension();
std::copy(dnums.output_spatial_dimensions().begin(),
dnums.output_spatial_dimensions().end(), output_dnums.begin() + 2);
std::sort(output_dnums.begin(), output_dnums.end());
absl::c_sort(output_dnums);
std::vector<int64> expected_dnums(num_dims);
std::iota(expected_dnums.begin(), expected_dnums.end(), 0);
const auto in_range = [num_dims](int64 i) { return 0 <= i && i < num_dims; };
if (!std::all_of(input_dnums.begin(), input_dnums.end(), in_range) ||
!std::all_of(window_dnums.begin(), window_dnums.end(), in_range) ||
!std::all_of(output_dnums.begin(), output_dnums.end(), in_range)) {
if (!absl::c_all_of(input_dnums, in_range) ||
!absl::c_all_of(window_dnums, in_range) ||
!absl::c_all_of(output_dnums, in_range)) {
return InvalidArgument(
"A dimension number is out of range in convolution: %s.",
dnums.DebugString());

View File

@ -130,8 +130,7 @@ bool FoldTransposeIntoConvolution(InstructionOperandsPair pair) {
HloInstruction* new_lhs;
const int64 kLhsIdx = 0;
if (std::find(operand_indices.begin(), operand_indices.end(), kLhsIdx) !=
operand_indices.end()) {
if (absl::c_linear_search(operand_indices, kLhsIdx)) {
HloInstruction& transpose = *convolution.mutable_operand(kLhsIdx);
const auto& transpose_dimensions = transpose.dimensions();
HloInstruction& transpose_operand = *transpose.mutable_operand(0);
@ -154,8 +153,7 @@ bool FoldTransposeIntoConvolution(InstructionOperandsPair pair) {
HloInstruction* new_rhs;
const int64 kRhsIdx = 1;
if (std::find(operand_indices.begin(), operand_indices.end(), kRhsIdx) !=
operand_indices.end()) {
if (absl::c_linear_search(operand_indices, kRhsIdx)) {
HloInstruction& transpose = *convolution.mutable_operand(kRhsIdx);
const auto& transpose_dimensions = transpose.dimensions();
HloInstruction& transpose_operand = *transpose.mutable_operand(0);

View File

@ -87,9 +87,7 @@ bool PointsToSet::ContainsBuffer(const LogicalBuffer& buffer) const {
bool found = false;
ForEachElement([&found, &buffer](const ShapeIndex& /*index*/,
const BufferList& pointed_to_buffers) {
if (!found &&
std::find(pointed_to_buffers.begin(), pointed_to_buffers.end(),
&buffer) != pointed_to_buffers.end()) {
if (!found && absl::c_linear_search(pointed_to_buffers, &buffer)) {
found = true;
}
});
@ -99,8 +97,7 @@ bool PointsToSet::ContainsBuffer(const LogicalBuffer& buffer) const {
bool PointsToSet::ContainsBufferAtIndex(const LogicalBuffer& buffer,
const ShapeIndex& index) const {
const auto& pointed_to_buffers = element(index);
return std::find(pointed_to_buffers.begin(), pointed_to_buffers.end(),
&buffer) != pointed_to_buffers.end();
return absl::c_linear_search(pointed_to_buffers, &buffer);
}
void PointsToSet::AddPointedToBuffer(const LogicalBuffer& buffer,
@ -604,9 +601,8 @@ bool TuplePointsToAnalysis::DoesNotUseOperandBuffer(
} else if (user->opcode() == HloOpcode::kFusion &&
user->fusion_kind() == HloInstruction::FusionKind::kLoop) {
// Find fusion parameter associated with 'operand'.
auto it = std::find_if(
user->fused_parameters().begin(), user->fused_parameters().end(),
[=](HloInstruction* fused_param) {
auto it = absl::c_find_if(
user->fused_parameters(), [&](HloInstruction* fused_param) {
return user->operand(fused_param->parameter_number()) == operand;
});
CHECK(it != user->fused_parameters().end());
@ -672,9 +668,8 @@ bool TuplePointsToAnalysis::HasUniqueFusedUseOfOperandAt(
}
// Find fusion parameter associated with 'operand'.
const auto& fused_params = fusion->fused_parameters();
auto fused_param_it = std::find_if(
fused_params.begin(), fused_params.end(),
[&](HloInstruction* fused_param) {
auto fused_param_it =
absl::c_find_if(fused_params, [&](HloInstruction* fused_param) {
return fusion->operand(fused_param->parameter_number()) == operand;
});
if (fused_param_it == fused_params.end()) {
@ -743,11 +738,10 @@ bool TuplePointsToAnalysis::CanShareOperandBufferWithUser(
// Check if one operand of kAdd fused root is kDot or kConvolution.
auto* add = user->fused_expression_root();
auto add_operand_it =
std::find_if(add->operands().begin(), add->operands().end(),
[&](HloInstruction* operand) {
return operand->opcode() == HloOpcode::kConvolution ||
operand->opcode() == HloOpcode::kDot;
});
absl::c_find_if(add->operands(), [&](HloInstruction* operand) {
return operand->opcode() == HloOpcode::kConvolution ||
operand->opcode() == HloOpcode::kDot;
});
if (add_operand_it == add->operands().end()) {
return false;
}

View File

@ -721,9 +721,8 @@ class FusionPointsToAnalysisTest : public TuplePointsToAnalysisTest {
// to fusion 'operand'.
HloInstruction* GetFusionParameterForOperand(HloInstruction* fusion,
HloInstruction* operand) {
auto it = std::find_if(
fusion->fused_instructions().begin(),
fusion->fused_instructions().end(), [=](const HloInstruction* fused) {
auto it = absl::c_find_if(
fusion->fused_instructions(), [&](const HloInstruction* fused) {
return fused->opcode() == HloOpcode::kParameter &&
fusion->operand(fused->parameter_number()) == operand;
});

View File

@ -109,8 +109,7 @@ static StatusOr<bool> TryRemoveDeadWhileParams(HloInstruction* while_op) {
// operand appears in, but it may appear more than once!
if (user->user_count() == 1 && user->users().front() == while_body_root &&
while_body_root->operand_index(user) == user->tuple_index() &&
std::count(while_body_root->operands().begin(),
while_body_root->operands().end(), user) == 1) {
absl::c_count(while_body_root->operands(), user) == 1) {
continue;
}
@ -158,7 +157,7 @@ static StatusOr<bool> TryRemoveDeadWhileParams(HloInstruction* while_op) {
// Build up maps from the old/new to the new/old tuple indices.
std::vector<int64> new_to_old_tuple_idx(used_tuple_indices.begin(),
used_tuple_indices.end());
std::sort(new_to_old_tuple_idx.begin(), new_to_old_tuple_idx.end());
absl::c_sort(new_to_old_tuple_idx);
absl::flat_hash_map<int64, int64> old_to_new_tuple_idx;
for (int64 new_idx = 0; new_idx < new_to_old_tuple_idx.size(); ++new_idx) {

View File

@ -407,13 +407,12 @@ TEST_F(WhileLoopSimplifierTest, RemoveUnusedLoopOperands) {
// The original while instruction is still left in the module as a dead
// instruction, find a while instruction with a different name as the new
// while instruction.
const auto& instrs = m->entry_computation()->instructions();
HloInstruction* new_while_op =
*std::find_if(m->entry_computation()->instructions().begin(),
m->entry_computation()->instructions().end(),
[&](const HloInstruction* instr) {
return (instr->opcode() == HloOpcode::kWhile &&
instr->name() != "while");
});
*absl::c_find_if(instrs, [&](const HloInstruction* instr) {
return (instr->opcode() == HloOpcode::kWhile &&
instr->name() != "while");
});
auto scalar_s32 = ShapeUtil::MakeShape(S32, {});
EXPECT_TRUE(

View File

@ -87,8 +87,7 @@ bool Shape::is_static() const {
}
}
}
return !std::any_of(dynamic_dimensions_.begin(), dynamic_dimensions_.end(),
[](bool b) { return b; });
return !absl::c_any_of(dynamic_dimensions_, [](bool b) { return b; });
}
void Shape::DeleteDimension(int64 dim_to_delete) {

View File

@ -405,8 +405,7 @@ ShapeUtil::MakeShapeWithDescendingLayoutAndSamePhysicalLayout(
}
/* static */ bool ShapeUtil::IsNestedTuple(const Shape& shape) {
return IsTuple(shape) && std::any_of(shape.tuple_shapes().begin(),
shape.tuple_shapes().end(), IsTuple);
return IsTuple(shape) && absl::c_any_of(shape.tuple_shapes(), IsTuple);
}
/* static */ bool ShapeUtil::IsEmptyTuple(const Shape& shape) {

View File

@ -135,7 +135,7 @@ void SparseIndexArray::SortWithValues(absl::Span<NativeT> values) {
auto sort_order_less = [this](int64 lhs, int64 rhs) {
return IndexUtil::CompareIndices(At(lhs), At(rhs)) < 0;
};
std::sort(sort_order.begin(), sort_order.end(), sort_order_less);
absl::c_sort(sort_order, sort_order_less);
// Reorder the array elements according to sort_order. Work through the array
// and follow cycles so we can do the reorder in-place.

View File

@ -467,8 +467,8 @@ XLA_TEST_F(ConvolutionTest, Convolve3D_1x4x2x3x3_2x2x2x3x3_Valid) {
// servers. The error message is missing the operator ++.
template <typename T>
void iota_int_init_value(std::vector<T>& values, int init_value) {
std::for_each(values.begin(), values.end(),
[&](T& value) { value = static_cast<T>(init_value++); });
absl::c_for_each(values,
[&](T& value) { value = static_cast<T>(init_value++); });
}
template <typename T>

View File

@ -86,7 +86,7 @@ bool IsPermutation(absl::Span<const int64> permutation, int64 rank) {
CHECK_LT(index, rank);
output[index] = 0;
}
return std::find(output.begin(), output.end(), -1) == output.end();
return !absl::c_linear_search(output, -1);
}
std::vector<int64> InversePermutation(

View File

@ -324,8 +324,7 @@ bool IsIdentityPermutation(absl::Span<const int64> permutation);
template <typename Container>
int64 PositionInContainer(const Container& container, int64 value) {
return std::distance(container.begin(),
std::find(container.begin(), container.end(), value));
return std::distance(container.begin(), absl::c_find(container, value));
}
// Formats the container as a comma-separated string. StrAppend must support

View File

@ -17,6 +17,7 @@ limitations under the License.
#include <vector>
#include "absl/algorithm/container.h"
#include "absl/strings/str_cat.h"
#include "tensorflow/compiler/xla/types.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
@ -137,25 +138,23 @@ bool HasPadding(const Window& window) {
}
bool HasSymmetricPadding(const Window& window) {
return std::all_of(window.dimensions().begin(), window.dimensions().end(),
[](const WindowDimension& dim) {
return dim.padding_low() == dim.padding_high();
});
return absl::c_all_of(window.dimensions(), [](const WindowDimension& dim) {
return dim.padding_low() == dim.padding_high();
});
}
bool HasSymmetricPadding(const PaddingConfig& padding_config) {
return std::all_of(padding_config.dimensions().begin(),
padding_config.dimensions().end(),
[](const PaddingConfig::PaddingConfigDimension& dim) {
return dim.edge_padding_low() == dim.edge_padding_high();
});
return absl::c_all_of(padding_config.dimensions(),
[](const PaddingConfig::PaddingConfigDimension& dim) {
return dim.edge_padding_low() ==
dim.edge_padding_high();
});
}
bool HasNegativePadding(const Window& window) {
return std::any_of(window.dimensions().begin(), window.dimensions().end(),
[](const WindowDimension& dim) {
return dim.padding_low() < 0 || dim.padding_high() < 0;
});
return absl::c_any_of(window.dimensions(), [](const WindowDimension& dim) {
return dim.padding_low() < 0 || dim.padding_high() < 0;
});
}
bool HasBaseDilation(const Window& window) {
@ -190,10 +189,9 @@ bool AllOrNoneReversed(const Window& window) {
return true;
}
bool reversed = window.dimensions()[0].window_reversal();
return std::all_of(window.dimensions().begin(), window.dimensions().end(),
[&](const WindowDimension& dim) {
return dim.window_reversal() == reversed;
});
return absl::c_all_of(window.dimensions(), [&](const WindowDimension& dim) {
return dim.window_reversal() == reversed;
});
}
bool HasDilation(const Window& window) {