[XLA] Use absl::flat_hash_{map,set}::contains() instead of count().

contains() is a much better description of what we're trying to do!

Also change a few std containers over to absl containers so we can advantage of
this.

PiperOrigin-RevId: 227894609
This commit is contained in:
Justin Lebar 2019-01-04 12:18:26 -08:00 committed by TensorFlower Gardener
parent f1d4d18a62
commit f9bd1568aa
53 changed files with 225 additions and 192 deletions

View File

@ -741,6 +741,7 @@ cc_library(
"//tensorflow/compiler/xla/service:hlo_evaluator",
"//tensorflow/compiler/xla/service:shape_inference",
"//tensorflow/core:lib",
"@com_google_absl//absl/container:flat_hash_set",
"@com_google_absl//absl/memory",
"@com_google_absl//absl/types:span",
],

View File

@ -22,6 +22,8 @@ limitations under the License.
#include <utility>
#include "absl/algorithm/container.h"
#include "absl/container/flat_hash_map.h"
#include "absl/container/flat_hash_set.h"
#include "absl/memory/memory.h"
#include "absl/strings/match.h"
#include "absl/strings/str_cat.h"
@ -193,9 +195,9 @@ StatusOr<ProgramShape> XlaBuilder::GetProgramShape(XlaOp root) const {
}
void XlaBuilder::IsConstantVisitor(const int64 op_handle,
std::set<int64>* visited,
absl::flat_hash_set<int64>* visited,
bool* is_constant) const {
if (visited->count(op_handle) != 0 || !*is_constant) {
if (visited->contains(op_handle) || !*is_constant) {
return;
}
@ -2415,7 +2417,7 @@ StatusOr<bool> XlaBuilder::IsConstant(const XlaOp& operand) const {
TF_RETURN_IF_ERROR(LookUpInstruction(operand).status());
bool is_constant = true;
std::set<int64> visited;
absl::flat_hash_set<int64> visited;
IsConstantVisitor(operand.handle(), &visited, &is_constant);
return is_constant;
}

View File

@ -727,7 +727,8 @@ class XlaBuilder {
// operation such as `RngNormal` or `Infeed`. The visitor walks the
// computation starting at a given operation and sets is_constant to false iff
// a parameter or stateful operation is encountered.
void IsConstantVisitor(const int64 op_handle, std::set<int64>* visited,
void IsConstantVisitor(const int64 op_handle,
absl::flat_hash_set<int64>* visited,
bool* is_constant) const;
// Checks bounds for convolution parameters.

View File

@ -18,6 +18,7 @@ limitations under the License.
#include <array>
#include <utility>
#include "absl/container/flat_hash_set.h"
#include "absl/memory/memory.h"
#include "tensorflow/compiler/xla/client/xla_builder.h"
#include "tensorflow/compiler/xla/literal_util.h"
@ -605,24 +606,26 @@ ReferenceUtil::ReduceToRowArray2D(
const std::function<float(float, float)>& reduce_function) {
std::vector<float> result;
CHECK_EQ(dims.size(), 3);
const std::set<int64> dim_set(dims.begin(), dims.end());
const absl::flat_hash_set<int64> dim_set(dims.begin(), dims.end());
CHECK_EQ(dim_set.size(), 3);
for (int64 a0 = 0; a0 == 0 || (!dim_set.count(0) && a0 < array.n1()); ++a0) {
for (int64 a1 = 0; a1 == 0 || (!dim_set.count(1) && a1 < array.n2());
for (int64 a0 = 0; a0 == 0 || (!dim_set.contains(0) && a0 < array.n1());
++a0) {
for (int64 a1 = 0; a1 == 0 || (!dim_set.contains(1) && a1 < array.n2());
++a1) {
for (int64 a2 = 0; a2 == 0 || (!dim_set.count(2) && a2 < array.n3());
for (int64 a2 = 0; a2 == 0 || (!dim_set.contains(2) && a2 < array.n3());
++a2) {
for (int64 a3 = 0; a3 == 0 || (!dim_set.count(3) && a3 < array.n4());
for (int64 a3 = 0; a3 == 0 || (!dim_set.contains(3) && a3 < array.n4());
++a3) {
float accumulator = init;
for (int64 i0 = 0; i0 == 0 || (dim_set.count(0) && i0 < array.n1());
++i0) {
for (int64 i1 = 0; i1 == 0 || (dim_set.count(1) && i1 < array.n2());
++i1) {
for (int64 i0 = 0;
i0 == 0 || (dim_set.contains(0) && i0 < array.n1()); ++i0) {
for (int64 i1 = 0;
i1 == 0 || (dim_set.contains(1) && i1 < array.n2()); ++i1) {
for (int64 i2 = 0;
i2 == 0 || (dim_set.count(2) && i2 < array.n3()); ++i2) {
i2 == 0 || (dim_set.contains(2) && i2 < array.n3()); ++i2) {
for (int64 i3 = 0;
i3 == 0 || (dim_set.count(3) && i3 < array.n4()); ++i3) {
i3 == 0 || (dim_set.contains(3) && i3 < array.n4());
++i3) {
// Handle zero-sized arrays.
if (array.n1() > 0 && array.n2() > 0 && array.n3() > 0 &&
array.n4() > 0) {

View File

@ -515,6 +515,7 @@ tf_cc_test(
"//tensorflow/compiler/xla:window_util",
"//tensorflow/compiler/xla/tests:hlo_test_base",
"//tensorflow/compiler/xla/tests:xla_internal_test_main",
"@com_google_absl//absl/container:flat_hash_map",
],
)
@ -677,6 +678,7 @@ cc_library(
"//tensorflow/core:lib",
"//tensorflow/core:stream_executor_no_cuda",
"//third_party/eigen3",
"@com_google_absl//absl/container:flat_hash_map",
"@com_google_absl//absl/memory",
"@com_google_absl//absl/strings",
"@com_google_absl//absl/types:span",
@ -1002,6 +1004,7 @@ cc_library(
"//tensorflow/compiler/xla:xla_data_proto",
"//tensorflow/core:lib",
"//tensorflow/core:lib_internal",
"@com_google_absl//absl/container:flat_hash_map",
"@com_google_absl//absl/memory",
"@com_google_absl//absl/strings",
"@com_google_absl//absl/types:span",
@ -1136,6 +1139,7 @@ tf_cc_test(
"//tensorflow/compiler/xla/tests:xla_internal_test_main",
"//tensorflow/core:lib",
"//tensorflow/core:test",
"@com_google_absl//absl/container:flat_hash_set",
"@com_google_absl//absl/memory",
],
)
@ -1580,6 +1584,7 @@ cc_library(
"//tensorflow/core:lib",
"@com_google_absl//absl/algorithm:container",
"@com_google_absl//absl/container:flat_hash_map",
"@com_google_absl//absl/container:flat_hash_set",
"@com_google_absl//absl/container:inlined_vector",
"@com_google_absl//absl/memory",
"@com_google_absl//absl/strings",
@ -2116,6 +2121,7 @@ tf_cc_test(
"//tensorflow/compiler/xla:test_helpers",
"//tensorflow/compiler/xla/tests:hlo_test_base",
"//tensorflow/compiler/xla/tests:xla_internal_test_main",
"@com_google_absl//absl/container:flat_hash_set",
],
)
@ -2288,6 +2294,7 @@ cc_library(
"//tensorflow/compiler/xla:types",
"//tensorflow/compiler/xla:util",
"//tensorflow/core:lib",
"@com_google_absl//absl/container:flat_hash_set",
"@com_google_absl//absl/memory",
"@com_google_absl//absl/strings",
],
@ -2548,6 +2555,7 @@ cc_library(
"//tensorflow/compiler/xla:types",
"//tensorflow/compiler/xla:util",
"//tensorflow/core:lib",
"@com_google_absl//absl/container:flat_hash_set",
],
)
@ -3188,6 +3196,7 @@ cc_library(
"//tensorflow/core:lib",
"//tensorflow/core:lib_internal",
"//tensorflow/core:regexp_internal",
"@com_google_absl//absl/container:flat_hash_map",
"@com_google_absl//absl/strings",
"@com_google_absl//absl/strings:str_format",
"@com_google_absl//absl/types:optional",

View File

@ -27,6 +27,7 @@ limitations under the License.
#include "absl/algorithm/container.h"
#include "absl/container/flat_hash_map.h"
#include "absl/container/flat_hash_set.h"
#include "absl/container/inlined_vector.h"
#include "absl/memory/memory.h"
#include "absl/strings/str_cat.h"
@ -2539,7 +2540,7 @@ Status AlgebraicSimplifierVisitor::HandleReduce(HloInstruction* reduce) {
}
if (can_move_reshape_into_reduce) {
changed_ = true;
std::unordered_set<int64> dimensions_not_to_reduce;
absl::flat_hash_set<int64> dimensions_not_to_reduce;
for (auto dim_pair : unmodified_dims) {
if (arg_dim_in_output[dim_pair.second]) {
dimensions_not_to_reduce.insert(dim_pair.first);
@ -2547,7 +2548,7 @@ Status AlgebraicSimplifierVisitor::HandleReduce(HloInstruction* reduce) {
}
std::vector<int64> new_reduce_dimensions;
for (int64 i = 0; i < arg->operand(0)->shape().rank(); ++i) {
if (dimensions_not_to_reduce.count(i) == 0) {
if (!dimensions_not_to_reduce.contains(i)) {
new_reduce_dimensions.push_back(i);
}
}

View File

@ -115,12 +115,10 @@ StatusOr<StreamPool::Ptr> Backend::BorrowStream(int device_ordinal) {
StatusOr<StreamPool::Ptr> Backend::BorrowStream(se::StreamExecutor* executor) {
tensorflow::mutex_lock l(mu_);
if (0 == stream_pools_.count(executor)) {
stream_pools_.emplace(std::piecewise_construct,
std::forward_as_tuple(executor),
std::forward_as_tuple());
if (!stream_pools_.contains(executor)) {
stream_pools_.emplace(executor, absl::make_unique<StreamPool>());
}
return stream_pools_.at(executor).BorrowStream(executor);
return stream_pools_.at(executor)->BorrowStream(executor);
}
Backend::Backend(se::Platform* platform, Compiler* compiler,

View File

@ -22,6 +22,7 @@ limitations under the License.
#include <string>
#include <vector>
#include "absl/container/flat_hash_map.h"
#include "absl/strings/str_cat.h"
#include "absl/types/span.h"
#include "tensorflow/compiler/xla/service/compiler.h"
@ -175,7 +176,8 @@ class Backend {
tensorflow::mutex mu_;
// Mapping from stream executor to stream pools, used by `BorrowStream` above.
std::map<se::StreamExecutor*, StreamPool> stream_pools_ GUARDED_BY(mu_);
absl::flat_hash_map<se::StreamExecutor*, std::unique_ptr<StreamPool>>
stream_pools_ GUARDED_BY(mu_);
// The default memory allocator to use.
std::unique_ptr<StreamExecutorMemoryAllocator> memory_allocator_;

View File

@ -138,8 +138,8 @@ Status GatherComputationsByAllocationType(
worklist.pop_front();
const HloComputation* computation = worklist_front.first;
bool is_thread_local = worklist_front.second;
bool in_thread_local_set = thread_local_set.count(computation) > 0;
bool in_global_set = global_set.count(computation) > 0;
bool in_thread_local_set = thread_local_set.contains(computation);
bool in_global_set = global_set.contains(computation);
// If the computation has already been added to the respective set, then
// nothing to do.
@ -207,9 +207,9 @@ Status GatherComputationsByAllocationType(
// Add the computations to the vectors in post order.
for (auto* computation : module->MakeComputationPostOrder()) {
if (thread_local_set.count(computation) > 0) {
if (thread_local_set.contains(computation)) {
thread_local_computations->push_back(computation);
} else if (global_set.count(computation) > 0) {
} else if (global_set.contains(computation)) {
global_computations->push_back(computation);
}
// If the computation is not reachable from the entry computation, then it
@ -219,13 +219,6 @@ Status GatherComputationsByAllocationType(
return Status::OK();
}
size_t BufferAllocation::Slice::Hasher::operator()(Slice s) const {
uint64 h = std::hash<int64>()(s.index());
h = tensorflow::Hash64Combine(h, std::hash<int64>()(s.offset()));
h = tensorflow::Hash64Combine(h, std::hash<int64>()(s.size()));
return h;
}
string BufferAllocation::Slice::ToString() const {
return absl::StrCat("{index:", index(), ", offset:", offset_,
", size:", size_, "}");
@ -240,7 +233,7 @@ BufferAllocation::Slice BufferAllocation::GetSlice(
void BufferAllocation::AddAssignment(const LogicalBuffer& buffer, int64 offset,
int64 size) {
VLOG(4) << "Trying to add " << buffer << " to allocation #" << index();
CHECK(assigned_buffers_.count(&buffer) == 0)
CHECK(!assigned_buffers_.contains(&buffer))
<< "LogicalBuffer " << buffer << " already assigned to allocation "
<< index_;
CHECK_LE(offset, size_) << "LogicalBuffer " << buffer
@ -346,7 +339,7 @@ const PointsToSet& BufferAssignment::GetPointsToSet(
bool BufferAssignment::HasAllocation(const LogicalBuffer& buffer) const {
TF_CHECK_OK(points_to_analysis().VerifyBuffer(buffer));
return allocation_index_for_buffer_.count(&buffer) > 0;
return allocation_index_for_buffer_.contains(&buffer);
}
const BufferAllocation& BufferAssignment::GetAssignedAllocation(
@ -401,7 +394,7 @@ bool BufferAssignment::HasAllocationAt(const HloInstruction* instruction,
const ShapeIndex& index) const {
for (const LogicalBuffer* buffer :
GetPointsToSet(instruction).element(index)) {
if (allocation_index_for_buffer_.count(buffer) > 0) {
if (allocation_index_for_buffer_.contains(buffer)) {
return true;
}
}
@ -459,8 +452,7 @@ bool BufferAssignment::SharesSliceAtIndex(
bool BufferAssignment::HaveDisjointSlices(const HloInstruction* hlo_a,
const HloInstruction* hlo_b) const {
using SliceSet =
flat_hash_set<BufferAllocation::Slice, BufferAllocation::Slice::Hasher>;
using SliceSet = flat_hash_set<BufferAllocation::Slice>;
// Gets the slices all of instr's subshapes. If any subshape doesn't have an
// assigned slice, returns the empty set.
auto collect_slices = [&](const HloInstruction* instr) -> SliceSet {
@ -519,7 +511,7 @@ BufferAllocation* BufferAssignment::NewAllocation(const LogicalBuffer& buffer,
void BufferAssignment::AddAssignment(BufferAllocation* allocation,
const LogicalBuffer& buffer, int64 offset,
int64 size) {
CHECK_EQ(0, allocation_index_for_buffer_.count(&buffer))
CHECK(!allocation_index_for_buffer_.contains(&buffer))
<< "LogicalBuffer " << buffer << " already has an allocation.";
CHECK(allocation->is_reusable() || allocation->assigned_buffers().empty())
<< "Non-reusable allocation already assigned a buffer: "
@ -988,7 +980,7 @@ Status BufferAssigner::AssignBuffersForComputation(
std::vector<BufferAllocation::Index> allocation_indices;
for (const LogicalBuffer* buffer : sorted_buffers) {
VLOG(3) << "Assigning allocation to: " << *buffer;
if (colocated_buffers.count(buffer) > 0) {
if (colocated_buffers.contains(buffer)) {
// Colocated buffers are currently assigned in an earlier pass.
VLOG(3) << "Skipping colocated buffer: " << *buffer;
continue;
@ -1056,7 +1048,7 @@ Status BufferAssigner::AssignBuffersForComputation(
assignment->GetAllSlices(operand, /*index=*/{})) {
BufferAllocation* allocation =
assignment->GetMutableAllocation(operand_slice.index());
if (colocated_allocations.count(allocation->index()) == 0) {
if (!colocated_allocations.contains(allocation->index())) {
// TODO(b/32491382) Colocated buffers are currently assigned in an
// earlier pass, and so can break the "increasing allocation size"
// invariant in this function (causing this CHECK to fail). However,
@ -1087,7 +1079,7 @@ Status BufferAssigner::AssignBuffersForComputation(
// Instructions are iterated in increasing buffer size, so any
// previously create allocation must be large enough to hold this
// instruction's output (with the exception of colocated buffers).
if (colocated_allocations.count(allocation->index()) == 0) {
if (!colocated_allocations.contains(allocation->index())) {
// TODO(b/32491382) Colocated buffers are currently assigned in an
// earlier pass, and so can break the "increasing allocation size"
// invariant in this function (causing this CHECK to fail). However,
@ -1376,7 +1368,7 @@ void BufferAssigner::AddSetToColocatedBufferSets(
std::vector<size_t> overlap_set_indices;
for (size_t index = 0; index < colocated_buffer_sets->size(); ++index) {
for (const LogicalBuffer* buffer : colocated_set) {
if ((*colocated_buffer_sets)[index].count(buffer) > 0) {
if ((*colocated_buffer_sets)[index].contains(buffer)) {
VLOG(5) << "Found overlap with existing set on buffer "
<< buffer->ToString() << "\n"
<< ColocatedBufferSetsToString((*colocated_buffer_sets)[index],

View File

@ -186,9 +186,10 @@ class BufferAllocation {
end > other.offset_;
}
struct Hasher {
size_t operator()(Slice s) const;
};
template <typename H>
friend H AbslHashValue(H h, const Slice& s) {
return H::combine(std::move(h), s.index(), s.offset(), s.size());
}
string ToString() const;

View File

@ -21,6 +21,7 @@ limitations under the License.
#include <utility>
#include <vector>
#include "absl/container/flat_hash_set.h"
#include "absl/memory/memory.h"
#include "tensorflow/compiler/xla/literal.h"
#include "tensorflow/compiler/xla/service/buffer_value.h"
@ -309,7 +310,7 @@ class BufferAssignmentTest : public HloTestBase {
static bool BuffersDistinct(const std::vector<const HloInstruction*>& a,
const std::vector<const HloInstruction*>& b,
const BufferAssignment& assignment) {
std::set<BufferAllocation::Slice> a_slices;
absl::flat_hash_set<BufferAllocation::Slice> a_slices;
for (const HloInstruction* instruction : a) {
if (assignment.HasTopLevelAllocation(instruction)) {
a_slices.insert(
@ -319,8 +320,8 @@ static bool BuffersDistinct(const std::vector<const HloInstruction*>& a,
for (const HloInstruction* instruction : b) {
if (assignment.HasTopLevelAllocation(instruction)) {
if (a_slices.count(assignment.GetUniqueTopLevelSlice(instruction)
.ConsumeValueOrDie())) {
if (a_slices.contains(assignment.GetUniqueTopLevelSlice(instruction)
.ConsumeValueOrDie())) {
return false;
}
}

View File

@ -72,7 +72,7 @@ ChannelHandle ChannelTracker::AllocateHandle(ChannelHandle::ChannelType type) {
}
Status ChannelTracker::RegisterSendInternal(const ChannelHandle& handle) {
if (opaque_to_channel_.count(handle.handle()) == 0) {
if (!opaque_to_channel_.contains(handle.handle())) {
return NotFound("channel handle not found: %d", handle.handle());
}
Channel& channel = opaque_to_channel_[handle.handle()];
@ -94,7 +94,7 @@ Status ChannelTracker::RegisterSendInternal(const ChannelHandle& handle) {
}
Status ChannelTracker::RegisterRecvInternal(const ChannelHandle& handle) {
if (opaque_to_channel_.count(handle.handle()) == 0) {
if (!opaque_to_channel_.contains(handle.handle())) {
return NotFound("channel handle not found: %d", handle.handle());
}
Channel& channel = opaque_to_channel_[handle.handle()];

View File

@ -18,6 +18,7 @@ limitations under the License.
#include <map>
#include "absl/container/flat_hash_map.h"
#include "absl/types/span.h"
#include "tensorflow/compiler/xla/service/hlo_module.h"
#include "tensorflow/compiler/xla/status.h"
@ -83,7 +84,8 @@ class ChannelTracker {
// Mapping from ChannelHandle value to the corresponding registered
// Channel object.
std::map<int64, Channel> opaque_to_channel_ GUARDED_BY(channel_mutex_);
absl::flat_hash_map<int64, Channel> opaque_to_channel_
GUARDED_BY(channel_mutex_);
TF_DISALLOW_COPY_AND_ASSIGN(ChannelTracker);
};

View File

@ -46,8 +46,7 @@ static bool ShouldMakeAllUsersColMajor(const HloInstruction* instruction) {
for (auto* user : instruction->users()) {
optional<int64> operand_idx = ProfitableToMakeDotOperandColumnMajor(*user);
if (!operand_idx || user->operand(*operand_idx) != instruction ||
std::count(user->operands().begin(), user->operands().end(),
instruction) != 1) {
absl::c_count(user->operands(), instruction) != 1) {
return false;
}
}

View File

@ -24,11 +24,9 @@ limitations under the License.
#include <utility>
#include <vector>
// IWYU pragma: no_include "llvm/IR/Intrinsics.gen.inc"
#include "absl/container/flat_hash_map.h"
#include "absl/container/flat_hash_set.h"
#include "tensorflow/core/lib/math/math_util.h"
#include "tensorflow/core/platform/logging.h"
// IWYU pragma: no_include "llvm/IR/Intrinsics.gen.inc"
#include "absl/strings/str_cat.h"
#include "absl/strings/str_format.h"
#include "absl/types/span.h"
@ -70,6 +68,8 @@ limitations under the License.
#include "tensorflow/compiler/xla/window_util.h"
#include "tensorflow/core/lib/core/bits.h"
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/lib/math/math_util.h"
#include "tensorflow/core/platform/logging.h"
namespace xla {
@ -1399,7 +1399,7 @@ static bool ReductionPreservesLayout(const HloInstruction& reduce) {
int64 delta = 0;
for (int64 i = 0; i < operand_shape.dimensions_size(); i++) {
if (reduced_dims.count(i)) {
if (reduced_dims.contains(i)) {
delta++;
} else {
InsertOrDie(&unreduced_dim_map, i, i - delta);
@ -1412,7 +1412,7 @@ static bool ReductionPreservesLayout(const HloInstruction& reduce) {
for (int64 operand_dim_idx = 0;
operand_dim_idx < operand_shape.dimensions_size(); operand_dim_idx++) {
int64 operand_dim = operand_shape.layout().minor_to_major(operand_dim_idx);
if (!reduced_dims.count(operand_dim)) {
if (!reduced_dims.contains(operand_dim)) {
if (FindOrDie(unreduced_dim_map, operand_dim) !=
result_shape.layout().minor_to_major(result_dim_idx++)) {
return false;
@ -1990,7 +1990,7 @@ Status IrEmitter::HandleSlice(HloInstruction* slice) {
// The memcpy will copy elements that are logically this shape (allowed to be
// scalar).
const Shape logical_element_shape = ShapeUtil::FilterDimensions(
[&inner_dims](int64 dim) -> bool { return inner_dims.count(dim); },
[&inner_dims](int64 dim) { return inner_dims.contains(dim); },
operand->shape());
const int64 primitive_elements_per_logical_element =

View File

@ -448,7 +448,7 @@ class IrEmitter : public DfsHloVisitorWithDefault,
computation_to_profile_idx_;
// Maps HLOs to Values emitted for them.
std::unordered_map<const HloInstruction*, llvm::Value*> emitted_value_;
absl::flat_hash_map<const HloInstruction*, llvm::Value*> emitted_value_;
llvm_ir::AliasAnalysis alias_analysis_;

View File

@ -94,8 +94,8 @@ cc_library(
"//tensorflow/compiler/xla:util",
"//tensorflow/compiler/xla/service:hlo",
"//tensorflow/compiler/xla/service:hlo_reachability",
"//tensorflow/core:lib",
"@com_google_absl//absl/container:flat_hash_map",
"@com_google_absl//absl/container:flat_hash_set",
"@com_google_absl//absl/memory",
],
)
@ -135,6 +135,8 @@ cc_library(
"//tensorflow/compiler/xla/service/llvm_ir:llvm_util",
"//tensorflow/compiler/xla/service/llvm_ir:tuple_ops",
"//tensorflow/core:lib",
"@com_google_absl//absl/container:flat_hash_map",
"@com_google_absl//absl/container:flat_hash_set",
"@com_google_absl//absl/strings",
"@com_google_absl//absl/types:span",
"@llvm//:core",
@ -263,7 +265,9 @@ cc_library(
"//tensorflow/compiler/xla/service:buffer_assignment",
"//tensorflow/compiler/xla/service:device_memory_allocator",
"//tensorflow/core:lib",
"//tensorflow/core:lib_internal",
"//tensorflow/core:stream_executor_no_cuda",
"@com_google_absl//absl/container:flat_hash_map",
"@com_google_absl//absl/memory",
"@com_google_absl//absl/types:span",
],
@ -362,6 +366,7 @@ cc_library(
"//tensorflow/core/platform/default/build_config:stream_executor_cuda", # build_cleaner: keep
"//tensorflow/stream_executor",
"@com_google_absl//absl/container:flat_hash_map",
"@com_google_absl//absl/container:flat_hash_set",
"@com_google_absl//absl/memory",
"@com_google_absl//absl/strings",
"@com_google_absl//absl/strings:str_format",

View File

@ -24,6 +24,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/types.h"
#include "tensorflow/compiler/xla/util.h"
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/lib/gtl/map_util.h"
#include "tensorflow/core/lib/strings/numbers.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/types.h"
@ -57,16 +58,16 @@ StatusOr<std::unique_ptr<BufferAllocations>> BufferAllocations::Builder::Build(
// If buffer #i's address is already registered (e.g. external arguments or
// result buffers), use that registered buffer.
if (registered_buffers_.count(i)) {
se::DeviceMemoryBase address = FindOrDie(registered_buffers_, i);
if (reinterpret_cast<uintptr_t>(address.opaque()) % expected_alignment !=
if (se::DeviceMemoryBase* address =
tensorflow::gtl::FindOrNull(registered_buffers_, i)) {
if (reinterpret_cast<uintptr_t>(address->opaque()) % expected_alignment !=
0) {
return InternalError(
"Address of registered buffer %d must be a multiple of %x, but "
"was %p",
i, kEntryParameterAlignBytes, address.opaque());
i, kEntryParameterAlignBytes, address->opaque());
}
buffer_allocations->SetBuffer(i, FindOrDie(registered_buffers_, i));
buffer_allocations->SetBuffer(i, *address);
continue;
}

View File

@ -20,6 +20,7 @@ limitations under the License.
#include <set>
#include <vector>
#include "absl/container/flat_hash_map.h"
#include "absl/types/span.h"
#include "tensorflow/compiler/xla/service/buffer_assignment.h"
#include "tensorflow/compiler/xla/service/device_memory_allocator.h"
@ -52,7 +53,8 @@ class BufferAllocations {
DeviceMemoryAllocator* memory_allocator);
private:
std::map<BufferAllocation::Index, se::DeviceMemoryBase> registered_buffers_;
absl::flat_hash_map<BufferAllocation::Index, se::DeviceMemoryBase>
registered_buffers_;
};
~BufferAllocations();

View File

@ -15,6 +15,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/gpu/hlo_to_ir_bindings.h"
#include "absl/container/flat_hash_set.h"
#include "absl/strings/str_cat.h"
#include "llvm/IR/BasicBlock.h"
#include "llvm/IR/Function.h"
@ -45,10 +46,10 @@ void HloToIrBindings::EmitBasePointersForHlos(
// An HLO can have duplicated operands. This data structure remembers which
// operand HLOs are already bound to avoid rebinding the same HLO.
std::set<const HloInstruction*> already_bound_for_this_function;
absl::flat_hash_set<const HloInstruction*> already_bound_for_this_function;
auto arg_iter = function->arg_begin();
for (const HloInstruction* io_hlo : io_hlos) {
if (!already_bound_for_this_function.count(io_hlo)) {
if (!already_bound_for_this_function.contains(io_hlo)) {
if (!is_nested_ && io_hlo->opcode() == HloOpcode::kGetTupleElement) {
BindHloToIrValue(*io_hlo, EmitGetTupleElement(io_hlo, &*arg_iter));
} else {
@ -63,7 +64,7 @@ void HloToIrBindings::EmitBasePointersForHlos(
temp_buffer_base_->setName("temp_buffer");
for (const HloInstruction* non_io_hlo : non_io_hlos) {
if (already_bound_for_this_function.count(non_io_hlo)) {
if (already_bound_for_this_function.contains(non_io_hlo)) {
continue;
}
already_bound_for_this_function.insert(non_io_hlo);

View File

@ -18,6 +18,7 @@ limitations under the License.
#include <unordered_map>
#include "absl/container/flat_hash_map.h"
#include "absl/types/span.h"
#include "llvm/IR/IRBuilder.h"
#include "llvm/IR/Value.h"
@ -61,7 +62,7 @@ class HloToIrBindings {
// Returns whether `hlo` is bound to an LLVM IR value.
bool BoundToIrValue(const HloInstruction& hlo) const {
return base_ptrs_.count(&hlo);
return base_ptrs_.contains(&hlo);
}
llvm::Value* GetTempBufferBase() const { return temp_buffer_base_; }
@ -110,7 +111,8 @@ class HloToIrBindings {
// For an instruction that generates multiple outputs, the root will be a
// tuple shape. The IrArray for each element output is stored in the subnode
// in the ShapeTree.
std::unordered_map<const HloInstruction*, ShapeTree<llvm::Value*>> base_ptrs_;
absl::flat_hash_map<const HloInstruction*, ShapeTree<llvm::Value*>>
base_ptrs_;
// The address of the memory block that contains all temporary buffers.
llvm::Value* temp_buffer_base_ = nullptr;

View File

@ -67,7 +67,7 @@ int64 GpuMultiOutputFusion::GetProfit(HloInstruction* instr1,
}
int64 profit = 0;
for (auto instr : instr2->operands()) {
if (!IsProfitableOperand(instr) || in_list.count(instr) == 0) {
if (!IsProfitableOperand(instr) || !in_list.contains(instr)) {
continue;
}
profit += ShapeUtil::ByteSizeOf(instr->shape());

View File

@ -15,6 +15,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/gpu/stream_assignment.h"
#include "absl/container/flat_hash_set.h"
#include "absl/memory/memory.h"
#include "tensorflow/compiler/xla/map_util.h"
#include "tensorflow/compiler/xla/service/gpu/ir_emission_utils.h"
@ -25,7 +26,7 @@ namespace xla {
namespace gpu {
bool StreamAssignment::HasStreamAssigned(const HloInstruction& hlo) const {
return hlo_to_stream_number_.count(&hlo);
return hlo_to_stream_number_.contains(&hlo);
}
int StreamAssignment::StreamNumberForHlo(const HloInstruction& hlo) const {
@ -98,10 +99,10 @@ int ComputeStreamToAssign(
// greedy approach. First, we compute as forbidden_stream_numbers the
// streams assigned to GEMMs that are concurrent with `hlo`. Then, we assign
// `hlo` a different stream.
std::set<int> forbidden_stream_numbers;
absl::flat_hash_set<int> forbidden_stream_numbers;
for (const auto* seen_gemm : seen_gemms) {
int stream_num = stream_assignment.StreamNumberForHlo(*seen_gemm);
if (!forbidden_stream_numbers.count(stream_num) &&
if (!forbidden_stream_numbers.contains(stream_num) &&
CanRunConcurrently(*seen_gemm, hlo, reachability)) {
forbidden_stream_numbers.insert(stream_num);
}
@ -109,7 +110,7 @@ int ComputeStreamToAssign(
for (int stream_num = 0; stream_num < stream_assignment.StreamCount();
++stream_num) {
if (!forbidden_stream_numbers.count(stream_num)) {
if (!forbidden_stream_numbers.contains(stream_num)) {
return stream_num;
}
}

View File

@ -14,17 +14,19 @@ limitations under the License.
==============================================================================*/
#include "tensorflow/compiler/xla/service/gpu/thunk_schedule.h"
#include "absl/container/flat_hash_map.h"
#include "tensorflow/compiler/xla/array2d.h"
#include "tensorflow/compiler/xla/map_util.h"
#include "tensorflow/compiler/xla/types.h"
#include "tensorflow/core/lib/gtl/map_util.h"
namespace xla {
namespace gpu {
void ThunkSchedule::AddDependenciesOnTransitiveOperands(
const Thunk& thunk, const HloInstruction& operand,
const std::unordered_map<const HloInstruction*, Thunk*>& hlo_to_thunk) {
if (hlo_to_thunk.count(&operand)) {
const absl::flat_hash_map<const HloInstruction*, Thunk*>& hlo_to_thunk) {
if (hlo_to_thunk.contains(&operand)) {
// If `operand` is mapped to a thunk, adds `operand` to `thunk`'s dependency
// list if `operand` is assigned to a different stream. As an optimization,
// we skip `operand`'s operands because `operand` depends on them already.
@ -48,14 +50,14 @@ ThunkSchedule::ThunkSchedule(
const std::vector<HloInstruction*>& hlo_total_order)
: thunks_(std::move(thunks)),
stream_assignment_(std::move(stream_assignment)) {
std::unordered_map<const HloInstruction*, Thunk*> hlo_to_thunk;
absl::flat_hash_map<const HloInstruction*, Thunk*> hlo_to_thunk;
for (const auto& thunk : *thunks_) {
InsertOrDie(&hlo_to_thunk, thunk->hlo_instruction(), thunk.get());
}
for (HloInstruction* hlo : hlo_total_order) {
if (hlo_to_thunk.count(hlo)) {
thunk_total_order_.push_back(FindOrDie(hlo_to_thunk, hlo));
if (Thunk** thunk = tensorflow::gtl::FindOrNull(hlo_to_thunk, hlo)) {
thunk_total_order_.push_back(*thunk);
}
}
@ -106,7 +108,7 @@ void ThunkSchedule::RemoveRedundantDependencyEdges() {
// redundant dependency edge.
Array2D<int> last_dependency(stream_count, stream_count, -1);
for (const Thunk* dst : thunk_total_order_) {
if (!depends_on_.count(dst)) {
if (!depends_on_.contains(dst)) {
continue;
}
@ -134,7 +136,7 @@ void ThunkSchedule::RemoveRedundantDependencyEdges() {
const std::list<const Thunk*>& ThunkSchedule::DependsOn(
const Thunk* thunk) const {
if (depends_on_.count(thunk)) {
if (depends_on_.contains(thunk)) {
return FindOrDie(depends_on_, thunk);
} else {
return empty_thunk_list_;

View File

@ -21,6 +21,8 @@ limitations under the License.
#include <unordered_map>
#include <vector>
#include "absl/container/flat_hash_map.h"
#include "absl/container/flat_hash_set.h"
#include "tensorflow/compiler/xla/service/gpu/stream_assignment.h"
#include "tensorflow/compiler/xla/service/gpu/thunk.h"
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
@ -54,7 +56,9 @@ class ThunkSchedule {
// Thunks that `thunk` depends on.
const std::list<const Thunk*>& DependsOn(const Thunk* thunk) const;
// Whether `thunk` is depended by another thunk.
bool Depended(const Thunk* thunk) const { return depended_by_.count(thunk); }
bool Depended(const Thunk* thunk) const {
return depended_by_.contains(thunk);
}
// Delegates to StreamAssignment.
int StreamCount() const { return stream_assignment_->StreamCount(); }
@ -75,13 +79,13 @@ class ThunkSchedule {
// thunk.hlo_instruction().
void AddDependenciesOnTransitiveOperands(
const Thunk& thunk, const HloInstruction& operand,
const std::unordered_map<const HloInstruction*, Thunk*>& hlo_to_thunk);
const absl::flat_hash_map<const HloInstruction*, Thunk*>& hlo_to_thunk);
std::unique_ptr<ThunkSequence> thunks_;
std::vector<Thunk*> thunk_total_order_;
std::unordered_map<const Thunk*, std::list<const Thunk*>> depends_on_;
std::set<const Thunk*> depended_by_;
absl::flat_hash_map<const Thunk*, std::list<const Thunk*>> depends_on_;
absl::flat_hash_set<const Thunk*> depended_by_;
std::list<const Thunk*> empty_thunk_list_;
std::unique_ptr<StreamAssignment> stream_assignment_;

View File

@ -199,7 +199,7 @@ Status HeapSimulator::RunComputation(
// If the buffer has no users and isn't an entry parameter or output, it
// must be a dead value.
if (live_buffers.count(buffer) == 0) {
if (!live_buffers.contains(buffer)) {
dead_buffers_to_free.push_back(buffer);
}
}
@ -253,7 +253,7 @@ Status HeapSimulator::RunComputation(
bool shared = false;
if (options_.may_reuse_operand_buffers) {
for (const BufferValue* operand_buffer : operand_buffers_to_free) {
if (reused_buffers.count(operand_buffer) != 0) {
if (reused_buffers.contains(operand_buffer)) {
continue;
}
if (buffer->instruction()->IsUserOf(operand_buffer->instruction()) &&
@ -374,15 +374,15 @@ bool HeapSimulator::IgnoreBuffer(const BufferValue* buffer) const {
return true;
}
return options_.buffers_to_assign != nullptr &&
options_.buffers_to_assign->count(buffer) == 0;
!options_.buffers_to_assign->contains(buffer);
}
// Alloc always calls the underlying heap algorithm.
void HeapSimulator::Alloc(const BufferValue* buffer,
const HloInstruction* instruction) {
CHECK(allocated_buffers_.count(buffer) == 0)
CHECK(!allocated_buffers_.contains(buffer))
<< "Alloc called on allocated buffer: " << *buffer;
CHECK(freed_buffers_.count(buffer) == 0)
CHECK(!freed_buffers_.contains(buffer))
<< "Alloc called on freed buffer: " << *buffer;
allocated_buffers_.insert(buffer);
@ -411,9 +411,9 @@ void HeapSimulator::Free(const BufferValue* buffer,
buffer = group->canonical;
}
CHECK(allocated_buffers_.count(buffer) > 0)
CHECK(allocated_buffers_.contains(buffer))
<< "Free called on non-allocated buffer: " << *buffer;
CHECK(freed_buffers_.count(buffer) == 0)
CHECK(!freed_buffers_.contains(buffer))
<< "Free called on freed buffer: " << *buffer;
freed_buffers_.insert(buffer);
@ -433,11 +433,11 @@ void HeapSimulator::ShareBuffer(const BufferValue* buffer,
const HloInstruction* instruction) {
CHECK_LE(size_fn_(*buffer), size_fn_(*shared))
<< "ShareBuffer oversized buffer" << *buffer << " shared: " << *shared;
CHECK(allocated_buffers_.count(buffer) == 0)
CHECK(!allocated_buffers_.contains(buffer))
<< "ShareBuffer called on allocated buffer: " << *buffer;
CHECK(freed_buffers_.count(buffer) == 0)
CHECK(!freed_buffers_.contains(buffer))
<< "ShareBuffer called on freed buffer: " << *buffer;
CHECK(freed_buffers_.count(shared) == 0)
CHECK(!freed_buffers_.contains(shared))
<< "ShareBuffer called on freed shared buffer: " << *shared;
const BufferValue* canonical = nullptr;
@ -452,7 +452,7 @@ void HeapSimulator::ShareBuffer(const BufferValue* buffer,
} else {
// The 'shared' buffer doesn't have a group; it must be the canonical. Add
// both 'buffer' and 'shared' to a new group.
CHECK(allocated_buffers_.count(shared) > 0)
CHECK(allocated_buffers_.contains(shared))
<< "ShareBuffer called on non-allocated shared buffer: " << *shared;
auto group = std::make_shared<SharedGroup>();
canonical = shared;

View File

@ -207,14 +207,14 @@ Status HloComputation::RemoveInstructionAndUnusedOperands(
TF_RET_CHECK(instruction->user_count() == 0);
TF_RET_CHECK(IsRemovable(instruction))
<< "Cannot remove instruction: " << instruction->ToString();
std::unordered_set<HloInstruction*> removed;
absl::flat_hash_set<HloInstruction*> removed;
std::queue<HloInstruction*> worklist;
worklist.push(instruction);
while (!worklist.empty()) {
HloInstruction* item = worklist.front();
worklist.pop();
if (removed.count(item) != 0 || item->user_count() != 0 ||
if (removed.contains(item) || item->user_count() != 0 ||
item == root_instruction() || !IsRemovable(item) ||
(item->HasSideEffect() && item != instruction)) {
continue;
@ -694,13 +694,14 @@ bool HloComputation::operator==(const HloComputation& other) const {
if (this == &other) {
return true;
}
std::set<std::pair<const HloInstruction*, const HloInstruction*>> visited;
absl::flat_hash_set<std::pair<const HloInstruction*, const HloInstruction*>>
visited;
std::function<bool(const HloInstruction*, const HloInstruction*)> eq =
[&visited, &eq](const HloInstruction* a, const HloInstruction* b) {
// If <a,b> are visited but not identical, the recursion should have
// been aborted. So, if <a,b> are visited at this point, they must be
// identical.
if (visited.count(std::make_pair(a, b)) > 0) {
if (visited.contains(std::make_pair(a, b))) {
return true;
}
visited.emplace(a, b);
@ -803,13 +804,13 @@ Status HloComputation::AcceptOrdered(
<< root->ToString();
}
TF_RET_CHECK(order.size() == instruction_count());
std::unordered_set<const HloInstruction*> visited;
absl::flat_hash_set<const HloInstruction*> visited;
for (const HloInstruction* instruction : order) {
VLOG(3) << "Visiting ordered: " << instruction->ToString();
TF_RET_CHECK(instruction_iterators_.count(instruction) == 1)
TF_RET_CHECK(instruction_iterators_.contains(instruction))
<< "Instruction " << instruction->name() << " is not in computation "
<< name();
TF_RET_CHECK(visited.count(instruction) == 0)
TF_RET_CHECK(!visited.contains(instruction))
<< "Instruction " << instruction->name()
<< " appears more than once in order";
HloInstruction* mutable_instruction =

View File

@ -17,6 +17,7 @@ limitations under the License.
#include <set>
#include "absl/container/flat_hash_set.h"
#include "tensorflow/compiler/xla/literal.h"
#include "tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h"
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
@ -226,7 +227,7 @@ TEST_F(HloComputationTest, VisitWithMultipleRoots) {
: computation_(computation) {}
Status DefaultAction(HloInstruction* hlo_instruction) override {
EXPECT_EQ(0, visited_set_.count(hlo_instruction));
EXPECT_FALSE(visited_set_.contains(hlo_instruction));
visited_set_.insert(hlo_instruction);
last_visited_ = hlo_instruction;
return Status::OK();
@ -239,7 +240,7 @@ TEST_F(HloComputationTest, VisitWithMultipleRoots) {
}
HloComputation* computation_;
std::set<HloInstruction*> visited_set_;
absl::flat_hash_set<HloInstruction*> visited_set_;
int64 finish_visit_calls_ = 0;
HloInstruction* last_visited_ = nullptr;
};

View File

@ -107,7 +107,7 @@ bool HloDataflowAnalysis::AreTransitiveUsesElementwiseOrTuple(
return false;
}
}
if (!visited.count(user)) {
if (!visited.contains(user)) {
stack.push_back(user);
}
}

View File

@ -20,6 +20,7 @@ limitations under the License.
#include <utility>
#include <vector>
#include "absl/container/flat_hash_set.h"
#include "tensorflow/compiler/xla/service/hlo_computation.h"
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
#include "tensorflow/compiler/xla/service/hlo_module.h"
@ -65,7 +66,7 @@ StatusOr<bool> HloDCE::Run(HloModule* module) {
// Now DCE HloComputations. First, collect the computations that are
// referenced by some remaining instruction.
std::unordered_set<HloComputation*> live_computations;
absl::flat_hash_set<HloComputation*> live_computations;
if (HloComputation* entry_computation = module->entry_computation()) {
live_computations.insert(entry_computation);
}
@ -79,7 +80,7 @@ StatusOr<bool> HloDCE::Run(HloModule* module) {
// Remove dead computations.
for (auto* computation : module->MakeComputationPostOrder()) {
if (live_computations.count(computation) == 0) {
if (!live_computations.contains(computation)) {
TF_RETURN_IF_ERROR(module->RemoveEmbeddedComputation(computation));
changed = true;
}

View File

@ -24,9 +24,9 @@ limitations under the License.
#include <queue>
#include <string>
#include <tuple>
#include <unordered_map>
#include <vector>
#include "absl/container/flat_hash_map.h"
#include "absl/strings/match.h"
#include "absl/strings/str_cat.h"
#include "absl/strings/str_format.h"
@ -380,7 +380,7 @@ class HloDotDumper {
// Each HloInstruction dumped gets a monotically-increasing node ID. This
// must start at 1, because that's where graphviz's accounting starts.
int64 next_node_id_ = 1;
std::unordered_map<const HloInstruction*, int64> node_ids_;
absl::flat_hash_map<const HloInstruction*, int64> node_ids_;
// The "root" tag doesn't have an associated HloInstruction pointer, so we
// need to store it outside the map.
@ -397,7 +397,7 @@ class HloDotDumper {
// Each HloComputation that's emitted gets a monotonically-increasing ID.
int64 next_cluster_id_ = 1;
std::unordered_map<const HloComputation*, int64> cluster_ids_;
absl::flat_hash_map<const HloComputation*, int64> cluster_ids_;
// Edges to print from Footer(). Edges come at the end because graphviz is
// unhappy if an edge from a subcomputation to a node in the outer computation
@ -407,7 +407,7 @@ class HloDotDumper {
// When coloring by sharding information, we track the sharding string
// representation to color association, by round-robin the color schemes.
std::unordered_map<HloSharding, ColorScheme, HloSharding::Hasher>
absl::flat_hash_map<HloSharding, ColorScheme, HloSharding::Hasher>
sharding_colors_;
int64 next_shard_color_ = 0;
};
@ -1286,7 +1286,7 @@ NodeFilter MakeNodeRadiusAroundFilter(const HloInstruction* root,
int64 radius) {
// First, find the neighborhood of nodes with distance from root <= radius.
// These nodes are our initial set of "normal" nodes.
std::unordered_map<const HloInstruction*, NodeFilterResult> nodes;
absl::flat_hash_map<const HloInstruction*, NodeFilterResult> nodes;
std::deque<std::pair<const HloInstruction*, /*depth*/ int64>> worklist;
worklist.push_back({root, 0});
while (!worklist.empty()) {
@ -1307,7 +1307,7 @@ NodeFilter MakeNodeRadiusAroundFilter(const HloInstruction* root,
// are not interesting to the graph at hand.
if (instr == root || instr->opcode() != HloOpcode::kTuple) {
for (const HloInstruction* operand : instr->operands()) {
if (!nodes.count(operand)) {
if (!nodes.contains(operand)) {
worklist.push_back({operand, depth + 1});
}
}
@ -1335,7 +1335,7 @@ NodeFilter MakeNodeRadiusAroundFilter(const HloInstruction* root,
continue;
}
for (const HloInstruction* user : instr->users()) {
if (!nodes.count(user)) {
if (!nodes.contains(user)) {
worklist.push_back({user, depth + 1});
}
}
@ -1344,7 +1344,7 @@ NodeFilter MakeNodeRadiusAroundFilter(const HloInstruction* root,
auto is_displayed = [&](const HloInstruction* instr) {
// Constants are displayed inline with their users; they're never omitted.
// Nodes in subcomputations are always shown.
return nodes.count(instr) > 0 || instr->opcode() == HloOpcode::kConstant ||
return nodes.contains(instr) || instr->opcode() == HloOpcode::kConstant ||
instr->parent() != root->parent();
};

View File

@ -20,6 +20,7 @@ limitations under the License.
#include <utility>
#include <vector>
#include "absl/container/flat_hash_map.h"
#include "tensorflow/compiler/xla/literal.h"
#include "tensorflow/compiler/xla/protobuf_util.h"
#include "tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h"
@ -55,13 +56,13 @@ class OpAndUserCollectingVisitor : public DfsHloVisitorWithDefault {
}
Status HandleParameter(HloInstruction* parameter) override {
EXPECT_EQ(0, count_.count(parameter));
EXPECT_FALSE(count_.contains(parameter));
count_[parameter] = GetCountsForNode(parameter);
return Status::OK();
}
Status HandleConstant(HloInstruction* constant) override {
EXPECT_EQ(0, count_.count(constant));
EXPECT_FALSE(count_.contains(constant));
count_[constant] = GetCountsForNode(constant);
return Status::OK();
}
@ -69,25 +70,25 @@ class OpAndUserCollectingVisitor : public DfsHloVisitorWithDefault {
Status HandleAdd(HloInstruction* add) override {
auto lhs = add->operand(0);
auto rhs = add->operand(1);
EXPECT_EQ(0, count_.count(add));
EXPECT_GT(count_.count(lhs), 0);
EXPECT_GT(count_.count(rhs), 0);
EXPECT_FALSE(count_.contains(add));
EXPECT_TRUE(count_.contains(lhs));
EXPECT_TRUE(count_.contains(rhs));
count_[add] = GetCountsForNode(add);
return Status::OK();
}
Status HandleNegate(HloInstruction* negate) override {
auto operand = negate->operand(0);
EXPECT_EQ(0, count_.count(negate));
EXPECT_GT(count_.count(operand), 0);
EXPECT_FALSE(count_.contains(negate));
EXPECT_TRUE(count_.contains(operand));
count_[negate] = GetCountsForNode(negate);
return Status::OK();
}
Status HandleMap(HloInstruction* map) override {
EXPECT_EQ(0, count_.count(map));
EXPECT_FALSE(count_.contains(map));
for (HloInstruction* arg : map->operands()) {
EXPECT_GT(count_.count(arg), 0);
EXPECT_TRUE(count_.contains(arg));
}
count_[map] = GetCountsForNode(map);
return Status::OK();
@ -96,9 +97,9 @@ class OpAndUserCollectingVisitor : public DfsHloVisitorWithDefault {
Status HandleReduce(HloInstruction* reduce) override {
auto arg = reduce->operand(0);
auto init_value = reduce->operand(1);
EXPECT_EQ(0, count_.count(reduce));
EXPECT_GT(count_.count(arg), 0);
EXPECT_GT(count_.count(init_value), 0);
EXPECT_FALSE(count_.contains(reduce));
EXPECT_TRUE(count_.contains(arg));
EXPECT_TRUE(count_.contains(init_value));
count_[reduce] = GetCountsForNode(reduce);
return Status::OK();
}
@ -128,7 +129,7 @@ class OpAndUserCollectingVisitor : public DfsHloVisitorWithDefault {
}
// Counters for HLOs. Maps HLO to a NumOpsAndUsers.
std::unordered_map<const HloInstruction*, NumOpsAndUsers> count_;
absl::flat_hash_map<const HloInstruction*, NumOpsAndUsers> count_;
};
TEST_F(HloInstructionTest, BasicProperties) {
@ -137,7 +138,7 @@ TEST_F(HloInstructionTest, BasicProperties) {
EXPECT_EQ(HloOpcode::kParameter, parameter->opcode());
EXPECT_TRUE(ShapeUtil::IsScalarWithElementType(parameter->shape(), F32));
EXPECT_FALSE(ShapeUtil::IsScalarWithElementType(parameter->shape(), S32));
EXPECT_EQ(0, parameter->operand_count());
EXPECT_FALSE(parameter->operand_count());
}
TEST_F(HloInstructionTest, UserWithTwoOperands) {
@ -981,9 +982,9 @@ TEST_F(HloInstructionTest, FunctionVisitor) {
module->AddEntryComputation(builder.Build());
int visit_num = 0;
std::unordered_map<HloInstruction*, int> visit_order;
absl::flat_hash_map<HloInstruction*, int> visit_order;
EXPECT_IS_OK(add->Accept([&visit_num, &visit_order](HloInstruction* inst) {
EXPECT_EQ(0, visit_order.count(inst));
EXPECT_FALSE(visit_order.contains(inst));
visit_order[inst] = visit_num;
visit_num++;
return Status::OK();

View File

@ -17,6 +17,7 @@ limitations under the License.
#include <deque>
#include "absl/container/flat_hash_set.h"
#include "absl/memory/memory.h"
#include "absl/strings/str_cat.h"
#include "tensorflow/compiler/xla/map_util.h"
@ -36,11 +37,11 @@ namespace xla {
namespace {
using Worklist = std::deque<const HloInstruction*>;
using Workset = std::unordered_set<const HloInstruction*>;
using Workset = absl::flat_hash_set<const HloInstruction*>;
void AddToWorklist(const HloInstruction* instruction, Worklist* worklist,
Workset* workset) {
if (workset->count(instruction) == 0) {
if (!workset->contains(instruction)) {
worklist->push_back(instruction);
workset->insert(instruction);
VLOG(3) << "ADD instruction: " << instruction->name();

View File

@ -392,15 +392,12 @@ namespace {
// Returns whether `hlo` is used outside the given subcomputation.
// `instructions_in_subcomputation` is the instruction set of the given
// subcomputation.
bool IsUsedOutsideSubcomputation(
const HloInstruction& hlo,
const std::unordered_set<HloInstruction*>& instructions_in_subcomputation) {
for (HloInstruction* user : hlo.users()) {
if (!instructions_in_subcomputation.count(user)) {
return true;
}
}
return false;
bool IsUsedOutsideSubcomputation(const HloInstruction& hlo,
const absl::flat_hash_set<HloInstruction*>&
instructions_in_subcomputation) {
return absl::c_any_of(hlo.users(), [&](HloInstruction* user) {
return !instructions_in_subcomputation.contains(user);
});
}
} // anonymous namespace
@ -411,9 +408,9 @@ HloInstruction* HloModule::OutlineExpressionFromComputation(
// A map from original instructions to their counterparts in the new outlined
// function.
std::unordered_map<HloInstruction*, HloInstruction*> outlined_instructions;
absl::flat_hash_map<HloInstruction*, HloInstruction*> outlined_instructions;
// A set that contains all instructions to be outlined.
std::unordered_set<HloInstruction*> instruction_set_to_outline(
absl::flat_hash_set<HloInstruction*> instruction_set_to_outline(
instructions_to_outline.begin(), instructions_to_outline.end());
std::vector<HloInstruction*> arguments;
std::vector<HloInstruction*> outputs;
@ -502,7 +499,7 @@ std::vector<HloComputation*> HloModule::MakeComputationPostOrder() const {
// First determine all root computations by building a set of nonroot
// computations (computations which are called by an instruction in the
// module).
std::set<HloComputation*> nonroot_computations;
absl::flat_hash_set<HloComputation*> nonroot_computations;
for (auto& computation : computations_) {
for (auto* instruction : computation->instructions()) {
for (HloComputation* called_computation :
@ -515,19 +512,19 @@ std::vector<HloComputation*> HloModule::MakeComputationPostOrder() const {
// Keep track of computations which have already been added to the post
// order. This prevents duplication as an embedded computation may be called
// from two different root computations.
std::set<HloComputation*> added_computations;
absl::flat_hash_set<HloComputation*> added_computations;
std::vector<HloComputation*> post_order;
for (auto& computation : computations_) {
if (nonroot_computations.count(computation.get()) == 0) {
if (!nonroot_computations.contains(computation.get())) {
for (HloComputation* embedded_computation :
computation->MakeEmbeddedComputationsList()) {
if (added_computations.count(embedded_computation) == 0) {
if (!added_computations.contains(embedded_computation)) {
post_order.push_back(embedded_computation);
added_computations.insert(embedded_computation);
}
}
// Root computations should only be encountered once.
CHECK_EQ(0, added_computations.count(computation.get()));
CHECK(!added_computations.contains(computation.get()));
post_order.push_back(computation.get());
added_computations.insert(computation.get());
}

View File

@ -199,7 +199,7 @@ bool HloModuleGroupMetadata::IsChannelInstruction(
}
bool HloModuleGroupMetadata::IsCompanionInstruction(HloInstruction* hlo) const {
return companion_set_index_.count(hlo) > 0;
return companion_set_index_.contains(hlo);
}
bool HloModuleGroupMetadata::InstructionCommunicates(
@ -510,7 +510,7 @@ Status HloModuleGroupMetadata::CheckCommunicatingInstruction(
HloComputation* computation = instruction->parent();
const HloModule* module = computation->parent();
if (module->entry_computation() == computation ||
tracked_instructions_.count(computation) > 0) {
tracked_instructions_.contains(computation)) {
return Status::OK();
}
return FailedPrecondition("channel is used in disallowed computation");

View File

@ -178,7 +178,7 @@ class HloModuleGroupMetadata {
// Precondition: IsCompanionWhile(instruction) is true.
const std::vector<HloInstruction*>& Companions(
const HloInstruction* instruction) const {
CHECK_EQ(companion_set_index_.count(instruction), 1);
CHECK(companion_set_index_.contains(instruction));
return companion_set(companion_set_index_.at(instruction));
}

View File

@ -367,7 +367,7 @@ bool SequentialHloOrdering::ExecutesBeforeInSameComputation(
const HloInstruction* a, const HloInstruction* b) const {
CHECK_EQ(a->parent(), b->parent());
// If either instruction is not in the order, then 'a' and 'b' are unordered.
if (order_position_.count(a) == 0 || order_position_.count(b) == 0) {
if (!order_position_.contains(a) || !order_position_.contains(b)) {
return false;
}
return order_position_.at(a) < order_position_.at(b);

View File

@ -89,7 +89,7 @@ std::vector<HloPassInterface*> HloPassPipeline::GetEnabledPasses(
std::vector<HloPassInterface*> enabled_passes;
for (auto& pass : passes_) {
if (disabled_pass_names.count(string(pass->name())) == 0) {
if (!disabled_pass_names.contains(pass->name())) {
enabled_passes.push_back(pass.get());
}
}

View File

@ -140,7 +140,7 @@ Status HloSchedule::UpdateComputationSchedule(
std::queue<HloInstruction*> worklist;
for (HloInstruction* instruction : computation->instructions()) {
if (ids_in_schedule.count(instruction->unique_id()) == 0) {
if (!ids_in_schedule.contains(instruction->unique_id())) {
// This is a newly added instruction which is not in the schedule.
if (instruction->operands().empty()) {
worklist.push(instruction);
@ -204,7 +204,7 @@ Status HloSchedule::Update() {
std::vector<HloComputation*> nonfusion_computations =
module_->MakeNonfusionComputations();
for (const HloComputation* computation : nonfusion_computations) {
TF_RET_CHECK(sequences_.count(computation->unique_id()) == 1)
TF_RET_CHECK(sequences_.contains(computation->unique_id()))
<< "Computation " << computation->name() << " not in HloSchedule.";
}
if (sequences_.size() > nonfusion_computations.size()) {
@ -215,7 +215,7 @@ Status HloSchedule::Update() {
nonfusion_computations_ids.insert(computation->unique_id());
}
for (auto it = sequences_.begin(); it != sequences_.end();) {
if (nonfusion_computations_ids.count(it->first) == 0) {
if (!nonfusion_computations_ids.contains(it->first)) {
sequences_.erase(it++);
} else {
++it;
@ -244,7 +244,7 @@ Status HloSchedule::Verify() const {
<< "Schedule has " << sequences_.size() << " sequences, but module has "
<< nonfusion_computations.size() << " non-fusion computations";
for (const HloComputation* computation : nonfusion_computations) {
TF_RET_CHECK(sequences_.count(computation->unique_id()) == 1)
TF_RET_CHECK(sequences_.contains(computation->unique_id()))
<< "Computation " << computation->name()
<< " missing from HLO schedule.";
}
@ -268,7 +268,7 @@ Status HloSchedule::Verify() const {
<< instruction_position.size() << " instructions, expected "
<< computation->instruction_count();
for (const HloInstruction* instruction : computation->instructions()) {
TF_RET_CHECK(instruction_position.count(instruction) == 1)
TF_RET_CHECK(instruction_position.contains(instruction))
<< "Instruction " << instruction->name() << " is not in schedule";
}

View File

@ -110,7 +110,7 @@ class HloSchedule {
// Returns true if the schedule has a sequence for the given computation.
bool is_computation_scheduled(const HloComputation* computation) const {
return sequences_.count(computation->unique_id()) == 1;
return sequences_.contains(computation->unique_id());
}
// Updates the schedule such that it is (again) a valid schedule for the

View File

@ -15,6 +15,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/hlo_sharding.h"
#include "absl/container/flat_hash_set.h"
#include "absl/strings/str_cat.h"
#include "absl/strings/str_join.h"
#include "tensorflow/compiler/xla/overflow_util.h"
@ -316,7 +317,7 @@ Status HloSharding::ValidateNonTuple(const Shape& shape,
// All tile assignments must be less than the number of available cores and
// unique.
Status status = Status::OK();
std::set<int64> seen_cores;
absl::flat_hash_set<int64> seen_cores;
tile_assignment_.Each(
[&](absl::Span<const int64> indices, int32 core) {
// Don't overwrite a bad status, so we report the first error.
@ -324,7 +325,7 @@ Status HloSharding::ValidateNonTuple(const Shape& shape,
if (core >= num_devices) {
status = tensorflow::errors::InvalidArgument(StrCat(
"core ", core, " > ", num_devices, " in tile assignment"));
} else if (seen_cores.count(core) != 0) {
} else if (seen_cores.contains(core)) {
status = tensorflow::errors::InvalidArgument(
StrCat("core ", core, " is not unique in tile assignment"));
}

View File

@ -99,7 +99,7 @@ std::vector<PassThrough> LocatePassThroughDomainLinks(
<< "Instruction is not a kDomain: " << instruction->ToString();
for (HloInstruction* user : instruction->users()) {
if (user->opcode() == HloOpcode::kDomain &&
domain.exit_domains.count(user) != 0) {
domain.exit_domains.contains(user)) {
pass_through.emplace_back(user, instruction);
VLOG(2) << "Found passthrough domain link:";
VLOG(2) << " " << user->ToString();
@ -253,7 +253,7 @@ StatusOr<bool> ApplyShardingFromUsers(HloInstruction* instruction,
instruction->shape(), HloSharding::AssignDevice(kUnassignedDevice));
for (HloInstruction* user : instruction->users()) {
if (user->opcode() == HloOpcode::kDomain &&
domain.exit_domains.count(user) > 0) {
domain.exit_domains.contains(user)) {
// If a user is a domain and it is registered in the domain exits, then
// the instruction sharding is taken directly from the domain, and no
// further users need to be visited.

View File

@ -103,7 +103,7 @@ Status IndexedArrayAnalysis::TraverseAndPopulateCache(
do {
const HloInstruction* instr = stack.back();
if (cache_.count(instr)) {
if (cache_.contains(instr)) {
stack.pop_back();
continue;
}
@ -111,9 +111,9 @@ Status IndexedArrayAnalysis::TraverseAndPopulateCache(
switch (FindOrDie(dfs_state_map, instr)) {
case kDiscovered: {
for (const HloInstruction* operand : instr->operands()) {
if (!cache_.count(operand)) {
if (!cache_.contains(operand)) {
stack.push_back(operand);
CHECK(!dfs_state_map.count(operand) ||
CHECK(!dfs_state_map.contains(operand) ||
dfs_state_map[operand] == kDiscovered);
dfs_state_map[operand] = kDiscovered;
}

View File

@ -2135,7 +2135,7 @@ Status LayoutAssignment::ClearPreviousPassSideEffects(HloModule* module) {
for (HloInstruction* instruction :
computation->MakeInstructionPostOrder()) {
if (instruction->opcode() == HloOpcode::kCopy &&
added_copies_.count(instruction) > 0) {
added_copies_.contains(instruction)) {
VLOG(5) << "Removing added copy: " << instruction->ToString();
TF_RETURN_IF_ERROR(
instruction->ReplaceAllUsesWith(instruction->mutable_operand(0)));

View File

@ -243,7 +243,7 @@ class ChannelLayoutConstraints {
// Returns true if channel_id has a layout constraint.
bool IsChannelConstrained(int64 channel_id) const {
return constraints_.count(channel_id) > 0;
return constraints_.contains(channel_id);
}
// Given `shape`, apply the layout for `channel_id`. `channel_id` must already
@ -276,7 +276,7 @@ class ChannelLayoutConstraints {
}
private:
std::unordered_map<int64, Layout> constraints_;
absl::flat_hash_map<int64, Layout> constraints_;
};
// HLO pass which assigns layouts to all instructions in the HLO module while

View File

@ -169,6 +169,7 @@ cc_library(
"//tensorflow/compiler/xla/service:elemental_ir_emitter",
"//tensorflow/compiler/xla/service:hlo",
"//tensorflow/core:lib",
"@com_google_absl//absl/container:flat_hash_map",
"@com_google_absl//absl/types:optional",
"@com_google_absl//absl/types:span",
"@llvm//:core",

View File

@ -76,15 +76,12 @@ class AliasAnalysis {
// A map from a buffer slice to metadata corresponding to its alias.scope
// metadata. The index kParameterAliasSet is used to hold aliasing
// information for parameters.
absl::flat_hash_map<BufferAllocation::Slice, llvm::MDNode*,
BufferAllocation::Slice::Hasher>
absl::flat_hash_map<BufferAllocation::Slice, llvm::MDNode*>
alias_scope_metadata_;
// A map from a buffer slice to metadata corresponding to its noalias
// metadata.
absl::flat_hash_map<BufferAllocation::Slice, llvm::MDNode*,
BufferAllocation::Slice::Hasher>
noalias_metadata_;
absl::flat_hash_map<BufferAllocation::Slice, llvm::MDNode*> noalias_metadata_;
};
} // namespace llvm_ir

View File

@ -35,7 +35,7 @@ using llvm_ir::IrArray;
Status FusedIrEmitter::DefaultAction(HloInstruction* hlo) {
indexed_generators_[hlo] =
[=](const IrArray::Index& index) -> StatusOr<llvm::Value*> {
if (generated_value_cache_[hlo].count(index.multidim()) > 0) {
if (generated_value_cache_[hlo].contains(index.multidim())) {
llvm::Value* generated_value =
generated_value_cache_[hlo][index.multidim()];
llvm::BasicBlock* generated_value_bb = nullptr;

View File

@ -19,6 +19,7 @@ limitations under the License.
#include <map>
#include <unordered_map>
#include "absl/container/flat_hash_map.h"
#include "absl/types/optional.h"
#include "absl/types/span.h"
#include "llvm/IR/IRBuilder.h"
@ -134,8 +135,9 @@ class FusedIrEmitter : public DfsHloVisitorWithDefault {
// Cache of generated values, lest we regenerate an element of a node with
// multiple outgoing edges
std::unordered_map<const HloInstruction*,
std::map<std::vector<llvm::Value*>, llvm::Value*>>
absl::flat_hash_map<
const HloInstruction*,
absl::flat_hash_map<std::vector<llvm::Value*>, llvm::Value*>>
generated_value_cache_;
};

View File

@ -198,7 +198,7 @@ void MultiOutputFusion::Update(HloInstruction* instr1, HloInstruction* instr2) {
if (instr == fusion || is_fused(instr) || is_connected(fusion, instr)) {
continue;
}
if (in_list.count(instr) > 0) {
if (in_list.contains(instr)) {
continue;
}
int64 profit = GetProfit(instr, fusion);

View File

@ -19,6 +19,7 @@ limitations under the License.
#include <utility>
#include <vector>
#include "absl/container/flat_hash_set.h"
#include "absl/memory/memory.h"
#include "absl/strings/str_cat.h"
#include "absl/strings/str_format.h"
@ -55,11 +56,10 @@ bool PointsToSet::IsAmbiguous() const {
bool PointsToSet::IsDistinct() const {
bool distinct = true;
std::set<const LogicalBuffer*> all_points_to;
ForEachElement([&distinct, &all_points_to](const ShapeIndex& /*index*/,
const BufferList& points_to) {
absl::flat_hash_set<const LogicalBuffer*> all_points_to;
ForEachElement([&](const ShapeIndex& /*index*/, const BufferList& points_to) {
for (auto& buffer : points_to) {
if (all_points_to.count(buffer) != 0) {
if (all_points_to.contains(buffer)) {
distinct = false;
}
all_points_to.insert(buffer);

View File

@ -89,7 +89,7 @@ static void CreateLoopInvariantCopy(
HloInstruction* next_operand =
frame->instruction->mutable_operand(frame->operand_index++);
if (hoisted_instructions->count(next_operand) ||
if (hoisted_instructions->contains(next_operand) ||
next_operand == while_body_param) {
continue;
}
@ -241,7 +241,7 @@ WhileLoopInvariantCodeMotion::TryHoistingInvariantInstructionsFromWhileBody(
auto is_invariant = [&](HloInstruction* op) {
return hoisted_instructions.find(op) != hoisted_instructions.end() ||
unhoisted_invariant_instructions.count(op) ||
unhoisted_invariant_instructions.contains(op) ||
op->opcode() == HloOpcode::kConstant;
};

View File

@ -127,7 +127,7 @@ static StatusOr<bool> TryRemoveDeadWhileParams(HloInstruction* while_op) {
// through to the while body's root, count that element as "used", since
// removing that element would be observable.
for (int64 i = 0; i < while_body_root->operand_count(); ++i) {
if (used_tuple_indices.count(i)) {
if (used_tuple_indices.contains(i)) {
continue;
}