[XLA] Migrate from gtl::FlatSet to absl::flat_hash_set
PiperOrigin-RevId: 215324035
This commit is contained in:
parent
beede8525b
commit
991f06fd50
@ -325,6 +325,7 @@ cc_library(
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:protos_all_cc",
|
||||
"@com_google_absl//absl/container:flat_hash_map",
|
||||
"@com_google_absl//absl/container:flat_hash_set",
|
||||
"@com_google_absl//absl/memory",
|
||||
"@com_google_absl//absl/strings",
|
||||
"@com_google_absl//absl/types:optional",
|
||||
@ -407,6 +408,7 @@ cc_library(
|
||||
"//tensorflow/core/kernels:bounds_check",
|
||||
"@com_google_absl//absl/algorithm:container",
|
||||
"@com_google_absl//absl/container:flat_hash_map",
|
||||
"@com_google_absl//absl/container:flat_hash_set",
|
||||
"@com_google_absl//absl/memory",
|
||||
"@com_google_absl//absl/strings",
|
||||
],
|
||||
|
||||
@ -16,11 +16,11 @@ limitations under the License.
|
||||
#include "tensorflow/compiler/jit/deadness_analysis.h"
|
||||
#include "absl/algorithm/container.h"
|
||||
#include "absl/container/flat_hash_map.h"
|
||||
#include "absl/container/flat_hash_set.h"
|
||||
#include "absl/strings/str_join.h"
|
||||
#include "tensorflow/compiler/jit/deadness_analysis_internal.h"
|
||||
#include "tensorflow/core/graph/algorithm.h"
|
||||
#include "tensorflow/core/graph/tensor_id.h"
|
||||
#include "tensorflow/core/lib/gtl/flatset.h"
|
||||
#include "tensorflow/core/lib/hash/hash.h"
|
||||
|
||||
// ALGORITHM OVERVIEW
|
||||
@ -298,7 +298,7 @@ class SymbolPredicate : public Predicate {
|
||||
|
||||
template <typename FunctionTy>
|
||||
/*static*/ void Predicate::Visit(Predicate* p, const FunctionTy& func) {
|
||||
gtl::FlatSet<Predicate*> visited;
|
||||
absl::flat_hash_set<Predicate*> visited;
|
||||
std::vector<Predicate*> stack;
|
||||
|
||||
stack.push_back(p);
|
||||
@ -467,7 +467,7 @@ Predicate* PredicateFactory::MakeAndOrImpl(
|
||||
is_and ? Predicate::Kind::kAnd : Predicate::Kind::kOr;
|
||||
Predicate::Kind other_pred_kind =
|
||||
is_and ? Predicate::Kind::kOr : Predicate::Kind::kAnd;
|
||||
gtl::FlatSet<Predicate*> simplified_ops_set;
|
||||
absl::flat_hash_set<Predicate*> simplified_ops_set;
|
||||
std::vector<Predicate*> simplified_ops;
|
||||
for (Predicate* op : operands) {
|
||||
// Simplify A&A => A and A|A => A.
|
||||
@ -492,7 +492,7 @@ Predicate* PredicateFactory::MakeAndOrImpl(
|
||||
}
|
||||
|
||||
// Simplify "A&~A=>False" and "A|~A=>True".
|
||||
gtl::FlatSet<Predicate*> negated_ops;
|
||||
absl::flat_hash_set<Predicate*> negated_ops;
|
||||
for (Predicate* op : simplified_ops) {
|
||||
if (op->kind() == Predicate::Kind::kNot) {
|
||||
negated_ops.insert(dynamic_cast<NotPredicate&>(*op).operand());
|
||||
@ -512,7 +512,7 @@ Predicate* PredicateFactory::MakeAndOrImpl(
|
||||
//
|
||||
// First find any predicates contained in all subops.
|
||||
std::vector<Predicate*> common_inner_operands;
|
||||
gtl::FlatSet<Predicate*> common_inner_operands_set;
|
||||
absl::flat_hash_set<Predicate*> common_inner_operands_set;
|
||||
for (Predicate* op : simplified_ops) {
|
||||
if (op->kind() != other_pred_kind) {
|
||||
common_inner_operands.clear();
|
||||
|
||||
@ -22,6 +22,7 @@ limitations under the License.
|
||||
#include <unordered_map>
|
||||
#include <vector>
|
||||
|
||||
#include "absl/container/flat_hash_set.h"
|
||||
#include "absl/strings/match.h"
|
||||
#include "absl/strings/str_cat.h"
|
||||
#include "tensorflow/compiler/jit/graphcycles/graphcycles.h"
|
||||
@ -44,7 +45,6 @@ limitations under the License.
|
||||
#include "tensorflow/core/graph/graph.h"
|
||||
#include "tensorflow/core/graph/graph_def_builder.h"
|
||||
#include "tensorflow/core/graph/tensor_id.h"
|
||||
#include "tensorflow/core/lib/gtl/flatset.h"
|
||||
#include "tensorflow/core/lib/gtl/map_util.h"
|
||||
#include "tensorflow/core/lib/hash/hash.h"
|
||||
#include "tensorflow/core/public/session_options.h"
|
||||
@ -78,7 +78,8 @@ void SortControlInputs(GraphDef* gdef) {
|
||||
namespace {
|
||||
|
||||
bool AreAllParentsGuaranteedConst(
|
||||
const Node& n, const gtl::FlatSet<const Node*>& runtime_const_nodes) {
|
||||
const Node& n,
|
||||
const absl::flat_hash_set<const Node*>& runtime_const_nodes) {
|
||||
if (n.type_string() == "GuaranteeConst") {
|
||||
// If the current node is itself a cast-to-const, no need
|
||||
// to look at the incoming edges.
|
||||
@ -101,7 +102,7 @@ bool AreAllParentsGuaranteedConst(
|
||||
void MarkGuaranteedConstants(
|
||||
const Graph& graph,
|
||||
const std::vector<std::pair<const Node*, Node*>>& src_arg_pairs) {
|
||||
gtl::FlatSet<const Node*> guaranteed_const_nodes;
|
||||
absl::flat_hash_set<const Node*> guaranteed_const_nodes;
|
||||
std::vector<const Node*> srcs;
|
||||
srcs.reserve(src_arg_pairs.size());
|
||||
for (const auto& src_arg : src_arg_pairs) {
|
||||
|
||||
@ -15,13 +15,13 @@ limitations under the License.
|
||||
|
||||
#include "tensorflow/compiler/jit/encapsulate_xla_computations_pass.h"
|
||||
|
||||
#include "absl/container/flat_hash_set.h"
|
||||
#include "absl/memory/memory.h"
|
||||
#include "absl/strings/str_cat.h"
|
||||
#include "tensorflow/compiler/jit/encapsulate_subgraphs_pass.h"
|
||||
#include "tensorflow/compiler/tf2xla/dump_graph.h"
|
||||
#include "tensorflow/compiler/xla/status_macros.h"
|
||||
#include "tensorflow/core/framework/node_def.pb.h"
|
||||
#include "tensorflow/core/lib/gtl/flatset.h"
|
||||
#include "tensorflow/core/lib/hash/hash.h"
|
||||
#include "tensorflow/core/lib/strings/proto_serialization.h"
|
||||
#include "tensorflow/core/lib/strings/str_util.h"
|
||||
@ -62,7 +62,7 @@ DataType EdgeType(const Edge* edge) {
|
||||
}
|
||||
|
||||
// Adds the control inputs of `node` to `*deps`.
|
||||
void AddControlInputs(const Node& node, gtl::FlatSet<Node*>* deps) {
|
||||
void AddControlInputs(const Node& node, absl::flat_hash_set<Node*>* deps) {
|
||||
for (const Edge* edge : node.in_edges()) {
|
||||
if (edge->IsControlEdge()) {
|
||||
deps->insert(edge->src());
|
||||
@ -71,7 +71,7 @@ void AddControlInputs(const Node& node, gtl::FlatSet<Node*>* deps) {
|
||||
}
|
||||
|
||||
// Adds the control outputs of `node` to `*deps`.
|
||||
void AddControlOutputs(const Node& node, gtl::FlatSet<Node*>* deps) {
|
||||
void AddControlOutputs(const Node& node, absl::flat_hash_set<Node*>* deps) {
|
||||
for (const Edge* edge : node.out_edges()) {
|
||||
if (edge->IsControlEdge()) {
|
||||
deps->insert(edge->dst());
|
||||
@ -246,7 +246,7 @@ Status RewriteSubgraph(const std::vector<OutputTensor>& arg_source_tensors,
|
||||
|
||||
// Data and control inputs to the new XlaLaunch node.
|
||||
std::vector<std::pair<Node*, int>> data_inputs(num_inputs);
|
||||
gtl::FlatSet<Node*> control_inputs;
|
||||
absl::flat_hash_set<Node*> control_inputs;
|
||||
DataTypeVector arg_types(num_args);
|
||||
|
||||
AddControlInputs(*launch, &control_inputs);
|
||||
@ -266,7 +266,7 @@ Status RewriteSubgraph(const std::vector<OutputTensor>& arg_source_tensors,
|
||||
|
||||
// Outputs.
|
||||
const int num_outputs = launch->output_types().size();
|
||||
gtl::FlatSet<Node*> control_outputs;
|
||||
absl::flat_hash_set<Node*> control_outputs;
|
||||
std::vector<std::vector<std::pair<Node*, int>>> data_outputs(num_outputs);
|
||||
DataTypeVector output_types(num_outputs);
|
||||
|
||||
|
||||
@ -21,6 +21,7 @@ limitations under the License.
|
||||
#include <unordered_map>
|
||||
#include <unordered_set>
|
||||
|
||||
#include "absl/container/flat_hash_set.h"
|
||||
#include "tensorflow/compiler/jit/deadness_analysis.h"
|
||||
#include "tensorflow/compiler/jit/defs.h"
|
||||
#include "tensorflow/compiler/jit/graphcycles/graphcycles.h"
|
||||
@ -42,7 +43,6 @@ limitations under the License.
|
||||
#include "tensorflow/core/graph/control_flow.h"
|
||||
#include "tensorflow/core/kernels/bounds_check.h"
|
||||
#include "tensorflow/core/lib/gtl/cleanup.h"
|
||||
#include "tensorflow/core/lib/gtl/flatset.h"
|
||||
#include "tensorflow/core/lib/strings/stringprintf.h"
|
||||
#include "tensorflow/core/public/version.h"
|
||||
|
||||
@ -371,7 +371,7 @@ bool IsXlaFusable(const NodeDef& node) {
|
||||
Status FindCompilationCandidates(
|
||||
const Graph& graph, FunctionLibraryDefinition* flib_def, Env* env,
|
||||
const std::function<bool(const Node*, const DeviceType&)>& is_compilable_fn,
|
||||
OrderedNodeSet* candidates, gtl::FlatSet<Node*>* isolated_nodes) {
|
||||
OrderedNodeSet* candidates, absl::flat_hash_set<Node*>* isolated_nodes) {
|
||||
OptimizerOptions opts;
|
||||
std::unique_ptr<ProcessFunctionLibraryRuntime> pflr(
|
||||
new ProcessFunctionLibraryRuntime(nullptr, env, TF_GRAPH_DEF_VERSION,
|
||||
@ -849,7 +849,7 @@ Status MarkForCompilationPass::RunImpl(
|
||||
Graph* graph = options.graph->get();
|
||||
|
||||
OrderedNodeSet compilation_candidates;
|
||||
gtl::FlatSet<Node*> isolated_nodes;
|
||||
absl::flat_hash_set<Node*> isolated_nodes;
|
||||
TF_RETURN_IF_ERROR(FindCompilationCandidates(
|
||||
*graph, options.flib_def,
|
||||
(options.session_options != nullptr) ? options.session_options->env
|
||||
|
||||
@ -15,17 +15,18 @@ limitations under the License.
|
||||
|
||||
#include "tensorflow/compiler/jit/partially_decluster_pass.h"
|
||||
#include "absl/algorithm/container.h"
|
||||
#include "absl/container/flat_hash_set.h"
|
||||
#include "absl/strings/str_cat.h"
|
||||
#include "tensorflow/compiler/jit/xla_cluster_util.h"
|
||||
#include "tensorflow/compiler/tf2xla/const_analysis.h"
|
||||
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
|
||||
#include "tensorflow/core/framework/memory_types.h"
|
||||
#include "tensorflow/core/framework/node_def.pb.h"
|
||||
#include "tensorflow/core/lib/gtl/flatset.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace {
|
||||
Status FindNodesToDecluster(const Graph& graph, gtl::FlatSet<Node*>* result,
|
||||
Status FindNodesToDecluster(const Graph& graph,
|
||||
absl::flat_hash_set<Node*>* result,
|
||||
absl::Span<Node* const> post_order) {
|
||||
// Find nodes that have at least one user outside their cluster that expects
|
||||
// hostmem output. These nodes should be cloned to outside the cluster to
|
||||
@ -171,7 +172,7 @@ Status PartiallyDeclusterToRemoveDeviceToHostCopies(Graph* graph) {
|
||||
GetPostOrder(*graph, &post_order, /*stable_comparator=*/NodeComparatorName(),
|
||||
/*edge_filter=*/NotBackedge);
|
||||
|
||||
gtl::FlatSet<Node*> nodes_to_partially_decluster;
|
||||
absl::flat_hash_set<Node*> nodes_to_partially_decluster;
|
||||
TF_RETURN_IF_ERROR(
|
||||
FindNodesToDecluster(*graph, &nodes_to_partially_decluster, post_order));
|
||||
|
||||
|
||||
@ -82,6 +82,7 @@ limitations under the License.
|
||||
|
||||
#include "tensorflow/compiler/jit/resource_operation_safety_analysis.h"
|
||||
|
||||
#include "absl/container/flat_hash_set.h"
|
||||
#include "absl/memory/memory.h"
|
||||
#include "absl/strings/str_join.h"
|
||||
#include "absl/types/optional.h"
|
||||
@ -89,7 +90,6 @@ limitations under the License.
|
||||
#include "tensorflow/core/framework/node_def.pb.h"
|
||||
#include "tensorflow/core/graph/algorithm.h"
|
||||
#include "tensorflow/core/graph/tensor_id.h"
|
||||
#include "tensorflow/core/lib/gtl/flatset.h"
|
||||
#include "tensorflow/core/lib/hash/hash.h"
|
||||
#include "tensorflow/core/util/ptr_util.h"
|
||||
|
||||
@ -176,7 +176,7 @@ string ResourceOpToString(const ResourceOp& resource_op) {
|
||||
// point.
|
||||
class ResourceOpSet {
|
||||
private:
|
||||
using Impl = gtl::FlatSet<ResourceOp>;
|
||||
using Impl = absl::flat_hash_set<ResourceOp>;
|
||||
|
||||
public:
|
||||
ResourceOpSet() = default;
|
||||
|
||||
@ -1105,6 +1105,7 @@ cc_library(
|
||||
"//tensorflow/core:test",
|
||||
"//tensorflow/core:testlib",
|
||||
"//tensorflow/core/kernels:ops_util",
|
||||
"@com_google_absl//absl/container:flat_hash_set",
|
||||
"@com_google_absl//absl/strings",
|
||||
],
|
||||
)
|
||||
|
||||
@ -45,6 +45,7 @@ limitations under the License.
|
||||
#include <random>
|
||||
#include <unordered_map>
|
||||
|
||||
#include "absl/container/flat_hash_set.h"
|
||||
#include "absl/strings/str_cat.h"
|
||||
#include "absl/strings/string_view.h"
|
||||
#include "tensorflow/compiler/jit/defs.h"
|
||||
@ -63,7 +64,6 @@ limitations under the License.
|
||||
#include "tensorflow/core/kernels/ops_util.h"
|
||||
#include "tensorflow/core/lib/core/status.h"
|
||||
#include "tensorflow/core/lib/core/status_test_util.h"
|
||||
#include "tensorflow/core/lib/gtl/flatset.h"
|
||||
#include "tensorflow/core/platform/test.h"
|
||||
#include "tensorflow/core/public/session.h"
|
||||
#include "tensorflow/core/public/session_options.h"
|
||||
@ -457,7 +457,7 @@ Tensor OpTest::RandomTensor(DataType dtype, bool needs_unique_values,
|
||||
Tensor tensor(dtype, TensorShape(shape));
|
||||
switch (dtype) {
|
||||
case DT_FLOAT: {
|
||||
gtl::FlatSet<float> already_generated;
|
||||
absl::flat_hash_set<float> already_generated;
|
||||
std::uniform_real_distribution<float> distribution(-1.0f, 1.0f);
|
||||
test::FillFn<float>(&tensor, [&](int i) -> float {
|
||||
float generated;
|
||||
@ -470,7 +470,7 @@ Tensor OpTest::RandomTensor(DataType dtype, bool needs_unique_values,
|
||||
break;
|
||||
}
|
||||
case DT_DOUBLE: {
|
||||
gtl::FlatSet<double> already_generated;
|
||||
absl::flat_hash_set<double> already_generated;
|
||||
std::uniform_real_distribution<double> distribution(-1.0, 1.0);
|
||||
test::FillFn<double>(&tensor, [&](int i) -> double {
|
||||
double generated;
|
||||
@ -483,7 +483,7 @@ Tensor OpTest::RandomTensor(DataType dtype, bool needs_unique_values,
|
||||
break;
|
||||
}
|
||||
case DT_COMPLEX64: {
|
||||
gtl::FlatSet<std::pair<float, float>> already_generated;
|
||||
absl::flat_hash_set<std::pair<float, float>> already_generated;
|
||||
std::uniform_real_distribution<float> distribution(-1.0f, 1.0f);
|
||||
test::FillFn<complex64>(&tensor, [&](int i) {
|
||||
complex64 generated;
|
||||
@ -500,7 +500,7 @@ Tensor OpTest::RandomTensor(DataType dtype, bool needs_unique_values,
|
||||
break;
|
||||
}
|
||||
case DT_INT32: {
|
||||
gtl::FlatSet<int32> already_generated;
|
||||
absl::flat_hash_set<int32> already_generated;
|
||||
std::uniform_int_distribution<int32> distribution(-(1 << 20), 1 << 20);
|
||||
test::FillFn<int32>(&tensor, [&](int i) -> int32 {
|
||||
int32 generated;
|
||||
@ -513,7 +513,7 @@ Tensor OpTest::RandomTensor(DataType dtype, bool needs_unique_values,
|
||||
break;
|
||||
}
|
||||
case DT_INT64: {
|
||||
gtl::FlatSet<int64> already_generated;
|
||||
absl::flat_hash_set<int64> already_generated;
|
||||
std::uniform_int_distribution<int64> distribution(-(1LL << 40),
|
||||
1LL << 40);
|
||||
test::FillFn<int64>(&tensor, [&](int i) -> int64 {
|
||||
@ -527,7 +527,7 @@ Tensor OpTest::RandomTensor(DataType dtype, bool needs_unique_values,
|
||||
break;
|
||||
}
|
||||
case DT_BOOL: {
|
||||
gtl::FlatSet<bool> already_generated;
|
||||
absl::flat_hash_set<bool> already_generated;
|
||||
std::bernoulli_distribution distribution;
|
||||
test::FillFn<bool>(&tensor, [&](int i) -> bool {
|
||||
bool generated;
|
||||
|
||||
@ -221,6 +221,7 @@ cc_library(
|
||||
"//tensorflow/core:lib",
|
||||
"@com_google_absl//absl/algorithm:container",
|
||||
"@com_google_absl//absl/container:flat_hash_map",
|
||||
"@com_google_absl//absl/container:flat_hash_set",
|
||||
"@com_google_absl//absl/memory",
|
||||
"@com_google_absl//absl/strings",
|
||||
"@com_google_absl//absl/types:span",
|
||||
|
||||
@ -22,6 +22,7 @@ limitations under the License.
|
||||
#include <utility>
|
||||
|
||||
#include "absl/algorithm/container.h"
|
||||
#include "absl/container/flat_hash_set.h"
|
||||
#include "absl/memory/memory.h"
|
||||
#include "absl/strings/match.h"
|
||||
#include "absl/strings/str_cat.h"
|
||||
@ -33,7 +34,6 @@ limitations under the License.
|
||||
#include "tensorflow/compiler/xla/service/hlo_opcode.h"
|
||||
#include "tensorflow/compiler/xla/service/shape_inference.h"
|
||||
#include "tensorflow/compiler/xla/util.h"
|
||||
#include "tensorflow/core/lib/gtl/flatset.h"
|
||||
#include "tensorflow/core/platform/mutex.h"
|
||||
|
||||
namespace xla {
|
||||
@ -2290,7 +2290,7 @@ StatusOr<XlaComputation> XlaBuilder::BuildConstantSubGraph(
|
||||
// also a valid dependency order). The related ops will be added to the
|
||||
// subgraph in the same order.
|
||||
std::set<int64> related_ops;
|
||||
tensorflow::gtl::FlatSet<int64> related_calls; // Related computations.
|
||||
absl::flat_hash_set<int64> related_calls; // Related computations.
|
||||
std::queue<int64> worklist;
|
||||
worklist.push(root->id());
|
||||
related_ops.insert(root->id());
|
||||
|
||||
@ -22,6 +22,7 @@ limitations under the License.
|
||||
#include <utility>
|
||||
|
||||
#include "absl/container/flat_hash_map.h"
|
||||
#include "absl/container/flat_hash_set.h"
|
||||
#include "absl/strings/string_view.h"
|
||||
#include "absl/types/span.h"
|
||||
#include "tensorflow/compiler/xla/client/padding.h"
|
||||
@ -35,7 +36,6 @@ limitations under the License.
|
||||
#include "tensorflow/compiler/xla/statusor.h"
|
||||
#include "tensorflow/compiler/xla/types.h"
|
||||
#include "tensorflow/compiler/xla/xla_data.pb.h"
|
||||
#include "tensorflow/core/lib/gtl/flatset.h"
|
||||
#include "tensorflow/core/platform/macros.h"
|
||||
#include "tensorflow/core/platform/stacktrace.h"
|
||||
#include "tensorflow/core/platform/types.h"
|
||||
@ -1035,7 +1035,7 @@ class XlaBuilder {
|
||||
std::map<int64, HloComputationProto> embedded_;
|
||||
|
||||
// The unique parameter numbers.
|
||||
tensorflow::gtl::FlatSet<int64> parameter_numbers_;
|
||||
absl::flat_hash_set<int64> parameter_numbers_;
|
||||
|
||||
// The metadata to attach to each op. This is structured as a "modal"-like
|
||||
// operation, in order to simplify client code (and not sprinkle this metadata
|
||||
|
||||
@ -147,6 +147,7 @@ cc_library(
|
||||
"//tensorflow/compiler/xla:util",
|
||||
"//tensorflow/core:lib",
|
||||
"@com_google_absl//absl/container:flat_hash_map",
|
||||
"@com_google_absl//absl/container:flat_hash_set",
|
||||
],
|
||||
)
|
||||
|
||||
@ -183,6 +184,7 @@ cc_library(
|
||||
"//tensorflow/compiler/xla:xla_data_proto",
|
||||
"//tensorflow/core:lib",
|
||||
"@com_google_absl//absl/algorithm:container",
|
||||
"@com_google_absl//absl/container:flat_hash_set",
|
||||
"@com_google_absl//absl/strings",
|
||||
"@com_google_absl//absl/strings:str_format",
|
||||
"@com_google_absl//absl/types:span",
|
||||
@ -336,6 +338,7 @@ cc_library(
|
||||
"//tensorflow/core:lib_internal",
|
||||
"@com_google_absl//absl/algorithm:container",
|
||||
"@com_google_absl//absl/container:flat_hash_map",
|
||||
"@com_google_absl//absl/container:flat_hash_set",
|
||||
"@com_google_absl//absl/container:inlined_vector",
|
||||
"@com_google_absl//absl/memory",
|
||||
"@com_google_absl//absl/strings",
|
||||
@ -490,6 +493,7 @@ cc_library(
|
||||
"//tensorflow/compiler/xla:util",
|
||||
"//tensorflow/core:lib",
|
||||
"@com_google_absl//absl/container:flat_hash_map",
|
||||
"@com_google_absl//absl/container:flat_hash_set",
|
||||
"@com_google_absl//absl/memory",
|
||||
"@com_google_absl//absl/strings",
|
||||
"@com_google_absl//absl/strings:str_format",
|
||||
@ -781,6 +785,7 @@ cc_library(
|
||||
"//tensorflow/compiler/xla:xla_data_proto",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:stream_executor_no_cuda",
|
||||
"@com_google_absl//absl/container:flat_hash_set",
|
||||
"@com_google_absl//absl/memory",
|
||||
"@com_google_absl//absl/strings",
|
||||
"@com_google_absl//absl/strings:str_format",
|
||||
@ -959,6 +964,7 @@ cc_library(
|
||||
"//tensorflow/compiler/xla:types",
|
||||
"//tensorflow/core:lib",
|
||||
"@com_google_absl//absl/container:flat_hash_map",
|
||||
"@com_google_absl//absl/container:flat_hash_set",
|
||||
"@com_google_absl//absl/strings",
|
||||
],
|
||||
)
|
||||
@ -995,6 +1001,7 @@ cc_library(
|
||||
"//tensorflow/compiler/xla:util",
|
||||
"//tensorflow/core:lib",
|
||||
"@com_google_absl//absl/container:flat_hash_map",
|
||||
"@com_google_absl//absl/container:flat_hash_set",
|
||||
"@com_google_absl//absl/strings",
|
||||
"@com_google_absl//absl/strings:str_format",
|
||||
],
|
||||
@ -1043,6 +1050,7 @@ cc_library(
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:lib_internal",
|
||||
"@com_google_absl//absl/container:flat_hash_map",
|
||||
"@com_google_absl//absl/container:flat_hash_set",
|
||||
"@com_google_absl//absl/memory",
|
||||
"@com_google_absl//absl/strings",
|
||||
"@com_google_absl//absl/strings:str_format",
|
||||
@ -1136,6 +1144,7 @@ cc_library(
|
||||
"//tensorflow/compiler/xla:util",
|
||||
"//tensorflow/core:lib",
|
||||
"@com_google_absl//absl/container:flat_hash_map",
|
||||
"@com_google_absl//absl/container:flat_hash_set",
|
||||
"@com_google_absl//absl/memory",
|
||||
],
|
||||
)
|
||||
@ -1230,6 +1239,7 @@ cc_library(
|
||||
"//tensorflow/compiler/xla:util",
|
||||
"//tensorflow/core:lib",
|
||||
"@com_google_absl//absl/container:flat_hash_map",
|
||||
"@com_google_absl//absl/container:flat_hash_set",
|
||||
"@com_google_absl//absl/memory",
|
||||
"@com_google_absl//absl/strings",
|
||||
"@com_google_absl//absl/types:span",
|
||||
@ -1275,6 +1285,7 @@ cc_library(
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:lib_internal",
|
||||
"@com_google_absl//absl/container:flat_hash_map",
|
||||
"@com_google_absl//absl/container:flat_hash_set",
|
||||
],
|
||||
)
|
||||
|
||||
@ -1348,6 +1359,7 @@ cc_library(
|
||||
"//tensorflow/compiler/xla/service:hlo_pass",
|
||||
"//tensorflow/core:lib",
|
||||
"@com_google_absl//absl/container:flat_hash_map",
|
||||
"@com_google_absl//absl/container:flat_hash_set",
|
||||
"@com_google_absl//absl/strings",
|
||||
],
|
||||
)
|
||||
@ -1660,6 +1672,7 @@ cc_library(
|
||||
"//tensorflow/compiler/xla:statusor",
|
||||
"//tensorflow/core:lib",
|
||||
"@com_google_absl//absl/container:flat_hash_map",
|
||||
"@com_google_absl//absl/container:flat_hash_set",
|
||||
"@com_google_absl//absl/strings",
|
||||
"@com_google_absl//absl/types:optional",
|
||||
],
|
||||
@ -2064,6 +2077,7 @@ cc_library(
|
||||
":logical_buffer",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:lib_internal",
|
||||
"@com_google_absl//absl/container:flat_hash_set",
|
||||
],
|
||||
)
|
||||
|
||||
@ -2099,6 +2113,7 @@ cc_library(
|
||||
"//tensorflow/compiler/xla:util",
|
||||
"//tensorflow/compiler/xla:xla_data_proto",
|
||||
"//tensorflow/core:lib",
|
||||
"@com_google_absl//absl/container:flat_hash_set",
|
||||
"@com_google_absl//absl/memory",
|
||||
"@com_google_absl//absl/strings",
|
||||
"@com_google_absl//absl/types:span",
|
||||
@ -2120,6 +2135,7 @@ cc_library(
|
||||
"//tensorflow/compiler/xla:util",
|
||||
"//tensorflow/compiler/xla:xla_data_proto",
|
||||
"//tensorflow/core:lib",
|
||||
"@com_google_absl//absl/container:flat_hash_set",
|
||||
"@com_google_absl//absl/container:inlined_vector",
|
||||
"@com_google_absl//absl/memory",
|
||||
"@com_google_absl//absl/strings",
|
||||
@ -2203,6 +2219,7 @@ cc_library(
|
||||
"//tensorflow/compiler/xla:util",
|
||||
"//tensorflow/compiler/xla:xla_data_proto",
|
||||
"//tensorflow/core:lib",
|
||||
"@com_google_absl//absl/container:flat_hash_set",
|
||||
"@com_google_absl//absl/strings",
|
||||
],
|
||||
)
|
||||
@ -2225,6 +2242,7 @@ cc_library(
|
||||
"//tensorflow/compiler/xla:xla_data_proto",
|
||||
"//tensorflow/core:lib",
|
||||
"@com_google_absl//absl/container:flat_hash_map",
|
||||
"@com_google_absl//absl/container:flat_hash_set",
|
||||
"@com_google_absl//absl/strings",
|
||||
"@com_google_absl//absl/types:span",
|
||||
],
|
||||
@ -2286,6 +2304,7 @@ cc_library(
|
||||
"//tensorflow/compiler/xla:xla_data_proto",
|
||||
"//tensorflow/core:lib",
|
||||
"@com_google_absl//absl/container:flat_hash_map",
|
||||
"@com_google_absl//absl/container:flat_hash_set",
|
||||
"@com_google_absl//absl/container:inlined_vector",
|
||||
"@com_google_absl//absl/memory",
|
||||
"@com_google_absl//absl/strings",
|
||||
@ -2343,6 +2362,7 @@ cc_library(
|
||||
"//tensorflow/compiler/xla:xla_data_proto",
|
||||
"//tensorflow/core:lib",
|
||||
"@com_google_absl//absl/container:flat_hash_map",
|
||||
"@com_google_absl//absl/container:flat_hash_set",
|
||||
"@com_google_absl//absl/memory",
|
||||
"@com_google_absl//absl/strings",
|
||||
"@com_google_absl//absl/strings:str_format",
|
||||
@ -2370,6 +2390,7 @@ cc_library(
|
||||
"//tensorflow/compiler/xla:util",
|
||||
"//tensorflow/core:lib",
|
||||
"@com_google_absl//absl/container:flat_hash_map",
|
||||
"@com_google_absl//absl/container:flat_hash_set",
|
||||
"@com_google_absl//absl/strings",
|
||||
],
|
||||
)
|
||||
@ -2487,6 +2508,7 @@ cc_library(
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:lib_internal",
|
||||
"@com_google_absl//absl/container:flat_hash_map",
|
||||
"@com_google_absl//absl/container:flat_hash_set",
|
||||
"@com_google_absl//absl/container:inlined_vector",
|
||||
"@com_google_absl//absl/strings",
|
||||
"@com_google_absl//absl/strings:str_format",
|
||||
@ -2616,6 +2638,7 @@ cc_library(
|
||||
"//tensorflow/compiler/xla:util",
|
||||
"//tensorflow/core:lib",
|
||||
"@com_google_absl//absl/container:flat_hash_map",
|
||||
"@com_google_absl//absl/container:flat_hash_set",
|
||||
"@com_google_absl//absl/memory",
|
||||
"@com_google_absl//absl/strings",
|
||||
"@com_google_absl//absl/strings:str_format",
|
||||
@ -2655,6 +2678,7 @@ cc_library(
|
||||
"//tensorflow/compiler/xla:types",
|
||||
"//tensorflow/compiler/xla:xla_data_proto",
|
||||
"//tensorflow/core:lib",
|
||||
"@com_google_absl//absl/container:flat_hash_set",
|
||||
"@com_google_absl//absl/container:inlined_vector",
|
||||
],
|
||||
)
|
||||
@ -2730,6 +2754,7 @@ cc_library(
|
||||
"//tensorflow/compiler/xla:util",
|
||||
"//tensorflow/core:lib",
|
||||
"@com_google_absl//absl/container:flat_hash_map",
|
||||
"@com_google_absl//absl/container:flat_hash_set",
|
||||
"@com_google_absl//absl/memory",
|
||||
],
|
||||
)
|
||||
@ -3300,6 +3325,7 @@ cc_library(
|
||||
"//tensorflow/core:lib",
|
||||
"@com_google_absl//absl/algorithm:container",
|
||||
"@com_google_absl//absl/container:flat_hash_map",
|
||||
"@com_google_absl//absl/container:flat_hash_set",
|
||||
"@com_google_absl//absl/container:inlined_vector",
|
||||
],
|
||||
)
|
||||
@ -3387,6 +3413,7 @@ cc_library(
|
||||
"//tensorflow/core:ptr_util",
|
||||
"@com_google_absl//absl/algorithm:container",
|
||||
"@com_google_absl//absl/container:flat_hash_map",
|
||||
"@com_google_absl//absl/container:flat_hash_set",
|
||||
"@com_google_absl//absl/container:inlined_vector",
|
||||
"@com_google_absl//absl/strings",
|
||||
"@com_google_absl//absl/types:optional",
|
||||
|
||||
@ -15,6 +15,7 @@ limitations under the License.
|
||||
|
||||
#include "tensorflow/compiler/xla/service/bfloat16_propagation.h"
|
||||
|
||||
#include "absl/container/flat_hash_set.h"
|
||||
#include "tensorflow/compiler/xla/literal.h"
|
||||
#include "tensorflow/compiler/xla/map_util.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_computation.h"
|
||||
@ -81,7 +82,7 @@ void BFloat16Propagation::RevertIfFusionInternalBF16Changes(
|
||||
};
|
||||
|
||||
auto root = fusion->fused_instructions_computation()->root_instruction();
|
||||
tensorflow::gtl::FlatSet<const HloValue*> changed_root_buffers;
|
||||
absl::flat_hash_set<const HloValue*> changed_root_buffers;
|
||||
|
||||
auto root_changes_it = changes_to_bf16_.find(root);
|
||||
if (root_changes_it != changes_to_bf16_.end()) {
|
||||
@ -500,7 +501,7 @@ void BFloat16Propagation::AdjustCalledComputationRoot(HloInstruction* hlo) {
|
||||
|
||||
bool BFloat16Propagation::ResolveInconsistencyOfAliasingBuffersHelper(
|
||||
HloComputation* computation,
|
||||
tensorflow::gtl::FlatSet<const HloComputation*>* visited_computations) {
|
||||
absl::flat_hash_set<const HloComputation*>* visited_computations) {
|
||||
bool parameter_changed = false;
|
||||
auto insts = computation->MakeInstructionPostOrder();
|
||||
// Do the adjustment on each instruction in the computation in reverse
|
||||
@ -560,7 +561,7 @@ bool BFloat16Propagation::ResolveInconsistencyOfAliasingBuffersHelper(
|
||||
// another input parameter. A fixed point will be reached because the
|
||||
// parameters can only be changed from BF16 to F32, not the other way
|
||||
// around.
|
||||
tensorflow::gtl::FlatSet<const HloComputation*> visited_in_while;
|
||||
absl::flat_hash_set<const HloComputation*> visited_in_while;
|
||||
while (ResolveInconsistencyOfAliasingBuffersHelper(hlo->while_condition(),
|
||||
&visited_in_while) ||
|
||||
ResolveInconsistencyOfAliasingBuffersHelper(hlo->while_body(),
|
||||
@ -587,7 +588,7 @@ void BFloat16Propagation::ResolveInconsistencyOfAliasingBuffers(
|
||||
HloModule* module) {
|
||||
const auto& computations_topological_order =
|
||||
module->MakeComputationPostOrder();
|
||||
tensorflow::gtl::FlatSet<const HloComputation*> resolved;
|
||||
absl::flat_hash_set<const HloComputation*> resolved;
|
||||
for (auto comp_it = computations_topological_order.rbegin();
|
||||
comp_it != computations_topological_order.rend(); ++comp_it) {
|
||||
if (ContainsKey(resolved, *comp_it)) {
|
||||
|
||||
@ -22,6 +22,7 @@ limitations under the License.
|
||||
#include <vector>
|
||||
|
||||
#include "absl/container/flat_hash_map.h"
|
||||
#include "absl/container/flat_hash_set.h"
|
||||
#include "tensorflow/compiler/xla/service/bfloat16_support.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_dataflow_analysis.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
|
||||
@ -82,7 +83,7 @@ class BFloat16Propagation : public HloModulePass {
|
||||
|
||||
// The set of instructions to consider using bfloat16, computed in the forward
|
||||
// pass.
|
||||
tensorflow::gtl::FlatSet<const HloInstruction*> consider_using_bfloat16_;
|
||||
absl::flat_hash_set<const HloInstruction*> consider_using_bfloat16_;
|
||||
|
||||
// ***************************
|
||||
// Functions called and state produced by the backward pass (from root to
|
||||
@ -111,12 +112,12 @@ class BFloat16Propagation : public HloModulePass {
|
||||
|
||||
// The set of HloInstructions that have been visited in the
|
||||
// opportunity-finding pass.
|
||||
tensorflow::gtl::FlatSet<const HloInstruction*>
|
||||
absl::flat_hash_set<const HloInstruction*>
|
||||
instructions_visited_in_backward_pass_;
|
||||
|
||||
// The set of HloComputations that have been visited in the
|
||||
// opportunity-finding pass.
|
||||
tensorflow::gtl::FlatSet<const HloComputation*>
|
||||
absl::flat_hash_set<const HloComputation*>
|
||||
computations_visited_in_backward_pass_;
|
||||
|
||||
// ***************************
|
||||
@ -132,7 +133,7 @@ class BFloat16Propagation : public HloModulePass {
|
||||
// point is reached.
|
||||
bool ResolveInconsistencyOfAliasingBuffersHelper(
|
||||
HloComputation* computation,
|
||||
tensorflow::gtl::FlatSet<const HloComputation*>* visited_computations);
|
||||
absl::flat_hash_set<const HloComputation*>* visited_computations);
|
||||
|
||||
// Makes the parameters of called computations match how they are called by
|
||||
// the given HLO.
|
||||
@ -183,7 +184,7 @@ class BFloat16Propagation : public HloModulePass {
|
||||
PrimitiveType target_type);
|
||||
|
||||
// The set of F32 HLO values that must be kept in F32.
|
||||
tensorflow::gtl::FlatSet<const HloValue*> values_that_must_be_kept_as_f32_;
|
||||
absl::flat_hash_set<const HloValue*> values_that_must_be_kept_as_f32_;
|
||||
|
||||
// Mapping from each HloComputation to the number of callers to it in the
|
||||
// module. Populated at the beginning of this pass.
|
||||
|
||||
@ -23,6 +23,7 @@ limitations under the License.
|
||||
#include <utility>
|
||||
|
||||
#include "absl/container/flat_hash_map.h"
|
||||
#include "absl/container/flat_hash_set.h"
|
||||
#include "absl/memory/memory.h"
|
||||
#include "absl/strings/str_cat.h"
|
||||
#include "absl/strings/str_format.h"
|
||||
@ -43,9 +44,9 @@ namespace xla {
|
||||
namespace {
|
||||
|
||||
using absl::flat_hash_map;
|
||||
using absl::flat_hash_set;
|
||||
using absl::StrAppend;
|
||||
using absl::StrAppendFormat;
|
||||
using ::tensorflow::gtl::FlatSet;
|
||||
using ::tensorflow::strings::HumanReadableNumBytes;
|
||||
|
||||
template <typename T>
|
||||
@ -129,8 +130,8 @@ Status GatherComputationsByAllocationType(
|
||||
|
||||
// Sets for quickly checking membership. Computations are returned in vectors
|
||||
// for stable iteration.
|
||||
FlatSet<const HloComputation*> thread_local_set;
|
||||
FlatSet<const HloComputation*> global_set;
|
||||
flat_hash_set<const HloComputation*> thread_local_set;
|
||||
flat_hash_set<const HloComputation*> global_set;
|
||||
|
||||
while (!worklist.empty()) {
|
||||
auto worklist_front = worklist.front();
|
||||
@ -445,7 +446,7 @@ bool BufferAssignment::SharesSliceAtIndex(
|
||||
bool BufferAssignment::HaveDisjointSlices(const HloInstruction* hlo_a,
|
||||
const HloInstruction* hlo_b) const {
|
||||
using SliceSet =
|
||||
FlatSet<BufferAllocation::Slice, BufferAllocation::Slice::Hasher>;
|
||||
flat_hash_set<BufferAllocation::Slice, BufferAllocation::Slice::Hasher>;
|
||||
// Gets the slices all of instr's subshapes. If any subshape doesn't have an
|
||||
// assigned slice, returns the empty set.
|
||||
auto collect_slices = [&](const HloInstruction* instr) -> SliceSet {
|
||||
@ -815,9 +816,9 @@ bool BufferAssigner::MaybeAssignBuffer(BufferAllocation* allocation,
|
||||
|
||||
Status BufferAssigner::AssignBuffersForComputation(
|
||||
const HloComputation* computation, bool is_thread_local,
|
||||
const FlatSet<const LogicalBuffer*>& colocated_buffers,
|
||||
const FlatSet<BufferAllocation::Index>& colocated_allocations,
|
||||
flat_hash_map<const HloComputation*, FlatSet<const LogicalBuffer*>>*
|
||||
const flat_hash_set<const LogicalBuffer*>& colocated_buffers,
|
||||
const flat_hash_set<BufferAllocation::Index>& colocated_allocations,
|
||||
flat_hash_map<const HloComputation*, flat_hash_set<const LogicalBuffer*>>*
|
||||
buffers_to_assign_sequentially,
|
||||
BufferAssignment* assignment) {
|
||||
// Buffers are sorted and assigned to BufferAllocations in decreasing order of
|
||||
@ -853,8 +854,8 @@ Status BufferAssigner::AssignBuffersForComputation(
|
||||
// buffers_to_assign_sequentially map, even if we end up with an empty set
|
||||
// of buffers. This ensures we can correctly determine whether to run
|
||||
// whole-module heap simulation.
|
||||
buffers_to_assign_sequentially->emplace(computation,
|
||||
FlatSet<const LogicalBuffer*>());
|
||||
buffers_to_assign_sequentially->emplace(
|
||||
computation, flat_hash_set<const LogicalBuffer*>());
|
||||
}
|
||||
|
||||
// Sort the LogicalBuffers first by size. We assign the larger LogicalBuffers
|
||||
@ -1046,11 +1047,11 @@ Status BufferAssigner::AssignBuffersForComputation(
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
flat_hash_map<LogicalBuffer::Color, FlatSet<const LogicalBuffer*>,
|
||||
flat_hash_map<LogicalBuffer::Color, flat_hash_set<const LogicalBuffer*>,
|
||||
LogicalBuffer::Color::Hasher>
|
||||
BufferAssigner::SplitBuffersByColor(
|
||||
const FlatSet<const LogicalBuffer*>& buffers) {
|
||||
flat_hash_map<LogicalBuffer::Color, FlatSet<const LogicalBuffer*>,
|
||||
const flat_hash_set<const LogicalBuffer*>& buffers) {
|
||||
flat_hash_map<LogicalBuffer::Color, flat_hash_set<const LogicalBuffer*>,
|
||||
LogicalBuffer::Color::Hasher>
|
||||
color_map;
|
||||
for (auto buffer : buffers) {
|
||||
@ -1060,7 +1061,8 @@ BufferAssigner::SplitBuffersByColor(
|
||||
}
|
||||
|
||||
Status BufferAssigner::AssignBuffersWithSequentialOrdering(
|
||||
const flat_hash_map<const HloComputation*, FlatSet<const LogicalBuffer*>>&
|
||||
const flat_hash_map<const HloComputation*,
|
||||
flat_hash_set<const LogicalBuffer*>>&
|
||||
buffers_to_assign_sequentially,
|
||||
bool run_whole_module_heap_simulation, BufferAssignment* assignment) {
|
||||
// Run the sequence of instructions through the heap simulator. The heuristic
|
||||
@ -1086,10 +1088,11 @@ Status BufferAssigner::AssignBuffersWithSequentialOrdering(
|
||||
// only live for the duration of their calling instructions.
|
||||
VLOG(1) << "Running whole-module heap simulation";
|
||||
HloSchedule schedule(&assignment->module());
|
||||
FlatSet<const LogicalBuffer*> all_buffers_to_assign;
|
||||
flat_hash_set<const LogicalBuffer*> all_buffers_to_assign;
|
||||
for (const auto& pair : buffers_to_assign_sequentially) {
|
||||
const HloComputation* computation = pair.first;
|
||||
const FlatSet<const LogicalBuffer*>& buffers_to_assign = pair.second;
|
||||
const flat_hash_set<const LogicalBuffer*>& buffers_to_assign =
|
||||
pair.second;
|
||||
const std::vector<const HloInstruction*>* instruction_sequence =
|
||||
hlo_ordering.SequentialOrder(*computation);
|
||||
CHECK(instruction_sequence != nullptr) << computation->name();
|
||||
@ -1123,7 +1126,8 @@ Status BufferAssigner::AssignBuffersWithSequentialOrdering(
|
||||
VLOG(1) << "Running per-computation heap simulation";
|
||||
for (const auto& pair : buffers_to_assign_sequentially) {
|
||||
const HloComputation* computation = pair.first;
|
||||
const FlatSet<const LogicalBuffer*>& buffers_to_assign = pair.second;
|
||||
const flat_hash_set<const LogicalBuffer*>& buffers_to_assign =
|
||||
pair.second;
|
||||
const std::vector<const HloInstruction*>* instruction_sequence =
|
||||
hlo_ordering.SequentialOrder(*computation);
|
||||
CHECK(instruction_sequence != nullptr) << computation->name();
|
||||
@ -1198,7 +1202,7 @@ std::vector<const LogicalBuffer*> ComputePeakMemoryLogicalBuffers(
|
||||
|
||||
// Next gather the set of logical buffers live at the earliest point of
|
||||
// maximal live set size.
|
||||
tensorflow::gtl::FlatSet<const LogicalBuffer*> live_buffers;
|
||||
absl::flat_hash_set<const LogicalBuffer*> live_buffers;
|
||||
live_size = 0;
|
||||
for (const auto& event : heap_trace.events()) {
|
||||
const LogicalBuffer* buffer = id_to_buffer.at(event.buffer_id());
|
||||
@ -1588,8 +1592,8 @@ void BufferAssigner::BuildColocatedBufferSets(
|
||||
void BufferAssigner::AssignColocatedBufferSets(
|
||||
const std::vector<ColocatedBufferSet>& colocated_buffer_sets,
|
||||
BufferAssignment* assignment,
|
||||
FlatSet<const LogicalBuffer*>* colocated_buffers,
|
||||
FlatSet<BufferAllocation::Index>* colocated_allocations) {
|
||||
flat_hash_set<const LogicalBuffer*>* colocated_buffers,
|
||||
flat_hash_set<BufferAllocation::Index>* colocated_allocations) {
|
||||
for (const ColocatedBufferSet& colocated_buffer_set : colocated_buffer_sets) {
|
||||
BufferAllocation* allocation = nullptr;
|
||||
// Set 'entry_parameter_number' and 'entry_parameter_shape_idx' if entry
|
||||
@ -1662,8 +1666,8 @@ StatusOr<std::unique_ptr<BufferAssignment>> BufferAssigner::CreateAssignment(
|
||||
// Once b/32491382 enables module-level liveness analysis, we may be able
|
||||
// to assign colocated buffers (or at least reuse their allocation for
|
||||
// buffers outside of the set) in AssignBuffersForComputation.
|
||||
FlatSet<const LogicalBuffer*> colocated_buffers;
|
||||
FlatSet<BufferAllocation::Index> colocated_allocations;
|
||||
flat_hash_set<const LogicalBuffer*> colocated_buffers;
|
||||
flat_hash_set<BufferAllocation::Index> colocated_allocations;
|
||||
std::vector<ColocatedBufferSet> colocated_buffer_sets;
|
||||
BuildColocatedBufferSets(module, assignment->liveness(),
|
||||
assignment->buffer_size_, &colocated_buffer_sets);
|
||||
@ -1681,7 +1685,7 @@ StatusOr<std::unique_ptr<BufferAssignment>> BufferAssigner::CreateAssignment(
|
||||
|
||||
// First assign buffers for global computatations. Temporary buffers for
|
||||
// sequential computations are collected in 'buffers_to_assign_sequentially'.
|
||||
flat_hash_map<const HloComputation*, FlatSet<const LogicalBuffer*>>
|
||||
flat_hash_map<const HloComputation*, flat_hash_set<const LogicalBuffer*>>
|
||||
buffers_to_assign_sequentially;
|
||||
for (auto* computation : global_computations) {
|
||||
TF_RETURN_IF_ERROR(AssignBuffersForComputation(
|
||||
|
||||
@ -23,6 +23,7 @@ limitations under the License.
|
||||
#include <vector>
|
||||
|
||||
#include "absl/container/flat_hash_map.h"
|
||||
#include "absl/container/flat_hash_set.h"
|
||||
#include "absl/types/span.h"
|
||||
#include "tensorflow/compiler/xla/service/buffer_liveness.h"
|
||||
#include "tensorflow/compiler/xla/service/heap_simulator.h"
|
||||
@ -34,7 +35,6 @@ limitations under the License.
|
||||
#include "tensorflow/compiler/xla/service/tuple_points_to_analysis.h"
|
||||
#include "tensorflow/compiler/xla/statusor.h"
|
||||
#include "tensorflow/compiler/xla/types.h"
|
||||
#include "tensorflow/core/lib/gtl/flatset.h"
|
||||
#include "tensorflow/core/platform/logging.h"
|
||||
#include "tensorflow/core/platform/macros.h"
|
||||
#include "tensorflow/core/platform/types.h"
|
||||
@ -554,11 +554,10 @@ class BufferAssigner {
|
||||
// true.
|
||||
Status AssignBuffersForComputation(
|
||||
const HloComputation* computation, bool is_thread_local,
|
||||
const tensorflow::gtl::FlatSet<const LogicalBuffer*>& colocated_buffers,
|
||||
const tensorflow::gtl::FlatSet<BufferAllocation::Index>&
|
||||
colocated_allocations,
|
||||
const absl::flat_hash_set<const LogicalBuffer*>& colocated_buffers,
|
||||
const absl::flat_hash_set<BufferAllocation::Index>& colocated_allocations,
|
||||
absl::flat_hash_map<const HloComputation*,
|
||||
tensorflow::gtl::FlatSet<const LogicalBuffer*>>*
|
||||
absl::flat_hash_set<const LogicalBuffer*>>*
|
||||
buffers_to_assign_sequentially,
|
||||
BufferAssignment* assignment);
|
||||
|
||||
@ -569,7 +568,7 @@ class BufferAssigner {
|
||||
// assuming all global computations are sequentially ordered.
|
||||
Status AssignBuffersWithSequentialOrdering(
|
||||
const absl::flat_hash_map<const HloComputation*,
|
||||
tensorflow::gtl::FlatSet<const LogicalBuffer*>>&
|
||||
absl::flat_hash_set<const LogicalBuffer*>>&
|
||||
buffers_to_assign_sequentially,
|
||||
bool run_whole_module_heap_simulation, BufferAssignment* assignment);
|
||||
|
||||
@ -589,7 +588,7 @@ class BufferAssigner {
|
||||
// alias. Explicitly handling these colocated buffers is necessary because
|
||||
// points-to analysis is computation level scope and does not recognize
|
||||
// aliasing across computations (b/32491382).
|
||||
using ColocatedBufferSet = tensorflow::gtl::FlatSet<const LogicalBuffer*>;
|
||||
using ColocatedBufferSet = absl::flat_hash_set<const LogicalBuffer*>;
|
||||
|
||||
// Returns a vector of ColocatedBufferSet objects, where each
|
||||
// ColocatedBufferSet aggregates a set of related LogicalBuffers from 'module'
|
||||
@ -604,8 +603,8 @@ class BufferAssigner {
|
||||
void AssignColocatedBufferSets(
|
||||
const std::vector<ColocatedBufferSet>& colocated_buffer_sets,
|
||||
BufferAssignment* assignment,
|
||||
tensorflow::gtl::FlatSet<const LogicalBuffer*>* colocated_buffers,
|
||||
tensorflow::gtl::FlatSet<BufferAllocation::Index>* colocated_allocations);
|
||||
absl::flat_hash_set<const LogicalBuffer*>* colocated_buffers,
|
||||
absl::flat_hash_set<BufferAllocation::Index>* colocated_allocations);
|
||||
|
||||
// Adds the 'colocated_set' of buffers to 'colocated_buffer_sets', maintaining
|
||||
// the invariant that all sets in 'colocated_buffer_sets' are disjoint.
|
||||
@ -624,10 +623,9 @@ class BufferAssigner {
|
||||
// Split a set of buffers into several sets, each of which contains buffers
|
||||
// colored with the same color.
|
||||
absl::flat_hash_map<LogicalBuffer::Color,
|
||||
tensorflow::gtl::FlatSet<const LogicalBuffer*>,
|
||||
absl::flat_hash_set<const LogicalBuffer*>,
|
||||
LogicalBuffer::Color::Hasher>
|
||||
SplitBuffersByColor(
|
||||
const tensorflow::gtl::FlatSet<const LogicalBuffer*>& buffers);
|
||||
SplitBuffersByColor(const absl::flat_hash_set<const LogicalBuffer*>& buffers);
|
||||
|
||||
// If true, buffer assignments assumes that input parameter buffers and output
|
||||
// buffers can be shared if their sizes match.
|
||||
|
||||
@ -20,6 +20,7 @@ limitations under the License.
|
||||
#include <string>
|
||||
#include <utility>
|
||||
|
||||
#include "absl/container/flat_hash_set.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_module.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_ordering.h"
|
||||
@ -27,7 +28,6 @@ limitations under the License.
|
||||
#include "tensorflow/compiler/xla/statusor.h"
|
||||
#include "tensorflow/compiler/xla/types.h"
|
||||
#include "tensorflow/core/lib/core/status.h"
|
||||
#include "tensorflow/core/lib/gtl/flatset.h"
|
||||
|
||||
namespace xla {
|
||||
|
||||
@ -101,7 +101,7 @@ class BufferLiveness {
|
||||
// Set of LogicalBuffers which are aliased in the output of other
|
||||
// instructions. For example, a LogicalBuffer which is inserted into a tuple
|
||||
// is considered to be aliased and will be in this set.
|
||||
tensorflow::gtl::FlatSet<const LogicalBuffer*> aliased_buffers_;
|
||||
absl::flat_hash_set<const LogicalBuffer*> aliased_buffers_;
|
||||
|
||||
// LogicalBuffers that may be live out of the entry computation.
|
||||
PointsToSet::BufferSet maybe_live_out_buffers_;
|
||||
|
||||
@ -16,10 +16,10 @@ limitations under the License.
|
||||
#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_BUFFER_VALUE_CONTAINERS_H_
|
||||
#define TENSORFLOW_COMPILER_XLA_SERVICE_BUFFER_VALUE_CONTAINERS_H_
|
||||
|
||||
#include "absl/container/flat_hash_set.h"
|
||||
#include "tensorflow/compiler/xla/service/buffer_value.h"
|
||||
#include "tensorflow/compiler/xla/service/logical_buffer.h"
|
||||
#include "tensorflow/core/lib/gtl/compactptrset.h"
|
||||
#include "tensorflow/core/lib/gtl/flatset.h"
|
||||
|
||||
namespace xla {
|
||||
|
||||
@ -38,7 +38,7 @@ BufferValueCompactPointerSet ToBufferValueCompactPointerSet(
|
||||
return output;
|
||||
}
|
||||
|
||||
using BufferValueFlatSet = tensorflow::gtl::FlatSet<const BufferValue*>;
|
||||
using BufferValueFlatSet = absl::flat_hash_set<const BufferValue*>;
|
||||
template <class LogicalBufferContainerT>
|
||||
BufferValueFlatSet ToBufferValueFlatSet(
|
||||
const LogicalBufferContainerT& logical_buffer_container) {
|
||||
|
||||
@ -17,6 +17,7 @@ limitations under the License.
|
||||
|
||||
#include <queue>
|
||||
|
||||
#include "absl/container/flat_hash_set.h"
|
||||
#include "absl/memory/memory.h"
|
||||
#include "absl/strings/str_cat.h"
|
||||
#include "absl/strings/str_format.h"
|
||||
@ -138,7 +139,7 @@ CallGraphNode& CallGraph::GetNode(const HloComputation* computation) {
|
||||
|
||||
bool CallGraph::DominatesHelper(
|
||||
const HloComputation* a, const HloComputation* b,
|
||||
tensorflow::gtl::FlatSet<const HloComputation*>* visited) const {
|
||||
absl::flat_hash_set<const HloComputation*>* visited) const {
|
||||
if (a == b || ContainsKey(*visited, b)) {
|
||||
// The call graph is guaranteed to be acyclic so any previously visited node
|
||||
// we encounter was already determined to be dominated.
|
||||
@ -163,7 +164,7 @@ bool CallGraph::DominatesHelper(
|
||||
|
||||
bool CallGraph::Dominates(const HloComputation* a,
|
||||
const HloComputation* b) const {
|
||||
tensorflow::gtl::FlatSet<const HloComputation*> visited;
|
||||
absl::flat_hash_set<const HloComputation*> visited;
|
||||
return DominatesHelper(a, b, &visited);
|
||||
}
|
||||
|
||||
@ -277,7 +278,7 @@ std::unique_ptr<CallGraph> CallGraph::Build(const HloModule* module) {
|
||||
|
||||
Status CallGraph::VisitNodesInternal(
|
||||
const VisitorFunction& visitor_func, const CallGraphNode& node,
|
||||
tensorflow::gtl::FlatSet<const CallGraphNode*>* visited) const {
|
||||
absl::flat_hash_set<const CallGraphNode*>* visited) const {
|
||||
auto pair = visited->insert(&node);
|
||||
if (!pair.second) {
|
||||
// Node was not inserted. Node has already been visited.
|
||||
@ -294,7 +295,7 @@ Status CallGraph::VisitNodesInternal(
|
||||
|
||||
Status CallGraph::VisitNodes(const VisitorFunction& visitor_func,
|
||||
bool visit_unreachable_nodes) const {
|
||||
tensorflow::gtl::FlatSet<const CallGraphNode*> visited;
|
||||
absl::flat_hash_set<const CallGraphNode*> visited;
|
||||
if (visit_unreachable_nodes) {
|
||||
// Traverse from all roots in the call graph.
|
||||
for (const CallGraphNode& node : nodes()) {
|
||||
|
||||
@ -21,10 +21,10 @@ limitations under the License.
|
||||
#include <ostream>
|
||||
|
||||
#include "absl/container/flat_hash_map.h"
|
||||
#include "absl/container/flat_hash_set.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_computation.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_module.h"
|
||||
#include "tensorflow/core/lib/gtl/flatset.h"
|
||||
|
||||
namespace xla {
|
||||
|
||||
@ -145,12 +145,12 @@ class CallGraphNode {
|
||||
// The computations called by this computation. The vector is used for a
|
||||
// stable ordering and the set enables fast membership testing.
|
||||
std::vector<HloComputation*> callees_;
|
||||
tensorflow::gtl::FlatSet<HloComputation*> callee_set_;
|
||||
absl::flat_hash_set<HloComputation*> callee_set_;
|
||||
|
||||
// The computations which call this computation. The vector is used for a
|
||||
// stable ordering and the set enables fast membership testing.
|
||||
std::vector<HloComputation*> callers_;
|
||||
tensorflow::gtl::FlatSet<HloComputation*> caller_set_;
|
||||
absl::flat_hash_set<HloComputation*> caller_set_;
|
||||
|
||||
// The call sites in this computation
|
||||
std::vector<CallSite> callsites_;
|
||||
@ -250,14 +250,14 @@ class CallGraph {
|
||||
// 'visited'.
|
||||
Status VisitNodesInternal(
|
||||
const VisitorFunction& visitor_func, const CallGraphNode& node,
|
||||
tensorflow::gtl::FlatSet<const CallGraphNode*>* visited) const;
|
||||
absl::flat_hash_set<const CallGraphNode*>* visited) const;
|
||||
|
||||
// Recursive helper for computing whether 'a' dominates 'b' in the call
|
||||
// graph. 'b_ancestor' is the currently visited node (which starts at 'b'),
|
||||
// and 'visited' is the set of computations which have been visited.
|
||||
bool DominatesHelper(
|
||||
const HloComputation* a, const HloComputation* b,
|
||||
tensorflow::gtl::FlatSet<const HloComputation*>* visited) const;
|
||||
absl::flat_hash_set<const HloComputation*>* visited) const;
|
||||
|
||||
// The HLO module represented by this call graph.
|
||||
const HloModule* module_ = nullptr;
|
||||
|
||||
@ -16,6 +16,7 @@ limitations under the License.
|
||||
#include "tensorflow/compiler/xla/service/copy_insertion.h"
|
||||
|
||||
#include "absl/container/flat_hash_map.h"
|
||||
#include "absl/container/flat_hash_set.h"
|
||||
#include "absl/strings/str_cat.h"
|
||||
#include "absl/strings/str_join.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_alias_analysis.h"
|
||||
@ -32,7 +33,6 @@ limitations under the License.
|
||||
#include "tensorflow/compiler/xla/statusor.h"
|
||||
#include "tensorflow/compiler/xla/types.h"
|
||||
#include "tensorflow/compiler/xla/util.h"
|
||||
#include "tensorflow/core/lib/gtl/flatset.h"
|
||||
#include "tensorflow/core/platform/logging.h"
|
||||
|
||||
namespace xla {
|
||||
@ -904,7 +904,7 @@ class CopyRemover {
|
||||
// The heads of all the value lists. Each value list represents the HLO
|
||||
// values contained in a particular HLO buffer. The values in the list are
|
||||
// in dependency order.
|
||||
tensorflow::gtl::FlatSet<const ValueNode*> value_lists_;
|
||||
absl::flat_hash_set<const ValueNode*> value_lists_;
|
||||
|
||||
// Copy removal requires fast access to the value list elements
|
||||
// corresponding to the source and destination values of the kCopy
|
||||
@ -1009,7 +1009,7 @@ Status CopyInsertion::AddSpecialCaseCopies(const CallGraph& call_graph,
|
||||
HloInstruction* root = computation->root_instruction();
|
||||
|
||||
// Mark nondistinct/ambiguous indices.
|
||||
tensorflow::gtl::FlatSet<const HloBuffer*> seen;
|
||||
absl::flat_hash_set<const HloBuffer*> seen;
|
||||
ShapeUtil::ForEachSubshape(
|
||||
root->shape(), [&](const Shape& /*subshape*/, const ShapeIndex& index) {
|
||||
std::vector<const HloBuffer*> buffers_at_index =
|
||||
|
||||
@ -291,6 +291,7 @@ cc_library(
|
||||
"//tensorflow/compiler/xla/service/llvm_ir:tuple_ops",
|
||||
"//tensorflow/core:lib",
|
||||
"@com_google_absl//absl/container:flat_hash_map",
|
||||
"@com_google_absl//absl/container:flat_hash_set",
|
||||
"@com_google_absl//absl/strings",
|
||||
"@com_google_absl//absl/strings:str_format",
|
||||
"@com_google_absl//absl/types:span",
|
||||
|
||||
@ -25,6 +25,7 @@ limitations under the License.
|
||||
#include <vector>
|
||||
|
||||
#include "absl/container/flat_hash_map.h"
|
||||
#include "absl/container/flat_hash_set.h"
|
||||
#include "tensorflow/core/lib/math/math_util.h"
|
||||
#include "tensorflow/core/platform/logging.h"
|
||||
// IWYU pragma: no_include "llvm/IR/Intrinsics.gen.inc"
|
||||
@ -68,7 +69,6 @@ limitations under the License.
|
||||
#include "tensorflow/compiler/xla/window_util.h"
|
||||
#include "tensorflow/core/lib/core/bits.h"
|
||||
#include "tensorflow/core/lib/core/errors.h"
|
||||
#include "tensorflow/core/lib/gtl/flatset.h"
|
||||
|
||||
namespace xla {
|
||||
|
||||
@ -1400,8 +1400,8 @@ static bool ReductionPreservesLayout(const HloInstruction& reduce) {
|
||||
// [0->0, 3->1].
|
||||
absl::flat_hash_map<int64, int64> unreduced_dim_map;
|
||||
|
||||
gtl::FlatSet<int64> reduced_dims(reduce.dimensions().begin(),
|
||||
reduce.dimensions().end());
|
||||
absl::flat_hash_set<int64> reduced_dims(reduce.dimensions().begin(),
|
||||
reduce.dimensions().end());
|
||||
|
||||
const Shape& operand_shape = reduce.operand(0)->shape();
|
||||
const Shape& result_shape = reduce.shape();
|
||||
@ -1977,7 +1977,7 @@ Status IrEmitter::HandleSlice(HloInstruction* slice) {
|
||||
//
|
||||
// * Implement the memcpy within the innermost loop.
|
||||
|
||||
gtl::FlatSet<int64> inner_dims;
|
||||
absl::flat_hash_set<int64> inner_dims;
|
||||
for (int64 dim : LayoutUtil::MinorToMajor(layout)) {
|
||||
if (operand->shape().dimensions(dim) != slice->shape().dimensions(dim)) {
|
||||
break;
|
||||
|
||||
@ -121,7 +121,7 @@ TEST_F(CpuNoAliasTest, Concat) {
|
||||
CHECK: %read_concat2_array = load {{.*}} !alias.scope [[concat1_noalias]], !noalias [[concat1_scope]]
|
||||
CHECK-DAG: [[buf_size32:![0-9]+]] = !{!"buffer:{{.*}} size:32
|
||||
CHECK-DAG: [[buf_size48:![0-9]+]] = !{!"buffer:{{.*}} size:48
|
||||
CHECK-DAG: [[param_x_noalias]] = !{[[buf_size32]], [[buf_size48]]}
|
||||
CHECK-DAG: [[param_x_noalias]] = !{[[buf_size48]], [[buf_size32]]}
|
||||
CHECK-DAG: [[concat1_scope]] = !{[[buf_size32]]}
|
||||
CHECK-DAG: [[concat1_noalias]] = !{[[buf_size48]]}
|
||||
)";
|
||||
|
||||
@ -476,6 +476,7 @@ cc_library(
|
||||
"//tensorflow/compiler/xla/service:hlo",
|
||||
"//tensorflow/compiler/xla/service:instruction_fusion",
|
||||
"//tensorflow/compiler/xla/service:pattern_matcher",
|
||||
"@com_google_absl//absl/container:flat_hash_set",
|
||||
],
|
||||
)
|
||||
|
||||
@ -508,6 +509,7 @@ cc_library(
|
||||
"//tensorflow/compiler/xla/service:multi_output_fusion",
|
||||
"//tensorflow/core:lib",
|
||||
"@com_google_absl//absl/algorithm:container",
|
||||
"@com_google_absl//absl/container:flat_hash_set",
|
||||
],
|
||||
)
|
||||
|
||||
@ -541,6 +543,7 @@ cc_library(
|
||||
"//tensorflow/compiler/xla/service:hlo_dataflow_analysis",
|
||||
"//tensorflow/compiler/xla/service:hlo_pass",
|
||||
"//tensorflow/core:lib",
|
||||
"@com_google_absl//absl/container:flat_hash_set",
|
||||
],
|
||||
)
|
||||
|
||||
|
||||
@ -19,6 +19,7 @@ limitations under the License.
|
||||
#include <set>
|
||||
#include <vector>
|
||||
|
||||
#include "absl/container/flat_hash_set.h"
|
||||
#include "tensorflow/compiler/xla/service/call_graph.h"
|
||||
#include "tensorflow/compiler/xla/service/copy_insertion.h"
|
||||
#include "tensorflow/compiler/xla/service/gpu/ir_emission_utils.h"
|
||||
@ -27,7 +28,6 @@ limitations under the License.
|
||||
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_opcode.h"
|
||||
#include "tensorflow/core/lib/core/status.h"
|
||||
#include "tensorflow/core/lib/gtl/flatset.h"
|
||||
#include "tensorflow/core/platform/logging.h"
|
||||
|
||||
namespace xla {
|
||||
|
||||
@ -15,6 +15,7 @@ limitations under the License.
|
||||
|
||||
#include "tensorflow/compiler/xla/service/gpu/instruction_fusion.h"
|
||||
|
||||
#include "absl/container/flat_hash_set.h"
|
||||
#include "tensorflow/compiler/xla/service/gpu/gpu_fusible.h"
|
||||
#include "tensorflow/compiler/xla/service/gpu/ir_emission_utils.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_opcode.h"
|
||||
@ -125,8 +126,8 @@ bool IsIEEEFloatingPointScalarConstant(const HloInstruction* constant) {
|
||||
}
|
||||
|
||||
// Compute the precise number of operands to the new fusion.
|
||||
tensorflow::gtl::FlatSet<const HloInstruction*> operands(
|
||||
a->operands().begin(), a->operands().end());
|
||||
absl::flat_hash_set<const HloInstruction*> operands(a->operands().begin(),
|
||||
a->operands().end());
|
||||
operands.insert(b->operands().begin(), b->operands().end());
|
||||
// If there's an edge between `a` and `b`, don't count it: We're fusing that
|
||||
// producer -> consumer relationship.
|
||||
|
||||
@ -24,6 +24,7 @@ limitations under the License.
|
||||
#include <utility>
|
||||
|
||||
#include "absl/algorithm/container.h"
|
||||
#include "absl/container/flat_hash_set.h"
|
||||
#include "tensorflow/compiler/xla/layout_util.h"
|
||||
#include "tensorflow/compiler/xla/service/gpu/gpu_fusible.h"
|
||||
#include "tensorflow/compiler/xla/service/gpu/instruction_fusion.h"
|
||||
@ -31,7 +32,6 @@ limitations under the License.
|
||||
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_opcode.h"
|
||||
#include "tensorflow/compiler/xla/shape_util.h"
|
||||
#include "tensorflow/core/lib/gtl/flatset.h"
|
||||
#include "tensorflow/core/platform/types.h"
|
||||
|
||||
namespace xla {
|
||||
@ -101,7 +101,7 @@ bool GpuMultiOutputFusion::IsFusible(HloInstruction* instr) {
|
||||
|
||||
int64 GpuMultiOutputFusion::GetProfit(HloInstruction* instr1,
|
||||
HloInstruction* instr2) {
|
||||
tensorflow::gtl::FlatSet<HloInstruction*> in_list;
|
||||
absl::flat_hash_set<HloInstruction*> in_list;
|
||||
for (auto instr : instr1->operands()) {
|
||||
if (!IsProfitableOperand(instr)) {
|
||||
continue;
|
||||
@ -148,7 +148,7 @@ bool GpuMultiOutputFusion::DoProducerConsumerMultiOutputFusion() {
|
||||
bool changed = false;
|
||||
RecomputeReachability();
|
||||
|
||||
tensorflow::gtl::FlatSet<HloInstruction*> to_fuse;
|
||||
absl::flat_hash_set<HloInstruction*> to_fuse;
|
||||
// Keep a list of the instructions to fuse after making all the fusion
|
||||
// decisions. We first aggressively add instructions to potential_fusion_list,
|
||||
// then filter out instructions that will be no longer fusible because of
|
||||
|
||||
@ -19,6 +19,7 @@ limitations under the License.
|
||||
#include <vector>
|
||||
|
||||
#include "absl/container/flat_hash_map.h"
|
||||
#include "absl/container/flat_hash_set.h"
|
||||
#include "absl/memory/memory.h"
|
||||
#include "tensorflow/compiler/xla/map_util.h"
|
||||
#include "tensorflow/compiler/xla/util.h"
|
||||
@ -26,7 +27,7 @@ limitations under the License.
|
||||
namespace xla {
|
||||
|
||||
using absl::flat_hash_map;
|
||||
using tensorflow::gtl::FlatSet;
|
||||
using absl::flat_hash_set;
|
||||
|
||||
/*static*/
|
||||
StatusOr<int64> HeapSimulator::MinimumMemoryForModule(
|
||||
@ -116,9 +117,9 @@ Status HeapSimulator::RunComputation(
|
||||
// 'used_buffers' is the reverse map - it tracks which buffers were used by an
|
||||
// instruction, so that we can remove the instructions from a buffer's live
|
||||
// set after they are visited.
|
||||
flat_hash_map<const BufferValue*, FlatSet<const HloInstruction*>>
|
||||
flat_hash_map<const BufferValue*, flat_hash_set<const HloInstruction*>>
|
||||
live_buffers;
|
||||
flat_hash_map<const HloInstruction*, FlatSet<const BufferValue*>>
|
||||
flat_hash_map<const HloInstruction*, flat_hash_set<const BufferValue*>>
|
||||
used_buffers;
|
||||
auto add_user_to_buffer = [this, &live_buffers, &used_buffers](
|
||||
const HloInstruction* user,
|
||||
@ -216,7 +217,7 @@ Status HeapSimulator::RunComputation(
|
||||
VLOG(4) << " Removing user " << instruction->name() << " from buffer "
|
||||
<< operand_buffer->ToString();
|
||||
auto it = live_buffers.find(operand_buffer);
|
||||
FlatSet<const HloInstruction*>* live_set = &it->second;
|
||||
flat_hash_set<const HloInstruction*>* live_set = &it->second;
|
||||
live_set->erase(instruction);
|
||||
if (live_set->empty()) {
|
||||
live_buffers.erase(it);
|
||||
@ -238,7 +239,7 @@ Status HeapSimulator::RunComputation(
|
||||
// that we should assign.
|
||||
|
||||
// Make sure each buffer get reused at most once.
|
||||
FlatSet<const BufferValue*> reused_buffers;
|
||||
flat_hash_set<const BufferValue*> reused_buffers;
|
||||
for (const BufferValue* buffer : buffers_defined_by_instruction) {
|
||||
if (IgnoreBuffer(buffer)) {
|
||||
continue;
|
||||
@ -326,7 +327,7 @@ Status HeapSimulator::RunComputation(
|
||||
to_free.reserve(live_buffers.size());
|
||||
for (const auto& buffer_pending : live_buffers) {
|
||||
const BufferValue* buffer = buffer_pending.first;
|
||||
const FlatSet<const HloInstruction*>& pending = buffer_pending.second;
|
||||
const flat_hash_set<const HloInstruction*>& pending = buffer_pending.second;
|
||||
CHECK_EQ(pending.size(), 1) << *buffer;
|
||||
CHECK(*pending.begin() == nullptr) << *buffer;
|
||||
to_free.push_back(buffer);
|
||||
|
||||
@ -22,6 +22,7 @@ limitations under the License.
|
||||
#include <vector>
|
||||
|
||||
#include "absl/container/flat_hash_map.h"
|
||||
#include "absl/container/flat_hash_set.h"
|
||||
#include "tensorflow/compiler/xla/service/buffer_value.h"
|
||||
#include "tensorflow/compiler/xla/service/buffer_value_containers.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo.pb.h"
|
||||
@ -31,7 +32,6 @@ limitations under the License.
|
||||
#include "tensorflow/compiler/xla/service/hlo_schedule.h"
|
||||
#include "tensorflow/compiler/xla/service/tuple_points_to_analysis.h"
|
||||
#include "tensorflow/compiler/xla/statusor.h"
|
||||
#include "tensorflow/core/lib/gtl/flatset.h"
|
||||
|
||||
namespace xla {
|
||||
|
||||
@ -197,8 +197,8 @@ class HeapSimulator {
|
||||
shared_buffers_;
|
||||
|
||||
// Hold some sets for error-checking the sequence of Alloc and Free calls.
|
||||
tensorflow::gtl::FlatSet<const BufferValue*> allocated_buffers_;
|
||||
tensorflow::gtl::FlatSet<const BufferValue*> freed_buffers_;
|
||||
absl::flat_hash_set<const BufferValue*> allocated_buffers_;
|
||||
absl::flat_hash_set<const BufferValue*> freed_buffers_;
|
||||
|
||||
// Debugging information filled in while the heap simulator runs.
|
||||
HeapSimulatorTrace debug_trace_;
|
||||
|
||||
@ -21,6 +21,7 @@ limitations under the License.
|
||||
#include <vector>
|
||||
|
||||
#include "absl/container/flat_hash_map.h"
|
||||
#include "absl/container/flat_hash_set.h"
|
||||
#include "absl/strings/str_cat.h"
|
||||
#include "absl/strings/str_join.h"
|
||||
#include "tensorflow/compiler/xla/map_util.h"
|
||||
@ -120,7 +121,7 @@ class BufferValueMap {
|
||||
}
|
||||
|
||||
// Return a set of all the values in the given buffer.
|
||||
const tensorflow::gtl::FlatSet<const HloValue*>& GetValuesInBuffer(
|
||||
const absl::flat_hash_set<const HloValue*>& GetValuesInBuffer(
|
||||
BufferNumber buffer_number) const {
|
||||
return buffers_.at(buffer_number);
|
||||
}
|
||||
@ -143,7 +144,7 @@ class BufferValueMap {
|
||||
// Move the given value into the given buffer.
|
||||
void MoveValueToBuffer(const HloValue& value, BufferNumber buffer_number) {
|
||||
BufferNumber old_buffer_number = value_to_buffer_number_.at(&value);
|
||||
tensorflow::gtl::FlatSet<const HloValue*>& old_value_set =
|
||||
absl::flat_hash_set<const HloValue*>& old_value_set =
|
||||
buffers_.at(old_buffer_number);
|
||||
old_value_set.erase(&value);
|
||||
if (old_value_set.empty()) {
|
||||
@ -291,7 +292,7 @@ class BufferValueMap {
|
||||
const HloDataflowAnalysis& dataflow_;
|
||||
|
||||
// A map containing the set of values contained in each buffer.
|
||||
absl::flat_hash_map<BufferNumber, tensorflow::gtl::FlatSet<const HloValue*>>
|
||||
absl::flat_hash_map<BufferNumber, absl::flat_hash_set<const HloValue*>>
|
||||
buffers_;
|
||||
|
||||
// A map indicating which buffer each value is contained in.
|
||||
@ -351,7 +352,7 @@ bool HloAliasAnalysis::InstructionBuffersAreAmbiguous(
|
||||
|
||||
bool HloAliasAnalysis::InstructionBuffersAreDistinct(
|
||||
const HloInstruction* instruction) const {
|
||||
tensorflow::gtl::FlatSet<const HloBuffer*> buffers_seen;
|
||||
absl::flat_hash_set<const HloBuffer*> buffers_seen;
|
||||
for (const auto& pair :
|
||||
dataflow_analysis_->GetInstructionValueSet(instruction)) {
|
||||
const HloValueSet& value_set = pair.second;
|
||||
|
||||
@ -20,6 +20,7 @@ limitations under the License.
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
|
||||
#include "absl/container/flat_hash_set.h"
|
||||
#include "absl/strings/str_cat.h"
|
||||
#include "absl/strings/str_join.h"
|
||||
#include "tensorflow/compiler/xla/map_util.h"
|
||||
@ -28,7 +29,6 @@ limitations under the License.
|
||||
#include "tensorflow/compiler/xla/types.h"
|
||||
#include "tensorflow/compiler/xla/util.h"
|
||||
#include "tensorflow/core/lib/core/errors.h"
|
||||
#include "tensorflow/core/lib/gtl/flatset.h"
|
||||
#include "tensorflow/core/platform/logging.h"
|
||||
|
||||
namespace xla {
|
||||
|
||||
@ -25,6 +25,7 @@ limitations under the License.
|
||||
|
||||
#include "absl/algorithm/container.h"
|
||||
#include "absl/container/flat_hash_map.h"
|
||||
#include "absl/container/flat_hash_set.h"
|
||||
#include "absl/memory/memory.h"
|
||||
#include "absl/strings/numbers.h"
|
||||
#include "absl/strings/str_cat.h"
|
||||
@ -40,7 +41,6 @@ limitations under the License.
|
||||
#include "tensorflow/compiler/xla/util.h"
|
||||
#include "tensorflow/core/lib/core/errors.h"
|
||||
#include "tensorflow/core/lib/core/status.h"
|
||||
#include "tensorflow/core/lib/gtl/flatset.h"
|
||||
#include "tensorflow/core/platform/logging.h"
|
||||
|
||||
namespace xla {
|
||||
@ -278,10 +278,9 @@ void HloComputation::set_root_instruction(HloInstruction* new_root_instruction,
|
||||
namespace {
|
||||
|
||||
// Helper which builds a post order of the HLO call graph.
|
||||
void ComputeComputationPostOrder(
|
||||
HloComputation* computation,
|
||||
tensorflow::gtl::FlatSet<HloComputation*>* visited,
|
||||
std::vector<HloComputation*>* post_order) {
|
||||
void ComputeComputationPostOrder(HloComputation* computation,
|
||||
absl::flat_hash_set<HloComputation*>* visited,
|
||||
std::vector<HloComputation*>* post_order) {
|
||||
if (visited->insert(computation).second) {
|
||||
for (auto* instruction : computation->instructions()) {
|
||||
for (HloComputation* called_computation :
|
||||
@ -416,7 +415,7 @@ std::vector<HloInstruction*> HloComputation::MakeInstructionPostOrder() const {
|
||||
|
||||
std::vector<HloComputation*> HloComputation::MakeEmbeddedComputationsList()
|
||||
const {
|
||||
tensorflow::gtl::FlatSet<HloComputation*> visited;
|
||||
absl::flat_hash_set<HloComputation*> visited;
|
||||
std::vector<HloComputation*> post_order;
|
||||
|
||||
// To avoid special handling of this computation, cast away const of
|
||||
|
||||
@ -26,6 +26,7 @@ limitations under the License.
|
||||
#include <vector>
|
||||
|
||||
#include "absl/container/flat_hash_map.h"
|
||||
#include "absl/container/flat_hash_set.h"
|
||||
#include "absl/types/span.h"
|
||||
#include "tensorflow/compiler/xla/iterator_util.h"
|
||||
#include "tensorflow/compiler/xla/map_util.h"
|
||||
@ -41,7 +42,6 @@ limitations under the License.
|
||||
#include "tensorflow/compiler/xla/types.h"
|
||||
#include "tensorflow/compiler/xla/xla_data.pb.h"
|
||||
#include "tensorflow/core/lib/core/status.h"
|
||||
#include "tensorflow/core/lib/gtl/flatset.h"
|
||||
#include "tensorflow/core/platform/macros.h"
|
||||
#include "tensorflow/core/platform/types.h"
|
||||
|
||||
|
||||
@ -23,6 +23,7 @@ limitations under the License.
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
|
||||
#include "absl/container/flat_hash_set.h"
|
||||
#include "absl/container/inlined_vector.h"
|
||||
#include "tensorflow/compiler/xla/layout_util.h"
|
||||
#include "tensorflow/compiler/xla/literal.h"
|
||||
@ -34,7 +35,6 @@ limitations under the License.
|
||||
#include "tensorflow/compiler/xla/types.h"
|
||||
#include "tensorflow/compiler/xla/xla_data.pb.h"
|
||||
#include "tensorflow/core/lib/core/errors.h"
|
||||
#include "tensorflow/core/lib/gtl/flatset.h"
|
||||
#include "tensorflow/core/lib/hash/hash.h"
|
||||
|
||||
namespace xla {
|
||||
@ -137,8 +137,8 @@ StatusOr<bool> HloCSE::Run(HloModule* module) {
|
||||
// HLO instructions are grouped into equivalency classes by using the
|
||||
// cse_equal predicate defined above. This set holds a representative
|
||||
// instruction for each class.
|
||||
tensorflow::gtl::FlatSet<HloInstruction*, decltype(&CseHash),
|
||||
decltype(cse_equal)>
|
||||
absl::flat_hash_set<HloInstruction*, decltype(&CseHash),
|
||||
decltype(cse_equal)>
|
||||
representatives(/*N=*/computation->instruction_count() + 1, &CseHash,
|
||||
cse_equal);
|
||||
for (auto instruction : computation->MakeInstructionPostOrder()) {
|
||||
|
||||
@ -19,6 +19,7 @@ limitations under the License.
|
||||
#include <queue>
|
||||
#include <vector>
|
||||
|
||||
#include "absl/container/flat_hash_set.h"
|
||||
#include "absl/container/inlined_vector.h"
|
||||
#include "absl/memory/memory.h"
|
||||
#include "absl/strings/str_cat.h"
|
||||
@ -91,7 +92,7 @@ HloDataflowAnalysis::HloDataflowAnalysis(
|
||||
|
||||
bool HloDataflowAnalysis::AreTransitiveUsesElementwiseOrTuple(
|
||||
const HloInstruction* inst) {
|
||||
tensorflow::gtl::FlatSet<const HloInstruction*> visited;
|
||||
absl::flat_hash_set<const HloInstruction*> visited;
|
||||
absl::InlinedVector<const HloInstruction*, 4> stack;
|
||||
stack.push_back(inst);
|
||||
while (!stack.empty()) {
|
||||
@ -159,8 +160,8 @@ void HloDataflowAnalysis::MarkValueForDeletion(HloValue::Id value_id) {
|
||||
void HloDataflowAnalysis::DeleteMarkedValues() {
|
||||
#ifndef NDEBUG
|
||||
// Verify that no marked-for-deletion values are in any of the value sets.
|
||||
tensorflow::gtl::FlatSet<HloValue::Id> id_set(value_ids_to_delete_.begin(),
|
||||
value_ids_to_delete_.end());
|
||||
absl::flat_hash_set<HloValue::Id> id_set(value_ids_to_delete_.begin(),
|
||||
value_ids_to_delete_.end());
|
||||
for (const auto& pair : value_sets_) {
|
||||
const HloInstruction* instruction = pair.first;
|
||||
const InstructionValueSet& instruction_value_set = pair.second;
|
||||
@ -673,7 +674,7 @@ bool HloDataflowAnalysis::UpdateInstructionValueSet(
|
||||
|
||||
void HloDataflowAnalysis::Propagate() {
|
||||
std::queue<HloInstruction*> worklist;
|
||||
tensorflow::gtl::FlatSet<HloInstruction*> workset;
|
||||
absl::flat_hash_set<HloInstruction*> workset;
|
||||
auto add_to_worklist = [&worklist, &workset](HloInstruction* instruction) {
|
||||
if (workset.insert(instruction).second) {
|
||||
worklist.push(instruction);
|
||||
|
||||
@ -18,6 +18,7 @@ limitations under the License.
|
||||
#include <algorithm>
|
||||
|
||||
#include "absl/container/flat_hash_map.h"
|
||||
#include "absl/container/flat_hash_set.h"
|
||||
#include "absl/memory/memory.h"
|
||||
#include "tensorflow/compiler/xla/map_util.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_opcode.h"
|
||||
@ -217,7 +218,7 @@ bool HloDomainMap::IsDomainInstruction(HloInstruction* instruction) const {
|
||||
|
||||
/* static */ std::vector<HloInstruction*>
|
||||
HloDomainMap::MakeNonDomainInstructions(
|
||||
const tensorflow::gtl::FlatSet<HloInstruction*>& instruction_set,
|
||||
const absl::flat_hash_set<HloInstruction*>& instruction_set,
|
||||
const InstructionOrderMap& instructions_order) {
|
||||
std::vector<HloInstruction*> instructions;
|
||||
instructions.reserve(instruction_set.size());
|
||||
|
||||
@ -20,13 +20,13 @@ limitations under the License.
|
||||
#include <vector>
|
||||
|
||||
#include "absl/container/flat_hash_map.h"
|
||||
#include "absl/container/flat_hash_set.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_computation.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_domain_metadata.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_module.h"
|
||||
#include "tensorflow/compiler/xla/statusor.h"
|
||||
#include "tensorflow/core/lib/core/status.h"
|
||||
#include "tensorflow/core/lib/gtl/flatset.h"
|
||||
|
||||
namespace xla {
|
||||
|
||||
@ -110,7 +110,7 @@ class HloDomainMap {
|
||||
// Out of an instruction set, returns a vector of all the ones which are not
|
||||
// a kDomain kind.
|
||||
static std::vector<HloInstruction*> MakeNonDomainInstructions(
|
||||
const tensorflow::gtl::FlatSet<HloInstruction*>& instruction_set,
|
||||
const absl::flat_hash_set<HloInstruction*>& instruction_set,
|
||||
const InstructionOrderMap& instructions_order);
|
||||
|
||||
// Populates domain_metadata_id_ that maps each HloInstruction to the unique
|
||||
|
||||
@ -20,11 +20,11 @@ limitations under the License.
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
#include "absl/container/flat_hash_set.h"
|
||||
#include "absl/strings/string_view.h"
|
||||
#include "tensorflow/compiler/xla/statusor.h"
|
||||
#include "tensorflow/compiler/xla/types.h"
|
||||
#include "tensorflow/core/lib/core/status.h"
|
||||
#include "tensorflow/core/lib/gtl/flatset.h"
|
||||
|
||||
namespace xla {
|
||||
|
||||
@ -42,7 +42,7 @@ class DomainMetadata {
|
||||
// operand/user pathways, without crossing a kDomain instruction of a given
|
||||
// kind. The reach_set can contain kDomain instructions of other kinds, if
|
||||
// two domains of different kind intersect each other.
|
||||
tensorflow::gtl::FlatSet<HloInstruction*> reach_set;
|
||||
absl::flat_hash_set<HloInstruction*> reach_set;
|
||||
|
||||
// The same instructions in reach_set, but purged from kDomain instructions
|
||||
// and ordered according to their computation graph post-order, i.e.
|
||||
@ -55,8 +55,8 @@ class DomainMetadata {
|
||||
// whose dataflow enters the reach set (domain), while the exit_domains
|
||||
// contains the set of kDomain instructions whose dataflow exit the reach
|
||||
// set.
|
||||
tensorflow::gtl::FlatSet<HloInstruction*> enter_domains;
|
||||
tensorflow::gtl::FlatSet<HloInstruction*> exit_domains;
|
||||
absl::flat_hash_set<HloInstruction*> enter_domains;
|
||||
absl::flat_hash_set<HloInstruction*> exit_domains;
|
||||
};
|
||||
|
||||
virtual ~DomainMetadata() = default;
|
||||
|
||||
@ -23,6 +23,7 @@ limitations under the License.
|
||||
|
||||
#include "absl/algorithm/container.h"
|
||||
#include "absl/container/flat_hash_map.h"
|
||||
#include "absl/container/flat_hash_set.h"
|
||||
#include "absl/container/inlined_vector.h"
|
||||
#include "absl/memory/memory.h"
|
||||
#include "absl/strings/ascii.h"
|
||||
@ -44,7 +45,6 @@ limitations under the License.
|
||||
#include "tensorflow/compiler/xla/types.h"
|
||||
#include "tensorflow/compiler/xla/util.h"
|
||||
#include "tensorflow/core/lib/core/errors.h"
|
||||
#include "tensorflow/core/lib/gtl/flatset.h"
|
||||
#include "tensorflow/core/lib/gtl/map_util.h"
|
||||
#include "tensorflow/core/platform/human_readable_json.h"
|
||||
#include "tensorflow/core/platform/logging.h"
|
||||
@ -1433,7 +1433,7 @@ int64 HloInstruction::operand_index(const HloInstruction* target) const {
|
||||
|
||||
HloInstruction::InstructionVector HloInstruction::unique_operands() const {
|
||||
InstructionVector unique;
|
||||
tensorflow::gtl::FlatSet<const HloInstruction*> seen;
|
||||
absl::flat_hash_set<const HloInstruction*> seen;
|
||||
for (HloInstruction* operand : operands()) {
|
||||
if (seen.insert(operand).second) {
|
||||
unique.push_back(operand);
|
||||
|
||||
@ -21,6 +21,7 @@ limitations under the License.
|
||||
#include <vector>
|
||||
|
||||
#include "absl/container/flat_hash_map.h"
|
||||
#include "absl/container/flat_hash_set.h"
|
||||
#include "tensorflow/compiler/xla/service/heap_simulator.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_computation.h"
|
||||
#include "tensorflow/compiler/xla/service/tuple_points_to_analysis.h"
|
||||
@ -111,7 +112,7 @@ class ListScheduler {
|
||||
// LogicalBuffer is in an operand of the instruction as indicated by
|
||||
// points-to analysis.
|
||||
for (auto* instruction : computation.instructions()) {
|
||||
tensorflow::gtl::FlatSet<const LogicalBuffer*> instr_uses;
|
||||
absl::flat_hash_set<const LogicalBuffer*> instr_uses;
|
||||
for (auto* operand : instruction->operands()) {
|
||||
points_to_analysis.GetPointsToSet(operand).ForEachElement(
|
||||
[&](const ShapeIndex& /*index*/,
|
||||
@ -360,7 +361,7 @@ class ListScheduler {
|
||||
std::unordered_map<const LogicalBuffer*, int64> unscheduled_use_count_;
|
||||
|
||||
// Set of instructions which have been scheduled.
|
||||
tensorflow::gtl::FlatSet<const HloInstruction*> scheduled_instructions_;
|
||||
absl::flat_hash_set<const HloInstruction*> scheduled_instructions_;
|
||||
};
|
||||
|
||||
int64 SumLogicalBufferSizes(
|
||||
@ -418,7 +419,7 @@ StatusOr<HloInstructionSequence> DFSMemoryScheduler(
|
||||
points_to_analysis.GetBuffersDefinedByInstruction(hlo), size_function);
|
||||
total_sizes[hlo] = logical_buffer_size;
|
||||
cumulative_total_size += logical_buffer_size;
|
||||
tensorflow::gtl::FlatSet<const HloInstruction*> unique_operands(
|
||||
absl::flat_hash_set<const HloInstruction*> unique_operands(
|
||||
hlo->operands().begin(), hlo->operands().end());
|
||||
for (const HloInstruction* operand : unique_operands) {
|
||||
extra_users[hlo] += extra_users[operand];
|
||||
|
||||
@ -24,6 +24,7 @@ limitations under the License.
|
||||
|
||||
#include "absl/algorithm/container.h"
|
||||
#include "absl/container/flat_hash_map.h"
|
||||
#include "absl/container/flat_hash_set.h"
|
||||
#include "absl/memory/memory.h"
|
||||
#include "absl/strings/str_cat.h"
|
||||
#include "tensorflow/compiler/xla/map_util.h"
|
||||
@ -328,10 +329,10 @@ StatusOr<std::unique_ptr<HloModule>> HloModule::CreateFromProto(
|
||||
|
||||
// Because we didn't uniquify the names or the ids, double-check that the
|
||||
// instruction and computation names and ids are unique from the proto.
|
||||
tensorflow::gtl::FlatSet<string> computation_names;
|
||||
tensorflow::gtl::FlatSet<string> instruction_names;
|
||||
tensorflow::gtl::FlatSet<int> computation_ids;
|
||||
tensorflow::gtl::FlatSet<int> instruction_ids;
|
||||
absl::flat_hash_set<string> computation_names;
|
||||
absl::flat_hash_set<string> instruction_names;
|
||||
absl::flat_hash_set<int> computation_ids;
|
||||
absl::flat_hash_set<int> instruction_ids;
|
||||
for (HloComputation* computation : module->computations()) {
|
||||
TF_RET_CHECK(!ContainsKey(computation_names, computation->name()))
|
||||
<< "Computation name is not unique: " << computation->name();
|
||||
|
||||
@ -22,6 +22,7 @@ limitations under the License.
|
||||
#include <string>
|
||||
#include <utility>
|
||||
|
||||
#include "absl/container/flat_hash_set.h"
|
||||
#include "absl/memory/memory.h"
|
||||
#include "absl/strings/str_cat.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_casting_utils.h"
|
||||
@ -32,7 +33,6 @@ limitations under the License.
|
||||
#include "tensorflow/compiler/xla/types.h"
|
||||
#include "tensorflow/compiler/xla/util.h"
|
||||
#include "tensorflow/core/lib/core/errors.h"
|
||||
#include "tensorflow/core/lib/gtl/flatset.h"
|
||||
#include "tensorflow/core/platform/logging.h"
|
||||
#include "tensorflow/core/platform/types.h"
|
||||
|
||||
@ -42,7 +42,7 @@ std::vector<HloInstruction*> HloModuleGroupUtil::GlobalPredecessors(
|
||||
HloInstruction* instruction) {
|
||||
std::vector<HloInstruction*>
|
||||
predecessors; // Use a vector to avoid non-determinism.
|
||||
tensorflow::gtl::FlatSet<HloInstruction*> unique;
|
||||
absl::flat_hash_set<HloInstruction*> unique;
|
||||
|
||||
// Adds to the unique predecessors list; if the predecessors is a companion
|
||||
// instruction, also add companion instructions; if the predecessors is a
|
||||
@ -119,7 +119,7 @@ std::vector<HloInstruction*> HloModuleGroupUtil::GlobalSuccessors(
|
||||
HloInstruction* instruction) {
|
||||
std::vector<HloInstruction*>
|
||||
successors; // Use a vector to avoid non-determinism.
|
||||
tensorflow::gtl::FlatSet<HloInstruction*> unique;
|
||||
absl::flat_hash_set<HloInstruction*> unique;
|
||||
|
||||
// Adds to the unique successors list; if the successor is a companion
|
||||
// instruction, also add companion instructions; if the successor is a
|
||||
|
||||
@ -18,6 +18,7 @@ limitations under the License.
|
||||
#include <functional>
|
||||
|
||||
#include "absl/container/flat_hash_map.h"
|
||||
#include "absl/container/flat_hash_set.h"
|
||||
#include "absl/strings/str_format.h"
|
||||
#include "absl/strings/str_join.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_graph_dumper.h"
|
||||
@ -25,7 +26,6 @@ limitations under the License.
|
||||
#include "tensorflow/compiler/xla/status_macros.h"
|
||||
#include "tensorflow/compiler/xla/types.h"
|
||||
#include "tensorflow/compiler/xla/util.h"
|
||||
#include "tensorflow/core/lib/gtl/flatset.h"
|
||||
#include "tensorflow/core/platform/logging.h"
|
||||
|
||||
namespace xla {
|
||||
@ -75,8 +75,8 @@ StatusOr<bool> HloPassPipeline::RunPassesInternal(
|
||||
std::vector<HloPassInterface*> HloPassPipeline::GetEnabledPasses(
|
||||
const DebugOptions& debug_options) {
|
||||
auto repeated_field = debug_options.xla_disable_hlo_passes();
|
||||
tensorflow::gtl::FlatSet<string> disabled_pass_names(repeated_field.begin(),
|
||||
repeated_field.end());
|
||||
absl::flat_hash_set<string> disabled_pass_names(repeated_field.begin(),
|
||||
repeated_field.end());
|
||||
if (!disabled_pass_names.empty()) {
|
||||
VLOG(1) << "Passes disabled by --xla_disable_hlo_passes: "
|
||||
<< absl::StrJoin(disabled_pass_names, ", ");
|
||||
|
||||
@ -21,6 +21,7 @@ limitations under the License.
|
||||
#include <string>
|
||||
|
||||
#include "absl/container/flat_hash_map.h"
|
||||
#include "absl/container/flat_hash_set.h"
|
||||
#include "absl/container/inlined_vector.h"
|
||||
#include "absl/strings/str_cat.h"
|
||||
#include "absl/strings/str_format.h"
|
||||
@ -981,7 +982,7 @@ StatusOr<bool> HloRematerialization::RematerializeComputation(
|
||||
// rematerialization is essentially a move). If the next rematerialization of
|
||||
// the instruction is also a move then the rematerialization is added to the
|
||||
// blacklist.
|
||||
tensorflow::gtl::FlatSet<const HloInstruction*> remat_move_instructions;
|
||||
absl::flat_hash_set<const HloInstruction*> remat_move_instructions;
|
||||
|
||||
// The map from instructions to their rematerializable status.
|
||||
absl::flat_hash_map<const HloInstruction*, bool> remat_able;
|
||||
|
||||
@ -16,6 +16,7 @@
|
||||
#define TENSORFLOW_COMPILER_XLA_SERVICE_HLO_REMATERIALIZATION_H_
|
||||
|
||||
#include "absl/container/flat_hash_map.h"
|
||||
#include "absl/container/flat_hash_set.h"
|
||||
#include "tensorflow/compiler/xla/service/buffer_liveness.h"
|
||||
#include "tensorflow/compiler/xla/service/call_graph.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_computation.h"
|
||||
@ -122,7 +123,7 @@ class HloRematerialization : public HloModulePass {
|
||||
|
||||
// Set of computations which have had rematerialization
|
||||
// applied. Rematerialization is only applied once per computation.
|
||||
tensorflow::gtl::FlatSet<const HloComputation*> rematerialized_computations_;
|
||||
absl::flat_hash_set<const HloComputation*> rematerialized_computations_;
|
||||
|
||||
// Count of the total instructions rematerialized.
|
||||
int64 instructions_rematerialized_ = 0;
|
||||
|
||||
@ -19,6 +19,7 @@ limitations under the License.
|
||||
#include <vector>
|
||||
|
||||
#include "absl/container/flat_hash_map.h"
|
||||
#include "absl/container/flat_hash_set.h"
|
||||
#include "absl/strings/str_format.h"
|
||||
#include "absl/strings/str_join.h"
|
||||
#include "tensorflow/compiler/xla/map_util.h"
|
||||
@ -119,7 +120,7 @@ Status HloSchedule::UpdateComputationSchedule(
|
||||
}
|
||||
|
||||
// Set of all HloInstructions in the schedule.
|
||||
tensorflow::gtl::FlatSet<int> ids_in_schedule;
|
||||
absl::flat_hash_set<int> ids_in_schedule;
|
||||
for (int id : sequences_.at(computation->unique_id()).ids()) {
|
||||
InsertOrDie(&ids_in_schedule, id);
|
||||
}
|
||||
@ -210,7 +211,7 @@ Status HloSchedule::Update() {
|
||||
if (sequences_.size() > nonfusion_computations.size()) {
|
||||
// Schedule contains some computations which have been removed from the
|
||||
// HloModule. Remove them from the schedule as well.
|
||||
tensorflow::gtl::FlatSet<int64> nonfusion_computations_ids;
|
||||
absl::flat_hash_set<int64> nonfusion_computations_ids;
|
||||
for (const HloComputation* computation : nonfusion_computations) {
|
||||
nonfusion_computations_ids.insert(computation->unique_id());
|
||||
}
|
||||
|
||||
@ -18,6 +18,7 @@ limitations under the License.
|
||||
#include <algorithm>
|
||||
#include <utility>
|
||||
|
||||
#include "absl/container/flat_hash_set.h"
|
||||
#include "absl/memory/memory.h"
|
||||
#include "absl/strings/str_cat.h"
|
||||
#include "absl/strings/str_join.h"
|
||||
@ -31,7 +32,6 @@ limitations under the License.
|
||||
#include "tensorflow/compiler/xla/types.h"
|
||||
#include "tensorflow/compiler/xla/util.h"
|
||||
#include "tensorflow/core/lib/core/errors.h"
|
||||
#include "tensorflow/core/lib/gtl/flatset.h"
|
||||
#include "tensorflow/core/platform/logging.h"
|
||||
#include "tensorflow/core/platform/types.h"
|
||||
|
||||
@ -167,7 +167,7 @@ void HloValue::SetPositionsAndComputeUses(
|
||||
positions_.insert(positions_.end(), positions.begin(), positions.end());
|
||||
|
||||
// Gather the computation roots at which this value appears.
|
||||
tensorflow::gtl::FlatSet<HloInstruction*> root_positions;
|
||||
absl::flat_hash_set<HloInstruction*> root_positions;
|
||||
for (const HloPosition& position : positions_) {
|
||||
if (position.instruction ==
|
||||
position.instruction->parent()->root_instruction()) {
|
||||
|
||||
@ -17,6 +17,7 @@ limitations under the License.
|
||||
|
||||
#include "absl/algorithm/container.h"
|
||||
#include "absl/container/flat_hash_map.h"
|
||||
#include "absl/container/flat_hash_set.h"
|
||||
#include "absl/container/inlined_vector.h"
|
||||
#include "absl/strings/str_cat.h"
|
||||
#include "absl/strings/str_join.h"
|
||||
@ -24,7 +25,6 @@ limitations under the License.
|
||||
#include "tensorflow/compiler/xla/map_util.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_evaluator.h"
|
||||
#include "tensorflow/compiler/xla/util.h"
|
||||
#include "tensorflow/core/lib/gtl/flatset.h"
|
||||
|
||||
namespace xla {
|
||||
namespace gtl = ::tensorflow::gtl;
|
||||
|
||||
@ -26,6 +26,7 @@ limitations under the License.
|
||||
#include <vector>
|
||||
|
||||
#include "absl/container/flat_hash_map.h"
|
||||
#include "absl/container/flat_hash_set.h"
|
||||
#include "tensorflow/compiler/xla/service/computation_layout.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_computation.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
|
||||
@ -39,7 +40,6 @@ limitations under the License.
|
||||
#include "tensorflow/compiler/xla/types.h"
|
||||
#include "tensorflow/compiler/xla/xla_data.pb.h"
|
||||
#include "tensorflow/core/lib/core/status.h"
|
||||
#include "tensorflow/core/lib/gtl/flatset.h"
|
||||
#include "tensorflow/core/platform/types.h"
|
||||
|
||||
namespace xla {
|
||||
@ -504,7 +504,7 @@ class LayoutAssignment : public HloModulePass {
|
||||
|
||||
// Every copy added to the module by the layout assignment pass is registered
|
||||
// here.
|
||||
tensorflow::gtl::FlatSet<HloInstruction*> added_copies_;
|
||||
absl::flat_hash_set<HloInstruction*> added_copies_;
|
||||
|
||||
// The pointer to the channel layout constraints passed in with the
|
||||
// constructor. If not nullptr, this is an input/output argument.
|
||||
@ -521,8 +521,7 @@ class LayoutAssignment : public HloModulePass {
|
||||
|
||||
// The set of HLO instructions which lacked any layout constraint, thus
|
||||
// receiving propagated default layouts.
|
||||
tensorflow::gtl::FlatSet<const HloInstruction*>
|
||||
unconstrained_layout_instructions_;
|
||||
absl::flat_hash_set<const HloInstruction*> unconstrained_layout_instructions_;
|
||||
};
|
||||
|
||||
} // namespace xla
|
||||
|
||||
@ -39,6 +39,7 @@ cc_library(
|
||||
"//tensorflow/compiler/xla/service:logical_buffer",
|
||||
"//tensorflow/core:lib",
|
||||
"@com_google_absl//absl/container:flat_hash_map",
|
||||
"@com_google_absl//absl/container:flat_hash_set",
|
||||
"@com_google_absl//absl/strings",
|
||||
"@llvm//:core",
|
||||
],
|
||||
|
||||
@ -15,7 +15,7 @@ limitations under the License.
|
||||
|
||||
#include "tensorflow/compiler/xla/service/llvm_ir/alias_analysis.h"
|
||||
|
||||
#include <unordered_set>
|
||||
#include <map>
|
||||
|
||||
#include "llvm/IR/MDBuilder.h"
|
||||
#include "tensorflow/compiler/xla/service/llvm_ir/llvm_util.h"
|
||||
@ -164,9 +164,7 @@ llvm::MDNode* AliasAnalysis::GetNoaliasMetadataForBuffer(
|
||||
add_buffers_to_worklist(operand);
|
||||
}
|
||||
|
||||
tensorflow::gtl::FlatSet<BufferAllocation::Slice,
|
||||
BufferAllocation::Slice::Hasher>
|
||||
buffers;
|
||||
std::set<BufferAllocation::Slice> buffers;
|
||||
for (const LogicalBuffer* buffer : worklist) {
|
||||
// Skip buffers which cannot be added to the noalias set.
|
||||
if (!assignment.HasAllocation(*buffer) ||
|
||||
|
||||
@ -23,7 +23,6 @@ limitations under the License.
|
||||
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
|
||||
#include "tensorflow/compiler/xla/service/llvm_ir/ir_array.h"
|
||||
#include "tensorflow/compiler/xla/types.h"
|
||||
#include "tensorflow/core/lib/gtl/flatset.h"
|
||||
|
||||
namespace xla {
|
||||
namespace llvm_ir {
|
||||
|
||||
@ -15,10 +15,10 @@ limitations under the License.
|
||||
|
||||
#include "tensorflow/compiler/xla/service/multi_output_fusion.h"
|
||||
|
||||
#include "absl/container/flat_hash_set.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_opcode.h"
|
||||
#include "tensorflow/compiler/xla/shape_util.h"
|
||||
#include "tensorflow/core/lib/gtl/flatset.h"
|
||||
#include "tensorflow/core/platform/types.h"
|
||||
|
||||
namespace xla {
|
||||
@ -50,7 +50,7 @@ StatusOr<bool> MultiOutputFusion::Run(HloModule* module) {
|
||||
all_fusion_candidates_.push_back(instruction);
|
||||
|
||||
std::vector<HloInstruction*> candidates;
|
||||
tensorflow::gtl::FlatSet<HloInstruction*> candidates_set;
|
||||
absl::flat_hash_set<HloInstruction*> candidates_set;
|
||||
VLOG(10) << "Looking at instruction: " << instruction->name();
|
||||
for (auto operand : instruction->operands()) {
|
||||
// Filter out the non-interesting instructions -- they
|
||||
@ -172,7 +172,7 @@ void MultiOutputFusion::Update(HloInstruction* instr1, HloInstruction* instr2) {
|
||||
// Update the fusible list for fusion. Variable new_fusibles keeps
|
||||
// track of the new or changed entries.
|
||||
std::vector<std::pair<HloInstruction*, int64>> new_fusibles;
|
||||
tensorflow::gtl::FlatSet<HloInstruction*> in_list;
|
||||
absl::flat_hash_set<HloInstruction*> in_list;
|
||||
auto it = fusion_node.fusibles.begin();
|
||||
while (it != fusion_node.fusibles.end()) {
|
||||
HloInstruction* instr = it->first;
|
||||
|
||||
@ -19,9 +19,9 @@ limitations under the License.
|
||||
#include <string>
|
||||
|
||||
#include "absl/container/flat_hash_map.h"
|
||||
#include "absl/container/flat_hash_set.h"
|
||||
#include "absl/strings/string_view.h"
|
||||
#include "tensorflow/compiler/xla/types.h"
|
||||
#include "tensorflow/core/lib/gtl/flatset.h"
|
||||
#include "tensorflow/core/platform/macros.h"
|
||||
|
||||
namespace xla {
|
||||
@ -69,7 +69,7 @@ class NameUniquer {
|
||||
int64 next_ = 0;
|
||||
|
||||
// Set of all the identifiers which has been used.
|
||||
tensorflow::gtl::FlatSet<int64> used_;
|
||||
absl::flat_hash_set<int64> used_;
|
||||
};
|
||||
|
||||
// The string to use to separate the prefix of the name from the uniquing
|
||||
|
||||
@ -22,6 +22,7 @@ limitations under the License.
|
||||
#include <string>
|
||||
|
||||
#include "absl/algorithm/container.h"
|
||||
#include "absl/container/flat_hash_set.h"
|
||||
#include "absl/strings/str_cat.h"
|
||||
#include "absl/strings/str_format.h"
|
||||
#include "absl/strings/str_join.h"
|
||||
@ -33,7 +34,6 @@ limitations under the License.
|
||||
#include "tensorflow/compiler/xla/window_util.h"
|
||||
#include "tensorflow/compiler/xla/xla_data.pb.h"
|
||||
#include "tensorflow/core/lib/core/errors.h"
|
||||
#include "tensorflow/core/lib/gtl/flatset.h"
|
||||
#include "tensorflow/core/lib/math/math_util.h"
|
||||
#include "tensorflow/core/platform/logging.h"
|
||||
#include "tensorflow/core/platform/protobuf.h"
|
||||
@ -577,7 +577,7 @@ Status ValidateDotDimensionNumbers(
|
||||
// Check that dimension numbers are unique.
|
||||
auto dims_unique = [](absl::Span<const int64> contracting_dims,
|
||||
absl::Span<const int64> batch_dims) -> bool {
|
||||
tensorflow::gtl::FlatSet<int64> dim_set;
|
||||
absl::flat_hash_set<int64> dim_set;
|
||||
auto is_unique = [&dim_set](int64 i) -> bool {
|
||||
return dim_set.insert(i).second;
|
||||
};
|
||||
|
||||
@ -18,6 +18,7 @@ limitations under the License.
|
||||
#include <string>
|
||||
#include <utility>
|
||||
|
||||
#include "absl/container/flat_hash_set.h"
|
||||
#include "absl/memory/memory.h"
|
||||
#include "absl/strings/str_cat.h"
|
||||
#include "absl/strings/str_format.h"
|
||||
@ -26,7 +27,6 @@ limitations under the License.
|
||||
#include "tensorflow/compiler/xla/status_macros.h"
|
||||
#include "tensorflow/compiler/xla/types.h"
|
||||
#include "tensorflow/compiler/xla/util.h"
|
||||
#include "tensorflow/core/lib/gtl/flatset.h"
|
||||
#include "tensorflow/core/platform/logging.h"
|
||||
|
||||
namespace xla {
|
||||
@ -147,7 +147,7 @@ void ScopedShapedBuffer::Deallocate() {
|
||||
// Deallocate all non-null buffers. A buffer may appear in more than one spot
|
||||
// in the shape (eg, a tuple with a repeated element) so keep track of what
|
||||
// has been deallocated.
|
||||
tensorflow::gtl::FlatSet<void*> deallocated_ptrs;
|
||||
absl::flat_hash_set<void*> deallocated_ptrs;
|
||||
for (auto& pair : buffers_) {
|
||||
se::DeviceMemoryBase& memory_base = pair.second;
|
||||
if (!memory_base.is_null() &&
|
||||
|
||||
@ -36,7 +36,6 @@ limitations under the License.
|
||||
#include "tensorflow/compiler/xla/xla_data.pb.h"
|
||||
#include "tensorflow/core/lib/core/status.h"
|
||||
#include "tensorflow/core/lib/gtl/compactptrset.h"
|
||||
#include "tensorflow/core/lib/gtl/flatset.h"
|
||||
#include "tensorflow/core/platform/macros.h"
|
||||
#include "tensorflow/core/platform/types.h"
|
||||
|
||||
|
||||
@ -16,17 +16,17 @@ limitations under the License.
|
||||
#include "tensorflow/compiler/xla/service/while_loop_invariant_code_motion.h"
|
||||
#include "absl/algorithm/container.h"
|
||||
#include "absl/container/flat_hash_map.h"
|
||||
#include "absl/container/flat_hash_set.h"
|
||||
#include "absl/container/inlined_vector.h"
|
||||
#include "tensorflow/compiler/xla/service/tuple_util.h"
|
||||
#include "tensorflow/compiler/xla/service/while_util.h"
|
||||
#include "tensorflow/compiler/xla/util.h"
|
||||
#include "tensorflow/core/lib/gtl/flatset.h"
|
||||
|
||||
namespace xla {
|
||||
|
||||
using absl::flat_hash_map;
|
||||
using absl::flat_hash_set;
|
||||
using absl::InlinedVector;
|
||||
using tensorflow::gtl::FlatSet;
|
||||
|
||||
// Copies `to_hoist` to the computation containing `while_instr`, hoisting its
|
||||
// operands as needed. All of its transitive operands are expected to be either
|
||||
@ -35,7 +35,7 @@ using tensorflow::gtl::FlatSet;
|
||||
// them into `hoisted_instructions`.
|
||||
static void CreateLoopInvariantCopy(
|
||||
flat_hash_map<HloInstruction*, HloInstruction*>* hoisted_instructions,
|
||||
FlatSet<HloInstruction*>* unhoisted_invariant_instructions,
|
||||
flat_hash_set<HloInstruction*>* unhoisted_invariant_instructions,
|
||||
HloInstruction* while_instr, HloInstruction* to_hoist) {
|
||||
HloComputation* parent_of_while = while_instr->parent();
|
||||
HloComputation* while_body = while_instr->while_body();
|
||||
@ -153,7 +153,7 @@ WhileLoopInvariantCodeMotion::TryHoistingInvariantInstructionsFromWhileBody(
|
||||
// unprofitable to be hoisted alone by NotWorthHoistingIndividually. When we
|
||||
// hoist an instruction in this set, we move it from
|
||||
// unhoisted_invariant_instructions to hoisted_instructions.
|
||||
FlatSet<HloInstruction*> unhoisted_invariant_instructions;
|
||||
flat_hash_set<HloInstruction*> unhoisted_invariant_instructions;
|
||||
|
||||
// Invariant GTE's axiomatically satisfy the constraints for
|
||||
// unhoisted_invariant_instructions -- they can be legally hoisted, but there
|
||||
|
||||
@ -15,6 +15,7 @@ limitations under the License.
|
||||
|
||||
#include "tensorflow/compiler/xla/service/while_loop_simplifier.h"
|
||||
#include "absl/container/flat_hash_map.h"
|
||||
#include "absl/container/flat_hash_set.h"
|
||||
#include "absl/strings/str_cat.h"
|
||||
#include "absl/strings/str_join.h"
|
||||
#include "absl/types/optional.h"
|
||||
@ -114,7 +115,7 @@ static StatusOr<bool> TryRemoveDeadWhileParams(HloInstruction* while_op) {
|
||||
return false;
|
||||
}
|
||||
|
||||
tensorflow::gtl::FlatSet<int64> used_tuple_indices;
|
||||
absl::flat_hash_set<int64> used_tuple_indices;
|
||||
for (HloComputation* comp : {while_body, while_cond}) {
|
||||
// The HLO verifier ensures that while_input's shape matches while_init's
|
||||
// shape, which we verified above is a tuple.
|
||||
|
||||
@ -2146,11 +2146,11 @@ xla_test(
|
||||
":test_utils",
|
||||
"//tensorflow/compiler/xla:shape_util",
|
||||
"//tensorflow/compiler/xla/client:xla_builder",
|
||||
"//tensorflow/compiler/xla/client:xla_computation",
|
||||
"//tensorflow/compiler/xla/service:hlo_parser",
|
||||
"//tensorflow/compiler/xla/tests:xla_internal_test_main",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:test",
|
||||
"@com_google_absl//absl/container:flat_hash_set",
|
||||
],
|
||||
)
|
||||
|
||||
|
||||
@ -15,6 +15,7 @@ limitations under the License.
|
||||
|
||||
#include "tensorflow/compiler/xla/tests/test_utils.h"
|
||||
|
||||
#include "absl/container/flat_hash_set.h"
|
||||
#include "tensorflow/compiler/xla/client/xla_builder.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_parser.h"
|
||||
#include "tensorflow/compiler/xla/shape_util.h"
|
||||
@ -145,7 +146,7 @@ ENTRY %sort.148.1589 (parameter.0: f32[1048576], parameter.1: s32[1048576]) -> (
|
||||
ASSERT_EQ(args.size(), 2);
|
||||
const Literal& key_arg = args[0];
|
||||
|
||||
tensorflow::gtl::FlatSet<uint32> key_set;
|
||||
absl::flat_hash_set<uint32> key_set;
|
||||
for (const float& value : key_arg.data<float>()) {
|
||||
EXPECT_TRUE(key_set.insert(tensorflow::bit_cast<uint32>(value)).second);
|
||||
}
|
||||
@ -168,7 +169,7 @@ ENTRY %sort.148.1589 (parameter.0: s32[1048576], parameter.1: s32[1048576]) -> (
|
||||
ASSERT_EQ(args.size(), 2);
|
||||
const Literal& key_arg = args[0];
|
||||
|
||||
tensorflow::gtl::FlatSet<int32> key_set;
|
||||
absl::flat_hash_set<int32> key_set;
|
||||
for (const int32& value : key_arg.data<int32>()) {
|
||||
EXPECT_TRUE(key_set.insert(tensorflow::bit_cast<uint32>(value)).second);
|
||||
}
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user