[XLA] Use absl::c_foo rather than std::foo.
No functional change. PiperOrigin-RevId: 227896034
This commit is contained in:
parent
f9bd1568aa
commit
b4813a0cff
@ -717,6 +717,7 @@ cc_library(
|
|||||||
":types",
|
":types",
|
||||||
":xla_data_proto",
|
":xla_data_proto",
|
||||||
"//tensorflow/core:lib",
|
"//tensorflow/core:lib",
|
||||||
|
"@com_google_absl//absl/algorithm:container",
|
||||||
"@com_google_absl//absl/strings",
|
"@com_google_absl//absl/strings",
|
||||||
"@com_google_absl//absl/types:span",
|
"@com_google_absl//absl/types:span",
|
||||||
],
|
],
|
||||||
|
@ -77,7 +77,7 @@ XLA_TEST_F(SortingTest, TopKFullSort) {
|
|||||||
auto x = ConstantR1<float>(&builder, inputs);
|
auto x = ConstantR1<float>(&builder, inputs);
|
||||||
xla::GetTupleElement(xla::TopK(x, kSize), 0);
|
xla::GetTupleElement(xla::TopK(x, kSize), 0);
|
||||||
|
|
||||||
std::sort(inputs.begin(), inputs.end(), std::greater<float>());
|
absl::c_sort(inputs, std::greater<float>());
|
||||||
ComputeAndCompareR1<float>(&builder, inputs, {});
|
ComputeAndCompareR1<float>(&builder, inputs, {});
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -290,7 +290,7 @@ Layout CreateDefaultLayoutForRank(int64 rank) {
|
|||||||
/* static */ bool LayoutUtil::HasLayout(const Shape& shape) {
|
/* static */ bool LayoutUtil::HasLayout(const Shape& shape) {
|
||||||
if (shape.IsTuple()) {
|
if (shape.IsTuple()) {
|
||||||
// Tuple shape: all subshapes must have a layout.
|
// Tuple shape: all subshapes must have a layout.
|
||||||
return std::all_of(shape.tuple_shapes().begin(), shape.tuple_shapes().end(),
|
return absl::c_all_of(shape.tuple_shapes(),
|
||||||
[](const Shape& s) { return HasLayout(s); });
|
[](const Shape& s) { return HasLayout(s); });
|
||||||
} else if (!shape.IsArray()) {
|
} else if (!shape.IsArray()) {
|
||||||
// Opaque, token types etc. ignore layout.
|
// Opaque, token types etc. ignore layout.
|
||||||
@ -424,7 +424,7 @@ Status LayoutUtil::CopyLayoutBetweenShapes(const Shape& src, Shape* dst) {
|
|||||||
positions_in_layout.push_back(
|
positions_in_layout.push_back(
|
||||||
PositionInContainer(layout.minor_to_major(), dim));
|
PositionInContainer(layout.minor_to_major(), dim));
|
||||||
}
|
}
|
||||||
std::sort(positions_in_layout.begin(), positions_in_layout.end());
|
absl::c_sort(positions_in_layout);
|
||||||
for (size_t i = 1; i < positions_in_layout.size(); ++i) {
|
for (size_t i = 1; i < positions_in_layout.size(); ++i) {
|
||||||
if (1 != positions_in_layout[i] - positions_in_layout[i - 1]) {
|
if (1 != positions_in_layout[i] - positions_in_layout[i - 1]) {
|
||||||
return false;
|
return false;
|
||||||
|
@ -55,7 +55,7 @@ string MetricTableReport::MakeReport(double expected_metric_sum) {
|
|||||||
const auto metric_greater = [](const Entry& a, const Entry& b) {
|
const auto metric_greater = [](const Entry& a, const Entry& b) {
|
||||||
return a.metric > b.metric;
|
return a.metric > b.metric;
|
||||||
};
|
};
|
||||||
std::sort(entries_.begin(), entries_.end(), metric_greater);
|
absl::c_sort(entries_, metric_greater);
|
||||||
|
|
||||||
// Create the report
|
// Create the report
|
||||||
AppendLine();
|
AppendLine();
|
||||||
@ -117,7 +117,7 @@ std::vector<MetricTableReport::Category> MetricTableReport::MakeCategories(
|
|||||||
auto metric_sum_greater = [](const Category& a, const Category& b) {
|
auto metric_sum_greater = [](const Category& a, const Category& b) {
|
||||||
return a.metric_sum > b.metric_sum;
|
return a.metric_sum > b.metric_sum;
|
||||||
};
|
};
|
||||||
std::sort(categories.begin(), categories.end(), metric_sum_greater);
|
absl::c_sort(categories, metric_sum_greater);
|
||||||
|
|
||||||
return categories;
|
return categories;
|
||||||
}
|
}
|
||||||
|
@ -3414,6 +3414,7 @@ cc_library(
|
|||||||
":hlo_profile_printer_data",
|
":hlo_profile_printer_data",
|
||||||
":human_readable_profile_builder",
|
":human_readable_profile_builder",
|
||||||
"//tensorflow/compiler/xla:types",
|
"//tensorflow/compiler/xla:types",
|
||||||
|
"@com_google_absl//absl/algorithm:container",
|
||||||
"@com_google_absl//absl/strings",
|
"@com_google_absl//absl/strings",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
@ -2216,8 +2216,7 @@ Status AlgebraicSimplifierVisitor::HandleReverse(HloInstruction* reverse) {
|
|||||||
auto dim_is_one = [&](int64 i) -> bool {
|
auto dim_is_one = [&](int64 i) -> bool {
|
||||||
return reverse->shape().dimensions(i) == 1;
|
return reverse->shape().dimensions(i) == 1;
|
||||||
};
|
};
|
||||||
if (std::all_of(reverse->dimensions().begin(), reverse->dimensions().end(),
|
if (absl::c_all_of(reverse->dimensions(), dim_is_one)) {
|
||||||
dim_is_one)) {
|
|
||||||
return ReplaceInstruction(reverse, reverse->mutable_operand(0));
|
return ReplaceInstruction(reverse, reverse->mutable_operand(0));
|
||||||
}
|
}
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
@ -2492,9 +2491,9 @@ Status AlgebraicSimplifierVisitor::HandleReduce(HloInstruction* reduce) {
|
|||||||
// Create a new reduce with the combined reduction dimensions of both
|
// Create a new reduce with the combined reduction dimensions of both
|
||||||
// reduces.
|
// reduces.
|
||||||
std::vector<int64> arg_dims = arg->dimensions();
|
std::vector<int64> arg_dims = arg->dimensions();
|
||||||
std::sort(arg_dims.begin(), arg_dims.end());
|
absl::c_sort(arg_dims);
|
||||||
std::vector<int64> reduce_dims = reduce->dimensions();
|
std::vector<int64> reduce_dims = reduce->dimensions();
|
||||||
std::sort(reduce_dims.begin(), reduce_dims.end());
|
absl::c_sort(reduce_dims);
|
||||||
// Transform reduce_dims to the same rank as the operand of the operand.
|
// Transform reduce_dims to the same rank as the operand of the operand.
|
||||||
for (int64 arg_dim : arg_dims) {
|
for (int64 arg_dim : arg_dims) {
|
||||||
for (int64& dim : reduce_dims) {
|
for (int64& dim : reduce_dims) {
|
||||||
|
@ -86,8 +86,7 @@ std::vector<int64> ColorInterferenceGraph(
|
|||||||
// first, but it would be good to investigate other ordering heuristics too.
|
// first, but it would be good to investigate other ordering heuristics too.
|
||||||
std::vector<int64> nodes(node_count);
|
std::vector<int64> nodes(node_count);
|
||||||
std::iota(nodes.begin(), nodes.end(), 0);
|
std::iota(nodes.begin(), nodes.end(), 0);
|
||||||
std::sort(nodes.begin(), nodes.end(),
|
absl::c_sort(nodes, [&interference_map](const int64 i, const int64 j) {
|
||||||
[&interference_map](const int64 i, const int64 j) {
|
|
||||||
return interference_map[i].size() > interference_map[j].size();
|
return interference_map[i].size() > interference_map[j].size();
|
||||||
});
|
});
|
||||||
|
|
||||||
@ -272,10 +271,11 @@ BufferAllocationProto BufferAllocation::ToProto() const {
|
|||||||
proto_assigned->set_offset(buffer_offset_size.second.offset);
|
proto_assigned->set_offset(buffer_offset_size.second.offset);
|
||||||
proto_assigned->set_size(buffer_offset_size.second.size);
|
proto_assigned->set_size(buffer_offset_size.second.size);
|
||||||
}
|
}
|
||||||
std::sort(proto.mutable_assigned()->begin(), proto.mutable_assigned()->end(),
|
absl::c_sort(*proto.mutable_assigned(),
|
||||||
[](const BufferAllocationProto::Assigned& assign1,
|
[](const BufferAllocationProto::Assigned& assign1,
|
||||||
const BufferAllocationProto::Assigned& assign2) {
|
const BufferAllocationProto::Assigned& assign2) {
|
||||||
return assign1.logical_buffer_id() < assign2.logical_buffer_id();
|
return assign1.logical_buffer_id() <
|
||||||
|
assign2.logical_buffer_id();
|
||||||
});
|
});
|
||||||
return proto;
|
return proto;
|
||||||
}
|
}
|
||||||
@ -308,7 +308,7 @@ string BufferAllocation::ToString() const {
|
|||||||
for (const auto& buffer_offset_size : assigned_buffers_) {
|
for (const auto& buffer_offset_size : assigned_buffers_) {
|
||||||
sorted_buffers.push_back(buffer_offset_size.first);
|
sorted_buffers.push_back(buffer_offset_size.first);
|
||||||
}
|
}
|
||||||
std::sort(sorted_buffers.begin(), sorted_buffers.end(),
|
absl::c_sort(sorted_buffers,
|
||||||
[](const LogicalBuffer* a, const LogicalBuffer* b) {
|
[](const LogicalBuffer* a, const LogicalBuffer* b) {
|
||||||
return a->id() < b->id();
|
return a->id() < b->id();
|
||||||
});
|
});
|
||||||
@ -479,9 +479,8 @@ bool BufferAssignment::HaveDisjointSlices(const HloInstruction* hlo_a,
|
|||||||
// didn't return the empty set) for both HLOs, and the two resulting sets of
|
// didn't return the empty set) for both HLOs, and the two resulting sets of
|
||||||
// slices are disjoint.
|
// slices are disjoint.
|
||||||
return !slices_a.empty() && !slices_b.empty() &&
|
return !slices_a.empty() && !slices_b.empty() &&
|
||||||
std::none_of(slices_a.begin(), slices_a.end(),
|
absl::c_none_of(slices_a, [&](const BufferAllocation::Slice& slice) {
|
||||||
[&](const BufferAllocation::Slice& slice) {
|
return slices_b.contains(slice);
|
||||||
return slices_b.count(slice) > 0;
|
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -952,9 +951,9 @@ Status BufferAssigner::AssignBuffersForComputation(
|
|||||||
// operands (assuming operands are the same/larger size) enabling the
|
// operands (assuming operands are the same/larger size) enabling the
|
||||||
// important reuse case where an elementwise instruction reuses one of its
|
// important reuse case where an elementwise instruction reuses one of its
|
||||||
// operand's buffer. This improves locality.
|
// operand's buffer. This improves locality.
|
||||||
std::sort(sorted_buffers.begin(), sorted_buffers.end(),
|
absl::c_sort(sorted_buffers,
|
||||||
[has_sequential_order, &liveness, &post_order_position, assignment](
|
[has_sequential_order, &liveness, &post_order_position,
|
||||||
const LogicalBuffer* a, const LogicalBuffer* b) {
|
assignment](const LogicalBuffer* a, const LogicalBuffer* b) {
|
||||||
// Primary sort is by decreasing buffer size.
|
// Primary sort is by decreasing buffer size.
|
||||||
const int64 a_size = assignment->buffer_size_(*a);
|
const int64 a_size = assignment->buffer_size_(*a);
|
||||||
const int64 b_size = assignment->buffer_size_(*b);
|
const int64 b_size = assignment->buffer_size_(*b);
|
||||||
@ -1305,7 +1304,7 @@ std::vector<const LogicalBuffer*> ComputePeakMemoryLogicalBuffers(
|
|||||||
live_buffers.end());
|
live_buffers.end());
|
||||||
|
|
||||||
// Stabily sort the live buffers.
|
// Stabily sort the live buffers.
|
||||||
std::sort(live_buffers_vector.begin(), live_buffers_vector.end(),
|
absl::c_sort(live_buffers_vector,
|
||||||
[](const LogicalBuffer* a, const LogicalBuffer* b) {
|
[](const LogicalBuffer* a, const LogicalBuffer* b) {
|
||||||
return a->id() < b->id();
|
return a->id() < b->id();
|
||||||
});
|
});
|
||||||
|
@ -384,7 +384,7 @@ TEST_F(CallGraphTest, ComplexGraph) {
|
|||||||
|
|
||||||
// Verify visitation order of some computations in the graph.
|
// Verify visitation order of some computations in the graph.
|
||||||
auto index_of = [&visited](const HloComputation* comp) {
|
auto index_of = [&visited](const HloComputation* comp) {
|
||||||
auto it = std::find(visited.begin(), visited.end(), comp);
|
auto it = absl::c_find(visited, comp);
|
||||||
EXPECT_NE(it, visited.end());
|
EXPECT_NE(it, visited.end());
|
||||||
return std::distance(visited.begin(), it);
|
return std::distance(visited.begin(), it);
|
||||||
};
|
};
|
||||||
|
@ -42,7 +42,7 @@ void ComputationLayout::SetToDefaultLayout() {
|
|||||||
}
|
}
|
||||||
|
|
||||||
bool ComputationLayout::LayoutIsSet() const {
|
bool ComputationLayout::LayoutIsSet() const {
|
||||||
return std::all_of(parameter_layouts_.begin(), parameter_layouts_.end(),
|
return absl::c_all_of(parameter_layouts_,
|
||||||
[](const ShapeLayout& s) { return s.LayoutIsSet(); }) &&
|
[](const ShapeLayout& s) { return s.LayoutIsSet(); }) &&
|
||||||
result_layout_.LayoutIsSet();
|
result_layout_.LayoutIsSet();
|
||||||
}
|
}
|
||||||
|
@ -539,8 +539,7 @@ class CopyRemover {
|
|||||||
}
|
}
|
||||||
|
|
||||||
std::vector<const HloValue*> values = buffer.values();
|
std::vector<const HloValue*> values = buffer.values();
|
||||||
std::sort(values.begin(), values.end(),
|
absl::c_sort(values, [this](const HloValue* a, const HloValue* b) {
|
||||||
[this](const HloValue* a, const HloValue* b) {
|
|
||||||
return ordering_.IsDefinedBefore(*a, *b);
|
return ordering_.IsDefinedBefore(*a, *b);
|
||||||
});
|
});
|
||||||
|
|
||||||
@ -842,9 +841,8 @@ class CopyRemover {
|
|||||||
copy_value_node->next->prev = operand_node;
|
copy_value_node->next->prev = operand_node;
|
||||||
|
|
||||||
// Patch up uses. Remove use of copy from operand_node uses.
|
// Patch up uses. Remove use of copy from operand_node uses.
|
||||||
auto it =
|
auto it = absl::c_find_if(
|
||||||
std::find_if(operand_node->uses.begin(), operand_node->uses.end(),
|
operand_node->uses, [copy_value_node](const HloUse* use) {
|
||||||
[copy_value_node](const HloUse* use) {
|
|
||||||
return use->instruction ==
|
return use->instruction ==
|
||||||
copy_value_node->value->defining_instruction();
|
copy_value_node->value->defining_instruction();
|
||||||
});
|
});
|
||||||
|
@ -77,9 +77,8 @@ StatusOr<DisassemblerResult> Disassembler::DisassembleObjectFile(
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Sort the symbols in increasing address order.
|
// Sort the symbols in increasing address order.
|
||||||
std::sort(
|
absl::c_sort(symbols, [](const llvm::object::SymbolRef& a,
|
||||||
symbols.begin(), symbols.end(),
|
const llvm::object::SymbolRef& b) {
|
||||||
[](const llvm::object::SymbolRef& a, const llvm::object::SymbolRef& b) {
|
|
||||||
// getAddress returns a Expected object. Assert there is no error
|
// getAddress returns a Expected object. Assert there is no error
|
||||||
// before extracting the address.
|
// before extracting the address.
|
||||||
llvm::Expected<uint64_t> a_address_or_error = a.getAddress();
|
llvm::Expected<uint64_t> a_address_or_error = a.getAddress();
|
||||||
|
@ -1709,10 +1709,8 @@ StatusOr<bool> IrEmitter::EmitVectorizedReduce(
|
|||||||
vectorization_factor_in_bytes /
|
vectorization_factor_in_bytes /
|
||||||
ShapeUtil::ByteSizeOfPrimitiveType(reduce->shape().element_type());
|
ShapeUtil::ByteSizeOfPrimitiveType(reduce->shape().element_type());
|
||||||
|
|
||||||
bool is_reduction_over_minor_dimension =
|
bool is_reduction_over_minor_dimension = absl::c_linear_search(
|
||||||
std::find(dimensions.begin(), dimensions.end(),
|
dimensions, LayoutUtil::Minor(arg->shape().layout(), 0));
|
||||||
LayoutUtil::Minor(arg->shape().layout(), 0)) !=
|
|
||||||
dimensions.end();
|
|
||||||
|
|
||||||
unsigned element_alignment = tensorflow::MathUtil::GCD<unsigned>(
|
unsigned element_alignment = tensorflow::MathUtil::GCD<unsigned>(
|
||||||
ShapeUtil::ByteSizeOfPrimitiveType(reduce->shape().element_type()),
|
ShapeUtil::ByteSizeOfPrimitiveType(reduce->shape().element_type()),
|
||||||
@ -2401,8 +2399,7 @@ StatusOr<bool> IrEmitter::EmitFastConcatenate(
|
|||||||
int64 concat_dim = concatenate->dimensions(0);
|
int64 concat_dim = concatenate->dimensions(0);
|
||||||
const Layout& output_layout = output_shape.layout();
|
const Layout& output_layout = output_shape.layout();
|
||||||
auto output_min2maj = LayoutUtil::MinorToMajor(output_layout);
|
auto output_min2maj = LayoutUtil::MinorToMajor(output_layout);
|
||||||
auto concat_dim_layout_itr =
|
auto concat_dim_layout_itr = absl::c_find(output_min2maj, concat_dim);
|
||||||
std::find(output_min2maj.begin(), output_min2maj.end(), concat_dim);
|
|
||||||
|
|
||||||
std::vector<int64> inner_dims(output_min2maj.begin(), concat_dim_layout_itr);
|
std::vector<int64> inner_dims(output_min2maj.begin(), concat_dim_layout_itr);
|
||||||
std::vector<int64> outer_dims(std::next(concat_dim_layout_itr),
|
std::vector<int64> outer_dims(std::next(concat_dim_layout_itr),
|
||||||
@ -2956,8 +2953,7 @@ Status IrEmitter::ElementTypesSameAndSupported(
|
|||||||
|
|
||||||
TF_RET_CHECK(!operands.empty());
|
TF_RET_CHECK(!operands.empty());
|
||||||
PrimitiveType primitive_type = operands[0]->shape().element_type();
|
PrimitiveType primitive_type = operands[0]->shape().element_type();
|
||||||
if (std::find(supported_types.begin(), supported_types.end(),
|
if (!absl::c_linear_search(supported_types, primitive_type)) {
|
||||||
primitive_type) == supported_types.end()) {
|
|
||||||
return Unimplemented("unsupported operand type %s in op %s",
|
return Unimplemented("unsupported operand type %s in op %s",
|
||||||
PrimitiveType_Name(primitive_type),
|
PrimitiveType_Name(primitive_type),
|
||||||
HloOpcodeString(instruction.opcode()));
|
HloOpcodeString(instruction.opcode()));
|
||||||
|
@ -154,19 +154,16 @@ bool IsReductionToVector(const HloInstruction& reduce) {
|
|||||||
const HloInstruction* input = reduce.operand(0);
|
const HloInstruction* input = reduce.operand(0);
|
||||||
std::vector<int64> dims_to_keep;
|
std::vector<int64> dims_to_keep;
|
||||||
for (int64 dim = 0; dim < input->shape().dimensions().size(); ++dim) {
|
for (int64 dim = 0; dim < input->shape().dimensions().size(); ++dim) {
|
||||||
if (!std::count(reduce.dimensions().begin(), reduce.dimensions().end(),
|
if (!absl::c_linear_search(reduce.dimensions(), dim)) {
|
||||||
dim)) {
|
|
||||||
dims_to_keep.push_back(dim);
|
dims_to_keep.push_back(dim);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return LayoutUtil::AreDimensionsConsecutive(input->shape().layout(),
|
return LayoutUtil::AreDimensionsConsecutive(input->shape().layout(),
|
||||||
dims_to_keep) &&
|
dims_to_keep) &&
|
||||||
ShapeUtil::Equal(reduce.shape(), ShapeUtil::FilterDimensions(
|
ShapeUtil::Equal(
|
||||||
[&dims_to_keep](int64 dim) {
|
reduce.shape(),
|
||||||
return std::count(
|
ShapeUtil::FilterDimensions(
|
||||||
dims_to_keep.begin(),
|
[&](int64 dim) { return absl::c_count(dims_to_keep, dim); },
|
||||||
dims_to_keep.end(), dim);
|
|
||||||
},
|
|
||||||
input->shape()));
|
input->shape()));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -1506,7 +1506,7 @@ std::unique_ptr<KernelThunk> IrEmitterUnnested::BuildKernelThunk(
|
|||||||
return !allocation->is_constant();
|
return !allocation->is_constant();
|
||||||
});
|
});
|
||||||
|
|
||||||
std::sort(non_constant_buffers.begin(), non_constant_buffers.end(),
|
absl::c_sort(non_constant_buffers,
|
||||||
[](const BufferAllocation* a, const BufferAllocation* b) {
|
[](const BufferAllocation* a, const BufferAllocation* b) {
|
||||||
return a->index() < b->index();
|
return a->index() < b->index();
|
||||||
});
|
});
|
||||||
|
@ -225,7 +225,7 @@ Status HeapSimulator::RunComputation(
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
// Sort to get a deterministic iteration order.
|
// Sort to get a deterministic iteration order.
|
||||||
std::sort(operand_buffers_to_free.begin(), operand_buffers_to_free.end(),
|
absl::c_sort(operand_buffers_to_free,
|
||||||
[](const BufferValue* x, const BufferValue* y) {
|
[](const BufferValue* x, const BufferValue* y) {
|
||||||
return x->id() < y->id();
|
return x->id() < y->id();
|
||||||
});
|
});
|
||||||
@ -335,8 +335,7 @@ Status HeapSimulator::RunComputation(
|
|||||||
to_free.push_back(buffer);
|
to_free.push_back(buffer);
|
||||||
}
|
}
|
||||||
|
|
||||||
std::sort(to_free.begin(), to_free.end(),
|
absl::c_sort(to_free, [](const BufferValue* x, const BufferValue* y) {
|
||||||
[](const BufferValue* x, const BufferValue* y) {
|
|
||||||
return x->id() < y->id();
|
return x->id() < y->id();
|
||||||
});
|
});
|
||||||
for (const BufferValue* buffer : to_free) {
|
for (const BufferValue* buffer : to_free) {
|
||||||
@ -596,7 +595,7 @@ void DecreasingSizeRunsHeap::CallAndDrainRun() {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Call ops in the run sorted by decreasing size, breaking ties by buffer id.
|
// Call ops in the run sorted by decreasing size, breaking ties by buffer id.
|
||||||
std::sort(run_.begin(), run_.end(), [](const Op& a, const Op& b) {
|
absl::c_sort(run_, [](const Op& a, const Op& b) {
|
||||||
if (a.size != b.size) {
|
if (a.size != b.size) {
|
||||||
return a.size > b.size;
|
return a.size > b.size;
|
||||||
}
|
}
|
||||||
@ -866,7 +865,7 @@ HeapSimulator::Result GlobalDecreasingSizeBestFitHeap::Finish() {
|
|||||||
for (auto& entry : buffer_intervals_) {
|
for (auto& entry : buffer_intervals_) {
|
||||||
sorted_buffer_intervals.push_back(entry.second);
|
sorted_buffer_intervals.push_back(entry.second);
|
||||||
}
|
}
|
||||||
std::sort(sorted_buffer_intervals.begin(), sorted_buffer_intervals.end(),
|
absl::c_sort(sorted_buffer_intervals,
|
||||||
[](const BufferInterval& x, const BufferInterval& y) {
|
[](const BufferInterval& x, const BufferInterval& y) {
|
||||||
if (x.size != y.size) {
|
if (x.size != y.size) {
|
||||||
return x.size > y.size;
|
return x.size > y.size;
|
||||||
@ -881,8 +880,8 @@ HeapSimulator::Result GlobalDecreasingSizeBestFitHeap::Finish() {
|
|||||||
for (auto& buffer_interval : sorted_buffer_intervals) {
|
for (auto& buffer_interval : sorted_buffer_intervals) {
|
||||||
auto chunks_overlapping_in_time = interval_tree.ChunksOverlappingInTime(
|
auto chunks_overlapping_in_time = interval_tree.ChunksOverlappingInTime(
|
||||||
buffer_interval.start, buffer_interval.end);
|
buffer_interval.start, buffer_interval.end);
|
||||||
std::sort(
|
absl::c_sort(
|
||||||
chunks_overlapping_in_time.begin(), chunks_overlapping_in_time.end(),
|
chunks_overlapping_in_time,
|
||||||
[](const Chunk& x, const Chunk& y) { return x.offset < y.offset; });
|
[](const Chunk& x, const Chunk& y) { return x.offset < y.offset; });
|
||||||
|
|
||||||
// Find the minimum free chunk that can hold this buffer.
|
// Find the minimum free chunk that can hold this buffer.
|
||||||
|
@ -117,7 +117,7 @@ class BufferValueMap {
|
|||||||
for (const auto& pair : buffers_) {
|
for (const auto& pair : buffers_) {
|
||||||
buffer_numbers.push_back(pair.first);
|
buffer_numbers.push_back(pair.first);
|
||||||
}
|
}
|
||||||
std::sort(buffer_numbers.begin(), buffer_numbers.end());
|
absl::c_sort(buffer_numbers);
|
||||||
return buffer_numbers;
|
return buffer_numbers;
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -319,7 +319,7 @@ class BufferValueMap {
|
|||||||
ComputeWhileAliasedBuffers(value, &aliased_buffers);
|
ComputeWhileAliasedBuffers(value, &aliased_buffers);
|
||||||
ComputeConditionalAliasedBuffers(value, &aliased_buffers);
|
ComputeConditionalAliasedBuffers(value, &aliased_buffers);
|
||||||
// Uniquify aliased buffers.
|
// Uniquify aliased buffers.
|
||||||
std::sort(aliased_buffers.begin(), aliased_buffers.end());
|
absl::c_sort(aliased_buffers);
|
||||||
aliased_buffers.erase(
|
aliased_buffers.erase(
|
||||||
std::unique(aliased_buffers.begin(), aliased_buffers.end()),
|
std::unique(aliased_buffers.begin(), aliased_buffers.end()),
|
||||||
aliased_buffers.end());
|
aliased_buffers.end());
|
||||||
@ -367,7 +367,7 @@ std::vector<const HloBuffer*> HloAliasAnalysis::ComputeBuffersAt(
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Sort and uniquify vector before returning.
|
// Sort and uniquify vector before returning.
|
||||||
std::sort(buffers.begin(), buffers.end(), HloBuffer::IdLessThan);
|
absl::c_sort(buffers, HloBuffer::IdLessThan);
|
||||||
buffers.erase(std::unique(buffers.begin(), buffers.end()), buffers.end());
|
buffers.erase(std::unique(buffers.begin(), buffers.end()), buffers.end());
|
||||||
|
|
||||||
return buffers;
|
return buffers;
|
||||||
@ -430,8 +430,7 @@ Status HloAliasAnalysis::Verify() const {
|
|||||||
for (const auto& pair : value_to_buffer_) {
|
for (const auto& pair : value_to_buffer_) {
|
||||||
const HloValue* value = pair.first;
|
const HloValue* value = pair.first;
|
||||||
const HloBuffer& buffer = *pair.second;
|
const HloBuffer& buffer = *pair.second;
|
||||||
TF_RET_CHECK(std::find(buffer.values().begin(), buffer.values().end(),
|
TF_RET_CHECK(absl::c_linear_search(buffer.values(), value));
|
||||||
value) != buffer.values().end());
|
|
||||||
}
|
}
|
||||||
|
|
||||||
for (HloBuffer::Id id = 0; id < buffers_.size(); ++id) {
|
for (HloBuffer::Id id = 0; id < buffers_.size(); ++id) {
|
||||||
@ -515,7 +514,7 @@ StatusOr<std::unique_ptr<HloAliasAnalysis>> HloAliasAnalysis::Run(
|
|||||||
auto& value_set = buffer_map.GetValuesInBuffer(buffer_number);
|
auto& value_set = buffer_map.GetValuesInBuffer(buffer_number);
|
||||||
std::vector<const HloValue*> sorted_values(value_set.begin(),
|
std::vector<const HloValue*> sorted_values(value_set.begin(),
|
||||||
value_set.end());
|
value_set.end());
|
||||||
std::sort(sorted_values.begin(), sorted_values.end(), HloValue::IdLessThan);
|
absl::c_sort(sorted_values, HloValue::IdLessThan);
|
||||||
alias_analysis->buffers_.emplace_back(next_id++, sorted_values);
|
alias_analysis->buffers_.emplace_back(next_id++, sorted_values);
|
||||||
for (const HloValue* value : sorted_values) {
|
for (const HloValue* value : sorted_values) {
|
||||||
alias_analysis->value_to_buffer_[value] =
|
alias_analysis->value_to_buffer_[value] =
|
||||||
@ -547,8 +546,7 @@ bool HloAliasAnalysis::HasLiveRangeInterference(
|
|||||||
// tie-break using value ID. The tie-break is necessary because we need a
|
// tie-break using value ID. The tie-break is necessary because we need a
|
||||||
// strict weak order for std::sort.
|
// strict weak order for std::sort.
|
||||||
std::vector<const HloValue*> values = buffer.values();
|
std::vector<const HloValue*> values = buffer.values();
|
||||||
std::sort(values.begin(), values.end(),
|
absl::c_sort(values, [&ordering](const HloValue* a, const HloValue* b) {
|
||||||
[&ordering](const HloValue* a, const HloValue* b) {
|
|
||||||
if (ordering.IsDefinedBefore(*a, *b)) {
|
if (ordering.IsDefinedBefore(*a, *b)) {
|
||||||
return true;
|
return true;
|
||||||
} else if (ordering.IsDefinedBefore(*b, *a)) {
|
} else if (ordering.IsDefinedBefore(*b, *a)) {
|
||||||
|
@ -49,7 +49,7 @@ std::vector<HloPosition> HloBuffer::ComputePositions() const {
|
|||||||
value->positions().end());
|
value->positions().end());
|
||||||
}
|
}
|
||||||
// Remove duplicates and sort positions.
|
// Remove duplicates and sort positions.
|
||||||
std::sort(positions.begin(), positions.end());
|
absl::c_sort(positions);
|
||||||
positions.erase(std::unique(positions.begin(), positions.end()),
|
positions.erase(std::unique(positions.begin(), positions.end()),
|
||||||
positions.end());
|
positions.end());
|
||||||
return positions;
|
return positions;
|
||||||
|
@ -531,8 +531,7 @@ HloComputation::CreateFromProto(
|
|||||||
HloInstruction* root = instruction_map.at(proto.root_id());
|
HloInstruction* root = instruction_map.at(proto.root_id());
|
||||||
|
|
||||||
// Sort the instructions in the proto id's order.
|
// Sort the instructions in the proto id's order.
|
||||||
std::sort(instructions.begin(), instructions.end(),
|
absl::c_sort(instructions, [&](const std::unique_ptr<HloInstruction>& a,
|
||||||
[&](const std::unique_ptr<HloInstruction>& a,
|
|
||||||
const std::unique_ptr<HloInstruction>& b) {
|
const std::unique_ptr<HloInstruction>& b) {
|
||||||
return to_proto_id[a.get()] < to_proto_id[b.get()];
|
return to_proto_id[a.get()] < to_proto_id[b.get()];
|
||||||
});
|
});
|
||||||
@ -800,8 +799,7 @@ Status HloComputation::AcceptOrdered(
|
|||||||
absl::Span<HloInstruction* const> order) const {
|
absl::Span<HloInstruction* const> order) const {
|
||||||
VLOG(3) << "Accepting visitor with order.";
|
VLOG(3) << "Accepting visitor with order.";
|
||||||
for (HloInstruction* root : CollectUnreachableRoots()) {
|
for (HloInstruction* root : CollectUnreachableRoots()) {
|
||||||
TF_RET_CHECK(std::find(order.begin(), order.end(), root) != order.end())
|
TF_RET_CHECK(absl::c_linear_search(order, root)) << root->ToString();
|
||||||
<< root->ToString();
|
|
||||||
}
|
}
|
||||||
TF_RET_CHECK(order.size() == instruction_count());
|
TF_RET_CHECK(order.size() == instruction_count());
|
||||||
absl::flat_hash_set<const HloInstruction*> visited;
|
absl::flat_hash_set<const HloInstruction*> visited;
|
||||||
|
@ -256,7 +256,7 @@ bool HloDataflowAnalysis::Phi(
|
|||||||
input_value_ids.push_back(value->id());
|
input_value_ids.push_back(value->id());
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
std::sort(input_value_ids.begin(), input_value_ids.end());
|
absl::c_sort(input_value_ids);
|
||||||
input_value_ids.erase(
|
input_value_ids.erase(
|
||||||
std::unique(input_value_ids.begin(), input_value_ids.end()),
|
std::unique(input_value_ids.begin(), input_value_ids.end()),
|
||||||
input_value_ids.end());
|
input_value_ids.end());
|
||||||
@ -271,8 +271,7 @@ bool HloDataflowAnalysis::Phi(
|
|||||||
if (current_value_defined_here) {
|
if (current_value_defined_here) {
|
||||||
VLOG(5) << "current_value_defined_here: " << current_value->ToString();
|
VLOG(5) << "current_value_defined_here: " << current_value->ToString();
|
||||||
CHECK(current_value->is_phi());
|
CHECK(current_value->is_phi());
|
||||||
auto it = std::find(input_value_ids.begin(), input_value_ids.end(),
|
auto it = absl::c_find(input_value_ids, current_value->id());
|
||||||
current_value->id());
|
|
||||||
if (it != input_value_ids.end()) {
|
if (it != input_value_ids.end()) {
|
||||||
input_value_ids.erase(it);
|
input_value_ids.erase(it);
|
||||||
}
|
}
|
||||||
@ -921,8 +920,7 @@ StatusOr<std::unique_ptr<HloDataflowAnalysis>> HloDataflowAnalysis::Run(
|
|||||||
for (auto& pair : dataflow_analysis->values_) {
|
for (auto& pair : dataflow_analysis->values_) {
|
||||||
dataflow_analysis->values_vector_.push_back(&pair.second);
|
dataflow_analysis->values_vector_.push_back(&pair.second);
|
||||||
}
|
}
|
||||||
std::sort(dataflow_analysis->values_vector_.begin(),
|
absl::c_sort(dataflow_analysis->values_vector_, HloValue::IdLessThan);
|
||||||
dataflow_analysis->values_vector_.end(), HloValue::IdLessThan);
|
|
||||||
|
|
||||||
TF_DCHECK_OK(dataflow_analysis->Verify());
|
TF_DCHECK_OK(dataflow_analysis->Verify());
|
||||||
|
|
||||||
@ -937,9 +935,7 @@ Status HloDataflowAnalysis::Verify() const {
|
|||||||
for (const HloValue* value : values()) {
|
for (const HloValue* value : values()) {
|
||||||
for (const HloPosition& position : value->positions()) {
|
for (const HloPosition& position : value->positions()) {
|
||||||
const HloValueSet& value_set = GetValueSet(position);
|
const HloValueSet& value_set = GetValueSet(position);
|
||||||
TF_RET_CHECK(std::find(value_set.values().begin(),
|
TF_RET_CHECK(absl::c_linear_search(value_set.values(), value))
|
||||||
value_set.values().end(),
|
|
||||||
value) != value_set.values().end())
|
|
||||||
<< "Value set at position " << position << " does not contain value "
|
<< "Value set at position " << position << " does not contain value "
|
||||||
<< value->ToShortString();
|
<< value->ToShortString();
|
||||||
}
|
}
|
||||||
@ -954,9 +950,7 @@ Status HloDataflowAnalysis::Verify() const {
|
|||||||
const HloValueSet& value_set = pair.second;
|
const HloValueSet& value_set = pair.second;
|
||||||
const HloPosition position{instruction, index};
|
const HloPosition position{instruction, index};
|
||||||
for (const HloValue* value : value_set.values()) {
|
for (const HloValue* value : value_set.values()) {
|
||||||
TF_RET_CHECK(std::find(value->positions().begin(),
|
TF_RET_CHECK(absl::c_linear_search(value->positions(), position))
|
||||||
value->positions().end(),
|
|
||||||
position) != value->positions().end())
|
|
||||||
<< "Value set at position " << position
|
<< "Value set at position " << position
|
||||||
<< " unexpectedly contains value " << value->ToShortString();
|
<< " unexpectedly contains value " << value->ToShortString();
|
||||||
}
|
}
|
||||||
@ -1041,8 +1035,7 @@ bool HloDataflowAnalysis::CanShareOperandBufferWithUser(
|
|||||||
// Check if one operand of kAdd fused root is kDot or kConvolution.
|
// Check if one operand of kAdd fused root is kDot or kConvolution.
|
||||||
auto* add = user->fused_expression_root();
|
auto* add = user->fused_expression_root();
|
||||||
auto add_operand_it =
|
auto add_operand_it =
|
||||||
std::find_if(add->operands().begin(), add->operands().end(),
|
absl::c_find_if(add->operands(), [&](HloInstruction* operand) {
|
||||||
[&](HloInstruction* operand) {
|
|
||||||
return operand->opcode() == HloOpcode::kConvolution ||
|
return operand->opcode() == HloOpcode::kConvolution ||
|
||||||
operand->opcode() == HloOpcode::kDot;
|
operand->opcode() == HloOpcode::kDot;
|
||||||
});
|
});
|
||||||
@ -1100,13 +1093,12 @@ bool HloDataflowAnalysis::CanShareOperandBufferWithUser(
|
|||||||
// *) The root instruction of the called computation is element-wise on
|
// *) The root instruction of the called computation is element-wise on
|
||||||
// 'operand'.
|
// 'operand'.
|
||||||
const bool found_caller_use =
|
const bool found_caller_use =
|
||||||
std::find_if(uses.begin(), uses.end(), [user](const HloUse& use) {
|
absl::c_find_if(uses, [user](const HloUse& use) {
|
||||||
return use.instruction == user;
|
return use.instruction == user;
|
||||||
}) != uses.end();
|
}) != uses.end();
|
||||||
auto* callee_root = user->to_apply()->root_instruction();
|
auto* callee_root = user->to_apply()->root_instruction();
|
||||||
const bool found_elementwise_callee_use =
|
const bool found_elementwise_callee_use =
|
||||||
std::find_if(
|
absl::c_find_if(uses, [callee_root](const HloUse& use) {
|
||||||
uses.begin(), uses.end(), [callee_root](const HloUse& use) {
|
|
||||||
return use.instruction == callee_root &&
|
return use.instruction == callee_root &&
|
||||||
callee_root->IsElementwiseOnOperand(use.operand_number);
|
callee_root->IsElementwiseOnOperand(use.operand_number);
|
||||||
}) != uses.end();
|
}) != uses.end();
|
||||||
|
@ -43,9 +43,7 @@ class HloDceTest : public HloTestBase {
|
|||||||
// Returns whether the given instruction exists in the given computation.
|
// Returns whether the given instruction exists in the given computation.
|
||||||
bool HasInstruction(const HloComputation& computation,
|
bool HasInstruction(const HloComputation& computation,
|
||||||
const HloInstruction* instruction) {
|
const HloInstruction* instruction) {
|
||||||
return std::find(computation.instructions().begin(),
|
return absl::c_linear_search(computation.instructions(), instruction);
|
||||||
computation.instructions().end(),
|
|
||||||
instruction) != computation.instructions().end();
|
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
@ -230,7 +230,7 @@ HloDomainMap::MakeNonDomainInstructions(
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
// sort instructions according to instructions_order
|
// sort instructions according to instructions_order
|
||||||
std::sort(instructions.begin(), instructions.end(),
|
absl::c_sort(instructions,
|
||||||
[&instructions_order](HloInstruction* a, HloInstruction* b) {
|
[&instructions_order](HloInstruction* a, HloInstruction* b) {
|
||||||
return instructions_order.at(a) < instructions_order.at(b);
|
return instructions_order.at(a) < instructions_order.at(b);
|
||||||
});
|
});
|
||||||
|
@ -1248,8 +1248,7 @@ StatusOr<Literal> EvaluateSortInternal(HloInstruction* sort,
|
|||||||
// Extract a slice from the keys and values literals that correspond to
|
// Extract a slice from the keys and values literals that correspond to
|
||||||
// exactly the row in dimension 'sort_dim'.
|
// exactly the row in dimension 'sort_dim'.
|
||||||
std::vector<int64> limit_indices(indices.begin(), indices.end());
|
std::vector<int64> limit_indices(indices.begin(), indices.end());
|
||||||
std::for_each(limit_indices.begin(), limit_indices.end(),
|
absl::c_for_each(limit_indices, [](int64& index) { ++index; });
|
||||||
[](int64& index) { ++index; });
|
|
||||||
limit_indices[sort_dim] = sort_dim_elements;
|
limit_indices[sort_dim] = sort_dim_elements;
|
||||||
TF_ASSIGN_OR_RETURN(auto keys_to_sort,
|
TF_ASSIGN_OR_RETURN(auto keys_to_sort,
|
||||||
keys_literal.Slice(indices, limit_indices)
|
keys_literal.Slice(indices, limit_indices)
|
||||||
|
@ -1673,8 +1673,7 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault {
|
|||||||
// Extract a slice from the literal that corresponds to exactly the
|
// Extract a slice from the literal that corresponds to exactly the
|
||||||
// row in dimension 'sort_dim'.
|
// row in dimension 'sort_dim'.
|
||||||
std::vector<int64> limit_indices(indices.begin(), indices.end());
|
std::vector<int64> limit_indices(indices.begin(), indices.end());
|
||||||
std::for_each(limit_indices.begin(), limit_indices.end(),
|
absl::c_for_each(limit_indices, [](int64& index) { ++index; });
|
||||||
[](int64& index) { ++index; });
|
|
||||||
limit_indices[sort_dim] = sort_dim_elements;
|
limit_indices[sort_dim] = sort_dim_elements;
|
||||||
TF_ASSIGN_OR_RETURN(auto row_to_sort,
|
TF_ASSIGN_OR_RETURN(auto row_to_sort,
|
||||||
keys_literal.Slice(indices, limit_indices)
|
keys_literal.Slice(indices, limit_indices)
|
||||||
|
@ -561,8 +561,8 @@ bool HloDotDumper::ShouldShowSubcomputation(const HloComputation* subcomp) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Show the subcomputation if we're showing any of its members.
|
// Show the subcomputation if we're showing any of its members.
|
||||||
return std::any_of(
|
return absl::c_any_of(
|
||||||
subcomp->instructions().begin(), subcomp->instructions().end(),
|
subcomp->instructions(),
|
||||||
[&](const HloInstruction* instr) { return filter_.Show(instr); });
|
[&](const HloInstruction* instr) { return filter_.Show(instr); });
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -735,12 +735,11 @@ bool HloDotDumper::ShouldMergeIntoUsers(const HloInstruction* instr) const {
|
|||||||
const int kMinUsersToOmit = 3;
|
const int kMinUsersToOmit = 3;
|
||||||
return instr->opcode() == HloOpcode::kParameter && instr->shape().IsTuple() &&
|
return instr->opcode() == HloOpcode::kParameter && instr->shape().IsTuple() &&
|
||||||
!instr->IsFused() &&
|
!instr->IsFused() &&
|
||||||
std::count_if(instr->users().begin(), instr->users().end(),
|
absl::c_count_if(instr->users(),
|
||||||
[&](const HloInstruction* user) {
|
[&](const HloInstruction* user) {
|
||||||
return filter_.Show(user);
|
return filter_.Show(user);
|
||||||
}) > kMinUsersToOmit &&
|
}) > kMinUsersToOmit &&
|
||||||
std::all_of(instr->users().begin(), instr->users().end(),
|
absl::c_all_of(instr->users(), [&](const HloInstruction* user) {
|
||||||
[&](const HloInstruction* user) {
|
|
||||||
return !filter_.Show(user) ||
|
return !filter_.Show(user) ||
|
||||||
user->opcode() == HloOpcode::kGetTupleElement;
|
user->opcode() == HloOpcode::kGetTupleElement;
|
||||||
});
|
});
|
||||||
@ -900,8 +899,7 @@ ColorScheme HloDotDumper::GetInstructionColor(const HloInstruction* instr) {
|
|||||||
// the same color as a parameter. Unless the merged-in parameter is a
|
// the same color as a parameter. Unless the merged-in parameter is a
|
||||||
// parameter to a fusion node that is bound to a constant -- these aren't
|
// parameter to a fusion node that is bound to a constant -- these aren't
|
||||||
// "real" parameters from the user's perspective.
|
// "real" parameters from the user's perspective.
|
||||||
if (std::any_of(instr->operands().begin(), instr->operands().end(),
|
if (absl::c_any_of(instr->operands(), [&](const HloInstruction* operand) {
|
||||||
[&](const HloInstruction* operand) {
|
|
||||||
return operand->opcode() == HloOpcode::kParameter &&
|
return operand->opcode() == HloOpcode::kParameter &&
|
||||||
ShouldMergeIntoUsers(operand) &&
|
ShouldMergeIntoUsers(operand) &&
|
||||||
TryGetFusionParameterConstant(operand) == nullptr;
|
TryGetFusionParameterConstant(operand) == nullptr;
|
||||||
@ -1355,12 +1353,11 @@ NodeFilter MakeNodeRadiusAroundFilter(const HloInstruction* root,
|
|||||||
NodeFilterResult& filter_result = kv.second;
|
NodeFilterResult& filter_result = kv.second;
|
||||||
const auto& operands = instr->operands();
|
const auto& operands = instr->operands();
|
||||||
|
|
||||||
if (std::any_of(operands.begin(), operands.end(), is_displayed) &&
|
if (absl::c_any_of(operands, is_displayed) &&
|
||||||
!std::all_of(operands.begin(), operands.end(), is_displayed)) {
|
!absl::c_all_of(operands, is_displayed)) {
|
||||||
// Mark nodes with some operands omitted appropriately.
|
// Mark nodes with some operands omitted appropriately.
|
||||||
filter_result = kSomeOperandsOmitted;
|
filter_result = kSomeOperandsOmitted;
|
||||||
} else if (!operands.empty() &&
|
} else if (!operands.empty() && absl::c_none_of(operands, is_displayed)) {
|
||||||
std::none_of(operands.begin(), operands.end(), is_displayed)) {
|
|
||||||
// Mark nodes with *all* operands omitted appropriately.
|
// Mark nodes with *all* operands omitted appropriately.
|
||||||
filter_result = kOmitNodeOperands;
|
filter_result = kOmitNodeOperands;
|
||||||
}
|
}
|
||||||
@ -1368,8 +1365,7 @@ NodeFilter MakeNodeRadiusAroundFilter(const HloInstruction* root,
|
|||||||
// Promote nodes with type kSomeUsersOmitted to kNormalNode if all of their
|
// Promote nodes with type kSomeUsersOmitted to kNormalNode if all of their
|
||||||
// users made it into the graph.
|
// users made it into the graph.
|
||||||
if (filter_result == kSomeUsersOmitted &&
|
if (filter_result == kSomeUsersOmitted &&
|
||||||
std::all_of(instr->users().begin(), instr->users().end(),
|
absl::c_all_of(instr->users(), is_displayed)) {
|
||||||
is_displayed)) {
|
|
||||||
filter_result = kNormalNode;
|
filter_result = kNormalNode;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -83,15 +83,14 @@ StatusOr<std::unique_ptr<HloInstruction>> HloInstruction::CreateFromProto(
|
|||||||
return computation_map.at(proto.called_computation_ids(index));
|
return computation_map.at(proto.called_computation_ids(index));
|
||||||
};
|
};
|
||||||
|
|
||||||
TF_RET_CHECK(std::all_of(
|
TF_RET_CHECK(
|
||||||
proto.operand_ids().begin(), proto.operand_ids().end(),
|
absl::c_all_of(proto.operand_ids(),
|
||||||
[&instruction_map](int64 id) { return instruction_map.contains(id); }))
|
[&](int64 id) { return instruction_map.contains(id); }))
|
||||||
<< proto.name() << " instruction contains invalid operand id(s)";
|
<< proto.name() << " instruction contains invalid operand id(s)";
|
||||||
|
|
||||||
TF_RET_CHECK(std::all_of(
|
TF_RET_CHECK(
|
||||||
proto.called_computation_ids().begin(),
|
absl::c_all_of(proto.called_computation_ids(),
|
||||||
proto.called_computation_ids().end(),
|
[&](int64 id) { return computation_map.contains(id); }))
|
||||||
[&computation_map](int64 id) { return computation_map.contains(id); }))
|
|
||||||
<< proto.name() << " instruction references invalid computation id(s)";
|
<< proto.name() << " instruction references invalid computation id(s)";
|
||||||
|
|
||||||
Shape shape(proto.shape());
|
Shape shape(proto.shape());
|
||||||
@ -1599,12 +1598,10 @@ HloInstruction::InstructionVector HloInstruction::unique_operands() const {
|
|||||||
|
|
||||||
Status HloInstruction::AddControlDependencyTo(HloInstruction* instruction) {
|
Status HloInstruction::AddControlDependencyTo(HloInstruction* instruction) {
|
||||||
TF_RET_CHECK(instruction->parent() == parent());
|
TF_RET_CHECK(instruction->parent() == parent());
|
||||||
if (std::find(control_successors_.begin(), control_successors_.end(),
|
if (!absl::c_linear_search(control_successors_, instruction)) {
|
||||||
instruction) == control_successors_.end()) {
|
|
||||||
control_successors_.push_back(instruction);
|
control_successors_.push_back(instruction);
|
||||||
TF_RET_CHECK(std::find(instruction->control_predecessors_.begin(),
|
TF_RET_CHECK(
|
||||||
instruction->control_predecessors_.end(),
|
!absl::c_linear_search(instruction->control_predecessors_, this));
|
||||||
this) == instruction->control_predecessors_.end());
|
|
||||||
instruction->control_predecessors_.push_back(this);
|
instruction->control_predecessors_.push_back(this);
|
||||||
}
|
}
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
@ -1853,7 +1850,7 @@ void HloInstruction::RemoveUser(HloInstruction* user) {
|
|||||||
user_set_.erase(set_it);
|
user_set_.erase(set_it);
|
||||||
// This is linear in the number of the users, but a vector provides a stable
|
// This is linear in the number of the users, but a vector provides a stable
|
||||||
// iteration order and much faster traversal.
|
// iteration order and much faster traversal.
|
||||||
auto vec_it = std::find(users_.begin(), users_.end(), user);
|
auto vec_it = absl::c_find(users_, user);
|
||||||
CHECK(vec_it != users_.end());
|
CHECK(vec_it != users_.end());
|
||||||
users_.erase(vec_it);
|
users_.erase(vec_it);
|
||||||
}
|
}
|
||||||
@ -1871,8 +1868,7 @@ Status HloInstruction::ReplaceUseWith(HloInstruction* user,
|
|||||||
|
|
||||||
RemoveUser(user);
|
RemoveUser(user);
|
||||||
|
|
||||||
TF_RET_CHECK(
|
TF_RET_CHECK(absl::c_count(user->operands_, this) >= 0);
|
||||||
std::count(user->operands_.begin(), user->operands_.end(), this) >= 0);
|
|
||||||
std::replace(user->operands_.begin(), user->operands_.end(), this,
|
std::replace(user->operands_.begin(), user->operands_.end(), this,
|
||||||
new_producer);
|
new_producer);
|
||||||
new_producer->AddUser(user);
|
new_producer->AddUser(user);
|
||||||
@ -1907,8 +1903,7 @@ Status HloInstruction::ReplaceOperandWithDifferentShape(
|
|||||||
VLOG(3) << "Replacing operand " << operand_num << " of " << name() << " with "
|
VLOG(3) << "Replacing operand " << operand_num << " of " << name() << " with "
|
||||||
<< new_operand->name() << ", was " << old_operand->name();
|
<< new_operand->name() << ", was " << old_operand->name();
|
||||||
|
|
||||||
if (std::find(operands_.begin(), operands_.end(), old_operand) ==
|
if (!absl::c_linear_search(operands_, old_operand)) {
|
||||||
operands_.end()) {
|
|
||||||
old_operand->RemoveUser(this);
|
old_operand->RemoveUser(this);
|
||||||
}
|
}
|
||||||
new_operand->AddUser(this);
|
new_operand->AddUser(this);
|
||||||
@ -2945,7 +2940,7 @@ StatusOr<HloInstruction::FusionKind> StringToFusionKind(
|
|||||||
|
|
||||||
string PaddingConfigToString(const PaddingConfig& padding) {
|
string PaddingConfigToString(const PaddingConfig& padding) {
|
||||||
bool has_interior_padding =
|
bool has_interior_padding =
|
||||||
std::any_of(padding.dimensions().begin(), padding.dimensions().end(),
|
absl::c_any_of(padding.dimensions(),
|
||||||
[](const PaddingConfig::PaddingConfigDimension& dim) {
|
[](const PaddingConfig::PaddingConfigDimension& dim) {
|
||||||
return dim.interior_padding() != 0;
|
return dim.interior_padding() != 0;
|
||||||
});
|
});
|
||||||
|
@ -42,9 +42,7 @@ using absl::StrJoin;
|
|||||||
bool IsInstructionElementwiseOnOperand(const HloInstruction* instruction,
|
bool IsInstructionElementwiseOnOperand(const HloInstruction* instruction,
|
||||||
const HloInstruction* operand) {
|
const HloInstruction* operand) {
|
||||||
std::vector<int64> operand_indices = instruction->OperandIndices(operand);
|
std::vector<int64> operand_indices = instruction->OperandIndices(operand);
|
||||||
return std::all_of(
|
return absl::c_all_of(operand_indices, [instruction](int64 operand_index) {
|
||||||
operand_indices.begin(), operand_indices.end(),
|
|
||||||
[instruction](int64 operand_index) {
|
|
||||||
return instruction->IsElementwiseOnOperand(operand_index);
|
return instruction->IsElementwiseOnOperand(operand_index);
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
@ -814,8 +812,7 @@ std::vector<string> HloSliceInstruction::ExtraAttributesToStringImpl(
|
|||||||
std::vector<string> bounds;
|
std::vector<string> bounds;
|
||||||
bounds.reserve(slice_starts_.size());
|
bounds.reserve(slice_starts_.size());
|
||||||
const bool omit_stride =
|
const bool omit_stride =
|
||||||
std::all_of(slice_strides_.begin(), slice_strides_.end(),
|
absl::c_all_of(slice_strides_, [](int64 stride) { return stride == 1; });
|
||||||
[](int64 stride) { return stride == 1; });
|
|
||||||
for (int i = 0; i < slice_starts_.size(); ++i) {
|
for (int i = 0; i < slice_starts_.size(); ++i) {
|
||||||
string stride_str = omit_stride ? "" : StrCat(":", slice_strides_[i]);
|
string stride_str = omit_stride ? "" : StrCat(":", slice_strides_[i]);
|
||||||
bounds.push_back(
|
bounds.push_back(
|
||||||
@ -1051,8 +1048,7 @@ HloInstruction* HloFusionInstruction::AddFusionOperand(
|
|||||||
|
|
||||||
void HloFusionInstruction::MergeFusionInstruction(
|
void HloFusionInstruction::MergeFusionInstruction(
|
||||||
HloFusionInstruction* instruction_to_merge) {
|
HloFusionInstruction* instruction_to_merge) {
|
||||||
CHECK(std::find(operands().begin(), operands().end(), instruction_to_merge) !=
|
CHECK(absl::c_linear_search(operands(), instruction_to_merge));
|
||||||
operands().end());
|
|
||||||
// Clone the instruction from which to merge fused instructions.
|
// Clone the instruction from which to merge fused instructions.
|
||||||
std::unique_ptr<HloInstruction> cloned = instruction_to_merge->Clone();
|
std::unique_ptr<HloInstruction> cloned = instruction_to_merge->Clone();
|
||||||
HloFusionInstruction* cloned_fusion =
|
HloFusionInstruction* cloned_fusion =
|
||||||
@ -1219,8 +1215,8 @@ HloInstruction* HloFusionInstruction::CloneAndFuseInternal(
|
|||||||
// corresponding fused parameter instruction. Renumber parameters as
|
// corresponding fused parameter instruction. Renumber parameters as
|
||||||
// necessary to make parameter numbers consistent with their index in the
|
// necessary to make parameter numbers consistent with their index in the
|
||||||
// fused_parameter_ vector.
|
// fused_parameter_ vector.
|
||||||
bool in_operand_list = std::find(operands().begin(), operands().end(),
|
bool in_operand_list =
|
||||||
instruction_to_fuse) != operands().end();
|
absl::c_linear_search(operands(), instruction_to_fuse);
|
||||||
CHECK(add_output || in_operand_list);
|
CHECK(add_output || in_operand_list);
|
||||||
if (instruction_to_fuse->opcode() == HloOpcode::kTuple) {
|
if (instruction_to_fuse->opcode() == HloOpcode::kTuple) {
|
||||||
// We assume all uses of a kTuple operation are GTE ops, not another
|
// We assume all uses of a kTuple operation are GTE ops, not another
|
||||||
|
@ -107,9 +107,8 @@ HloComputation* HloModule::AddEntryComputation(
|
|||||||
}
|
}
|
||||||
|
|
||||||
Status HloModule::RemoveEmbeddedComputation(HloComputation* to_remove) {
|
Status HloModule::RemoveEmbeddedComputation(HloComputation* to_remove) {
|
||||||
auto it =
|
auto it = absl::c_find_if(
|
||||||
std::find_if(computations_.begin(), computations_.end(),
|
computations_, [&to_remove](const std::unique_ptr<HloComputation>& comp) {
|
||||||
[&to_remove](const std::unique_ptr<HloComputation>& comp) {
|
|
||||||
return comp.get() == to_remove;
|
return comp.get() == to_remove;
|
||||||
});
|
});
|
||||||
TF_RET_CHECK(it->get() == to_remove);
|
TF_RET_CHECK(it->get() == to_remove);
|
||||||
@ -304,8 +303,7 @@ StatusOr<std::unique_ptr<HloModule>> HloModule::CreateFromProto(
|
|||||||
auto module = absl::make_unique<HloModule>(proto.name(), module_config);
|
auto module = absl::make_unique<HloModule>(proto.name(), module_config);
|
||||||
|
|
||||||
// Sort the computations in the proto id's order.
|
// Sort the computations in the proto id's order.
|
||||||
std::sort(computations.begin(), computations.end(),
|
absl::c_sort(computations, [&](const std::unique_ptr<HloComputation>& a,
|
||||||
[&](const std::unique_ptr<HloComputation>& a,
|
|
||||||
const std::unique_ptr<HloComputation>& b) {
|
const std::unique_ptr<HloComputation>& b) {
|
||||||
return to_proto_id[a.get()] < to_proto_id[b.get()];
|
return to_proto_id[a.get()] < to_proto_id[b.get()];
|
||||||
});
|
});
|
||||||
|
@ -38,9 +38,7 @@ class HloModuleDceTest : public HloTestBase {
|
|||||||
// Returns whether the given instruction exists in the given computation.
|
// Returns whether the given instruction exists in the given computation.
|
||||||
bool HasInstruction(const HloComputation& computation,
|
bool HasInstruction(const HloComputation& computation,
|
||||||
const HloInstruction* instruction) {
|
const HloInstruction* instruction) {
|
||||||
return std::find(computation.instructions().begin(),
|
return absl::c_linear_search(computation.instructions(), instruction);
|
||||||
computation.instructions().end(),
|
|
||||||
instruction) != computation.instructions().end();
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Returns whether the while instruction with name 'while_name' in
|
// Returns whether the while instruction with name 'while_name' in
|
||||||
|
@ -2746,7 +2746,7 @@ bool HloParser::ParseConvolutionDimensionNumbers(
|
|||||||
}
|
}
|
||||||
|
|
||||||
auto is_unique = [](string str) -> bool {
|
auto is_unique = [](string str) -> bool {
|
||||||
std::sort(str.begin(), str.end());
|
absl::c_sort(str);
|
||||||
return std::unique(str.begin(), str.end()) == str.end();
|
return std::unique(str.begin(), str.end()) == str.end();
|
||||||
};
|
};
|
||||||
|
|
||||||
|
@ -15,6 +15,7 @@ limitations under the License.
|
|||||||
|
|
||||||
#include "tensorflow/compiler/xla/service/hlo_profile_printer.h"
|
#include "tensorflow/compiler/xla/service/hlo_profile_printer.h"
|
||||||
|
|
||||||
|
#include "absl/algorithm/container.h"
|
||||||
#include "absl/strings/str_cat.h"
|
#include "absl/strings/str_cat.h"
|
||||||
#include "tensorflow/compiler/xla/service/human_readable_profile_builder.h"
|
#include "tensorflow/compiler/xla/service/human_readable_profile_builder.h"
|
||||||
|
|
||||||
@ -34,9 +35,8 @@ string PrintHloProfile(const HloProfilePrinterData& hlo_profile_printer_data,
|
|||||||
for (const HloComputationInfo& computation_info :
|
for (const HloComputationInfo& computation_info :
|
||||||
hlo_profile_printer_data.computation_infos()) {
|
hlo_profile_printer_data.computation_infos()) {
|
||||||
const auto& instruction_infos = computation_info.instruction_infos();
|
const auto& instruction_infos = computation_info.instruction_infos();
|
||||||
bool any_instruction_profiled =
|
bool any_instruction_profiled = absl::c_any_of(
|
||||||
std::any_of(instruction_infos.begin(), instruction_infos.end(),
|
instruction_infos, [&](const HloInstructionInfo& instruction_info) {
|
||||||
[&](const HloInstructionInfo& instruction_info) {
|
|
||||||
return counters[instruction_info.profile_index()] != 0;
|
return counters[instruction_info.profile_index()] != 0;
|
||||||
});
|
});
|
||||||
|
|
||||||
|
@ -49,7 +49,7 @@ void HloReachabilityMap::SetReachabilityToUnionHelper(
|
|||||||
absl::Span<const HloInstruction* const> inputs,
|
absl::Span<const HloInstruction* const> inputs,
|
||||||
const HloInstruction* instruction, BitVector* bit_vector) {
|
const HloInstruction* instruction, BitVector* bit_vector) {
|
||||||
// If instruction is part of inputs, don't reset the bit_vector.
|
// If instruction is part of inputs, don't reset the bit_vector.
|
||||||
if (std::find(inputs.begin(), inputs.end(), instruction) == inputs.end()) {
|
if (!absl::c_linear_search(inputs, instruction)) {
|
||||||
bit_vector->SetToZero();
|
bit_vector->SetToZero();
|
||||||
}
|
}
|
||||||
bit_vector->Set(GetIndex(instruction));
|
bit_vector->Set(GetIndex(instruction));
|
||||||
|
@ -235,8 +235,7 @@ class InstructionList {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Now scan forwards until we find one of the before_instructions.
|
// Now scan forwards until we find one of the before_instructions.
|
||||||
while (std::find(before_instructions.begin(), before_instructions.end(),
|
while (!absl::c_linear_search(before_instructions, min_position_item)) {
|
||||||
min_position_item) == before_instructions.end()) {
|
|
||||||
min_position_item = min_position_item->next;
|
min_position_item = min_position_item->next;
|
||||||
}
|
}
|
||||||
return InsertBefore(to_insert, min_position_item);
|
return InsertBefore(to_insert, min_position_item);
|
||||||
@ -302,7 +301,7 @@ ItemList GetUsers(const InstructionList& instruction_list,
|
|||||||
// A buffer may be used by the instruction via more than one alias. For
|
// A buffer may be used by the instruction via more than one alias. For
|
||||||
// example, a buffer which appears in more than one element of a tuple.
|
// example, a buffer which appears in more than one element of a tuple.
|
||||||
Item* user_item = instruction_list.GetItem(user);
|
Item* user_item = instruction_list.GetItem(user);
|
||||||
if (std::find(users.begin(), users.end(), user_item) == users.end()) {
|
if (!absl::c_linear_search(users, user_item)) {
|
||||||
users.push_back(user_item);
|
users.push_back(user_item);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -456,8 +455,7 @@ class MemoryUsageTracker {
|
|||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
const BufferIdList& in_progress_uses = in_progress_item_->buffers_used;
|
const BufferIdList& in_progress_uses = in_progress_item_->buffers_used;
|
||||||
return std::find(in_progress_uses.begin(), in_progress_uses.end(),
|
return absl::c_linear_search(in_progress_uses, buffer_id);
|
||||||
buffer_id) != in_progress_uses.end();
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Returns whether the given instruction is live at the current program
|
// Returns whether the given instruction is live at the current program
|
||||||
@ -535,8 +533,7 @@ MemoryUsageTracker::MemoryUsageTracker(
|
|||||||
bool unused;
|
bool unused;
|
||||||
for (Item* user_item : GetUsers(instruction_list_, logical_buffer,
|
for (Item* user_item : GetUsers(instruction_list_, logical_buffer,
|
||||||
points_to_analysis, &unused)) {
|
points_to_analysis, &unused)) {
|
||||||
if (std::find(buffer->users.begin(), buffer->users.end(),
|
if (!absl::c_linear_search(buffer->users, user_item)) {
|
||||||
user_item) == buffer->users.end()) {
|
|
||||||
buffer->users.push_back(user_item);
|
buffer->users.push_back(user_item);
|
||||||
buffer->unfinished_user_count++;
|
buffer->unfinished_user_count++;
|
||||||
user_item->buffers_used.push_back(buffer->id);
|
user_item->buffers_used.push_back(buffer->id);
|
||||||
@ -784,8 +781,7 @@ bool MemoryUsageTracker::Check() const {
|
|||||||
|
|
||||||
for (const Buffer& buffer : buffers_) {
|
for (const Buffer& buffer : buffers_) {
|
||||||
if (buffer.defining_instruction->instruction == instruction) {
|
if (buffer.defining_instruction->instruction == instruction) {
|
||||||
CHECK(std::find(defined_buffers.begin(), defined_buffers.end(),
|
CHECK(absl::c_linear_search(defined_buffers, buffer.id))
|
||||||
buffer.id) != defined_buffers.end())
|
|
||||||
<< "Instruction " << instruction->name()
|
<< "Instruction " << instruction->name()
|
||||||
<< " defined buffers is missing: " << buffer.ToString();
|
<< " defined buffers is missing: " << buffer.ToString();
|
||||||
}
|
}
|
||||||
@ -808,8 +804,7 @@ bool MemoryUsageTracker::Check() const {
|
|||||||
int64 unfinished_uses = 0;
|
int64 unfinished_uses = 0;
|
||||||
for (Item* user : buffer.users) {
|
for (Item* user : buffer.users) {
|
||||||
const BufferIdList& used_buffers = user->buffers_used;
|
const BufferIdList& used_buffers = user->buffers_used;
|
||||||
CHECK(std::find(used_buffers.begin(), used_buffers.end(), buffer.id) !=
|
CHECK(absl::c_linear_search(used_buffers, buffer.id))
|
||||||
used_buffers.end())
|
|
||||||
<< "Instruction " << user->instruction->name()
|
<< "Instruction " << user->instruction->name()
|
||||||
<< " used buffers is missing " << buffer.ToString();
|
<< " used buffers is missing " << buffer.ToString();
|
||||||
if (!IsFinished(user)) {
|
if (!IsFinished(user)) {
|
||||||
@ -836,7 +831,7 @@ int64 RematerializationCost(const HloInstruction* instruction,
|
|||||||
// If none of the users of 'instruction' have been placed in the sequence (as
|
// If none of the users of 'instruction' have been placed in the sequence (as
|
||||||
// tracked by memory_tracker), then rematerialization of 'instruction' is a
|
// tracked by memory_tracker), then rematerialization of 'instruction' is a
|
||||||
// zero-cost move of 'instruction' in the sequence.
|
// zero-cost move of 'instruction' in the sequence.
|
||||||
if (!std::any_of(instruction->users().begin(), instruction->users().end(),
|
if (!absl::c_any_of(instruction->users(),
|
||||||
[&memory_tracker](const HloInstruction* inst) {
|
[&memory_tracker](const HloInstruction* inst) {
|
||||||
return memory_tracker.IsPlaced(inst);
|
return memory_tracker.IsPlaced(inst);
|
||||||
})) {
|
})) {
|
||||||
|
@ -107,13 +107,12 @@ string HloSharding::ToString() const {
|
|||||||
|
|
||||||
bool HloSharding::UsesDevice(int64 device) const {
|
bool HloSharding::UsesDevice(int64 device) const {
|
||||||
if (IsTuple()) {
|
if (IsTuple()) {
|
||||||
return std::any_of(
|
return absl::c_any_of(tuple_elements_, [&](const HloSharding& s) {
|
||||||
tuple_elements_.begin(), tuple_elements_.end(),
|
return s.UsesDevice(device);
|
||||||
[&](const HloSharding& s) { return s.UsesDevice(device); });
|
});
|
||||||
}
|
}
|
||||||
const auto& devices = tile_assignment_;
|
const auto& devices = tile_assignment_;
|
||||||
return replicated_ ||
|
return replicated_ || absl::c_linear_search(devices, device);
|
||||||
std::find(devices.begin(), devices.end(), device) != devices.end();
|
|
||||||
}
|
}
|
||||||
|
|
||||||
std::map<int64, int64> HloSharding::UsedDevices(int64* count) const {
|
std::map<int64, int64> HloSharding::UsedDevices(int64* count) const {
|
||||||
|
@ -101,8 +101,8 @@ class HloSharding {
|
|||||||
if (!IsTuple()) {
|
if (!IsTuple()) {
|
||||||
return replicated_;
|
return replicated_;
|
||||||
}
|
}
|
||||||
return std::all_of(tuple_elements_.begin(), tuple_elements_.end(),
|
return absl::c_all_of(
|
||||||
[](const HloSharding& s) { return s.IsReplicated(); });
|
tuple_elements_, [](const HloSharding& s) { return s.IsReplicated(); });
|
||||||
}
|
}
|
||||||
|
|
||||||
// Returns true if the tile size is the same as the input size.
|
// Returns true if the tile size is the same as the input size.
|
||||||
@ -110,8 +110,9 @@ class HloSharding {
|
|||||||
if (!IsTuple()) {
|
if (!IsTuple()) {
|
||||||
return maximal_;
|
return maximal_;
|
||||||
}
|
}
|
||||||
return std::all_of(tuple_elements_.begin(), tuple_elements_.end(),
|
return absl::c_all_of(tuple_elements_, [](const HloSharding& s) {
|
||||||
[](const HloSharding& s) { return s.IsTileMaximal(); });
|
return s.IsTileMaximal();
|
||||||
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
// Returns true if the sharding defines an operation on the given device.
|
// Returns true if the sharding defines an operation on the given device.
|
||||||
|
@ -61,8 +61,7 @@ void CleanNodeName(string* name) {
|
|||||||
name->erase(std::remove(name->begin(), name->end(), '%'), name->end());
|
name->erase(std::remove(name->begin(), name->end(), '%'), name->end());
|
||||||
const string chars_to_replace = "<>[]";
|
const string chars_to_replace = "<>[]";
|
||||||
auto pred = [&](char c) {
|
auto pred = [&](char c) {
|
||||||
return std::find(chars_to_replace.begin(), chars_to_replace.end(), c) !=
|
return absl::c_linear_search(chars_to_replace, c);
|
||||||
chars_to_replace.end();
|
|
||||||
};
|
};
|
||||||
std::replace_if(name->begin(), name->end(), pred, '_');
|
std::replace_if(name->begin(), name->end(), pred, '_');
|
||||||
}
|
}
|
||||||
|
@ -209,7 +209,7 @@ std::ostream& operator<<(std::ostream& out, const HloValue& value) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
void HloValueSet::SortAndUniquifyValues() {
|
void HloValueSet::SortAndUniquifyValues() {
|
||||||
std::sort(values_.begin(), values_.end(), HloValue::IdLessThan);
|
absl::c_sort(values_, HloValue::IdLessThan);
|
||||||
values_.erase(std::unique(values_.begin(), values_.end(), HloValue::IdEqual),
|
values_.erase(std::unique(values_.begin(), values_.end(), HloValue::IdEqual),
|
||||||
values_.end());
|
values_.end());
|
||||||
}
|
}
|
||||||
|
@ -128,9 +128,9 @@ string HumanReadableProfileBuilder::ToString() const {
|
|||||||
|
|
||||||
// Sort ops in decreasing order of cycles, and print them.
|
// Sort ops in decreasing order of cycles, and print them.
|
||||||
std::vector<OpInfo> sorted_ops(op_infos_);
|
std::vector<OpInfo> sorted_ops(op_infos_);
|
||||||
std::sort(
|
absl::c_sort(sorted_ops, [](const OpInfo& a, const OpInfo& b) {
|
||||||
sorted_ops.begin(), sorted_ops.end(),
|
return a.cycles > b.cycles;
|
||||||
[](const OpInfo& a, const OpInfo& b) { return a.cycles > b.cycles; });
|
});
|
||||||
for (const auto& op : sorted_ops) {
|
for (const auto& op : sorted_ops) {
|
||||||
print_op(op);
|
print_op(op);
|
||||||
}
|
}
|
||||||
|
@ -178,8 +178,8 @@ bool InstructionFusion::EffectivelyAtMostUnary(HloInstruction* hlo) {
|
|||||||
output_rank = std::max(output_rank, ShapeUtil::TrueRank(subshape));
|
output_rank = std::max(output_rank, ShapeUtil::TrueRank(subshape));
|
||||||
}
|
}
|
||||||
});
|
});
|
||||||
return std::count_if(hlo->operands().begin(), hlo->operands().end(),
|
return absl::c_count_if(
|
||||||
[output_rank](HloInstruction* operand) {
|
hlo->operands(), [output_rank](HloInstruction* operand) {
|
||||||
if (operand->opcode() == HloOpcode::kBroadcast ||
|
if (operand->opcode() == HloOpcode::kBroadcast ||
|
||||||
operand->opcode() == HloOpcode::kIota) {
|
operand->opcode() == HloOpcode::kIota) {
|
||||||
return false;
|
return false;
|
||||||
@ -188,8 +188,7 @@ bool InstructionFusion::EffectivelyAtMostUnary(HloInstruction* hlo) {
|
|||||||
ShapeUtil::IsEffectiveScalar(operand->shape())) {
|
ShapeUtil::IsEffectiveScalar(operand->shape())) {
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
return ShapeUtil::TrueRank(operand->shape()) >=
|
return ShapeUtil::TrueRank(operand->shape()) >= output_rank;
|
||||||
output_rank;
|
|
||||||
}) <= 1;
|
}) <= 1;
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -409,9 +408,8 @@ class ReversePostOrderFusionQueue : public FusionQueue {
|
|||||||
}
|
}
|
||||||
sorted_operand_numbers.push_back(i);
|
sorted_operand_numbers.push_back(i);
|
||||||
}
|
}
|
||||||
std::sort(
|
absl::c_sort(
|
||||||
sorted_operand_numbers.begin(), sorted_operand_numbers.end(),
|
sorted_operand_numbers, [&](int64 i, int64 j) {
|
||||||
[&](int64 i, int64 j) {
|
|
||||||
// Instructions with higher priority in the queue come first.
|
// Instructions with higher priority in the queue come first.
|
||||||
return (
|
return (
|
||||||
FindOrDie(post_order_index_, instruction->mutable_operand(i)) >
|
FindOrDie(post_order_index_, instruction->mutable_operand(i)) >
|
||||||
|
@ -147,12 +147,9 @@ bool LayoutConstraints::OperandBufferForwarded(
|
|||||||
PointsToSet::BufferSet* output_buffers = GetBufferSet(instruction);
|
PointsToSet::BufferSet* output_buffers = GetBufferSet(instruction);
|
||||||
PointsToSet::BufferSet* operand_buffers =
|
PointsToSet::BufferSet* operand_buffers =
|
||||||
GetBufferSet(instruction->operand(operand_no));
|
GetBufferSet(instruction->operand(operand_no));
|
||||||
for (const LogicalBuffer* output_buffer : *output_buffers) {
|
return absl::c_any_of(*output_buffers, [&](const LogicalBuffer* b) {
|
||||||
if (operand_buffers->count(output_buffer) > 0) {
|
return operand_buffers->count(b) > 0;
|
||||||
return true;
|
});
|
||||||
}
|
|
||||||
}
|
|
||||||
return false;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
Status LayoutConstraints::SetBufferLayout(const Layout& layout,
|
Status LayoutConstraints::SetBufferLayout(const Layout& layout,
|
||||||
|
@ -81,9 +81,7 @@ void AliasAnalysis::AddAliasingInformationToIrArray(const HloInstruction& hlo,
|
|||||||
if (hlo.opcode() == HloOpcode::kParameter) {
|
if (hlo.opcode() == HloOpcode::kParameter) {
|
||||||
const std::vector<HloInstruction*>& parameter_instructions =
|
const std::vector<HloInstruction*>& parameter_instructions =
|
||||||
module_.entry_computation()->parameter_instructions();
|
module_.entry_computation()->parameter_instructions();
|
||||||
if (std::find(parameter_instructions.begin(),
|
if (absl::c_linear_search(parameter_instructions, &hlo)) {
|
||||||
parameter_instructions.end(),
|
|
||||||
&hlo) != parameter_instructions.end()) {
|
|
||||||
array->MarkInvariantOverWholeProgram(context_);
|
array->MarkInvariantOverWholeProgram(context_);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -34,7 +34,7 @@ bool IsAllowed(char character) {
|
|||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
NameUniquer::NameUniquer(const string& separator) {
|
NameUniquer::NameUniquer(const string& separator) {
|
||||||
CHECK(std::all_of(separator.begin(), separator.end(), IsAllowed))
|
CHECK(absl::c_all_of(separator, IsAllowed))
|
||||||
<< "separator should comprises allowed characters only";
|
<< "separator should comprises allowed characters only";
|
||||||
separator_ = separator;
|
separator_ = separator;
|
||||||
}
|
}
|
||||||
|
@ -260,7 +260,7 @@ PlatformUtil::GetStreamExecutors(
|
|||||||
// Block here in thread_pool destructor until all devices are initialized.
|
// Block here in thread_pool destructor until all devices are initialized.
|
||||||
}
|
}
|
||||||
VLOG(1) << "Device initialization complete";
|
VLOG(1) << "Device initialization complete";
|
||||||
if (std::all_of(stream_executors.begin(), stream_executors.end(),
|
if (absl::c_all_of(stream_executors,
|
||||||
[](se::StreamExecutor* s) { return s == nullptr; })) {
|
[](se::StreamExecutor* s) { return s == nullptr; })) {
|
||||||
return InternalError("no supported devices found for platform %s",
|
return InternalError("no supported devices found for platform %s",
|
||||||
platform->Name());
|
platform->Name());
|
||||||
|
@ -534,9 +534,8 @@ Status ValidateDotDimensionNumbers(
|
|||||||
absl::Span<const int64> contracting_dims,
|
absl::Span<const int64> contracting_dims,
|
||||||
absl::Span<const int64> batch_dims) -> bool {
|
absl::Span<const int64> batch_dims) -> bool {
|
||||||
auto in_range = [&rank](int64 i) -> bool { return 0 <= i && i < rank; };
|
auto in_range = [&rank](int64 i) -> bool { return 0 <= i && i < rank; };
|
||||||
return std::all_of(contracting_dims.begin(), contracting_dims.end(),
|
return absl::c_all_of(contracting_dims, in_range) &&
|
||||||
in_range) &&
|
absl::c_all_of(batch_dims, in_range);
|
||||||
std::all_of(batch_dims.begin(), batch_dims.end(), in_range);
|
|
||||||
};
|
};
|
||||||
|
|
||||||
absl::Span<const int64> lhs_contracting_dimensions =
|
absl::Span<const int64> lhs_contracting_dimensions =
|
||||||
@ -563,9 +562,8 @@ Status ValidateDotDimensionNumbers(
|
|||||||
auto is_unique = [&dim_set](int64 i) -> bool {
|
auto is_unique = [&dim_set](int64 i) -> bool {
|
||||||
return dim_set.insert(i).second;
|
return dim_set.insert(i).second;
|
||||||
};
|
};
|
||||||
return std::all_of(contracting_dims.begin(), contracting_dims.end(),
|
return absl::c_all_of(contracting_dims, is_unique) &&
|
||||||
is_unique) &&
|
absl::c_all_of(batch_dims, is_unique);
|
||||||
std::all_of(batch_dims.begin(), batch_dims.end(), is_unique);
|
|
||||||
};
|
};
|
||||||
|
|
||||||
if (!dims_unique(lhs_contracting_dimensions, lhs_batch_dimensions) ||
|
if (!dims_unique(lhs_contracting_dimensions, lhs_batch_dimensions) ||
|
||||||
@ -1589,29 +1587,29 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation,
|
|||||||
input_dnums[1] = dnums.input_feature_dimension();
|
input_dnums[1] = dnums.input_feature_dimension();
|
||||||
std::copy(dnums.input_spatial_dimensions().begin(),
|
std::copy(dnums.input_spatial_dimensions().begin(),
|
||||||
dnums.input_spatial_dimensions().end(), input_dnums.begin() + 2);
|
dnums.input_spatial_dimensions().end(), input_dnums.begin() + 2);
|
||||||
std::sort(input_dnums.begin(), input_dnums.end());
|
absl::c_sort(input_dnums);
|
||||||
|
|
||||||
std::vector<int64> window_dnums(num_dims);
|
std::vector<int64> window_dnums(num_dims);
|
||||||
window_dnums[0] = dnums.kernel_input_feature_dimension();
|
window_dnums[0] = dnums.kernel_input_feature_dimension();
|
||||||
window_dnums[1] = dnums.kernel_output_feature_dimension();
|
window_dnums[1] = dnums.kernel_output_feature_dimension();
|
||||||
std::copy(dnums.kernel_spatial_dimensions().begin(),
|
std::copy(dnums.kernel_spatial_dimensions().begin(),
|
||||||
dnums.kernel_spatial_dimensions().end(), window_dnums.begin() + 2);
|
dnums.kernel_spatial_dimensions().end(), window_dnums.begin() + 2);
|
||||||
std::sort(window_dnums.begin(), window_dnums.end());
|
absl::c_sort(window_dnums);
|
||||||
|
|
||||||
std::vector<int64> output_dnums(num_dims);
|
std::vector<int64> output_dnums(num_dims);
|
||||||
output_dnums[0] = dnums.output_batch_dimension();
|
output_dnums[0] = dnums.output_batch_dimension();
|
||||||
output_dnums[1] = dnums.output_feature_dimension();
|
output_dnums[1] = dnums.output_feature_dimension();
|
||||||
std::copy(dnums.output_spatial_dimensions().begin(),
|
std::copy(dnums.output_spatial_dimensions().begin(),
|
||||||
dnums.output_spatial_dimensions().end(), output_dnums.begin() + 2);
|
dnums.output_spatial_dimensions().end(), output_dnums.begin() + 2);
|
||||||
std::sort(output_dnums.begin(), output_dnums.end());
|
absl::c_sort(output_dnums);
|
||||||
|
|
||||||
std::vector<int64> expected_dnums(num_dims);
|
std::vector<int64> expected_dnums(num_dims);
|
||||||
std::iota(expected_dnums.begin(), expected_dnums.end(), 0);
|
std::iota(expected_dnums.begin(), expected_dnums.end(), 0);
|
||||||
|
|
||||||
const auto in_range = [num_dims](int64 i) { return 0 <= i && i < num_dims; };
|
const auto in_range = [num_dims](int64 i) { return 0 <= i && i < num_dims; };
|
||||||
if (!std::all_of(input_dnums.begin(), input_dnums.end(), in_range) ||
|
if (!absl::c_all_of(input_dnums, in_range) ||
|
||||||
!std::all_of(window_dnums.begin(), window_dnums.end(), in_range) ||
|
!absl::c_all_of(window_dnums, in_range) ||
|
||||||
!std::all_of(output_dnums.begin(), output_dnums.end(), in_range)) {
|
!absl::c_all_of(output_dnums, in_range)) {
|
||||||
return InvalidArgument(
|
return InvalidArgument(
|
||||||
"A dimension number is out of range in convolution: %s.",
|
"A dimension number is out of range in convolution: %s.",
|
||||||
dnums.DebugString());
|
dnums.DebugString());
|
||||||
|
@ -130,8 +130,7 @@ bool FoldTransposeIntoConvolution(InstructionOperandsPair pair) {
|
|||||||
|
|
||||||
HloInstruction* new_lhs;
|
HloInstruction* new_lhs;
|
||||||
const int64 kLhsIdx = 0;
|
const int64 kLhsIdx = 0;
|
||||||
if (std::find(operand_indices.begin(), operand_indices.end(), kLhsIdx) !=
|
if (absl::c_linear_search(operand_indices, kLhsIdx)) {
|
||||||
operand_indices.end()) {
|
|
||||||
HloInstruction& transpose = *convolution.mutable_operand(kLhsIdx);
|
HloInstruction& transpose = *convolution.mutable_operand(kLhsIdx);
|
||||||
const auto& transpose_dimensions = transpose.dimensions();
|
const auto& transpose_dimensions = transpose.dimensions();
|
||||||
HloInstruction& transpose_operand = *transpose.mutable_operand(0);
|
HloInstruction& transpose_operand = *transpose.mutable_operand(0);
|
||||||
@ -154,8 +153,7 @@ bool FoldTransposeIntoConvolution(InstructionOperandsPair pair) {
|
|||||||
|
|
||||||
HloInstruction* new_rhs;
|
HloInstruction* new_rhs;
|
||||||
const int64 kRhsIdx = 1;
|
const int64 kRhsIdx = 1;
|
||||||
if (std::find(operand_indices.begin(), operand_indices.end(), kRhsIdx) !=
|
if (absl::c_linear_search(operand_indices, kRhsIdx)) {
|
||||||
operand_indices.end()) {
|
|
||||||
HloInstruction& transpose = *convolution.mutable_operand(kRhsIdx);
|
HloInstruction& transpose = *convolution.mutable_operand(kRhsIdx);
|
||||||
const auto& transpose_dimensions = transpose.dimensions();
|
const auto& transpose_dimensions = transpose.dimensions();
|
||||||
HloInstruction& transpose_operand = *transpose.mutable_operand(0);
|
HloInstruction& transpose_operand = *transpose.mutable_operand(0);
|
||||||
|
@ -87,9 +87,7 @@ bool PointsToSet::ContainsBuffer(const LogicalBuffer& buffer) const {
|
|||||||
bool found = false;
|
bool found = false;
|
||||||
ForEachElement([&found, &buffer](const ShapeIndex& /*index*/,
|
ForEachElement([&found, &buffer](const ShapeIndex& /*index*/,
|
||||||
const BufferList& pointed_to_buffers) {
|
const BufferList& pointed_to_buffers) {
|
||||||
if (!found &&
|
if (!found && absl::c_linear_search(pointed_to_buffers, &buffer)) {
|
||||||
std::find(pointed_to_buffers.begin(), pointed_to_buffers.end(),
|
|
||||||
&buffer) != pointed_to_buffers.end()) {
|
|
||||||
found = true;
|
found = true;
|
||||||
}
|
}
|
||||||
});
|
});
|
||||||
@ -99,8 +97,7 @@ bool PointsToSet::ContainsBuffer(const LogicalBuffer& buffer) const {
|
|||||||
bool PointsToSet::ContainsBufferAtIndex(const LogicalBuffer& buffer,
|
bool PointsToSet::ContainsBufferAtIndex(const LogicalBuffer& buffer,
|
||||||
const ShapeIndex& index) const {
|
const ShapeIndex& index) const {
|
||||||
const auto& pointed_to_buffers = element(index);
|
const auto& pointed_to_buffers = element(index);
|
||||||
return std::find(pointed_to_buffers.begin(), pointed_to_buffers.end(),
|
return absl::c_linear_search(pointed_to_buffers, &buffer);
|
||||||
&buffer) != pointed_to_buffers.end();
|
|
||||||
}
|
}
|
||||||
|
|
||||||
void PointsToSet::AddPointedToBuffer(const LogicalBuffer& buffer,
|
void PointsToSet::AddPointedToBuffer(const LogicalBuffer& buffer,
|
||||||
@ -604,9 +601,8 @@ bool TuplePointsToAnalysis::DoesNotUseOperandBuffer(
|
|||||||
} else if (user->opcode() == HloOpcode::kFusion &&
|
} else if (user->opcode() == HloOpcode::kFusion &&
|
||||||
user->fusion_kind() == HloInstruction::FusionKind::kLoop) {
|
user->fusion_kind() == HloInstruction::FusionKind::kLoop) {
|
||||||
// Find fusion parameter associated with 'operand'.
|
// Find fusion parameter associated with 'operand'.
|
||||||
auto it = std::find_if(
|
auto it = absl::c_find_if(
|
||||||
user->fused_parameters().begin(), user->fused_parameters().end(),
|
user->fused_parameters(), [&](HloInstruction* fused_param) {
|
||||||
[=](HloInstruction* fused_param) {
|
|
||||||
return user->operand(fused_param->parameter_number()) == operand;
|
return user->operand(fused_param->parameter_number()) == operand;
|
||||||
});
|
});
|
||||||
CHECK(it != user->fused_parameters().end());
|
CHECK(it != user->fused_parameters().end());
|
||||||
@ -672,9 +668,8 @@ bool TuplePointsToAnalysis::HasUniqueFusedUseOfOperandAt(
|
|||||||
}
|
}
|
||||||
// Find fusion parameter associated with 'operand'.
|
// Find fusion parameter associated with 'operand'.
|
||||||
const auto& fused_params = fusion->fused_parameters();
|
const auto& fused_params = fusion->fused_parameters();
|
||||||
auto fused_param_it = std::find_if(
|
auto fused_param_it =
|
||||||
fused_params.begin(), fused_params.end(),
|
absl::c_find_if(fused_params, [&](HloInstruction* fused_param) {
|
||||||
[&](HloInstruction* fused_param) {
|
|
||||||
return fusion->operand(fused_param->parameter_number()) == operand;
|
return fusion->operand(fused_param->parameter_number()) == operand;
|
||||||
});
|
});
|
||||||
if (fused_param_it == fused_params.end()) {
|
if (fused_param_it == fused_params.end()) {
|
||||||
@ -743,8 +738,7 @@ bool TuplePointsToAnalysis::CanShareOperandBufferWithUser(
|
|||||||
// Check if one operand of kAdd fused root is kDot or kConvolution.
|
// Check if one operand of kAdd fused root is kDot or kConvolution.
|
||||||
auto* add = user->fused_expression_root();
|
auto* add = user->fused_expression_root();
|
||||||
auto add_operand_it =
|
auto add_operand_it =
|
||||||
std::find_if(add->operands().begin(), add->operands().end(),
|
absl::c_find_if(add->operands(), [&](HloInstruction* operand) {
|
||||||
[&](HloInstruction* operand) {
|
|
||||||
return operand->opcode() == HloOpcode::kConvolution ||
|
return operand->opcode() == HloOpcode::kConvolution ||
|
||||||
operand->opcode() == HloOpcode::kDot;
|
operand->opcode() == HloOpcode::kDot;
|
||||||
});
|
});
|
||||||
|
@ -721,9 +721,8 @@ class FusionPointsToAnalysisTest : public TuplePointsToAnalysisTest {
|
|||||||
// to fusion 'operand'.
|
// to fusion 'operand'.
|
||||||
HloInstruction* GetFusionParameterForOperand(HloInstruction* fusion,
|
HloInstruction* GetFusionParameterForOperand(HloInstruction* fusion,
|
||||||
HloInstruction* operand) {
|
HloInstruction* operand) {
|
||||||
auto it = std::find_if(
|
auto it = absl::c_find_if(
|
||||||
fusion->fused_instructions().begin(),
|
fusion->fused_instructions(), [&](const HloInstruction* fused) {
|
||||||
fusion->fused_instructions().end(), [=](const HloInstruction* fused) {
|
|
||||||
return fused->opcode() == HloOpcode::kParameter &&
|
return fused->opcode() == HloOpcode::kParameter &&
|
||||||
fusion->operand(fused->parameter_number()) == operand;
|
fusion->operand(fused->parameter_number()) == operand;
|
||||||
});
|
});
|
||||||
|
@ -109,8 +109,7 @@ static StatusOr<bool> TryRemoveDeadWhileParams(HloInstruction* while_op) {
|
|||||||
// operand appears in, but it may appear more than once!
|
// operand appears in, but it may appear more than once!
|
||||||
if (user->user_count() == 1 && user->users().front() == while_body_root &&
|
if (user->user_count() == 1 && user->users().front() == while_body_root &&
|
||||||
while_body_root->operand_index(user) == user->tuple_index() &&
|
while_body_root->operand_index(user) == user->tuple_index() &&
|
||||||
std::count(while_body_root->operands().begin(),
|
absl::c_count(while_body_root->operands(), user) == 1) {
|
||||||
while_body_root->operands().end(), user) == 1) {
|
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -158,7 +157,7 @@ static StatusOr<bool> TryRemoveDeadWhileParams(HloInstruction* while_op) {
|
|||||||
// Build up maps from the old/new to the new/old tuple indices.
|
// Build up maps from the old/new to the new/old tuple indices.
|
||||||
std::vector<int64> new_to_old_tuple_idx(used_tuple_indices.begin(),
|
std::vector<int64> new_to_old_tuple_idx(used_tuple_indices.begin(),
|
||||||
used_tuple_indices.end());
|
used_tuple_indices.end());
|
||||||
std::sort(new_to_old_tuple_idx.begin(), new_to_old_tuple_idx.end());
|
absl::c_sort(new_to_old_tuple_idx);
|
||||||
|
|
||||||
absl::flat_hash_map<int64, int64> old_to_new_tuple_idx;
|
absl::flat_hash_map<int64, int64> old_to_new_tuple_idx;
|
||||||
for (int64 new_idx = 0; new_idx < new_to_old_tuple_idx.size(); ++new_idx) {
|
for (int64 new_idx = 0; new_idx < new_to_old_tuple_idx.size(); ++new_idx) {
|
||||||
|
@ -407,10 +407,9 @@ TEST_F(WhileLoopSimplifierTest, RemoveUnusedLoopOperands) {
|
|||||||
// The original while instruction is still left in the module as a dead
|
// The original while instruction is still left in the module as a dead
|
||||||
// instruction, find a while instruction with a different name as the new
|
// instruction, find a while instruction with a different name as the new
|
||||||
// while instruction.
|
// while instruction.
|
||||||
|
const auto& instrs = m->entry_computation()->instructions();
|
||||||
HloInstruction* new_while_op =
|
HloInstruction* new_while_op =
|
||||||
*std::find_if(m->entry_computation()->instructions().begin(),
|
*absl::c_find_if(instrs, [&](const HloInstruction* instr) {
|
||||||
m->entry_computation()->instructions().end(),
|
|
||||||
[&](const HloInstruction* instr) {
|
|
||||||
return (instr->opcode() == HloOpcode::kWhile &&
|
return (instr->opcode() == HloOpcode::kWhile &&
|
||||||
instr->name() != "while");
|
instr->name() != "while");
|
||||||
});
|
});
|
||||||
|
@ -87,8 +87,7 @@ bool Shape::is_static() const {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return !std::any_of(dynamic_dimensions_.begin(), dynamic_dimensions_.end(),
|
return !absl::c_any_of(dynamic_dimensions_, [](bool b) { return b; });
|
||||||
[](bool b) { return b; });
|
|
||||||
}
|
}
|
||||||
|
|
||||||
void Shape::DeleteDimension(int64 dim_to_delete) {
|
void Shape::DeleteDimension(int64 dim_to_delete) {
|
||||||
|
@ -405,8 +405,7 @@ ShapeUtil::MakeShapeWithDescendingLayoutAndSamePhysicalLayout(
|
|||||||
}
|
}
|
||||||
|
|
||||||
/* static */ bool ShapeUtil::IsNestedTuple(const Shape& shape) {
|
/* static */ bool ShapeUtil::IsNestedTuple(const Shape& shape) {
|
||||||
return IsTuple(shape) && std::any_of(shape.tuple_shapes().begin(),
|
return IsTuple(shape) && absl::c_any_of(shape.tuple_shapes(), IsTuple);
|
||||||
shape.tuple_shapes().end(), IsTuple);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
/* static */ bool ShapeUtil::IsEmptyTuple(const Shape& shape) {
|
/* static */ bool ShapeUtil::IsEmptyTuple(const Shape& shape) {
|
||||||
|
@ -135,7 +135,7 @@ void SparseIndexArray::SortWithValues(absl::Span<NativeT> values) {
|
|||||||
auto sort_order_less = [this](int64 lhs, int64 rhs) {
|
auto sort_order_less = [this](int64 lhs, int64 rhs) {
|
||||||
return IndexUtil::CompareIndices(At(lhs), At(rhs)) < 0;
|
return IndexUtil::CompareIndices(At(lhs), At(rhs)) < 0;
|
||||||
};
|
};
|
||||||
std::sort(sort_order.begin(), sort_order.end(), sort_order_less);
|
absl::c_sort(sort_order, sort_order_less);
|
||||||
|
|
||||||
// Reorder the array elements according to sort_order. Work through the array
|
// Reorder the array elements according to sort_order. Work through the array
|
||||||
// and follow cycles so we can do the reorder in-place.
|
// and follow cycles so we can do the reorder in-place.
|
||||||
|
@ -467,7 +467,7 @@ XLA_TEST_F(ConvolutionTest, Convolve3D_1x4x2x3x3_2x2x2x3x3_Valid) {
|
|||||||
// servers. The error message is missing the operator ++.
|
// servers. The error message is missing the operator ++.
|
||||||
template <typename T>
|
template <typename T>
|
||||||
void iota_int_init_value(std::vector<T>& values, int init_value) {
|
void iota_int_init_value(std::vector<T>& values, int init_value) {
|
||||||
std::for_each(values.begin(), values.end(),
|
absl::c_for_each(values,
|
||||||
[&](T& value) { value = static_cast<T>(init_value++); });
|
[&](T& value) { value = static_cast<T>(init_value++); });
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -86,7 +86,7 @@ bool IsPermutation(absl::Span<const int64> permutation, int64 rank) {
|
|||||||
CHECK_LT(index, rank);
|
CHECK_LT(index, rank);
|
||||||
output[index] = 0;
|
output[index] = 0;
|
||||||
}
|
}
|
||||||
return std::find(output.begin(), output.end(), -1) == output.end();
|
return !absl::c_linear_search(output, -1);
|
||||||
}
|
}
|
||||||
|
|
||||||
std::vector<int64> InversePermutation(
|
std::vector<int64> InversePermutation(
|
||||||
|
@ -324,8 +324,7 @@ bool IsIdentityPermutation(absl::Span<const int64> permutation);
|
|||||||
|
|
||||||
template <typename Container>
|
template <typename Container>
|
||||||
int64 PositionInContainer(const Container& container, int64 value) {
|
int64 PositionInContainer(const Container& container, int64 value) {
|
||||||
return std::distance(container.begin(),
|
return std::distance(container.begin(), absl::c_find(container, value));
|
||||||
std::find(container.begin(), container.end(), value));
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Formats the container as a comma-separated string. StrAppend must support
|
// Formats the container as a comma-separated string. StrAppend must support
|
||||||
|
@ -17,6 +17,7 @@ limitations under the License.
|
|||||||
|
|
||||||
#include <vector>
|
#include <vector>
|
||||||
|
|
||||||
|
#include "absl/algorithm/container.h"
|
||||||
#include "absl/strings/str_cat.h"
|
#include "absl/strings/str_cat.h"
|
||||||
#include "tensorflow/compiler/xla/types.h"
|
#include "tensorflow/compiler/xla/types.h"
|
||||||
#include "tensorflow/compiler/xla/xla_data.pb.h"
|
#include "tensorflow/compiler/xla/xla_data.pb.h"
|
||||||
@ -137,23 +138,21 @@ bool HasPadding(const Window& window) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
bool HasSymmetricPadding(const Window& window) {
|
bool HasSymmetricPadding(const Window& window) {
|
||||||
return std::all_of(window.dimensions().begin(), window.dimensions().end(),
|
return absl::c_all_of(window.dimensions(), [](const WindowDimension& dim) {
|
||||||
[](const WindowDimension& dim) {
|
|
||||||
return dim.padding_low() == dim.padding_high();
|
return dim.padding_low() == dim.padding_high();
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
bool HasSymmetricPadding(const PaddingConfig& padding_config) {
|
bool HasSymmetricPadding(const PaddingConfig& padding_config) {
|
||||||
return std::all_of(padding_config.dimensions().begin(),
|
return absl::c_all_of(padding_config.dimensions(),
|
||||||
padding_config.dimensions().end(),
|
|
||||||
[](const PaddingConfig::PaddingConfigDimension& dim) {
|
[](const PaddingConfig::PaddingConfigDimension& dim) {
|
||||||
return dim.edge_padding_low() == dim.edge_padding_high();
|
return dim.edge_padding_low() ==
|
||||||
|
dim.edge_padding_high();
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
bool HasNegativePadding(const Window& window) {
|
bool HasNegativePadding(const Window& window) {
|
||||||
return std::any_of(window.dimensions().begin(), window.dimensions().end(),
|
return absl::c_any_of(window.dimensions(), [](const WindowDimension& dim) {
|
||||||
[](const WindowDimension& dim) {
|
|
||||||
return dim.padding_low() < 0 || dim.padding_high() < 0;
|
return dim.padding_low() < 0 || dim.padding_high() < 0;
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
@ -190,8 +189,7 @@ bool AllOrNoneReversed(const Window& window) {
|
|||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
bool reversed = window.dimensions()[0].window_reversal();
|
bool reversed = window.dimensions()[0].window_reversal();
|
||||||
return std::all_of(window.dimensions().begin(), window.dimensions().end(),
|
return absl::c_all_of(window.dimensions(), [&](const WindowDimension& dim) {
|
||||||
[&](const WindowDimension& dim) {
|
|
||||||
return dim.window_reversal() == reversed;
|
return dim.window_reversal() == reversed;
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
Loading…
x
Reference in New Issue
Block a user