[XLA] Use absl::c_foo rather than std::foo.

No functional change.

PiperOrigin-RevId: 227896034
This commit is contained in:
Justin Lebar 2019-01-04 12:29:13 -08:00 committed by TensorFlower Gardener
parent f9bd1568aa
commit b4813a0cff
55 changed files with 305 additions and 380 deletions

View File

@ -717,6 +717,7 @@ cc_library(
":types", ":types",
":xla_data_proto", ":xla_data_proto",
"//tensorflow/core:lib", "//tensorflow/core:lib",
"@com_google_absl//absl/algorithm:container",
"@com_google_absl//absl/strings", "@com_google_absl//absl/strings",
"@com_google_absl//absl/types:span", "@com_google_absl//absl/types:span",
], ],

View File

@ -77,7 +77,7 @@ XLA_TEST_F(SortingTest, TopKFullSort) {
auto x = ConstantR1<float>(&builder, inputs); auto x = ConstantR1<float>(&builder, inputs);
xla::GetTupleElement(xla::TopK(x, kSize), 0); xla::GetTupleElement(xla::TopK(x, kSize), 0);
std::sort(inputs.begin(), inputs.end(), std::greater<float>()); absl::c_sort(inputs, std::greater<float>());
ComputeAndCompareR1<float>(&builder, inputs, {}); ComputeAndCompareR1<float>(&builder, inputs, {});
} }

View File

@ -290,7 +290,7 @@ Layout CreateDefaultLayoutForRank(int64 rank) {
/* static */ bool LayoutUtil::HasLayout(const Shape& shape) { /* static */ bool LayoutUtil::HasLayout(const Shape& shape) {
if (shape.IsTuple()) { if (shape.IsTuple()) {
// Tuple shape: all subshapes must have a layout. // Tuple shape: all subshapes must have a layout.
return std::all_of(shape.tuple_shapes().begin(), shape.tuple_shapes().end(), return absl::c_all_of(shape.tuple_shapes(),
[](const Shape& s) { return HasLayout(s); }); [](const Shape& s) { return HasLayout(s); });
} else if (!shape.IsArray()) { } else if (!shape.IsArray()) {
// Opaque, token types etc. ignore layout. // Opaque, token types etc. ignore layout.
@ -424,7 +424,7 @@ Status LayoutUtil::CopyLayoutBetweenShapes(const Shape& src, Shape* dst) {
positions_in_layout.push_back( positions_in_layout.push_back(
PositionInContainer(layout.minor_to_major(), dim)); PositionInContainer(layout.minor_to_major(), dim));
} }
std::sort(positions_in_layout.begin(), positions_in_layout.end()); absl::c_sort(positions_in_layout);
for (size_t i = 1; i < positions_in_layout.size(); ++i) { for (size_t i = 1; i < positions_in_layout.size(); ++i) {
if (1 != positions_in_layout[i] - positions_in_layout[i - 1]) { if (1 != positions_in_layout[i] - positions_in_layout[i - 1]) {
return false; return false;

View File

@ -55,7 +55,7 @@ string MetricTableReport::MakeReport(double expected_metric_sum) {
const auto metric_greater = [](const Entry& a, const Entry& b) { const auto metric_greater = [](const Entry& a, const Entry& b) {
return a.metric > b.metric; return a.metric > b.metric;
}; };
std::sort(entries_.begin(), entries_.end(), metric_greater); absl::c_sort(entries_, metric_greater);
// Create the report // Create the report
AppendLine(); AppendLine();
@ -117,7 +117,7 @@ std::vector<MetricTableReport::Category> MetricTableReport::MakeCategories(
auto metric_sum_greater = [](const Category& a, const Category& b) { auto metric_sum_greater = [](const Category& a, const Category& b) {
return a.metric_sum > b.metric_sum; return a.metric_sum > b.metric_sum;
}; };
std::sort(categories.begin(), categories.end(), metric_sum_greater); absl::c_sort(categories, metric_sum_greater);
return categories; return categories;
} }

View File

@ -3414,6 +3414,7 @@ cc_library(
":hlo_profile_printer_data", ":hlo_profile_printer_data",
":human_readable_profile_builder", ":human_readable_profile_builder",
"//tensorflow/compiler/xla:types", "//tensorflow/compiler/xla:types",
"@com_google_absl//absl/algorithm:container",
"@com_google_absl//absl/strings", "@com_google_absl//absl/strings",
], ],
) )

View File

@ -2216,8 +2216,7 @@ Status AlgebraicSimplifierVisitor::HandleReverse(HloInstruction* reverse) {
auto dim_is_one = [&](int64 i) -> bool { auto dim_is_one = [&](int64 i) -> bool {
return reverse->shape().dimensions(i) == 1; return reverse->shape().dimensions(i) == 1;
}; };
if (std::all_of(reverse->dimensions().begin(), reverse->dimensions().end(), if (absl::c_all_of(reverse->dimensions(), dim_is_one)) {
dim_is_one)) {
return ReplaceInstruction(reverse, reverse->mutable_operand(0)); return ReplaceInstruction(reverse, reverse->mutable_operand(0));
} }
return Status::OK(); return Status::OK();
@ -2492,9 +2491,9 @@ Status AlgebraicSimplifierVisitor::HandleReduce(HloInstruction* reduce) {
// Create a new reduce with the combined reduction dimensions of both // Create a new reduce with the combined reduction dimensions of both
// reduces. // reduces.
std::vector<int64> arg_dims = arg->dimensions(); std::vector<int64> arg_dims = arg->dimensions();
std::sort(arg_dims.begin(), arg_dims.end()); absl::c_sort(arg_dims);
std::vector<int64> reduce_dims = reduce->dimensions(); std::vector<int64> reduce_dims = reduce->dimensions();
std::sort(reduce_dims.begin(), reduce_dims.end()); absl::c_sort(reduce_dims);
// Transform reduce_dims to the same rank as the operand of the operand. // Transform reduce_dims to the same rank as the operand of the operand.
for (int64 arg_dim : arg_dims) { for (int64 arg_dim : arg_dims) {
for (int64& dim : reduce_dims) { for (int64& dim : reduce_dims) {

View File

@ -86,8 +86,7 @@ std::vector<int64> ColorInterferenceGraph(
// first, but it would be good to investigate other ordering heuristics too. // first, but it would be good to investigate other ordering heuristics too.
std::vector<int64> nodes(node_count); std::vector<int64> nodes(node_count);
std::iota(nodes.begin(), nodes.end(), 0); std::iota(nodes.begin(), nodes.end(), 0);
std::sort(nodes.begin(), nodes.end(), absl::c_sort(nodes, [&interference_map](const int64 i, const int64 j) {
[&interference_map](const int64 i, const int64 j) {
return interference_map[i].size() > interference_map[j].size(); return interference_map[i].size() > interference_map[j].size();
}); });
@ -272,10 +271,11 @@ BufferAllocationProto BufferAllocation::ToProto() const {
proto_assigned->set_offset(buffer_offset_size.second.offset); proto_assigned->set_offset(buffer_offset_size.second.offset);
proto_assigned->set_size(buffer_offset_size.second.size); proto_assigned->set_size(buffer_offset_size.second.size);
} }
std::sort(proto.mutable_assigned()->begin(), proto.mutable_assigned()->end(), absl::c_sort(*proto.mutable_assigned(),
[](const BufferAllocationProto::Assigned& assign1, [](const BufferAllocationProto::Assigned& assign1,
const BufferAllocationProto::Assigned& assign2) { const BufferAllocationProto::Assigned& assign2) {
return assign1.logical_buffer_id() < assign2.logical_buffer_id(); return assign1.logical_buffer_id() <
assign2.logical_buffer_id();
}); });
return proto; return proto;
} }
@ -308,7 +308,7 @@ string BufferAllocation::ToString() const {
for (const auto& buffer_offset_size : assigned_buffers_) { for (const auto& buffer_offset_size : assigned_buffers_) {
sorted_buffers.push_back(buffer_offset_size.first); sorted_buffers.push_back(buffer_offset_size.first);
} }
std::sort(sorted_buffers.begin(), sorted_buffers.end(), absl::c_sort(sorted_buffers,
[](const LogicalBuffer* a, const LogicalBuffer* b) { [](const LogicalBuffer* a, const LogicalBuffer* b) {
return a->id() < b->id(); return a->id() < b->id();
}); });
@ -479,9 +479,8 @@ bool BufferAssignment::HaveDisjointSlices(const HloInstruction* hlo_a,
// didn't return the empty set) for both HLOs, and the two resulting sets of // didn't return the empty set) for both HLOs, and the two resulting sets of
// slices are disjoint. // slices are disjoint.
return !slices_a.empty() && !slices_b.empty() && return !slices_a.empty() && !slices_b.empty() &&
std::none_of(slices_a.begin(), slices_a.end(), absl::c_none_of(slices_a, [&](const BufferAllocation::Slice& slice) {
[&](const BufferAllocation::Slice& slice) { return slices_b.contains(slice);
return slices_b.count(slice) > 0;
}); });
} }
@ -952,9 +951,9 @@ Status BufferAssigner::AssignBuffersForComputation(
// operands (assuming operands are the same/larger size) enabling the // operands (assuming operands are the same/larger size) enabling the
// important reuse case where an elementwise instruction reuses one of its // important reuse case where an elementwise instruction reuses one of its
// operand's buffer. This improves locality. // operand's buffer. This improves locality.
std::sort(sorted_buffers.begin(), sorted_buffers.end(), absl::c_sort(sorted_buffers,
[has_sequential_order, &liveness, &post_order_position, assignment]( [has_sequential_order, &liveness, &post_order_position,
const LogicalBuffer* a, const LogicalBuffer* b) { assignment](const LogicalBuffer* a, const LogicalBuffer* b) {
// Primary sort is by decreasing buffer size. // Primary sort is by decreasing buffer size.
const int64 a_size = assignment->buffer_size_(*a); const int64 a_size = assignment->buffer_size_(*a);
const int64 b_size = assignment->buffer_size_(*b); const int64 b_size = assignment->buffer_size_(*b);
@ -1305,7 +1304,7 @@ std::vector<const LogicalBuffer*> ComputePeakMemoryLogicalBuffers(
live_buffers.end()); live_buffers.end());
// Stabily sort the live buffers. // Stabily sort the live buffers.
std::sort(live_buffers_vector.begin(), live_buffers_vector.end(), absl::c_sort(live_buffers_vector,
[](const LogicalBuffer* a, const LogicalBuffer* b) { [](const LogicalBuffer* a, const LogicalBuffer* b) {
return a->id() < b->id(); return a->id() < b->id();
}); });

View File

@ -384,7 +384,7 @@ TEST_F(CallGraphTest, ComplexGraph) {
// Verify visitation order of some computations in the graph. // Verify visitation order of some computations in the graph.
auto index_of = [&visited](const HloComputation* comp) { auto index_of = [&visited](const HloComputation* comp) {
auto it = std::find(visited.begin(), visited.end(), comp); auto it = absl::c_find(visited, comp);
EXPECT_NE(it, visited.end()); EXPECT_NE(it, visited.end());
return std::distance(visited.begin(), it); return std::distance(visited.begin(), it);
}; };

View File

@ -42,7 +42,7 @@ void ComputationLayout::SetToDefaultLayout() {
} }
bool ComputationLayout::LayoutIsSet() const { bool ComputationLayout::LayoutIsSet() const {
return std::all_of(parameter_layouts_.begin(), parameter_layouts_.end(), return absl::c_all_of(parameter_layouts_,
[](const ShapeLayout& s) { return s.LayoutIsSet(); }) && [](const ShapeLayout& s) { return s.LayoutIsSet(); }) &&
result_layout_.LayoutIsSet(); result_layout_.LayoutIsSet();
} }

View File

@ -539,8 +539,7 @@ class CopyRemover {
} }
std::vector<const HloValue*> values = buffer.values(); std::vector<const HloValue*> values = buffer.values();
std::sort(values.begin(), values.end(), absl::c_sort(values, [this](const HloValue* a, const HloValue* b) {
[this](const HloValue* a, const HloValue* b) {
return ordering_.IsDefinedBefore(*a, *b); return ordering_.IsDefinedBefore(*a, *b);
}); });
@ -842,9 +841,8 @@ class CopyRemover {
copy_value_node->next->prev = operand_node; copy_value_node->next->prev = operand_node;
// Patch up uses. Remove use of copy from operand_node uses. // Patch up uses. Remove use of copy from operand_node uses.
auto it = auto it = absl::c_find_if(
std::find_if(operand_node->uses.begin(), operand_node->uses.end(), operand_node->uses, [copy_value_node](const HloUse* use) {
[copy_value_node](const HloUse* use) {
return use->instruction == return use->instruction ==
copy_value_node->value->defining_instruction(); copy_value_node->value->defining_instruction();
}); });

View File

@ -77,9 +77,8 @@ StatusOr<DisassemblerResult> Disassembler::DisassembleObjectFile(
} }
// Sort the symbols in increasing address order. // Sort the symbols in increasing address order.
std::sort( absl::c_sort(symbols, [](const llvm::object::SymbolRef& a,
symbols.begin(), symbols.end(), const llvm::object::SymbolRef& b) {
[](const llvm::object::SymbolRef& a, const llvm::object::SymbolRef& b) {
// getAddress returns a Expected object. Assert there is no error // getAddress returns a Expected object. Assert there is no error
// before extracting the address. // before extracting the address.
llvm::Expected<uint64_t> a_address_or_error = a.getAddress(); llvm::Expected<uint64_t> a_address_or_error = a.getAddress();

View File

@ -1709,10 +1709,8 @@ StatusOr<bool> IrEmitter::EmitVectorizedReduce(
vectorization_factor_in_bytes / vectorization_factor_in_bytes /
ShapeUtil::ByteSizeOfPrimitiveType(reduce->shape().element_type()); ShapeUtil::ByteSizeOfPrimitiveType(reduce->shape().element_type());
bool is_reduction_over_minor_dimension = bool is_reduction_over_minor_dimension = absl::c_linear_search(
std::find(dimensions.begin(), dimensions.end(), dimensions, LayoutUtil::Minor(arg->shape().layout(), 0));
LayoutUtil::Minor(arg->shape().layout(), 0)) !=
dimensions.end();
unsigned element_alignment = tensorflow::MathUtil::GCD<unsigned>( unsigned element_alignment = tensorflow::MathUtil::GCD<unsigned>(
ShapeUtil::ByteSizeOfPrimitiveType(reduce->shape().element_type()), ShapeUtil::ByteSizeOfPrimitiveType(reduce->shape().element_type()),
@ -2401,8 +2399,7 @@ StatusOr<bool> IrEmitter::EmitFastConcatenate(
int64 concat_dim = concatenate->dimensions(0); int64 concat_dim = concatenate->dimensions(0);
const Layout& output_layout = output_shape.layout(); const Layout& output_layout = output_shape.layout();
auto output_min2maj = LayoutUtil::MinorToMajor(output_layout); auto output_min2maj = LayoutUtil::MinorToMajor(output_layout);
auto concat_dim_layout_itr = auto concat_dim_layout_itr = absl::c_find(output_min2maj, concat_dim);
std::find(output_min2maj.begin(), output_min2maj.end(), concat_dim);
std::vector<int64> inner_dims(output_min2maj.begin(), concat_dim_layout_itr); std::vector<int64> inner_dims(output_min2maj.begin(), concat_dim_layout_itr);
std::vector<int64> outer_dims(std::next(concat_dim_layout_itr), std::vector<int64> outer_dims(std::next(concat_dim_layout_itr),
@ -2956,8 +2953,7 @@ Status IrEmitter::ElementTypesSameAndSupported(
TF_RET_CHECK(!operands.empty()); TF_RET_CHECK(!operands.empty());
PrimitiveType primitive_type = operands[0]->shape().element_type(); PrimitiveType primitive_type = operands[0]->shape().element_type();
if (std::find(supported_types.begin(), supported_types.end(), if (!absl::c_linear_search(supported_types, primitive_type)) {
primitive_type) == supported_types.end()) {
return Unimplemented("unsupported operand type %s in op %s", return Unimplemented("unsupported operand type %s in op %s",
PrimitiveType_Name(primitive_type), PrimitiveType_Name(primitive_type),
HloOpcodeString(instruction.opcode())); HloOpcodeString(instruction.opcode()));

View File

@ -154,19 +154,16 @@ bool IsReductionToVector(const HloInstruction& reduce) {
const HloInstruction* input = reduce.operand(0); const HloInstruction* input = reduce.operand(0);
std::vector<int64> dims_to_keep; std::vector<int64> dims_to_keep;
for (int64 dim = 0; dim < input->shape().dimensions().size(); ++dim) { for (int64 dim = 0; dim < input->shape().dimensions().size(); ++dim) {
if (!std::count(reduce.dimensions().begin(), reduce.dimensions().end(), if (!absl::c_linear_search(reduce.dimensions(), dim)) {
dim)) {
dims_to_keep.push_back(dim); dims_to_keep.push_back(dim);
} }
} }
return LayoutUtil::AreDimensionsConsecutive(input->shape().layout(), return LayoutUtil::AreDimensionsConsecutive(input->shape().layout(),
dims_to_keep) && dims_to_keep) &&
ShapeUtil::Equal(reduce.shape(), ShapeUtil::FilterDimensions( ShapeUtil::Equal(
[&dims_to_keep](int64 dim) { reduce.shape(),
return std::count( ShapeUtil::FilterDimensions(
dims_to_keep.begin(), [&](int64 dim) { return absl::c_count(dims_to_keep, dim); },
dims_to_keep.end(), dim);
},
input->shape())); input->shape()));
} }

View File

@ -1506,7 +1506,7 @@ std::unique_ptr<KernelThunk> IrEmitterUnnested::BuildKernelThunk(
return !allocation->is_constant(); return !allocation->is_constant();
}); });
std::sort(non_constant_buffers.begin(), non_constant_buffers.end(), absl::c_sort(non_constant_buffers,
[](const BufferAllocation* a, const BufferAllocation* b) { [](const BufferAllocation* a, const BufferAllocation* b) {
return a->index() < b->index(); return a->index() < b->index();
}); });

View File

@ -225,7 +225,7 @@ Status HeapSimulator::RunComputation(
} }
} }
// Sort to get a deterministic iteration order. // Sort to get a deterministic iteration order.
std::sort(operand_buffers_to_free.begin(), operand_buffers_to_free.end(), absl::c_sort(operand_buffers_to_free,
[](const BufferValue* x, const BufferValue* y) { [](const BufferValue* x, const BufferValue* y) {
return x->id() < y->id(); return x->id() < y->id();
}); });
@ -335,8 +335,7 @@ Status HeapSimulator::RunComputation(
to_free.push_back(buffer); to_free.push_back(buffer);
} }
std::sort(to_free.begin(), to_free.end(), absl::c_sort(to_free, [](const BufferValue* x, const BufferValue* y) {
[](const BufferValue* x, const BufferValue* y) {
return x->id() < y->id(); return x->id() < y->id();
}); });
for (const BufferValue* buffer : to_free) { for (const BufferValue* buffer : to_free) {
@ -596,7 +595,7 @@ void DecreasingSizeRunsHeap::CallAndDrainRun() {
} }
// Call ops in the run sorted by decreasing size, breaking ties by buffer id. // Call ops in the run sorted by decreasing size, breaking ties by buffer id.
std::sort(run_.begin(), run_.end(), [](const Op& a, const Op& b) { absl::c_sort(run_, [](const Op& a, const Op& b) {
if (a.size != b.size) { if (a.size != b.size) {
return a.size > b.size; return a.size > b.size;
} }
@ -866,7 +865,7 @@ HeapSimulator::Result GlobalDecreasingSizeBestFitHeap::Finish() {
for (auto& entry : buffer_intervals_) { for (auto& entry : buffer_intervals_) {
sorted_buffer_intervals.push_back(entry.second); sorted_buffer_intervals.push_back(entry.second);
} }
std::sort(sorted_buffer_intervals.begin(), sorted_buffer_intervals.end(), absl::c_sort(sorted_buffer_intervals,
[](const BufferInterval& x, const BufferInterval& y) { [](const BufferInterval& x, const BufferInterval& y) {
if (x.size != y.size) { if (x.size != y.size) {
return x.size > y.size; return x.size > y.size;
@ -881,8 +880,8 @@ HeapSimulator::Result GlobalDecreasingSizeBestFitHeap::Finish() {
for (auto& buffer_interval : sorted_buffer_intervals) { for (auto& buffer_interval : sorted_buffer_intervals) {
auto chunks_overlapping_in_time = interval_tree.ChunksOverlappingInTime( auto chunks_overlapping_in_time = interval_tree.ChunksOverlappingInTime(
buffer_interval.start, buffer_interval.end); buffer_interval.start, buffer_interval.end);
std::sort( absl::c_sort(
chunks_overlapping_in_time.begin(), chunks_overlapping_in_time.end(), chunks_overlapping_in_time,
[](const Chunk& x, const Chunk& y) { return x.offset < y.offset; }); [](const Chunk& x, const Chunk& y) { return x.offset < y.offset; });
// Find the minimum free chunk that can hold this buffer. // Find the minimum free chunk that can hold this buffer.

View File

@ -117,7 +117,7 @@ class BufferValueMap {
for (const auto& pair : buffers_) { for (const auto& pair : buffers_) {
buffer_numbers.push_back(pair.first); buffer_numbers.push_back(pair.first);
} }
std::sort(buffer_numbers.begin(), buffer_numbers.end()); absl::c_sort(buffer_numbers);
return buffer_numbers; return buffer_numbers;
} }
@ -319,7 +319,7 @@ class BufferValueMap {
ComputeWhileAliasedBuffers(value, &aliased_buffers); ComputeWhileAliasedBuffers(value, &aliased_buffers);
ComputeConditionalAliasedBuffers(value, &aliased_buffers); ComputeConditionalAliasedBuffers(value, &aliased_buffers);
// Uniquify aliased buffers. // Uniquify aliased buffers.
std::sort(aliased_buffers.begin(), aliased_buffers.end()); absl::c_sort(aliased_buffers);
aliased_buffers.erase( aliased_buffers.erase(
std::unique(aliased_buffers.begin(), aliased_buffers.end()), std::unique(aliased_buffers.begin(), aliased_buffers.end()),
aliased_buffers.end()); aliased_buffers.end());
@ -367,7 +367,7 @@ std::vector<const HloBuffer*> HloAliasAnalysis::ComputeBuffersAt(
} }
// Sort and uniquify vector before returning. // Sort and uniquify vector before returning.
std::sort(buffers.begin(), buffers.end(), HloBuffer::IdLessThan); absl::c_sort(buffers, HloBuffer::IdLessThan);
buffers.erase(std::unique(buffers.begin(), buffers.end()), buffers.end()); buffers.erase(std::unique(buffers.begin(), buffers.end()), buffers.end());
return buffers; return buffers;
@ -430,8 +430,7 @@ Status HloAliasAnalysis::Verify() const {
for (const auto& pair : value_to_buffer_) { for (const auto& pair : value_to_buffer_) {
const HloValue* value = pair.first; const HloValue* value = pair.first;
const HloBuffer& buffer = *pair.second; const HloBuffer& buffer = *pair.second;
TF_RET_CHECK(std::find(buffer.values().begin(), buffer.values().end(), TF_RET_CHECK(absl::c_linear_search(buffer.values(), value));
value) != buffer.values().end());
} }
for (HloBuffer::Id id = 0; id < buffers_.size(); ++id) { for (HloBuffer::Id id = 0; id < buffers_.size(); ++id) {
@ -515,7 +514,7 @@ StatusOr<std::unique_ptr<HloAliasAnalysis>> HloAliasAnalysis::Run(
auto& value_set = buffer_map.GetValuesInBuffer(buffer_number); auto& value_set = buffer_map.GetValuesInBuffer(buffer_number);
std::vector<const HloValue*> sorted_values(value_set.begin(), std::vector<const HloValue*> sorted_values(value_set.begin(),
value_set.end()); value_set.end());
std::sort(sorted_values.begin(), sorted_values.end(), HloValue::IdLessThan); absl::c_sort(sorted_values, HloValue::IdLessThan);
alias_analysis->buffers_.emplace_back(next_id++, sorted_values); alias_analysis->buffers_.emplace_back(next_id++, sorted_values);
for (const HloValue* value : sorted_values) { for (const HloValue* value : sorted_values) {
alias_analysis->value_to_buffer_[value] = alias_analysis->value_to_buffer_[value] =
@ -547,8 +546,7 @@ bool HloAliasAnalysis::HasLiveRangeInterference(
// tie-break using value ID. The tie-break is necessary because we need a // tie-break using value ID. The tie-break is necessary because we need a
// strict weak order for std::sort. // strict weak order for std::sort.
std::vector<const HloValue*> values = buffer.values(); std::vector<const HloValue*> values = buffer.values();
std::sort(values.begin(), values.end(), absl::c_sort(values, [&ordering](const HloValue* a, const HloValue* b) {
[&ordering](const HloValue* a, const HloValue* b) {
if (ordering.IsDefinedBefore(*a, *b)) { if (ordering.IsDefinedBefore(*a, *b)) {
return true; return true;
} else if (ordering.IsDefinedBefore(*b, *a)) { } else if (ordering.IsDefinedBefore(*b, *a)) {

View File

@ -49,7 +49,7 @@ std::vector<HloPosition> HloBuffer::ComputePositions() const {
value->positions().end()); value->positions().end());
} }
// Remove duplicates and sort positions. // Remove duplicates and sort positions.
std::sort(positions.begin(), positions.end()); absl::c_sort(positions);
positions.erase(std::unique(positions.begin(), positions.end()), positions.erase(std::unique(positions.begin(), positions.end()),
positions.end()); positions.end());
return positions; return positions;

View File

@ -531,8 +531,7 @@ HloComputation::CreateFromProto(
HloInstruction* root = instruction_map.at(proto.root_id()); HloInstruction* root = instruction_map.at(proto.root_id());
// Sort the instructions in the proto id's order. // Sort the instructions in the proto id's order.
std::sort(instructions.begin(), instructions.end(), absl::c_sort(instructions, [&](const std::unique_ptr<HloInstruction>& a,
[&](const std::unique_ptr<HloInstruction>& a,
const std::unique_ptr<HloInstruction>& b) { const std::unique_ptr<HloInstruction>& b) {
return to_proto_id[a.get()] < to_proto_id[b.get()]; return to_proto_id[a.get()] < to_proto_id[b.get()];
}); });
@ -800,8 +799,7 @@ Status HloComputation::AcceptOrdered(
absl::Span<HloInstruction* const> order) const { absl::Span<HloInstruction* const> order) const {
VLOG(3) << "Accepting visitor with order."; VLOG(3) << "Accepting visitor with order.";
for (HloInstruction* root : CollectUnreachableRoots()) { for (HloInstruction* root : CollectUnreachableRoots()) {
TF_RET_CHECK(std::find(order.begin(), order.end(), root) != order.end()) TF_RET_CHECK(absl::c_linear_search(order, root)) << root->ToString();
<< root->ToString();
} }
TF_RET_CHECK(order.size() == instruction_count()); TF_RET_CHECK(order.size() == instruction_count());
absl::flat_hash_set<const HloInstruction*> visited; absl::flat_hash_set<const HloInstruction*> visited;

View File

@ -256,7 +256,7 @@ bool HloDataflowAnalysis::Phi(
input_value_ids.push_back(value->id()); input_value_ids.push_back(value->id());
} }
} }
std::sort(input_value_ids.begin(), input_value_ids.end()); absl::c_sort(input_value_ids);
input_value_ids.erase( input_value_ids.erase(
std::unique(input_value_ids.begin(), input_value_ids.end()), std::unique(input_value_ids.begin(), input_value_ids.end()),
input_value_ids.end()); input_value_ids.end());
@ -271,8 +271,7 @@ bool HloDataflowAnalysis::Phi(
if (current_value_defined_here) { if (current_value_defined_here) {
VLOG(5) << "current_value_defined_here: " << current_value->ToString(); VLOG(5) << "current_value_defined_here: " << current_value->ToString();
CHECK(current_value->is_phi()); CHECK(current_value->is_phi());
auto it = std::find(input_value_ids.begin(), input_value_ids.end(), auto it = absl::c_find(input_value_ids, current_value->id());
current_value->id());
if (it != input_value_ids.end()) { if (it != input_value_ids.end()) {
input_value_ids.erase(it); input_value_ids.erase(it);
} }
@ -921,8 +920,7 @@ StatusOr<std::unique_ptr<HloDataflowAnalysis>> HloDataflowAnalysis::Run(
for (auto& pair : dataflow_analysis->values_) { for (auto& pair : dataflow_analysis->values_) {
dataflow_analysis->values_vector_.push_back(&pair.second); dataflow_analysis->values_vector_.push_back(&pair.second);
} }
std::sort(dataflow_analysis->values_vector_.begin(), absl::c_sort(dataflow_analysis->values_vector_, HloValue::IdLessThan);
dataflow_analysis->values_vector_.end(), HloValue::IdLessThan);
TF_DCHECK_OK(dataflow_analysis->Verify()); TF_DCHECK_OK(dataflow_analysis->Verify());
@ -937,9 +935,7 @@ Status HloDataflowAnalysis::Verify() const {
for (const HloValue* value : values()) { for (const HloValue* value : values()) {
for (const HloPosition& position : value->positions()) { for (const HloPosition& position : value->positions()) {
const HloValueSet& value_set = GetValueSet(position); const HloValueSet& value_set = GetValueSet(position);
TF_RET_CHECK(std::find(value_set.values().begin(), TF_RET_CHECK(absl::c_linear_search(value_set.values(), value))
value_set.values().end(),
value) != value_set.values().end())
<< "Value set at position " << position << " does not contain value " << "Value set at position " << position << " does not contain value "
<< value->ToShortString(); << value->ToShortString();
} }
@ -954,9 +950,7 @@ Status HloDataflowAnalysis::Verify() const {
const HloValueSet& value_set = pair.second; const HloValueSet& value_set = pair.second;
const HloPosition position{instruction, index}; const HloPosition position{instruction, index};
for (const HloValue* value : value_set.values()) { for (const HloValue* value : value_set.values()) {
TF_RET_CHECK(std::find(value->positions().begin(), TF_RET_CHECK(absl::c_linear_search(value->positions(), position))
value->positions().end(),
position) != value->positions().end())
<< "Value set at position " << position << "Value set at position " << position
<< " unexpectedly contains value " << value->ToShortString(); << " unexpectedly contains value " << value->ToShortString();
} }
@ -1041,8 +1035,7 @@ bool HloDataflowAnalysis::CanShareOperandBufferWithUser(
// Check if one operand of kAdd fused root is kDot or kConvolution. // Check if one operand of kAdd fused root is kDot or kConvolution.
auto* add = user->fused_expression_root(); auto* add = user->fused_expression_root();
auto add_operand_it = auto add_operand_it =
std::find_if(add->operands().begin(), add->operands().end(), absl::c_find_if(add->operands(), [&](HloInstruction* operand) {
[&](HloInstruction* operand) {
return operand->opcode() == HloOpcode::kConvolution || return operand->opcode() == HloOpcode::kConvolution ||
operand->opcode() == HloOpcode::kDot; operand->opcode() == HloOpcode::kDot;
}); });
@ -1100,13 +1093,12 @@ bool HloDataflowAnalysis::CanShareOperandBufferWithUser(
// *) The root instruction of the called computation is element-wise on // *) The root instruction of the called computation is element-wise on
// 'operand'. // 'operand'.
const bool found_caller_use = const bool found_caller_use =
std::find_if(uses.begin(), uses.end(), [user](const HloUse& use) { absl::c_find_if(uses, [user](const HloUse& use) {
return use.instruction == user; return use.instruction == user;
}) != uses.end(); }) != uses.end();
auto* callee_root = user->to_apply()->root_instruction(); auto* callee_root = user->to_apply()->root_instruction();
const bool found_elementwise_callee_use = const bool found_elementwise_callee_use =
std::find_if( absl::c_find_if(uses, [callee_root](const HloUse& use) {
uses.begin(), uses.end(), [callee_root](const HloUse& use) {
return use.instruction == callee_root && return use.instruction == callee_root &&
callee_root->IsElementwiseOnOperand(use.operand_number); callee_root->IsElementwiseOnOperand(use.operand_number);
}) != uses.end(); }) != uses.end();

View File

@ -43,9 +43,7 @@ class HloDceTest : public HloTestBase {
// Returns whether the given instruction exists in the given computation. // Returns whether the given instruction exists in the given computation.
bool HasInstruction(const HloComputation& computation, bool HasInstruction(const HloComputation& computation,
const HloInstruction* instruction) { const HloInstruction* instruction) {
return std::find(computation.instructions().begin(), return absl::c_linear_search(computation.instructions(), instruction);
computation.instructions().end(),
instruction) != computation.instructions().end();
} }
}; };

View File

@ -230,7 +230,7 @@ HloDomainMap::MakeNonDomainInstructions(
} }
} }
// sort instructions according to instructions_order // sort instructions according to instructions_order
std::sort(instructions.begin(), instructions.end(), absl::c_sort(instructions,
[&instructions_order](HloInstruction* a, HloInstruction* b) { [&instructions_order](HloInstruction* a, HloInstruction* b) {
return instructions_order.at(a) < instructions_order.at(b); return instructions_order.at(a) < instructions_order.at(b);
}); });

View File

@ -1248,8 +1248,7 @@ StatusOr<Literal> EvaluateSortInternal(HloInstruction* sort,
// Extract a slice from the keys and values literals that correspond to // Extract a slice from the keys and values literals that correspond to
// exactly the row in dimension 'sort_dim'. // exactly the row in dimension 'sort_dim'.
std::vector<int64> limit_indices(indices.begin(), indices.end()); std::vector<int64> limit_indices(indices.begin(), indices.end());
std::for_each(limit_indices.begin(), limit_indices.end(), absl::c_for_each(limit_indices, [](int64& index) { ++index; });
[](int64& index) { ++index; });
limit_indices[sort_dim] = sort_dim_elements; limit_indices[sort_dim] = sort_dim_elements;
TF_ASSIGN_OR_RETURN(auto keys_to_sort, TF_ASSIGN_OR_RETURN(auto keys_to_sort,
keys_literal.Slice(indices, limit_indices) keys_literal.Slice(indices, limit_indices)

View File

@ -1673,8 +1673,7 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault {
// Extract a slice from the literal that corresponds to exactly the // Extract a slice from the literal that corresponds to exactly the
// row in dimension 'sort_dim'. // row in dimension 'sort_dim'.
std::vector<int64> limit_indices(indices.begin(), indices.end()); std::vector<int64> limit_indices(indices.begin(), indices.end());
std::for_each(limit_indices.begin(), limit_indices.end(), absl::c_for_each(limit_indices, [](int64& index) { ++index; });
[](int64& index) { ++index; });
limit_indices[sort_dim] = sort_dim_elements; limit_indices[sort_dim] = sort_dim_elements;
TF_ASSIGN_OR_RETURN(auto row_to_sort, TF_ASSIGN_OR_RETURN(auto row_to_sort,
keys_literal.Slice(indices, limit_indices) keys_literal.Slice(indices, limit_indices)

View File

@ -561,8 +561,8 @@ bool HloDotDumper::ShouldShowSubcomputation(const HloComputation* subcomp) {
} }
// Show the subcomputation if we're showing any of its members. // Show the subcomputation if we're showing any of its members.
return std::any_of( return absl::c_any_of(
subcomp->instructions().begin(), subcomp->instructions().end(), subcomp->instructions(),
[&](const HloInstruction* instr) { return filter_.Show(instr); }); [&](const HloInstruction* instr) { return filter_.Show(instr); });
} }
@ -735,12 +735,11 @@ bool HloDotDumper::ShouldMergeIntoUsers(const HloInstruction* instr) const {
const int kMinUsersToOmit = 3; const int kMinUsersToOmit = 3;
return instr->opcode() == HloOpcode::kParameter && instr->shape().IsTuple() && return instr->opcode() == HloOpcode::kParameter && instr->shape().IsTuple() &&
!instr->IsFused() && !instr->IsFused() &&
std::count_if(instr->users().begin(), instr->users().end(), absl::c_count_if(instr->users(),
[&](const HloInstruction* user) { [&](const HloInstruction* user) {
return filter_.Show(user); return filter_.Show(user);
}) > kMinUsersToOmit && }) > kMinUsersToOmit &&
std::all_of(instr->users().begin(), instr->users().end(), absl::c_all_of(instr->users(), [&](const HloInstruction* user) {
[&](const HloInstruction* user) {
return !filter_.Show(user) || return !filter_.Show(user) ||
user->opcode() == HloOpcode::kGetTupleElement; user->opcode() == HloOpcode::kGetTupleElement;
}); });
@ -900,8 +899,7 @@ ColorScheme HloDotDumper::GetInstructionColor(const HloInstruction* instr) {
// the same color as a parameter. Unless the merged-in parameter is a // the same color as a parameter. Unless the merged-in parameter is a
// parameter to a fusion node that is bound to a constant -- these aren't // parameter to a fusion node that is bound to a constant -- these aren't
// "real" parameters from the user's perspective. // "real" parameters from the user's perspective.
if (std::any_of(instr->operands().begin(), instr->operands().end(), if (absl::c_any_of(instr->operands(), [&](const HloInstruction* operand) {
[&](const HloInstruction* operand) {
return operand->opcode() == HloOpcode::kParameter && return operand->opcode() == HloOpcode::kParameter &&
ShouldMergeIntoUsers(operand) && ShouldMergeIntoUsers(operand) &&
TryGetFusionParameterConstant(operand) == nullptr; TryGetFusionParameterConstant(operand) == nullptr;
@ -1355,12 +1353,11 @@ NodeFilter MakeNodeRadiusAroundFilter(const HloInstruction* root,
NodeFilterResult& filter_result = kv.second; NodeFilterResult& filter_result = kv.second;
const auto& operands = instr->operands(); const auto& operands = instr->operands();
if (std::any_of(operands.begin(), operands.end(), is_displayed) && if (absl::c_any_of(operands, is_displayed) &&
!std::all_of(operands.begin(), operands.end(), is_displayed)) { !absl::c_all_of(operands, is_displayed)) {
// Mark nodes with some operands omitted appropriately. // Mark nodes with some operands omitted appropriately.
filter_result = kSomeOperandsOmitted; filter_result = kSomeOperandsOmitted;
} else if (!operands.empty() && } else if (!operands.empty() && absl::c_none_of(operands, is_displayed)) {
std::none_of(operands.begin(), operands.end(), is_displayed)) {
// Mark nodes with *all* operands omitted appropriately. // Mark nodes with *all* operands omitted appropriately.
filter_result = kOmitNodeOperands; filter_result = kOmitNodeOperands;
} }
@ -1368,8 +1365,7 @@ NodeFilter MakeNodeRadiusAroundFilter(const HloInstruction* root,
// Promote nodes with type kSomeUsersOmitted to kNormalNode if all of their // Promote nodes with type kSomeUsersOmitted to kNormalNode if all of their
// users made it into the graph. // users made it into the graph.
if (filter_result == kSomeUsersOmitted && if (filter_result == kSomeUsersOmitted &&
std::all_of(instr->users().begin(), instr->users().end(), absl::c_all_of(instr->users(), is_displayed)) {
is_displayed)) {
filter_result = kNormalNode; filter_result = kNormalNode;
} }
} }

View File

@ -83,15 +83,14 @@ StatusOr<std::unique_ptr<HloInstruction>> HloInstruction::CreateFromProto(
return computation_map.at(proto.called_computation_ids(index)); return computation_map.at(proto.called_computation_ids(index));
}; };
TF_RET_CHECK(std::all_of( TF_RET_CHECK(
proto.operand_ids().begin(), proto.operand_ids().end(), absl::c_all_of(proto.operand_ids(),
[&instruction_map](int64 id) { return instruction_map.contains(id); })) [&](int64 id) { return instruction_map.contains(id); }))
<< proto.name() << " instruction contains invalid operand id(s)"; << proto.name() << " instruction contains invalid operand id(s)";
TF_RET_CHECK(std::all_of( TF_RET_CHECK(
proto.called_computation_ids().begin(), absl::c_all_of(proto.called_computation_ids(),
proto.called_computation_ids().end(), [&](int64 id) { return computation_map.contains(id); }))
[&computation_map](int64 id) { return computation_map.contains(id); }))
<< proto.name() << " instruction references invalid computation id(s)"; << proto.name() << " instruction references invalid computation id(s)";
Shape shape(proto.shape()); Shape shape(proto.shape());
@ -1599,12 +1598,10 @@ HloInstruction::InstructionVector HloInstruction::unique_operands() const {
Status HloInstruction::AddControlDependencyTo(HloInstruction* instruction) { Status HloInstruction::AddControlDependencyTo(HloInstruction* instruction) {
TF_RET_CHECK(instruction->parent() == parent()); TF_RET_CHECK(instruction->parent() == parent());
if (std::find(control_successors_.begin(), control_successors_.end(), if (!absl::c_linear_search(control_successors_, instruction)) {
instruction) == control_successors_.end()) {
control_successors_.push_back(instruction); control_successors_.push_back(instruction);
TF_RET_CHECK(std::find(instruction->control_predecessors_.begin(), TF_RET_CHECK(
instruction->control_predecessors_.end(), !absl::c_linear_search(instruction->control_predecessors_, this));
this) == instruction->control_predecessors_.end());
instruction->control_predecessors_.push_back(this); instruction->control_predecessors_.push_back(this);
} }
return Status::OK(); return Status::OK();
@ -1853,7 +1850,7 @@ void HloInstruction::RemoveUser(HloInstruction* user) {
user_set_.erase(set_it); user_set_.erase(set_it);
// This is linear in the number of the users, but a vector provides a stable // This is linear in the number of the users, but a vector provides a stable
// iteration order and much faster traversal. // iteration order and much faster traversal.
auto vec_it = std::find(users_.begin(), users_.end(), user); auto vec_it = absl::c_find(users_, user);
CHECK(vec_it != users_.end()); CHECK(vec_it != users_.end());
users_.erase(vec_it); users_.erase(vec_it);
} }
@ -1871,8 +1868,7 @@ Status HloInstruction::ReplaceUseWith(HloInstruction* user,
RemoveUser(user); RemoveUser(user);
TF_RET_CHECK( TF_RET_CHECK(absl::c_count(user->operands_, this) >= 0);
std::count(user->operands_.begin(), user->operands_.end(), this) >= 0);
std::replace(user->operands_.begin(), user->operands_.end(), this, std::replace(user->operands_.begin(), user->operands_.end(), this,
new_producer); new_producer);
new_producer->AddUser(user); new_producer->AddUser(user);
@ -1907,8 +1903,7 @@ Status HloInstruction::ReplaceOperandWithDifferentShape(
VLOG(3) << "Replacing operand " << operand_num << " of " << name() << " with " VLOG(3) << "Replacing operand " << operand_num << " of " << name() << " with "
<< new_operand->name() << ", was " << old_operand->name(); << new_operand->name() << ", was " << old_operand->name();
if (std::find(operands_.begin(), operands_.end(), old_operand) == if (!absl::c_linear_search(operands_, old_operand)) {
operands_.end()) {
old_operand->RemoveUser(this); old_operand->RemoveUser(this);
} }
new_operand->AddUser(this); new_operand->AddUser(this);
@ -2945,7 +2940,7 @@ StatusOr<HloInstruction::FusionKind> StringToFusionKind(
string PaddingConfigToString(const PaddingConfig& padding) { string PaddingConfigToString(const PaddingConfig& padding) {
bool has_interior_padding = bool has_interior_padding =
std::any_of(padding.dimensions().begin(), padding.dimensions().end(), absl::c_any_of(padding.dimensions(),
[](const PaddingConfig::PaddingConfigDimension& dim) { [](const PaddingConfig::PaddingConfigDimension& dim) {
return dim.interior_padding() != 0; return dim.interior_padding() != 0;
}); });

View File

@ -42,9 +42,7 @@ using absl::StrJoin;
bool IsInstructionElementwiseOnOperand(const HloInstruction* instruction, bool IsInstructionElementwiseOnOperand(const HloInstruction* instruction,
const HloInstruction* operand) { const HloInstruction* operand) {
std::vector<int64> operand_indices = instruction->OperandIndices(operand); std::vector<int64> operand_indices = instruction->OperandIndices(operand);
return std::all_of( return absl::c_all_of(operand_indices, [instruction](int64 operand_index) {
operand_indices.begin(), operand_indices.end(),
[instruction](int64 operand_index) {
return instruction->IsElementwiseOnOperand(operand_index); return instruction->IsElementwiseOnOperand(operand_index);
}); });
} }
@ -814,8 +812,7 @@ std::vector<string> HloSliceInstruction::ExtraAttributesToStringImpl(
std::vector<string> bounds; std::vector<string> bounds;
bounds.reserve(slice_starts_.size()); bounds.reserve(slice_starts_.size());
const bool omit_stride = const bool omit_stride =
std::all_of(slice_strides_.begin(), slice_strides_.end(), absl::c_all_of(slice_strides_, [](int64 stride) { return stride == 1; });
[](int64 stride) { return stride == 1; });
for (int i = 0; i < slice_starts_.size(); ++i) { for (int i = 0; i < slice_starts_.size(); ++i) {
string stride_str = omit_stride ? "" : StrCat(":", slice_strides_[i]); string stride_str = omit_stride ? "" : StrCat(":", slice_strides_[i]);
bounds.push_back( bounds.push_back(
@ -1051,8 +1048,7 @@ HloInstruction* HloFusionInstruction::AddFusionOperand(
void HloFusionInstruction::MergeFusionInstruction( void HloFusionInstruction::MergeFusionInstruction(
HloFusionInstruction* instruction_to_merge) { HloFusionInstruction* instruction_to_merge) {
CHECK(std::find(operands().begin(), operands().end(), instruction_to_merge) != CHECK(absl::c_linear_search(operands(), instruction_to_merge));
operands().end());
// Clone the instruction from which to merge fused instructions. // Clone the instruction from which to merge fused instructions.
std::unique_ptr<HloInstruction> cloned = instruction_to_merge->Clone(); std::unique_ptr<HloInstruction> cloned = instruction_to_merge->Clone();
HloFusionInstruction* cloned_fusion = HloFusionInstruction* cloned_fusion =
@ -1219,8 +1215,8 @@ HloInstruction* HloFusionInstruction::CloneAndFuseInternal(
// corresponding fused parameter instruction. Renumber parameters as // corresponding fused parameter instruction. Renumber parameters as
// necessary to make parameter numbers consistent with their index in the // necessary to make parameter numbers consistent with their index in the
// fused_parameter_ vector. // fused_parameter_ vector.
bool in_operand_list = std::find(operands().begin(), operands().end(), bool in_operand_list =
instruction_to_fuse) != operands().end(); absl::c_linear_search(operands(), instruction_to_fuse);
CHECK(add_output || in_operand_list); CHECK(add_output || in_operand_list);
if (instruction_to_fuse->opcode() == HloOpcode::kTuple) { if (instruction_to_fuse->opcode() == HloOpcode::kTuple) {
// We assume all uses of a kTuple operation are GTE ops, not another // We assume all uses of a kTuple operation are GTE ops, not another

View File

@ -107,9 +107,8 @@ HloComputation* HloModule::AddEntryComputation(
} }
Status HloModule::RemoveEmbeddedComputation(HloComputation* to_remove) { Status HloModule::RemoveEmbeddedComputation(HloComputation* to_remove) {
auto it = auto it = absl::c_find_if(
std::find_if(computations_.begin(), computations_.end(), computations_, [&to_remove](const std::unique_ptr<HloComputation>& comp) {
[&to_remove](const std::unique_ptr<HloComputation>& comp) {
return comp.get() == to_remove; return comp.get() == to_remove;
}); });
TF_RET_CHECK(it->get() == to_remove); TF_RET_CHECK(it->get() == to_remove);
@ -304,8 +303,7 @@ StatusOr<std::unique_ptr<HloModule>> HloModule::CreateFromProto(
auto module = absl::make_unique<HloModule>(proto.name(), module_config); auto module = absl::make_unique<HloModule>(proto.name(), module_config);
// Sort the computations in the proto id's order. // Sort the computations in the proto id's order.
std::sort(computations.begin(), computations.end(), absl::c_sort(computations, [&](const std::unique_ptr<HloComputation>& a,
[&](const std::unique_ptr<HloComputation>& a,
const std::unique_ptr<HloComputation>& b) { const std::unique_ptr<HloComputation>& b) {
return to_proto_id[a.get()] < to_proto_id[b.get()]; return to_proto_id[a.get()] < to_proto_id[b.get()];
}); });

View File

@ -38,9 +38,7 @@ class HloModuleDceTest : public HloTestBase {
// Returns whether the given instruction exists in the given computation. // Returns whether the given instruction exists in the given computation.
bool HasInstruction(const HloComputation& computation, bool HasInstruction(const HloComputation& computation,
const HloInstruction* instruction) { const HloInstruction* instruction) {
return std::find(computation.instructions().begin(), return absl::c_linear_search(computation.instructions(), instruction);
computation.instructions().end(),
instruction) != computation.instructions().end();
} }
// Returns whether the while instruction with name 'while_name' in // Returns whether the while instruction with name 'while_name' in

View File

@ -2746,7 +2746,7 @@ bool HloParser::ParseConvolutionDimensionNumbers(
} }
auto is_unique = [](string str) -> bool { auto is_unique = [](string str) -> bool {
std::sort(str.begin(), str.end()); absl::c_sort(str);
return std::unique(str.begin(), str.end()) == str.end(); return std::unique(str.begin(), str.end()) == str.end();
}; };

View File

@ -15,6 +15,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/hlo_profile_printer.h" #include "tensorflow/compiler/xla/service/hlo_profile_printer.h"
#include "absl/algorithm/container.h"
#include "absl/strings/str_cat.h" #include "absl/strings/str_cat.h"
#include "tensorflow/compiler/xla/service/human_readable_profile_builder.h" #include "tensorflow/compiler/xla/service/human_readable_profile_builder.h"
@ -34,9 +35,8 @@ string PrintHloProfile(const HloProfilePrinterData& hlo_profile_printer_data,
for (const HloComputationInfo& computation_info : for (const HloComputationInfo& computation_info :
hlo_profile_printer_data.computation_infos()) { hlo_profile_printer_data.computation_infos()) {
const auto& instruction_infos = computation_info.instruction_infos(); const auto& instruction_infos = computation_info.instruction_infos();
bool any_instruction_profiled = bool any_instruction_profiled = absl::c_any_of(
std::any_of(instruction_infos.begin(), instruction_infos.end(), instruction_infos, [&](const HloInstructionInfo& instruction_info) {
[&](const HloInstructionInfo& instruction_info) {
return counters[instruction_info.profile_index()] != 0; return counters[instruction_info.profile_index()] != 0;
}); });

View File

@ -49,7 +49,7 @@ void HloReachabilityMap::SetReachabilityToUnionHelper(
absl::Span<const HloInstruction* const> inputs, absl::Span<const HloInstruction* const> inputs,
const HloInstruction* instruction, BitVector* bit_vector) { const HloInstruction* instruction, BitVector* bit_vector) {
// If instruction is part of inputs, don't reset the bit_vector. // If instruction is part of inputs, don't reset the bit_vector.
if (std::find(inputs.begin(), inputs.end(), instruction) == inputs.end()) { if (!absl::c_linear_search(inputs, instruction)) {
bit_vector->SetToZero(); bit_vector->SetToZero();
} }
bit_vector->Set(GetIndex(instruction)); bit_vector->Set(GetIndex(instruction));

View File

@ -235,8 +235,7 @@ class InstructionList {
} }
// Now scan forwards until we find one of the before_instructions. // Now scan forwards until we find one of the before_instructions.
while (std::find(before_instructions.begin(), before_instructions.end(), while (!absl::c_linear_search(before_instructions, min_position_item)) {
min_position_item) == before_instructions.end()) {
min_position_item = min_position_item->next; min_position_item = min_position_item->next;
} }
return InsertBefore(to_insert, min_position_item); return InsertBefore(to_insert, min_position_item);
@ -302,7 +301,7 @@ ItemList GetUsers(const InstructionList& instruction_list,
// A buffer may be used by the instruction via more than one alias. For // A buffer may be used by the instruction via more than one alias. For
// example, a buffer which appears in more than one element of a tuple. // example, a buffer which appears in more than one element of a tuple.
Item* user_item = instruction_list.GetItem(user); Item* user_item = instruction_list.GetItem(user);
if (std::find(users.begin(), users.end(), user_item) == users.end()) { if (!absl::c_linear_search(users, user_item)) {
users.push_back(user_item); users.push_back(user_item);
} }
} }
@ -456,8 +455,7 @@ class MemoryUsageTracker {
return false; return false;
} }
const BufferIdList& in_progress_uses = in_progress_item_->buffers_used; const BufferIdList& in_progress_uses = in_progress_item_->buffers_used;
return std::find(in_progress_uses.begin(), in_progress_uses.end(), return absl::c_linear_search(in_progress_uses, buffer_id);
buffer_id) != in_progress_uses.end();
} }
// Returns whether the given instruction is live at the current program // Returns whether the given instruction is live at the current program
@ -535,8 +533,7 @@ MemoryUsageTracker::MemoryUsageTracker(
bool unused; bool unused;
for (Item* user_item : GetUsers(instruction_list_, logical_buffer, for (Item* user_item : GetUsers(instruction_list_, logical_buffer,
points_to_analysis, &unused)) { points_to_analysis, &unused)) {
if (std::find(buffer->users.begin(), buffer->users.end(), if (!absl::c_linear_search(buffer->users, user_item)) {
user_item) == buffer->users.end()) {
buffer->users.push_back(user_item); buffer->users.push_back(user_item);
buffer->unfinished_user_count++; buffer->unfinished_user_count++;
user_item->buffers_used.push_back(buffer->id); user_item->buffers_used.push_back(buffer->id);
@ -784,8 +781,7 @@ bool MemoryUsageTracker::Check() const {
for (const Buffer& buffer : buffers_) { for (const Buffer& buffer : buffers_) {
if (buffer.defining_instruction->instruction == instruction) { if (buffer.defining_instruction->instruction == instruction) {
CHECK(std::find(defined_buffers.begin(), defined_buffers.end(), CHECK(absl::c_linear_search(defined_buffers, buffer.id))
buffer.id) != defined_buffers.end())
<< "Instruction " << instruction->name() << "Instruction " << instruction->name()
<< " defined buffers is missing: " << buffer.ToString(); << " defined buffers is missing: " << buffer.ToString();
} }
@ -808,8 +804,7 @@ bool MemoryUsageTracker::Check() const {
int64 unfinished_uses = 0; int64 unfinished_uses = 0;
for (Item* user : buffer.users) { for (Item* user : buffer.users) {
const BufferIdList& used_buffers = user->buffers_used; const BufferIdList& used_buffers = user->buffers_used;
CHECK(std::find(used_buffers.begin(), used_buffers.end(), buffer.id) != CHECK(absl::c_linear_search(used_buffers, buffer.id))
used_buffers.end())
<< "Instruction " << user->instruction->name() << "Instruction " << user->instruction->name()
<< " used buffers is missing " << buffer.ToString(); << " used buffers is missing " << buffer.ToString();
if (!IsFinished(user)) { if (!IsFinished(user)) {
@ -836,7 +831,7 @@ int64 RematerializationCost(const HloInstruction* instruction,
// If none of the users of 'instruction' have been placed in the sequence (as // If none of the users of 'instruction' have been placed in the sequence (as
// tracked by memory_tracker), then rematerialization of 'instruction' is a // tracked by memory_tracker), then rematerialization of 'instruction' is a
// zero-cost move of 'instruction' in the sequence. // zero-cost move of 'instruction' in the sequence.
if (!std::any_of(instruction->users().begin(), instruction->users().end(), if (!absl::c_any_of(instruction->users(),
[&memory_tracker](const HloInstruction* inst) { [&memory_tracker](const HloInstruction* inst) {
return memory_tracker.IsPlaced(inst); return memory_tracker.IsPlaced(inst);
})) { })) {

View File

@ -107,13 +107,12 @@ string HloSharding::ToString() const {
bool HloSharding::UsesDevice(int64 device) const { bool HloSharding::UsesDevice(int64 device) const {
if (IsTuple()) { if (IsTuple()) {
return std::any_of( return absl::c_any_of(tuple_elements_, [&](const HloSharding& s) {
tuple_elements_.begin(), tuple_elements_.end(), return s.UsesDevice(device);
[&](const HloSharding& s) { return s.UsesDevice(device); }); });
} }
const auto& devices = tile_assignment_; const auto& devices = tile_assignment_;
return replicated_ || return replicated_ || absl::c_linear_search(devices, device);
std::find(devices.begin(), devices.end(), device) != devices.end();
} }
std::map<int64, int64> HloSharding::UsedDevices(int64* count) const { std::map<int64, int64> HloSharding::UsedDevices(int64* count) const {

View File

@ -101,8 +101,8 @@ class HloSharding {
if (!IsTuple()) { if (!IsTuple()) {
return replicated_; return replicated_;
} }
return std::all_of(tuple_elements_.begin(), tuple_elements_.end(), return absl::c_all_of(
[](const HloSharding& s) { return s.IsReplicated(); }); tuple_elements_, [](const HloSharding& s) { return s.IsReplicated(); });
} }
// Returns true if the tile size is the same as the input size. // Returns true if the tile size is the same as the input size.
@ -110,8 +110,9 @@ class HloSharding {
if (!IsTuple()) { if (!IsTuple()) {
return maximal_; return maximal_;
} }
return std::all_of(tuple_elements_.begin(), tuple_elements_.end(), return absl::c_all_of(tuple_elements_, [](const HloSharding& s) {
[](const HloSharding& s) { return s.IsTileMaximal(); }); return s.IsTileMaximal();
});
} }
// Returns true if the sharding defines an operation on the given device. // Returns true if the sharding defines an operation on the given device.

View File

@ -61,8 +61,7 @@ void CleanNodeName(string* name) {
name->erase(std::remove(name->begin(), name->end(), '%'), name->end()); name->erase(std::remove(name->begin(), name->end(), '%'), name->end());
const string chars_to_replace = "<>[]"; const string chars_to_replace = "<>[]";
auto pred = [&](char c) { auto pred = [&](char c) {
return std::find(chars_to_replace.begin(), chars_to_replace.end(), c) != return absl::c_linear_search(chars_to_replace, c);
chars_to_replace.end();
}; };
std::replace_if(name->begin(), name->end(), pred, '_'); std::replace_if(name->begin(), name->end(), pred, '_');
} }

View File

@ -209,7 +209,7 @@ std::ostream& operator<<(std::ostream& out, const HloValue& value) {
} }
void HloValueSet::SortAndUniquifyValues() { void HloValueSet::SortAndUniquifyValues() {
std::sort(values_.begin(), values_.end(), HloValue::IdLessThan); absl::c_sort(values_, HloValue::IdLessThan);
values_.erase(std::unique(values_.begin(), values_.end(), HloValue::IdEqual), values_.erase(std::unique(values_.begin(), values_.end(), HloValue::IdEqual),
values_.end()); values_.end());
} }

View File

@ -128,9 +128,9 @@ string HumanReadableProfileBuilder::ToString() const {
// Sort ops in decreasing order of cycles, and print them. // Sort ops in decreasing order of cycles, and print them.
std::vector<OpInfo> sorted_ops(op_infos_); std::vector<OpInfo> sorted_ops(op_infos_);
std::sort( absl::c_sort(sorted_ops, [](const OpInfo& a, const OpInfo& b) {
sorted_ops.begin(), sorted_ops.end(), return a.cycles > b.cycles;
[](const OpInfo& a, const OpInfo& b) { return a.cycles > b.cycles; }); });
for (const auto& op : sorted_ops) { for (const auto& op : sorted_ops) {
print_op(op); print_op(op);
} }

View File

@ -178,8 +178,8 @@ bool InstructionFusion::EffectivelyAtMostUnary(HloInstruction* hlo) {
output_rank = std::max(output_rank, ShapeUtil::TrueRank(subshape)); output_rank = std::max(output_rank, ShapeUtil::TrueRank(subshape));
} }
}); });
return std::count_if(hlo->operands().begin(), hlo->operands().end(), return absl::c_count_if(
[output_rank](HloInstruction* operand) { hlo->operands(), [output_rank](HloInstruction* operand) {
if (operand->opcode() == HloOpcode::kBroadcast || if (operand->opcode() == HloOpcode::kBroadcast ||
operand->opcode() == HloOpcode::kIota) { operand->opcode() == HloOpcode::kIota) {
return false; return false;
@ -188,8 +188,7 @@ bool InstructionFusion::EffectivelyAtMostUnary(HloInstruction* hlo) {
ShapeUtil::IsEffectiveScalar(operand->shape())) { ShapeUtil::IsEffectiveScalar(operand->shape())) {
return false; return false;
} }
return ShapeUtil::TrueRank(operand->shape()) >= return ShapeUtil::TrueRank(operand->shape()) >= output_rank;
output_rank;
}) <= 1; }) <= 1;
} }
@ -409,9 +408,8 @@ class ReversePostOrderFusionQueue : public FusionQueue {
} }
sorted_operand_numbers.push_back(i); sorted_operand_numbers.push_back(i);
} }
std::sort( absl::c_sort(
sorted_operand_numbers.begin(), sorted_operand_numbers.end(), sorted_operand_numbers, [&](int64 i, int64 j) {
[&](int64 i, int64 j) {
// Instructions with higher priority in the queue come first. // Instructions with higher priority in the queue come first.
return ( return (
FindOrDie(post_order_index_, instruction->mutable_operand(i)) > FindOrDie(post_order_index_, instruction->mutable_operand(i)) >

View File

@ -147,12 +147,9 @@ bool LayoutConstraints::OperandBufferForwarded(
PointsToSet::BufferSet* output_buffers = GetBufferSet(instruction); PointsToSet::BufferSet* output_buffers = GetBufferSet(instruction);
PointsToSet::BufferSet* operand_buffers = PointsToSet::BufferSet* operand_buffers =
GetBufferSet(instruction->operand(operand_no)); GetBufferSet(instruction->operand(operand_no));
for (const LogicalBuffer* output_buffer : *output_buffers) { return absl::c_any_of(*output_buffers, [&](const LogicalBuffer* b) {
if (operand_buffers->count(output_buffer) > 0) { return operand_buffers->count(b) > 0;
return true; });
}
}
return false;
} }
Status LayoutConstraints::SetBufferLayout(const Layout& layout, Status LayoutConstraints::SetBufferLayout(const Layout& layout,

View File

@ -81,9 +81,7 @@ void AliasAnalysis::AddAliasingInformationToIrArray(const HloInstruction& hlo,
if (hlo.opcode() == HloOpcode::kParameter) { if (hlo.opcode() == HloOpcode::kParameter) {
const std::vector<HloInstruction*>& parameter_instructions = const std::vector<HloInstruction*>& parameter_instructions =
module_.entry_computation()->parameter_instructions(); module_.entry_computation()->parameter_instructions();
if (std::find(parameter_instructions.begin(), if (absl::c_linear_search(parameter_instructions, &hlo)) {
parameter_instructions.end(),
&hlo) != parameter_instructions.end()) {
array->MarkInvariantOverWholeProgram(context_); array->MarkInvariantOverWholeProgram(context_);
} }
} }

View File

@ -34,7 +34,7 @@ bool IsAllowed(char character) {
} // namespace } // namespace
NameUniquer::NameUniquer(const string& separator) { NameUniquer::NameUniquer(const string& separator) {
CHECK(std::all_of(separator.begin(), separator.end(), IsAllowed)) CHECK(absl::c_all_of(separator, IsAllowed))
<< "separator should comprises allowed characters only"; << "separator should comprises allowed characters only";
separator_ = separator; separator_ = separator;
} }

View File

@ -260,7 +260,7 @@ PlatformUtil::GetStreamExecutors(
// Block here in thread_pool destructor until all devices are initialized. // Block here in thread_pool destructor until all devices are initialized.
} }
VLOG(1) << "Device initialization complete"; VLOG(1) << "Device initialization complete";
if (std::all_of(stream_executors.begin(), stream_executors.end(), if (absl::c_all_of(stream_executors,
[](se::StreamExecutor* s) { return s == nullptr; })) { [](se::StreamExecutor* s) { return s == nullptr; })) {
return InternalError("no supported devices found for platform %s", return InternalError("no supported devices found for platform %s",
platform->Name()); platform->Name());

View File

@ -534,9 +534,8 @@ Status ValidateDotDimensionNumbers(
absl::Span<const int64> contracting_dims, absl::Span<const int64> contracting_dims,
absl::Span<const int64> batch_dims) -> bool { absl::Span<const int64> batch_dims) -> bool {
auto in_range = [&rank](int64 i) -> bool { return 0 <= i && i < rank; }; auto in_range = [&rank](int64 i) -> bool { return 0 <= i && i < rank; };
return std::all_of(contracting_dims.begin(), contracting_dims.end(), return absl::c_all_of(contracting_dims, in_range) &&
in_range) && absl::c_all_of(batch_dims, in_range);
std::all_of(batch_dims.begin(), batch_dims.end(), in_range);
}; };
absl::Span<const int64> lhs_contracting_dimensions = absl::Span<const int64> lhs_contracting_dimensions =
@ -563,9 +562,8 @@ Status ValidateDotDimensionNumbers(
auto is_unique = [&dim_set](int64 i) -> bool { auto is_unique = [&dim_set](int64 i) -> bool {
return dim_set.insert(i).second; return dim_set.insert(i).second;
}; };
return std::all_of(contracting_dims.begin(), contracting_dims.end(), return absl::c_all_of(contracting_dims, is_unique) &&
is_unique) && absl::c_all_of(batch_dims, is_unique);
std::all_of(batch_dims.begin(), batch_dims.end(), is_unique);
}; };
if (!dims_unique(lhs_contracting_dimensions, lhs_batch_dimensions) || if (!dims_unique(lhs_contracting_dimensions, lhs_batch_dimensions) ||
@ -1589,29 +1587,29 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation,
input_dnums[1] = dnums.input_feature_dimension(); input_dnums[1] = dnums.input_feature_dimension();
std::copy(dnums.input_spatial_dimensions().begin(), std::copy(dnums.input_spatial_dimensions().begin(),
dnums.input_spatial_dimensions().end(), input_dnums.begin() + 2); dnums.input_spatial_dimensions().end(), input_dnums.begin() + 2);
std::sort(input_dnums.begin(), input_dnums.end()); absl::c_sort(input_dnums);
std::vector<int64> window_dnums(num_dims); std::vector<int64> window_dnums(num_dims);
window_dnums[0] = dnums.kernel_input_feature_dimension(); window_dnums[0] = dnums.kernel_input_feature_dimension();
window_dnums[1] = dnums.kernel_output_feature_dimension(); window_dnums[1] = dnums.kernel_output_feature_dimension();
std::copy(dnums.kernel_spatial_dimensions().begin(), std::copy(dnums.kernel_spatial_dimensions().begin(),
dnums.kernel_spatial_dimensions().end(), window_dnums.begin() + 2); dnums.kernel_spatial_dimensions().end(), window_dnums.begin() + 2);
std::sort(window_dnums.begin(), window_dnums.end()); absl::c_sort(window_dnums);
std::vector<int64> output_dnums(num_dims); std::vector<int64> output_dnums(num_dims);
output_dnums[0] = dnums.output_batch_dimension(); output_dnums[0] = dnums.output_batch_dimension();
output_dnums[1] = dnums.output_feature_dimension(); output_dnums[1] = dnums.output_feature_dimension();
std::copy(dnums.output_spatial_dimensions().begin(), std::copy(dnums.output_spatial_dimensions().begin(),
dnums.output_spatial_dimensions().end(), output_dnums.begin() + 2); dnums.output_spatial_dimensions().end(), output_dnums.begin() + 2);
std::sort(output_dnums.begin(), output_dnums.end()); absl::c_sort(output_dnums);
std::vector<int64> expected_dnums(num_dims); std::vector<int64> expected_dnums(num_dims);
std::iota(expected_dnums.begin(), expected_dnums.end(), 0); std::iota(expected_dnums.begin(), expected_dnums.end(), 0);
const auto in_range = [num_dims](int64 i) { return 0 <= i && i < num_dims; }; const auto in_range = [num_dims](int64 i) { return 0 <= i && i < num_dims; };
if (!std::all_of(input_dnums.begin(), input_dnums.end(), in_range) || if (!absl::c_all_of(input_dnums, in_range) ||
!std::all_of(window_dnums.begin(), window_dnums.end(), in_range) || !absl::c_all_of(window_dnums, in_range) ||
!std::all_of(output_dnums.begin(), output_dnums.end(), in_range)) { !absl::c_all_of(output_dnums, in_range)) {
return InvalidArgument( return InvalidArgument(
"A dimension number is out of range in convolution: %s.", "A dimension number is out of range in convolution: %s.",
dnums.DebugString()); dnums.DebugString());

View File

@ -130,8 +130,7 @@ bool FoldTransposeIntoConvolution(InstructionOperandsPair pair) {
HloInstruction* new_lhs; HloInstruction* new_lhs;
const int64 kLhsIdx = 0; const int64 kLhsIdx = 0;
if (std::find(operand_indices.begin(), operand_indices.end(), kLhsIdx) != if (absl::c_linear_search(operand_indices, kLhsIdx)) {
operand_indices.end()) {
HloInstruction& transpose = *convolution.mutable_operand(kLhsIdx); HloInstruction& transpose = *convolution.mutable_operand(kLhsIdx);
const auto& transpose_dimensions = transpose.dimensions(); const auto& transpose_dimensions = transpose.dimensions();
HloInstruction& transpose_operand = *transpose.mutable_operand(0); HloInstruction& transpose_operand = *transpose.mutable_operand(0);
@ -154,8 +153,7 @@ bool FoldTransposeIntoConvolution(InstructionOperandsPair pair) {
HloInstruction* new_rhs; HloInstruction* new_rhs;
const int64 kRhsIdx = 1; const int64 kRhsIdx = 1;
if (std::find(operand_indices.begin(), operand_indices.end(), kRhsIdx) != if (absl::c_linear_search(operand_indices, kRhsIdx)) {
operand_indices.end()) {
HloInstruction& transpose = *convolution.mutable_operand(kRhsIdx); HloInstruction& transpose = *convolution.mutable_operand(kRhsIdx);
const auto& transpose_dimensions = transpose.dimensions(); const auto& transpose_dimensions = transpose.dimensions();
HloInstruction& transpose_operand = *transpose.mutable_operand(0); HloInstruction& transpose_operand = *transpose.mutable_operand(0);

View File

@ -87,9 +87,7 @@ bool PointsToSet::ContainsBuffer(const LogicalBuffer& buffer) const {
bool found = false; bool found = false;
ForEachElement([&found, &buffer](const ShapeIndex& /*index*/, ForEachElement([&found, &buffer](const ShapeIndex& /*index*/,
const BufferList& pointed_to_buffers) { const BufferList& pointed_to_buffers) {
if (!found && if (!found && absl::c_linear_search(pointed_to_buffers, &buffer)) {
std::find(pointed_to_buffers.begin(), pointed_to_buffers.end(),
&buffer) != pointed_to_buffers.end()) {
found = true; found = true;
} }
}); });
@ -99,8 +97,7 @@ bool PointsToSet::ContainsBuffer(const LogicalBuffer& buffer) const {
bool PointsToSet::ContainsBufferAtIndex(const LogicalBuffer& buffer, bool PointsToSet::ContainsBufferAtIndex(const LogicalBuffer& buffer,
const ShapeIndex& index) const { const ShapeIndex& index) const {
const auto& pointed_to_buffers = element(index); const auto& pointed_to_buffers = element(index);
return std::find(pointed_to_buffers.begin(), pointed_to_buffers.end(), return absl::c_linear_search(pointed_to_buffers, &buffer);
&buffer) != pointed_to_buffers.end();
} }
void PointsToSet::AddPointedToBuffer(const LogicalBuffer& buffer, void PointsToSet::AddPointedToBuffer(const LogicalBuffer& buffer,
@ -604,9 +601,8 @@ bool TuplePointsToAnalysis::DoesNotUseOperandBuffer(
} else if (user->opcode() == HloOpcode::kFusion && } else if (user->opcode() == HloOpcode::kFusion &&
user->fusion_kind() == HloInstruction::FusionKind::kLoop) { user->fusion_kind() == HloInstruction::FusionKind::kLoop) {
// Find fusion parameter associated with 'operand'. // Find fusion parameter associated with 'operand'.
auto it = std::find_if( auto it = absl::c_find_if(
user->fused_parameters().begin(), user->fused_parameters().end(), user->fused_parameters(), [&](HloInstruction* fused_param) {
[=](HloInstruction* fused_param) {
return user->operand(fused_param->parameter_number()) == operand; return user->operand(fused_param->parameter_number()) == operand;
}); });
CHECK(it != user->fused_parameters().end()); CHECK(it != user->fused_parameters().end());
@ -672,9 +668,8 @@ bool TuplePointsToAnalysis::HasUniqueFusedUseOfOperandAt(
} }
// Find fusion parameter associated with 'operand'. // Find fusion parameter associated with 'operand'.
const auto& fused_params = fusion->fused_parameters(); const auto& fused_params = fusion->fused_parameters();
auto fused_param_it = std::find_if( auto fused_param_it =
fused_params.begin(), fused_params.end(), absl::c_find_if(fused_params, [&](HloInstruction* fused_param) {
[&](HloInstruction* fused_param) {
return fusion->operand(fused_param->parameter_number()) == operand; return fusion->operand(fused_param->parameter_number()) == operand;
}); });
if (fused_param_it == fused_params.end()) { if (fused_param_it == fused_params.end()) {
@ -743,8 +738,7 @@ bool TuplePointsToAnalysis::CanShareOperandBufferWithUser(
// Check if one operand of kAdd fused root is kDot or kConvolution. // Check if one operand of kAdd fused root is kDot or kConvolution.
auto* add = user->fused_expression_root(); auto* add = user->fused_expression_root();
auto add_operand_it = auto add_operand_it =
std::find_if(add->operands().begin(), add->operands().end(), absl::c_find_if(add->operands(), [&](HloInstruction* operand) {
[&](HloInstruction* operand) {
return operand->opcode() == HloOpcode::kConvolution || return operand->opcode() == HloOpcode::kConvolution ||
operand->opcode() == HloOpcode::kDot; operand->opcode() == HloOpcode::kDot;
}); });

View File

@ -721,9 +721,8 @@ class FusionPointsToAnalysisTest : public TuplePointsToAnalysisTest {
// to fusion 'operand'. // to fusion 'operand'.
HloInstruction* GetFusionParameterForOperand(HloInstruction* fusion, HloInstruction* GetFusionParameterForOperand(HloInstruction* fusion,
HloInstruction* operand) { HloInstruction* operand) {
auto it = std::find_if( auto it = absl::c_find_if(
fusion->fused_instructions().begin(), fusion->fused_instructions(), [&](const HloInstruction* fused) {
fusion->fused_instructions().end(), [=](const HloInstruction* fused) {
return fused->opcode() == HloOpcode::kParameter && return fused->opcode() == HloOpcode::kParameter &&
fusion->operand(fused->parameter_number()) == operand; fusion->operand(fused->parameter_number()) == operand;
}); });

View File

@ -109,8 +109,7 @@ static StatusOr<bool> TryRemoveDeadWhileParams(HloInstruction* while_op) {
// operand appears in, but it may appear more than once! // operand appears in, but it may appear more than once!
if (user->user_count() == 1 && user->users().front() == while_body_root && if (user->user_count() == 1 && user->users().front() == while_body_root &&
while_body_root->operand_index(user) == user->tuple_index() && while_body_root->operand_index(user) == user->tuple_index() &&
std::count(while_body_root->operands().begin(), absl::c_count(while_body_root->operands(), user) == 1) {
while_body_root->operands().end(), user) == 1) {
continue; continue;
} }
@ -158,7 +157,7 @@ static StatusOr<bool> TryRemoveDeadWhileParams(HloInstruction* while_op) {
// Build up maps from the old/new to the new/old tuple indices. // Build up maps from the old/new to the new/old tuple indices.
std::vector<int64> new_to_old_tuple_idx(used_tuple_indices.begin(), std::vector<int64> new_to_old_tuple_idx(used_tuple_indices.begin(),
used_tuple_indices.end()); used_tuple_indices.end());
std::sort(new_to_old_tuple_idx.begin(), new_to_old_tuple_idx.end()); absl::c_sort(new_to_old_tuple_idx);
absl::flat_hash_map<int64, int64> old_to_new_tuple_idx; absl::flat_hash_map<int64, int64> old_to_new_tuple_idx;
for (int64 new_idx = 0; new_idx < new_to_old_tuple_idx.size(); ++new_idx) { for (int64 new_idx = 0; new_idx < new_to_old_tuple_idx.size(); ++new_idx) {

View File

@ -407,10 +407,9 @@ TEST_F(WhileLoopSimplifierTest, RemoveUnusedLoopOperands) {
// The original while instruction is still left in the module as a dead // The original while instruction is still left in the module as a dead
// instruction, find a while instruction with a different name as the new // instruction, find a while instruction with a different name as the new
// while instruction. // while instruction.
const auto& instrs = m->entry_computation()->instructions();
HloInstruction* new_while_op = HloInstruction* new_while_op =
*std::find_if(m->entry_computation()->instructions().begin(), *absl::c_find_if(instrs, [&](const HloInstruction* instr) {
m->entry_computation()->instructions().end(),
[&](const HloInstruction* instr) {
return (instr->opcode() == HloOpcode::kWhile && return (instr->opcode() == HloOpcode::kWhile &&
instr->name() != "while"); instr->name() != "while");
}); });

View File

@ -87,8 +87,7 @@ bool Shape::is_static() const {
} }
} }
} }
return !std::any_of(dynamic_dimensions_.begin(), dynamic_dimensions_.end(), return !absl::c_any_of(dynamic_dimensions_, [](bool b) { return b; });
[](bool b) { return b; });
} }
void Shape::DeleteDimension(int64 dim_to_delete) { void Shape::DeleteDimension(int64 dim_to_delete) {

View File

@ -405,8 +405,7 @@ ShapeUtil::MakeShapeWithDescendingLayoutAndSamePhysicalLayout(
} }
/* static */ bool ShapeUtil::IsNestedTuple(const Shape& shape) { /* static */ bool ShapeUtil::IsNestedTuple(const Shape& shape) {
return IsTuple(shape) && std::any_of(shape.tuple_shapes().begin(), return IsTuple(shape) && absl::c_any_of(shape.tuple_shapes(), IsTuple);
shape.tuple_shapes().end(), IsTuple);
} }
/* static */ bool ShapeUtil::IsEmptyTuple(const Shape& shape) { /* static */ bool ShapeUtil::IsEmptyTuple(const Shape& shape) {

View File

@ -135,7 +135,7 @@ void SparseIndexArray::SortWithValues(absl::Span<NativeT> values) {
auto sort_order_less = [this](int64 lhs, int64 rhs) { auto sort_order_less = [this](int64 lhs, int64 rhs) {
return IndexUtil::CompareIndices(At(lhs), At(rhs)) < 0; return IndexUtil::CompareIndices(At(lhs), At(rhs)) < 0;
}; };
std::sort(sort_order.begin(), sort_order.end(), sort_order_less); absl::c_sort(sort_order, sort_order_less);
// Reorder the array elements according to sort_order. Work through the array // Reorder the array elements according to sort_order. Work through the array
// and follow cycles so we can do the reorder in-place. // and follow cycles so we can do the reorder in-place.

View File

@ -467,7 +467,7 @@ XLA_TEST_F(ConvolutionTest, Convolve3D_1x4x2x3x3_2x2x2x3x3_Valid) {
// servers. The error message is missing the operator ++. // servers. The error message is missing the operator ++.
template <typename T> template <typename T>
void iota_int_init_value(std::vector<T>& values, int init_value) { void iota_int_init_value(std::vector<T>& values, int init_value) {
std::for_each(values.begin(), values.end(), absl::c_for_each(values,
[&](T& value) { value = static_cast<T>(init_value++); }); [&](T& value) { value = static_cast<T>(init_value++); });
} }

View File

@ -86,7 +86,7 @@ bool IsPermutation(absl::Span<const int64> permutation, int64 rank) {
CHECK_LT(index, rank); CHECK_LT(index, rank);
output[index] = 0; output[index] = 0;
} }
return std::find(output.begin(), output.end(), -1) == output.end(); return !absl::c_linear_search(output, -1);
} }
std::vector<int64> InversePermutation( std::vector<int64> InversePermutation(

View File

@ -324,8 +324,7 @@ bool IsIdentityPermutation(absl::Span<const int64> permutation);
template <typename Container> template <typename Container>
int64 PositionInContainer(const Container& container, int64 value) { int64 PositionInContainer(const Container& container, int64 value) {
return std::distance(container.begin(), return std::distance(container.begin(), absl::c_find(container, value));
std::find(container.begin(), container.end(), value));
} }
// Formats the container as a comma-separated string. StrAppend must support // Formats the container as a comma-separated string. StrAppend must support

View File

@ -17,6 +17,7 @@ limitations under the License.
#include <vector> #include <vector>
#include "absl/algorithm/container.h"
#include "absl/strings/str_cat.h" #include "absl/strings/str_cat.h"
#include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/types.h"
#include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/compiler/xla/xla_data.pb.h"
@ -137,23 +138,21 @@ bool HasPadding(const Window& window) {
} }
bool HasSymmetricPadding(const Window& window) { bool HasSymmetricPadding(const Window& window) {
return std::all_of(window.dimensions().begin(), window.dimensions().end(), return absl::c_all_of(window.dimensions(), [](const WindowDimension& dim) {
[](const WindowDimension& dim) {
return dim.padding_low() == dim.padding_high(); return dim.padding_low() == dim.padding_high();
}); });
} }
bool HasSymmetricPadding(const PaddingConfig& padding_config) { bool HasSymmetricPadding(const PaddingConfig& padding_config) {
return std::all_of(padding_config.dimensions().begin(), return absl::c_all_of(padding_config.dimensions(),
padding_config.dimensions().end(),
[](const PaddingConfig::PaddingConfigDimension& dim) { [](const PaddingConfig::PaddingConfigDimension& dim) {
return dim.edge_padding_low() == dim.edge_padding_high(); return dim.edge_padding_low() ==
dim.edge_padding_high();
}); });
} }
bool HasNegativePadding(const Window& window) { bool HasNegativePadding(const Window& window) {
return std::any_of(window.dimensions().begin(), window.dimensions().end(), return absl::c_any_of(window.dimensions(), [](const WindowDimension& dim) {
[](const WindowDimension& dim) {
return dim.padding_low() < 0 || dim.padding_high() < 0; return dim.padding_low() < 0 || dim.padding_high() < 0;
}); });
} }
@ -190,8 +189,7 @@ bool AllOrNoneReversed(const Window& window) {
return true; return true;
} }
bool reversed = window.dimensions()[0].window_reversal(); bool reversed = window.dimensions()[0].window_reversal();
return std::all_of(window.dimensions().begin(), window.dimensions().end(), return absl::c_all_of(window.dimensions(), [&](const WindowDimension& dim) {
[&](const WindowDimension& dim) {
return dim.window_reversal() == reversed; return dim.window_reversal() == reversed;
}); });
} }