[XLA] Use absl::flat_hash_{map,set}::contains() instead of count().
contains() is a much better description of what we're trying to do! Also change a few std containers over to absl containers so we can advantage of this. PiperOrigin-RevId: 227894609
This commit is contained in:
parent
f1d4d18a62
commit
f9bd1568aa
@ -741,6 +741,7 @@ cc_library(
|
||||
"//tensorflow/compiler/xla/service:hlo_evaluator",
|
||||
"//tensorflow/compiler/xla/service:shape_inference",
|
||||
"//tensorflow/core:lib",
|
||||
"@com_google_absl//absl/container:flat_hash_set",
|
||||
"@com_google_absl//absl/memory",
|
||||
"@com_google_absl//absl/types:span",
|
||||
],
|
||||
|
@ -22,6 +22,8 @@ limitations under the License.
|
||||
#include <utility>
|
||||
|
||||
#include "absl/algorithm/container.h"
|
||||
#include "absl/container/flat_hash_map.h"
|
||||
#include "absl/container/flat_hash_set.h"
|
||||
#include "absl/memory/memory.h"
|
||||
#include "absl/strings/match.h"
|
||||
#include "absl/strings/str_cat.h"
|
||||
@ -193,9 +195,9 @@ StatusOr<ProgramShape> XlaBuilder::GetProgramShape(XlaOp root) const {
|
||||
}
|
||||
|
||||
void XlaBuilder::IsConstantVisitor(const int64 op_handle,
|
||||
std::set<int64>* visited,
|
||||
absl::flat_hash_set<int64>* visited,
|
||||
bool* is_constant) const {
|
||||
if (visited->count(op_handle) != 0 || !*is_constant) {
|
||||
if (visited->contains(op_handle) || !*is_constant) {
|
||||
return;
|
||||
}
|
||||
|
||||
@ -2415,7 +2417,7 @@ StatusOr<bool> XlaBuilder::IsConstant(const XlaOp& operand) const {
|
||||
TF_RETURN_IF_ERROR(LookUpInstruction(operand).status());
|
||||
|
||||
bool is_constant = true;
|
||||
std::set<int64> visited;
|
||||
absl::flat_hash_set<int64> visited;
|
||||
IsConstantVisitor(operand.handle(), &visited, &is_constant);
|
||||
return is_constant;
|
||||
}
|
||||
|
@ -727,7 +727,8 @@ class XlaBuilder {
|
||||
// operation such as `RngNormal` or `Infeed`. The visitor walks the
|
||||
// computation starting at a given operation and sets is_constant to false iff
|
||||
// a parameter or stateful operation is encountered.
|
||||
void IsConstantVisitor(const int64 op_handle, std::set<int64>* visited,
|
||||
void IsConstantVisitor(const int64 op_handle,
|
||||
absl::flat_hash_set<int64>* visited,
|
||||
bool* is_constant) const;
|
||||
|
||||
// Checks bounds for convolution parameters.
|
||||
|
@ -18,6 +18,7 @@ limitations under the License.
|
||||
#include <array>
|
||||
#include <utility>
|
||||
|
||||
#include "absl/container/flat_hash_set.h"
|
||||
#include "absl/memory/memory.h"
|
||||
#include "tensorflow/compiler/xla/client/xla_builder.h"
|
||||
#include "tensorflow/compiler/xla/literal_util.h"
|
||||
@ -605,24 +606,26 @@ ReferenceUtil::ReduceToRowArray2D(
|
||||
const std::function<float(float, float)>& reduce_function) {
|
||||
std::vector<float> result;
|
||||
CHECK_EQ(dims.size(), 3);
|
||||
const std::set<int64> dim_set(dims.begin(), dims.end());
|
||||
const absl::flat_hash_set<int64> dim_set(dims.begin(), dims.end());
|
||||
CHECK_EQ(dim_set.size(), 3);
|
||||
for (int64 a0 = 0; a0 == 0 || (!dim_set.count(0) && a0 < array.n1()); ++a0) {
|
||||
for (int64 a1 = 0; a1 == 0 || (!dim_set.count(1) && a1 < array.n2());
|
||||
for (int64 a0 = 0; a0 == 0 || (!dim_set.contains(0) && a0 < array.n1());
|
||||
++a0) {
|
||||
for (int64 a1 = 0; a1 == 0 || (!dim_set.contains(1) && a1 < array.n2());
|
||||
++a1) {
|
||||
for (int64 a2 = 0; a2 == 0 || (!dim_set.count(2) && a2 < array.n3());
|
||||
for (int64 a2 = 0; a2 == 0 || (!dim_set.contains(2) && a2 < array.n3());
|
||||
++a2) {
|
||||
for (int64 a3 = 0; a3 == 0 || (!dim_set.count(3) && a3 < array.n4());
|
||||
for (int64 a3 = 0; a3 == 0 || (!dim_set.contains(3) && a3 < array.n4());
|
||||
++a3) {
|
||||
float accumulator = init;
|
||||
for (int64 i0 = 0; i0 == 0 || (dim_set.count(0) && i0 < array.n1());
|
||||
++i0) {
|
||||
for (int64 i1 = 0; i1 == 0 || (dim_set.count(1) && i1 < array.n2());
|
||||
++i1) {
|
||||
for (int64 i0 = 0;
|
||||
i0 == 0 || (dim_set.contains(0) && i0 < array.n1()); ++i0) {
|
||||
for (int64 i1 = 0;
|
||||
i1 == 0 || (dim_set.contains(1) && i1 < array.n2()); ++i1) {
|
||||
for (int64 i2 = 0;
|
||||
i2 == 0 || (dim_set.count(2) && i2 < array.n3()); ++i2) {
|
||||
i2 == 0 || (dim_set.contains(2) && i2 < array.n3()); ++i2) {
|
||||
for (int64 i3 = 0;
|
||||
i3 == 0 || (dim_set.count(3) && i3 < array.n4()); ++i3) {
|
||||
i3 == 0 || (dim_set.contains(3) && i3 < array.n4());
|
||||
++i3) {
|
||||
// Handle zero-sized arrays.
|
||||
if (array.n1() > 0 && array.n2() > 0 && array.n3() > 0 &&
|
||||
array.n4() > 0) {
|
||||
|
@ -515,6 +515,7 @@ tf_cc_test(
|
||||
"//tensorflow/compiler/xla:window_util",
|
||||
"//tensorflow/compiler/xla/tests:hlo_test_base",
|
||||
"//tensorflow/compiler/xla/tests:xla_internal_test_main",
|
||||
"@com_google_absl//absl/container:flat_hash_map",
|
||||
],
|
||||
)
|
||||
|
||||
@ -677,6 +678,7 @@ cc_library(
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:stream_executor_no_cuda",
|
||||
"//third_party/eigen3",
|
||||
"@com_google_absl//absl/container:flat_hash_map",
|
||||
"@com_google_absl//absl/memory",
|
||||
"@com_google_absl//absl/strings",
|
||||
"@com_google_absl//absl/types:span",
|
||||
@ -1002,6 +1004,7 @@ cc_library(
|
||||
"//tensorflow/compiler/xla:xla_data_proto",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:lib_internal",
|
||||
"@com_google_absl//absl/container:flat_hash_map",
|
||||
"@com_google_absl//absl/memory",
|
||||
"@com_google_absl//absl/strings",
|
||||
"@com_google_absl//absl/types:span",
|
||||
@ -1136,6 +1139,7 @@ tf_cc_test(
|
||||
"//tensorflow/compiler/xla/tests:xla_internal_test_main",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:test",
|
||||
"@com_google_absl//absl/container:flat_hash_set",
|
||||
"@com_google_absl//absl/memory",
|
||||
],
|
||||
)
|
||||
@ -1580,6 +1584,7 @@ cc_library(
|
||||
"//tensorflow/core:lib",
|
||||
"@com_google_absl//absl/algorithm:container",
|
||||
"@com_google_absl//absl/container:flat_hash_map",
|
||||
"@com_google_absl//absl/container:flat_hash_set",
|
||||
"@com_google_absl//absl/container:inlined_vector",
|
||||
"@com_google_absl//absl/memory",
|
||||
"@com_google_absl//absl/strings",
|
||||
@ -2116,6 +2121,7 @@ tf_cc_test(
|
||||
"//tensorflow/compiler/xla:test_helpers",
|
||||
"//tensorflow/compiler/xla/tests:hlo_test_base",
|
||||
"//tensorflow/compiler/xla/tests:xla_internal_test_main",
|
||||
"@com_google_absl//absl/container:flat_hash_set",
|
||||
],
|
||||
)
|
||||
|
||||
@ -2288,6 +2294,7 @@ cc_library(
|
||||
"//tensorflow/compiler/xla:types",
|
||||
"//tensorflow/compiler/xla:util",
|
||||
"//tensorflow/core:lib",
|
||||
"@com_google_absl//absl/container:flat_hash_set",
|
||||
"@com_google_absl//absl/memory",
|
||||
"@com_google_absl//absl/strings",
|
||||
],
|
||||
@ -2548,6 +2555,7 @@ cc_library(
|
||||
"//tensorflow/compiler/xla:types",
|
||||
"//tensorflow/compiler/xla:util",
|
||||
"//tensorflow/core:lib",
|
||||
"@com_google_absl//absl/container:flat_hash_set",
|
||||
],
|
||||
)
|
||||
|
||||
@ -3188,6 +3196,7 @@ cc_library(
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:lib_internal",
|
||||
"//tensorflow/core:regexp_internal",
|
||||
"@com_google_absl//absl/container:flat_hash_map",
|
||||
"@com_google_absl//absl/strings",
|
||||
"@com_google_absl//absl/strings:str_format",
|
||||
"@com_google_absl//absl/types:optional",
|
||||
|
@ -27,6 +27,7 @@ limitations under the License.
|
||||
|
||||
#include "absl/algorithm/container.h"
|
||||
#include "absl/container/flat_hash_map.h"
|
||||
#include "absl/container/flat_hash_set.h"
|
||||
#include "absl/container/inlined_vector.h"
|
||||
#include "absl/memory/memory.h"
|
||||
#include "absl/strings/str_cat.h"
|
||||
@ -2539,7 +2540,7 @@ Status AlgebraicSimplifierVisitor::HandleReduce(HloInstruction* reduce) {
|
||||
}
|
||||
if (can_move_reshape_into_reduce) {
|
||||
changed_ = true;
|
||||
std::unordered_set<int64> dimensions_not_to_reduce;
|
||||
absl::flat_hash_set<int64> dimensions_not_to_reduce;
|
||||
for (auto dim_pair : unmodified_dims) {
|
||||
if (arg_dim_in_output[dim_pair.second]) {
|
||||
dimensions_not_to_reduce.insert(dim_pair.first);
|
||||
@ -2547,7 +2548,7 @@ Status AlgebraicSimplifierVisitor::HandleReduce(HloInstruction* reduce) {
|
||||
}
|
||||
std::vector<int64> new_reduce_dimensions;
|
||||
for (int64 i = 0; i < arg->operand(0)->shape().rank(); ++i) {
|
||||
if (dimensions_not_to_reduce.count(i) == 0) {
|
||||
if (!dimensions_not_to_reduce.contains(i)) {
|
||||
new_reduce_dimensions.push_back(i);
|
||||
}
|
||||
}
|
||||
|
@ -115,12 +115,10 @@ StatusOr<StreamPool::Ptr> Backend::BorrowStream(int device_ordinal) {
|
||||
|
||||
StatusOr<StreamPool::Ptr> Backend::BorrowStream(se::StreamExecutor* executor) {
|
||||
tensorflow::mutex_lock l(mu_);
|
||||
if (0 == stream_pools_.count(executor)) {
|
||||
stream_pools_.emplace(std::piecewise_construct,
|
||||
std::forward_as_tuple(executor),
|
||||
std::forward_as_tuple());
|
||||
if (!stream_pools_.contains(executor)) {
|
||||
stream_pools_.emplace(executor, absl::make_unique<StreamPool>());
|
||||
}
|
||||
return stream_pools_.at(executor).BorrowStream(executor);
|
||||
return stream_pools_.at(executor)->BorrowStream(executor);
|
||||
}
|
||||
|
||||
Backend::Backend(se::Platform* platform, Compiler* compiler,
|
||||
|
@ -22,6 +22,7 @@ limitations under the License.
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
#include "absl/container/flat_hash_map.h"
|
||||
#include "absl/strings/str_cat.h"
|
||||
#include "absl/types/span.h"
|
||||
#include "tensorflow/compiler/xla/service/compiler.h"
|
||||
@ -175,7 +176,8 @@ class Backend {
|
||||
tensorflow::mutex mu_;
|
||||
|
||||
// Mapping from stream executor to stream pools, used by `BorrowStream` above.
|
||||
std::map<se::StreamExecutor*, StreamPool> stream_pools_ GUARDED_BY(mu_);
|
||||
absl::flat_hash_map<se::StreamExecutor*, std::unique_ptr<StreamPool>>
|
||||
stream_pools_ GUARDED_BY(mu_);
|
||||
|
||||
// The default memory allocator to use.
|
||||
std::unique_ptr<StreamExecutorMemoryAllocator> memory_allocator_;
|
||||
|
@ -138,8 +138,8 @@ Status GatherComputationsByAllocationType(
|
||||
worklist.pop_front();
|
||||
const HloComputation* computation = worklist_front.first;
|
||||
bool is_thread_local = worklist_front.second;
|
||||
bool in_thread_local_set = thread_local_set.count(computation) > 0;
|
||||
bool in_global_set = global_set.count(computation) > 0;
|
||||
bool in_thread_local_set = thread_local_set.contains(computation);
|
||||
bool in_global_set = global_set.contains(computation);
|
||||
|
||||
// If the computation has already been added to the respective set, then
|
||||
// nothing to do.
|
||||
@ -207,9 +207,9 @@ Status GatherComputationsByAllocationType(
|
||||
|
||||
// Add the computations to the vectors in post order.
|
||||
for (auto* computation : module->MakeComputationPostOrder()) {
|
||||
if (thread_local_set.count(computation) > 0) {
|
||||
if (thread_local_set.contains(computation)) {
|
||||
thread_local_computations->push_back(computation);
|
||||
} else if (global_set.count(computation) > 0) {
|
||||
} else if (global_set.contains(computation)) {
|
||||
global_computations->push_back(computation);
|
||||
}
|
||||
// If the computation is not reachable from the entry computation, then it
|
||||
@ -219,13 +219,6 @@ Status GatherComputationsByAllocationType(
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
size_t BufferAllocation::Slice::Hasher::operator()(Slice s) const {
|
||||
uint64 h = std::hash<int64>()(s.index());
|
||||
h = tensorflow::Hash64Combine(h, std::hash<int64>()(s.offset()));
|
||||
h = tensorflow::Hash64Combine(h, std::hash<int64>()(s.size()));
|
||||
return h;
|
||||
}
|
||||
|
||||
string BufferAllocation::Slice::ToString() const {
|
||||
return absl::StrCat("{index:", index(), ", offset:", offset_,
|
||||
", size:", size_, "}");
|
||||
@ -240,7 +233,7 @@ BufferAllocation::Slice BufferAllocation::GetSlice(
|
||||
void BufferAllocation::AddAssignment(const LogicalBuffer& buffer, int64 offset,
|
||||
int64 size) {
|
||||
VLOG(4) << "Trying to add " << buffer << " to allocation #" << index();
|
||||
CHECK(assigned_buffers_.count(&buffer) == 0)
|
||||
CHECK(!assigned_buffers_.contains(&buffer))
|
||||
<< "LogicalBuffer " << buffer << " already assigned to allocation "
|
||||
<< index_;
|
||||
CHECK_LE(offset, size_) << "LogicalBuffer " << buffer
|
||||
@ -346,7 +339,7 @@ const PointsToSet& BufferAssignment::GetPointsToSet(
|
||||
|
||||
bool BufferAssignment::HasAllocation(const LogicalBuffer& buffer) const {
|
||||
TF_CHECK_OK(points_to_analysis().VerifyBuffer(buffer));
|
||||
return allocation_index_for_buffer_.count(&buffer) > 0;
|
||||
return allocation_index_for_buffer_.contains(&buffer);
|
||||
}
|
||||
|
||||
const BufferAllocation& BufferAssignment::GetAssignedAllocation(
|
||||
@ -401,7 +394,7 @@ bool BufferAssignment::HasAllocationAt(const HloInstruction* instruction,
|
||||
const ShapeIndex& index) const {
|
||||
for (const LogicalBuffer* buffer :
|
||||
GetPointsToSet(instruction).element(index)) {
|
||||
if (allocation_index_for_buffer_.count(buffer) > 0) {
|
||||
if (allocation_index_for_buffer_.contains(buffer)) {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
@ -459,8 +452,7 @@ bool BufferAssignment::SharesSliceAtIndex(
|
||||
|
||||
bool BufferAssignment::HaveDisjointSlices(const HloInstruction* hlo_a,
|
||||
const HloInstruction* hlo_b) const {
|
||||
using SliceSet =
|
||||
flat_hash_set<BufferAllocation::Slice, BufferAllocation::Slice::Hasher>;
|
||||
using SliceSet = flat_hash_set<BufferAllocation::Slice>;
|
||||
// Gets the slices all of instr's subshapes. If any subshape doesn't have an
|
||||
// assigned slice, returns the empty set.
|
||||
auto collect_slices = [&](const HloInstruction* instr) -> SliceSet {
|
||||
@ -519,7 +511,7 @@ BufferAllocation* BufferAssignment::NewAllocation(const LogicalBuffer& buffer,
|
||||
void BufferAssignment::AddAssignment(BufferAllocation* allocation,
|
||||
const LogicalBuffer& buffer, int64 offset,
|
||||
int64 size) {
|
||||
CHECK_EQ(0, allocation_index_for_buffer_.count(&buffer))
|
||||
CHECK(!allocation_index_for_buffer_.contains(&buffer))
|
||||
<< "LogicalBuffer " << buffer << " already has an allocation.";
|
||||
CHECK(allocation->is_reusable() || allocation->assigned_buffers().empty())
|
||||
<< "Non-reusable allocation already assigned a buffer: "
|
||||
@ -988,7 +980,7 @@ Status BufferAssigner::AssignBuffersForComputation(
|
||||
std::vector<BufferAllocation::Index> allocation_indices;
|
||||
for (const LogicalBuffer* buffer : sorted_buffers) {
|
||||
VLOG(3) << "Assigning allocation to: " << *buffer;
|
||||
if (colocated_buffers.count(buffer) > 0) {
|
||||
if (colocated_buffers.contains(buffer)) {
|
||||
// Colocated buffers are currently assigned in an earlier pass.
|
||||
VLOG(3) << "Skipping colocated buffer: " << *buffer;
|
||||
continue;
|
||||
@ -1056,7 +1048,7 @@ Status BufferAssigner::AssignBuffersForComputation(
|
||||
assignment->GetAllSlices(operand, /*index=*/{})) {
|
||||
BufferAllocation* allocation =
|
||||
assignment->GetMutableAllocation(operand_slice.index());
|
||||
if (colocated_allocations.count(allocation->index()) == 0) {
|
||||
if (!colocated_allocations.contains(allocation->index())) {
|
||||
// TODO(b/32491382) Colocated buffers are currently assigned in an
|
||||
// earlier pass, and so can break the "increasing allocation size"
|
||||
// invariant in this function (causing this CHECK to fail). However,
|
||||
@ -1087,7 +1079,7 @@ Status BufferAssigner::AssignBuffersForComputation(
|
||||
// Instructions are iterated in increasing buffer size, so any
|
||||
// previously create allocation must be large enough to hold this
|
||||
// instruction's output (with the exception of colocated buffers).
|
||||
if (colocated_allocations.count(allocation->index()) == 0) {
|
||||
if (!colocated_allocations.contains(allocation->index())) {
|
||||
// TODO(b/32491382) Colocated buffers are currently assigned in an
|
||||
// earlier pass, and so can break the "increasing allocation size"
|
||||
// invariant in this function (causing this CHECK to fail). However,
|
||||
@ -1376,7 +1368,7 @@ void BufferAssigner::AddSetToColocatedBufferSets(
|
||||
std::vector<size_t> overlap_set_indices;
|
||||
for (size_t index = 0; index < colocated_buffer_sets->size(); ++index) {
|
||||
for (const LogicalBuffer* buffer : colocated_set) {
|
||||
if ((*colocated_buffer_sets)[index].count(buffer) > 0) {
|
||||
if ((*colocated_buffer_sets)[index].contains(buffer)) {
|
||||
VLOG(5) << "Found overlap with existing set on buffer "
|
||||
<< buffer->ToString() << "\n"
|
||||
<< ColocatedBufferSetsToString((*colocated_buffer_sets)[index],
|
||||
|
@ -186,9 +186,10 @@ class BufferAllocation {
|
||||
end > other.offset_;
|
||||
}
|
||||
|
||||
struct Hasher {
|
||||
size_t operator()(Slice s) const;
|
||||
};
|
||||
template <typename H>
|
||||
friend H AbslHashValue(H h, const Slice& s) {
|
||||
return H::combine(std::move(h), s.index(), s.offset(), s.size());
|
||||
}
|
||||
|
||||
string ToString() const;
|
||||
|
||||
|
@ -21,6 +21,7 @@ limitations under the License.
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
|
||||
#include "absl/container/flat_hash_set.h"
|
||||
#include "absl/memory/memory.h"
|
||||
#include "tensorflow/compiler/xla/literal.h"
|
||||
#include "tensorflow/compiler/xla/service/buffer_value.h"
|
||||
@ -309,7 +310,7 @@ class BufferAssignmentTest : public HloTestBase {
|
||||
static bool BuffersDistinct(const std::vector<const HloInstruction*>& a,
|
||||
const std::vector<const HloInstruction*>& b,
|
||||
const BufferAssignment& assignment) {
|
||||
std::set<BufferAllocation::Slice> a_slices;
|
||||
absl::flat_hash_set<BufferAllocation::Slice> a_slices;
|
||||
for (const HloInstruction* instruction : a) {
|
||||
if (assignment.HasTopLevelAllocation(instruction)) {
|
||||
a_slices.insert(
|
||||
@ -319,8 +320,8 @@ static bool BuffersDistinct(const std::vector<const HloInstruction*>& a,
|
||||
|
||||
for (const HloInstruction* instruction : b) {
|
||||
if (assignment.HasTopLevelAllocation(instruction)) {
|
||||
if (a_slices.count(assignment.GetUniqueTopLevelSlice(instruction)
|
||||
.ConsumeValueOrDie())) {
|
||||
if (a_slices.contains(assignment.GetUniqueTopLevelSlice(instruction)
|
||||
.ConsumeValueOrDie())) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
@ -72,7 +72,7 @@ ChannelHandle ChannelTracker::AllocateHandle(ChannelHandle::ChannelType type) {
|
||||
}
|
||||
|
||||
Status ChannelTracker::RegisterSendInternal(const ChannelHandle& handle) {
|
||||
if (opaque_to_channel_.count(handle.handle()) == 0) {
|
||||
if (!opaque_to_channel_.contains(handle.handle())) {
|
||||
return NotFound("channel handle not found: %d", handle.handle());
|
||||
}
|
||||
Channel& channel = opaque_to_channel_[handle.handle()];
|
||||
@ -94,7 +94,7 @@ Status ChannelTracker::RegisterSendInternal(const ChannelHandle& handle) {
|
||||
}
|
||||
|
||||
Status ChannelTracker::RegisterRecvInternal(const ChannelHandle& handle) {
|
||||
if (opaque_to_channel_.count(handle.handle()) == 0) {
|
||||
if (!opaque_to_channel_.contains(handle.handle())) {
|
||||
return NotFound("channel handle not found: %d", handle.handle());
|
||||
}
|
||||
Channel& channel = opaque_to_channel_[handle.handle()];
|
||||
|
@ -18,6 +18,7 @@ limitations under the License.
|
||||
|
||||
#include <map>
|
||||
|
||||
#include "absl/container/flat_hash_map.h"
|
||||
#include "absl/types/span.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_module.h"
|
||||
#include "tensorflow/compiler/xla/status.h"
|
||||
@ -83,7 +84,8 @@ class ChannelTracker {
|
||||
|
||||
// Mapping from ChannelHandle value to the corresponding registered
|
||||
// Channel object.
|
||||
std::map<int64, Channel> opaque_to_channel_ GUARDED_BY(channel_mutex_);
|
||||
absl::flat_hash_map<int64, Channel> opaque_to_channel_
|
||||
GUARDED_BY(channel_mutex_);
|
||||
|
||||
TF_DISALLOW_COPY_AND_ASSIGN(ChannelTracker);
|
||||
};
|
||||
|
@ -46,8 +46,7 @@ static bool ShouldMakeAllUsersColMajor(const HloInstruction* instruction) {
|
||||
for (auto* user : instruction->users()) {
|
||||
optional<int64> operand_idx = ProfitableToMakeDotOperandColumnMajor(*user);
|
||||
if (!operand_idx || user->operand(*operand_idx) != instruction ||
|
||||
std::count(user->operands().begin(), user->operands().end(),
|
||||
instruction) != 1) {
|
||||
absl::c_count(user->operands(), instruction) != 1) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
@ -24,11 +24,9 @@ limitations under the License.
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
|
||||
// IWYU pragma: no_include "llvm/IR/Intrinsics.gen.inc"
|
||||
#include "absl/container/flat_hash_map.h"
|
||||
#include "absl/container/flat_hash_set.h"
|
||||
#include "tensorflow/core/lib/math/math_util.h"
|
||||
#include "tensorflow/core/platform/logging.h"
|
||||
// IWYU pragma: no_include "llvm/IR/Intrinsics.gen.inc"
|
||||
#include "absl/strings/str_cat.h"
|
||||
#include "absl/strings/str_format.h"
|
||||
#include "absl/types/span.h"
|
||||
@ -70,6 +68,8 @@ limitations under the License.
|
||||
#include "tensorflow/compiler/xla/window_util.h"
|
||||
#include "tensorflow/core/lib/core/bits.h"
|
||||
#include "tensorflow/core/lib/core/errors.h"
|
||||
#include "tensorflow/core/lib/math/math_util.h"
|
||||
#include "tensorflow/core/platform/logging.h"
|
||||
|
||||
namespace xla {
|
||||
|
||||
@ -1399,7 +1399,7 @@ static bool ReductionPreservesLayout(const HloInstruction& reduce) {
|
||||
|
||||
int64 delta = 0;
|
||||
for (int64 i = 0; i < operand_shape.dimensions_size(); i++) {
|
||||
if (reduced_dims.count(i)) {
|
||||
if (reduced_dims.contains(i)) {
|
||||
delta++;
|
||||
} else {
|
||||
InsertOrDie(&unreduced_dim_map, i, i - delta);
|
||||
@ -1412,7 +1412,7 @@ static bool ReductionPreservesLayout(const HloInstruction& reduce) {
|
||||
for (int64 operand_dim_idx = 0;
|
||||
operand_dim_idx < operand_shape.dimensions_size(); operand_dim_idx++) {
|
||||
int64 operand_dim = operand_shape.layout().minor_to_major(operand_dim_idx);
|
||||
if (!reduced_dims.count(operand_dim)) {
|
||||
if (!reduced_dims.contains(operand_dim)) {
|
||||
if (FindOrDie(unreduced_dim_map, operand_dim) !=
|
||||
result_shape.layout().minor_to_major(result_dim_idx++)) {
|
||||
return false;
|
||||
@ -1990,7 +1990,7 @@ Status IrEmitter::HandleSlice(HloInstruction* slice) {
|
||||
// The memcpy will copy elements that are logically this shape (allowed to be
|
||||
// scalar).
|
||||
const Shape logical_element_shape = ShapeUtil::FilterDimensions(
|
||||
[&inner_dims](int64 dim) -> bool { return inner_dims.count(dim); },
|
||||
[&inner_dims](int64 dim) { return inner_dims.contains(dim); },
|
||||
operand->shape());
|
||||
|
||||
const int64 primitive_elements_per_logical_element =
|
||||
|
@ -448,7 +448,7 @@ class IrEmitter : public DfsHloVisitorWithDefault,
|
||||
computation_to_profile_idx_;
|
||||
|
||||
// Maps HLOs to Values emitted for them.
|
||||
std::unordered_map<const HloInstruction*, llvm::Value*> emitted_value_;
|
||||
absl::flat_hash_map<const HloInstruction*, llvm::Value*> emitted_value_;
|
||||
|
||||
llvm_ir::AliasAnalysis alias_analysis_;
|
||||
|
||||
|
@ -94,8 +94,8 @@ cc_library(
|
||||
"//tensorflow/compiler/xla:util",
|
||||
"//tensorflow/compiler/xla/service:hlo",
|
||||
"//tensorflow/compiler/xla/service:hlo_reachability",
|
||||
"//tensorflow/core:lib",
|
||||
"@com_google_absl//absl/container:flat_hash_map",
|
||||
"@com_google_absl//absl/container:flat_hash_set",
|
||||
"@com_google_absl//absl/memory",
|
||||
],
|
||||
)
|
||||
@ -135,6 +135,8 @@ cc_library(
|
||||
"//tensorflow/compiler/xla/service/llvm_ir:llvm_util",
|
||||
"//tensorflow/compiler/xla/service/llvm_ir:tuple_ops",
|
||||
"//tensorflow/core:lib",
|
||||
"@com_google_absl//absl/container:flat_hash_map",
|
||||
"@com_google_absl//absl/container:flat_hash_set",
|
||||
"@com_google_absl//absl/strings",
|
||||
"@com_google_absl//absl/types:span",
|
||||
"@llvm//:core",
|
||||
@ -263,7 +265,9 @@ cc_library(
|
||||
"//tensorflow/compiler/xla/service:buffer_assignment",
|
||||
"//tensorflow/compiler/xla/service:device_memory_allocator",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:lib_internal",
|
||||
"//tensorflow/core:stream_executor_no_cuda",
|
||||
"@com_google_absl//absl/container:flat_hash_map",
|
||||
"@com_google_absl//absl/memory",
|
||||
"@com_google_absl//absl/types:span",
|
||||
],
|
||||
@ -362,6 +366,7 @@ cc_library(
|
||||
"//tensorflow/core/platform/default/build_config:stream_executor_cuda", # build_cleaner: keep
|
||||
"//tensorflow/stream_executor",
|
||||
"@com_google_absl//absl/container:flat_hash_map",
|
||||
"@com_google_absl//absl/container:flat_hash_set",
|
||||
"@com_google_absl//absl/memory",
|
||||
"@com_google_absl//absl/strings",
|
||||
"@com_google_absl//absl/strings:str_format",
|
||||
|
@ -24,6 +24,7 @@ limitations under the License.
|
||||
#include "tensorflow/compiler/xla/types.h"
|
||||
#include "tensorflow/compiler/xla/util.h"
|
||||
#include "tensorflow/core/lib/core/errors.h"
|
||||
#include "tensorflow/core/lib/gtl/map_util.h"
|
||||
#include "tensorflow/core/lib/strings/numbers.h"
|
||||
#include "tensorflow/core/platform/logging.h"
|
||||
#include "tensorflow/core/platform/types.h"
|
||||
@ -57,16 +58,16 @@ StatusOr<std::unique_ptr<BufferAllocations>> BufferAllocations::Builder::Build(
|
||||
|
||||
// If buffer #i's address is already registered (e.g. external arguments or
|
||||
// result buffers), use that registered buffer.
|
||||
if (registered_buffers_.count(i)) {
|
||||
se::DeviceMemoryBase address = FindOrDie(registered_buffers_, i);
|
||||
if (reinterpret_cast<uintptr_t>(address.opaque()) % expected_alignment !=
|
||||
if (se::DeviceMemoryBase* address =
|
||||
tensorflow::gtl::FindOrNull(registered_buffers_, i)) {
|
||||
if (reinterpret_cast<uintptr_t>(address->opaque()) % expected_alignment !=
|
||||
0) {
|
||||
return InternalError(
|
||||
"Address of registered buffer %d must be a multiple of %x, but "
|
||||
"was %p",
|
||||
i, kEntryParameterAlignBytes, address.opaque());
|
||||
i, kEntryParameterAlignBytes, address->opaque());
|
||||
}
|
||||
buffer_allocations->SetBuffer(i, FindOrDie(registered_buffers_, i));
|
||||
buffer_allocations->SetBuffer(i, *address);
|
||||
continue;
|
||||
}
|
||||
|
||||
|
@ -20,6 +20,7 @@ limitations under the License.
|
||||
#include <set>
|
||||
#include <vector>
|
||||
|
||||
#include "absl/container/flat_hash_map.h"
|
||||
#include "absl/types/span.h"
|
||||
#include "tensorflow/compiler/xla/service/buffer_assignment.h"
|
||||
#include "tensorflow/compiler/xla/service/device_memory_allocator.h"
|
||||
@ -52,7 +53,8 @@ class BufferAllocations {
|
||||
DeviceMemoryAllocator* memory_allocator);
|
||||
|
||||
private:
|
||||
std::map<BufferAllocation::Index, se::DeviceMemoryBase> registered_buffers_;
|
||||
absl::flat_hash_map<BufferAllocation::Index, se::DeviceMemoryBase>
|
||||
registered_buffers_;
|
||||
};
|
||||
|
||||
~BufferAllocations();
|
||||
|
@ -15,6 +15,7 @@ limitations under the License.
|
||||
|
||||
#include "tensorflow/compiler/xla/service/gpu/hlo_to_ir_bindings.h"
|
||||
|
||||
#include "absl/container/flat_hash_set.h"
|
||||
#include "absl/strings/str_cat.h"
|
||||
#include "llvm/IR/BasicBlock.h"
|
||||
#include "llvm/IR/Function.h"
|
||||
@ -45,10 +46,10 @@ void HloToIrBindings::EmitBasePointersForHlos(
|
||||
|
||||
// An HLO can have duplicated operands. This data structure remembers which
|
||||
// operand HLOs are already bound to avoid rebinding the same HLO.
|
||||
std::set<const HloInstruction*> already_bound_for_this_function;
|
||||
absl::flat_hash_set<const HloInstruction*> already_bound_for_this_function;
|
||||
auto arg_iter = function->arg_begin();
|
||||
for (const HloInstruction* io_hlo : io_hlos) {
|
||||
if (!already_bound_for_this_function.count(io_hlo)) {
|
||||
if (!already_bound_for_this_function.contains(io_hlo)) {
|
||||
if (!is_nested_ && io_hlo->opcode() == HloOpcode::kGetTupleElement) {
|
||||
BindHloToIrValue(*io_hlo, EmitGetTupleElement(io_hlo, &*arg_iter));
|
||||
} else {
|
||||
@ -63,7 +64,7 @@ void HloToIrBindings::EmitBasePointersForHlos(
|
||||
temp_buffer_base_->setName("temp_buffer");
|
||||
|
||||
for (const HloInstruction* non_io_hlo : non_io_hlos) {
|
||||
if (already_bound_for_this_function.count(non_io_hlo)) {
|
||||
if (already_bound_for_this_function.contains(non_io_hlo)) {
|
||||
continue;
|
||||
}
|
||||
already_bound_for_this_function.insert(non_io_hlo);
|
||||
|
@ -18,6 +18,7 @@ limitations under the License.
|
||||
|
||||
#include <unordered_map>
|
||||
|
||||
#include "absl/container/flat_hash_map.h"
|
||||
#include "absl/types/span.h"
|
||||
#include "llvm/IR/IRBuilder.h"
|
||||
#include "llvm/IR/Value.h"
|
||||
@ -61,7 +62,7 @@ class HloToIrBindings {
|
||||
|
||||
// Returns whether `hlo` is bound to an LLVM IR value.
|
||||
bool BoundToIrValue(const HloInstruction& hlo) const {
|
||||
return base_ptrs_.count(&hlo);
|
||||
return base_ptrs_.contains(&hlo);
|
||||
}
|
||||
|
||||
llvm::Value* GetTempBufferBase() const { return temp_buffer_base_; }
|
||||
@ -110,7 +111,8 @@ class HloToIrBindings {
|
||||
// For an instruction that generates multiple outputs, the root will be a
|
||||
// tuple shape. The IrArray for each element output is stored in the subnode
|
||||
// in the ShapeTree.
|
||||
std::unordered_map<const HloInstruction*, ShapeTree<llvm::Value*>> base_ptrs_;
|
||||
absl::flat_hash_map<const HloInstruction*, ShapeTree<llvm::Value*>>
|
||||
base_ptrs_;
|
||||
|
||||
// The address of the memory block that contains all temporary buffers.
|
||||
llvm::Value* temp_buffer_base_ = nullptr;
|
||||
|
@ -67,7 +67,7 @@ int64 GpuMultiOutputFusion::GetProfit(HloInstruction* instr1,
|
||||
}
|
||||
int64 profit = 0;
|
||||
for (auto instr : instr2->operands()) {
|
||||
if (!IsProfitableOperand(instr) || in_list.count(instr) == 0) {
|
||||
if (!IsProfitableOperand(instr) || !in_list.contains(instr)) {
|
||||
continue;
|
||||
}
|
||||
profit += ShapeUtil::ByteSizeOf(instr->shape());
|
||||
|
@ -15,6 +15,7 @@ limitations under the License.
|
||||
|
||||
#include "tensorflow/compiler/xla/service/gpu/stream_assignment.h"
|
||||
|
||||
#include "absl/container/flat_hash_set.h"
|
||||
#include "absl/memory/memory.h"
|
||||
#include "tensorflow/compiler/xla/map_util.h"
|
||||
#include "tensorflow/compiler/xla/service/gpu/ir_emission_utils.h"
|
||||
@ -25,7 +26,7 @@ namespace xla {
|
||||
namespace gpu {
|
||||
|
||||
bool StreamAssignment::HasStreamAssigned(const HloInstruction& hlo) const {
|
||||
return hlo_to_stream_number_.count(&hlo);
|
||||
return hlo_to_stream_number_.contains(&hlo);
|
||||
}
|
||||
|
||||
int StreamAssignment::StreamNumberForHlo(const HloInstruction& hlo) const {
|
||||
@ -98,10 +99,10 @@ int ComputeStreamToAssign(
|
||||
// greedy approach. First, we compute as forbidden_stream_numbers the
|
||||
// streams assigned to GEMMs that are concurrent with `hlo`. Then, we assign
|
||||
// `hlo` a different stream.
|
||||
std::set<int> forbidden_stream_numbers;
|
||||
absl::flat_hash_set<int> forbidden_stream_numbers;
|
||||
for (const auto* seen_gemm : seen_gemms) {
|
||||
int stream_num = stream_assignment.StreamNumberForHlo(*seen_gemm);
|
||||
if (!forbidden_stream_numbers.count(stream_num) &&
|
||||
if (!forbidden_stream_numbers.contains(stream_num) &&
|
||||
CanRunConcurrently(*seen_gemm, hlo, reachability)) {
|
||||
forbidden_stream_numbers.insert(stream_num);
|
||||
}
|
||||
@ -109,7 +110,7 @@ int ComputeStreamToAssign(
|
||||
|
||||
for (int stream_num = 0; stream_num < stream_assignment.StreamCount();
|
||||
++stream_num) {
|
||||
if (!forbidden_stream_numbers.count(stream_num)) {
|
||||
if (!forbidden_stream_numbers.contains(stream_num)) {
|
||||
return stream_num;
|
||||
}
|
||||
}
|
||||
|
@ -14,17 +14,19 @@ limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/compiler/xla/service/gpu/thunk_schedule.h"
|
||||
#include "absl/container/flat_hash_map.h"
|
||||
#include "tensorflow/compiler/xla/array2d.h"
|
||||
#include "tensorflow/compiler/xla/map_util.h"
|
||||
#include "tensorflow/compiler/xla/types.h"
|
||||
#include "tensorflow/core/lib/gtl/map_util.h"
|
||||
|
||||
namespace xla {
|
||||
namespace gpu {
|
||||
|
||||
void ThunkSchedule::AddDependenciesOnTransitiveOperands(
|
||||
const Thunk& thunk, const HloInstruction& operand,
|
||||
const std::unordered_map<const HloInstruction*, Thunk*>& hlo_to_thunk) {
|
||||
if (hlo_to_thunk.count(&operand)) {
|
||||
const absl::flat_hash_map<const HloInstruction*, Thunk*>& hlo_to_thunk) {
|
||||
if (hlo_to_thunk.contains(&operand)) {
|
||||
// If `operand` is mapped to a thunk, adds `operand` to `thunk`'s dependency
|
||||
// list if `operand` is assigned to a different stream. As an optimization,
|
||||
// we skip `operand`'s operands because `operand` depends on them already.
|
||||
@ -48,14 +50,14 @@ ThunkSchedule::ThunkSchedule(
|
||||
const std::vector<HloInstruction*>& hlo_total_order)
|
||||
: thunks_(std::move(thunks)),
|
||||
stream_assignment_(std::move(stream_assignment)) {
|
||||
std::unordered_map<const HloInstruction*, Thunk*> hlo_to_thunk;
|
||||
absl::flat_hash_map<const HloInstruction*, Thunk*> hlo_to_thunk;
|
||||
for (const auto& thunk : *thunks_) {
|
||||
InsertOrDie(&hlo_to_thunk, thunk->hlo_instruction(), thunk.get());
|
||||
}
|
||||
|
||||
for (HloInstruction* hlo : hlo_total_order) {
|
||||
if (hlo_to_thunk.count(hlo)) {
|
||||
thunk_total_order_.push_back(FindOrDie(hlo_to_thunk, hlo));
|
||||
if (Thunk** thunk = tensorflow::gtl::FindOrNull(hlo_to_thunk, hlo)) {
|
||||
thunk_total_order_.push_back(*thunk);
|
||||
}
|
||||
}
|
||||
|
||||
@ -106,7 +108,7 @@ void ThunkSchedule::RemoveRedundantDependencyEdges() {
|
||||
// redundant dependency edge.
|
||||
Array2D<int> last_dependency(stream_count, stream_count, -1);
|
||||
for (const Thunk* dst : thunk_total_order_) {
|
||||
if (!depends_on_.count(dst)) {
|
||||
if (!depends_on_.contains(dst)) {
|
||||
continue;
|
||||
}
|
||||
|
||||
@ -134,7 +136,7 @@ void ThunkSchedule::RemoveRedundantDependencyEdges() {
|
||||
|
||||
const std::list<const Thunk*>& ThunkSchedule::DependsOn(
|
||||
const Thunk* thunk) const {
|
||||
if (depends_on_.count(thunk)) {
|
||||
if (depends_on_.contains(thunk)) {
|
||||
return FindOrDie(depends_on_, thunk);
|
||||
} else {
|
||||
return empty_thunk_list_;
|
||||
|
@ -21,6 +21,8 @@ limitations under the License.
|
||||
#include <unordered_map>
|
||||
#include <vector>
|
||||
|
||||
#include "absl/container/flat_hash_map.h"
|
||||
#include "absl/container/flat_hash_set.h"
|
||||
#include "tensorflow/compiler/xla/service/gpu/stream_assignment.h"
|
||||
#include "tensorflow/compiler/xla/service/gpu/thunk.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
|
||||
@ -54,7 +56,9 @@ class ThunkSchedule {
|
||||
// Thunks that `thunk` depends on.
|
||||
const std::list<const Thunk*>& DependsOn(const Thunk* thunk) const;
|
||||
// Whether `thunk` is depended by another thunk.
|
||||
bool Depended(const Thunk* thunk) const { return depended_by_.count(thunk); }
|
||||
bool Depended(const Thunk* thunk) const {
|
||||
return depended_by_.contains(thunk);
|
||||
}
|
||||
|
||||
// Delegates to StreamAssignment.
|
||||
int StreamCount() const { return stream_assignment_->StreamCount(); }
|
||||
@ -75,13 +79,13 @@ class ThunkSchedule {
|
||||
// thunk.hlo_instruction().
|
||||
void AddDependenciesOnTransitiveOperands(
|
||||
const Thunk& thunk, const HloInstruction& operand,
|
||||
const std::unordered_map<const HloInstruction*, Thunk*>& hlo_to_thunk);
|
||||
const absl::flat_hash_map<const HloInstruction*, Thunk*>& hlo_to_thunk);
|
||||
|
||||
std::unique_ptr<ThunkSequence> thunks_;
|
||||
std::vector<Thunk*> thunk_total_order_;
|
||||
|
||||
std::unordered_map<const Thunk*, std::list<const Thunk*>> depends_on_;
|
||||
std::set<const Thunk*> depended_by_;
|
||||
absl::flat_hash_map<const Thunk*, std::list<const Thunk*>> depends_on_;
|
||||
absl::flat_hash_set<const Thunk*> depended_by_;
|
||||
std::list<const Thunk*> empty_thunk_list_;
|
||||
|
||||
std::unique_ptr<StreamAssignment> stream_assignment_;
|
||||
|
@ -199,7 +199,7 @@ Status HeapSimulator::RunComputation(
|
||||
|
||||
// If the buffer has no users and isn't an entry parameter or output, it
|
||||
// must be a dead value.
|
||||
if (live_buffers.count(buffer) == 0) {
|
||||
if (!live_buffers.contains(buffer)) {
|
||||
dead_buffers_to_free.push_back(buffer);
|
||||
}
|
||||
}
|
||||
@ -253,7 +253,7 @@ Status HeapSimulator::RunComputation(
|
||||
bool shared = false;
|
||||
if (options_.may_reuse_operand_buffers) {
|
||||
for (const BufferValue* operand_buffer : operand_buffers_to_free) {
|
||||
if (reused_buffers.count(operand_buffer) != 0) {
|
||||
if (reused_buffers.contains(operand_buffer)) {
|
||||
continue;
|
||||
}
|
||||
if (buffer->instruction()->IsUserOf(operand_buffer->instruction()) &&
|
||||
@ -374,15 +374,15 @@ bool HeapSimulator::IgnoreBuffer(const BufferValue* buffer) const {
|
||||
return true;
|
||||
}
|
||||
return options_.buffers_to_assign != nullptr &&
|
||||
options_.buffers_to_assign->count(buffer) == 0;
|
||||
!options_.buffers_to_assign->contains(buffer);
|
||||
}
|
||||
|
||||
// Alloc always calls the underlying heap algorithm.
|
||||
void HeapSimulator::Alloc(const BufferValue* buffer,
|
||||
const HloInstruction* instruction) {
|
||||
CHECK(allocated_buffers_.count(buffer) == 0)
|
||||
CHECK(!allocated_buffers_.contains(buffer))
|
||||
<< "Alloc called on allocated buffer: " << *buffer;
|
||||
CHECK(freed_buffers_.count(buffer) == 0)
|
||||
CHECK(!freed_buffers_.contains(buffer))
|
||||
<< "Alloc called on freed buffer: " << *buffer;
|
||||
|
||||
allocated_buffers_.insert(buffer);
|
||||
@ -411,9 +411,9 @@ void HeapSimulator::Free(const BufferValue* buffer,
|
||||
buffer = group->canonical;
|
||||
}
|
||||
|
||||
CHECK(allocated_buffers_.count(buffer) > 0)
|
||||
CHECK(allocated_buffers_.contains(buffer))
|
||||
<< "Free called on non-allocated buffer: " << *buffer;
|
||||
CHECK(freed_buffers_.count(buffer) == 0)
|
||||
CHECK(!freed_buffers_.contains(buffer))
|
||||
<< "Free called on freed buffer: " << *buffer;
|
||||
|
||||
freed_buffers_.insert(buffer);
|
||||
@ -433,11 +433,11 @@ void HeapSimulator::ShareBuffer(const BufferValue* buffer,
|
||||
const HloInstruction* instruction) {
|
||||
CHECK_LE(size_fn_(*buffer), size_fn_(*shared))
|
||||
<< "ShareBuffer oversized buffer" << *buffer << " shared: " << *shared;
|
||||
CHECK(allocated_buffers_.count(buffer) == 0)
|
||||
CHECK(!allocated_buffers_.contains(buffer))
|
||||
<< "ShareBuffer called on allocated buffer: " << *buffer;
|
||||
CHECK(freed_buffers_.count(buffer) == 0)
|
||||
CHECK(!freed_buffers_.contains(buffer))
|
||||
<< "ShareBuffer called on freed buffer: " << *buffer;
|
||||
CHECK(freed_buffers_.count(shared) == 0)
|
||||
CHECK(!freed_buffers_.contains(shared))
|
||||
<< "ShareBuffer called on freed shared buffer: " << *shared;
|
||||
|
||||
const BufferValue* canonical = nullptr;
|
||||
@ -452,7 +452,7 @@ void HeapSimulator::ShareBuffer(const BufferValue* buffer,
|
||||
} else {
|
||||
// The 'shared' buffer doesn't have a group; it must be the canonical. Add
|
||||
// both 'buffer' and 'shared' to a new group.
|
||||
CHECK(allocated_buffers_.count(shared) > 0)
|
||||
CHECK(allocated_buffers_.contains(shared))
|
||||
<< "ShareBuffer called on non-allocated shared buffer: " << *shared;
|
||||
auto group = std::make_shared<SharedGroup>();
|
||||
canonical = shared;
|
||||
|
@ -207,14 +207,14 @@ Status HloComputation::RemoveInstructionAndUnusedOperands(
|
||||
TF_RET_CHECK(instruction->user_count() == 0);
|
||||
TF_RET_CHECK(IsRemovable(instruction))
|
||||
<< "Cannot remove instruction: " << instruction->ToString();
|
||||
std::unordered_set<HloInstruction*> removed;
|
||||
absl::flat_hash_set<HloInstruction*> removed;
|
||||
std::queue<HloInstruction*> worklist;
|
||||
worklist.push(instruction);
|
||||
while (!worklist.empty()) {
|
||||
HloInstruction* item = worklist.front();
|
||||
worklist.pop();
|
||||
|
||||
if (removed.count(item) != 0 || item->user_count() != 0 ||
|
||||
if (removed.contains(item) || item->user_count() != 0 ||
|
||||
item == root_instruction() || !IsRemovable(item) ||
|
||||
(item->HasSideEffect() && item != instruction)) {
|
||||
continue;
|
||||
@ -694,13 +694,14 @@ bool HloComputation::operator==(const HloComputation& other) const {
|
||||
if (this == &other) {
|
||||
return true;
|
||||
}
|
||||
std::set<std::pair<const HloInstruction*, const HloInstruction*>> visited;
|
||||
absl::flat_hash_set<std::pair<const HloInstruction*, const HloInstruction*>>
|
||||
visited;
|
||||
std::function<bool(const HloInstruction*, const HloInstruction*)> eq =
|
||||
[&visited, &eq](const HloInstruction* a, const HloInstruction* b) {
|
||||
// If <a,b> are visited but not identical, the recursion should have
|
||||
// been aborted. So, if <a,b> are visited at this point, they must be
|
||||
// identical.
|
||||
if (visited.count(std::make_pair(a, b)) > 0) {
|
||||
if (visited.contains(std::make_pair(a, b))) {
|
||||
return true;
|
||||
}
|
||||
visited.emplace(a, b);
|
||||
@ -803,13 +804,13 @@ Status HloComputation::AcceptOrdered(
|
||||
<< root->ToString();
|
||||
}
|
||||
TF_RET_CHECK(order.size() == instruction_count());
|
||||
std::unordered_set<const HloInstruction*> visited;
|
||||
absl::flat_hash_set<const HloInstruction*> visited;
|
||||
for (const HloInstruction* instruction : order) {
|
||||
VLOG(3) << "Visiting ordered: " << instruction->ToString();
|
||||
TF_RET_CHECK(instruction_iterators_.count(instruction) == 1)
|
||||
TF_RET_CHECK(instruction_iterators_.contains(instruction))
|
||||
<< "Instruction " << instruction->name() << " is not in computation "
|
||||
<< name();
|
||||
TF_RET_CHECK(visited.count(instruction) == 0)
|
||||
TF_RET_CHECK(!visited.contains(instruction))
|
||||
<< "Instruction " << instruction->name()
|
||||
<< " appears more than once in order";
|
||||
HloInstruction* mutable_instruction =
|
||||
|
@ -17,6 +17,7 @@ limitations under the License.
|
||||
|
||||
#include <set>
|
||||
|
||||
#include "absl/container/flat_hash_set.h"
|
||||
#include "tensorflow/compiler/xla/literal.h"
|
||||
#include "tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
|
||||
@ -226,7 +227,7 @@ TEST_F(HloComputationTest, VisitWithMultipleRoots) {
|
||||
: computation_(computation) {}
|
||||
|
||||
Status DefaultAction(HloInstruction* hlo_instruction) override {
|
||||
EXPECT_EQ(0, visited_set_.count(hlo_instruction));
|
||||
EXPECT_FALSE(visited_set_.contains(hlo_instruction));
|
||||
visited_set_.insert(hlo_instruction);
|
||||
last_visited_ = hlo_instruction;
|
||||
return Status::OK();
|
||||
@ -239,7 +240,7 @@ TEST_F(HloComputationTest, VisitWithMultipleRoots) {
|
||||
}
|
||||
|
||||
HloComputation* computation_;
|
||||
std::set<HloInstruction*> visited_set_;
|
||||
absl::flat_hash_set<HloInstruction*> visited_set_;
|
||||
int64 finish_visit_calls_ = 0;
|
||||
HloInstruction* last_visited_ = nullptr;
|
||||
};
|
||||
|
@ -107,7 +107,7 @@ bool HloDataflowAnalysis::AreTransitiveUsesElementwiseOrTuple(
|
||||
return false;
|
||||
}
|
||||
}
|
||||
if (!visited.count(user)) {
|
||||
if (!visited.contains(user)) {
|
||||
stack.push_back(user);
|
||||
}
|
||||
}
|
||||
|
@ -20,6 +20,7 @@ limitations under the License.
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
|
||||
#include "absl/container/flat_hash_set.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_computation.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_module.h"
|
||||
@ -65,7 +66,7 @@ StatusOr<bool> HloDCE::Run(HloModule* module) {
|
||||
|
||||
// Now DCE HloComputations. First, collect the computations that are
|
||||
// referenced by some remaining instruction.
|
||||
std::unordered_set<HloComputation*> live_computations;
|
||||
absl::flat_hash_set<HloComputation*> live_computations;
|
||||
if (HloComputation* entry_computation = module->entry_computation()) {
|
||||
live_computations.insert(entry_computation);
|
||||
}
|
||||
@ -79,7 +80,7 @@ StatusOr<bool> HloDCE::Run(HloModule* module) {
|
||||
|
||||
// Remove dead computations.
|
||||
for (auto* computation : module->MakeComputationPostOrder()) {
|
||||
if (live_computations.count(computation) == 0) {
|
||||
if (!live_computations.contains(computation)) {
|
||||
TF_RETURN_IF_ERROR(module->RemoveEmbeddedComputation(computation));
|
||||
changed = true;
|
||||
}
|
||||
|
@ -24,9 +24,9 @@ limitations under the License.
|
||||
#include <queue>
|
||||
#include <string>
|
||||
#include <tuple>
|
||||
#include <unordered_map>
|
||||
#include <vector>
|
||||
|
||||
#include "absl/container/flat_hash_map.h"
|
||||
#include "absl/strings/match.h"
|
||||
#include "absl/strings/str_cat.h"
|
||||
#include "absl/strings/str_format.h"
|
||||
@ -380,7 +380,7 @@ class HloDotDumper {
|
||||
// Each HloInstruction dumped gets a monotically-increasing node ID. This
|
||||
// must start at 1, because that's where graphviz's accounting starts.
|
||||
int64 next_node_id_ = 1;
|
||||
std::unordered_map<const HloInstruction*, int64> node_ids_;
|
||||
absl::flat_hash_map<const HloInstruction*, int64> node_ids_;
|
||||
|
||||
// The "root" tag doesn't have an associated HloInstruction pointer, so we
|
||||
// need to store it outside the map.
|
||||
@ -397,7 +397,7 @@ class HloDotDumper {
|
||||
|
||||
// Each HloComputation that's emitted gets a monotonically-increasing ID.
|
||||
int64 next_cluster_id_ = 1;
|
||||
std::unordered_map<const HloComputation*, int64> cluster_ids_;
|
||||
absl::flat_hash_map<const HloComputation*, int64> cluster_ids_;
|
||||
|
||||
// Edges to print from Footer(). Edges come at the end because graphviz is
|
||||
// unhappy if an edge from a subcomputation to a node in the outer computation
|
||||
@ -407,7 +407,7 @@ class HloDotDumper {
|
||||
|
||||
// When coloring by sharding information, we track the sharding string
|
||||
// representation to color association, by round-robin the color schemes.
|
||||
std::unordered_map<HloSharding, ColorScheme, HloSharding::Hasher>
|
||||
absl::flat_hash_map<HloSharding, ColorScheme, HloSharding::Hasher>
|
||||
sharding_colors_;
|
||||
int64 next_shard_color_ = 0;
|
||||
};
|
||||
@ -1286,7 +1286,7 @@ NodeFilter MakeNodeRadiusAroundFilter(const HloInstruction* root,
|
||||
int64 radius) {
|
||||
// First, find the neighborhood of nodes with distance from root <= radius.
|
||||
// These nodes are our initial set of "normal" nodes.
|
||||
std::unordered_map<const HloInstruction*, NodeFilterResult> nodes;
|
||||
absl::flat_hash_map<const HloInstruction*, NodeFilterResult> nodes;
|
||||
std::deque<std::pair<const HloInstruction*, /*depth*/ int64>> worklist;
|
||||
worklist.push_back({root, 0});
|
||||
while (!worklist.empty()) {
|
||||
@ -1307,7 +1307,7 @@ NodeFilter MakeNodeRadiusAroundFilter(const HloInstruction* root,
|
||||
// are not interesting to the graph at hand.
|
||||
if (instr == root || instr->opcode() != HloOpcode::kTuple) {
|
||||
for (const HloInstruction* operand : instr->operands()) {
|
||||
if (!nodes.count(operand)) {
|
||||
if (!nodes.contains(operand)) {
|
||||
worklist.push_back({operand, depth + 1});
|
||||
}
|
||||
}
|
||||
@ -1335,7 +1335,7 @@ NodeFilter MakeNodeRadiusAroundFilter(const HloInstruction* root,
|
||||
continue;
|
||||
}
|
||||
for (const HloInstruction* user : instr->users()) {
|
||||
if (!nodes.count(user)) {
|
||||
if (!nodes.contains(user)) {
|
||||
worklist.push_back({user, depth + 1});
|
||||
}
|
||||
}
|
||||
@ -1344,7 +1344,7 @@ NodeFilter MakeNodeRadiusAroundFilter(const HloInstruction* root,
|
||||
auto is_displayed = [&](const HloInstruction* instr) {
|
||||
// Constants are displayed inline with their users; they're never omitted.
|
||||
// Nodes in subcomputations are always shown.
|
||||
return nodes.count(instr) > 0 || instr->opcode() == HloOpcode::kConstant ||
|
||||
return nodes.contains(instr) || instr->opcode() == HloOpcode::kConstant ||
|
||||
instr->parent() != root->parent();
|
||||
};
|
||||
|
||||
|
@ -20,6 +20,7 @@ limitations under the License.
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
|
||||
#include "absl/container/flat_hash_map.h"
|
||||
#include "tensorflow/compiler/xla/literal.h"
|
||||
#include "tensorflow/compiler/xla/protobuf_util.h"
|
||||
#include "tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h"
|
||||
@ -55,13 +56,13 @@ class OpAndUserCollectingVisitor : public DfsHloVisitorWithDefault {
|
||||
}
|
||||
|
||||
Status HandleParameter(HloInstruction* parameter) override {
|
||||
EXPECT_EQ(0, count_.count(parameter));
|
||||
EXPECT_FALSE(count_.contains(parameter));
|
||||
count_[parameter] = GetCountsForNode(parameter);
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status HandleConstant(HloInstruction* constant) override {
|
||||
EXPECT_EQ(0, count_.count(constant));
|
||||
EXPECT_FALSE(count_.contains(constant));
|
||||
count_[constant] = GetCountsForNode(constant);
|
||||
return Status::OK();
|
||||
}
|
||||
@ -69,25 +70,25 @@ class OpAndUserCollectingVisitor : public DfsHloVisitorWithDefault {
|
||||
Status HandleAdd(HloInstruction* add) override {
|
||||
auto lhs = add->operand(0);
|
||||
auto rhs = add->operand(1);
|
||||
EXPECT_EQ(0, count_.count(add));
|
||||
EXPECT_GT(count_.count(lhs), 0);
|
||||
EXPECT_GT(count_.count(rhs), 0);
|
||||
EXPECT_FALSE(count_.contains(add));
|
||||
EXPECT_TRUE(count_.contains(lhs));
|
||||
EXPECT_TRUE(count_.contains(rhs));
|
||||
count_[add] = GetCountsForNode(add);
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status HandleNegate(HloInstruction* negate) override {
|
||||
auto operand = negate->operand(0);
|
||||
EXPECT_EQ(0, count_.count(negate));
|
||||
EXPECT_GT(count_.count(operand), 0);
|
||||
EXPECT_FALSE(count_.contains(negate));
|
||||
EXPECT_TRUE(count_.contains(operand));
|
||||
count_[negate] = GetCountsForNode(negate);
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status HandleMap(HloInstruction* map) override {
|
||||
EXPECT_EQ(0, count_.count(map));
|
||||
EXPECT_FALSE(count_.contains(map));
|
||||
for (HloInstruction* arg : map->operands()) {
|
||||
EXPECT_GT(count_.count(arg), 0);
|
||||
EXPECT_TRUE(count_.contains(arg));
|
||||
}
|
||||
count_[map] = GetCountsForNode(map);
|
||||
return Status::OK();
|
||||
@ -96,9 +97,9 @@ class OpAndUserCollectingVisitor : public DfsHloVisitorWithDefault {
|
||||
Status HandleReduce(HloInstruction* reduce) override {
|
||||
auto arg = reduce->operand(0);
|
||||
auto init_value = reduce->operand(1);
|
||||
EXPECT_EQ(0, count_.count(reduce));
|
||||
EXPECT_GT(count_.count(arg), 0);
|
||||
EXPECT_GT(count_.count(init_value), 0);
|
||||
EXPECT_FALSE(count_.contains(reduce));
|
||||
EXPECT_TRUE(count_.contains(arg));
|
||||
EXPECT_TRUE(count_.contains(init_value));
|
||||
count_[reduce] = GetCountsForNode(reduce);
|
||||
return Status::OK();
|
||||
}
|
||||
@ -128,7 +129,7 @@ class OpAndUserCollectingVisitor : public DfsHloVisitorWithDefault {
|
||||
}
|
||||
|
||||
// Counters for HLOs. Maps HLO to a NumOpsAndUsers.
|
||||
std::unordered_map<const HloInstruction*, NumOpsAndUsers> count_;
|
||||
absl::flat_hash_map<const HloInstruction*, NumOpsAndUsers> count_;
|
||||
};
|
||||
|
||||
TEST_F(HloInstructionTest, BasicProperties) {
|
||||
@ -137,7 +138,7 @@ TEST_F(HloInstructionTest, BasicProperties) {
|
||||
EXPECT_EQ(HloOpcode::kParameter, parameter->opcode());
|
||||
EXPECT_TRUE(ShapeUtil::IsScalarWithElementType(parameter->shape(), F32));
|
||||
EXPECT_FALSE(ShapeUtil::IsScalarWithElementType(parameter->shape(), S32));
|
||||
EXPECT_EQ(0, parameter->operand_count());
|
||||
EXPECT_FALSE(parameter->operand_count());
|
||||
}
|
||||
|
||||
TEST_F(HloInstructionTest, UserWithTwoOperands) {
|
||||
@ -981,9 +982,9 @@ TEST_F(HloInstructionTest, FunctionVisitor) {
|
||||
module->AddEntryComputation(builder.Build());
|
||||
|
||||
int visit_num = 0;
|
||||
std::unordered_map<HloInstruction*, int> visit_order;
|
||||
absl::flat_hash_map<HloInstruction*, int> visit_order;
|
||||
EXPECT_IS_OK(add->Accept([&visit_num, &visit_order](HloInstruction* inst) {
|
||||
EXPECT_EQ(0, visit_order.count(inst));
|
||||
EXPECT_FALSE(visit_order.contains(inst));
|
||||
visit_order[inst] = visit_num;
|
||||
visit_num++;
|
||||
return Status::OK();
|
||||
|
@ -17,6 +17,7 @@ limitations under the License.
|
||||
|
||||
#include <deque>
|
||||
|
||||
#include "absl/container/flat_hash_set.h"
|
||||
#include "absl/memory/memory.h"
|
||||
#include "absl/strings/str_cat.h"
|
||||
#include "tensorflow/compiler/xla/map_util.h"
|
||||
@ -36,11 +37,11 @@ namespace xla {
|
||||
namespace {
|
||||
|
||||
using Worklist = std::deque<const HloInstruction*>;
|
||||
using Workset = std::unordered_set<const HloInstruction*>;
|
||||
using Workset = absl::flat_hash_set<const HloInstruction*>;
|
||||
|
||||
void AddToWorklist(const HloInstruction* instruction, Worklist* worklist,
|
||||
Workset* workset) {
|
||||
if (workset->count(instruction) == 0) {
|
||||
if (!workset->contains(instruction)) {
|
||||
worklist->push_back(instruction);
|
||||
workset->insert(instruction);
|
||||
VLOG(3) << "ADD instruction: " << instruction->name();
|
||||
|
@ -392,15 +392,12 @@ namespace {
|
||||
// Returns whether `hlo` is used outside the given subcomputation.
|
||||
// `instructions_in_subcomputation` is the instruction set of the given
|
||||
// subcomputation.
|
||||
bool IsUsedOutsideSubcomputation(
|
||||
const HloInstruction& hlo,
|
||||
const std::unordered_set<HloInstruction*>& instructions_in_subcomputation) {
|
||||
for (HloInstruction* user : hlo.users()) {
|
||||
if (!instructions_in_subcomputation.count(user)) {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
return false;
|
||||
bool IsUsedOutsideSubcomputation(const HloInstruction& hlo,
|
||||
const absl::flat_hash_set<HloInstruction*>&
|
||||
instructions_in_subcomputation) {
|
||||
return absl::c_any_of(hlo.users(), [&](HloInstruction* user) {
|
||||
return !instructions_in_subcomputation.contains(user);
|
||||
});
|
||||
}
|
||||
} // anonymous namespace
|
||||
|
||||
@ -411,9 +408,9 @@ HloInstruction* HloModule::OutlineExpressionFromComputation(
|
||||
|
||||
// A map from original instructions to their counterparts in the new outlined
|
||||
// function.
|
||||
std::unordered_map<HloInstruction*, HloInstruction*> outlined_instructions;
|
||||
absl::flat_hash_map<HloInstruction*, HloInstruction*> outlined_instructions;
|
||||
// A set that contains all instructions to be outlined.
|
||||
std::unordered_set<HloInstruction*> instruction_set_to_outline(
|
||||
absl::flat_hash_set<HloInstruction*> instruction_set_to_outline(
|
||||
instructions_to_outline.begin(), instructions_to_outline.end());
|
||||
std::vector<HloInstruction*> arguments;
|
||||
std::vector<HloInstruction*> outputs;
|
||||
@ -502,7 +499,7 @@ std::vector<HloComputation*> HloModule::MakeComputationPostOrder() const {
|
||||
// First determine all root computations by building a set of nonroot
|
||||
// computations (computations which are called by an instruction in the
|
||||
// module).
|
||||
std::set<HloComputation*> nonroot_computations;
|
||||
absl::flat_hash_set<HloComputation*> nonroot_computations;
|
||||
for (auto& computation : computations_) {
|
||||
for (auto* instruction : computation->instructions()) {
|
||||
for (HloComputation* called_computation :
|
||||
@ -515,19 +512,19 @@ std::vector<HloComputation*> HloModule::MakeComputationPostOrder() const {
|
||||
// Keep track of computations which have already been added to the post
|
||||
// order. This prevents duplication as an embedded computation may be called
|
||||
// from two different root computations.
|
||||
std::set<HloComputation*> added_computations;
|
||||
absl::flat_hash_set<HloComputation*> added_computations;
|
||||
std::vector<HloComputation*> post_order;
|
||||
for (auto& computation : computations_) {
|
||||
if (nonroot_computations.count(computation.get()) == 0) {
|
||||
if (!nonroot_computations.contains(computation.get())) {
|
||||
for (HloComputation* embedded_computation :
|
||||
computation->MakeEmbeddedComputationsList()) {
|
||||
if (added_computations.count(embedded_computation) == 0) {
|
||||
if (!added_computations.contains(embedded_computation)) {
|
||||
post_order.push_back(embedded_computation);
|
||||
added_computations.insert(embedded_computation);
|
||||
}
|
||||
}
|
||||
// Root computations should only be encountered once.
|
||||
CHECK_EQ(0, added_computations.count(computation.get()));
|
||||
CHECK(!added_computations.contains(computation.get()));
|
||||
post_order.push_back(computation.get());
|
||||
added_computations.insert(computation.get());
|
||||
}
|
||||
|
@ -199,7 +199,7 @@ bool HloModuleGroupMetadata::IsChannelInstruction(
|
||||
}
|
||||
|
||||
bool HloModuleGroupMetadata::IsCompanionInstruction(HloInstruction* hlo) const {
|
||||
return companion_set_index_.count(hlo) > 0;
|
||||
return companion_set_index_.contains(hlo);
|
||||
}
|
||||
|
||||
bool HloModuleGroupMetadata::InstructionCommunicates(
|
||||
@ -510,7 +510,7 @@ Status HloModuleGroupMetadata::CheckCommunicatingInstruction(
|
||||
HloComputation* computation = instruction->parent();
|
||||
const HloModule* module = computation->parent();
|
||||
if (module->entry_computation() == computation ||
|
||||
tracked_instructions_.count(computation) > 0) {
|
||||
tracked_instructions_.contains(computation)) {
|
||||
return Status::OK();
|
||||
}
|
||||
return FailedPrecondition("channel is used in disallowed computation");
|
||||
|
@ -178,7 +178,7 @@ class HloModuleGroupMetadata {
|
||||
// Precondition: IsCompanionWhile(instruction) is true.
|
||||
const std::vector<HloInstruction*>& Companions(
|
||||
const HloInstruction* instruction) const {
|
||||
CHECK_EQ(companion_set_index_.count(instruction), 1);
|
||||
CHECK(companion_set_index_.contains(instruction));
|
||||
return companion_set(companion_set_index_.at(instruction));
|
||||
}
|
||||
|
||||
|
@ -367,7 +367,7 @@ bool SequentialHloOrdering::ExecutesBeforeInSameComputation(
|
||||
const HloInstruction* a, const HloInstruction* b) const {
|
||||
CHECK_EQ(a->parent(), b->parent());
|
||||
// If either instruction is not in the order, then 'a' and 'b' are unordered.
|
||||
if (order_position_.count(a) == 0 || order_position_.count(b) == 0) {
|
||||
if (!order_position_.contains(a) || !order_position_.contains(b)) {
|
||||
return false;
|
||||
}
|
||||
return order_position_.at(a) < order_position_.at(b);
|
||||
|
@ -89,7 +89,7 @@ std::vector<HloPassInterface*> HloPassPipeline::GetEnabledPasses(
|
||||
|
||||
std::vector<HloPassInterface*> enabled_passes;
|
||||
for (auto& pass : passes_) {
|
||||
if (disabled_pass_names.count(string(pass->name())) == 0) {
|
||||
if (!disabled_pass_names.contains(pass->name())) {
|
||||
enabled_passes.push_back(pass.get());
|
||||
}
|
||||
}
|
||||
|
@ -140,7 +140,7 @@ Status HloSchedule::UpdateComputationSchedule(
|
||||
std::queue<HloInstruction*> worklist;
|
||||
|
||||
for (HloInstruction* instruction : computation->instructions()) {
|
||||
if (ids_in_schedule.count(instruction->unique_id()) == 0) {
|
||||
if (!ids_in_schedule.contains(instruction->unique_id())) {
|
||||
// This is a newly added instruction which is not in the schedule.
|
||||
if (instruction->operands().empty()) {
|
||||
worklist.push(instruction);
|
||||
@ -204,7 +204,7 @@ Status HloSchedule::Update() {
|
||||
std::vector<HloComputation*> nonfusion_computations =
|
||||
module_->MakeNonfusionComputations();
|
||||
for (const HloComputation* computation : nonfusion_computations) {
|
||||
TF_RET_CHECK(sequences_.count(computation->unique_id()) == 1)
|
||||
TF_RET_CHECK(sequences_.contains(computation->unique_id()))
|
||||
<< "Computation " << computation->name() << " not in HloSchedule.";
|
||||
}
|
||||
if (sequences_.size() > nonfusion_computations.size()) {
|
||||
@ -215,7 +215,7 @@ Status HloSchedule::Update() {
|
||||
nonfusion_computations_ids.insert(computation->unique_id());
|
||||
}
|
||||
for (auto it = sequences_.begin(); it != sequences_.end();) {
|
||||
if (nonfusion_computations_ids.count(it->first) == 0) {
|
||||
if (!nonfusion_computations_ids.contains(it->first)) {
|
||||
sequences_.erase(it++);
|
||||
} else {
|
||||
++it;
|
||||
@ -244,7 +244,7 @@ Status HloSchedule::Verify() const {
|
||||
<< "Schedule has " << sequences_.size() << " sequences, but module has "
|
||||
<< nonfusion_computations.size() << " non-fusion computations";
|
||||
for (const HloComputation* computation : nonfusion_computations) {
|
||||
TF_RET_CHECK(sequences_.count(computation->unique_id()) == 1)
|
||||
TF_RET_CHECK(sequences_.contains(computation->unique_id()))
|
||||
<< "Computation " << computation->name()
|
||||
<< " missing from HLO schedule.";
|
||||
}
|
||||
@ -268,7 +268,7 @@ Status HloSchedule::Verify() const {
|
||||
<< instruction_position.size() << " instructions, expected "
|
||||
<< computation->instruction_count();
|
||||
for (const HloInstruction* instruction : computation->instructions()) {
|
||||
TF_RET_CHECK(instruction_position.count(instruction) == 1)
|
||||
TF_RET_CHECK(instruction_position.contains(instruction))
|
||||
<< "Instruction " << instruction->name() << " is not in schedule";
|
||||
}
|
||||
|
||||
|
@ -110,7 +110,7 @@ class HloSchedule {
|
||||
|
||||
// Returns true if the schedule has a sequence for the given computation.
|
||||
bool is_computation_scheduled(const HloComputation* computation) const {
|
||||
return sequences_.count(computation->unique_id()) == 1;
|
||||
return sequences_.contains(computation->unique_id());
|
||||
}
|
||||
|
||||
// Updates the schedule such that it is (again) a valid schedule for the
|
||||
|
@ -15,6 +15,7 @@ limitations under the License.
|
||||
|
||||
#include "tensorflow/compiler/xla/service/hlo_sharding.h"
|
||||
|
||||
#include "absl/container/flat_hash_set.h"
|
||||
#include "absl/strings/str_cat.h"
|
||||
#include "absl/strings/str_join.h"
|
||||
#include "tensorflow/compiler/xla/overflow_util.h"
|
||||
@ -316,7 +317,7 @@ Status HloSharding::ValidateNonTuple(const Shape& shape,
|
||||
// All tile assignments must be less than the number of available cores and
|
||||
// unique.
|
||||
Status status = Status::OK();
|
||||
std::set<int64> seen_cores;
|
||||
absl::flat_hash_set<int64> seen_cores;
|
||||
tile_assignment_.Each(
|
||||
[&](absl::Span<const int64> indices, int32 core) {
|
||||
// Don't overwrite a bad status, so we report the first error.
|
||||
@ -324,7 +325,7 @@ Status HloSharding::ValidateNonTuple(const Shape& shape,
|
||||
if (core >= num_devices) {
|
||||
status = tensorflow::errors::InvalidArgument(StrCat(
|
||||
"core ", core, " > ", num_devices, " in tile assignment"));
|
||||
} else if (seen_cores.count(core) != 0) {
|
||||
} else if (seen_cores.contains(core)) {
|
||||
status = tensorflow::errors::InvalidArgument(
|
||||
StrCat("core ", core, " is not unique in tile assignment"));
|
||||
}
|
||||
|
@ -99,7 +99,7 @@ std::vector<PassThrough> LocatePassThroughDomainLinks(
|
||||
<< "Instruction is not a kDomain: " << instruction->ToString();
|
||||
for (HloInstruction* user : instruction->users()) {
|
||||
if (user->opcode() == HloOpcode::kDomain &&
|
||||
domain.exit_domains.count(user) != 0) {
|
||||
domain.exit_domains.contains(user)) {
|
||||
pass_through.emplace_back(user, instruction);
|
||||
VLOG(2) << "Found passthrough domain link:";
|
||||
VLOG(2) << " " << user->ToString();
|
||||
@ -253,7 +253,7 @@ StatusOr<bool> ApplyShardingFromUsers(HloInstruction* instruction,
|
||||
instruction->shape(), HloSharding::AssignDevice(kUnassignedDevice));
|
||||
for (HloInstruction* user : instruction->users()) {
|
||||
if (user->opcode() == HloOpcode::kDomain &&
|
||||
domain.exit_domains.count(user) > 0) {
|
||||
domain.exit_domains.contains(user)) {
|
||||
// If a user is a domain and it is registered in the domain exits, then
|
||||
// the instruction sharding is taken directly from the domain, and no
|
||||
// further users need to be visited.
|
||||
|
@ -103,7 +103,7 @@ Status IndexedArrayAnalysis::TraverseAndPopulateCache(
|
||||
|
||||
do {
|
||||
const HloInstruction* instr = stack.back();
|
||||
if (cache_.count(instr)) {
|
||||
if (cache_.contains(instr)) {
|
||||
stack.pop_back();
|
||||
continue;
|
||||
}
|
||||
@ -111,9 +111,9 @@ Status IndexedArrayAnalysis::TraverseAndPopulateCache(
|
||||
switch (FindOrDie(dfs_state_map, instr)) {
|
||||
case kDiscovered: {
|
||||
for (const HloInstruction* operand : instr->operands()) {
|
||||
if (!cache_.count(operand)) {
|
||||
if (!cache_.contains(operand)) {
|
||||
stack.push_back(operand);
|
||||
CHECK(!dfs_state_map.count(operand) ||
|
||||
CHECK(!dfs_state_map.contains(operand) ||
|
||||
dfs_state_map[operand] == kDiscovered);
|
||||
dfs_state_map[operand] = kDiscovered;
|
||||
}
|
||||
|
@ -2135,7 +2135,7 @@ Status LayoutAssignment::ClearPreviousPassSideEffects(HloModule* module) {
|
||||
for (HloInstruction* instruction :
|
||||
computation->MakeInstructionPostOrder()) {
|
||||
if (instruction->opcode() == HloOpcode::kCopy &&
|
||||
added_copies_.count(instruction) > 0) {
|
||||
added_copies_.contains(instruction)) {
|
||||
VLOG(5) << "Removing added copy: " << instruction->ToString();
|
||||
TF_RETURN_IF_ERROR(
|
||||
instruction->ReplaceAllUsesWith(instruction->mutable_operand(0)));
|
||||
|
@ -243,7 +243,7 @@ class ChannelLayoutConstraints {
|
||||
|
||||
// Returns true if channel_id has a layout constraint.
|
||||
bool IsChannelConstrained(int64 channel_id) const {
|
||||
return constraints_.count(channel_id) > 0;
|
||||
return constraints_.contains(channel_id);
|
||||
}
|
||||
|
||||
// Given `shape`, apply the layout for `channel_id`. `channel_id` must already
|
||||
@ -276,7 +276,7 @@ class ChannelLayoutConstraints {
|
||||
}
|
||||
|
||||
private:
|
||||
std::unordered_map<int64, Layout> constraints_;
|
||||
absl::flat_hash_map<int64, Layout> constraints_;
|
||||
};
|
||||
|
||||
// HLO pass which assigns layouts to all instructions in the HLO module while
|
||||
|
@ -169,6 +169,7 @@ cc_library(
|
||||
"//tensorflow/compiler/xla/service:elemental_ir_emitter",
|
||||
"//tensorflow/compiler/xla/service:hlo",
|
||||
"//tensorflow/core:lib",
|
||||
"@com_google_absl//absl/container:flat_hash_map",
|
||||
"@com_google_absl//absl/types:optional",
|
||||
"@com_google_absl//absl/types:span",
|
||||
"@llvm//:core",
|
||||
|
@ -76,15 +76,12 @@ class AliasAnalysis {
|
||||
// A map from a buffer slice to metadata corresponding to its alias.scope
|
||||
// metadata. The index kParameterAliasSet is used to hold aliasing
|
||||
// information for parameters.
|
||||
absl::flat_hash_map<BufferAllocation::Slice, llvm::MDNode*,
|
||||
BufferAllocation::Slice::Hasher>
|
||||
absl::flat_hash_map<BufferAllocation::Slice, llvm::MDNode*>
|
||||
alias_scope_metadata_;
|
||||
|
||||
// A map from a buffer slice to metadata corresponding to its noalias
|
||||
// metadata.
|
||||
absl::flat_hash_map<BufferAllocation::Slice, llvm::MDNode*,
|
||||
BufferAllocation::Slice::Hasher>
|
||||
noalias_metadata_;
|
||||
absl::flat_hash_map<BufferAllocation::Slice, llvm::MDNode*> noalias_metadata_;
|
||||
};
|
||||
|
||||
} // namespace llvm_ir
|
||||
|
@ -35,7 +35,7 @@ using llvm_ir::IrArray;
|
||||
Status FusedIrEmitter::DefaultAction(HloInstruction* hlo) {
|
||||
indexed_generators_[hlo] =
|
||||
[=](const IrArray::Index& index) -> StatusOr<llvm::Value*> {
|
||||
if (generated_value_cache_[hlo].count(index.multidim()) > 0) {
|
||||
if (generated_value_cache_[hlo].contains(index.multidim())) {
|
||||
llvm::Value* generated_value =
|
||||
generated_value_cache_[hlo][index.multidim()];
|
||||
llvm::BasicBlock* generated_value_bb = nullptr;
|
||||
|
@ -19,6 +19,7 @@ limitations under the License.
|
||||
#include <map>
|
||||
#include <unordered_map>
|
||||
|
||||
#include "absl/container/flat_hash_map.h"
|
||||
#include "absl/types/optional.h"
|
||||
#include "absl/types/span.h"
|
||||
#include "llvm/IR/IRBuilder.h"
|
||||
@ -134,8 +135,9 @@ class FusedIrEmitter : public DfsHloVisitorWithDefault {
|
||||
|
||||
// Cache of generated values, lest we regenerate an element of a node with
|
||||
// multiple outgoing edges
|
||||
std::unordered_map<const HloInstruction*,
|
||||
std::map<std::vector<llvm::Value*>, llvm::Value*>>
|
||||
absl::flat_hash_map<
|
||||
const HloInstruction*,
|
||||
absl::flat_hash_map<std::vector<llvm::Value*>, llvm::Value*>>
|
||||
generated_value_cache_;
|
||||
};
|
||||
|
||||
|
@ -198,7 +198,7 @@ void MultiOutputFusion::Update(HloInstruction* instr1, HloInstruction* instr2) {
|
||||
if (instr == fusion || is_fused(instr) || is_connected(fusion, instr)) {
|
||||
continue;
|
||||
}
|
||||
if (in_list.count(instr) > 0) {
|
||||
if (in_list.contains(instr)) {
|
||||
continue;
|
||||
}
|
||||
int64 profit = GetProfit(instr, fusion);
|
||||
|
@ -19,6 +19,7 @@ limitations under the License.
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
|
||||
#include "absl/container/flat_hash_set.h"
|
||||
#include "absl/memory/memory.h"
|
||||
#include "absl/strings/str_cat.h"
|
||||
#include "absl/strings/str_format.h"
|
||||
@ -55,11 +56,10 @@ bool PointsToSet::IsAmbiguous() const {
|
||||
|
||||
bool PointsToSet::IsDistinct() const {
|
||||
bool distinct = true;
|
||||
std::set<const LogicalBuffer*> all_points_to;
|
||||
ForEachElement([&distinct, &all_points_to](const ShapeIndex& /*index*/,
|
||||
const BufferList& points_to) {
|
||||
absl::flat_hash_set<const LogicalBuffer*> all_points_to;
|
||||
ForEachElement([&](const ShapeIndex& /*index*/, const BufferList& points_to) {
|
||||
for (auto& buffer : points_to) {
|
||||
if (all_points_to.count(buffer) != 0) {
|
||||
if (all_points_to.contains(buffer)) {
|
||||
distinct = false;
|
||||
}
|
||||
all_points_to.insert(buffer);
|
||||
|
@ -89,7 +89,7 @@ static void CreateLoopInvariantCopy(
|
||||
|
||||
HloInstruction* next_operand =
|
||||
frame->instruction->mutable_operand(frame->operand_index++);
|
||||
if (hoisted_instructions->count(next_operand) ||
|
||||
if (hoisted_instructions->contains(next_operand) ||
|
||||
next_operand == while_body_param) {
|
||||
continue;
|
||||
}
|
||||
@ -241,7 +241,7 @@ WhileLoopInvariantCodeMotion::TryHoistingInvariantInstructionsFromWhileBody(
|
||||
|
||||
auto is_invariant = [&](HloInstruction* op) {
|
||||
return hoisted_instructions.find(op) != hoisted_instructions.end() ||
|
||||
unhoisted_invariant_instructions.count(op) ||
|
||||
unhoisted_invariant_instructions.contains(op) ||
|
||||
op->opcode() == HloOpcode::kConstant;
|
||||
};
|
||||
|
||||
|
@ -127,7 +127,7 @@ static StatusOr<bool> TryRemoveDeadWhileParams(HloInstruction* while_op) {
|
||||
// through to the while body's root, count that element as "used", since
|
||||
// removing that element would be observable.
|
||||
for (int64 i = 0; i < while_body_root->operand_count(); ++i) {
|
||||
if (used_tuple_indices.count(i)) {
|
||||
if (used_tuple_indices.contains(i)) {
|
||||
continue;
|
||||
}
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user