[XLA] Use absl::c_foo rather than std::foo.
No functional change. PiperOrigin-RevId: 227896034
This commit is contained in:
parent
f9bd1568aa
commit
b4813a0cff
@ -717,6 +717,7 @@ cc_library(
|
||||
":types",
|
||||
":xla_data_proto",
|
||||
"//tensorflow/core:lib",
|
||||
"@com_google_absl//absl/algorithm:container",
|
||||
"@com_google_absl//absl/strings",
|
||||
"@com_google_absl//absl/types:span",
|
||||
],
|
||||
|
@ -77,7 +77,7 @@ XLA_TEST_F(SortingTest, TopKFullSort) {
|
||||
auto x = ConstantR1<float>(&builder, inputs);
|
||||
xla::GetTupleElement(xla::TopK(x, kSize), 0);
|
||||
|
||||
std::sort(inputs.begin(), inputs.end(), std::greater<float>());
|
||||
absl::c_sort(inputs, std::greater<float>());
|
||||
ComputeAndCompareR1<float>(&builder, inputs, {});
|
||||
}
|
||||
|
||||
|
@ -290,8 +290,8 @@ Layout CreateDefaultLayoutForRank(int64 rank) {
|
||||
/* static */ bool LayoutUtil::HasLayout(const Shape& shape) {
|
||||
if (shape.IsTuple()) {
|
||||
// Tuple shape: all subshapes must have a layout.
|
||||
return std::all_of(shape.tuple_shapes().begin(), shape.tuple_shapes().end(),
|
||||
[](const Shape& s) { return HasLayout(s); });
|
||||
return absl::c_all_of(shape.tuple_shapes(),
|
||||
[](const Shape& s) { return HasLayout(s); });
|
||||
} else if (!shape.IsArray()) {
|
||||
// Opaque, token types etc. ignore layout.
|
||||
return true;
|
||||
@ -424,7 +424,7 @@ Status LayoutUtil::CopyLayoutBetweenShapes(const Shape& src, Shape* dst) {
|
||||
positions_in_layout.push_back(
|
||||
PositionInContainer(layout.minor_to_major(), dim));
|
||||
}
|
||||
std::sort(positions_in_layout.begin(), positions_in_layout.end());
|
||||
absl::c_sort(positions_in_layout);
|
||||
for (size_t i = 1; i < positions_in_layout.size(); ++i) {
|
||||
if (1 != positions_in_layout[i] - positions_in_layout[i - 1]) {
|
||||
return false;
|
||||
|
@ -55,7 +55,7 @@ string MetricTableReport::MakeReport(double expected_metric_sum) {
|
||||
const auto metric_greater = [](const Entry& a, const Entry& b) {
|
||||
return a.metric > b.metric;
|
||||
};
|
||||
std::sort(entries_.begin(), entries_.end(), metric_greater);
|
||||
absl::c_sort(entries_, metric_greater);
|
||||
|
||||
// Create the report
|
||||
AppendLine();
|
||||
@ -117,7 +117,7 @@ std::vector<MetricTableReport::Category> MetricTableReport::MakeCategories(
|
||||
auto metric_sum_greater = [](const Category& a, const Category& b) {
|
||||
return a.metric_sum > b.metric_sum;
|
||||
};
|
||||
std::sort(categories.begin(), categories.end(), metric_sum_greater);
|
||||
absl::c_sort(categories, metric_sum_greater);
|
||||
|
||||
return categories;
|
||||
}
|
||||
|
@ -3414,6 +3414,7 @@ cc_library(
|
||||
":hlo_profile_printer_data",
|
||||
":human_readable_profile_builder",
|
||||
"//tensorflow/compiler/xla:types",
|
||||
"@com_google_absl//absl/algorithm:container",
|
||||
"@com_google_absl//absl/strings",
|
||||
],
|
||||
)
|
||||
|
@ -2216,8 +2216,7 @@ Status AlgebraicSimplifierVisitor::HandleReverse(HloInstruction* reverse) {
|
||||
auto dim_is_one = [&](int64 i) -> bool {
|
||||
return reverse->shape().dimensions(i) == 1;
|
||||
};
|
||||
if (std::all_of(reverse->dimensions().begin(), reverse->dimensions().end(),
|
||||
dim_is_one)) {
|
||||
if (absl::c_all_of(reverse->dimensions(), dim_is_one)) {
|
||||
return ReplaceInstruction(reverse, reverse->mutable_operand(0));
|
||||
}
|
||||
return Status::OK();
|
||||
@ -2492,9 +2491,9 @@ Status AlgebraicSimplifierVisitor::HandleReduce(HloInstruction* reduce) {
|
||||
// Create a new reduce with the combined reduction dimensions of both
|
||||
// reduces.
|
||||
std::vector<int64> arg_dims = arg->dimensions();
|
||||
std::sort(arg_dims.begin(), arg_dims.end());
|
||||
absl::c_sort(arg_dims);
|
||||
std::vector<int64> reduce_dims = reduce->dimensions();
|
||||
std::sort(reduce_dims.begin(), reduce_dims.end());
|
||||
absl::c_sort(reduce_dims);
|
||||
// Transform reduce_dims to the same rank as the operand of the operand.
|
||||
for (int64 arg_dim : arg_dims) {
|
||||
for (int64& dim : reduce_dims) {
|
||||
|
@ -86,10 +86,9 @@ std::vector<int64> ColorInterferenceGraph(
|
||||
// first, but it would be good to investigate other ordering heuristics too.
|
||||
std::vector<int64> nodes(node_count);
|
||||
std::iota(nodes.begin(), nodes.end(), 0);
|
||||
std::sort(nodes.begin(), nodes.end(),
|
||||
[&interference_map](const int64 i, const int64 j) {
|
||||
return interference_map[i].size() > interference_map[j].size();
|
||||
});
|
||||
absl::c_sort(nodes, [&interference_map](const int64 i, const int64 j) {
|
||||
return interference_map[i].size() > interference_map[j].size();
|
||||
});
|
||||
|
||||
const int64 kColorUnassigned = -1;
|
||||
std::vector<int64> assigned_colors(node_count, kColorUnassigned);
|
||||
@ -272,11 +271,12 @@ BufferAllocationProto BufferAllocation::ToProto() const {
|
||||
proto_assigned->set_offset(buffer_offset_size.second.offset);
|
||||
proto_assigned->set_size(buffer_offset_size.second.size);
|
||||
}
|
||||
std::sort(proto.mutable_assigned()->begin(), proto.mutable_assigned()->end(),
|
||||
[](const BufferAllocationProto::Assigned& assign1,
|
||||
const BufferAllocationProto::Assigned& assign2) {
|
||||
return assign1.logical_buffer_id() < assign2.logical_buffer_id();
|
||||
});
|
||||
absl::c_sort(*proto.mutable_assigned(),
|
||||
[](const BufferAllocationProto::Assigned& assign1,
|
||||
const BufferAllocationProto::Assigned& assign2) {
|
||||
return assign1.logical_buffer_id() <
|
||||
assign2.logical_buffer_id();
|
||||
});
|
||||
return proto;
|
||||
}
|
||||
|
||||
@ -308,10 +308,10 @@ string BufferAllocation::ToString() const {
|
||||
for (const auto& buffer_offset_size : assigned_buffers_) {
|
||||
sorted_buffers.push_back(buffer_offset_size.first);
|
||||
}
|
||||
std::sort(sorted_buffers.begin(), sorted_buffers.end(),
|
||||
[](const LogicalBuffer* a, const LogicalBuffer* b) {
|
||||
return a->id() < b->id();
|
||||
});
|
||||
absl::c_sort(sorted_buffers,
|
||||
[](const LogicalBuffer* a, const LogicalBuffer* b) {
|
||||
return a->id() < b->id();
|
||||
});
|
||||
for (const LogicalBuffer* buffer : sorted_buffers) {
|
||||
const OffsetSize& offset_size = FindOrDie(assigned_buffers_, buffer);
|
||||
StrAppend(&output, absl::StrFormat(
|
||||
@ -479,10 +479,9 @@ bool BufferAssignment::HaveDisjointSlices(const HloInstruction* hlo_a,
|
||||
// didn't return the empty set) for both HLOs, and the two resulting sets of
|
||||
// slices are disjoint.
|
||||
return !slices_a.empty() && !slices_b.empty() &&
|
||||
std::none_of(slices_a.begin(), slices_a.end(),
|
||||
[&](const BufferAllocation::Slice& slice) {
|
||||
return slices_b.count(slice) > 0;
|
||||
});
|
||||
absl::c_none_of(slices_a, [&](const BufferAllocation::Slice& slice) {
|
||||
return slices_b.contains(slice);
|
||||
});
|
||||
}
|
||||
|
||||
StatusOr<BufferAllocation::Slice>
|
||||
@ -952,28 +951,28 @@ Status BufferAssigner::AssignBuffersForComputation(
|
||||
// operands (assuming operands are the same/larger size) enabling the
|
||||
// important reuse case where an elementwise instruction reuses one of its
|
||||
// operand's buffer. This improves locality.
|
||||
std::sort(sorted_buffers.begin(), sorted_buffers.end(),
|
||||
[has_sequential_order, &liveness, &post_order_position, assignment](
|
||||
const LogicalBuffer* a, const LogicalBuffer* b) {
|
||||
// Primary sort is by decreasing buffer size.
|
||||
const int64 a_size = assignment->buffer_size_(*a);
|
||||
const int64 b_size = assignment->buffer_size_(*b);
|
||||
if (a_size != b_size) {
|
||||
return a_size > b_size; // use ">" for decreasing size.
|
||||
}
|
||||
// Otherwise live out buffers come before others, if the
|
||||
// instructions are sequentially ordered.
|
||||
if (has_sequential_order) {
|
||||
const bool a_live_out = liveness.MaybeLiveOut(*a);
|
||||
const bool b_live_out = liveness.MaybeLiveOut(*b);
|
||||
if (a_live_out != b_live_out) {
|
||||
return a_live_out;
|
||||
}
|
||||
}
|
||||
// Final tiebreaker is in instruction post order.
|
||||
return post_order_position.at(a->instruction()) <
|
||||
post_order_position.at(b->instruction());
|
||||
});
|
||||
absl::c_sort(sorted_buffers,
|
||||
[has_sequential_order, &liveness, &post_order_position,
|
||||
assignment](const LogicalBuffer* a, const LogicalBuffer* b) {
|
||||
// Primary sort is by decreasing buffer size.
|
||||
const int64 a_size = assignment->buffer_size_(*a);
|
||||
const int64 b_size = assignment->buffer_size_(*b);
|
||||
if (a_size != b_size) {
|
||||
return a_size > b_size; // use ">" for decreasing size.
|
||||
}
|
||||
// Otherwise live out buffers come before others, if the
|
||||
// instructions are sequentially ordered.
|
||||
if (has_sequential_order) {
|
||||
const bool a_live_out = liveness.MaybeLiveOut(*a);
|
||||
const bool b_live_out = liveness.MaybeLiveOut(*b);
|
||||
if (a_live_out != b_live_out) {
|
||||
return a_live_out;
|
||||
}
|
||||
}
|
||||
// Final tiebreaker is in instruction post order.
|
||||
return post_order_position.at(a->instruction()) <
|
||||
post_order_position.at(b->instruction());
|
||||
});
|
||||
|
||||
// BufferAllocations are necessarily created in decreasing size order. Keep
|
||||
// indices of previously created BufferAllocations in allocation_indices.
|
||||
@ -1305,10 +1304,10 @@ std::vector<const LogicalBuffer*> ComputePeakMemoryLogicalBuffers(
|
||||
live_buffers.end());
|
||||
|
||||
// Stabily sort the live buffers.
|
||||
std::sort(live_buffers_vector.begin(), live_buffers_vector.end(),
|
||||
[](const LogicalBuffer* a, const LogicalBuffer* b) {
|
||||
return a->id() < b->id();
|
||||
});
|
||||
absl::c_sort(live_buffers_vector,
|
||||
[](const LogicalBuffer* a, const LogicalBuffer* b) {
|
||||
return a->id() < b->id();
|
||||
});
|
||||
return live_buffers_vector;
|
||||
}
|
||||
|
||||
|
@ -384,7 +384,7 @@ TEST_F(CallGraphTest, ComplexGraph) {
|
||||
|
||||
// Verify visitation order of some computations in the graph.
|
||||
auto index_of = [&visited](const HloComputation* comp) {
|
||||
auto it = std::find(visited.begin(), visited.end(), comp);
|
||||
auto it = absl::c_find(visited, comp);
|
||||
EXPECT_NE(it, visited.end());
|
||||
return std::distance(visited.begin(), it);
|
||||
};
|
||||
|
@ -42,8 +42,8 @@ void ComputationLayout::SetToDefaultLayout() {
|
||||
}
|
||||
|
||||
bool ComputationLayout::LayoutIsSet() const {
|
||||
return std::all_of(parameter_layouts_.begin(), parameter_layouts_.end(),
|
||||
[](const ShapeLayout& s) { return s.LayoutIsSet(); }) &&
|
||||
return absl::c_all_of(parameter_layouts_,
|
||||
[](const ShapeLayout& s) { return s.LayoutIsSet(); }) &&
|
||||
result_layout_.LayoutIsSet();
|
||||
}
|
||||
|
||||
|
@ -539,10 +539,9 @@ class CopyRemover {
|
||||
}
|
||||
|
||||
std::vector<const HloValue*> values = buffer.values();
|
||||
std::sort(values.begin(), values.end(),
|
||||
[this](const HloValue* a, const HloValue* b) {
|
||||
return ordering_.IsDefinedBefore(*a, *b);
|
||||
});
|
||||
absl::c_sort(values, [this](const HloValue* a, const HloValue* b) {
|
||||
return ordering_.IsDefinedBefore(*a, *b);
|
||||
});
|
||||
|
||||
// Create a list containing all of the values in the buffer.
|
||||
AddValueList(values, &value_to_node);
|
||||
@ -842,12 +841,11 @@ class CopyRemover {
|
||||
copy_value_node->next->prev = operand_node;
|
||||
|
||||
// Patch up uses. Remove use of copy from operand_node uses.
|
||||
auto it =
|
||||
std::find_if(operand_node->uses.begin(), operand_node->uses.end(),
|
||||
[copy_value_node](const HloUse* use) {
|
||||
return use->instruction ==
|
||||
copy_value_node->value->defining_instruction();
|
||||
});
|
||||
auto it = absl::c_find_if(
|
||||
operand_node->uses, [copy_value_node](const HloUse* use) {
|
||||
return use->instruction ==
|
||||
copy_value_node->value->defining_instruction();
|
||||
});
|
||||
CHECK(it != operand_node->uses.end());
|
||||
operand_node->uses.erase(it);
|
||||
|
||||
|
@ -77,17 +77,16 @@ StatusOr<DisassemblerResult> Disassembler::DisassembleObjectFile(
|
||||
}
|
||||
|
||||
// Sort the symbols in increasing address order.
|
||||
std::sort(
|
||||
symbols.begin(), symbols.end(),
|
||||
[](const llvm::object::SymbolRef& a, const llvm::object::SymbolRef& b) {
|
||||
// getAddress returns a Expected object. Assert there is no error
|
||||
// before extracting the address.
|
||||
llvm::Expected<uint64_t> a_address_or_error = a.getAddress();
|
||||
CHECK(a_address_or_error);
|
||||
llvm::Expected<uint64_t> b_address_or_error = b.getAddress();
|
||||
CHECK(b_address_or_error);
|
||||
return a_address_or_error.get() < b_address_or_error.get();
|
||||
});
|
||||
absl::c_sort(symbols, [](const llvm::object::SymbolRef& a,
|
||||
const llvm::object::SymbolRef& b) {
|
||||
// getAddress returns a Expected object. Assert there is no error
|
||||
// before extracting the address.
|
||||
llvm::Expected<uint64_t> a_address_or_error = a.getAddress();
|
||||
CHECK(a_address_or_error);
|
||||
llvm::Expected<uint64_t> b_address_or_error = b.getAddress();
|
||||
CHECK(b_address_or_error);
|
||||
return a_address_or_error.get() < b_address_or_error.get();
|
||||
});
|
||||
|
||||
// Construct ArrayRef pointing to section contents.
|
||||
llvm::StringRef section_content_string;
|
||||
|
@ -1709,10 +1709,8 @@ StatusOr<bool> IrEmitter::EmitVectorizedReduce(
|
||||
vectorization_factor_in_bytes /
|
||||
ShapeUtil::ByteSizeOfPrimitiveType(reduce->shape().element_type());
|
||||
|
||||
bool is_reduction_over_minor_dimension =
|
||||
std::find(dimensions.begin(), dimensions.end(),
|
||||
LayoutUtil::Minor(arg->shape().layout(), 0)) !=
|
||||
dimensions.end();
|
||||
bool is_reduction_over_minor_dimension = absl::c_linear_search(
|
||||
dimensions, LayoutUtil::Minor(arg->shape().layout(), 0));
|
||||
|
||||
unsigned element_alignment = tensorflow::MathUtil::GCD<unsigned>(
|
||||
ShapeUtil::ByteSizeOfPrimitiveType(reduce->shape().element_type()),
|
||||
@ -2401,8 +2399,7 @@ StatusOr<bool> IrEmitter::EmitFastConcatenate(
|
||||
int64 concat_dim = concatenate->dimensions(0);
|
||||
const Layout& output_layout = output_shape.layout();
|
||||
auto output_min2maj = LayoutUtil::MinorToMajor(output_layout);
|
||||
auto concat_dim_layout_itr =
|
||||
std::find(output_min2maj.begin(), output_min2maj.end(), concat_dim);
|
||||
auto concat_dim_layout_itr = absl::c_find(output_min2maj, concat_dim);
|
||||
|
||||
std::vector<int64> inner_dims(output_min2maj.begin(), concat_dim_layout_itr);
|
||||
std::vector<int64> outer_dims(std::next(concat_dim_layout_itr),
|
||||
@ -2956,8 +2953,7 @@ Status IrEmitter::ElementTypesSameAndSupported(
|
||||
|
||||
TF_RET_CHECK(!operands.empty());
|
||||
PrimitiveType primitive_type = operands[0]->shape().element_type();
|
||||
if (std::find(supported_types.begin(), supported_types.end(),
|
||||
primitive_type) == supported_types.end()) {
|
||||
if (!absl::c_linear_search(supported_types, primitive_type)) {
|
||||
return Unimplemented("unsupported operand type %s in op %s",
|
||||
PrimitiveType_Name(primitive_type),
|
||||
HloOpcodeString(instruction.opcode()));
|
||||
|
@ -154,20 +154,17 @@ bool IsReductionToVector(const HloInstruction& reduce) {
|
||||
const HloInstruction* input = reduce.operand(0);
|
||||
std::vector<int64> dims_to_keep;
|
||||
for (int64 dim = 0; dim < input->shape().dimensions().size(); ++dim) {
|
||||
if (!std::count(reduce.dimensions().begin(), reduce.dimensions().end(),
|
||||
dim)) {
|
||||
if (!absl::c_linear_search(reduce.dimensions(), dim)) {
|
||||
dims_to_keep.push_back(dim);
|
||||
}
|
||||
}
|
||||
return LayoutUtil::AreDimensionsConsecutive(input->shape().layout(),
|
||||
dims_to_keep) &&
|
||||
ShapeUtil::Equal(reduce.shape(), ShapeUtil::FilterDimensions(
|
||||
[&dims_to_keep](int64 dim) {
|
||||
return std::count(
|
||||
dims_to_keep.begin(),
|
||||
dims_to_keep.end(), dim);
|
||||
},
|
||||
input->shape()));
|
||||
ShapeUtil::Equal(
|
||||
reduce.shape(),
|
||||
ShapeUtil::FilterDimensions(
|
||||
[&](int64 dim) { return absl::c_count(dims_to_keep, dim); },
|
||||
input->shape()));
|
||||
}
|
||||
|
||||
// This emits a device-side call to
|
||||
|
@ -1506,10 +1506,10 @@ std::unique_ptr<KernelThunk> IrEmitterUnnested::BuildKernelThunk(
|
||||
return !allocation->is_constant();
|
||||
});
|
||||
|
||||
std::sort(non_constant_buffers.begin(), non_constant_buffers.end(),
|
||||
[](const BufferAllocation* a, const BufferAllocation* b) {
|
||||
return a->index() < b->index();
|
||||
});
|
||||
absl::c_sort(non_constant_buffers,
|
||||
[](const BufferAllocation* a, const BufferAllocation* b) {
|
||||
return a->index() < b->index();
|
||||
});
|
||||
|
||||
llvm::Function* kernel = BuildKernelPrototype(*inst, non_constant_buffers);
|
||||
|
||||
|
@ -225,10 +225,10 @@ Status HeapSimulator::RunComputation(
|
||||
}
|
||||
}
|
||||
// Sort to get a deterministic iteration order.
|
||||
std::sort(operand_buffers_to_free.begin(), operand_buffers_to_free.end(),
|
||||
[](const BufferValue* x, const BufferValue* y) {
|
||||
return x->id() < y->id();
|
||||
});
|
||||
absl::c_sort(operand_buffers_to_free,
|
||||
[](const BufferValue* x, const BufferValue* y) {
|
||||
return x->id() < y->id();
|
||||
});
|
||||
|
||||
// Allocate buffers defined by this instruction. This is the latest point
|
||||
// that we can allocate; right before the buffer is first used. This must
|
||||
@ -335,10 +335,9 @@ Status HeapSimulator::RunComputation(
|
||||
to_free.push_back(buffer);
|
||||
}
|
||||
|
||||
std::sort(to_free.begin(), to_free.end(),
|
||||
[](const BufferValue* x, const BufferValue* y) {
|
||||
return x->id() < y->id();
|
||||
});
|
||||
absl::c_sort(to_free, [](const BufferValue* x, const BufferValue* y) {
|
||||
return x->id() < y->id();
|
||||
});
|
||||
for (const BufferValue* buffer : to_free) {
|
||||
VLOG(3) << "Freeing pending: " << buffer->ToString();
|
||||
Free(buffer, root);
|
||||
@ -596,7 +595,7 @@ void DecreasingSizeRunsHeap::CallAndDrainRun() {
|
||||
}
|
||||
|
||||
// Call ops in the run sorted by decreasing size, breaking ties by buffer id.
|
||||
std::sort(run_.begin(), run_.end(), [](const Op& a, const Op& b) {
|
||||
absl::c_sort(run_, [](const Op& a, const Op& b) {
|
||||
if (a.size != b.size) {
|
||||
return a.size > b.size;
|
||||
}
|
||||
@ -866,23 +865,23 @@ HeapSimulator::Result GlobalDecreasingSizeBestFitHeap::Finish() {
|
||||
for (auto& entry : buffer_intervals_) {
|
||||
sorted_buffer_intervals.push_back(entry.second);
|
||||
}
|
||||
std::sort(sorted_buffer_intervals.begin(), sorted_buffer_intervals.end(),
|
||||
[](const BufferInterval& x, const BufferInterval& y) {
|
||||
if (x.size != y.size) {
|
||||
return x.size > y.size;
|
||||
}
|
||||
if (x.end - x.start != y.end - y.start) {
|
||||
return x.end - x.start > y.end - y.start;
|
||||
}
|
||||
return x.buffer->id() < y.buffer->id();
|
||||
});
|
||||
absl::c_sort(sorted_buffer_intervals,
|
||||
[](const BufferInterval& x, const BufferInterval& y) {
|
||||
if (x.size != y.size) {
|
||||
return x.size > y.size;
|
||||
}
|
||||
if (x.end - x.start != y.end - y.start) {
|
||||
return x.end - x.start > y.end - y.start;
|
||||
}
|
||||
return x.buffer->id() < y.buffer->id();
|
||||
});
|
||||
|
||||
BufferIntervalTree interval_tree(sorted_buffer_intervals.size());
|
||||
for (auto& buffer_interval : sorted_buffer_intervals) {
|
||||
auto chunks_overlapping_in_time = interval_tree.ChunksOverlappingInTime(
|
||||
buffer_interval.start, buffer_interval.end);
|
||||
std::sort(
|
||||
chunks_overlapping_in_time.begin(), chunks_overlapping_in_time.end(),
|
||||
absl::c_sort(
|
||||
chunks_overlapping_in_time,
|
||||
[](const Chunk& x, const Chunk& y) { return x.offset < y.offset; });
|
||||
|
||||
// Find the minimum free chunk that can hold this buffer.
|
||||
|
@ -117,7 +117,7 @@ class BufferValueMap {
|
||||
for (const auto& pair : buffers_) {
|
||||
buffer_numbers.push_back(pair.first);
|
||||
}
|
||||
std::sort(buffer_numbers.begin(), buffer_numbers.end());
|
||||
absl::c_sort(buffer_numbers);
|
||||
return buffer_numbers;
|
||||
}
|
||||
|
||||
@ -319,7 +319,7 @@ class BufferValueMap {
|
||||
ComputeWhileAliasedBuffers(value, &aliased_buffers);
|
||||
ComputeConditionalAliasedBuffers(value, &aliased_buffers);
|
||||
// Uniquify aliased buffers.
|
||||
std::sort(aliased_buffers.begin(), aliased_buffers.end());
|
||||
absl::c_sort(aliased_buffers);
|
||||
aliased_buffers.erase(
|
||||
std::unique(aliased_buffers.begin(), aliased_buffers.end()),
|
||||
aliased_buffers.end());
|
||||
@ -367,7 +367,7 @@ std::vector<const HloBuffer*> HloAliasAnalysis::ComputeBuffersAt(
|
||||
}
|
||||
|
||||
// Sort and uniquify vector before returning.
|
||||
std::sort(buffers.begin(), buffers.end(), HloBuffer::IdLessThan);
|
||||
absl::c_sort(buffers, HloBuffer::IdLessThan);
|
||||
buffers.erase(std::unique(buffers.begin(), buffers.end()), buffers.end());
|
||||
|
||||
return buffers;
|
||||
@ -430,8 +430,7 @@ Status HloAliasAnalysis::Verify() const {
|
||||
for (const auto& pair : value_to_buffer_) {
|
||||
const HloValue* value = pair.first;
|
||||
const HloBuffer& buffer = *pair.second;
|
||||
TF_RET_CHECK(std::find(buffer.values().begin(), buffer.values().end(),
|
||||
value) != buffer.values().end());
|
||||
TF_RET_CHECK(absl::c_linear_search(buffer.values(), value));
|
||||
}
|
||||
|
||||
for (HloBuffer::Id id = 0; id < buffers_.size(); ++id) {
|
||||
@ -515,7 +514,7 @@ StatusOr<std::unique_ptr<HloAliasAnalysis>> HloAliasAnalysis::Run(
|
||||
auto& value_set = buffer_map.GetValuesInBuffer(buffer_number);
|
||||
std::vector<const HloValue*> sorted_values(value_set.begin(),
|
||||
value_set.end());
|
||||
std::sort(sorted_values.begin(), sorted_values.end(), HloValue::IdLessThan);
|
||||
absl::c_sort(sorted_values, HloValue::IdLessThan);
|
||||
alias_analysis->buffers_.emplace_back(next_id++, sorted_values);
|
||||
for (const HloValue* value : sorted_values) {
|
||||
alias_analysis->value_to_buffer_[value] =
|
||||
@ -547,16 +546,15 @@ bool HloAliasAnalysis::HasLiveRangeInterference(
|
||||
// tie-break using value ID. The tie-break is necessary because we need a
|
||||
// strict weak order for std::sort.
|
||||
std::vector<const HloValue*> values = buffer.values();
|
||||
std::sort(values.begin(), values.end(),
|
||||
[&ordering](const HloValue* a, const HloValue* b) {
|
||||
if (ordering.IsDefinedBefore(*a, *b)) {
|
||||
return true;
|
||||
} else if (ordering.IsDefinedBefore(*b, *a)) {
|
||||
return false;
|
||||
} else {
|
||||
return a->id() < b->id();
|
||||
}
|
||||
});
|
||||
absl::c_sort(values, [&ordering](const HloValue* a, const HloValue* b) {
|
||||
if (ordering.IsDefinedBefore(*a, *b)) {
|
||||
return true;
|
||||
} else if (ordering.IsDefinedBefore(*b, *a)) {
|
||||
return false;
|
||||
} else {
|
||||
return a->id() < b->id();
|
||||
}
|
||||
});
|
||||
|
||||
// Walk through the ordered vector of values. First verify that the values
|
||||
// are totally ordered with respect to 'ordering', then check that no
|
||||
|
@ -49,7 +49,7 @@ std::vector<HloPosition> HloBuffer::ComputePositions() const {
|
||||
value->positions().end());
|
||||
}
|
||||
// Remove duplicates and sort positions.
|
||||
std::sort(positions.begin(), positions.end());
|
||||
absl::c_sort(positions);
|
||||
positions.erase(std::unique(positions.begin(), positions.end()),
|
||||
positions.end());
|
||||
return positions;
|
||||
|
@ -531,11 +531,10 @@ HloComputation::CreateFromProto(
|
||||
HloInstruction* root = instruction_map.at(proto.root_id());
|
||||
|
||||
// Sort the instructions in the proto id's order.
|
||||
std::sort(instructions.begin(), instructions.end(),
|
||||
[&](const std::unique_ptr<HloInstruction>& a,
|
||||
const std::unique_ptr<HloInstruction>& b) {
|
||||
return to_proto_id[a.get()] < to_proto_id[b.get()];
|
||||
});
|
||||
absl::c_sort(instructions, [&](const std::unique_ptr<HloInstruction>& a,
|
||||
const std::unique_ptr<HloInstruction>& b) {
|
||||
return to_proto_id[a.get()] < to_proto_id[b.get()];
|
||||
});
|
||||
|
||||
TF_RETURN_IF_ERROR([&]() -> Status {
|
||||
std::vector<bool> parameters_seen(parameter_count);
|
||||
@ -800,8 +799,7 @@ Status HloComputation::AcceptOrdered(
|
||||
absl::Span<HloInstruction* const> order) const {
|
||||
VLOG(3) << "Accepting visitor with order.";
|
||||
for (HloInstruction* root : CollectUnreachableRoots()) {
|
||||
TF_RET_CHECK(std::find(order.begin(), order.end(), root) != order.end())
|
||||
<< root->ToString();
|
||||
TF_RET_CHECK(absl::c_linear_search(order, root)) << root->ToString();
|
||||
}
|
||||
TF_RET_CHECK(order.size() == instruction_count());
|
||||
absl::flat_hash_set<const HloInstruction*> visited;
|
||||
|
@ -256,7 +256,7 @@ bool HloDataflowAnalysis::Phi(
|
||||
input_value_ids.push_back(value->id());
|
||||
}
|
||||
}
|
||||
std::sort(input_value_ids.begin(), input_value_ids.end());
|
||||
absl::c_sort(input_value_ids);
|
||||
input_value_ids.erase(
|
||||
std::unique(input_value_ids.begin(), input_value_ids.end()),
|
||||
input_value_ids.end());
|
||||
@ -271,8 +271,7 @@ bool HloDataflowAnalysis::Phi(
|
||||
if (current_value_defined_here) {
|
||||
VLOG(5) << "current_value_defined_here: " << current_value->ToString();
|
||||
CHECK(current_value->is_phi());
|
||||
auto it = std::find(input_value_ids.begin(), input_value_ids.end(),
|
||||
current_value->id());
|
||||
auto it = absl::c_find(input_value_ids, current_value->id());
|
||||
if (it != input_value_ids.end()) {
|
||||
input_value_ids.erase(it);
|
||||
}
|
||||
@ -921,8 +920,7 @@ StatusOr<std::unique_ptr<HloDataflowAnalysis>> HloDataflowAnalysis::Run(
|
||||
for (auto& pair : dataflow_analysis->values_) {
|
||||
dataflow_analysis->values_vector_.push_back(&pair.second);
|
||||
}
|
||||
std::sort(dataflow_analysis->values_vector_.begin(),
|
||||
dataflow_analysis->values_vector_.end(), HloValue::IdLessThan);
|
||||
absl::c_sort(dataflow_analysis->values_vector_, HloValue::IdLessThan);
|
||||
|
||||
TF_DCHECK_OK(dataflow_analysis->Verify());
|
||||
|
||||
@ -937,9 +935,7 @@ Status HloDataflowAnalysis::Verify() const {
|
||||
for (const HloValue* value : values()) {
|
||||
for (const HloPosition& position : value->positions()) {
|
||||
const HloValueSet& value_set = GetValueSet(position);
|
||||
TF_RET_CHECK(std::find(value_set.values().begin(),
|
||||
value_set.values().end(),
|
||||
value) != value_set.values().end())
|
||||
TF_RET_CHECK(absl::c_linear_search(value_set.values(), value))
|
||||
<< "Value set at position " << position << " does not contain value "
|
||||
<< value->ToShortString();
|
||||
}
|
||||
@ -954,9 +950,7 @@ Status HloDataflowAnalysis::Verify() const {
|
||||
const HloValueSet& value_set = pair.second;
|
||||
const HloPosition position{instruction, index};
|
||||
for (const HloValue* value : value_set.values()) {
|
||||
TF_RET_CHECK(std::find(value->positions().begin(),
|
||||
value->positions().end(),
|
||||
position) != value->positions().end())
|
||||
TF_RET_CHECK(absl::c_linear_search(value->positions(), position))
|
||||
<< "Value set at position " << position
|
||||
<< " unexpectedly contains value " << value->ToShortString();
|
||||
}
|
||||
@ -1041,11 +1035,10 @@ bool HloDataflowAnalysis::CanShareOperandBufferWithUser(
|
||||
// Check if one operand of kAdd fused root is kDot or kConvolution.
|
||||
auto* add = user->fused_expression_root();
|
||||
auto add_operand_it =
|
||||
std::find_if(add->operands().begin(), add->operands().end(),
|
||||
[&](HloInstruction* operand) {
|
||||
return operand->opcode() == HloOpcode::kConvolution ||
|
||||
operand->opcode() == HloOpcode::kDot;
|
||||
});
|
||||
absl::c_find_if(add->operands(), [&](HloInstruction* operand) {
|
||||
return operand->opcode() == HloOpcode::kConvolution ||
|
||||
operand->opcode() == HloOpcode::kDot;
|
||||
});
|
||||
if (add_operand_it == add->operands().end()) {
|
||||
return false;
|
||||
}
|
||||
@ -1100,16 +1093,15 @@ bool HloDataflowAnalysis::CanShareOperandBufferWithUser(
|
||||
// *) The root instruction of the called computation is element-wise on
|
||||
// 'operand'.
|
||||
const bool found_caller_use =
|
||||
std::find_if(uses.begin(), uses.end(), [user](const HloUse& use) {
|
||||
absl::c_find_if(uses, [user](const HloUse& use) {
|
||||
return use.instruction == user;
|
||||
}) != uses.end();
|
||||
auto* callee_root = user->to_apply()->root_instruction();
|
||||
const bool found_elementwise_callee_use =
|
||||
std::find_if(
|
||||
uses.begin(), uses.end(), [callee_root](const HloUse& use) {
|
||||
return use.instruction == callee_root &&
|
||||
callee_root->IsElementwiseOnOperand(use.operand_number);
|
||||
}) != uses.end();
|
||||
absl::c_find_if(uses, [callee_root](const HloUse& use) {
|
||||
return use.instruction == callee_root &&
|
||||
callee_root->IsElementwiseOnOperand(use.operand_number);
|
||||
}) != uses.end();
|
||||
return uses.size() == 2 && found_caller_use && found_elementwise_callee_use;
|
||||
}
|
||||
|
||||
|
@ -43,9 +43,7 @@ class HloDceTest : public HloTestBase {
|
||||
// Returns whether the given instruction exists in the given computation.
|
||||
bool HasInstruction(const HloComputation& computation,
|
||||
const HloInstruction* instruction) {
|
||||
return std::find(computation.instructions().begin(),
|
||||
computation.instructions().end(),
|
||||
instruction) != computation.instructions().end();
|
||||
return absl::c_linear_search(computation.instructions(), instruction);
|
||||
}
|
||||
};
|
||||
|
||||
|
@ -230,10 +230,10 @@ HloDomainMap::MakeNonDomainInstructions(
|
||||
}
|
||||
}
|
||||
// sort instructions according to instructions_order
|
||||
std::sort(instructions.begin(), instructions.end(),
|
||||
[&instructions_order](HloInstruction* a, HloInstruction* b) {
|
||||
return instructions_order.at(a) < instructions_order.at(b);
|
||||
});
|
||||
absl::c_sort(instructions,
|
||||
[&instructions_order](HloInstruction* a, HloInstruction* b) {
|
||||
return instructions_order.at(a) < instructions_order.at(b);
|
||||
});
|
||||
return instructions;
|
||||
}
|
||||
|
||||
|
@ -1248,8 +1248,7 @@ StatusOr<Literal> EvaluateSortInternal(HloInstruction* sort,
|
||||
// Extract a slice from the keys and values literals that correspond to
|
||||
// exactly the row in dimension 'sort_dim'.
|
||||
std::vector<int64> limit_indices(indices.begin(), indices.end());
|
||||
std::for_each(limit_indices.begin(), limit_indices.end(),
|
||||
[](int64& index) { ++index; });
|
||||
absl::c_for_each(limit_indices, [](int64& index) { ++index; });
|
||||
limit_indices[sort_dim] = sort_dim_elements;
|
||||
TF_ASSIGN_OR_RETURN(auto keys_to_sort,
|
||||
keys_literal.Slice(indices, limit_indices)
|
||||
|
@ -1673,8 +1673,7 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault {
|
||||
// Extract a slice from the literal that corresponds to exactly the
|
||||
// row in dimension 'sort_dim'.
|
||||
std::vector<int64> limit_indices(indices.begin(), indices.end());
|
||||
std::for_each(limit_indices.begin(), limit_indices.end(),
|
||||
[](int64& index) { ++index; });
|
||||
absl::c_for_each(limit_indices, [](int64& index) { ++index; });
|
||||
limit_indices[sort_dim] = sort_dim_elements;
|
||||
TF_ASSIGN_OR_RETURN(auto row_to_sort,
|
||||
keys_literal.Slice(indices, limit_indices)
|
||||
|
@ -561,8 +561,8 @@ bool HloDotDumper::ShouldShowSubcomputation(const HloComputation* subcomp) {
|
||||
}
|
||||
|
||||
// Show the subcomputation if we're showing any of its members.
|
||||
return std::any_of(
|
||||
subcomp->instructions().begin(), subcomp->instructions().end(),
|
||||
return absl::c_any_of(
|
||||
subcomp->instructions(),
|
||||
[&](const HloInstruction* instr) { return filter_.Show(instr); });
|
||||
}
|
||||
|
||||
@ -735,15 +735,14 @@ bool HloDotDumper::ShouldMergeIntoUsers(const HloInstruction* instr) const {
|
||||
const int kMinUsersToOmit = 3;
|
||||
return instr->opcode() == HloOpcode::kParameter && instr->shape().IsTuple() &&
|
||||
!instr->IsFused() &&
|
||||
std::count_if(instr->users().begin(), instr->users().end(),
|
||||
[&](const HloInstruction* user) {
|
||||
return filter_.Show(user);
|
||||
}) > kMinUsersToOmit &&
|
||||
std::all_of(instr->users().begin(), instr->users().end(),
|
||||
[&](const HloInstruction* user) {
|
||||
return !filter_.Show(user) ||
|
||||
user->opcode() == HloOpcode::kGetTupleElement;
|
||||
});
|
||||
absl::c_count_if(instr->users(),
|
||||
[&](const HloInstruction* user) {
|
||||
return filter_.Show(user);
|
||||
}) > kMinUsersToOmit &&
|
||||
absl::c_all_of(instr->users(), [&](const HloInstruction* user) {
|
||||
return !filter_.Show(user) ||
|
||||
user->opcode() == HloOpcode::kGetTupleElement;
|
||||
});
|
||||
}
|
||||
|
||||
string HloDotDumper::DumpInstruction(const HloInstruction* instr) {
|
||||
@ -900,12 +899,11 @@ ColorScheme HloDotDumper::GetInstructionColor(const HloInstruction* instr) {
|
||||
// the same color as a parameter. Unless the merged-in parameter is a
|
||||
// parameter to a fusion node that is bound to a constant -- these aren't
|
||||
// "real" parameters from the user's perspective.
|
||||
if (std::any_of(instr->operands().begin(), instr->operands().end(),
|
||||
[&](const HloInstruction* operand) {
|
||||
return operand->opcode() == HloOpcode::kParameter &&
|
||||
ShouldMergeIntoUsers(operand) &&
|
||||
TryGetFusionParameterConstant(operand) == nullptr;
|
||||
})) {
|
||||
if (absl::c_any_of(instr->operands(), [&](const HloInstruction* operand) {
|
||||
return operand->opcode() == HloOpcode::kParameter &&
|
||||
ShouldMergeIntoUsers(operand) &&
|
||||
TryGetFusionParameterConstant(operand) == nullptr;
|
||||
})) {
|
||||
return parameter_color;
|
||||
}
|
||||
|
||||
@ -1355,12 +1353,11 @@ NodeFilter MakeNodeRadiusAroundFilter(const HloInstruction* root,
|
||||
NodeFilterResult& filter_result = kv.second;
|
||||
const auto& operands = instr->operands();
|
||||
|
||||
if (std::any_of(operands.begin(), operands.end(), is_displayed) &&
|
||||
!std::all_of(operands.begin(), operands.end(), is_displayed)) {
|
||||
if (absl::c_any_of(operands, is_displayed) &&
|
||||
!absl::c_all_of(operands, is_displayed)) {
|
||||
// Mark nodes with some operands omitted appropriately.
|
||||
filter_result = kSomeOperandsOmitted;
|
||||
} else if (!operands.empty() &&
|
||||
std::none_of(operands.begin(), operands.end(), is_displayed)) {
|
||||
} else if (!operands.empty() && absl::c_none_of(operands, is_displayed)) {
|
||||
// Mark nodes with *all* operands omitted appropriately.
|
||||
filter_result = kOmitNodeOperands;
|
||||
}
|
||||
@ -1368,8 +1365,7 @@ NodeFilter MakeNodeRadiusAroundFilter(const HloInstruction* root,
|
||||
// Promote nodes with type kSomeUsersOmitted to kNormalNode if all of their
|
||||
// users made it into the graph.
|
||||
if (filter_result == kSomeUsersOmitted &&
|
||||
std::all_of(instr->users().begin(), instr->users().end(),
|
||||
is_displayed)) {
|
||||
absl::c_all_of(instr->users(), is_displayed)) {
|
||||
filter_result = kNormalNode;
|
||||
}
|
||||
}
|
||||
|
@ -83,15 +83,14 @@ StatusOr<std::unique_ptr<HloInstruction>> HloInstruction::CreateFromProto(
|
||||
return computation_map.at(proto.called_computation_ids(index));
|
||||
};
|
||||
|
||||
TF_RET_CHECK(std::all_of(
|
||||
proto.operand_ids().begin(), proto.operand_ids().end(),
|
||||
[&instruction_map](int64 id) { return instruction_map.contains(id); }))
|
||||
TF_RET_CHECK(
|
||||
absl::c_all_of(proto.operand_ids(),
|
||||
[&](int64 id) { return instruction_map.contains(id); }))
|
||||
<< proto.name() << " instruction contains invalid operand id(s)";
|
||||
|
||||
TF_RET_CHECK(std::all_of(
|
||||
proto.called_computation_ids().begin(),
|
||||
proto.called_computation_ids().end(),
|
||||
[&computation_map](int64 id) { return computation_map.contains(id); }))
|
||||
TF_RET_CHECK(
|
||||
absl::c_all_of(proto.called_computation_ids(),
|
||||
[&](int64 id) { return computation_map.contains(id); }))
|
||||
<< proto.name() << " instruction references invalid computation id(s)";
|
||||
|
||||
Shape shape(proto.shape());
|
||||
@ -1599,12 +1598,10 @@ HloInstruction::InstructionVector HloInstruction::unique_operands() const {
|
||||
|
||||
Status HloInstruction::AddControlDependencyTo(HloInstruction* instruction) {
|
||||
TF_RET_CHECK(instruction->parent() == parent());
|
||||
if (std::find(control_successors_.begin(), control_successors_.end(),
|
||||
instruction) == control_successors_.end()) {
|
||||
if (!absl::c_linear_search(control_successors_, instruction)) {
|
||||
control_successors_.push_back(instruction);
|
||||
TF_RET_CHECK(std::find(instruction->control_predecessors_.begin(),
|
||||
instruction->control_predecessors_.end(),
|
||||
this) == instruction->control_predecessors_.end());
|
||||
TF_RET_CHECK(
|
||||
!absl::c_linear_search(instruction->control_predecessors_, this));
|
||||
instruction->control_predecessors_.push_back(this);
|
||||
}
|
||||
return Status::OK();
|
||||
@ -1853,7 +1850,7 @@ void HloInstruction::RemoveUser(HloInstruction* user) {
|
||||
user_set_.erase(set_it);
|
||||
// This is linear in the number of the users, but a vector provides a stable
|
||||
// iteration order and much faster traversal.
|
||||
auto vec_it = std::find(users_.begin(), users_.end(), user);
|
||||
auto vec_it = absl::c_find(users_, user);
|
||||
CHECK(vec_it != users_.end());
|
||||
users_.erase(vec_it);
|
||||
}
|
||||
@ -1871,8 +1868,7 @@ Status HloInstruction::ReplaceUseWith(HloInstruction* user,
|
||||
|
||||
RemoveUser(user);
|
||||
|
||||
TF_RET_CHECK(
|
||||
std::count(user->operands_.begin(), user->operands_.end(), this) >= 0);
|
||||
TF_RET_CHECK(absl::c_count(user->operands_, this) >= 0);
|
||||
std::replace(user->operands_.begin(), user->operands_.end(), this,
|
||||
new_producer);
|
||||
new_producer->AddUser(user);
|
||||
@ -1907,8 +1903,7 @@ Status HloInstruction::ReplaceOperandWithDifferentShape(
|
||||
VLOG(3) << "Replacing operand " << operand_num << " of " << name() << " with "
|
||||
<< new_operand->name() << ", was " << old_operand->name();
|
||||
|
||||
if (std::find(operands_.begin(), operands_.end(), old_operand) ==
|
||||
operands_.end()) {
|
||||
if (!absl::c_linear_search(operands_, old_operand)) {
|
||||
old_operand->RemoveUser(this);
|
||||
}
|
||||
new_operand->AddUser(this);
|
||||
@ -2945,10 +2940,10 @@ StatusOr<HloInstruction::FusionKind> StringToFusionKind(
|
||||
|
||||
string PaddingConfigToString(const PaddingConfig& padding) {
|
||||
bool has_interior_padding =
|
||||
std::any_of(padding.dimensions().begin(), padding.dimensions().end(),
|
||||
[](const PaddingConfig::PaddingConfigDimension& dim) {
|
||||
return dim.interior_padding() != 0;
|
||||
});
|
||||
absl::c_any_of(padding.dimensions(),
|
||||
[](const PaddingConfig::PaddingConfigDimension& dim) {
|
||||
return dim.interior_padding() != 0;
|
||||
});
|
||||
return StrJoin(
|
||||
padding.dimensions(), "x",
|
||||
[&](string* out, const PaddingConfig::PaddingConfigDimension& dim) {
|
||||
|
@ -42,11 +42,9 @@ using absl::StrJoin;
|
||||
bool IsInstructionElementwiseOnOperand(const HloInstruction* instruction,
|
||||
const HloInstruction* operand) {
|
||||
std::vector<int64> operand_indices = instruction->OperandIndices(operand);
|
||||
return std::all_of(
|
||||
operand_indices.begin(), operand_indices.end(),
|
||||
[instruction](int64 operand_index) {
|
||||
return instruction->IsElementwiseOnOperand(operand_index);
|
||||
});
|
||||
return absl::c_all_of(operand_indices, [instruction](int64 operand_index) {
|
||||
return instruction->IsElementwiseOnOperand(operand_index);
|
||||
});
|
||||
}
|
||||
|
||||
string PrecisionConfigToString(const PrecisionConfig& precision_config) {
|
||||
@ -814,8 +812,7 @@ std::vector<string> HloSliceInstruction::ExtraAttributesToStringImpl(
|
||||
std::vector<string> bounds;
|
||||
bounds.reserve(slice_starts_.size());
|
||||
const bool omit_stride =
|
||||
std::all_of(slice_strides_.begin(), slice_strides_.end(),
|
||||
[](int64 stride) { return stride == 1; });
|
||||
absl::c_all_of(slice_strides_, [](int64 stride) { return stride == 1; });
|
||||
for (int i = 0; i < slice_starts_.size(); ++i) {
|
||||
string stride_str = omit_stride ? "" : StrCat(":", slice_strides_[i]);
|
||||
bounds.push_back(
|
||||
@ -1051,8 +1048,7 @@ HloInstruction* HloFusionInstruction::AddFusionOperand(
|
||||
|
||||
void HloFusionInstruction::MergeFusionInstruction(
|
||||
HloFusionInstruction* instruction_to_merge) {
|
||||
CHECK(std::find(operands().begin(), operands().end(), instruction_to_merge) !=
|
||||
operands().end());
|
||||
CHECK(absl::c_linear_search(operands(), instruction_to_merge));
|
||||
// Clone the instruction from which to merge fused instructions.
|
||||
std::unique_ptr<HloInstruction> cloned = instruction_to_merge->Clone();
|
||||
HloFusionInstruction* cloned_fusion =
|
||||
@ -1219,8 +1215,8 @@ HloInstruction* HloFusionInstruction::CloneAndFuseInternal(
|
||||
// corresponding fused parameter instruction. Renumber parameters as
|
||||
// necessary to make parameter numbers consistent with their index in the
|
||||
// fused_parameter_ vector.
|
||||
bool in_operand_list = std::find(operands().begin(), operands().end(),
|
||||
instruction_to_fuse) != operands().end();
|
||||
bool in_operand_list =
|
||||
absl::c_linear_search(operands(), instruction_to_fuse);
|
||||
CHECK(add_output || in_operand_list);
|
||||
if (instruction_to_fuse->opcode() == HloOpcode::kTuple) {
|
||||
// We assume all uses of a kTuple operation are GTE ops, not another
|
||||
|
@ -107,11 +107,10 @@ HloComputation* HloModule::AddEntryComputation(
|
||||
}
|
||||
|
||||
Status HloModule::RemoveEmbeddedComputation(HloComputation* to_remove) {
|
||||
auto it =
|
||||
std::find_if(computations_.begin(), computations_.end(),
|
||||
[&to_remove](const std::unique_ptr<HloComputation>& comp) {
|
||||
return comp.get() == to_remove;
|
||||
});
|
||||
auto it = absl::c_find_if(
|
||||
computations_, [&to_remove](const std::unique_ptr<HloComputation>& comp) {
|
||||
return comp.get() == to_remove;
|
||||
});
|
||||
TF_RET_CHECK(it->get() == to_remove);
|
||||
computations_.erase(it);
|
||||
return Status::OK();
|
||||
@ -304,11 +303,10 @@ StatusOr<std::unique_ptr<HloModule>> HloModule::CreateFromProto(
|
||||
auto module = absl::make_unique<HloModule>(proto.name(), module_config);
|
||||
|
||||
// Sort the computations in the proto id's order.
|
||||
std::sort(computations.begin(), computations.end(),
|
||||
[&](const std::unique_ptr<HloComputation>& a,
|
||||
const std::unique_ptr<HloComputation>& b) {
|
||||
return to_proto_id[a.get()] < to_proto_id[b.get()];
|
||||
});
|
||||
absl::c_sort(computations, [&](const std::unique_ptr<HloComputation>& a,
|
||||
const std::unique_ptr<HloComputation>& b) {
|
||||
return to_proto_id[a.get()] < to_proto_id[b.get()];
|
||||
});
|
||||
|
||||
// Add sorted computations to the module.
|
||||
for (auto& computation : computations) {
|
||||
|
@ -38,9 +38,7 @@ class HloModuleDceTest : public HloTestBase {
|
||||
// Returns whether the given instruction exists in the given computation.
|
||||
bool HasInstruction(const HloComputation& computation,
|
||||
const HloInstruction* instruction) {
|
||||
return std::find(computation.instructions().begin(),
|
||||
computation.instructions().end(),
|
||||
instruction) != computation.instructions().end();
|
||||
return absl::c_linear_search(computation.instructions(), instruction);
|
||||
}
|
||||
|
||||
// Returns whether the while instruction with name 'while_name' in
|
||||
|
@ -2746,7 +2746,7 @@ bool HloParser::ParseConvolutionDimensionNumbers(
|
||||
}
|
||||
|
||||
auto is_unique = [](string str) -> bool {
|
||||
std::sort(str.begin(), str.end());
|
||||
absl::c_sort(str);
|
||||
return std::unique(str.begin(), str.end()) == str.end();
|
||||
};
|
||||
|
||||
|
@ -15,6 +15,7 @@ limitations under the License.
|
||||
|
||||
#include "tensorflow/compiler/xla/service/hlo_profile_printer.h"
|
||||
|
||||
#include "absl/algorithm/container.h"
|
||||
#include "absl/strings/str_cat.h"
|
||||
#include "tensorflow/compiler/xla/service/human_readable_profile_builder.h"
|
||||
|
||||
@ -34,11 +35,10 @@ string PrintHloProfile(const HloProfilePrinterData& hlo_profile_printer_data,
|
||||
for (const HloComputationInfo& computation_info :
|
||||
hlo_profile_printer_data.computation_infos()) {
|
||||
const auto& instruction_infos = computation_info.instruction_infos();
|
||||
bool any_instruction_profiled =
|
||||
std::any_of(instruction_infos.begin(), instruction_infos.end(),
|
||||
[&](const HloInstructionInfo& instruction_info) {
|
||||
return counters[instruction_info.profile_index()] != 0;
|
||||
});
|
||||
bool any_instruction_profiled = absl::c_any_of(
|
||||
instruction_infos, [&](const HloInstructionInfo& instruction_info) {
|
||||
return counters[instruction_info.profile_index()] != 0;
|
||||
});
|
||||
|
||||
if (!any_instruction_profiled) {
|
||||
continue;
|
||||
|
@ -49,7 +49,7 @@ void HloReachabilityMap::SetReachabilityToUnionHelper(
|
||||
absl::Span<const HloInstruction* const> inputs,
|
||||
const HloInstruction* instruction, BitVector* bit_vector) {
|
||||
// If instruction is part of inputs, don't reset the bit_vector.
|
||||
if (std::find(inputs.begin(), inputs.end(), instruction) == inputs.end()) {
|
||||
if (!absl::c_linear_search(inputs, instruction)) {
|
||||
bit_vector->SetToZero();
|
||||
}
|
||||
bit_vector->Set(GetIndex(instruction));
|
||||
|
@ -235,8 +235,7 @@ class InstructionList {
|
||||
}
|
||||
|
||||
// Now scan forwards until we find one of the before_instructions.
|
||||
while (std::find(before_instructions.begin(), before_instructions.end(),
|
||||
min_position_item) == before_instructions.end()) {
|
||||
while (!absl::c_linear_search(before_instructions, min_position_item)) {
|
||||
min_position_item = min_position_item->next;
|
||||
}
|
||||
return InsertBefore(to_insert, min_position_item);
|
||||
@ -302,7 +301,7 @@ ItemList GetUsers(const InstructionList& instruction_list,
|
||||
// A buffer may be used by the instruction via more than one alias. For
|
||||
// example, a buffer which appears in more than one element of a tuple.
|
||||
Item* user_item = instruction_list.GetItem(user);
|
||||
if (std::find(users.begin(), users.end(), user_item) == users.end()) {
|
||||
if (!absl::c_linear_search(users, user_item)) {
|
||||
users.push_back(user_item);
|
||||
}
|
||||
}
|
||||
@ -456,8 +455,7 @@ class MemoryUsageTracker {
|
||||
return false;
|
||||
}
|
||||
const BufferIdList& in_progress_uses = in_progress_item_->buffers_used;
|
||||
return std::find(in_progress_uses.begin(), in_progress_uses.end(),
|
||||
buffer_id) != in_progress_uses.end();
|
||||
return absl::c_linear_search(in_progress_uses, buffer_id);
|
||||
}
|
||||
|
||||
// Returns whether the given instruction is live at the current program
|
||||
@ -535,8 +533,7 @@ MemoryUsageTracker::MemoryUsageTracker(
|
||||
bool unused;
|
||||
for (Item* user_item : GetUsers(instruction_list_, logical_buffer,
|
||||
points_to_analysis, &unused)) {
|
||||
if (std::find(buffer->users.begin(), buffer->users.end(),
|
||||
user_item) == buffer->users.end()) {
|
||||
if (!absl::c_linear_search(buffer->users, user_item)) {
|
||||
buffer->users.push_back(user_item);
|
||||
buffer->unfinished_user_count++;
|
||||
user_item->buffers_used.push_back(buffer->id);
|
||||
@ -784,8 +781,7 @@ bool MemoryUsageTracker::Check() const {
|
||||
|
||||
for (const Buffer& buffer : buffers_) {
|
||||
if (buffer.defining_instruction->instruction == instruction) {
|
||||
CHECK(std::find(defined_buffers.begin(), defined_buffers.end(),
|
||||
buffer.id) != defined_buffers.end())
|
||||
CHECK(absl::c_linear_search(defined_buffers, buffer.id))
|
||||
<< "Instruction " << instruction->name()
|
||||
<< " defined buffers is missing: " << buffer.ToString();
|
||||
}
|
||||
@ -808,8 +804,7 @@ bool MemoryUsageTracker::Check() const {
|
||||
int64 unfinished_uses = 0;
|
||||
for (Item* user : buffer.users) {
|
||||
const BufferIdList& used_buffers = user->buffers_used;
|
||||
CHECK(std::find(used_buffers.begin(), used_buffers.end(), buffer.id) !=
|
||||
used_buffers.end())
|
||||
CHECK(absl::c_linear_search(used_buffers, buffer.id))
|
||||
<< "Instruction " << user->instruction->name()
|
||||
<< " used buffers is missing " << buffer.ToString();
|
||||
if (!IsFinished(user)) {
|
||||
@ -836,10 +831,10 @@ int64 RematerializationCost(const HloInstruction* instruction,
|
||||
// If none of the users of 'instruction' have been placed in the sequence (as
|
||||
// tracked by memory_tracker), then rematerialization of 'instruction' is a
|
||||
// zero-cost move of 'instruction' in the sequence.
|
||||
if (!std::any_of(instruction->users().begin(), instruction->users().end(),
|
||||
[&memory_tracker](const HloInstruction* inst) {
|
||||
return memory_tracker.IsPlaced(inst);
|
||||
})) {
|
||||
if (!absl::c_any_of(instruction->users(),
|
||||
[&memory_tracker](const HloInstruction* inst) {
|
||||
return memory_tracker.IsPlaced(inst);
|
||||
})) {
|
||||
return 0;
|
||||
}
|
||||
|
||||
|
@ -107,13 +107,12 @@ string HloSharding::ToString() const {
|
||||
|
||||
bool HloSharding::UsesDevice(int64 device) const {
|
||||
if (IsTuple()) {
|
||||
return std::any_of(
|
||||
tuple_elements_.begin(), tuple_elements_.end(),
|
||||
[&](const HloSharding& s) { return s.UsesDevice(device); });
|
||||
return absl::c_any_of(tuple_elements_, [&](const HloSharding& s) {
|
||||
return s.UsesDevice(device);
|
||||
});
|
||||
}
|
||||
const auto& devices = tile_assignment_;
|
||||
return replicated_ ||
|
||||
std::find(devices.begin(), devices.end(), device) != devices.end();
|
||||
return replicated_ || absl::c_linear_search(devices, device);
|
||||
}
|
||||
|
||||
std::map<int64, int64> HloSharding::UsedDevices(int64* count) const {
|
||||
|
@ -101,8 +101,8 @@ class HloSharding {
|
||||
if (!IsTuple()) {
|
||||
return replicated_;
|
||||
}
|
||||
return std::all_of(tuple_elements_.begin(), tuple_elements_.end(),
|
||||
[](const HloSharding& s) { return s.IsReplicated(); });
|
||||
return absl::c_all_of(
|
||||
tuple_elements_, [](const HloSharding& s) { return s.IsReplicated(); });
|
||||
}
|
||||
|
||||
// Returns true if the tile size is the same as the input size.
|
||||
@ -110,8 +110,9 @@ class HloSharding {
|
||||
if (!IsTuple()) {
|
||||
return maximal_;
|
||||
}
|
||||
return std::all_of(tuple_elements_.begin(), tuple_elements_.end(),
|
||||
[](const HloSharding& s) { return s.IsTileMaximal(); });
|
||||
return absl::c_all_of(tuple_elements_, [](const HloSharding& s) {
|
||||
return s.IsTileMaximal();
|
||||
});
|
||||
}
|
||||
|
||||
// Returns true if the sharding defines an operation on the given device.
|
||||
|
@ -61,8 +61,7 @@ void CleanNodeName(string* name) {
|
||||
name->erase(std::remove(name->begin(), name->end(), '%'), name->end());
|
||||
const string chars_to_replace = "<>[]";
|
||||
auto pred = [&](char c) {
|
||||
return std::find(chars_to_replace.begin(), chars_to_replace.end(), c) !=
|
||||
chars_to_replace.end();
|
||||
return absl::c_linear_search(chars_to_replace, c);
|
||||
};
|
||||
std::replace_if(name->begin(), name->end(), pred, '_');
|
||||
}
|
||||
|
@ -209,7 +209,7 @@ std::ostream& operator<<(std::ostream& out, const HloValue& value) {
|
||||
}
|
||||
|
||||
void HloValueSet::SortAndUniquifyValues() {
|
||||
std::sort(values_.begin(), values_.end(), HloValue::IdLessThan);
|
||||
absl::c_sort(values_, HloValue::IdLessThan);
|
||||
values_.erase(std::unique(values_.begin(), values_.end(), HloValue::IdEqual),
|
||||
values_.end());
|
||||
}
|
||||
|
@ -128,9 +128,9 @@ string HumanReadableProfileBuilder::ToString() const {
|
||||
|
||||
// Sort ops in decreasing order of cycles, and print them.
|
||||
std::vector<OpInfo> sorted_ops(op_infos_);
|
||||
std::sort(
|
||||
sorted_ops.begin(), sorted_ops.end(),
|
||||
[](const OpInfo& a, const OpInfo& b) { return a.cycles > b.cycles; });
|
||||
absl::c_sort(sorted_ops, [](const OpInfo& a, const OpInfo& b) {
|
||||
return a.cycles > b.cycles;
|
||||
});
|
||||
for (const auto& op : sorted_ops) {
|
||||
print_op(op);
|
||||
}
|
||||
|
@ -178,19 +178,18 @@ bool InstructionFusion::EffectivelyAtMostUnary(HloInstruction* hlo) {
|
||||
output_rank = std::max(output_rank, ShapeUtil::TrueRank(subshape));
|
||||
}
|
||||
});
|
||||
return std::count_if(hlo->operands().begin(), hlo->operands().end(),
|
||||
[output_rank](HloInstruction* operand) {
|
||||
if (operand->opcode() == HloOpcode::kBroadcast ||
|
||||
operand->opcode() == HloOpcode::kIota) {
|
||||
return false;
|
||||
}
|
||||
if (operand->opcode() == HloOpcode::kConstant &&
|
||||
ShapeUtil::IsEffectiveScalar(operand->shape())) {
|
||||
return false;
|
||||
}
|
||||
return ShapeUtil::TrueRank(operand->shape()) >=
|
||||
output_rank;
|
||||
}) <= 1;
|
||||
return absl::c_count_if(
|
||||
hlo->operands(), [output_rank](HloInstruction* operand) {
|
||||
if (operand->opcode() == HloOpcode::kBroadcast ||
|
||||
operand->opcode() == HloOpcode::kIota) {
|
||||
return false;
|
||||
}
|
||||
if (operand->opcode() == HloOpcode::kConstant &&
|
||||
ShapeUtil::IsEffectiveScalar(operand->shape())) {
|
||||
return false;
|
||||
}
|
||||
return ShapeUtil::TrueRank(operand->shape()) >= output_rank;
|
||||
}) <= 1;
|
||||
}
|
||||
|
||||
bool InstructionFusion::CanFuseOnAllPaths(
|
||||
@ -409,9 +408,8 @@ class ReversePostOrderFusionQueue : public FusionQueue {
|
||||
}
|
||||
sorted_operand_numbers.push_back(i);
|
||||
}
|
||||
std::sort(
|
||||
sorted_operand_numbers.begin(), sorted_operand_numbers.end(),
|
||||
[&](int64 i, int64 j) {
|
||||
absl::c_sort(
|
||||
sorted_operand_numbers, [&](int64 i, int64 j) {
|
||||
// Instructions with higher priority in the queue come first.
|
||||
return (
|
||||
FindOrDie(post_order_index_, instruction->mutable_operand(i)) >
|
||||
|
@ -147,12 +147,9 @@ bool LayoutConstraints::OperandBufferForwarded(
|
||||
PointsToSet::BufferSet* output_buffers = GetBufferSet(instruction);
|
||||
PointsToSet::BufferSet* operand_buffers =
|
||||
GetBufferSet(instruction->operand(operand_no));
|
||||
for (const LogicalBuffer* output_buffer : *output_buffers) {
|
||||
if (operand_buffers->count(output_buffer) > 0) {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
return false;
|
||||
return absl::c_any_of(*output_buffers, [&](const LogicalBuffer* b) {
|
||||
return operand_buffers->count(b) > 0;
|
||||
});
|
||||
}
|
||||
|
||||
Status LayoutConstraints::SetBufferLayout(const Layout& layout,
|
||||
|
@ -81,9 +81,7 @@ void AliasAnalysis::AddAliasingInformationToIrArray(const HloInstruction& hlo,
|
||||
if (hlo.opcode() == HloOpcode::kParameter) {
|
||||
const std::vector<HloInstruction*>& parameter_instructions =
|
||||
module_.entry_computation()->parameter_instructions();
|
||||
if (std::find(parameter_instructions.begin(),
|
||||
parameter_instructions.end(),
|
||||
&hlo) != parameter_instructions.end()) {
|
||||
if (absl::c_linear_search(parameter_instructions, &hlo)) {
|
||||
array->MarkInvariantOverWholeProgram(context_);
|
||||
}
|
||||
}
|
||||
|
@ -34,7 +34,7 @@ bool IsAllowed(char character) {
|
||||
} // namespace
|
||||
|
||||
NameUniquer::NameUniquer(const string& separator) {
|
||||
CHECK(std::all_of(separator.begin(), separator.end(), IsAllowed))
|
||||
CHECK(absl::c_all_of(separator, IsAllowed))
|
||||
<< "separator should comprises allowed characters only";
|
||||
separator_ = separator;
|
||||
}
|
||||
|
@ -260,8 +260,8 @@ PlatformUtil::GetStreamExecutors(
|
||||
// Block here in thread_pool destructor until all devices are initialized.
|
||||
}
|
||||
VLOG(1) << "Device initialization complete";
|
||||
if (std::all_of(stream_executors.begin(), stream_executors.end(),
|
||||
[](se::StreamExecutor* s) { return s == nullptr; })) {
|
||||
if (absl::c_all_of(stream_executors,
|
||||
[](se::StreamExecutor* s) { return s == nullptr; })) {
|
||||
return InternalError("no supported devices found for platform %s",
|
||||
platform->Name());
|
||||
}
|
||||
|
@ -534,9 +534,8 @@ Status ValidateDotDimensionNumbers(
|
||||
absl::Span<const int64> contracting_dims,
|
||||
absl::Span<const int64> batch_dims) -> bool {
|
||||
auto in_range = [&rank](int64 i) -> bool { return 0 <= i && i < rank; };
|
||||
return std::all_of(contracting_dims.begin(), contracting_dims.end(),
|
||||
in_range) &&
|
||||
std::all_of(batch_dims.begin(), batch_dims.end(), in_range);
|
||||
return absl::c_all_of(contracting_dims, in_range) &&
|
||||
absl::c_all_of(batch_dims, in_range);
|
||||
};
|
||||
|
||||
absl::Span<const int64> lhs_contracting_dimensions =
|
||||
@ -563,9 +562,8 @@ Status ValidateDotDimensionNumbers(
|
||||
auto is_unique = [&dim_set](int64 i) -> bool {
|
||||
return dim_set.insert(i).second;
|
||||
};
|
||||
return std::all_of(contracting_dims.begin(), contracting_dims.end(),
|
||||
is_unique) &&
|
||||
std::all_of(batch_dims.begin(), batch_dims.end(), is_unique);
|
||||
return absl::c_all_of(contracting_dims, is_unique) &&
|
||||
absl::c_all_of(batch_dims, is_unique);
|
||||
};
|
||||
|
||||
if (!dims_unique(lhs_contracting_dimensions, lhs_batch_dimensions) ||
|
||||
@ -1589,29 +1587,29 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation,
|
||||
input_dnums[1] = dnums.input_feature_dimension();
|
||||
std::copy(dnums.input_spatial_dimensions().begin(),
|
||||
dnums.input_spatial_dimensions().end(), input_dnums.begin() + 2);
|
||||
std::sort(input_dnums.begin(), input_dnums.end());
|
||||
absl::c_sort(input_dnums);
|
||||
|
||||
std::vector<int64> window_dnums(num_dims);
|
||||
window_dnums[0] = dnums.kernel_input_feature_dimension();
|
||||
window_dnums[1] = dnums.kernel_output_feature_dimension();
|
||||
std::copy(dnums.kernel_spatial_dimensions().begin(),
|
||||
dnums.kernel_spatial_dimensions().end(), window_dnums.begin() + 2);
|
||||
std::sort(window_dnums.begin(), window_dnums.end());
|
||||
absl::c_sort(window_dnums);
|
||||
|
||||
std::vector<int64> output_dnums(num_dims);
|
||||
output_dnums[0] = dnums.output_batch_dimension();
|
||||
output_dnums[1] = dnums.output_feature_dimension();
|
||||
std::copy(dnums.output_spatial_dimensions().begin(),
|
||||
dnums.output_spatial_dimensions().end(), output_dnums.begin() + 2);
|
||||
std::sort(output_dnums.begin(), output_dnums.end());
|
||||
absl::c_sort(output_dnums);
|
||||
|
||||
std::vector<int64> expected_dnums(num_dims);
|
||||
std::iota(expected_dnums.begin(), expected_dnums.end(), 0);
|
||||
|
||||
const auto in_range = [num_dims](int64 i) { return 0 <= i && i < num_dims; };
|
||||
if (!std::all_of(input_dnums.begin(), input_dnums.end(), in_range) ||
|
||||
!std::all_of(window_dnums.begin(), window_dnums.end(), in_range) ||
|
||||
!std::all_of(output_dnums.begin(), output_dnums.end(), in_range)) {
|
||||
if (!absl::c_all_of(input_dnums, in_range) ||
|
||||
!absl::c_all_of(window_dnums, in_range) ||
|
||||
!absl::c_all_of(output_dnums, in_range)) {
|
||||
return InvalidArgument(
|
||||
"A dimension number is out of range in convolution: %s.",
|
||||
dnums.DebugString());
|
||||
|
@ -130,8 +130,7 @@ bool FoldTransposeIntoConvolution(InstructionOperandsPair pair) {
|
||||
|
||||
HloInstruction* new_lhs;
|
||||
const int64 kLhsIdx = 0;
|
||||
if (std::find(operand_indices.begin(), operand_indices.end(), kLhsIdx) !=
|
||||
operand_indices.end()) {
|
||||
if (absl::c_linear_search(operand_indices, kLhsIdx)) {
|
||||
HloInstruction& transpose = *convolution.mutable_operand(kLhsIdx);
|
||||
const auto& transpose_dimensions = transpose.dimensions();
|
||||
HloInstruction& transpose_operand = *transpose.mutable_operand(0);
|
||||
@ -154,8 +153,7 @@ bool FoldTransposeIntoConvolution(InstructionOperandsPair pair) {
|
||||
|
||||
HloInstruction* new_rhs;
|
||||
const int64 kRhsIdx = 1;
|
||||
if (std::find(operand_indices.begin(), operand_indices.end(), kRhsIdx) !=
|
||||
operand_indices.end()) {
|
||||
if (absl::c_linear_search(operand_indices, kRhsIdx)) {
|
||||
HloInstruction& transpose = *convolution.mutable_operand(kRhsIdx);
|
||||
const auto& transpose_dimensions = transpose.dimensions();
|
||||
HloInstruction& transpose_operand = *transpose.mutable_operand(0);
|
||||
|
@ -87,9 +87,7 @@ bool PointsToSet::ContainsBuffer(const LogicalBuffer& buffer) const {
|
||||
bool found = false;
|
||||
ForEachElement([&found, &buffer](const ShapeIndex& /*index*/,
|
||||
const BufferList& pointed_to_buffers) {
|
||||
if (!found &&
|
||||
std::find(pointed_to_buffers.begin(), pointed_to_buffers.end(),
|
||||
&buffer) != pointed_to_buffers.end()) {
|
||||
if (!found && absl::c_linear_search(pointed_to_buffers, &buffer)) {
|
||||
found = true;
|
||||
}
|
||||
});
|
||||
@ -99,8 +97,7 @@ bool PointsToSet::ContainsBuffer(const LogicalBuffer& buffer) const {
|
||||
bool PointsToSet::ContainsBufferAtIndex(const LogicalBuffer& buffer,
|
||||
const ShapeIndex& index) const {
|
||||
const auto& pointed_to_buffers = element(index);
|
||||
return std::find(pointed_to_buffers.begin(), pointed_to_buffers.end(),
|
||||
&buffer) != pointed_to_buffers.end();
|
||||
return absl::c_linear_search(pointed_to_buffers, &buffer);
|
||||
}
|
||||
|
||||
void PointsToSet::AddPointedToBuffer(const LogicalBuffer& buffer,
|
||||
@ -604,9 +601,8 @@ bool TuplePointsToAnalysis::DoesNotUseOperandBuffer(
|
||||
} else if (user->opcode() == HloOpcode::kFusion &&
|
||||
user->fusion_kind() == HloInstruction::FusionKind::kLoop) {
|
||||
// Find fusion parameter associated with 'operand'.
|
||||
auto it = std::find_if(
|
||||
user->fused_parameters().begin(), user->fused_parameters().end(),
|
||||
[=](HloInstruction* fused_param) {
|
||||
auto it = absl::c_find_if(
|
||||
user->fused_parameters(), [&](HloInstruction* fused_param) {
|
||||
return user->operand(fused_param->parameter_number()) == operand;
|
||||
});
|
||||
CHECK(it != user->fused_parameters().end());
|
||||
@ -672,9 +668,8 @@ bool TuplePointsToAnalysis::HasUniqueFusedUseOfOperandAt(
|
||||
}
|
||||
// Find fusion parameter associated with 'operand'.
|
||||
const auto& fused_params = fusion->fused_parameters();
|
||||
auto fused_param_it = std::find_if(
|
||||
fused_params.begin(), fused_params.end(),
|
||||
[&](HloInstruction* fused_param) {
|
||||
auto fused_param_it =
|
||||
absl::c_find_if(fused_params, [&](HloInstruction* fused_param) {
|
||||
return fusion->operand(fused_param->parameter_number()) == operand;
|
||||
});
|
||||
if (fused_param_it == fused_params.end()) {
|
||||
@ -743,11 +738,10 @@ bool TuplePointsToAnalysis::CanShareOperandBufferWithUser(
|
||||
// Check if one operand of kAdd fused root is kDot or kConvolution.
|
||||
auto* add = user->fused_expression_root();
|
||||
auto add_operand_it =
|
||||
std::find_if(add->operands().begin(), add->operands().end(),
|
||||
[&](HloInstruction* operand) {
|
||||
return operand->opcode() == HloOpcode::kConvolution ||
|
||||
operand->opcode() == HloOpcode::kDot;
|
||||
});
|
||||
absl::c_find_if(add->operands(), [&](HloInstruction* operand) {
|
||||
return operand->opcode() == HloOpcode::kConvolution ||
|
||||
operand->opcode() == HloOpcode::kDot;
|
||||
});
|
||||
if (add_operand_it == add->operands().end()) {
|
||||
return false;
|
||||
}
|
||||
|
@ -721,9 +721,8 @@ class FusionPointsToAnalysisTest : public TuplePointsToAnalysisTest {
|
||||
// to fusion 'operand'.
|
||||
HloInstruction* GetFusionParameterForOperand(HloInstruction* fusion,
|
||||
HloInstruction* operand) {
|
||||
auto it = std::find_if(
|
||||
fusion->fused_instructions().begin(),
|
||||
fusion->fused_instructions().end(), [=](const HloInstruction* fused) {
|
||||
auto it = absl::c_find_if(
|
||||
fusion->fused_instructions(), [&](const HloInstruction* fused) {
|
||||
return fused->opcode() == HloOpcode::kParameter &&
|
||||
fusion->operand(fused->parameter_number()) == operand;
|
||||
});
|
||||
|
@ -109,8 +109,7 @@ static StatusOr<bool> TryRemoveDeadWhileParams(HloInstruction* while_op) {
|
||||
// operand appears in, but it may appear more than once!
|
||||
if (user->user_count() == 1 && user->users().front() == while_body_root &&
|
||||
while_body_root->operand_index(user) == user->tuple_index() &&
|
||||
std::count(while_body_root->operands().begin(),
|
||||
while_body_root->operands().end(), user) == 1) {
|
||||
absl::c_count(while_body_root->operands(), user) == 1) {
|
||||
continue;
|
||||
}
|
||||
|
||||
@ -158,7 +157,7 @@ static StatusOr<bool> TryRemoveDeadWhileParams(HloInstruction* while_op) {
|
||||
// Build up maps from the old/new to the new/old tuple indices.
|
||||
std::vector<int64> new_to_old_tuple_idx(used_tuple_indices.begin(),
|
||||
used_tuple_indices.end());
|
||||
std::sort(new_to_old_tuple_idx.begin(), new_to_old_tuple_idx.end());
|
||||
absl::c_sort(new_to_old_tuple_idx);
|
||||
|
||||
absl::flat_hash_map<int64, int64> old_to_new_tuple_idx;
|
||||
for (int64 new_idx = 0; new_idx < new_to_old_tuple_idx.size(); ++new_idx) {
|
||||
|
@ -407,13 +407,12 @@ TEST_F(WhileLoopSimplifierTest, RemoveUnusedLoopOperands) {
|
||||
// The original while instruction is still left in the module as a dead
|
||||
// instruction, find a while instruction with a different name as the new
|
||||
// while instruction.
|
||||
const auto& instrs = m->entry_computation()->instructions();
|
||||
HloInstruction* new_while_op =
|
||||
*std::find_if(m->entry_computation()->instructions().begin(),
|
||||
m->entry_computation()->instructions().end(),
|
||||
[&](const HloInstruction* instr) {
|
||||
return (instr->opcode() == HloOpcode::kWhile &&
|
||||
instr->name() != "while");
|
||||
});
|
||||
*absl::c_find_if(instrs, [&](const HloInstruction* instr) {
|
||||
return (instr->opcode() == HloOpcode::kWhile &&
|
||||
instr->name() != "while");
|
||||
});
|
||||
|
||||
auto scalar_s32 = ShapeUtil::MakeShape(S32, {});
|
||||
EXPECT_TRUE(
|
||||
|
@ -87,8 +87,7 @@ bool Shape::is_static() const {
|
||||
}
|
||||
}
|
||||
}
|
||||
return !std::any_of(dynamic_dimensions_.begin(), dynamic_dimensions_.end(),
|
||||
[](bool b) { return b; });
|
||||
return !absl::c_any_of(dynamic_dimensions_, [](bool b) { return b; });
|
||||
}
|
||||
|
||||
void Shape::DeleteDimension(int64 dim_to_delete) {
|
||||
|
@ -405,8 +405,7 @@ ShapeUtil::MakeShapeWithDescendingLayoutAndSamePhysicalLayout(
|
||||
}
|
||||
|
||||
/* static */ bool ShapeUtil::IsNestedTuple(const Shape& shape) {
|
||||
return IsTuple(shape) && std::any_of(shape.tuple_shapes().begin(),
|
||||
shape.tuple_shapes().end(), IsTuple);
|
||||
return IsTuple(shape) && absl::c_any_of(shape.tuple_shapes(), IsTuple);
|
||||
}
|
||||
|
||||
/* static */ bool ShapeUtil::IsEmptyTuple(const Shape& shape) {
|
||||
|
@ -135,7 +135,7 @@ void SparseIndexArray::SortWithValues(absl::Span<NativeT> values) {
|
||||
auto sort_order_less = [this](int64 lhs, int64 rhs) {
|
||||
return IndexUtil::CompareIndices(At(lhs), At(rhs)) < 0;
|
||||
};
|
||||
std::sort(sort_order.begin(), sort_order.end(), sort_order_less);
|
||||
absl::c_sort(sort_order, sort_order_less);
|
||||
|
||||
// Reorder the array elements according to sort_order. Work through the array
|
||||
// and follow cycles so we can do the reorder in-place.
|
||||
|
@ -467,8 +467,8 @@ XLA_TEST_F(ConvolutionTest, Convolve3D_1x4x2x3x3_2x2x2x3x3_Valid) {
|
||||
// servers. The error message is missing the operator ++.
|
||||
template <typename T>
|
||||
void iota_int_init_value(std::vector<T>& values, int init_value) {
|
||||
std::for_each(values.begin(), values.end(),
|
||||
[&](T& value) { value = static_cast<T>(init_value++); });
|
||||
absl::c_for_each(values,
|
||||
[&](T& value) { value = static_cast<T>(init_value++); });
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
|
@ -86,7 +86,7 @@ bool IsPermutation(absl::Span<const int64> permutation, int64 rank) {
|
||||
CHECK_LT(index, rank);
|
||||
output[index] = 0;
|
||||
}
|
||||
return std::find(output.begin(), output.end(), -1) == output.end();
|
||||
return !absl::c_linear_search(output, -1);
|
||||
}
|
||||
|
||||
std::vector<int64> InversePermutation(
|
||||
|
@ -324,8 +324,7 @@ bool IsIdentityPermutation(absl::Span<const int64> permutation);
|
||||
|
||||
template <typename Container>
|
||||
int64 PositionInContainer(const Container& container, int64 value) {
|
||||
return std::distance(container.begin(),
|
||||
std::find(container.begin(), container.end(), value));
|
||||
return std::distance(container.begin(), absl::c_find(container, value));
|
||||
}
|
||||
|
||||
// Formats the container as a comma-separated string. StrAppend must support
|
||||
|
@ -17,6 +17,7 @@ limitations under the License.
|
||||
|
||||
#include <vector>
|
||||
|
||||
#include "absl/algorithm/container.h"
|
||||
#include "absl/strings/str_cat.h"
|
||||
#include "tensorflow/compiler/xla/types.h"
|
||||
#include "tensorflow/compiler/xla/xla_data.pb.h"
|
||||
@ -137,25 +138,23 @@ bool HasPadding(const Window& window) {
|
||||
}
|
||||
|
||||
bool HasSymmetricPadding(const Window& window) {
|
||||
return std::all_of(window.dimensions().begin(), window.dimensions().end(),
|
||||
[](const WindowDimension& dim) {
|
||||
return dim.padding_low() == dim.padding_high();
|
||||
});
|
||||
return absl::c_all_of(window.dimensions(), [](const WindowDimension& dim) {
|
||||
return dim.padding_low() == dim.padding_high();
|
||||
});
|
||||
}
|
||||
|
||||
bool HasSymmetricPadding(const PaddingConfig& padding_config) {
|
||||
return std::all_of(padding_config.dimensions().begin(),
|
||||
padding_config.dimensions().end(),
|
||||
[](const PaddingConfig::PaddingConfigDimension& dim) {
|
||||
return dim.edge_padding_low() == dim.edge_padding_high();
|
||||
});
|
||||
return absl::c_all_of(padding_config.dimensions(),
|
||||
[](const PaddingConfig::PaddingConfigDimension& dim) {
|
||||
return dim.edge_padding_low() ==
|
||||
dim.edge_padding_high();
|
||||
});
|
||||
}
|
||||
|
||||
bool HasNegativePadding(const Window& window) {
|
||||
return std::any_of(window.dimensions().begin(), window.dimensions().end(),
|
||||
[](const WindowDimension& dim) {
|
||||
return dim.padding_low() < 0 || dim.padding_high() < 0;
|
||||
});
|
||||
return absl::c_any_of(window.dimensions(), [](const WindowDimension& dim) {
|
||||
return dim.padding_low() < 0 || dim.padding_high() < 0;
|
||||
});
|
||||
}
|
||||
|
||||
bool HasBaseDilation(const Window& window) {
|
||||
@ -190,10 +189,9 @@ bool AllOrNoneReversed(const Window& window) {
|
||||
return true;
|
||||
}
|
||||
bool reversed = window.dimensions()[0].window_reversal();
|
||||
return std::all_of(window.dimensions().begin(), window.dimensions().end(),
|
||||
[&](const WindowDimension& dim) {
|
||||
return dim.window_reversal() == reversed;
|
||||
});
|
||||
return absl::c_all_of(window.dimensions(), [&](const WindowDimension& dim) {
|
||||
return dim.window_reversal() == reversed;
|
||||
});
|
||||
}
|
||||
|
||||
bool HasDilation(const Window& window) {
|
||||
|
Loading…
x
Reference in New Issue
Block a user