[XLA] Migrate from gtl::FlatMap to absl::flat_hash_map
PiperOrigin-RevId: 215272497
This commit is contained in:
parent
44acd839c5
commit
3039a4694e
@ -258,6 +258,7 @@ cc_library(
|
|||||||
"//tensorflow/core:lib_internal",
|
"//tensorflow/core:lib_internal",
|
||||||
"//tensorflow/core:protos_all_cc",
|
"//tensorflow/core:protos_all_cc",
|
||||||
"//tensorflow/core/kernels:variable_ops",
|
"//tensorflow/core/kernels:variable_ops",
|
||||||
|
"@com_google_absl//absl/container:flat_hash_map",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -323,6 +324,7 @@ cc_library(
|
|||||||
"//tensorflow/core:graph",
|
"//tensorflow/core:graph",
|
||||||
"//tensorflow/core:lib",
|
"//tensorflow/core:lib",
|
||||||
"//tensorflow/core:protos_all_cc",
|
"//tensorflow/core:protos_all_cc",
|
||||||
|
"@com_google_absl//absl/container:flat_hash_map",
|
||||||
"@com_google_absl//absl/memory",
|
"@com_google_absl//absl/memory",
|
||||||
"@com_google_absl//absl/strings",
|
"@com_google_absl//absl/strings",
|
||||||
"@com_google_absl//absl/types:optional",
|
"@com_google_absl//absl/types:optional",
|
||||||
@ -400,6 +402,7 @@ cc_library(
|
|||||||
"//tensorflow/core:protos_all_cc",
|
"//tensorflow/core:protos_all_cc",
|
||||||
"//tensorflow/core/kernels:bounds_check",
|
"//tensorflow/core/kernels:bounds_check",
|
||||||
"@com_google_absl//absl/algorithm:container",
|
"@com_google_absl//absl/algorithm:container",
|
||||||
|
"@com_google_absl//absl/container:flat_hash_map",
|
||||||
"@com_google_absl//absl/memory",
|
"@com_google_absl//absl/memory",
|
||||||
"@com_google_absl//absl/strings",
|
"@com_google_absl//absl/strings",
|
||||||
],
|
],
|
||||||
@ -471,6 +474,7 @@ tf_cc_test(
|
|||||||
"//tensorflow/core:test",
|
"//tensorflow/core:test",
|
||||||
"//tensorflow/core:test_main",
|
"//tensorflow/core:test_main",
|
||||||
"//tensorflow/core:testlib",
|
"//tensorflow/core:testlib",
|
||||||
|
"@com_google_absl//absl/container:flat_hash_map",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -509,6 +513,7 @@ tf_cc_test(
|
|||||||
"//tensorflow/core:test_main",
|
"//tensorflow/core:test_main",
|
||||||
"//tensorflow/core:testlib",
|
"//tensorflow/core:testlib",
|
||||||
"//tensorflow/core/grappler/optimizers/data:graph_utils",
|
"//tensorflow/core/grappler/optimizers/data:graph_utils",
|
||||||
|
"@com_google_absl//absl/container:flat_hash_map",
|
||||||
"@com_google_absl//absl/memory",
|
"@com_google_absl//absl/memory",
|
||||||
"@com_google_absl//absl/strings",
|
"@com_google_absl//absl/strings",
|
||||||
],
|
],
|
||||||
|
@ -15,6 +15,7 @@ limitations under the License.
|
|||||||
|
|
||||||
#include "tensorflow/compiler/jit/deadness_analysis.h"
|
#include "tensorflow/compiler/jit/deadness_analysis.h"
|
||||||
#include "absl/algorithm/container.h"
|
#include "absl/algorithm/container.h"
|
||||||
|
#include "absl/container/flat_hash_map.h"
|
||||||
#include "absl/strings/str_join.h"
|
#include "absl/strings/str_join.h"
|
||||||
#include "tensorflow/compiler/jit/deadness_analysis_internal.h"
|
#include "tensorflow/compiler/jit/deadness_analysis_internal.h"
|
||||||
#include "tensorflow/core/graph/algorithm.h"
|
#include "tensorflow/core/graph/algorithm.h"
|
||||||
@ -420,15 +421,15 @@ class PredicateFactory {
|
|||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
gtl::FlatMap<SignatureForAndOr, std::unique_ptr<Predicate>,
|
absl::flat_hash_map<SignatureForAndOr, std::unique_ptr<Predicate>,
|
||||||
HashSignatureForAndOr>
|
HashSignatureForAndOr>
|
||||||
interned_and_or_instances_;
|
interned_and_or_instances_;
|
||||||
gtl::FlatMap<SignatureForNot, std::unique_ptr<Predicate>>
|
absl::flat_hash_map<SignatureForNot, std::unique_ptr<Predicate>>
|
||||||
interned_not_instances_;
|
interned_not_instances_;
|
||||||
gtl::FlatMap<SignatureForAndRec, std::unique_ptr<Predicate>>
|
absl::flat_hash_map<SignatureForAndRec, std::unique_ptr<Predicate>>
|
||||||
interned_and_rec_instances_;
|
interned_and_rec_instances_;
|
||||||
gtl::FlatMap<SignatureForSymbol, std::unique_ptr<Predicate>,
|
absl::flat_hash_map<SignatureForSymbol, std::unique_ptr<Predicate>,
|
||||||
HashSignatureForSymbol>
|
HashSignatureForSymbol>
|
||||||
interned_symbol_instances_;
|
interned_symbol_instances_;
|
||||||
};
|
};
|
||||||
|
|
||||||
@ -572,7 +573,8 @@ class DeadnessAnalysisImpl : public DeadnessAnalysis {
|
|||||||
Status PopulateWithReversePostOrder(absl::Span<Node* const> rpo);
|
Status PopulateWithReversePostOrder(absl::Span<Node* const> rpo);
|
||||||
bool HasInputsWithMismatchingDeadness(const Node& node) override;
|
bool HasInputsWithMismatchingDeadness(const Node& node) override;
|
||||||
void Print() const override;
|
void Print() const override;
|
||||||
gtl::FlatMap<TensorId, string, TensorId::Hasher> PredicateMapAsString() const;
|
absl::flat_hash_map<TensorId, string, TensorId::Hasher> PredicateMapAsString()
|
||||||
|
const;
|
||||||
|
|
||||||
private:
|
private:
|
||||||
enum class EdgeKind { kDataAndControl, kDataOnly, kControlOnly };
|
enum class EdgeKind { kDataAndControl, kDataOnly, kControlOnly };
|
||||||
@ -614,7 +616,7 @@ class DeadnessAnalysisImpl : public DeadnessAnalysis {
|
|||||||
Status HandleNode(Node* n, std::vector<bool>* should_revisit);
|
Status HandleNode(Node* n, std::vector<bool>* should_revisit);
|
||||||
|
|
||||||
const Graph& graph_;
|
const Graph& graph_;
|
||||||
gtl::FlatMap<TensorId, Predicate*, TensorId::Hasher> predicate_map_;
|
absl::flat_hash_map<TensorId, Predicate*, TensorId::Hasher> predicate_map_;
|
||||||
PredicateFactory predicate_factory_;
|
PredicateFactory predicate_factory_;
|
||||||
bool vlog_;
|
bool vlog_;
|
||||||
};
|
};
|
||||||
@ -977,9 +979,9 @@ DeadnessAnalysis::~DeadnessAnalysis() {}
|
|||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
|
||||||
gtl::FlatMap<TensorId, string, TensorId::Hasher>
|
absl::flat_hash_map<TensorId, string, TensorId::Hasher>
|
||||||
DeadnessAnalysisImpl::PredicateMapAsString() const {
|
DeadnessAnalysisImpl::PredicateMapAsString() const {
|
||||||
gtl::FlatMap<TensorId, string, TensorId::Hasher> result;
|
absl::flat_hash_map<TensorId, string, TensorId::Hasher> result;
|
||||||
std::vector<TensorId> tensor_ids;
|
std::vector<TensorId> tensor_ids;
|
||||||
for (const auto& kv_pair : predicate_map_) {
|
for (const auto& kv_pair : predicate_map_) {
|
||||||
CHECK(result.insert({kv_pair.first, kv_pair.second->ToString()}).second);
|
CHECK(result.insert({kv_pair.first, kv_pair.second->ToString()}).second);
|
||||||
|
@ -16,15 +16,15 @@ limitations under the License.
|
|||||||
#ifndef TENSORFLOW_COMPILER_JIT_DEADNESS_ANALYSIS_INTERNAL_H_
|
#ifndef TENSORFLOW_COMPILER_JIT_DEADNESS_ANALYSIS_INTERNAL_H_
|
||||||
#define TENSORFLOW_COMPILER_JIT_DEADNESS_ANALYSIS_INTERNAL_H_
|
#define TENSORFLOW_COMPILER_JIT_DEADNESS_ANALYSIS_INTERNAL_H_
|
||||||
|
|
||||||
|
#include "absl/container/flat_hash_map.h"
|
||||||
#include "tensorflow/core/graph/tensor_id.h"
|
#include "tensorflow/core/graph/tensor_id.h"
|
||||||
#include "tensorflow/core/lib/gtl/flatmap.h"
|
|
||||||
|
|
||||||
namespace tensorflow {
|
namespace tensorflow {
|
||||||
namespace deadness_analysis_internal {
|
namespace deadness_analysis_internal {
|
||||||
|
|
||||||
// Returns a map describing the predicate each Tensor was mapped to. For
|
// Returns a map describing the predicate each Tensor was mapped to. For
|
||||||
// testing purposes only.
|
// testing purposes only.
|
||||||
using PredicateMapTy = gtl::FlatMap<TensorId, string, TensorId::Hasher>;
|
using PredicateMapTy = absl::flat_hash_map<TensorId, string, TensorId::Hasher>;
|
||||||
Status ComputePredicates(const Graph& graph, PredicateMapTy* out_predicate_map);
|
Status ComputePredicates(const Graph& graph, PredicateMapTy* out_predicate_map);
|
||||||
|
|
||||||
// Returns a map describing the predicate each Tensor was mapped to. For
|
// Returns a map describing the predicate each Tensor was mapped to. For
|
||||||
|
@ -26,6 +26,7 @@ cc_library(
|
|||||||
"//tensorflow/core:lib",
|
"//tensorflow/core:lib",
|
||||||
"//tensorflow/core:stream_executor_no_cuda",
|
"//tensorflow/core:stream_executor_no_cuda",
|
||||||
"//tensorflow/core/kernels:variable_ops",
|
"//tensorflow/core/kernels:variable_ops",
|
||||||
|
"@com_google_absl//absl/container:flat_hash_map",
|
||||||
"@com_google_absl//absl/memory",
|
"@com_google_absl//absl/memory",
|
||||||
],
|
],
|
||||||
alwayslink = 1,
|
alwayslink = 1,
|
||||||
|
@ -15,6 +15,7 @@ limitations under the License.
|
|||||||
|
|
||||||
#include "tensorflow/compiler/jit/kernels/xla_ops.h"
|
#include "tensorflow/compiler/jit/kernels/xla_ops.h"
|
||||||
|
|
||||||
|
#include "absl/container/flat_hash_map.h"
|
||||||
#include "absl/memory/memory.h"
|
#include "absl/memory/memory.h"
|
||||||
#include "tensorflow/compiler/jit/defs.h"
|
#include "tensorflow/compiler/jit/defs.h"
|
||||||
#include "tensorflow/compiler/tf2xla/shape_util.h"
|
#include "tensorflow/compiler/tf2xla/shape_util.h"
|
||||||
@ -163,7 +164,7 @@ class XlaExecutableClosureStore {
|
|||||||
private:
|
private:
|
||||||
mutex mutex_;
|
mutex mutex_;
|
||||||
int64 key_counter_ GUARDED_BY(mutex_);
|
int64 key_counter_ GUARDED_BY(mutex_);
|
||||||
gtl::FlatMap<KeyT, XlaExecutableClosure> closures_ GUARDED_BY(mutex_);
|
absl::flat_hash_map<KeyT, XlaExecutableClosure> closures_ GUARDED_BY(mutex_);
|
||||||
|
|
||||||
TF_DISALLOW_COPY_AND_ASSIGN(XlaExecutableClosureStore);
|
TF_DISALLOW_COPY_AND_ASSIGN(XlaExecutableClosureStore);
|
||||||
};
|
};
|
||||||
|
@ -15,6 +15,7 @@ limitations under the License.
|
|||||||
|
|
||||||
#include "tensorflow/compiler/jit/mark_for_compilation_pass_test_helper.h"
|
#include "tensorflow/compiler/jit/mark_for_compilation_pass_test_helper.h"
|
||||||
|
|
||||||
|
#include "absl/container/flat_hash_map.h"
|
||||||
#include "absl/memory/memory.h"
|
#include "absl/memory/memory.h"
|
||||||
#include "absl/strings/match.h"
|
#include "absl/strings/match.h"
|
||||||
#include "tensorflow/cc/framework/ops.h"
|
#include "tensorflow/cc/framework/ops.h"
|
||||||
@ -61,10 +62,10 @@ std::unordered_map<string, string> GetClusters(const Graph& graph) {
|
|||||||
return ids;
|
return ids;
|
||||||
}
|
}
|
||||||
|
|
||||||
gtl::FlatMap<string, std::vector<string>> GetClusterSets(
|
absl::flat_hash_map<string, std::vector<string>> GetClusterSets(
|
||||||
const Graph& g, std::vector<string>* cluster_names = nullptr) {
|
const Graph& g, std::vector<string>* cluster_names = nullptr) {
|
||||||
CHECK(cluster_names == nullptr || cluster_names->empty());
|
CHECK(cluster_names == nullptr || cluster_names->empty());
|
||||||
gtl::FlatMap<string, std::vector<string>> cluster_sets;
|
absl::flat_hash_map<string, std::vector<string>> cluster_sets;
|
||||||
for (const auto& p : GetClusters(g)) {
|
for (const auto& p : GetClusters(g)) {
|
||||||
cluster_sets[p.second].push_back(p.first);
|
cluster_sets[p.second].push_back(p.first);
|
||||||
}
|
}
|
||||||
@ -566,7 +567,7 @@ TEST(XlaCompilationTest, ResourcesClusteringAllowed) {
|
|||||||
std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
|
std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
|
||||||
TF_EXPECT_OK(root.ToGraph(graph.get()));
|
TF_EXPECT_OK(root.ToGraph(graph.get()));
|
||||||
TF_ASSERT_OK(MarkForCompilationPassTestHelper::MarkForCompilation(&graph));
|
TF_ASSERT_OK(MarkForCompilationPassTestHelper::MarkForCompilation(&graph));
|
||||||
gtl::FlatMap<string, std::vector<string>> cluster_sets =
|
absl::flat_hash_map<string, std::vector<string>> cluster_sets =
|
||||||
GetClusterSets(*graph);
|
GetClusterSets(*graph);
|
||||||
ASSERT_EQ(cluster_sets.size(), 1);
|
ASSERT_EQ(cluster_sets.size(), 1);
|
||||||
std::vector<string> expected_clustered_nodes = {"AssignmentW", "ReadR",
|
std::vector<string> expected_clustered_nodes = {"AssignmentW", "ReadR",
|
||||||
@ -586,7 +587,7 @@ TEST(XlaCompilationTest, ResourcesClusteringDisallowed) {
|
|||||||
std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
|
std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
|
||||||
TF_EXPECT_OK(root.ToGraph(graph.get()));
|
TF_EXPECT_OK(root.ToGraph(graph.get()));
|
||||||
TF_ASSERT_OK(MarkForCompilationPassTestHelper::MarkForCompilation(&graph));
|
TF_ASSERT_OK(MarkForCompilationPassTestHelper::MarkForCompilation(&graph));
|
||||||
gtl::FlatMap<string, std::vector<string>> cluster_sets =
|
absl::flat_hash_map<string, std::vector<string>> cluster_sets =
|
||||||
GetClusterSets(*graph);
|
GetClusterSets(*graph);
|
||||||
ASSERT_EQ(cluster_sets.size(), 1);
|
ASSERT_EQ(cluster_sets.size(), 1);
|
||||||
std::vector<string> expected_clustered_nodes = {"AssignmentW",
|
std::vector<string> expected_clustered_nodes = {"AssignmentW",
|
||||||
@ -616,7 +617,7 @@ TEST(XlaCompilationTest, ChainOfOps) {
|
|||||||
TF_ASSERT_OK(MarkForCompilationPassTestHelper::MarkForCompilation(&graph));
|
TF_ASSERT_OK(MarkForCompilationPassTestHelper::MarkForCompilation(&graph));
|
||||||
|
|
||||||
std::vector<string> cluster_names;
|
std::vector<string> cluster_names;
|
||||||
gtl::FlatMap<string, std::vector<string>> cluster_sets =
|
absl::flat_hash_map<string, std::vector<string>> cluster_sets =
|
||||||
GetClusterSets(*graph, &cluster_names);
|
GetClusterSets(*graph, &cluster_names);
|
||||||
|
|
||||||
ASSERT_EQ(cluster_sets.size(), 2);
|
ASSERT_EQ(cluster_sets.size(), 2);
|
||||||
|
@ -89,7 +89,6 @@ limitations under the License.
|
|||||||
#include "tensorflow/core/framework/node_def.pb.h"
|
#include "tensorflow/core/framework/node_def.pb.h"
|
||||||
#include "tensorflow/core/graph/algorithm.h"
|
#include "tensorflow/core/graph/algorithm.h"
|
||||||
#include "tensorflow/core/graph/tensor_id.h"
|
#include "tensorflow/core/graph/tensor_id.h"
|
||||||
#include "tensorflow/core/lib/gtl/flatmap.h"
|
|
||||||
#include "tensorflow/core/lib/gtl/flatset.h"
|
#include "tensorflow/core/lib/gtl/flatset.h"
|
||||||
#include "tensorflow/core/lib/hash/hash.h"
|
#include "tensorflow/core/lib/hash/hash.h"
|
||||||
#include "tensorflow/core/util/ptr_util.h"
|
#include "tensorflow/core/util/ptr_util.h"
|
||||||
|
@ -16,6 +16,7 @@ limitations under the License.
|
|||||||
#ifndef TENSORFLOW_COMPILER_JIT_XLA_COMPILATION_CACHE_H_
|
#ifndef TENSORFLOW_COMPILER_JIT_XLA_COMPILATION_CACHE_H_
|
||||||
#define TENSORFLOW_COMPILER_JIT_XLA_COMPILATION_CACHE_H_
|
#define TENSORFLOW_COMPILER_JIT_XLA_COMPILATION_CACHE_H_
|
||||||
|
|
||||||
|
#include "absl/container/flat_hash_map.h"
|
||||||
#include "tensorflow/compiler/tf2xla/xla_compiler.h"
|
#include "tensorflow/compiler/tf2xla/xla_compiler.h"
|
||||||
#include "tensorflow/compiler/tf2xla/xla_context.h"
|
#include "tensorflow/compiler/tf2xla/xla_context.h"
|
||||||
#include "tensorflow/compiler/xla/client/local_client.h"
|
#include "tensorflow/compiler/xla/client/local_client.h"
|
||||||
@ -24,7 +25,6 @@ limitations under the License.
|
|||||||
#include "tensorflow/core/framework/graph.pb.h"
|
#include "tensorflow/core/framework/graph.pb.h"
|
||||||
#include "tensorflow/core/framework/op_kernel.h"
|
#include "tensorflow/core/framework/op_kernel.h"
|
||||||
#include "tensorflow/core/lib/core/threadpool.h"
|
#include "tensorflow/core/lib/core/threadpool.h"
|
||||||
#include "tensorflow/core/lib/gtl/flatmap.h"
|
|
||||||
#include "tensorflow/core/platform/mutex.h"
|
#include "tensorflow/core/platform/mutex.h"
|
||||||
#include "tensorflow/core/platform/thread_annotations.h"
|
#include "tensorflow/core/platform/thread_annotations.h"
|
||||||
|
|
||||||
@ -152,7 +152,7 @@ class XlaCompilationCache : public ResourceBase {
|
|||||||
};
|
};
|
||||||
|
|
||||||
mutex compile_cache_mu_;
|
mutex compile_cache_mu_;
|
||||||
gtl::FlatMap<Signature, std::unique_ptr<Entry>, Signature::Hash> cache_
|
absl::flat_hash_map<Signature, std::unique_ptr<Entry>, Signature::Hash> cache_
|
||||||
GUARDED_BY(compile_cache_mu_);
|
GUARDED_BY(compile_cache_mu_);
|
||||||
|
|
||||||
struct CompileStats {
|
struct CompileStats {
|
||||||
@ -165,7 +165,7 @@ class XlaCompilationCache : public ResourceBase {
|
|||||||
mutex compile_stats_mu_;
|
mutex compile_stats_mu_;
|
||||||
|
|
||||||
// Maps cluster names to compilation statistics for said cluster.
|
// Maps cluster names to compilation statistics for said cluster.
|
||||||
gtl::FlatMap<string, CompileStats> compile_stats_
|
absl::flat_hash_map<string, CompileStats> compile_stats_
|
||||||
GUARDED_BY(compile_stats_mu_);
|
GUARDED_BY(compile_stats_mu_);
|
||||||
|
|
||||||
TF_DISALLOW_COPY_AND_ASSIGN(XlaCompilationCache);
|
TF_DISALLOW_COPY_AND_ASSIGN(XlaCompilationCache);
|
||||||
|
@ -635,6 +635,7 @@ cc_library(
|
|||||||
"//tensorflow/core:lib",
|
"//tensorflow/core:lib",
|
||||||
"//tensorflow/core:ops",
|
"//tensorflow/core:ops",
|
||||||
"@com_google_absl//absl/algorithm:container",
|
"@com_google_absl//absl/algorithm:container",
|
||||||
|
"@com_google_absl//absl/container:flat_hash_map",
|
||||||
"@com_google_absl//absl/strings",
|
"@com_google_absl//absl/strings",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
@ -649,6 +650,7 @@ tf_cc_test(
|
|||||||
"//tensorflow/core:test",
|
"//tensorflow/core:test",
|
||||||
"//tensorflow/core:test_main",
|
"//tensorflow/core:test_main",
|
||||||
"@com_google_absl//absl/algorithm:container",
|
"@com_google_absl//absl/algorithm:container",
|
||||||
|
"@com_google_absl//absl/container:flat_hash_map",
|
||||||
"@com_google_absl//absl/strings",
|
"@com_google_absl//absl/strings",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
@ -15,7 +15,7 @@ limitations under the License.
|
|||||||
|
|
||||||
#include "tensorflow/compiler/tf2xla/resource_operation_table.h"
|
#include "tensorflow/compiler/tf2xla/resource_operation_table.h"
|
||||||
#include "absl/algorithm/container.h"
|
#include "absl/algorithm/container.h"
|
||||||
#include "tensorflow/core/lib/gtl/flatmap.h"
|
#include "absl/container/flat_hash_map.h"
|
||||||
|
|
||||||
namespace tensorflow {
|
namespace tensorflow {
|
||||||
/*static*/ absl::string_view XlaResourceOpInfo::XlaResourceOpKindToString(
|
/*static*/ absl::string_view XlaResourceOpInfo::XlaResourceOpKindToString(
|
||||||
@ -30,9 +30,9 @@ namespace tensorflow {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
static gtl::FlatMap<absl::string_view, XlaResourceOpInfo>*
|
static absl::flat_hash_map<absl::string_view, XlaResourceOpInfo>*
|
||||||
CreateResourceOpInfoMap() {
|
CreateResourceOpInfoMap() {
|
||||||
auto* result = new gtl::FlatMap<absl::string_view, XlaResourceOpInfo>;
|
auto* result = new absl::flat_hash_map<absl::string_view, XlaResourceOpInfo>;
|
||||||
|
|
||||||
auto add = [&](absl::string_view op, XlaResourceOpKind op_kind,
|
auto add = [&](absl::string_view op, XlaResourceOpKind op_kind,
|
||||||
XlaResourceKind resource_kind) {
|
XlaResourceKind resource_kind) {
|
||||||
@ -103,15 +103,15 @@ CreateResourceOpInfoMap() {
|
|||||||
return result;
|
return result;
|
||||||
}
|
}
|
||||||
|
|
||||||
static const gtl::FlatMap<absl::string_view, XlaResourceOpInfo>&
|
static const absl::flat_hash_map<absl::string_view, XlaResourceOpInfo>&
|
||||||
GetStaticResourceOpInfoMap() {
|
GetStaticResourceOpInfoMap() {
|
||||||
static gtl::FlatMap<absl::string_view, XlaResourceOpInfo>* op_info_map =
|
static absl::flat_hash_map<absl::string_view, XlaResourceOpInfo>*
|
||||||
CreateResourceOpInfoMap();
|
op_info_map = CreateResourceOpInfoMap();
|
||||||
return *op_info_map;
|
return *op_info_map;
|
||||||
}
|
}
|
||||||
|
|
||||||
const XlaResourceOpInfo* GetResourceOpInfoForOp(absl::string_view op) {
|
const XlaResourceOpInfo* GetResourceOpInfoForOp(absl::string_view op) {
|
||||||
const gtl::FlatMap<absl::string_view, XlaResourceOpInfo>& op_infos =
|
const absl::flat_hash_map<absl::string_view, XlaResourceOpInfo>& op_infos =
|
||||||
GetStaticResourceOpInfoMap();
|
GetStaticResourceOpInfoMap();
|
||||||
auto it = op_infos.find(op);
|
auto it = op_infos.find(op);
|
||||||
return it == op_infos.end() ? nullptr : &it->second;
|
return it == op_infos.end() ? nullptr : &it->second;
|
||||||
|
@ -16,6 +16,7 @@ limitations under the License.
|
|||||||
#include "tensorflow/compiler/tf2xla/resource_operation_table.h"
|
#include "tensorflow/compiler/tf2xla/resource_operation_table.h"
|
||||||
|
|
||||||
#include "absl/algorithm/container.h"
|
#include "absl/algorithm/container.h"
|
||||||
|
#include "absl/container/flat_hash_map.h"
|
||||||
#include "absl/strings/str_join.h"
|
#include "absl/strings/str_join.h"
|
||||||
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
|
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
|
||||||
#include "tensorflow/core/lib/core/status_test_util.h"
|
#include "tensorflow/core/lib/core/status_test_util.h"
|
||||||
@ -33,7 +34,7 @@ bool HasResourceInputOrOutput(const OpDef& op_def) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
TEST(ResourceOperationTableTest, HaveAllResourceOps) {
|
TEST(ResourceOperationTableTest, HaveAllResourceOps) {
|
||||||
gtl::FlatMap<string, bool> known_resource_ops;
|
absl::flat_hash_map<string, bool> known_resource_ops;
|
||||||
for (absl::string_view known_resource_op :
|
for (absl::string_view known_resource_op :
|
||||||
resource_op_table_internal::GetKnownResourceOps()) {
|
resource_op_table_internal::GetKnownResourceOps()) {
|
||||||
ASSERT_TRUE(
|
ASSERT_TRUE(
|
||||||
|
@ -220,6 +220,7 @@ cc_library(
|
|||||||
"//tensorflow/compiler/xla/service:shape_inference",
|
"//tensorflow/compiler/xla/service:shape_inference",
|
||||||
"//tensorflow/core:lib",
|
"//tensorflow/core:lib",
|
||||||
"@com_google_absl//absl/algorithm:container",
|
"@com_google_absl//absl/algorithm:container",
|
||||||
|
"@com_google_absl//absl/container:flat_hash_map",
|
||||||
"@com_google_absl//absl/memory",
|
"@com_google_absl//absl/memory",
|
||||||
"@com_google_absl//absl/strings",
|
"@com_google_absl//absl/strings",
|
||||||
"@com_google_absl//absl/types:span",
|
"@com_google_absl//absl/types:span",
|
||||||
|
@ -21,6 +21,7 @@ limitations under the License.
|
|||||||
#include <type_traits>
|
#include <type_traits>
|
||||||
#include <utility>
|
#include <utility>
|
||||||
|
|
||||||
|
#include "absl/container/flat_hash_map.h"
|
||||||
#include "absl/strings/string_view.h"
|
#include "absl/strings/string_view.h"
|
||||||
#include "absl/types/span.h"
|
#include "absl/types/span.h"
|
||||||
#include "tensorflow/compiler/xla/client/padding.h"
|
#include "tensorflow/compiler/xla/client/padding.h"
|
||||||
@ -34,7 +35,6 @@ limitations under the License.
|
|||||||
#include "tensorflow/compiler/xla/statusor.h"
|
#include "tensorflow/compiler/xla/statusor.h"
|
||||||
#include "tensorflow/compiler/xla/types.h"
|
#include "tensorflow/compiler/xla/types.h"
|
||||||
#include "tensorflow/compiler/xla/xla_data.pb.h"
|
#include "tensorflow/compiler/xla/xla_data.pb.h"
|
||||||
#include "tensorflow/core/lib/gtl/flatmap.h"
|
|
||||||
#include "tensorflow/core/lib/gtl/flatset.h"
|
#include "tensorflow/core/lib/gtl/flatset.h"
|
||||||
#include "tensorflow/core/platform/macros.h"
|
#include "tensorflow/core/platform/macros.h"
|
||||||
#include "tensorflow/core/platform/stacktrace.h"
|
#include "tensorflow/core/platform/stacktrace.h"
|
||||||
@ -1027,7 +1027,7 @@ class XlaBuilder {
|
|||||||
|
|
||||||
// A map from XlaOp::Handle to the index in the instructions_ vector where the
|
// A map from XlaOp::Handle to the index in the instructions_ vector where the
|
||||||
// instruction is held.
|
// instruction is held.
|
||||||
tensorflow::gtl::FlatMap<int64, int64> handle_to_index_;
|
absl::flat_hash_map<int64, int64> handle_to_index_;
|
||||||
|
|
||||||
// The embedded computations used by this computation. Each computation was
|
// The embedded computations used by this computation. Each computation was
|
||||||
// the entry computation of some XlaComputation, the key is the unique id of
|
// the entry computation of some XlaComputation, the key is the unique id of
|
||||||
|
@ -146,6 +146,7 @@ cc_library(
|
|||||||
"//tensorflow/compiler/xla:shape_util",
|
"//tensorflow/compiler/xla:shape_util",
|
||||||
"//tensorflow/compiler/xla:util",
|
"//tensorflow/compiler/xla:util",
|
||||||
"//tensorflow/core:lib",
|
"//tensorflow/core:lib",
|
||||||
|
"@com_google_absl//absl/container:flat_hash_map",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -250,6 +251,7 @@ cc_library(
|
|||||||
"//tensorflow/compiler/xla:xla_data_proto",
|
"//tensorflow/compiler/xla:xla_data_proto",
|
||||||
"//tensorflow/core:lib",
|
"//tensorflow/core:lib",
|
||||||
"@com_google_absl//absl/algorithm:container",
|
"@com_google_absl//absl/algorithm:container",
|
||||||
|
"@com_google_absl//absl/container:flat_hash_map",
|
||||||
"@com_google_absl//absl/container:inlined_vector",
|
"@com_google_absl//absl/container:inlined_vector",
|
||||||
"@com_google_absl//absl/memory",
|
"@com_google_absl//absl/memory",
|
||||||
"@com_google_absl//absl/strings",
|
"@com_google_absl//absl/strings",
|
||||||
@ -333,6 +335,7 @@ cc_library(
|
|||||||
"//tensorflow/core:lib",
|
"//tensorflow/core:lib",
|
||||||
"//tensorflow/core:lib_internal",
|
"//tensorflow/core:lib_internal",
|
||||||
"@com_google_absl//absl/algorithm:container",
|
"@com_google_absl//absl/algorithm:container",
|
||||||
|
"@com_google_absl//absl/container:flat_hash_map",
|
||||||
"@com_google_absl//absl/container:inlined_vector",
|
"@com_google_absl//absl/container:inlined_vector",
|
||||||
"@com_google_absl//absl/memory",
|
"@com_google_absl//absl/memory",
|
||||||
"@com_google_absl//absl/strings",
|
"@com_google_absl//absl/strings",
|
||||||
@ -395,6 +398,7 @@ cc_library(
|
|||||||
"//tensorflow/compiler/xla:util",
|
"//tensorflow/compiler/xla:util",
|
||||||
"//tensorflow/core:lib",
|
"//tensorflow/core:lib",
|
||||||
"//tensorflow/core:lib_internal",
|
"//tensorflow/core:lib_internal",
|
||||||
|
"@com_google_absl//absl/container:flat_hash_map",
|
||||||
"@com_google_absl//absl/types:span",
|
"@com_google_absl//absl/types:span",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
@ -485,6 +489,7 @@ cc_library(
|
|||||||
"//tensorflow/compiler/xla:status_macros",
|
"//tensorflow/compiler/xla:status_macros",
|
||||||
"//tensorflow/compiler/xla:util",
|
"//tensorflow/compiler/xla:util",
|
||||||
"//tensorflow/core:lib",
|
"//tensorflow/core:lib",
|
||||||
|
"@com_google_absl//absl/container:flat_hash_map",
|
||||||
"@com_google_absl//absl/memory",
|
"@com_google_absl//absl/memory",
|
||||||
"@com_google_absl//absl/strings",
|
"@com_google_absl//absl/strings",
|
||||||
"@com_google_absl//absl/strings:str_format",
|
"@com_google_absl//absl/strings:str_format",
|
||||||
@ -903,6 +908,7 @@ cc_library(
|
|||||||
"//tensorflow/compiler/xla:util",
|
"//tensorflow/compiler/xla:util",
|
||||||
"//tensorflow/compiler/xla:xla_data_proto",
|
"//tensorflow/compiler/xla:xla_data_proto",
|
||||||
"//tensorflow/core:lib",
|
"//tensorflow/core:lib",
|
||||||
|
"@com_google_absl//absl/container:flat_hash_map",
|
||||||
"@com_google_absl//absl/memory",
|
"@com_google_absl//absl/memory",
|
||||||
"@com_google_absl//absl/strings",
|
"@com_google_absl//absl/strings",
|
||||||
],
|
],
|
||||||
@ -952,6 +958,7 @@ cc_library(
|
|||||||
deps = [
|
deps = [
|
||||||
"//tensorflow/compiler/xla:types",
|
"//tensorflow/compiler/xla:types",
|
||||||
"//tensorflow/core:lib",
|
"//tensorflow/core:lib",
|
||||||
|
"@com_google_absl//absl/container:flat_hash_map",
|
||||||
"@com_google_absl//absl/strings",
|
"@com_google_absl//absl/strings",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
@ -987,6 +994,7 @@ cc_library(
|
|||||||
"//tensorflow/compiler/xla:types",
|
"//tensorflow/compiler/xla:types",
|
||||||
"//tensorflow/compiler/xla:util",
|
"//tensorflow/compiler/xla:util",
|
||||||
"//tensorflow/core:lib",
|
"//tensorflow/core:lib",
|
||||||
|
"@com_google_absl//absl/container:flat_hash_map",
|
||||||
"@com_google_absl//absl/strings",
|
"@com_google_absl//absl/strings",
|
||||||
"@com_google_absl//absl/strings:str_format",
|
"@com_google_absl//absl/strings:str_format",
|
||||||
],
|
],
|
||||||
@ -1034,6 +1042,7 @@ cc_library(
|
|||||||
"//tensorflow/compiler/xla:util",
|
"//tensorflow/compiler/xla:util",
|
||||||
"//tensorflow/core:lib",
|
"//tensorflow/core:lib",
|
||||||
"//tensorflow/core:lib_internal",
|
"//tensorflow/core:lib_internal",
|
||||||
|
"@com_google_absl//absl/container:flat_hash_map",
|
||||||
"@com_google_absl//absl/memory",
|
"@com_google_absl//absl/memory",
|
||||||
"@com_google_absl//absl/strings",
|
"@com_google_absl//absl/strings",
|
||||||
"@com_google_absl//absl/strings:str_format",
|
"@com_google_absl//absl/strings:str_format",
|
||||||
@ -1087,6 +1096,7 @@ cc_library(
|
|||||||
"//tensorflow/compiler/xla:types",
|
"//tensorflow/compiler/xla:types",
|
||||||
"//tensorflow/compiler/xla:util",
|
"//tensorflow/compiler/xla:util",
|
||||||
"//tensorflow/core:lib",
|
"//tensorflow/core:lib",
|
||||||
|
"@com_google_absl//absl/container:flat_hash_map",
|
||||||
"@com_google_absl//absl/strings",
|
"@com_google_absl//absl/strings",
|
||||||
"@com_google_absl//absl/strings:str_format",
|
"@com_google_absl//absl/strings:str_format",
|
||||||
],
|
],
|
||||||
@ -1125,6 +1135,7 @@ cc_library(
|
|||||||
"//tensorflow/compiler/xla:statusor",
|
"//tensorflow/compiler/xla:statusor",
|
||||||
"//tensorflow/compiler/xla:util",
|
"//tensorflow/compiler/xla:util",
|
||||||
"//tensorflow/core:lib",
|
"//tensorflow/core:lib",
|
||||||
|
"@com_google_absl//absl/container:flat_hash_map",
|
||||||
"@com_google_absl//absl/memory",
|
"@com_google_absl//absl/memory",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
@ -1146,6 +1157,7 @@ tf_cc_test(
|
|||||||
"//tensorflow/compiler/xla/tests:xla_internal_test_main",
|
"//tensorflow/compiler/xla/tests:xla_internal_test_main",
|
||||||
"//tensorflow/core:lib",
|
"//tensorflow/core:lib",
|
||||||
"//tensorflow/core:test",
|
"//tensorflow/core:test",
|
||||||
|
"@com_google_absl//absl/container:flat_hash_map",
|
||||||
"@com_google_absl//absl/memory",
|
"@com_google_absl//absl/memory",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
@ -1196,6 +1208,7 @@ cc_library(
|
|||||||
"//tensorflow/compiler/xla:statusor",
|
"//tensorflow/compiler/xla:statusor",
|
||||||
"//tensorflow/compiler/xla:util",
|
"//tensorflow/compiler/xla:util",
|
||||||
"//tensorflow/core:lib",
|
"//tensorflow/core:lib",
|
||||||
|
"@com_google_absl//absl/container:flat_hash_map",
|
||||||
"@com_google_absl//absl/memory",
|
"@com_google_absl//absl/memory",
|
||||||
"@com_google_absl//absl/types:optional",
|
"@com_google_absl//absl/types:optional",
|
||||||
],
|
],
|
||||||
@ -1216,6 +1229,7 @@ cc_library(
|
|||||||
"//tensorflow/compiler/xla:types",
|
"//tensorflow/compiler/xla:types",
|
||||||
"//tensorflow/compiler/xla:util",
|
"//tensorflow/compiler/xla:util",
|
||||||
"//tensorflow/core:lib",
|
"//tensorflow/core:lib",
|
||||||
|
"@com_google_absl//absl/container:flat_hash_map",
|
||||||
"@com_google_absl//absl/memory",
|
"@com_google_absl//absl/memory",
|
||||||
"@com_google_absl//absl/strings",
|
"@com_google_absl//absl/strings",
|
||||||
"@com_google_absl//absl/types:span",
|
"@com_google_absl//absl/types:span",
|
||||||
@ -1260,6 +1274,7 @@ cc_library(
|
|||||||
"//tensorflow/compiler/xla:util",
|
"//tensorflow/compiler/xla:util",
|
||||||
"//tensorflow/core:lib",
|
"//tensorflow/core:lib",
|
||||||
"//tensorflow/core:lib_internal",
|
"//tensorflow/core:lib_internal",
|
||||||
|
"@com_google_absl//absl/container:flat_hash_map",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -1280,6 +1295,7 @@ tf_cc_test(
|
|||||||
"//tensorflow/compiler/xla/tests:xla_internal_test_main",
|
"//tensorflow/compiler/xla/tests:xla_internal_test_main",
|
||||||
"//tensorflow/core:test",
|
"//tensorflow/core:test",
|
||||||
"@com_google_absl//absl/algorithm:container",
|
"@com_google_absl//absl/algorithm:container",
|
||||||
|
"@com_google_absl//absl/container:flat_hash_map",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -1304,6 +1320,7 @@ cc_library(
|
|||||||
"//tensorflow/compiler/xla:util",
|
"//tensorflow/compiler/xla:util",
|
||||||
"//tensorflow/core:lib",
|
"//tensorflow/core:lib",
|
||||||
"@com_google_absl//absl/algorithm:container",
|
"@com_google_absl//absl/algorithm:container",
|
||||||
|
"@com_google_absl//absl/container:flat_hash_map",
|
||||||
"@com_google_absl//absl/memory",
|
"@com_google_absl//absl/memory",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
@ -1330,6 +1347,7 @@ cc_library(
|
|||||||
"//tensorflow/compiler/xla/service:hlo",
|
"//tensorflow/compiler/xla/service:hlo",
|
||||||
"//tensorflow/compiler/xla/service:hlo_pass",
|
"//tensorflow/compiler/xla/service:hlo_pass",
|
||||||
"//tensorflow/core:lib",
|
"//tensorflow/core:lib",
|
||||||
|
"@com_google_absl//absl/container:flat_hash_map",
|
||||||
"@com_google_absl//absl/strings",
|
"@com_google_absl//absl/strings",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
@ -1385,6 +1403,7 @@ cc_library(
|
|||||||
"//tensorflow/compiler/xla:util",
|
"//tensorflow/compiler/xla:util",
|
||||||
"//tensorflow/compiler/xla:xla_data_proto",
|
"//tensorflow/compiler/xla:xla_data_proto",
|
||||||
"//tensorflow/core:lib",
|
"//tensorflow/core:lib",
|
||||||
|
"@com_google_absl//absl/container:flat_hash_map",
|
||||||
"@com_google_absl//absl/types:optional",
|
"@com_google_absl//absl/types:optional",
|
||||||
"@com_google_absl//absl/types:span",
|
"@com_google_absl//absl/types:span",
|
||||||
],
|
],
|
||||||
@ -1640,6 +1659,7 @@ cc_library(
|
|||||||
":while_loop_analysis",
|
":while_loop_analysis",
|
||||||
"//tensorflow/compiler/xla:statusor",
|
"//tensorflow/compiler/xla:statusor",
|
||||||
"//tensorflow/core:lib",
|
"//tensorflow/core:lib",
|
||||||
|
"@com_google_absl//absl/container:flat_hash_map",
|
||||||
"@com_google_absl//absl/strings",
|
"@com_google_absl//absl/strings",
|
||||||
"@com_google_absl//absl/types:optional",
|
"@com_google_absl//absl/types:optional",
|
||||||
],
|
],
|
||||||
@ -1671,6 +1691,7 @@ cc_library(
|
|||||||
"//tensorflow/compiler/xla:types",
|
"//tensorflow/compiler/xla:types",
|
||||||
"//tensorflow/compiler/xla:util",
|
"//tensorflow/compiler/xla:util",
|
||||||
"//tensorflow/core:lib",
|
"//tensorflow/core:lib",
|
||||||
|
"@com_google_absl//absl/container:flat_hash_map",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -2203,6 +2224,7 @@ cc_library(
|
|||||||
"//tensorflow/compiler/xla:util",
|
"//tensorflow/compiler/xla:util",
|
||||||
"//tensorflow/compiler/xla:xla_data_proto",
|
"//tensorflow/compiler/xla:xla_data_proto",
|
||||||
"//tensorflow/core:lib",
|
"//tensorflow/core:lib",
|
||||||
|
"@com_google_absl//absl/container:flat_hash_map",
|
||||||
"@com_google_absl//absl/strings",
|
"@com_google_absl//absl/strings",
|
||||||
"@com_google_absl//absl/types:span",
|
"@com_google_absl//absl/types:span",
|
||||||
],
|
],
|
||||||
@ -2263,6 +2285,7 @@ cc_library(
|
|||||||
"//tensorflow/compiler/xla:util",
|
"//tensorflow/compiler/xla:util",
|
||||||
"//tensorflow/compiler/xla:xla_data_proto",
|
"//tensorflow/compiler/xla:xla_data_proto",
|
||||||
"//tensorflow/core:lib",
|
"//tensorflow/core:lib",
|
||||||
|
"@com_google_absl//absl/container:flat_hash_map",
|
||||||
"@com_google_absl//absl/container:inlined_vector",
|
"@com_google_absl//absl/container:inlined_vector",
|
||||||
"@com_google_absl//absl/memory",
|
"@com_google_absl//absl/memory",
|
||||||
"@com_google_absl//absl/strings",
|
"@com_google_absl//absl/strings",
|
||||||
@ -2319,6 +2342,7 @@ cc_library(
|
|||||||
"//tensorflow/compiler/xla:util",
|
"//tensorflow/compiler/xla:util",
|
||||||
"//tensorflow/compiler/xla:xla_data_proto",
|
"//tensorflow/compiler/xla:xla_data_proto",
|
||||||
"//tensorflow/core:lib",
|
"//tensorflow/core:lib",
|
||||||
|
"@com_google_absl//absl/container:flat_hash_map",
|
||||||
"@com_google_absl//absl/memory",
|
"@com_google_absl//absl/memory",
|
||||||
"@com_google_absl//absl/strings",
|
"@com_google_absl//absl/strings",
|
||||||
"@com_google_absl//absl/strings:str_format",
|
"@com_google_absl//absl/strings:str_format",
|
||||||
@ -2345,6 +2369,7 @@ cc_library(
|
|||||||
"//tensorflow/compiler/xla:types",
|
"//tensorflow/compiler/xla:types",
|
||||||
"//tensorflow/compiler/xla:util",
|
"//tensorflow/compiler/xla:util",
|
||||||
"//tensorflow/core:lib",
|
"//tensorflow/core:lib",
|
||||||
|
"@com_google_absl//absl/container:flat_hash_map",
|
||||||
"@com_google_absl//absl/strings",
|
"@com_google_absl//absl/strings",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
@ -2416,6 +2441,7 @@ cc_library(
|
|||||||
"//tensorflow/compiler/xla:status_macros",
|
"//tensorflow/compiler/xla:status_macros",
|
||||||
"//tensorflow/compiler/xla:util",
|
"//tensorflow/compiler/xla:util",
|
||||||
"//tensorflow/core:lib",
|
"//tensorflow/core:lib",
|
||||||
|
"@com_google_absl//absl/container:flat_hash_map",
|
||||||
"@com_google_absl//absl/memory",
|
"@com_google_absl//absl/memory",
|
||||||
"@com_google_absl//absl/strings",
|
"@com_google_absl//absl/strings",
|
||||||
],
|
],
|
||||||
@ -2460,6 +2486,7 @@ cc_library(
|
|||||||
"//tensorflow/compiler/xla:util",
|
"//tensorflow/compiler/xla:util",
|
||||||
"//tensorflow/core:lib",
|
"//tensorflow/core:lib",
|
||||||
"//tensorflow/core:lib_internal",
|
"//tensorflow/core:lib_internal",
|
||||||
|
"@com_google_absl//absl/container:flat_hash_map",
|
||||||
"@com_google_absl//absl/container:inlined_vector",
|
"@com_google_absl//absl/container:inlined_vector",
|
||||||
"@com_google_absl//absl/strings",
|
"@com_google_absl//absl/strings",
|
||||||
"@com_google_absl//absl/strings:str_format",
|
"@com_google_absl//absl/strings:str_format",
|
||||||
@ -2588,6 +2615,7 @@ cc_library(
|
|||||||
"//tensorflow/compiler/xla:types",
|
"//tensorflow/compiler/xla:types",
|
||||||
"//tensorflow/compiler/xla:util",
|
"//tensorflow/compiler/xla:util",
|
||||||
"//tensorflow/core:lib",
|
"//tensorflow/core:lib",
|
||||||
|
"@com_google_absl//absl/container:flat_hash_map",
|
||||||
"@com_google_absl//absl/memory",
|
"@com_google_absl//absl/memory",
|
||||||
"@com_google_absl//absl/strings",
|
"@com_google_absl//absl/strings",
|
||||||
"@com_google_absl//absl/strings:str_format",
|
"@com_google_absl//absl/strings:str_format",
|
||||||
@ -2701,6 +2729,7 @@ cc_library(
|
|||||||
"//tensorflow/compiler/xla:types",
|
"//tensorflow/compiler/xla:types",
|
||||||
"//tensorflow/compiler/xla:util",
|
"//tensorflow/compiler/xla:util",
|
||||||
"//tensorflow/core:lib",
|
"//tensorflow/core:lib",
|
||||||
|
"@com_google_absl//absl/container:flat_hash_map",
|
||||||
"@com_google_absl//absl/memory",
|
"@com_google_absl//absl/memory",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
@ -3147,6 +3176,7 @@ cc_library(
|
|||||||
":hlo_pass_pipeline",
|
":hlo_pass_pipeline",
|
||||||
"//tensorflow/compiler/xla:shape_util",
|
"//tensorflow/compiler/xla:shape_util",
|
||||||
"//tensorflow/core:lib",
|
"//tensorflow/core:lib",
|
||||||
|
"@com_google_absl//absl/container:flat_hash_map",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -3269,6 +3299,7 @@ cc_library(
|
|||||||
"//tensorflow/compiler/xla:util",
|
"//tensorflow/compiler/xla:util",
|
||||||
"//tensorflow/core:lib",
|
"//tensorflow/core:lib",
|
||||||
"@com_google_absl//absl/algorithm:container",
|
"@com_google_absl//absl/algorithm:container",
|
||||||
|
"@com_google_absl//absl/container:flat_hash_map",
|
||||||
"@com_google_absl//absl/container:inlined_vector",
|
"@com_google_absl//absl/container:inlined_vector",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
@ -3298,6 +3329,7 @@ cc_library(
|
|||||||
"//tensorflow/compiler/xla:util",
|
"//tensorflow/compiler/xla:util",
|
||||||
"//tensorflow/core:lib",
|
"//tensorflow/core:lib",
|
||||||
"@com_google_absl//absl/algorithm:container",
|
"@com_google_absl//absl/algorithm:container",
|
||||||
|
"@com_google_absl//absl/container:flat_hash_map",
|
||||||
"@com_google_absl//absl/container:inlined_vector",
|
"@com_google_absl//absl/container:inlined_vector",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
@ -3354,6 +3386,7 @@ cc_library(
|
|||||||
"//tensorflow/core:lib",
|
"//tensorflow/core:lib",
|
||||||
"//tensorflow/core:ptr_util",
|
"//tensorflow/core:ptr_util",
|
||||||
"@com_google_absl//absl/algorithm:container",
|
"@com_google_absl//absl/algorithm:container",
|
||||||
|
"@com_google_absl//absl/container:flat_hash_map",
|
||||||
"@com_google_absl//absl/container:inlined_vector",
|
"@com_google_absl//absl/container:inlined_vector",
|
||||||
"@com_google_absl//absl/strings",
|
"@com_google_absl//absl/strings",
|
||||||
"@com_google_absl//absl/types:optional",
|
"@com_google_absl//absl/types:optional",
|
||||||
|
@ -22,6 +22,7 @@ limitations under the License.
|
|||||||
#include <string>
|
#include <string>
|
||||||
#include <vector>
|
#include <vector>
|
||||||
|
|
||||||
|
#include "absl/container/flat_hash_map.h"
|
||||||
#include "tensorflow/compiler/xla/service/backend.h"
|
#include "tensorflow/compiler/xla/service/backend.h"
|
||||||
#include "tensorflow/compiler/xla/statusor.h"
|
#include "tensorflow/compiler/xla/statusor.h"
|
||||||
#include "tensorflow/compiler/xla/types.h"
|
#include "tensorflow/compiler/xla/types.h"
|
||||||
@ -110,7 +111,7 @@ class AllocationTracker {
|
|||||||
|
|
||||||
// A map from device memory opaque value to allocation. One such map is
|
// A map from device memory opaque value to allocation. One such map is
|
||||||
// maintained per device ordinal.
|
// maintained per device ordinal.
|
||||||
using AllocationMap = tensorflow::gtl::FlatMap<const void*, Allocation>;
|
using AllocationMap = absl::flat_hash_map<const void*, Allocation>;
|
||||||
|
|
||||||
tensorflow::mutex mutex_;
|
tensorflow::mutex mutex_;
|
||||||
|
|
||||||
@ -146,7 +147,7 @@ class AllocationTracker {
|
|||||||
// non-owning "view" into a tuple's sub-buffers. The sub-buffers are then
|
// non-owning "view" into a tuple's sub-buffers. The sub-buffers are then
|
||||||
// free'd when both the view *and* the original tuple are Unregistered. This
|
// free'd when both the view *and* the original tuple are Unregistered. This
|
||||||
// refcounting is managed in opaque_to_allocation_map_.
|
// refcounting is managed in opaque_to_allocation_map_.
|
||||||
tensorflow::gtl::FlatMap<int64, std::vector<std::unique_ptr<ShapedBuffer>>>
|
absl::flat_hash_map<int64, std::vector<std::unique_ptr<ShapedBuffer>>>
|
||||||
handle_to_shaped_buffers_ GUARDED_BY(mutex_);
|
handle_to_shaped_buffers_ GUARDED_BY(mutex_);
|
||||||
|
|
||||||
TF_DISALLOW_COPY_AND_ASSIGN(AllocationTracker);
|
TF_DISALLOW_COPY_AND_ASSIGN(AllocationTracker);
|
||||||
|
@ -35,7 +35,6 @@ limitations under the License.
|
|||||||
#include "tensorflow/compiler/xla/xla_data.pb.h"
|
#include "tensorflow/compiler/xla/xla_data.pb.h"
|
||||||
#include "tensorflow/core/lib/core/errors.h"
|
#include "tensorflow/core/lib/core/errors.h"
|
||||||
#include "tensorflow/core/lib/core/status.h"
|
#include "tensorflow/core/lib/core/status.h"
|
||||||
#include "tensorflow/core/lib/gtl/flatmap.h"
|
|
||||||
#include "tensorflow/core/platform/logging.h"
|
#include "tensorflow/core/platform/logging.h"
|
||||||
#include "tensorflow/core/platform/types.h"
|
#include "tensorflow/core/platform/types.h"
|
||||||
|
|
||||||
|
@ -21,6 +21,7 @@ limitations under the License.
|
|||||||
#include <unordered_set>
|
#include <unordered_set>
|
||||||
#include <vector>
|
#include <vector>
|
||||||
|
|
||||||
|
#include "absl/container/flat_hash_map.h"
|
||||||
#include "tensorflow/compiler/xla/service/bfloat16_support.h"
|
#include "tensorflow/compiler/xla/service/bfloat16_support.h"
|
||||||
#include "tensorflow/compiler/xla/service/hlo_dataflow_analysis.h"
|
#include "tensorflow/compiler/xla/service/hlo_dataflow_analysis.h"
|
||||||
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
|
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
|
||||||
@ -186,7 +187,7 @@ class BFloat16Propagation : public HloModulePass {
|
|||||||
|
|
||||||
// Mapping from each HloComputation to the number of callers to it in the
|
// Mapping from each HloComputation to the number of callers to it in the
|
||||||
// module. Populated at the beginning of this pass.
|
// module. Populated at the beginning of this pass.
|
||||||
tensorflow::gtl::FlatMap<const HloComputation*, int64> caller_counts_;
|
absl::flat_hash_map<const HloComputation*, int64> caller_counts_;
|
||||||
|
|
||||||
// We first store the potential F32-to-BF16 changes to changes_to_bf16_, which
|
// We first store the potential F32-to-BF16 changes to changes_to_bf16_, which
|
||||||
// are subject to further adjustment, then finally applied to the HLOs. This
|
// are subject to further adjustment, then finally applied to the HLOs. This
|
||||||
@ -195,8 +196,7 @@ class BFloat16Propagation : public HloModulePass {
|
|||||||
//
|
//
|
||||||
// For each HloInstruction, changes_to_bf16_ stores the affected buffers in
|
// For each HloInstruction, changes_to_bf16_ stores the affected buffers in
|
||||||
// the output as a map from in-place pointers to subshapes to shape indices.
|
// the output as a map from in-place pointers to subshapes to shape indices.
|
||||||
tensorflow::gtl::FlatMap<HloInstruction*,
|
absl::flat_hash_map<HloInstruction*, absl::flat_hash_map<Shape*, ShapeIndex>>
|
||||||
tensorflow::gtl::FlatMap<Shape*, ShapeIndex>>
|
|
||||||
changes_to_bf16_;
|
changes_to_bf16_;
|
||||||
|
|
||||||
// Whether the last processed HLO module has been changed by this pass.
|
// Whether the last processed HLO module has been changed by this pass.
|
||||||
|
@ -22,6 +22,7 @@ limitations under the License.
|
|||||||
#include <ostream>
|
#include <ostream>
|
||||||
#include <utility>
|
#include <utility>
|
||||||
|
|
||||||
|
#include "absl/container/flat_hash_map.h"
|
||||||
#include "absl/memory/memory.h"
|
#include "absl/memory/memory.h"
|
||||||
#include "absl/strings/str_cat.h"
|
#include "absl/strings/str_cat.h"
|
||||||
#include "absl/strings/str_format.h"
|
#include "absl/strings/str_format.h"
|
||||||
@ -41,9 +42,9 @@ limitations under the License.
|
|||||||
namespace xla {
|
namespace xla {
|
||||||
namespace {
|
namespace {
|
||||||
|
|
||||||
|
using absl::flat_hash_map;
|
||||||
using absl::StrAppend;
|
using absl::StrAppend;
|
||||||
using absl::StrAppendFormat;
|
using absl::StrAppendFormat;
|
||||||
using ::tensorflow::gtl::FlatMap;
|
|
||||||
using ::tensorflow::gtl::FlatSet;
|
using ::tensorflow::gtl::FlatSet;
|
||||||
using ::tensorflow::strings::HumanReadableNumBytes;
|
using ::tensorflow::strings::HumanReadableNumBytes;
|
||||||
|
|
||||||
@ -519,7 +520,8 @@ void BufferAssignment::AddAssignment(BufferAllocation* allocation,
|
|||||||
// BufferAllocation.
|
// BufferAllocation.
|
||||||
void BufferAssignment::CombineTempAllocations() {
|
void BufferAssignment::CombineTempAllocations() {
|
||||||
VLOG(1) << "CombineTempAllocations()";
|
VLOG(1) << "CombineTempAllocations()";
|
||||||
FlatMap<LogicalBuffer::Color, BufferAllocation, LogicalBuffer::Color::Hasher>
|
flat_hash_map<LogicalBuffer::Color, BufferAllocation,
|
||||||
|
LogicalBuffer::Color::Hasher>
|
||||||
combined_allocation_map;
|
combined_allocation_map;
|
||||||
|
|
||||||
// Move all temp allocations into a single run at the end of the allocations
|
// Move all temp allocations into a single run at the end of the allocations
|
||||||
@ -582,7 +584,8 @@ void BufferAssignment::CombineTempAllocations() {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Update allocation indices to their new positions.
|
// Update allocation indices to their new positions.
|
||||||
allocation_index_for_buffer_.clear_no_resize();
|
allocation_index_for_buffer_.erase(allocation_index_for_buffer_.begin(),
|
||||||
|
allocation_index_for_buffer_.end());
|
||||||
for (size_t index = 0; index < allocations_.size(); ++index) {
|
for (size_t index = 0; index < allocations_.size(); ++index) {
|
||||||
BufferAllocation* allocation = &allocations_[index];
|
BufferAllocation* allocation = &allocations_[index];
|
||||||
allocation->set_index(index);
|
allocation->set_index(index);
|
||||||
@ -814,7 +817,7 @@ Status BufferAssigner::AssignBuffersForComputation(
|
|||||||
const HloComputation* computation, bool is_thread_local,
|
const HloComputation* computation, bool is_thread_local,
|
||||||
const FlatSet<const LogicalBuffer*>& colocated_buffers,
|
const FlatSet<const LogicalBuffer*>& colocated_buffers,
|
||||||
const FlatSet<BufferAllocation::Index>& colocated_allocations,
|
const FlatSet<BufferAllocation::Index>& colocated_allocations,
|
||||||
FlatMap<const HloComputation*, FlatSet<const LogicalBuffer*>>*
|
flat_hash_map<const HloComputation*, FlatSet<const LogicalBuffer*>>*
|
||||||
buffers_to_assign_sequentially,
|
buffers_to_assign_sequentially,
|
||||||
BufferAssignment* assignment) {
|
BufferAssignment* assignment) {
|
||||||
// Buffers are sorted and assigned to BufferAllocations in decreasing order of
|
// Buffers are sorted and assigned to BufferAllocations in decreasing order of
|
||||||
@ -833,7 +836,7 @@ Status BufferAssigner::AssignBuffersForComputation(
|
|||||||
|
|
||||||
// Generate a post order sort of instructions for sorting of the
|
// Generate a post order sort of instructions for sorting of the
|
||||||
// LogicalBuffers.
|
// LogicalBuffers.
|
||||||
FlatMap<const HloInstruction*, int> post_order_position;
|
flat_hash_map<const HloInstruction*, int> post_order_position;
|
||||||
int position = 0;
|
int position = 0;
|
||||||
for (auto* instruction : computation->MakeInstructionPostOrder()) {
|
for (auto* instruction : computation->MakeInstructionPostOrder()) {
|
||||||
post_order_position.emplace(instruction, position);
|
post_order_position.emplace(instruction, position);
|
||||||
@ -1043,12 +1046,12 @@ Status BufferAssigner::AssignBuffersForComputation(
|
|||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
|
||||||
FlatMap<LogicalBuffer::Color, FlatSet<const LogicalBuffer*>,
|
flat_hash_map<LogicalBuffer::Color, FlatSet<const LogicalBuffer*>,
|
||||||
LogicalBuffer::Color::Hasher>
|
LogicalBuffer::Color::Hasher>
|
||||||
BufferAssigner::SplitBuffersByColor(
|
BufferAssigner::SplitBuffersByColor(
|
||||||
const FlatSet<const LogicalBuffer*>& buffers) {
|
const FlatSet<const LogicalBuffer*>& buffers) {
|
||||||
FlatMap<LogicalBuffer::Color, FlatSet<const LogicalBuffer*>,
|
flat_hash_map<LogicalBuffer::Color, FlatSet<const LogicalBuffer*>,
|
||||||
LogicalBuffer::Color::Hasher>
|
LogicalBuffer::Color::Hasher>
|
||||||
color_map;
|
color_map;
|
||||||
for (auto buffer : buffers) {
|
for (auto buffer : buffers) {
|
||||||
color_map[buffer->color()].insert(buffer);
|
color_map[buffer->color()].insert(buffer);
|
||||||
@ -1057,7 +1060,7 @@ BufferAssigner::SplitBuffersByColor(
|
|||||||
}
|
}
|
||||||
|
|
||||||
Status BufferAssigner::AssignBuffersWithSequentialOrdering(
|
Status BufferAssigner::AssignBuffersWithSequentialOrdering(
|
||||||
const FlatMap<const HloComputation*, FlatSet<const LogicalBuffer*>>&
|
const flat_hash_map<const HloComputation*, FlatSet<const LogicalBuffer*>>&
|
||||||
buffers_to_assign_sequentially,
|
buffers_to_assign_sequentially,
|
||||||
bool run_whole_module_heap_simulation, BufferAssignment* assignment) {
|
bool run_whole_module_heap_simulation, BufferAssignment* assignment) {
|
||||||
// Run the sequence of instructions through the heap simulator. The heuristic
|
// Run the sequence of instructions through the heap simulator. The heuristic
|
||||||
@ -1155,9 +1158,8 @@ std::vector<const LogicalBuffer*> ComputePeakMemoryLogicalBuffers(
|
|||||||
const BufferAllocation& allocation, const HeapSimulatorTrace& heap_trace) {
|
const BufferAllocation& allocation, const HeapSimulatorTrace& heap_trace) {
|
||||||
// Create a map from LogicalBuffer::Id to LogicalBuffer* for the logical
|
// Create a map from LogicalBuffer::Id to LogicalBuffer* for the logical
|
||||||
// buffers in this allocation.
|
// buffers in this allocation.
|
||||||
tensorflow::gtl::FlatMap<LogicalBuffer::Id, const LogicalBuffer*>
|
absl::flat_hash_map<LogicalBuffer::Id, const LogicalBuffer*> id_to_buffer;
|
||||||
id_to_buffer;
|
absl::flat_hash_map<const LogicalBuffer*, int64> buffer_sizes;
|
||||||
tensorflow::gtl::FlatMap<const LogicalBuffer*, int64> buffer_sizes;
|
|
||||||
for (const auto& pair : allocation.assigned_buffers()) {
|
for (const auto& pair : allocation.assigned_buffers()) {
|
||||||
const LogicalBuffer* buffer = pair.first;
|
const LogicalBuffer* buffer = pair.first;
|
||||||
const BufferAllocation::OffsetSize& offset_size = pair.second;
|
const BufferAllocation::OffsetSize& offset_size = pair.second;
|
||||||
@ -1679,7 +1681,7 @@ StatusOr<std::unique_ptr<BufferAssignment>> BufferAssigner::CreateAssignment(
|
|||||||
|
|
||||||
// First assign buffers for global computatations. Temporary buffers for
|
// First assign buffers for global computatations. Temporary buffers for
|
||||||
// sequential computations are collected in 'buffers_to_assign_sequentially'.
|
// sequential computations are collected in 'buffers_to_assign_sequentially'.
|
||||||
FlatMap<const HloComputation*, FlatSet<const LogicalBuffer*>>
|
flat_hash_map<const HloComputation*, FlatSet<const LogicalBuffer*>>
|
||||||
buffers_to_assign_sequentially;
|
buffers_to_assign_sequentially;
|
||||||
for (auto* computation : global_computations) {
|
for (auto* computation : global_computations) {
|
||||||
TF_RETURN_IF_ERROR(AssignBuffersForComputation(
|
TF_RETURN_IF_ERROR(AssignBuffersForComputation(
|
||||||
|
@ -22,6 +22,7 @@ limitations under the License.
|
|||||||
#include <string>
|
#include <string>
|
||||||
#include <vector>
|
#include <vector>
|
||||||
|
|
||||||
|
#include "absl/container/flat_hash_map.h"
|
||||||
#include "absl/types/span.h"
|
#include "absl/types/span.h"
|
||||||
#include "tensorflow/compiler/xla/service/buffer_liveness.h"
|
#include "tensorflow/compiler/xla/service/buffer_liveness.h"
|
||||||
#include "tensorflow/compiler/xla/service/heap_simulator.h"
|
#include "tensorflow/compiler/xla/service/heap_simulator.h"
|
||||||
@ -33,7 +34,6 @@ limitations under the License.
|
|||||||
#include "tensorflow/compiler/xla/service/tuple_points_to_analysis.h"
|
#include "tensorflow/compiler/xla/service/tuple_points_to_analysis.h"
|
||||||
#include "tensorflow/compiler/xla/statusor.h"
|
#include "tensorflow/compiler/xla/statusor.h"
|
||||||
#include "tensorflow/compiler/xla/types.h"
|
#include "tensorflow/compiler/xla/types.h"
|
||||||
#include "tensorflow/core/lib/gtl/flatmap.h"
|
|
||||||
#include "tensorflow/core/lib/gtl/flatset.h"
|
#include "tensorflow/core/lib/gtl/flatset.h"
|
||||||
#include "tensorflow/core/platform/logging.h"
|
#include "tensorflow/core/platform/logging.h"
|
||||||
#include "tensorflow/core/platform/macros.h"
|
#include "tensorflow/core/platform/macros.h"
|
||||||
@ -148,7 +148,7 @@ class BufferAllocation {
|
|||||||
|
|
||||||
// Access to the logical buffers assigned to this allocation, and their
|
// Access to the logical buffers assigned to this allocation, and their
|
||||||
// associated logical offsets and sizes.
|
// associated logical offsets and sizes.
|
||||||
const tensorflow::gtl::FlatMap<const LogicalBuffer*, OffsetSize>&
|
const absl::flat_hash_map<const LogicalBuffer*, OffsetSize>&
|
||||||
assigned_buffers() const {
|
assigned_buffers() const {
|
||||||
return assigned_buffers_;
|
return assigned_buffers_;
|
||||||
}
|
}
|
||||||
@ -323,7 +323,7 @@ class BufferAllocation {
|
|||||||
|
|
||||||
// Mapping from the set of buffers assigned to this allocation to their
|
// Mapping from the set of buffers assigned to this allocation to their
|
||||||
// logical offsets and sizes.
|
// logical offsets and sizes.
|
||||||
tensorflow::gtl::FlatMap<const LogicalBuffer*, OffsetSize> assigned_buffers_;
|
absl::flat_hash_map<const LogicalBuffer*, OffsetSize> assigned_buffers_;
|
||||||
|
|
||||||
int64 fragmentation_bytes_ = 0;
|
int64 fragmentation_bytes_ = 0;
|
||||||
std::vector<HeapSimulatorTrace> heap_traces_;
|
std::vector<HeapSimulatorTrace> heap_traces_;
|
||||||
@ -500,7 +500,7 @@ class BufferAssignment {
|
|||||||
int64 temp_allocation_total_size_ = 0;
|
int64 temp_allocation_total_size_ = 0;
|
||||||
|
|
||||||
// Maps Buffers to the index of the BufferAllocation which holds the buffer.
|
// Maps Buffers to the index of the BufferAllocation which holds the buffer.
|
||||||
tensorflow::gtl::FlatMap<const LogicalBuffer*, BufferAllocation::Index>
|
absl::flat_hash_map<const LogicalBuffer*, BufferAllocation::Index>
|
||||||
allocation_index_for_buffer_;
|
allocation_index_for_buffer_;
|
||||||
|
|
||||||
const HloModule* module_;
|
const HloModule* module_;
|
||||||
@ -557,8 +557,8 @@ class BufferAssigner {
|
|||||||
const tensorflow::gtl::FlatSet<const LogicalBuffer*>& colocated_buffers,
|
const tensorflow::gtl::FlatSet<const LogicalBuffer*>& colocated_buffers,
|
||||||
const tensorflow::gtl::FlatSet<BufferAllocation::Index>&
|
const tensorflow::gtl::FlatSet<BufferAllocation::Index>&
|
||||||
colocated_allocations,
|
colocated_allocations,
|
||||||
tensorflow::gtl::FlatMap<const HloComputation*,
|
absl::flat_hash_map<const HloComputation*,
|
||||||
tensorflow::gtl::FlatSet<const LogicalBuffer*>>*
|
tensorflow::gtl::FlatSet<const LogicalBuffer*>>*
|
||||||
buffers_to_assign_sequentially,
|
buffers_to_assign_sequentially,
|
||||||
BufferAssignment* assignment);
|
BufferAssignment* assignment);
|
||||||
|
|
||||||
@ -568,9 +568,8 @@ class BufferAssigner {
|
|||||||
// 'run_whole_module_heap_simulation' is true, the heap simulation will be run
|
// 'run_whole_module_heap_simulation' is true, the heap simulation will be run
|
||||||
// assuming all global computations are sequentially ordered.
|
// assuming all global computations are sequentially ordered.
|
||||||
Status AssignBuffersWithSequentialOrdering(
|
Status AssignBuffersWithSequentialOrdering(
|
||||||
const tensorflow::gtl::FlatMap<
|
const absl::flat_hash_map<const HloComputation*,
|
||||||
const HloComputation*,
|
tensorflow::gtl::FlatSet<const LogicalBuffer*>>&
|
||||||
tensorflow::gtl::FlatSet<const LogicalBuffer*>>&
|
|
||||||
buffers_to_assign_sequentially,
|
buffers_to_assign_sequentially,
|
||||||
bool run_whole_module_heap_simulation, BufferAssignment* assignment);
|
bool run_whole_module_heap_simulation, BufferAssignment* assignment);
|
||||||
|
|
||||||
@ -624,9 +623,9 @@ class BufferAssigner {
|
|||||||
|
|
||||||
// Split a set of buffers into several sets, each of which contains buffers
|
// Split a set of buffers into several sets, each of which contains buffers
|
||||||
// colored with the same color.
|
// colored with the same color.
|
||||||
tensorflow::gtl::FlatMap<LogicalBuffer::Color,
|
absl::flat_hash_map<LogicalBuffer::Color,
|
||||||
tensorflow::gtl::FlatSet<const LogicalBuffer*>,
|
tensorflow::gtl::FlatSet<const LogicalBuffer*>,
|
||||||
LogicalBuffer::Color::Hasher>
|
LogicalBuffer::Color::Hasher>
|
||||||
SplitBuffersByColor(
|
SplitBuffersByColor(
|
||||||
const tensorflow::gtl::FlatSet<const LogicalBuffer*>& buffers);
|
const tensorflow::gtl::FlatSet<const LogicalBuffer*>& buffers);
|
||||||
|
|
||||||
|
@ -27,7 +27,6 @@ limitations under the License.
|
|||||||
#include "tensorflow/compiler/xla/statusor.h"
|
#include "tensorflow/compiler/xla/statusor.h"
|
||||||
#include "tensorflow/compiler/xla/types.h"
|
#include "tensorflow/compiler/xla/types.h"
|
||||||
#include "tensorflow/core/lib/core/status.h"
|
#include "tensorflow/core/lib/core/status.h"
|
||||||
#include "tensorflow/core/lib/gtl/flatmap.h"
|
|
||||||
#include "tensorflow/core/lib/gtl/flatset.h"
|
#include "tensorflow/core/lib/gtl/flatset.h"
|
||||||
|
|
||||||
namespace xla {
|
namespace xla {
|
||||||
|
@ -20,10 +20,10 @@ limitations under the License.
|
|||||||
|
|
||||||
#include <ostream>
|
#include <ostream>
|
||||||
|
|
||||||
|
#include "absl/container/flat_hash_map.h"
|
||||||
#include "tensorflow/compiler/xla/service/hlo_computation.h"
|
#include "tensorflow/compiler/xla/service/hlo_computation.h"
|
||||||
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
|
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
|
||||||
#include "tensorflow/compiler/xla/service/hlo_module.h"
|
#include "tensorflow/compiler/xla/service/hlo_module.h"
|
||||||
#include "tensorflow/core/lib/gtl/flatmap.h"
|
|
||||||
#include "tensorflow/core/lib/gtl/flatset.h"
|
#include "tensorflow/core/lib/gtl/flatset.h"
|
||||||
|
|
||||||
namespace xla {
|
namespace xla {
|
||||||
@ -157,7 +157,7 @@ class CallGraphNode {
|
|||||||
|
|
||||||
// The map from instruction to index in callsites_ for looking up the callsite
|
// The map from instruction to index in callsites_ for looking up the callsite
|
||||||
// (if any) associated with a particular instruction in this computation.
|
// (if any) associated with a particular instruction in this computation.
|
||||||
tensorflow::gtl::FlatMap<const HloInstruction*, int64> callsite_instructions_;
|
absl::flat_hash_map<const HloInstruction*, int64> callsite_instructions_;
|
||||||
|
|
||||||
// The call sites in other computations which call this computation.
|
// The call sites in other computations which call this computation.
|
||||||
std::vector<CallSite> caller_callsites_;
|
std::vector<CallSite> caller_callsites_;
|
||||||
@ -267,7 +267,7 @@ class CallGraph {
|
|||||||
|
|
||||||
// Map from HLO computation to the index of the corresponding call graph node
|
// Map from HLO computation to the index of the corresponding call graph node
|
||||||
// in nodes_.
|
// in nodes_.
|
||||||
tensorflow::gtl::FlatMap<const HloComputation*, int64> node_indices_;
|
absl::flat_hash_map<const HloComputation*, int64> node_indices_;
|
||||||
};
|
};
|
||||||
|
|
||||||
} // namespace xla
|
} // namespace xla
|
||||||
|
@ -15,6 +15,7 @@ limitations under the License.
|
|||||||
|
|
||||||
#include "tensorflow/compiler/xla/service/copy_insertion.h"
|
#include "tensorflow/compiler/xla/service/copy_insertion.h"
|
||||||
|
|
||||||
|
#include "absl/container/flat_hash_map.h"
|
||||||
#include "absl/strings/str_cat.h"
|
#include "absl/strings/str_cat.h"
|
||||||
#include "absl/strings/str_join.h"
|
#include "absl/strings/str_join.h"
|
||||||
#include "tensorflow/compiler/xla/service/hlo_alias_analysis.h"
|
#include "tensorflow/compiler/xla/service/hlo_alias_analysis.h"
|
||||||
@ -31,7 +32,6 @@ limitations under the License.
|
|||||||
#include "tensorflow/compiler/xla/statusor.h"
|
#include "tensorflow/compiler/xla/statusor.h"
|
||||||
#include "tensorflow/compiler/xla/types.h"
|
#include "tensorflow/compiler/xla/types.h"
|
||||||
#include "tensorflow/compiler/xla/util.h"
|
#include "tensorflow/compiler/xla/util.h"
|
||||||
#include "tensorflow/core/lib/gtl/flatmap.h"
|
|
||||||
#include "tensorflow/core/lib/gtl/flatset.h"
|
#include "tensorflow/core/lib/gtl/flatset.h"
|
||||||
#include "tensorflow/core/platform/logging.h"
|
#include "tensorflow/core/platform/logging.h"
|
||||||
|
|
||||||
@ -432,7 +432,7 @@ class CopyRemover {
|
|||||||
// Construct a list for each HLO buffer in the alias analysis. Maintain a
|
// Construct a list for each HLO buffer in the alias analysis. Maintain a
|
||||||
// map from HloValue to the respective list element representing that
|
// map from HloValue to the respective list element representing that
|
||||||
// value. The map is used to construct the copy info map below.
|
// value. The map is used to construct the copy info map below.
|
||||||
tensorflow::gtl::FlatMap<const HloValue*, ValueNode*> value_to_node;
|
absl::flat_hash_map<const HloValue*, ValueNode*> value_to_node;
|
||||||
for (const HloBuffer& buffer : alias_analysis.buffers()) {
|
for (const HloBuffer& buffer : alias_analysis.buffers()) {
|
||||||
// Verify values contained in the buffer are strictly ordered. This
|
// Verify values contained in the buffer are strictly ordered. This
|
||||||
// should always be the case after adding copies to eliminate
|
// should always be the case after adding copies to eliminate
|
||||||
@ -480,7 +480,7 @@ class CopyRemover {
|
|||||||
// respective ValueNode representing that value.
|
// respective ValueNode representing that value.
|
||||||
void AddValueList(
|
void AddValueList(
|
||||||
absl::Span<const HloValue* const> values,
|
absl::Span<const HloValue* const> values,
|
||||||
tensorflow::gtl::FlatMap<const HloValue*, ValueNode*>* value_to_node) {
|
absl::flat_hash_map<const HloValue*, ValueNode*>* value_to_node) {
|
||||||
ValueNode* tail = nullptr;
|
ValueNode* tail = nullptr;
|
||||||
ValueNode* head = nullptr;
|
ValueNode* head = nullptr;
|
||||||
for (const HloValue* value : values) {
|
for (const HloValue* value : values) {
|
||||||
@ -516,8 +516,7 @@ class CopyRemover {
|
|||||||
// respective ValueNode.
|
// respective ValueNode.
|
||||||
void CreateCopyMap(
|
void CreateCopyMap(
|
||||||
const HloModule& module,
|
const HloModule& module,
|
||||||
const tensorflow::gtl::FlatMap<const HloValue*, ValueNode*>&
|
const absl::flat_hash_map<const HloValue*, ValueNode*>& value_to_node) {
|
||||||
value_to_node) {
|
|
||||||
for (HloComputation* computation : module.computations()) {
|
for (HloComputation* computation : module.computations()) {
|
||||||
for (HloInstruction* instruction : computation->instructions()) {
|
for (HloInstruction* instruction : computation->instructions()) {
|
||||||
// Add copies with unambiguous source values to the map. Copies with
|
// Add copies with unambiguous source values to the map. Copies with
|
||||||
@ -916,7 +915,7 @@ class CopyRemover {
|
|||||||
ValueNode* src = nullptr;
|
ValueNode* src = nullptr;
|
||||||
ValueNode* dest = nullptr;
|
ValueNode* dest = nullptr;
|
||||||
};
|
};
|
||||||
tensorflow::gtl::FlatMap<const HloInstruction*, CopyNodes> copy_map_;
|
absl::flat_hash_map<const HloInstruction*, CopyNodes> copy_map_;
|
||||||
};
|
};
|
||||||
|
|
||||||
HloModule* module_;
|
HloModule* module_;
|
||||||
|
@ -290,6 +290,7 @@ cc_library(
|
|||||||
"//tensorflow/compiler/xla/service/llvm_ir:loop_emitter",
|
"//tensorflow/compiler/xla/service/llvm_ir:loop_emitter",
|
||||||
"//tensorflow/compiler/xla/service/llvm_ir:tuple_ops",
|
"//tensorflow/compiler/xla/service/llvm_ir:tuple_ops",
|
||||||
"//tensorflow/core:lib",
|
"//tensorflow/core:lib",
|
||||||
|
"@com_google_absl//absl/container:flat_hash_map",
|
||||||
"@com_google_absl//absl/strings",
|
"@com_google_absl//absl/strings",
|
||||||
"@com_google_absl//absl/strings:str_format",
|
"@com_google_absl//absl/strings:str_format",
|
||||||
"@com_google_absl//absl/types:span",
|
"@com_google_absl//absl/types:span",
|
||||||
@ -309,6 +310,7 @@ cc_library(
|
|||||||
deps = [
|
deps = [
|
||||||
"//tensorflow/compiler/xla:shape_util",
|
"//tensorflow/compiler/xla:shape_util",
|
||||||
"//tensorflow/core:lib",
|
"//tensorflow/core:lib",
|
||||||
|
"@com_google_absl//absl/container:flat_hash_map",
|
||||||
"@llvm//:analysis",
|
"@llvm//:analysis",
|
||||||
"@llvm//:target",
|
"@llvm//:target",
|
||||||
],
|
],
|
||||||
@ -471,6 +473,7 @@ cc_library(
|
|||||||
"//tensorflow/compiler/xla/service/llvm_ir:llvm_util",
|
"//tensorflow/compiler/xla/service/llvm_ir:llvm_util",
|
||||||
"//tensorflow/core:lib",
|
"//tensorflow/core:lib",
|
||||||
"//tensorflow/stream_executor",
|
"//tensorflow/stream_executor",
|
||||||
|
"@com_google_absl//absl/container:flat_hash_map",
|
||||||
"@com_google_absl//absl/synchronization",
|
"@com_google_absl//absl/synchronization",
|
||||||
"@com_google_absl//absl/types:span",
|
"@com_google_absl//absl/types:span",
|
||||||
],
|
],
|
||||||
@ -762,6 +765,7 @@ cc_library(
|
|||||||
"//tensorflow/compiler/xla/service:computation_layout",
|
"//tensorflow/compiler/xla/service:computation_layout",
|
||||||
"//tensorflow/compiler/xla/service:layout_assignment",
|
"//tensorflow/compiler/xla/service:layout_assignment",
|
||||||
"//tensorflow/core:lib",
|
"//tensorflow/core:lib",
|
||||||
|
"@com_google_absl//absl/container:flat_hash_map",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -17,6 +17,7 @@ limitations under the License.
|
|||||||
|
|
||||||
#include <numeric>
|
#include <numeric>
|
||||||
|
|
||||||
|
#include "absl/container/flat_hash_map.h"
|
||||||
#include "tensorflow/compiler/xla/map_util.h"
|
#include "tensorflow/compiler/xla/map_util.h"
|
||||||
#include "tensorflow/compiler/xla/service/cpu/dot_op_emitter.h"
|
#include "tensorflow/compiler/xla/service/cpu/dot_op_emitter.h"
|
||||||
#include "tensorflow/compiler/xla/service/cpu/ir_emission_utils.h"
|
#include "tensorflow/compiler/xla/service/cpu/ir_emission_utils.h"
|
||||||
@ -38,7 +39,7 @@ using absl::nullopt;
|
|||||||
using absl::optional;
|
using absl::optional;
|
||||||
|
|
||||||
using ShouldMakeOperandColMajorCache =
|
using ShouldMakeOperandColMajorCache =
|
||||||
tensorflow::gtl::FlatMap<const HloInstruction*, bool>;
|
absl::flat_hash_map<const HloInstruction*, bool>;
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
static bool ShouldMakeAllUsersColMajor(const HloInstruction* instruction) {
|
static bool ShouldMakeAllUsersColMajor(const HloInstruction* instruction) {
|
||||||
|
@ -17,6 +17,7 @@ limitations under the License.
|
|||||||
|
|
||||||
#include <functional>
|
#include <functional>
|
||||||
|
|
||||||
|
#include "absl/container/flat_hash_map.h"
|
||||||
#include "absl/synchronization/mutex.h"
|
#include "absl/synchronization/mutex.h"
|
||||||
#include "tensorflow/compiler/xla/service/llvm_ir/llvm_util.h"
|
#include "tensorflow/compiler/xla/service/llvm_ir/llvm_util.h"
|
||||||
#include "tensorflow/core/platform/dynamic_annotations.h"
|
#include "tensorflow/core/platform/dynamic_annotations.h"
|
||||||
@ -30,8 +31,7 @@ namespace cpu {
|
|||||||
namespace runtime {
|
namespace runtime {
|
||||||
|
|
||||||
XfeedManager* GetXfeedManager(int device_ordinal) {
|
XfeedManager* GetXfeedManager(int device_ordinal) {
|
||||||
static tensorflow::gtl::FlatMap<int, XfeedManager*>* managers =
|
static auto* managers = new absl::flat_hash_map<int, XfeedManager*>();
|
||||||
new tensorflow::gtl::FlatMap<int, XfeedManager*>();
|
|
||||||
static absl::Mutex* mutex = new absl::Mutex();
|
static absl::Mutex* mutex = new absl::Mutex();
|
||||||
|
|
||||||
absl::MutexLock lock(mutex);
|
absl::MutexLock lock(mutex);
|
||||||
|
@ -24,6 +24,7 @@ limitations under the License.
|
|||||||
#include <utility>
|
#include <utility>
|
||||||
#include <vector>
|
#include <vector>
|
||||||
|
|
||||||
|
#include "absl/container/flat_hash_map.h"
|
||||||
#include "tensorflow/core/lib/math/math_util.h"
|
#include "tensorflow/core/lib/math/math_util.h"
|
||||||
#include "tensorflow/core/platform/logging.h"
|
#include "tensorflow/core/platform/logging.h"
|
||||||
// IWYU pragma: no_include "llvm/IR/Intrinsics.gen.inc"
|
// IWYU pragma: no_include "llvm/IR/Intrinsics.gen.inc"
|
||||||
@ -67,7 +68,6 @@ limitations under the License.
|
|||||||
#include "tensorflow/compiler/xla/window_util.h"
|
#include "tensorflow/compiler/xla/window_util.h"
|
||||||
#include "tensorflow/core/lib/core/bits.h"
|
#include "tensorflow/core/lib/core/bits.h"
|
||||||
#include "tensorflow/core/lib/core/errors.h"
|
#include "tensorflow/core/lib/core/errors.h"
|
||||||
#include "tensorflow/core/lib/gtl/flatmap.h"
|
|
||||||
#include "tensorflow/core/lib/gtl/flatset.h"
|
#include "tensorflow/core/lib/gtl/flatset.h"
|
||||||
|
|
||||||
namespace xla {
|
namespace xla {
|
||||||
@ -1398,7 +1398,7 @@ static bool ReductionPreservesLayout(const HloInstruction& reduce) {
|
|||||||
//
|
//
|
||||||
// So if we reduce f32[A,B,C,D] on dimensions 1 and 2, this map contains
|
// So if we reduce f32[A,B,C,D] on dimensions 1 and 2, this map contains
|
||||||
// [0->0, 3->1].
|
// [0->0, 3->1].
|
||||||
gtl::FlatMap<int64, int64> unreduced_dim_map;
|
absl::flat_hash_map<int64, int64> unreduced_dim_map;
|
||||||
|
|
||||||
gtl::FlatSet<int64> reduced_dims(reduce.dimensions().begin(),
|
gtl::FlatSet<int64> reduced_dims(reduce.dimensions().begin(),
|
||||||
reduce.dimensions().end());
|
reduce.dimensions().end());
|
||||||
|
@ -23,6 +23,7 @@ limitations under the License.
|
|||||||
#include <unordered_map>
|
#include <unordered_map>
|
||||||
#include <vector>
|
#include <vector>
|
||||||
|
|
||||||
|
#include "absl/container/flat_hash_map.h"
|
||||||
#include "absl/strings/string_view.h"
|
#include "absl/strings/string_view.h"
|
||||||
#include "absl/types/span.h"
|
#include "absl/types/span.h"
|
||||||
#include "llvm/ADT/Triple.h"
|
#include "llvm/ADT/Triple.h"
|
||||||
@ -47,7 +48,6 @@ limitations under the License.
|
|||||||
#include "tensorflow/compiler/xla/statusor.h"
|
#include "tensorflow/compiler/xla/statusor.h"
|
||||||
#include "tensorflow/compiler/xla/types.h"
|
#include "tensorflow/compiler/xla/types.h"
|
||||||
#include "tensorflow/compiler/xla/xla_data.pb.h"
|
#include "tensorflow/compiler/xla/xla_data.pb.h"
|
||||||
#include "tensorflow/core/lib/gtl/flatmap.h"
|
|
||||||
#include "tensorflow/core/platform/macros.h"
|
#include "tensorflow/core/platform/macros.h"
|
||||||
#include "tensorflow/core/platform/types.h"
|
#include "tensorflow/core/platform/types.h"
|
||||||
|
|
||||||
@ -427,7 +427,7 @@ class IrEmitter : public DfsHloVisitorWithDefault,
|
|||||||
// Maps the buffer allocation slices for the parameters to the computation
|
// Maps the buffer allocation slices for the parameters to the computation
|
||||||
// being compiled to their parameter numbers. Only relevant for thread local
|
// being compiled to their parameter numbers. Only relevant for thread local
|
||||||
// computations.
|
// computations.
|
||||||
tensorflow::gtl::FlatMap<BufferAllocation::Index, int64>
|
absl::flat_hash_map<BufferAllocation::Index, int64>
|
||||||
computation_parameter_allocations_;
|
computation_parameter_allocations_;
|
||||||
|
|
||||||
// Maps HLO instructions to their index into the profile counter array.
|
// Maps HLO instructions to their index into the profile counter array.
|
||||||
@ -567,11 +567,11 @@ class IrEmitter : public DfsHloVisitorWithDefault,
|
|||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
tensorflow::gtl::FlatMap<const Literal*, llvm::Constant*,
|
absl::flat_hash_map<const Literal*, llvm::Constant*, LiteralPtrHashFunctor,
|
||||||
LiteralPtrHashFunctor, LiteralPtrEqualityFunctor>
|
LiteralPtrEqualityFunctor>
|
||||||
emitted_literals_;
|
emitted_literals_;
|
||||||
|
|
||||||
tensorflow::gtl::FlatMap<BufferAllocation::Index, llvm::Constant*>
|
absl::flat_hash_map<BufferAllocation::Index, llvm::Constant*>
|
||||||
constant_buffer_to_global_;
|
constant_buffer_to_global_;
|
||||||
|
|
||||||
std::vector<const HloComputation*> thread_local_computations_;
|
std::vector<const HloComputation*> thread_local_computations_;
|
||||||
|
@ -14,6 +14,7 @@ limitations under the License.
|
|||||||
==============================================================================*/
|
==============================================================================*/
|
||||||
|
|
||||||
#include "tensorflow/compiler/xla/service/cpu/target_machine_features.h"
|
#include "tensorflow/compiler/xla/service/cpu/target_machine_features.h"
|
||||||
|
#include "tensorflow/core/platform/logging.h"
|
||||||
|
|
||||||
namespace xla {
|
namespace xla {
|
||||||
namespace cpu {
|
namespace cpu {
|
||||||
|
@ -16,10 +16,10 @@ limitations under the License.
|
|||||||
#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_CPU_TARGET_MACHINE_FEATURES_H_
|
#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_CPU_TARGET_MACHINE_FEATURES_H_
|
||||||
#define TENSORFLOW_COMPILER_XLA_SERVICE_CPU_TARGET_MACHINE_FEATURES_H_
|
#define TENSORFLOW_COMPILER_XLA_SERVICE_CPU_TARGET_MACHINE_FEATURES_H_
|
||||||
|
|
||||||
|
#include "absl/container/flat_hash_map.h"
|
||||||
#include "llvm/Analysis/TargetTransformInfo.h"
|
#include "llvm/Analysis/TargetTransformInfo.h"
|
||||||
#include "llvm/Target/TargetMachine.h"
|
#include "llvm/Target/TargetMachine.h"
|
||||||
#include "tensorflow/compiler/xla/primitive_util.h"
|
#include "tensorflow/compiler/xla/primitive_util.h"
|
||||||
#include "tensorflow/core/lib/gtl/flatmap.h"
|
|
||||||
|
|
||||||
namespace xla {
|
namespace xla {
|
||||||
namespace cpu {
|
namespace cpu {
|
||||||
@ -97,8 +97,7 @@ class LLVMTargetMachineFeatures : public TargetMachineFeatures {
|
|||||||
// This is mutated from within `GetTargetTransformInfoFor` which is
|
// This is mutated from within `GetTargetTransformInfoFor` which is
|
||||||
// semantically a getter (and thus `const`); and is therefore declared
|
// semantically a getter (and thus `const`); and is therefore declared
|
||||||
// mutable. Making this mutable is okay because it has cache semantics.
|
// mutable. Making this mutable is okay because it has cache semantics.
|
||||||
mutable tensorflow::gtl::FlatMap<const llvm::Function*,
|
mutable absl::flat_hash_map<const llvm::Function*, llvm::TargetTransformInfo>
|
||||||
llvm::TargetTransformInfo>
|
|
||||||
target_transform_info_cache_;
|
target_transform_info_cache_;
|
||||||
llvm::TargetMachine* target_machine_;
|
llvm::TargetMachine* target_machine_;
|
||||||
};
|
};
|
||||||
|
@ -22,6 +22,7 @@ limitations under the License.
|
|||||||
#include <utility>
|
#include <utility>
|
||||||
#include <vector>
|
#include <vector>
|
||||||
|
|
||||||
|
#include "absl/container/flat_hash_map.h"
|
||||||
#include "tensorflow/compiler/xla/service/call_graph.h"
|
#include "tensorflow/compiler/xla/service/call_graph.h"
|
||||||
#include "tensorflow/compiler/xla/service/hlo_computation.h"
|
#include "tensorflow/compiler/xla/service/hlo_computation.h"
|
||||||
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
|
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
|
||||||
@ -48,7 +49,7 @@ Status Defuse(HloInstruction* fusion_instruction) {
|
|||||||
fusion_instruction->fused_instructions_computation();
|
fusion_instruction->fused_instructions_computation();
|
||||||
|
|
||||||
// A map from fused instruction to its defused clone.
|
// A map from fused instruction to its defused clone.
|
||||||
tensorflow::gtl::FlatMap<const HloInstruction*, HloInstruction*>
|
absl::flat_hash_map<const HloInstruction*, HloInstruction*>
|
||||||
defused_instructions;
|
defused_instructions;
|
||||||
// Initialize map to contain the fusion instruction parameters mapping
|
// Initialize map to contain the fusion instruction parameters mapping
|
||||||
// to the operands of the fusion instruction.
|
// to the operands of the fusion instruction.
|
||||||
|
@ -27,7 +27,6 @@ limitations under the License.
|
|||||||
#include "tensorflow/compiler/xla/types.h"
|
#include "tensorflow/compiler/xla/types.h"
|
||||||
#include "tensorflow/compiler/xla/xla_data.pb.h"
|
#include "tensorflow/compiler/xla/xla_data.pb.h"
|
||||||
#include "tensorflow/core/lib/core/status.h"
|
#include "tensorflow/core/lib/core/status.h"
|
||||||
#include "tensorflow/core/lib/gtl/flatmap.h"
|
|
||||||
#include "tensorflow/core/platform/macros.h"
|
#include "tensorflow/core/platform/macros.h"
|
||||||
#include "tensorflow/core/platform/types.h"
|
#include "tensorflow/core/platform/types.h"
|
||||||
|
|
||||||
|
@ -91,6 +91,7 @@ cc_library(
|
|||||||
"//tensorflow/compiler/xla/service:hlo",
|
"//tensorflow/compiler/xla/service:hlo",
|
||||||
"//tensorflow/compiler/xla/service:hlo_reachability",
|
"//tensorflow/compiler/xla/service:hlo_reachability",
|
||||||
"//tensorflow/core:lib",
|
"//tensorflow/core:lib",
|
||||||
|
"@com_google_absl//absl/container:flat_hash_map",
|
||||||
"@com_google_absl//absl/memory",
|
"@com_google_absl//absl/memory",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
@ -357,6 +358,7 @@ cc_library(
|
|||||||
"//tensorflow/core/platform/default/build_config:cufft_plugin",
|
"//tensorflow/core/platform/default/build_config:cufft_plugin",
|
||||||
"//tensorflow/core/platform/default/build_config:stream_executor_cuda", # build_cleaner: keep
|
"//tensorflow/core/platform/default/build_config:stream_executor_cuda", # build_cleaner: keep
|
||||||
"//tensorflow/stream_executor",
|
"//tensorflow/stream_executor",
|
||||||
|
"@com_google_absl//absl/container:flat_hash_map",
|
||||||
"@com_google_absl//absl/memory",
|
"@com_google_absl//absl/memory",
|
||||||
"@com_google_absl//absl/strings",
|
"@com_google_absl//absl/strings",
|
||||||
"@com_google_absl//absl/strings:str_format",
|
"@com_google_absl//absl/strings:str_format",
|
||||||
|
@ -19,6 +19,7 @@ limitations under the License.
|
|||||||
#include <utility>
|
#include <utility>
|
||||||
#include <vector>
|
#include <vector>
|
||||||
|
|
||||||
|
#include "absl/container/flat_hash_map.h"
|
||||||
#include "absl/memory/memory.h"
|
#include "absl/memory/memory.h"
|
||||||
#include "tensorflow/compiler/xla/map_util.h"
|
#include "tensorflow/compiler/xla/map_util.h"
|
||||||
#include "tensorflow/compiler/xla/service/gpu/buffer_allocations.h"
|
#include "tensorflow/compiler/xla/service/gpu/buffer_allocations.h"
|
||||||
@ -197,7 +198,7 @@ GpuExecutable::ResolveConstantGlobals(se::StreamExecutor* executor) {
|
|||||||
}
|
}
|
||||||
module_spec.AddCudaPtxInMemory(ptx().c_str());
|
module_spec.AddCudaPtxInMemory(ptx().c_str());
|
||||||
|
|
||||||
tensorflow::gtl::FlatMap<int64, se::DeviceMemoryBase> globals;
|
absl::flat_hash_map<int64, se::DeviceMemoryBase> globals;
|
||||||
se::ModuleHandle module_handle;
|
se::ModuleHandle module_handle;
|
||||||
executor->LoadModule(module_spec, &module_handle);
|
executor->LoadModule(module_spec, &module_handle);
|
||||||
|
|
||||||
|
@ -19,6 +19,7 @@ limitations under the License.
|
|||||||
#include <memory>
|
#include <memory>
|
||||||
#include <string>
|
#include <string>
|
||||||
|
|
||||||
|
#include "absl/container/flat_hash_map.h"
|
||||||
#include "absl/strings/string_view.h"
|
#include "absl/strings/string_view.h"
|
||||||
#include "absl/types/optional.h"
|
#include "absl/types/optional.h"
|
||||||
#include "absl/types/span.h"
|
#include "absl/types/span.h"
|
||||||
@ -35,7 +36,6 @@ limitations under the License.
|
|||||||
#include "tensorflow/compiler/xla/service/tuple_points_to_analysis.h"
|
#include "tensorflow/compiler/xla/service/tuple_points_to_analysis.h"
|
||||||
#include "tensorflow/compiler/xla/statusor.h"
|
#include "tensorflow/compiler/xla/statusor.h"
|
||||||
#include "tensorflow/compiler/xla/types.h"
|
#include "tensorflow/compiler/xla/types.h"
|
||||||
#include "tensorflow/core/lib/gtl/flatmap.h"
|
|
||||||
#include "tensorflow/core/platform/macros.h"
|
#include "tensorflow/core/platform/macros.h"
|
||||||
#include "tensorflow/core/platform/stream_executor_no_cuda.h"
|
#include "tensorflow/core/platform/stream_executor_no_cuda.h"
|
||||||
|
|
||||||
@ -101,7 +101,7 @@ class GpuExecutable : public Executable {
|
|||||||
const PointsToSet& GetRootPointsToSet() const;
|
const PointsToSet& GetRootPointsToSet() const;
|
||||||
|
|
||||||
using BufferAllocToDeviceMemoryMap =
|
using BufferAllocToDeviceMemoryMap =
|
||||||
tensorflow::gtl::FlatMap<BufferAllocation::Index, se::DeviceMemoryBase>;
|
absl::flat_hash_map<BufferAllocation::Index, se::DeviceMemoryBase>;
|
||||||
|
|
||||||
// Loads the PTX or CUBIN for this executable into `executor` and resolves the
|
// Loads the PTX or CUBIN for this executable into `executor` and resolves the
|
||||||
// globals corresponding to constant buffers. Returns a map mapping buffer
|
// globals corresponding to constant buffers. Returns a map mapping buffer
|
||||||
|
@ -16,9 +16,9 @@ limitations under the License.
|
|||||||
#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_GPU_STREAM_ASSIGNMENT_H_
|
#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_GPU_STREAM_ASSIGNMENT_H_
|
||||||
#define TENSORFLOW_COMPILER_XLA_SERVICE_GPU_STREAM_ASSIGNMENT_H_
|
#define TENSORFLOW_COMPILER_XLA_SERVICE_GPU_STREAM_ASSIGNMENT_H_
|
||||||
|
|
||||||
|
#include "absl/container/flat_hash_map.h"
|
||||||
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
|
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
|
||||||
#include "tensorflow/compiler/xla/service/hlo_module.h"
|
#include "tensorflow/compiler/xla/service/hlo_module.h"
|
||||||
#include "tensorflow/core/lib/gtl/flatmap.h"
|
|
||||||
|
|
||||||
namespace xla {
|
namespace xla {
|
||||||
namespace gpu {
|
namespace gpu {
|
||||||
@ -34,7 +34,7 @@ class StreamAssignment {
|
|||||||
|
|
||||||
private:
|
private:
|
||||||
int stream_count_ = 1; // At least the main stream.
|
int stream_count_ = 1; // At least the main stream.
|
||||||
tensorflow::gtl::FlatMap<const HloInstruction*, int> hlo_to_stream_number_;
|
absl::flat_hash_map<const HloInstruction*, int> hlo_to_stream_number_;
|
||||||
};
|
};
|
||||||
|
|
||||||
// Assigns GPU streams to instructions in `module`.
|
// Assigns GPU streams to instructions in `module`.
|
||||||
|
@ -18,13 +18,14 @@ limitations under the License.
|
|||||||
#include <algorithm>
|
#include <algorithm>
|
||||||
#include <vector>
|
#include <vector>
|
||||||
|
|
||||||
|
#include "absl/container/flat_hash_map.h"
|
||||||
#include "absl/memory/memory.h"
|
#include "absl/memory/memory.h"
|
||||||
#include "tensorflow/compiler/xla/map_util.h"
|
#include "tensorflow/compiler/xla/map_util.h"
|
||||||
#include "tensorflow/compiler/xla/util.h"
|
#include "tensorflow/compiler/xla/util.h"
|
||||||
|
|
||||||
namespace xla {
|
namespace xla {
|
||||||
|
|
||||||
using tensorflow::gtl::FlatMap;
|
using absl::flat_hash_map;
|
||||||
using tensorflow::gtl::FlatSet;
|
using tensorflow::gtl::FlatSet;
|
||||||
|
|
||||||
/*static*/
|
/*static*/
|
||||||
@ -56,7 +57,7 @@ StatusOr<int64> HeapSimulator::MinimumMemoryForComputation(
|
|||||||
const HloComputation& computation, const HloInstructionSequence& sequence,
|
const HloComputation& computation, const HloInstructionSequence& sequence,
|
||||||
const TuplePointsToAnalysis& points_to_analysis,
|
const TuplePointsToAnalysis& points_to_analysis,
|
||||||
const LogicalBuffer::SizeFunction& size_function,
|
const LogicalBuffer::SizeFunction& size_function,
|
||||||
const tensorflow::gtl::FlatMap<const HloComputation*, int64>*
|
const absl::flat_hash_map<const HloComputation*, int64>*
|
||||||
memory_by_computation) {
|
memory_by_computation) {
|
||||||
TF_ASSIGN_OR_RETURN(
|
TF_ASSIGN_OR_RETURN(
|
||||||
HeapSimulator::Result result,
|
HeapSimulator::Result result,
|
||||||
@ -88,7 +89,7 @@ StatusOr<HeapSimulator::Result> HeapSimulator::Run(
|
|||||||
const HloInstructionSequence& instruction_sequence,
|
const HloInstructionSequence& instruction_sequence,
|
||||||
const TuplePointsToAnalysis& points_to_analysis,
|
const TuplePointsToAnalysis& points_to_analysis,
|
||||||
const BufferValue::SizeFunction& size_fn, const Options& options,
|
const BufferValue::SizeFunction& size_fn, const Options& options,
|
||||||
const tensorflow::gtl::FlatMap<const HloComputation*, int64>*
|
const absl::flat_hash_map<const HloComputation*, int64>*
|
||||||
memory_by_computation) {
|
memory_by_computation) {
|
||||||
HeapSimulator heap(std::move(algorithm), size_fn, options,
|
HeapSimulator heap(std::move(algorithm), size_fn, options,
|
||||||
/*schedule=*/nullptr, memory_by_computation);
|
/*schedule=*/nullptr, memory_by_computation);
|
||||||
@ -115,8 +116,10 @@ Status HeapSimulator::RunComputation(
|
|||||||
// 'used_buffers' is the reverse map - it tracks which buffers were used by an
|
// 'used_buffers' is the reverse map - it tracks which buffers were used by an
|
||||||
// instruction, so that we can remove the instructions from a buffer's live
|
// instruction, so that we can remove the instructions from a buffer's live
|
||||||
// set after they are visited.
|
// set after they are visited.
|
||||||
FlatMap<const BufferValue*, FlatSet<const HloInstruction*>> live_buffers;
|
flat_hash_map<const BufferValue*, FlatSet<const HloInstruction*>>
|
||||||
FlatMap<const HloInstruction*, FlatSet<const BufferValue*>> used_buffers;
|
live_buffers;
|
||||||
|
flat_hash_map<const HloInstruction*, FlatSet<const BufferValue*>>
|
||||||
|
used_buffers;
|
||||||
auto add_user_to_buffer = [this, &live_buffers, &used_buffers](
|
auto add_user_to_buffer = [this, &live_buffers, &used_buffers](
|
||||||
const HloInstruction* user,
|
const HloInstruction* user,
|
||||||
const BufferValue* buffer) {
|
const BufferValue* buffer) {
|
||||||
@ -345,7 +348,7 @@ HeapSimulator::HeapSimulator(
|
|||||||
std::unique_ptr<HeapAlgorithm> algorithm,
|
std::unique_ptr<HeapAlgorithm> algorithm,
|
||||||
const BufferValue::SizeFunction& size_fn, const Options& options,
|
const BufferValue::SizeFunction& size_fn, const Options& options,
|
||||||
const HloSchedule* schedule,
|
const HloSchedule* schedule,
|
||||||
const tensorflow::gtl::FlatMap<const HloComputation*, int64>*
|
const absl::flat_hash_map<const HloComputation*, int64>*
|
||||||
memory_by_computation)
|
memory_by_computation)
|
||||||
: no_fragmentation_stats_(absl::make_unique<NoFragmentationStatsHeap>()),
|
: no_fragmentation_stats_(absl::make_unique<NoFragmentationStatsHeap>()),
|
||||||
algorithm_(std::move(algorithm)),
|
algorithm_(std::move(algorithm)),
|
||||||
@ -536,7 +539,7 @@ void NoFragmentationStatsHeap::Alloc(const BufferValue* buffer, int64 size,
|
|||||||
|
|
||||||
void NoFragmentationStatsHeap::AccountForSubcomputationMemory(
|
void NoFragmentationStatsHeap::AccountForSubcomputationMemory(
|
||||||
const HloInstruction* instruction,
|
const HloInstruction* instruction,
|
||||||
const tensorflow::gtl::FlatMap<const HloComputation*, int64>&
|
const absl::flat_hash_map<const HloComputation*, int64>&
|
||||||
memory_by_computation) {
|
memory_by_computation) {
|
||||||
// We only count the memory usage of the largest subcomputation, instead of
|
// We only count the memory usage of the largest subcomputation, instead of
|
||||||
// adding them all, because subcomputations won't execute in parallel.
|
// adding them all, because subcomputations won't execute in parallel.
|
||||||
|
@ -21,6 +21,7 @@ limitations under the License.
|
|||||||
#include <utility>
|
#include <utility>
|
||||||
#include <vector>
|
#include <vector>
|
||||||
|
|
||||||
|
#include "absl/container/flat_hash_map.h"
|
||||||
#include "tensorflow/compiler/xla/service/buffer_value.h"
|
#include "tensorflow/compiler/xla/service/buffer_value.h"
|
||||||
#include "tensorflow/compiler/xla/service/buffer_value_containers.h"
|
#include "tensorflow/compiler/xla/service/buffer_value_containers.h"
|
||||||
#include "tensorflow/compiler/xla/service/hlo.pb.h"
|
#include "tensorflow/compiler/xla/service/hlo.pb.h"
|
||||||
@ -30,7 +31,6 @@ limitations under the License.
|
|||||||
#include "tensorflow/compiler/xla/service/hlo_schedule.h"
|
#include "tensorflow/compiler/xla/service/hlo_schedule.h"
|
||||||
#include "tensorflow/compiler/xla/service/tuple_points_to_analysis.h"
|
#include "tensorflow/compiler/xla/service/tuple_points_to_analysis.h"
|
||||||
#include "tensorflow/compiler/xla/statusor.h"
|
#include "tensorflow/compiler/xla/statusor.h"
|
||||||
#include "tensorflow/core/lib/gtl/flatmap.h"
|
|
||||||
#include "tensorflow/core/lib/gtl/flatset.h"
|
#include "tensorflow/core/lib/gtl/flatset.h"
|
||||||
|
|
||||||
namespace xla {
|
namespace xla {
|
||||||
@ -58,7 +58,7 @@ class HeapSimulator {
|
|||||||
// Result represents the result of the heap simulation.
|
// Result represents the result of the heap simulation.
|
||||||
struct Result {
|
struct Result {
|
||||||
// The assignment of buffers to chunks.
|
// The assignment of buffers to chunks.
|
||||||
tensorflow::gtl::FlatMap<const BufferValue*, Chunk> chunk_map;
|
absl::flat_hash_map<const BufferValue*, Chunk> chunk_map;
|
||||||
|
|
||||||
// The total size in bytes of the heap, containing all assigned chunks.
|
// The total size in bytes of the heap, containing all assigned chunks.
|
||||||
int64 heap_size = 0;
|
int64 heap_size = 0;
|
||||||
@ -100,7 +100,7 @@ class HeapSimulator {
|
|||||||
const HloComputation& computation, const HloInstructionSequence& sequence,
|
const HloComputation& computation, const HloInstructionSequence& sequence,
|
||||||
const TuplePointsToAnalysis& points_to_analysis,
|
const TuplePointsToAnalysis& points_to_analysis,
|
||||||
const LogicalBuffer::SizeFunction& size_function,
|
const LogicalBuffer::SizeFunction& size_function,
|
||||||
const tensorflow::gtl::FlatMap<const HloComputation*, int64>*
|
const absl::flat_hash_map<const HloComputation*, int64>*
|
||||||
memory_by_computation = nullptr);
|
memory_by_computation = nullptr);
|
||||||
|
|
||||||
// Run the heap simulation with the given algorithm, assuming the given
|
// Run the heap simulation with the given algorithm, assuming the given
|
||||||
@ -130,7 +130,7 @@ class HeapSimulator {
|
|||||||
const TuplePointsToAnalysis& points_to_analysis,
|
const TuplePointsToAnalysis& points_to_analysis,
|
||||||
const BufferValue::SizeFunction& size_fn,
|
const BufferValue::SizeFunction& size_fn,
|
||||||
const Options& options = Options(),
|
const Options& options = Options(),
|
||||||
const tensorflow::gtl::FlatMap<const HloComputation*, int64>*
|
const absl::flat_hash_map<const HloComputation*, int64>*
|
||||||
memory_by_computation = nullptr);
|
memory_by_computation = nullptr);
|
||||||
|
|
||||||
private:
|
private:
|
||||||
@ -140,7 +140,7 @@ class HeapSimulator {
|
|||||||
HeapSimulator(std::unique_ptr<HeapAlgorithm> algorithm,
|
HeapSimulator(std::unique_ptr<HeapAlgorithm> algorithm,
|
||||||
const BufferValue::SizeFunction& size_fn,
|
const BufferValue::SizeFunction& size_fn,
|
||||||
const Options& options, const HloSchedule* schedule = nullptr,
|
const Options& options, const HloSchedule* schedule = nullptr,
|
||||||
const tensorflow::gtl::FlatMap<const HloComputation*, int64>*
|
const absl::flat_hash_map<const HloComputation*, int64>*
|
||||||
memory_by_computation = nullptr);
|
memory_by_computation = nullptr);
|
||||||
~HeapSimulator();
|
~HeapSimulator();
|
||||||
|
|
||||||
@ -172,7 +172,7 @@ class HeapSimulator {
|
|||||||
// handle subcomputations. It would be good to unify the handling of
|
// handle subcomputations. It would be good to unify the handling of
|
||||||
// subcomputations, but it's not clear how.
|
// subcomputations, but it's not clear how.
|
||||||
const HloSchedule* schedule_;
|
const HloSchedule* schedule_;
|
||||||
const tensorflow::gtl::FlatMap<const HloComputation*, int64>*
|
const absl::flat_hash_map<const HloComputation*, int64>*
|
||||||
memory_by_computation_;
|
memory_by_computation_;
|
||||||
|
|
||||||
// In addition to Alloc and Free, the heap simulator exposes a concept of
|
// In addition to Alloc and Free, the heap simulator exposes a concept of
|
||||||
@ -193,7 +193,7 @@ class HeapSimulator {
|
|||||||
const BufferValue* canonical = nullptr;
|
const BufferValue* canonical = nullptr;
|
||||||
int64 refcount = 0;
|
int64 refcount = 0;
|
||||||
};
|
};
|
||||||
tensorflow::gtl::FlatMap<const BufferValue*, std::shared_ptr<SharedGroup>>
|
absl::flat_hash_map<const BufferValue*, std::shared_ptr<SharedGroup>>
|
||||||
shared_buffers_;
|
shared_buffers_;
|
||||||
|
|
||||||
// Hold some sets for error-checking the sequence of Alloc and Free calls.
|
// Hold some sets for error-checking the sequence of Alloc and Free calls.
|
||||||
@ -235,7 +235,7 @@ class HeapAlgorithm {
|
|||||||
// analysis, it's not worth making major changes to HeapSimulator now.
|
// analysis, it's not worth making major changes to HeapSimulator now.
|
||||||
virtual void AccountForSubcomputationMemory(
|
virtual void AccountForSubcomputationMemory(
|
||||||
const HloInstruction* instruction,
|
const HloInstruction* instruction,
|
||||||
const tensorflow::gtl::FlatMap<const HloComputation*, int64>&
|
const absl::flat_hash_map<const HloComputation*, int64>&
|
||||||
memory_by_computation) {}
|
memory_by_computation) {}
|
||||||
|
|
||||||
// Free de-allocates a previously allocated buffer.
|
// Free de-allocates a previously allocated buffer.
|
||||||
@ -262,7 +262,7 @@ class NoFragmentationStatsHeap : public HeapAlgorithm {
|
|||||||
|
|
||||||
void AccountForSubcomputationMemory(
|
void AccountForSubcomputationMemory(
|
||||||
const HloInstruction* instruction,
|
const HloInstruction* instruction,
|
||||||
const tensorflow::gtl::FlatMap<const HloComputation*, int64>&
|
const absl::flat_hash_map<const HloComputation*, int64>&
|
||||||
memory_by_computation) override;
|
memory_by_computation) override;
|
||||||
|
|
||||||
void Free(const BufferValue* buffer, int64 size) override;
|
void Free(const BufferValue* buffer, int64 size) override;
|
||||||
@ -382,8 +382,7 @@ class GlobalDecreasingSizeBestFitHeap : public HeapAlgorithm {
|
|||||||
// Free time of the buffer.
|
// Free time of the buffer.
|
||||||
int64 end;
|
int64 end;
|
||||||
};
|
};
|
||||||
tensorflow::gtl::FlatMap<const BufferValue*, BufferInterval>
|
absl::flat_hash_map<const BufferValue*, BufferInterval> buffer_intervals_;
|
||||||
buffer_intervals_;
|
|
||||||
};
|
};
|
||||||
|
|
||||||
// A heap algorithm that chooses the best results from other algorithms added to
|
// A heap algorithm that chooses the best results from other algorithms added to
|
||||||
|
@ -19,6 +19,7 @@ limitations under the License.
|
|||||||
#include <utility>
|
#include <utility>
|
||||||
#include <vector>
|
#include <vector>
|
||||||
|
|
||||||
|
#include "absl/container/flat_hash_map.h"
|
||||||
#include "absl/memory/memory.h"
|
#include "absl/memory/memory.h"
|
||||||
#include "tensorflow/compiler/xla/literal.h"
|
#include "tensorflow/compiler/xla/literal.h"
|
||||||
#include "tensorflow/compiler/xla/service/buffer_value.h"
|
#include "tensorflow/compiler/xla/service/buffer_value.h"
|
||||||
@ -31,7 +32,6 @@ limitations under the License.
|
|||||||
#include "tensorflow/compiler/xla/status_macros.h"
|
#include "tensorflow/compiler/xla/status_macros.h"
|
||||||
#include "tensorflow/compiler/xla/tests/hlo_verified_test_base.h"
|
#include "tensorflow/compiler/xla/tests/hlo_verified_test_base.h"
|
||||||
#include "tensorflow/core/lib/core/status_test_util.h"
|
#include "tensorflow/core/lib/core/status_test_util.h"
|
||||||
#include "tensorflow/core/lib/gtl/flatmap.h"
|
|
||||||
|
|
||||||
namespace xla {
|
namespace xla {
|
||||||
namespace {
|
namespace {
|
||||||
@ -174,7 +174,7 @@ class HeapSimulatorTracker {
|
|||||||
|
|
||||||
// Construct the module sequence grouped by computation.
|
// Construct the module sequence grouped by computation.
|
||||||
HloSchedule schedule(module_.get());
|
HloSchedule schedule(module_.get());
|
||||||
tensorflow::gtl::FlatMap<const HloInstruction*, int> reverse_position;
|
absl::flat_hash_map<const HloInstruction*, int> reverse_position;
|
||||||
for (int i = 0; i < full_module_sequence.size(); ++i) {
|
for (int i = 0; i < full_module_sequence.size(); ++i) {
|
||||||
const HloInstruction* instruction = full_module_sequence[i];
|
const HloInstruction* instruction = full_module_sequence[i];
|
||||||
schedule.GetOrCreateSequence(instruction->parent())
|
schedule.GetOrCreateSequence(instruction->parent())
|
||||||
|
@ -20,6 +20,7 @@ limitations under the License.
|
|||||||
#include <utility>
|
#include <utility>
|
||||||
#include <vector>
|
#include <vector>
|
||||||
|
|
||||||
|
#include "absl/container/flat_hash_map.h"
|
||||||
#include "absl/strings/str_cat.h"
|
#include "absl/strings/str_cat.h"
|
||||||
#include "absl/strings/str_join.h"
|
#include "absl/strings/str_join.h"
|
||||||
#include "tensorflow/compiler/xla/map_util.h"
|
#include "tensorflow/compiler/xla/map_util.h"
|
||||||
@ -290,13 +291,11 @@ class BufferValueMap {
|
|||||||
const HloDataflowAnalysis& dataflow_;
|
const HloDataflowAnalysis& dataflow_;
|
||||||
|
|
||||||
// A map containing the set of values contained in each buffer.
|
// A map containing the set of values contained in each buffer.
|
||||||
tensorflow::gtl::FlatMap<BufferNumber,
|
absl::flat_hash_map<BufferNumber, tensorflow::gtl::FlatSet<const HloValue*>>
|
||||||
tensorflow::gtl::FlatSet<const HloValue*>>
|
|
||||||
buffers_;
|
buffers_;
|
||||||
|
|
||||||
// A map indicating which buffer each value is contained in.
|
// A map indicating which buffer each value is contained in.
|
||||||
tensorflow::gtl::FlatMap<const HloValue*, BufferNumber>
|
absl::flat_hash_map<const HloValue*, BufferNumber> value_to_buffer_number_;
|
||||||
value_to_buffer_number_;
|
|
||||||
|
|
||||||
// The buffer number of the next buffer to be created.
|
// The buffer number of the next buffer to be created.
|
||||||
BufferNumber next_buffer_number_ = 0;
|
BufferNumber next_buffer_number_ = 0;
|
||||||
|
@ -20,6 +20,7 @@ limitations under the License.
|
|||||||
#include <string>
|
#include <string>
|
||||||
#include <vector>
|
#include <vector>
|
||||||
|
|
||||||
|
#include "absl/container/flat_hash_map.h"
|
||||||
#include "absl/types/span.h"
|
#include "absl/types/span.h"
|
||||||
#include "tensorflow/compiler/xla/service/hlo_buffer.h"
|
#include "tensorflow/compiler/xla/service/hlo_buffer.h"
|
||||||
#include "tensorflow/compiler/xla/service/hlo_dataflow_analysis.h"
|
#include "tensorflow/compiler/xla/service/hlo_dataflow_analysis.h"
|
||||||
@ -110,7 +111,7 @@ class HloAliasAnalysis {
|
|||||||
std::unique_ptr<HloDataflowAnalysis> dataflow_analysis_;
|
std::unique_ptr<HloDataflowAnalysis> dataflow_analysis_;
|
||||||
|
|
||||||
// A map indicating which buffer a value is contained in.
|
// A map indicating which buffer a value is contained in.
|
||||||
tensorflow::gtl::FlatMap<const HloValue*, HloBuffer*> value_to_buffer_;
|
absl::flat_hash_map<const HloValue*, HloBuffer*> value_to_buffer_;
|
||||||
|
|
||||||
// A lazily constructed vector containing all HloBuffers sorted by
|
// A lazily constructed vector containing all HloBuffers sorted by
|
||||||
// HloBuffer::Id.
|
// HloBuffer::Id.
|
||||||
|
@ -18,8 +18,8 @@ limitations under the License.
|
|||||||
|
|
||||||
#include <string>
|
#include <string>
|
||||||
|
|
||||||
|
#include "absl/container/flat_hash_map.h"
|
||||||
#include "tensorflow/compiler/xla/map_util.h"
|
#include "tensorflow/compiler/xla/map_util.h"
|
||||||
#include "tensorflow/core/lib/gtl/flatmap.h"
|
|
||||||
|
|
||||||
namespace xla {
|
namespace xla {
|
||||||
|
|
||||||
@ -73,12 +73,12 @@ class HloCloneContext {
|
|||||||
return FindOrDie(computations_, old_computation);
|
return FindOrDie(computations_, old_computation);
|
||||||
}
|
}
|
||||||
|
|
||||||
const tensorflow::gtl::FlatMap<const HloInstruction*, HloInstruction*>&
|
const absl::flat_hash_map<const HloInstruction*, HloInstruction*>&
|
||||||
cloned_instructions() const {
|
cloned_instructions() const {
|
||||||
return instructions_;
|
return instructions_;
|
||||||
}
|
}
|
||||||
|
|
||||||
const tensorflow::gtl::FlatMap<const HloComputation*, HloComputation*>&
|
const absl::flat_hash_map<const HloComputation*, HloComputation*>&
|
||||||
cloned_computations() const {
|
cloned_computations() const {
|
||||||
return computations_;
|
return computations_;
|
||||||
}
|
}
|
||||||
@ -86,10 +86,8 @@ class HloCloneContext {
|
|||||||
private:
|
private:
|
||||||
HloModule* module_;
|
HloModule* module_;
|
||||||
string suffix_;
|
string suffix_;
|
||||||
tensorflow::gtl::FlatMap<const HloInstruction*, HloInstruction*>
|
absl::flat_hash_map<const HloInstruction*, HloInstruction*> instructions_;
|
||||||
instructions_;
|
absl::flat_hash_map<const HloComputation*, HloComputation*> computations_;
|
||||||
tensorflow::gtl::FlatMap<const HloComputation*, HloComputation*>
|
|
||||||
computations_;
|
|
||||||
};
|
};
|
||||||
|
|
||||||
} // namespace xla
|
} // namespace xla
|
||||||
|
@ -24,6 +24,7 @@ limitations under the License.
|
|||||||
#include <sstream>
|
#include <sstream>
|
||||||
|
|
||||||
#include "absl/algorithm/container.h"
|
#include "absl/algorithm/container.h"
|
||||||
|
#include "absl/container/flat_hash_map.h"
|
||||||
#include "absl/memory/memory.h"
|
#include "absl/memory/memory.h"
|
||||||
#include "absl/strings/numbers.h"
|
#include "absl/strings/numbers.h"
|
||||||
#include "absl/strings/str_cat.h"
|
#include "absl/strings/str_cat.h"
|
||||||
@ -297,7 +298,7 @@ void ComputeComputationPostOrder(
|
|||||||
void HloComputation::ComputeInstructionPostOrder(
|
void HloComputation::ComputeInstructionPostOrder(
|
||||||
const HloComputation::ChannelDependencyMap& channel_dependency_map,
|
const HloComputation::ChannelDependencyMap& channel_dependency_map,
|
||||||
std::vector<HloInstruction*>* post_order, HloInstruction* root,
|
std::vector<HloInstruction*>* post_order, HloInstruction* root,
|
||||||
tensorflow::gtl::FlatMap<HloInstruction*, VisitState>* visited) const {
|
absl::flat_hash_map<HloInstruction*, VisitState>* visited) const {
|
||||||
std::vector<HloInstruction*> dfs_stack;
|
std::vector<HloInstruction*> dfs_stack;
|
||||||
dfs_stack.push_back(root);
|
dfs_stack.push_back(root);
|
||||||
while (!dfs_stack.empty()) {
|
while (!dfs_stack.empty()) {
|
||||||
@ -394,7 +395,7 @@ std::vector<HloInstruction*> HloComputation::MakeInstructionPostOrder() const {
|
|||||||
std::vector<HloInstruction*> post_order;
|
std::vector<HloInstruction*> post_order;
|
||||||
post_order.reserve(instruction_count());
|
post_order.reserve(instruction_count());
|
||||||
std::vector<HloInstruction*> trace_instructions;
|
std::vector<HloInstruction*> trace_instructions;
|
||||||
tensorflow::gtl::FlatMap<HloInstruction*, VisitState> visited;
|
absl::flat_hash_map<HloInstruction*, VisitState> visited;
|
||||||
for (auto& instruction : instructions_) {
|
for (auto& instruction : instructions_) {
|
||||||
if (instruction->opcode() == HloOpcode::kTrace) {
|
if (instruction->opcode() == HloOpcode::kTrace) {
|
||||||
// Trace instructions aren't handled by the DFS visitor. Add trace
|
// Trace instructions aren't handled by the DFS visitor. Add trace
|
||||||
@ -505,9 +506,9 @@ HloComputationProto HloComputation::ToProto() const {
|
|||||||
/* static */ StatusOr<std::unique_ptr<HloComputation>>
|
/* static */ StatusOr<std::unique_ptr<HloComputation>>
|
||||||
HloComputation::CreateFromProto(
|
HloComputation::CreateFromProto(
|
||||||
const HloComputationProto& proto,
|
const HloComputationProto& proto,
|
||||||
const tensorflow::gtl::FlatMap<int64, HloComputation*>& computation_map) {
|
const absl::flat_hash_map<int64, HloComputation*>& computation_map) {
|
||||||
tensorflow::gtl::FlatMap<int64, HloInstruction*> instruction_map;
|
absl::flat_hash_map<int64, HloInstruction*> instruction_map;
|
||||||
tensorflow::gtl::FlatMap<HloInstruction*, int64> to_proto_id;
|
absl::flat_hash_map<HloInstruction*, int64> to_proto_id;
|
||||||
std::vector<std::unique_ptr<HloInstruction>> instructions;
|
std::vector<std::unique_ptr<HloInstruction>> instructions;
|
||||||
int64 parameter_count = 0;
|
int64 parameter_count = 0;
|
||||||
for (const HloInstructionProto& instruction_proto : proto.instructions()) {
|
for (const HloInstructionProto& instruction_proto : proto.instructions()) {
|
||||||
|
@ -25,6 +25,7 @@ limitations under the License.
|
|||||||
#include <utility>
|
#include <utility>
|
||||||
#include <vector>
|
#include <vector>
|
||||||
|
|
||||||
|
#include "absl/container/flat_hash_map.h"
|
||||||
#include "absl/types/span.h"
|
#include "absl/types/span.h"
|
||||||
#include "tensorflow/compiler/xla/iterator_util.h"
|
#include "tensorflow/compiler/xla/iterator_util.h"
|
||||||
#include "tensorflow/compiler/xla/map_util.h"
|
#include "tensorflow/compiler/xla/map_util.h"
|
||||||
@ -40,7 +41,6 @@ limitations under the License.
|
|||||||
#include "tensorflow/compiler/xla/types.h"
|
#include "tensorflow/compiler/xla/types.h"
|
||||||
#include "tensorflow/compiler/xla/xla_data.pb.h"
|
#include "tensorflow/compiler/xla/xla_data.pb.h"
|
||||||
#include "tensorflow/core/lib/core/status.h"
|
#include "tensorflow/core/lib/core/status.h"
|
||||||
#include "tensorflow/core/lib/gtl/flatmap.h"
|
|
||||||
#include "tensorflow/core/lib/gtl/flatset.h"
|
#include "tensorflow/core/lib/gtl/flatset.h"
|
||||||
#include "tensorflow/core/platform/macros.h"
|
#include "tensorflow/core/platform/macros.h"
|
||||||
#include "tensorflow/core/platform/types.h"
|
#include "tensorflow/core/platform/types.h"
|
||||||
@ -188,7 +188,7 @@ class HloComputation {
|
|||||||
// calls.
|
// calls.
|
||||||
static StatusOr<std::unique_ptr<HloComputation>> CreateFromProto(
|
static StatusOr<std::unique_ptr<HloComputation>> CreateFromProto(
|
||||||
const HloComputationProto& proto,
|
const HloComputationProto& proto,
|
||||||
const tensorflow::gtl::FlatMap<int64, HloComputation*>& computation_map);
|
const absl::flat_hash_map<int64, HloComputation*>& computation_map);
|
||||||
|
|
||||||
// Gets the instructions in this computation.
|
// Gets the instructions in this computation.
|
||||||
//
|
//
|
||||||
@ -414,14 +414,14 @@ class HloComputation {
|
|||||||
// cross-replica-sum the union of the dependencies for all participating
|
// cross-replica-sum the union of the dependencies for all participating
|
||||||
// instructions.
|
// instructions.
|
||||||
using ChannelDependencyMap =
|
using ChannelDependencyMap =
|
||||||
tensorflow::gtl::FlatMap<int64, absl::InlinedVector<HloInstruction*, 1>>;
|
absl::flat_hash_map<int64, absl::InlinedVector<HloInstruction*, 1>>;
|
||||||
ChannelDependencyMap ComputeChannelDependencies() const;
|
ChannelDependencyMap ComputeChannelDependencies() const;
|
||||||
|
|
||||||
enum VisitState { kVisiting, kVisited };
|
enum VisitState { kVisiting, kVisited };
|
||||||
void ComputeInstructionPostOrder(
|
void ComputeInstructionPostOrder(
|
||||||
const HloComputation::ChannelDependencyMap& channel_dependency_map,
|
const HloComputation::ChannelDependencyMap& channel_dependency_map,
|
||||||
std::vector<HloInstruction*>* post_order, HloInstruction* root,
|
std::vector<HloInstruction*>* post_order, HloInstruction* root,
|
||||||
tensorflow::gtl::FlatMap<HloInstruction*, VisitState>* visited) const;
|
absl::flat_hash_map<HloInstruction*, VisitState>* visited) const;
|
||||||
|
|
||||||
string name_;
|
string name_;
|
||||||
int64 unique_id_;
|
int64 unique_id_;
|
||||||
@ -439,7 +439,7 @@ class HloComputation {
|
|||||||
// instruction pointer to location in the list for fast lookup.
|
// instruction pointer to location in the list for fast lookup.
|
||||||
using InstructionList = std::list<std::unique_ptr<HloInstruction>>;
|
using InstructionList = std::list<std::unique_ptr<HloInstruction>>;
|
||||||
InstructionList instructions_;
|
InstructionList instructions_;
|
||||||
tensorflow::gtl::FlatMap<const HloInstruction*, InstructionList::iterator>
|
absl::flat_hash_map<const HloInstruction*, InstructionList::iterator>
|
||||||
instruction_iterators_;
|
instruction_iterators_;
|
||||||
|
|
||||||
std::vector<HloInstruction*> param_instructions_;
|
std::vector<HloInstruction*> param_instructions_;
|
||||||
|
@ -17,6 +17,7 @@ limitations under the License.
|
|||||||
|
|
||||||
#include <algorithm>
|
#include <algorithm>
|
||||||
|
|
||||||
|
#include "absl/container/flat_hash_map.h"
|
||||||
#include "absl/memory/memory.h"
|
#include "absl/memory/memory.h"
|
||||||
#include "tensorflow/compiler/xla/map_util.h"
|
#include "tensorflow/compiler/xla/map_util.h"
|
||||||
#include "tensorflow/compiler/xla/service/hlo_opcode.h"
|
#include "tensorflow/compiler/xla/service/hlo_opcode.h"
|
||||||
@ -106,8 +107,8 @@ Status HloDomainMap::PopulateDomainMetadataMap() {
|
|||||||
auto equal = [](const DomainMetadata* a, const DomainMetadata* b) {
|
auto equal = [](const DomainMetadata* a, const DomainMetadata* b) {
|
||||||
return a->Matches(*b);
|
return a->Matches(*b);
|
||||||
};
|
};
|
||||||
tensorflow::gtl::FlatMap<const DomainMetadata*, int64, decltype(hash),
|
absl::flat_hash_map<const DomainMetadata*, int64, decltype(hash),
|
||||||
decltype(equal)>
|
decltype(equal)>
|
||||||
domain_metadata(1024, hash, equal);
|
domain_metadata(1024, hash, equal);
|
||||||
|
|
||||||
for (auto& domain : instruction_domains_) {
|
for (auto& domain : instruction_domains_) {
|
||||||
|
@ -19,13 +19,13 @@ limitations under the License.
|
|||||||
#include <memory>
|
#include <memory>
|
||||||
#include <vector>
|
#include <vector>
|
||||||
|
|
||||||
|
#include "absl/container/flat_hash_map.h"
|
||||||
#include "tensorflow/compiler/xla/service/hlo_computation.h"
|
#include "tensorflow/compiler/xla/service/hlo_computation.h"
|
||||||
#include "tensorflow/compiler/xla/service/hlo_domain_metadata.h"
|
#include "tensorflow/compiler/xla/service/hlo_domain_metadata.h"
|
||||||
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
|
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
|
||||||
#include "tensorflow/compiler/xla/service/hlo_module.h"
|
#include "tensorflow/compiler/xla/service/hlo_module.h"
|
||||||
#include "tensorflow/compiler/xla/statusor.h"
|
#include "tensorflow/compiler/xla/statusor.h"
|
||||||
#include "tensorflow/core/lib/core/status.h"
|
#include "tensorflow/core/lib/core/status.h"
|
||||||
#include "tensorflow/core/lib/gtl/flatmap.h"
|
|
||||||
#include "tensorflow/core/lib/gtl/flatset.h"
|
#include "tensorflow/core/lib/gtl/flatset.h"
|
||||||
|
|
||||||
namespace xla {
|
namespace xla {
|
||||||
@ -77,8 +77,7 @@ class HloDomainMap {
|
|||||||
private:
|
private:
|
||||||
// Map used for representing instruction ordering, i.e.
|
// Map used for representing instruction ordering, i.e.
|
||||||
// order_map[a] < order_map[b] means a must be ordered before b.
|
// order_map[a] < order_map[b] means a must be ordered before b.
|
||||||
using InstructionOrderMap =
|
using InstructionOrderMap = absl::flat_hash_map<const HloInstruction*, int64>;
|
||||||
tensorflow::gtl::FlatMap<const HloInstruction*, int64>;
|
|
||||||
|
|
||||||
HloDomainMap(string domain_kind) : domain_kind_(std::move(domain_kind)) {}
|
HloDomainMap(string domain_kind) : domain_kind_(std::move(domain_kind)) {}
|
||||||
|
|
||||||
@ -120,8 +119,8 @@ class HloDomainMap {
|
|||||||
|
|
||||||
string domain_kind_;
|
string domain_kind_;
|
||||||
std::vector<std::unique_ptr<DomainMetadata::Domain>> instruction_domains_;
|
std::vector<std::unique_ptr<DomainMetadata::Domain>> instruction_domains_;
|
||||||
tensorflow::gtl::FlatMap<HloInstruction*, int64> instruction_to_domain_;
|
absl::flat_hash_map<HloInstruction*, int64> instruction_to_domain_;
|
||||||
tensorflow::gtl::FlatMap<HloInstruction*, int64> domain_metadata_id_;
|
absl::flat_hash_map<HloInstruction*, int64> domain_metadata_id_;
|
||||||
};
|
};
|
||||||
|
|
||||||
} // namespace xla
|
} // namespace xla
|
||||||
|
@ -22,6 +22,7 @@ limitations under the License.
|
|||||||
#include <utility>
|
#include <utility>
|
||||||
|
|
||||||
#include "absl/algorithm/container.h"
|
#include "absl/algorithm/container.h"
|
||||||
|
#include "absl/container/flat_hash_map.h"
|
||||||
#include "absl/container/inlined_vector.h"
|
#include "absl/container/inlined_vector.h"
|
||||||
#include "absl/memory/memory.h"
|
#include "absl/memory/memory.h"
|
||||||
#include "absl/strings/ascii.h"
|
#include "absl/strings/ascii.h"
|
||||||
@ -43,7 +44,6 @@ limitations under the License.
|
|||||||
#include "tensorflow/compiler/xla/types.h"
|
#include "tensorflow/compiler/xla/types.h"
|
||||||
#include "tensorflow/compiler/xla/util.h"
|
#include "tensorflow/compiler/xla/util.h"
|
||||||
#include "tensorflow/core/lib/core/errors.h"
|
#include "tensorflow/core/lib/core/errors.h"
|
||||||
#include "tensorflow/core/lib/gtl/flatmap.h"
|
|
||||||
#include "tensorflow/core/lib/gtl/flatset.h"
|
#include "tensorflow/core/lib/gtl/flatset.h"
|
||||||
#include "tensorflow/core/lib/gtl/map_util.h"
|
#include "tensorflow/core/lib/gtl/map_util.h"
|
||||||
#include "tensorflow/core/platform/human_readable_json.h"
|
#include "tensorflow/core/platform/human_readable_json.h"
|
||||||
@ -59,8 +59,8 @@ using absl::StrJoin;
|
|||||||
/* static */
|
/* static */
|
||||||
StatusOr<std::unique_ptr<HloInstruction>> HloInstruction::CreateFromProto(
|
StatusOr<std::unique_ptr<HloInstruction>> HloInstruction::CreateFromProto(
|
||||||
const HloInstructionProto& proto,
|
const HloInstructionProto& proto,
|
||||||
const tensorflow::gtl::FlatMap<int64, HloInstruction*>& instruction_map,
|
const absl::flat_hash_map<int64, HloInstruction*>& instruction_map,
|
||||||
const tensorflow::gtl::FlatMap<int64, HloComputation*>& computation_map) {
|
const absl::flat_hash_map<int64, HloComputation*>& computation_map) {
|
||||||
TF_RET_CHECK(!proto.opcode().empty());
|
TF_RET_CHECK(!proto.opcode().empty());
|
||||||
TF_ASSIGN_OR_RETURN(HloOpcode opcode, StringToHloOpcode(proto.opcode()));
|
TF_ASSIGN_OR_RETURN(HloOpcode opcode, StringToHloOpcode(proto.opcode()));
|
||||||
TF_RET_CHECK(proto.has_shape());
|
TF_RET_CHECK(proto.has_shape());
|
||||||
@ -266,7 +266,8 @@ StatusOr<std::unique_ptr<HloInstruction>> HloInstruction::CreateFromProto(
|
|||||||
<< "Expect 1 called computation for fusion instruction but sees "
|
<< "Expect 1 called computation for fusion instruction but sees "
|
||||||
<< proto.called_computation_ids_size();
|
<< proto.called_computation_ids_size();
|
||||||
const int64 fusion_id = proto.called_computation_ids(0);
|
const int64 fusion_id = proto.called_computation_ids(0);
|
||||||
auto* fused_computation = FindPtrOrNull(computation_map, fusion_id);
|
auto* fused_computation =
|
||||||
|
tensorflow::gtl::FindPtrOrNull(computation_map, fusion_id);
|
||||||
TF_RET_CHECK(fused_computation != nullptr)
|
TF_RET_CHECK(fused_computation != nullptr)
|
||||||
<< "No fusion computation with id " << fusion_id;
|
<< "No fusion computation with id " << fusion_id;
|
||||||
instruction = CreateFusion(proto.shape(), fusion_kind, all_operands(),
|
instruction = CreateFusion(proto.shape(), fusion_kind, all_operands(),
|
||||||
@ -2661,14 +2662,14 @@ class HloInstruction::FusionReusesParamElements {
|
|||||||
// the value of this parameter, which would save stack space but not allow us
|
// the value of this parameter, which would save stack space but not allow us
|
||||||
// to finish early if we find a reuse.
|
// to finish early if we find a reuse.
|
||||||
static UseKind Compute(int64 i, const HloInstruction& hlo) {
|
static UseKind Compute(int64 i, const HloInstruction& hlo) {
|
||||||
tensorflow::gtl::FlatMap<const HloInstruction*, UseKind> memoization_cache;
|
absl::flat_hash_map<const HloInstruction*, UseKind> memoization_cache;
|
||||||
return ComputeInternal(i, hlo, &memoization_cache);
|
return ComputeInternal(i, hlo, &memoization_cache);
|
||||||
}
|
}
|
||||||
|
|
||||||
private:
|
private:
|
||||||
static UseKind ComputeInternal(
|
static UseKind ComputeInternal(
|
||||||
int64 i, const HloInstruction& hlo,
|
int64 i, const HloInstruction& hlo,
|
||||||
tensorflow::gtl::FlatMap<const HloInstruction*, UseKind>* cache) {
|
absl::flat_hash_map<const HloInstruction*, UseKind>* cache) {
|
||||||
if (auto hlo_param = DynCast<HloParameterInstruction>(&hlo)) {
|
if (auto hlo_param = DynCast<HloParameterInstruction>(&hlo)) {
|
||||||
if (hlo_param->parameter_number() == i) {
|
if (hlo_param->parameter_number() == i) {
|
||||||
return UseKind::kUse;
|
return UseKind::kUse;
|
||||||
|
@ -32,6 +32,7 @@ limitations under the License.
|
|||||||
#include <unordered_set>
|
#include <unordered_set>
|
||||||
#include <vector>
|
#include <vector>
|
||||||
|
|
||||||
|
#include "absl/container/flat_hash_map.h"
|
||||||
#include "absl/container/inlined_vector.h"
|
#include "absl/container/inlined_vector.h"
|
||||||
#include "absl/memory/memory.h"
|
#include "absl/memory/memory.h"
|
||||||
#include "absl/strings/str_cat.h"
|
#include "absl/strings/str_cat.h"
|
||||||
@ -50,7 +51,6 @@ limitations under the License.
|
|||||||
#include "tensorflow/compiler/xla/types.h"
|
#include "tensorflow/compiler/xla/types.h"
|
||||||
#include "tensorflow/compiler/xla/xla_data.pb.h"
|
#include "tensorflow/compiler/xla/xla_data.pb.h"
|
||||||
#include "tensorflow/core/lib/core/status.h"
|
#include "tensorflow/core/lib/core/status.h"
|
||||||
#include "tensorflow/core/lib/gtl/flatmap.h"
|
|
||||||
#include "tensorflow/core/lib/gtl/iterator_range.h"
|
#include "tensorflow/core/lib/gtl/iterator_range.h"
|
||||||
#include "tensorflow/core/platform/logging.h"
|
#include "tensorflow/core/platform/logging.h"
|
||||||
#include "tensorflow/core/platform/macros.h"
|
#include "tensorflow/core/platform/macros.h"
|
||||||
@ -247,7 +247,7 @@ class CanonicalNameMap {
|
|||||||
|
|
||||||
private:
|
private:
|
||||||
int64 index;
|
int64 index;
|
||||||
tensorflow::gtl::FlatMap<string, string> canonical_name_map;
|
absl::flat_hash_map<string, string> canonical_name_map;
|
||||||
};
|
};
|
||||||
|
|
||||||
// HLO instructions are the atomic unit of the high-level compiler's IR.
|
// HLO instructions are the atomic unit of the high-level compiler's IR.
|
||||||
@ -350,8 +350,8 @@ class HloInstruction {
|
|||||||
// calls.
|
// calls.
|
||||||
static StatusOr<std::unique_ptr<HloInstruction>> CreateFromProto(
|
static StatusOr<std::unique_ptr<HloInstruction>> CreateFromProto(
|
||||||
const HloInstructionProto& proto,
|
const HloInstructionProto& proto,
|
||||||
const tensorflow::gtl::FlatMap<int64, HloInstruction*>& instruction_map,
|
const absl::flat_hash_map<int64, HloInstruction*>& instruction_map,
|
||||||
const tensorflow::gtl::FlatMap<int64, HloComputation*>& computation_map);
|
const absl::flat_hash_map<int64, HloComputation*>& computation_map);
|
||||||
|
|
||||||
// Creates a parameter-retrieving instruction.
|
// Creates a parameter-retrieving instruction.
|
||||||
static std::unique_ptr<HloInstruction> CreateParameter(int64 parameter_number,
|
static std::unique_ptr<HloInstruction> CreateParameter(int64 parameter_number,
|
||||||
|
@ -18,6 +18,7 @@ limitations under the License.
|
|||||||
#include <deque>
|
#include <deque>
|
||||||
|
|
||||||
#include "absl/algorithm/container.h"
|
#include "absl/algorithm/container.h"
|
||||||
|
#include "absl/container/flat_hash_map.h"
|
||||||
#include "absl/memory/memory.h"
|
#include "absl/memory/memory.h"
|
||||||
#include "absl/strings/escaping.h"
|
#include "absl/strings/escaping.h"
|
||||||
#include "absl/strings/str_cat.h"
|
#include "absl/strings/str_cat.h"
|
||||||
@ -28,7 +29,6 @@ limitations under the License.
|
|||||||
#include "tensorflow/compiler/xla/service/hlo_computation.h"
|
#include "tensorflow/compiler/xla/service/hlo_computation.h"
|
||||||
#include "tensorflow/compiler/xla/service/hlo_module.h"
|
#include "tensorflow/compiler/xla/service/hlo_module.h"
|
||||||
#include "tensorflow/compiler/xla/window_util.h"
|
#include "tensorflow/compiler/xla/window_util.h"
|
||||||
#include "tensorflow/core/lib/gtl/flatmap.h"
|
|
||||||
|
|
||||||
namespace xla {
|
namespace xla {
|
||||||
namespace {
|
namespace {
|
||||||
@ -1099,7 +1099,7 @@ void HloFusionInstruction::MergeFusionInstructionIntoMultiOutput(
|
|||||||
// Note that we add the unfused instructions to this->parent_ computation.
|
// Note that we add the unfused instructions to this->parent_ computation.
|
||||||
// This is necessary because the unique_id needs for an instruction and
|
// This is necessary because the unique_id needs for an instruction and
|
||||||
// it's only added when inserting to the computation.
|
// it's only added when inserting to the computation.
|
||||||
tensorflow::gtl::FlatMap<HloInstruction*, HloInstruction*> old_to_new;
|
absl::flat_hash_map<HloInstruction*, HloInstruction*> old_to_new;
|
||||||
std::vector<HloInstruction*> unfused_instructions;
|
std::vector<HloInstruction*> unfused_instructions;
|
||||||
auto computation_to_merge =
|
auto computation_to_merge =
|
||||||
instruction_to_merge->fused_instructions_computation();
|
instruction_to_merge->fused_instructions_computation();
|
||||||
@ -1392,7 +1392,7 @@ std::unique_ptr<HloInstruction> HloFusionInstruction::CloneWithNewOperandsImpl(
|
|||||||
}
|
}
|
||||||
|
|
||||||
Status HloFusionInstruction::DeduplicateFusionOperands() {
|
Status HloFusionInstruction::DeduplicateFusionOperands() {
|
||||||
tensorflow::gtl::FlatMap<const HloInstruction*, int> operand_indices;
|
absl::flat_hash_map<const HloInstruction*, int> operand_indices;
|
||||||
std::vector<int> operands_to_remove;
|
std::vector<int> operands_to_remove;
|
||||||
for (int i = 0; i < operand_count(); ++i) {
|
for (int i = 0; i < operand_count(); ++i) {
|
||||||
auto emplace_result = operand_indices.emplace(operand(i), i);
|
auto emplace_result = operand_indices.emplace(operand(i), i);
|
||||||
|
@ -20,6 +20,7 @@ limitations under the License.
|
|||||||
#include <utility>
|
#include <utility>
|
||||||
#include <vector>
|
#include <vector>
|
||||||
|
|
||||||
|
#include "absl/container/flat_hash_map.h"
|
||||||
#include "tensorflow/compiler/xla/service/heap_simulator.h"
|
#include "tensorflow/compiler/xla/service/heap_simulator.h"
|
||||||
#include "tensorflow/compiler/xla/service/hlo_computation.h"
|
#include "tensorflow/compiler/xla/service/hlo_computation.h"
|
||||||
#include "tensorflow/compiler/xla/service/tuple_points_to_analysis.h"
|
#include "tensorflow/compiler/xla/service/tuple_points_to_analysis.h"
|
||||||
@ -74,7 +75,7 @@ class ListScheduler {
|
|||||||
const HloComputation& computation,
|
const HloComputation& computation,
|
||||||
const TuplePointsToAnalysis& points_to_analysis,
|
const TuplePointsToAnalysis& points_to_analysis,
|
||||||
const LogicalBuffer::SizeFunction& size_function,
|
const LogicalBuffer::SizeFunction& size_function,
|
||||||
const tensorflow::gtl::FlatMap<const HloComputation*, int64>&
|
const absl::flat_hash_map<const HloComputation*, int64>&
|
||||||
memory_by_computation) {
|
memory_by_computation) {
|
||||||
ListScheduler scheduler(computation, points_to_analysis, size_function,
|
ListScheduler scheduler(computation, points_to_analysis, size_function,
|
||||||
memory_by_computation);
|
memory_by_computation);
|
||||||
@ -99,7 +100,7 @@ class ListScheduler {
|
|||||||
ListScheduler(const HloComputation& computation,
|
ListScheduler(const HloComputation& computation,
|
||||||
const TuplePointsToAnalysis& points_to_analysis,
|
const TuplePointsToAnalysis& points_to_analysis,
|
||||||
const LogicalBuffer::SizeFunction& size_function,
|
const LogicalBuffer::SizeFunction& size_function,
|
||||||
const tensorflow::gtl::FlatMap<const HloComputation*, int64>&
|
const absl::flat_hash_map<const HloComputation*, int64>&
|
||||||
memory_by_computation)
|
memory_by_computation)
|
||||||
: computation_(computation),
|
: computation_(computation),
|
||||||
points_to_analysis_(points_to_analysis),
|
points_to_analysis_(points_to_analysis),
|
||||||
@ -234,8 +235,7 @@ class ListScheduler {
|
|||||||
|
|
||||||
// Populate the ready list with instructions which have no operands or
|
// Populate the ready list with instructions which have no operands or
|
||||||
// control predecessors.
|
// control predecessors.
|
||||||
tensorflow::gtl::FlatMap<const HloInstruction*, int64>
|
absl::flat_hash_map<const HloInstruction*, int64> unscheduled_pred_count;
|
||||||
unscheduled_pred_count;
|
|
||||||
for (auto* instruction : computation_.instructions()) {
|
for (auto* instruction : computation_.instructions()) {
|
||||||
// TODO(b/34466113): Replace this and above with successors() or
|
// TODO(b/34466113): Replace this and above with successors() or
|
||||||
// predecessors() when these methods are added to HloInstruction.
|
// predecessors() when these methods are added to HloInstruction.
|
||||||
@ -251,8 +251,8 @@ class ListScheduler {
|
|||||||
std::multimap<Priority, ReadyListEntry> ready_queue;
|
std::multimap<Priority, ReadyListEntry> ready_queue;
|
||||||
|
|
||||||
// Map of ready instructions to their iterators in ready_queue.
|
// Map of ready instructions to their iterators in ready_queue.
|
||||||
tensorflow::gtl::FlatMap<const HloInstruction*,
|
absl::flat_hash_map<const HloInstruction*,
|
||||||
std::multimap<Priority, ReadyListEntry>::iterator>
|
std::multimap<Priority, ReadyListEntry>::iterator>
|
||||||
ready_instructions;
|
ready_instructions;
|
||||||
|
|
||||||
auto add_to_ready_queue = [&](HloInstruction* inst) {
|
auto add_to_ready_queue = [&](HloInstruction* inst) {
|
||||||
@ -347,12 +347,11 @@ class ListScheduler {
|
|||||||
// Computations are analyzed in post-order. When scheduling an instruction
|
// Computations are analyzed in post-order. When scheduling an instruction
|
||||||
// that includes subcomputations, such as a while loop, we use this map to
|
// that includes subcomputations, such as a while loop, we use this map to
|
||||||
// look up the memory needed by subcomputations.
|
// look up the memory needed by subcomputations.
|
||||||
const tensorflow::gtl::FlatMap<const HloComputation*, int64>&
|
const absl::flat_hash_map<const HloComputation*, int64>&
|
||||||
memory_by_computation_;
|
memory_by_computation_;
|
||||||
|
|
||||||
// A map containing the LogicalBuffers that each instruction uses.
|
// A map containing the LogicalBuffers that each instruction uses.
|
||||||
tensorflow::gtl::FlatMap<const HloInstruction*,
|
absl::flat_hash_map<const HloInstruction*, std::vector<const LogicalBuffer*>>
|
||||||
std::vector<const LogicalBuffer*>>
|
|
||||||
buffer_uses_;
|
buffer_uses_;
|
||||||
|
|
||||||
// A map containing the count of unscheduled HLOs which using a particular
|
// A map containing the count of unscheduled HLOs which using a particular
|
||||||
@ -379,7 +378,7 @@ StatusOr<HloInstructionSequence> ScheduleComputationHelper(
|
|||||||
const TuplePointsToAnalysis& points_to_analysis,
|
const TuplePointsToAnalysis& points_to_analysis,
|
||||||
const LogicalBuffer::SizeFunction& size_function,
|
const LogicalBuffer::SizeFunction& size_function,
|
||||||
const MemorySchedulerAlgorithm& algorithm,
|
const MemorySchedulerAlgorithm& algorithm,
|
||||||
const tensorflow::gtl::FlatMap<const HloComputation*, int64>&
|
const absl::flat_hash_map<const HloComputation*, int64>&
|
||||||
memory_by_computation) {
|
memory_by_computation) {
|
||||||
VLOG(2) << "Computation: " << computation.name();
|
VLOG(2) << "Computation: " << computation.name();
|
||||||
if (algorithm) {
|
if (algorithm) {
|
||||||
@ -396,13 +395,13 @@ StatusOr<HloInstructionSequence> DFSMemoryScheduler(
|
|||||||
const HloComputation& computation,
|
const HloComputation& computation,
|
||||||
const TuplePointsToAnalysis& points_to_analysis,
|
const TuplePointsToAnalysis& points_to_analysis,
|
||||||
const LogicalBuffer::SizeFunction& size_function,
|
const LogicalBuffer::SizeFunction& size_function,
|
||||||
const tensorflow::gtl::FlatMap<const HloComputation*, int64>&
|
const absl::flat_hash_map<const HloComputation*, int64>&
|
||||||
memory_by_computation) {
|
memory_by_computation) {
|
||||||
// These variables are a hack to prevent overflows.
|
// These variables are a hack to prevent overflows.
|
||||||
int64 cumulative_total_size = 0;
|
int64 cumulative_total_size = 0;
|
||||||
int64 total_hlos = computation.parent()->instruction_count();
|
int64 total_hlos = computation.parent()->instruction_count();
|
||||||
tensorflow::gtl::FlatMap<const HloInstruction*, int64> extra_users;
|
absl::flat_hash_map<const HloInstruction*, int64> extra_users;
|
||||||
tensorflow::gtl::FlatMap<const HloInstruction*, int64> total_sizes;
|
absl::flat_hash_map<const HloInstruction*, int64> total_sizes;
|
||||||
for (const HloInstruction* hlo : computation.MakeInstructionPostOrder()) {
|
for (const HloInstruction* hlo : computation.MakeInstructionPostOrder()) {
|
||||||
if (ListScheduler::IgnoreInstruction(*hlo)) {
|
if (ListScheduler::IgnoreInstruction(*hlo)) {
|
||||||
extra_users[hlo] = 0;
|
extra_users[hlo] = 0;
|
||||||
@ -467,7 +466,7 @@ StatusOr<HloInstructionSequence> ListMemoryScheduler(
|
|||||||
const HloComputation& computation,
|
const HloComputation& computation,
|
||||||
const TuplePointsToAnalysis& points_to_analysis,
|
const TuplePointsToAnalysis& points_to_analysis,
|
||||||
const LogicalBuffer::SizeFunction& size_function,
|
const LogicalBuffer::SizeFunction& size_function,
|
||||||
const tensorflow::gtl::FlatMap<const HloComputation*, int64>&
|
const absl::flat_hash_map<const HloComputation*, int64>&
|
||||||
memory_by_computation) {
|
memory_by_computation) {
|
||||||
return ListScheduler::Run(computation, points_to_analysis, size_function,
|
return ListScheduler::Run(computation, points_to_analysis, size_function,
|
||||||
memory_by_computation);
|
memory_by_computation);
|
||||||
@ -477,7 +476,7 @@ StatusOr<HloInstructionSequence> PostOrderMemoryScheduler(
|
|||||||
const HloComputation& computation,
|
const HloComputation& computation,
|
||||||
const TuplePointsToAnalysis& points_to_analysis,
|
const TuplePointsToAnalysis& points_to_analysis,
|
||||||
const LogicalBuffer::SizeFunction& size_function,
|
const LogicalBuffer::SizeFunction& size_function,
|
||||||
const tensorflow::gtl::FlatMap<const HloComputation*, int64>&
|
const absl::flat_hash_map<const HloComputation*, int64>&
|
||||||
memory_by_computation) {
|
memory_by_computation) {
|
||||||
return HloInstructionSequence(computation.MakeInstructionPostOrder());
|
return HloInstructionSequence(computation.MakeInstructionPostOrder());
|
||||||
}
|
}
|
||||||
@ -486,7 +485,7 @@ StatusOr<HloInstructionSequence> DefaultMemoryScheduler(
|
|||||||
const HloComputation& computation,
|
const HloComputation& computation,
|
||||||
const TuplePointsToAnalysis& points_to_analysis,
|
const TuplePointsToAnalysis& points_to_analysis,
|
||||||
const LogicalBuffer::SizeFunction& size_function,
|
const LogicalBuffer::SizeFunction& size_function,
|
||||||
const tensorflow::gtl::FlatMap<const HloComputation*, int64>&
|
const absl::flat_hash_map<const HloComputation*, int64>&
|
||||||
memory_by_computation) {
|
memory_by_computation) {
|
||||||
// We try a few schedulers and choose whichever returns a lower min-memory,
|
// We try a few schedulers and choose whichever returns a lower min-memory,
|
||||||
// not accounting for fragmentation.
|
// not accounting for fragmentation.
|
||||||
@ -549,7 +548,7 @@ StatusOr<HloSchedule> ScheduleModule(
|
|||||||
HloSchedule schedule(&module);
|
HloSchedule schedule(&module);
|
||||||
TF_ASSIGN_OR_RETURN(std::unique_ptr<TuplePointsToAnalysis> points_to_analysis,
|
TF_ASSIGN_OR_RETURN(std::unique_ptr<TuplePointsToAnalysis> points_to_analysis,
|
||||||
TuplePointsToAnalysis::Run(&module));
|
TuplePointsToAnalysis::Run(&module));
|
||||||
tensorflow::gtl::FlatMap<const HloComputation*, int64> memory_by_computation;
|
absl::flat_hash_map<const HloComputation*, int64> memory_by_computation;
|
||||||
for (const auto* computation : module.MakeComputationPostOrder()) {
|
for (const auto* computation : module.MakeComputationPostOrder()) {
|
||||||
if (!computation->IsFusionComputation()) {
|
if (!computation->IsFusionComputation()) {
|
||||||
TF_ASSIGN_OR_RETURN(HloInstructionSequence computation_sequence,
|
TF_ASSIGN_OR_RETURN(HloInstructionSequence computation_sequence,
|
||||||
@ -577,7 +576,7 @@ StatusOr<HloInstructionSequence> ScheduleComputation(
|
|||||||
CHECK(!computation.IsFusionComputation());
|
CHECK(!computation.IsFusionComputation());
|
||||||
TF_ASSIGN_OR_RETURN(std::unique_ptr<TuplePointsToAnalysis> points_to_analysis,
|
TF_ASSIGN_OR_RETURN(std::unique_ptr<TuplePointsToAnalysis> points_to_analysis,
|
||||||
TuplePointsToAnalysis::Run(computation.parent()));
|
TuplePointsToAnalysis::Run(computation.parent()));
|
||||||
tensorflow::gtl::FlatMap<const HloComputation*, int64> empty_map;
|
absl::flat_hash_map<const HloComputation*, int64> empty_map;
|
||||||
return ScheduleComputationHelper(computation, *points_to_analysis,
|
return ScheduleComputationHelper(computation, *points_to_analysis,
|
||||||
size_function, nullptr, empty_map);
|
size_function, nullptr, empty_map);
|
||||||
}
|
}
|
||||||
|
@ -18,6 +18,7 @@ limitations under the License.
|
|||||||
|
|
||||||
#include <vector>
|
#include <vector>
|
||||||
|
|
||||||
|
#include "absl/container/flat_hash_map.h"
|
||||||
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
|
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
|
||||||
#include "tensorflow/compiler/xla/service/hlo_module.h"
|
#include "tensorflow/compiler/xla/service/hlo_module.h"
|
||||||
#include "tensorflow/compiler/xla/service/hlo_ordering.h"
|
#include "tensorflow/compiler/xla/service/hlo_ordering.h"
|
||||||
@ -37,7 +38,7 @@ namespace xla {
|
|||||||
typedef std::function<StatusOr<HloInstructionSequence>(
|
typedef std::function<StatusOr<HloInstructionSequence>(
|
||||||
const HloComputation&, const TuplePointsToAnalysis&,
|
const HloComputation&, const TuplePointsToAnalysis&,
|
||||||
const LogicalBuffer::SizeFunction&,
|
const LogicalBuffer::SizeFunction&,
|
||||||
const tensorflow::gtl::FlatMap<const HloComputation*, int64>&)>
|
const absl::flat_hash_map<const HloComputation*, int64>&)>
|
||||||
MemorySchedulerAlgorithm;
|
MemorySchedulerAlgorithm;
|
||||||
|
|
||||||
// List scheduler
|
// List scheduler
|
||||||
@ -45,7 +46,7 @@ StatusOr<HloInstructionSequence> ListMemoryScheduler(
|
|||||||
const HloComputation& computation,
|
const HloComputation& computation,
|
||||||
const TuplePointsToAnalysis& points_to_analysis,
|
const TuplePointsToAnalysis& points_to_analysis,
|
||||||
const LogicalBuffer::SizeFunction& size_function,
|
const LogicalBuffer::SizeFunction& size_function,
|
||||||
const tensorflow::gtl::FlatMap<const HloComputation*, int64>&
|
const absl::flat_hash_map<const HloComputation*, int64>&
|
||||||
memory_by_computation);
|
memory_by_computation);
|
||||||
|
|
||||||
// DFS-order scheduler
|
// DFS-order scheduler
|
||||||
@ -53,7 +54,7 @@ StatusOr<HloInstructionSequence> DFSMemoryScheduler(
|
|||||||
const HloComputation& computation,
|
const HloComputation& computation,
|
||||||
const TuplePointsToAnalysis& points_to_analysis,
|
const TuplePointsToAnalysis& points_to_analysis,
|
||||||
const LogicalBuffer::SizeFunction& size_function,
|
const LogicalBuffer::SizeFunction& size_function,
|
||||||
const tensorflow::gtl::FlatMap<const HloComputation*, int64>&
|
const absl::flat_hash_map<const HloComputation*, int64>&
|
||||||
memory_by_computation);
|
memory_by_computation);
|
||||||
|
|
||||||
// Naive Post Order scheduler
|
// Naive Post Order scheduler
|
||||||
@ -61,7 +62,7 @@ StatusOr<HloInstructionSequence> PostOrderMemoryScheduler(
|
|||||||
const HloComputation& computation,
|
const HloComputation& computation,
|
||||||
const TuplePointsToAnalysis& points_to_analysis,
|
const TuplePointsToAnalysis& points_to_analysis,
|
||||||
const LogicalBuffer::SizeFunction& size_function,
|
const LogicalBuffer::SizeFunction& size_function,
|
||||||
const tensorflow::gtl::FlatMap<const HloComputation*, int64>&
|
const absl::flat_hash_map<const HloComputation*, int64>&
|
||||||
memory_by_computation);
|
memory_by_computation);
|
||||||
|
|
||||||
// The default scheduling algorithm. Runs both the list scheduler
|
// The default scheduling algorithm. Runs both the list scheduler
|
||||||
@ -71,7 +72,7 @@ StatusOr<HloInstructionSequence> DefaultMemoryScheduler(
|
|||||||
const HloComputation& computation,
|
const HloComputation& computation,
|
||||||
const TuplePointsToAnalysis& points_to_analysis,
|
const TuplePointsToAnalysis& points_to_analysis,
|
||||||
const LogicalBuffer::SizeFunction& size_function,
|
const LogicalBuffer::SizeFunction& size_function,
|
||||||
const tensorflow::gtl::FlatMap<const HloComputation*, int64>&
|
const absl::flat_hash_map<const HloComputation*, int64>&
|
||||||
memory_by_computation);
|
memory_by_computation);
|
||||||
|
|
||||||
// Returns an HloSchedule which seeks to minimize the memory required for
|
// Returns an HloSchedule which seeks to minimize the memory required for
|
||||||
|
@ -19,6 +19,7 @@ limitations under the License.
|
|||||||
#include <string>
|
#include <string>
|
||||||
|
|
||||||
#include "absl/algorithm/container.h"
|
#include "absl/algorithm/container.h"
|
||||||
|
#include "absl/container/flat_hash_map.h"
|
||||||
#include "tensorflow/compiler/xla/service/heap_simulator.h"
|
#include "tensorflow/compiler/xla/service/heap_simulator.h"
|
||||||
#include "tensorflow/compiler/xla/service/hlo_computation.h"
|
#include "tensorflow/compiler/xla/service/hlo_computation.h"
|
||||||
#include "tensorflow/compiler/xla/service/hlo_dce.h"
|
#include "tensorflow/compiler/xla/service/hlo_dce.h"
|
||||||
@ -247,7 +248,7 @@ TEST_F(HloSchedulingTest, ListAccountsForSubcomputations) {
|
|||||||
EXPECT_TRUE(ordering.ExecutesBefore(bcast, add));
|
EXPECT_TRUE(ordering.ExecutesBefore(bcast, add));
|
||||||
EXPECT_TRUE(ordering.ExecutesBefore(transpose, add));
|
EXPECT_TRUE(ordering.ExecutesBefore(transpose, add));
|
||||||
|
|
||||||
tensorflow::gtl::FlatMap<const HloComputation*, int64> memory_by_computation;
|
absl::flat_hash_map<const HloComputation*, int64> memory_by_computation;
|
||||||
memory_by_computation[cond_computation] = 17;
|
memory_by_computation[cond_computation] = 17;
|
||||||
memory_by_computation[body_computation] = 16;
|
memory_by_computation[body_computation] = 16;
|
||||||
std::unique_ptr<TuplePointsToAnalysis> points_to_analysis =
|
std::unique_ptr<TuplePointsToAnalysis> points_to_analysis =
|
||||||
@ -409,7 +410,7 @@ TEST_F(HloSchedulingTest, HeapSimulatorAccountsForSubcomputations) {
|
|||||||
EXPECT_EQ(module->entry_computation()->instruction_count(),
|
EXPECT_EQ(module->entry_computation()->instruction_count(),
|
||||||
schedule.sequence(module->entry_computation()).size());
|
schedule.sequence(module->entry_computation()).size());
|
||||||
|
|
||||||
tensorflow::gtl::FlatMap<const HloComputation*, int64> memory_by_computation;
|
absl::flat_hash_map<const HloComputation*, int64> memory_by_computation;
|
||||||
memory_by_computation[cond_computation] = 17;
|
memory_by_computation[cond_computation] = 17;
|
||||||
memory_by_computation[body_computation] = 16;
|
memory_by_computation[body_computation] = 16;
|
||||||
std::unique_ptr<TuplePointsToAnalysis> points_to_analysis =
|
std::unique_ptr<TuplePointsToAnalysis> points_to_analysis =
|
||||||
|
@ -23,6 +23,7 @@ limitations under the License.
|
|||||||
#include <utility>
|
#include <utility>
|
||||||
|
|
||||||
#include "absl/algorithm/container.h"
|
#include "absl/algorithm/container.h"
|
||||||
|
#include "absl/container/flat_hash_map.h"
|
||||||
#include "absl/memory/memory.h"
|
#include "absl/memory/memory.h"
|
||||||
#include "absl/strings/str_cat.h"
|
#include "absl/strings/str_cat.h"
|
||||||
#include "tensorflow/compiler/xla/map_util.h"
|
#include "tensorflow/compiler/xla/map_util.h"
|
||||||
@ -285,8 +286,8 @@ StatusOr<std::unique_ptr<HloModule>> HloModule::CreateFromProto(
|
|||||||
<< ShapeUtil::HumanStringWithLayout(expected_program_shape.result())
|
<< ShapeUtil::HumanStringWithLayout(expected_program_shape.result())
|
||||||
<< ", actual: " << ShapeUtil::HumanStringWithLayout(result_shape);
|
<< ", actual: " << ShapeUtil::HumanStringWithLayout(result_shape);
|
||||||
|
|
||||||
tensorflow::gtl::FlatMap<int64, HloComputation*> computation_map;
|
absl::flat_hash_map<int64, HloComputation*> computation_map;
|
||||||
tensorflow::gtl::FlatMap<HloComputation*, int64> to_proto_id;
|
absl::flat_hash_map<HloComputation*, int64> to_proto_id;
|
||||||
std::vector<std::unique_ptr<HloComputation>> computations;
|
std::vector<std::unique_ptr<HloComputation>> computations;
|
||||||
HloComputation* entry = nullptr;
|
HloComputation* entry = nullptr;
|
||||||
for (const HloComputationProto& computation_proto : proto.computations()) {
|
for (const HloComputationProto& computation_proto : proto.computations()) {
|
||||||
|
@ -22,6 +22,7 @@ limitations under the License.
|
|||||||
#include <unordered_set>
|
#include <unordered_set>
|
||||||
#include <vector>
|
#include <vector>
|
||||||
|
|
||||||
|
#include "absl/container/flat_hash_map.h"
|
||||||
#include "absl/types/optional.h"
|
#include "absl/types/optional.h"
|
||||||
#include "tensorflow/compiler/xla/service/hlo_computation.h"
|
#include "tensorflow/compiler/xla/service/hlo_computation.h"
|
||||||
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
|
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
|
||||||
@ -30,7 +31,6 @@ limitations under the License.
|
|||||||
#include "tensorflow/compiler/xla/status.h"
|
#include "tensorflow/compiler/xla/status.h"
|
||||||
#include "tensorflow/compiler/xla/statusor.h"
|
#include "tensorflow/compiler/xla/statusor.h"
|
||||||
#include "tensorflow/core/lib/core/status.h"
|
#include "tensorflow/core/lib/core/status.h"
|
||||||
#include "tensorflow/core/lib/gtl/flatmap.h"
|
|
||||||
#include "tensorflow/core/platform/types.h"
|
#include "tensorflow/core/platform/types.h"
|
||||||
|
|
||||||
namespace xla {
|
namespace xla {
|
||||||
@ -250,25 +250,25 @@ class HloModuleGroupMetadata {
|
|||||||
std::vector<std::unique_ptr<std::vector<HloInstruction*>>> companion_sets_;
|
std::vector<std::unique_ptr<std::vector<HloInstruction*>>> companion_sets_;
|
||||||
|
|
||||||
// Map from each companion while instruction to the index into companion_set_.
|
// Map from each companion while instruction to the index into companion_set_.
|
||||||
tensorflow::gtl::FlatMap<const HloInstruction*, int64> companion_set_index_;
|
absl::flat_hash_map<const HloInstruction*, int64> companion_set_index_;
|
||||||
|
|
||||||
// Map from computation to the instruction using it (a kWhile, kConditional).
|
// Map from computation to the instruction using it (a kWhile, kConditional).
|
||||||
tensorflow::gtl::FlatMap<const HloComputation*, TrackedInstruction>
|
absl::flat_hash_map<const HloComputation*, TrackedInstruction>
|
||||||
tracked_instructions_;
|
tracked_instructions_;
|
||||||
|
|
||||||
// Maps tracked instructions (kWhile, kConditional, kCall, ...) to the set of
|
// Maps tracked instructions (kWhile, kConditional, kCall, ...) to the set of
|
||||||
// communicating instructions within the proper called computation(s).
|
// communicating instructions within the proper called computation(s).
|
||||||
tensorflow::gtl::FlatMap<HloInstruction*, std::vector<HloInstruction*>>
|
absl::flat_hash_map<HloInstruction*, std::vector<HloInstruction*>>
|
||||||
tracked_instructions_comms_;
|
tracked_instructions_comms_;
|
||||||
|
|
||||||
// All channels in the module.
|
// All channels in the module.
|
||||||
std::vector<Channel> channels_;
|
std::vector<Channel> channels_;
|
||||||
|
|
||||||
// Map from channel ids to the index in channels_.
|
// Map from channel ids to the index in channels_.
|
||||||
tensorflow::gtl::FlatMap<int64, int64> channel_id_map_;
|
absl::flat_hash_map<int64, int64> channel_id_map_;
|
||||||
|
|
||||||
// Map from all-reduce ids to the all reduce instructions.
|
// Map from all-reduce ids to the all reduce instructions.
|
||||||
tensorflow::gtl::FlatMap<int64, std::vector<HloInstruction*>> all_reduce_map_;
|
absl::flat_hash_map<int64, std::vector<HloInstruction*>> all_reduce_map_;
|
||||||
|
|
||||||
// The maximum channel id used in the module group.
|
// The maximum channel id used in the module group.
|
||||||
int64 max_channel_id_ = -1;
|
int64 max_channel_id_ = -1;
|
||||||
@ -276,7 +276,7 @@ class HloModuleGroupMetadata {
|
|||||||
// The modules that this metadata was built from.
|
// The modules that this metadata was built from.
|
||||||
const std::vector<HloModule*>& modules_;
|
const std::vector<HloModule*>& modules_;
|
||||||
|
|
||||||
tensorflow::gtl::FlatMap<HloModule*, std::unique_ptr<TuplePointsToAnalysis>>
|
absl::flat_hash_map<HloModule*, std::unique_ptr<TuplePointsToAnalysis>>
|
||||||
points_to_analyses_;
|
points_to_analyses_;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
@ -20,6 +20,7 @@ limitations under the License.
|
|||||||
#include <memory>
|
#include <memory>
|
||||||
#include <vector>
|
#include <vector>
|
||||||
|
|
||||||
|
#include "absl/container/flat_hash_map.h"
|
||||||
#include "absl/types/span.h"
|
#include "absl/types/span.h"
|
||||||
#include "tensorflow/compiler/xla/service/hlo_computation.h"
|
#include "tensorflow/compiler/xla/service/hlo_computation.h"
|
||||||
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
|
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
|
||||||
@ -28,7 +29,6 @@ limitations under the License.
|
|||||||
#include "tensorflow/compiler/xla/status.h"
|
#include "tensorflow/compiler/xla/status.h"
|
||||||
#include "tensorflow/compiler/xla/statusor.h"
|
#include "tensorflow/compiler/xla/statusor.h"
|
||||||
#include "tensorflow/core/lib/core/status.h"
|
#include "tensorflow/core/lib/core/status.h"
|
||||||
#include "tensorflow/core/lib/gtl/flatmap.h"
|
|
||||||
|
|
||||||
namespace xla {
|
namespace xla {
|
||||||
|
|
||||||
@ -87,7 +87,7 @@ class HloModuleGroupUtil {
|
|||||||
// * visit_state: map from each instruction to its visit state.
|
// * visit_state: map from each instruction to its visit state.
|
||||||
// * visit_function: function called when each instruction group.
|
// * visit_function: function called when each instruction group.
|
||||||
// * root: the root instruction of the traversal.
|
// * root: the root instruction of the traversal.
|
||||||
using VisitStates = tensorflow::gtl::FlatMap<HloInstruction*, VisitState>;
|
using VisitStates = absl::flat_hash_map<HloInstruction*, VisitState>;
|
||||||
Status VisitTopologicalOrder(VisitStates* visit_state,
|
Status VisitTopologicalOrder(VisitStates* visit_state,
|
||||||
const VisitFunction& visit_function,
|
const VisitFunction& visit_function,
|
||||||
HloInstruction* root);
|
HloInstruction* root);
|
||||||
|
@ -14,9 +14,9 @@ limitations under the License.
|
|||||||
==============================================================================*/
|
==============================================================================*/
|
||||||
|
|
||||||
#include "tensorflow/compiler/xla/service/hlo_opcode.h"
|
#include "tensorflow/compiler/xla/service/hlo_opcode.h"
|
||||||
|
#include "absl/container/flat_hash_map.h"
|
||||||
#include "tensorflow/compiler/xla/types.h"
|
#include "tensorflow/compiler/xla/types.h"
|
||||||
#include "tensorflow/compiler/xla/util.h"
|
#include "tensorflow/compiler/xla/util.h"
|
||||||
#include "tensorflow/core/lib/gtl/flatmap.h"
|
|
||||||
|
|
||||||
namespace xla {
|
namespace xla {
|
||||||
|
|
||||||
@ -31,7 +31,7 @@ string HloOpcodeString(HloOpcode opcode) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
StatusOr<HloOpcode> StringToHloOpcode(const string& opcode_name) {
|
StatusOr<HloOpcode> StringToHloOpcode(const string& opcode_name) {
|
||||||
static auto* opcode_map = new tensorflow::gtl::FlatMap<string, HloOpcode>({
|
static auto* opcode_map = new absl::flat_hash_map<string, HloOpcode>({
|
||||||
#define STRING_TO_OPCODE_ENTRY(enum_name, opcode_name, ...) \
|
#define STRING_TO_OPCODE_ENTRY(enum_name, opcode_name, ...) \
|
||||||
{opcode_name, HloOpcode::enum_name},
|
{opcode_name, HloOpcode::enum_name},
|
||||||
HLO_OPCODE_LIST(STRING_TO_OPCODE_ENTRY)
|
HLO_OPCODE_LIST(STRING_TO_OPCODE_ENTRY)
|
||||||
|
@ -20,6 +20,7 @@ limitations under the License.
|
|||||||
#include <string>
|
#include <string>
|
||||||
#include <utility>
|
#include <utility>
|
||||||
|
|
||||||
|
#include "absl/container/flat_hash_map.h"
|
||||||
#include "tensorflow/compiler/xla/service/call_graph.h"
|
#include "tensorflow/compiler/xla/service/call_graph.h"
|
||||||
#include "tensorflow/compiler/xla/service/hlo.pb.h"
|
#include "tensorflow/compiler/xla/service/hlo.pb.h"
|
||||||
#include "tensorflow/compiler/xla/service/hlo_dataflow_analysis.h"
|
#include "tensorflow/compiler/xla/service/hlo_dataflow_analysis.h"
|
||||||
@ -28,7 +29,6 @@ limitations under the License.
|
|||||||
#include "tensorflow/compiler/xla/service/hlo_schedule.h"
|
#include "tensorflow/compiler/xla/service/hlo_schedule.h"
|
||||||
#include "tensorflow/compiler/xla/service/hlo_value.h"
|
#include "tensorflow/compiler/xla/service/hlo_value.h"
|
||||||
#include "tensorflow/compiler/xla/types.h"
|
#include "tensorflow/compiler/xla/types.h"
|
||||||
#include "tensorflow/core/lib/gtl/flatmap.h"
|
|
||||||
|
|
||||||
namespace xla {
|
namespace xla {
|
||||||
|
|
||||||
@ -120,8 +120,8 @@ class PredecessorHloOrdering : public HloOrdering {
|
|||||||
// predecessors. An instruction is an element of its own predecessor set.
|
// predecessors. An instruction is an element of its own predecessor set.
|
||||||
//
|
//
|
||||||
// Subclasses should fill this in to define the desired ordering.
|
// Subclasses should fill this in to define the desired ordering.
|
||||||
tensorflow::gtl::FlatMap<const HloComputation*,
|
absl::flat_hash_map<const HloComputation*,
|
||||||
std::unique_ptr<HloReachabilityMap>>
|
std::unique_ptr<HloReachabilityMap>>
|
||||||
predecessors_;
|
predecessors_;
|
||||||
};
|
};
|
||||||
|
|
||||||
@ -204,7 +204,7 @@ class SequentialHloOrdering : public HloOrdering {
|
|||||||
// this map so more than one instruction may have the same position
|
// this map so more than one instruction may have the same position
|
||||||
// value. This is not a problem because ExecutesBefore also verifies
|
// value. This is not a problem because ExecutesBefore also verifies
|
||||||
// instructions are in the same computation.
|
// instructions are in the same computation.
|
||||||
tensorflow::gtl::FlatMap<const HloInstruction*, int> order_position_;
|
absl::flat_hash_map<const HloInstruction*, int> order_position_;
|
||||||
};
|
};
|
||||||
|
|
||||||
} // namespace xla
|
} // namespace xla
|
||||||
|
@ -17,6 +17,7 @@ limitations under the License.
|
|||||||
|
|
||||||
#include <functional>
|
#include <functional>
|
||||||
|
|
||||||
|
#include "absl/container/flat_hash_map.h"
|
||||||
#include "absl/strings/str_format.h"
|
#include "absl/strings/str_format.h"
|
||||||
#include "absl/strings/str_join.h"
|
#include "absl/strings/str_join.h"
|
||||||
#include "tensorflow/compiler/xla/service/hlo_graph_dumper.h"
|
#include "tensorflow/compiler/xla/service/hlo_graph_dumper.h"
|
||||||
@ -98,7 +99,7 @@ void HloPassPipeline::MaybeDumpHlo(const HloModule& module,
|
|||||||
if (!proto_dump_path.empty()) {
|
if (!proto_dump_path.empty()) {
|
||||||
static tensorflow::mutex mu(tensorflow::LINKER_INITIALIZED);
|
static tensorflow::mutex mu(tensorflow::LINKER_INITIALIZED);
|
||||||
static auto* const module_id_to_pass_number =
|
static auto* const module_id_to_pass_number =
|
||||||
new tensorflow::gtl::FlatMap<int64, int64>();
|
new absl::flat_hash_map<int64, int64>();
|
||||||
|
|
||||||
tensorflow::mutex_lock lock(mu);
|
tensorflow::mutex_lock lock(mu);
|
||||||
const int64 pass_number = (*module_id_to_pass_number)[module.unique_id()]++;
|
const int64 pass_number = (*module_id_to_pass_number)[module.unique_id()]++;
|
||||||
|
@ -19,11 +19,11 @@ limitations under the License.
|
|||||||
#include <list>
|
#include <list>
|
||||||
#include <vector>
|
#include <vector>
|
||||||
|
|
||||||
|
#include "absl/container/flat_hash_map.h"
|
||||||
#include "absl/types/span.h"
|
#include "absl/types/span.h"
|
||||||
#include "tensorflow/compiler/xla/map_util.h"
|
#include "tensorflow/compiler/xla/map_util.h"
|
||||||
#include "tensorflow/compiler/xla/types.h"
|
#include "tensorflow/compiler/xla/types.h"
|
||||||
#include "tensorflow/core/lib/core/status.h"
|
#include "tensorflow/core/lib/core/status.h"
|
||||||
#include "tensorflow/core/lib/gtl/flatmap.h"
|
|
||||||
#include "tensorflow/core/platform/types.h"
|
#include "tensorflow/core/platform/types.h"
|
||||||
|
|
||||||
namespace xla {
|
namespace xla {
|
||||||
@ -154,7 +154,7 @@ class HloReachabilityMap {
|
|||||||
|
|
||||||
// Dense assignment from HloInstruction* to number. These numbers index
|
// Dense assignment from HloInstruction* to number. These numbers index
|
||||||
// into the bit_vectors_ vector and into the bits within a BitVector.
|
// into the bit_vectors_ vector and into the bits within a BitVector.
|
||||||
tensorflow::gtl::FlatMap<const HloInstruction*, int> indices_;
|
absl::flat_hash_map<const HloInstruction*, int> indices_;
|
||||||
|
|
||||||
// Bitvectors holding the reachability to each instruction. The bit vector for
|
// Bitvectors holding the reachability to each instruction. The bit vector for
|
||||||
// instruction X includes ones for each instruction which X is reachable from.
|
// instruction X includes ones for each instruction which X is reachable from.
|
||||||
|
@ -20,6 +20,7 @@ limitations under the License.
|
|||||||
#include <set>
|
#include <set>
|
||||||
#include <string>
|
#include <string>
|
||||||
|
|
||||||
|
#include "absl/container/flat_hash_map.h"
|
||||||
#include "absl/container/inlined_vector.h"
|
#include "absl/container/inlined_vector.h"
|
||||||
#include "absl/strings/str_cat.h"
|
#include "absl/strings/str_cat.h"
|
||||||
#include "absl/strings/str_format.h"
|
#include "absl/strings/str_format.h"
|
||||||
@ -75,7 +76,7 @@ bool IsRematerializable(const HloInstruction* instruction) {
|
|||||||
// cache before, and eventually calling the IsRematerializable() API.
|
// cache before, and eventually calling the IsRematerializable() API.
|
||||||
bool CanBeRematerialized(
|
bool CanBeRematerialized(
|
||||||
const HloInstruction* instruction,
|
const HloInstruction* instruction,
|
||||||
tensorflow::gtl::FlatMap<const HloInstruction*, bool>* remat_able) {
|
absl::flat_hash_map<const HloInstruction*, bool>* remat_able) {
|
||||||
auto it = remat_able->find(instruction);
|
auto it = remat_able->find(instruction);
|
||||||
if (it != remat_able->end()) {
|
if (it != remat_able->end()) {
|
||||||
return it->second;
|
return it->second;
|
||||||
@ -268,7 +269,7 @@ class InstructionList {
|
|||||||
Item* first_;
|
Item* first_;
|
||||||
|
|
||||||
// Item for each instruction.
|
// Item for each instruction.
|
||||||
tensorflow::gtl::FlatMap<const HloInstruction*, Item*> item_map_;
|
absl::flat_hash_map<const HloInstruction*, Item*> item_map_;
|
||||||
};
|
};
|
||||||
|
|
||||||
// Return the items which use the given LogicalBuffer. Sets
|
// Return the items which use the given LogicalBuffer. Sets
|
||||||
@ -503,7 +504,7 @@ MemoryUsageTracker::MemoryUsageTracker(
|
|||||||
PointsToSet::BufferSet live_out_set =
|
PointsToSet::BufferSet live_out_set =
|
||||||
points_to_analysis.GetPointsToSet(computation_->root_instruction())
|
points_to_analysis.GetPointsToSet(computation_->root_instruction())
|
||||||
.CreateFlattenedSet();
|
.CreateFlattenedSet();
|
||||||
tensorflow::gtl::FlatMap<const LogicalBuffer*, BufferId>
|
absl::flat_hash_map<const LogicalBuffer*, BufferId>
|
||||||
logical_buffer_to_buffer_id;
|
logical_buffer_to_buffer_id;
|
||||||
|
|
||||||
for (auto* item = instruction_list_.first(); item != nullptr;
|
for (auto* item = instruction_list_.first(); item != nullptr;
|
||||||
@ -854,7 +855,7 @@ int64 RematerializationCost(const HloInstruction* instruction,
|
|||||||
Item* PickRematerializationCandidate(
|
Item* PickRematerializationCandidate(
|
||||||
const MemoryUsageTracker& memory_tracker,
|
const MemoryUsageTracker& memory_tracker,
|
||||||
const InstructionList& instruction_list, int64 memory_limit_bytes,
|
const InstructionList& instruction_list, int64 memory_limit_bytes,
|
||||||
tensorflow::gtl::FlatMap<const HloInstruction*, bool>* remat_able) {
|
absl::flat_hash_map<const HloInstruction*, bool>* remat_able) {
|
||||||
Item* best_item = nullptr;
|
Item* best_item = nullptr;
|
||||||
int64 best_cost = 0;
|
int64 best_cost = 0;
|
||||||
|
|
||||||
@ -983,7 +984,7 @@ StatusOr<bool> HloRematerialization::RematerializeComputation(
|
|||||||
tensorflow::gtl::FlatSet<const HloInstruction*> remat_move_instructions;
|
tensorflow::gtl::FlatSet<const HloInstruction*> remat_move_instructions;
|
||||||
|
|
||||||
// The map from instructions to their rematerializable status.
|
// The map from instructions to their rematerializable status.
|
||||||
tensorflow::gtl::FlatMap<const HloInstruction*, bool> remat_able;
|
absl::flat_hash_map<const HloInstruction*, bool> remat_able;
|
||||||
|
|
||||||
// The peak memory of the computation at any point in the instruction
|
// The peak memory of the computation at any point in the instruction
|
||||||
// sequence.
|
// sequence.
|
||||||
|
@ -15,6 +15,7 @@
|
|||||||
#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_HLO_REMATERIALIZATION_H_
|
#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_HLO_REMATERIALIZATION_H_
|
||||||
#define TENSORFLOW_COMPILER_XLA_SERVICE_HLO_REMATERIALIZATION_H_
|
#define TENSORFLOW_COMPILER_XLA_SERVICE_HLO_REMATERIALIZATION_H_
|
||||||
|
|
||||||
|
#include "absl/container/flat_hash_map.h"
|
||||||
#include "tensorflow/compiler/xla/service/buffer_liveness.h"
|
#include "tensorflow/compiler/xla/service/buffer_liveness.h"
|
||||||
#include "tensorflow/compiler/xla/service/call_graph.h"
|
#include "tensorflow/compiler/xla/service/call_graph.h"
|
||||||
#include "tensorflow/compiler/xla/service/hlo_computation.h"
|
#include "tensorflow/compiler/xla/service/hlo_computation.h"
|
||||||
@ -115,8 +116,7 @@ class HloRematerialization : public HloModulePass {
|
|||||||
// computations called from sequential context
|
// computations called from sequential context
|
||||||
// (CallContext::kSequential). These values are updated as rematerialization
|
// (CallContext::kSequential). These values are updated as rematerialization
|
||||||
// occurs.
|
// occurs.
|
||||||
tensorflow::gtl::FlatMap<const HloComputation*, int64>
|
absl::flat_hash_map<const HloComputation*, int64> computation_peak_memory_;
|
||||||
computation_peak_memory_;
|
|
||||||
|
|
||||||
std::unique_ptr<TuplePointsToAnalysis> points_to_analysis_;
|
std::unique_ptr<TuplePointsToAnalysis> points_to_analysis_;
|
||||||
|
|
||||||
|
@ -18,6 +18,7 @@ limitations under the License.
|
|||||||
#include <queue>
|
#include <queue>
|
||||||
#include <vector>
|
#include <vector>
|
||||||
|
|
||||||
|
#include "absl/container/flat_hash_map.h"
|
||||||
#include "absl/strings/str_format.h"
|
#include "absl/strings/str_format.h"
|
||||||
#include "absl/strings/str_join.h"
|
#include "absl/strings/str_join.h"
|
||||||
#include "tensorflow/compiler/xla/map_util.h"
|
#include "tensorflow/compiler/xla/map_util.h"
|
||||||
@ -30,7 +31,7 @@ namespace xla {
|
|||||||
|
|
||||||
/* static */ StatusOr<HloSchedule> HloSchedule::CreateFromProto(
|
/* static */ StatusOr<HloSchedule> HloSchedule::CreateFromProto(
|
||||||
const HloModule* module, const HloScheduleProto& proto) {
|
const HloModule* module, const HloScheduleProto& proto) {
|
||||||
tensorflow::gtl::FlatMap<int64, const HloComputation*> id_to_computation;
|
absl::flat_hash_map<int64, const HloComputation*> id_to_computation;
|
||||||
for (const HloComputation* computation : module->computations()) {
|
for (const HloComputation* computation : module->computations()) {
|
||||||
id_to_computation[computation->unique_id()] = computation;
|
id_to_computation[computation->unique_id()] = computation;
|
||||||
}
|
}
|
||||||
@ -44,7 +45,7 @@ namespace xla {
|
|||||||
<< "No computation exists in HLO module with id " << computation_id;
|
<< "No computation exists in HLO module with id " << computation_id;
|
||||||
const HloComputation* computation = comp_it->second;
|
const HloComputation* computation = comp_it->second;
|
||||||
|
|
||||||
tensorflow::gtl::FlatMap<int64, const HloInstruction*> id_to_instruction;
|
absl::flat_hash_map<int64, const HloInstruction*> id_to_instruction;
|
||||||
for (const HloInstruction* instruction : computation->instructions()) {
|
for (const HloInstruction* instruction : computation->instructions()) {
|
||||||
id_to_instruction[instruction->unique_id()] = instruction;
|
id_to_instruction[instruction->unique_id()] = instruction;
|
||||||
}
|
}
|
||||||
@ -112,7 +113,7 @@ Status HloSchedule::UpdateComputationSchedule(
|
|||||||
const HloComputation* computation) {
|
const HloComputation* computation) {
|
||||||
// Map from unique ID to HloInstruction pointer for instructions in the
|
// Map from unique ID to HloInstruction pointer for instructions in the
|
||||||
// computation.
|
// computation.
|
||||||
tensorflow::gtl::FlatMap<int, const HloInstruction*> id_to_instruction;
|
absl::flat_hash_map<int, const HloInstruction*> id_to_instruction;
|
||||||
for (const HloInstruction* instruction : computation->instructions()) {
|
for (const HloInstruction* instruction : computation->instructions()) {
|
||||||
InsertOrDie(&id_to_instruction, instruction->unique_id(), instruction);
|
InsertOrDie(&id_to_instruction, instruction->unique_id(), instruction);
|
||||||
}
|
}
|
||||||
@ -126,15 +127,13 @@ Status HloSchedule::UpdateComputationSchedule(
|
|||||||
// Map from HloInstruction X to newly added instructions (instruction is in
|
// Map from HloInstruction X to newly added instructions (instruction is in
|
||||||
// computation, but not in schedule) which use X. If an instruction is not in
|
// computation, but not in schedule) which use X. If an instruction is not in
|
||||||
// the map, then it has no users which are newly added instructions.
|
// the map, then it has no users which are newly added instructions.
|
||||||
tensorflow::gtl::FlatMap<const HloInstruction*,
|
absl::flat_hash_map<const HloInstruction*, std::vector<const HloInstruction*>>
|
||||||
std::vector<const HloInstruction*>>
|
|
||||||
new_instruction_uses;
|
new_instruction_uses;
|
||||||
|
|
||||||
// For each newly added instruction, this is the count of the instruction's
|
// For each newly added instruction, this is the count of the instruction's
|
||||||
// operands that have not yet been scheduled. When this value reaches zero,
|
// operands that have not yet been scheduled. When this value reaches zero,
|
||||||
// then the instruction may be placed in the schedule.
|
// then the instruction may be placed in the schedule.
|
||||||
tensorflow::gtl::FlatMap<const HloInstruction*, int>
|
absl::flat_hash_map<const HloInstruction*, int> unscheduled_operand_count;
|
||||||
unscheduled_operand_count;
|
|
||||||
|
|
||||||
// Create a worklist of newly added instructions which are ready to be added
|
// Create a worklist of newly added instructions which are ready to be added
|
||||||
// to the schedule. Initialize worklist with those that have zero operands.
|
// to the schedule. Initialize worklist with those that have zero operands.
|
||||||
@ -217,9 +216,9 @@ Status HloSchedule::Update() {
|
|||||||
}
|
}
|
||||||
for (auto it = sequences_.begin(); it != sequences_.end();) {
|
for (auto it = sequences_.begin(); it != sequences_.end();) {
|
||||||
if (nonfusion_computations_ids.count(it->first) == 0) {
|
if (nonfusion_computations_ids.count(it->first) == 0) {
|
||||||
it = sequences_.erase(it);
|
sequences_.erase(it++);
|
||||||
} else {
|
} else {
|
||||||
it++;
|
++it;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -254,7 +253,7 @@ Status HloSchedule::Verify() const {
|
|||||||
// For each computation verify the set of instructions is the same and that
|
// For each computation verify the set of instructions is the same and that
|
||||||
// each dependency and control edge is honored.
|
// each dependency and control edge is honored.
|
||||||
for (const HloComputation* computation : nonfusion_computations) {
|
for (const HloComputation* computation : nonfusion_computations) {
|
||||||
tensorflow::gtl::FlatMap<const HloInstruction*, int> instruction_position;
|
absl::flat_hash_map<const HloInstruction*, int> instruction_position;
|
||||||
int pos = 0;
|
int pos = 0;
|
||||||
for (const HloInstruction* instruction :
|
for (const HloInstruction* instruction :
|
||||||
sequence(computation).instructions()) {
|
sequence(computation).instructions()) {
|
||||||
|
@ -18,6 +18,7 @@ limitations under the License.
|
|||||||
|
|
||||||
#include <vector>
|
#include <vector>
|
||||||
|
|
||||||
|
#include "absl/container/flat_hash_map.h"
|
||||||
#include "absl/types/span.h"
|
#include "absl/types/span.h"
|
||||||
#include "tensorflow/compiler/xla/service/hlo_computation.h"
|
#include "tensorflow/compiler/xla/service/hlo_computation.h"
|
||||||
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
|
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
|
||||||
@ -103,8 +104,7 @@ class HloSchedule {
|
|||||||
|
|
||||||
// Returns a map from HloComputation unique ID to instruction sequence. The
|
// Returns a map from HloComputation unique ID to instruction sequence. The
|
||||||
// map contains all sequences in the schedule.
|
// map contains all sequences in the schedule.
|
||||||
const tensorflow::gtl::FlatMap<int64, HloInstructionSequence>& sequences()
|
const absl::flat_hash_map<int64, HloInstructionSequence>& sequences() const {
|
||||||
const {
|
|
||||||
return sequences_;
|
return sequences_;
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -148,7 +148,7 @@ class HloSchedule {
|
|||||||
// A map from computation unique ID to instruction sequence. Unique IDs are
|
// A map from computation unique ID to instruction sequence. Unique IDs are
|
||||||
// used rather than HloComputation pointers because HLO pointers are not
|
// used rather than HloComputation pointers because HLO pointers are not
|
||||||
// unique across HLO transformations because pointers may be recycled.
|
// unique across HLO transformations because pointers may be recycled.
|
||||||
tensorflow::gtl::FlatMap<int64, HloInstructionSequence> sequences_;
|
absl::flat_hash_map<int64, HloInstructionSequence> sequences_;
|
||||||
};
|
};
|
||||||
|
|
||||||
std::ostream& operator<<(std::ostream& out, const HloSchedule& schedule);
|
std::ostream& operator<<(std::ostream& out, const HloSchedule& schedule);
|
||||||
|
@ -15,6 +15,7 @@ limitations under the License.
|
|||||||
|
|
||||||
#include <set>
|
#include <set>
|
||||||
|
|
||||||
|
#include "absl/container/flat_hash_map.h"
|
||||||
#include "absl/strings/str_join.h"
|
#include "absl/strings/str_join.h"
|
||||||
#include "tensorflow/compiler/xla/service/hlo_casting_utils.h"
|
#include "tensorflow/compiler/xla/service/hlo_casting_utils.h"
|
||||||
#include "tensorflow/compiler/xla/service/hlo_instructions.h"
|
#include "tensorflow/compiler/xla/service/hlo_instructions.h"
|
||||||
@ -23,7 +24,6 @@ limitations under the License.
|
|||||||
#include "tensorflow/compiler/xla/status_macros.h"
|
#include "tensorflow/compiler/xla/status_macros.h"
|
||||||
#include "tensorflow/compiler/xla/util.h"
|
#include "tensorflow/compiler/xla/util.h"
|
||||||
#include "tensorflow/core/lib/core/errors.h"
|
#include "tensorflow/core/lib/core/errors.h"
|
||||||
#include "tensorflow/core/lib/gtl/flatmap.h"
|
|
||||||
|
|
||||||
namespace xla {
|
namespace xla {
|
||||||
|
|
||||||
@ -993,7 +993,7 @@ Status CheckSameIsHostTransfer(const HloInstruction* instr1,
|
|||||||
|
|
||||||
// Checks various invariants of send and recv instructions.
|
// Checks various invariants of send and recv instructions.
|
||||||
Status VerifySendsAndRecvs(const HloModule& module) {
|
Status VerifySendsAndRecvs(const HloModule& module) {
|
||||||
tensorflow::gtl::FlatMap<int64, const HloInstruction*> host_channels;
|
absl::flat_hash_map<int64, const HloInstruction*> host_channels;
|
||||||
// Host send/recv instructions must have their own unique channel.
|
// Host send/recv instructions must have their own unique channel.
|
||||||
auto check_unique_host_channel = [&](const HloInstruction* instruction) {
|
auto check_unique_host_channel = [&](const HloInstruction* instruction) {
|
||||||
const HloSendRecvInstruction* sendrecv =
|
const HloSendRecvInstruction* sendrecv =
|
||||||
@ -1061,7 +1061,7 @@ StatusOr<bool> HloVerifier::Run(HloModule* module) {
|
|||||||
TF_RETURN_IF_ERROR(VerifyHloStructure(module));
|
TF_RETURN_IF_ERROR(VerifyHloStructure(module));
|
||||||
TF_RETURN_IF_ERROR(VerifySendsAndRecvs(*module));
|
TF_RETURN_IF_ERROR(VerifySendsAndRecvs(*module));
|
||||||
|
|
||||||
tensorflow::gtl::FlatMap<string, const HloInstruction*> instructions;
|
absl::flat_hash_map<string, const HloInstruction*> instructions;
|
||||||
|
|
||||||
for (auto* computation : module->computations()) {
|
for (auto* computation : module->computations()) {
|
||||||
for (const auto& instruction : computation->instructions()) {
|
for (const auto& instruction : computation->instructions()) {
|
||||||
|
@ -16,6 +16,7 @@ limitations under the License.
|
|||||||
#include "tensorflow/compiler/xla/service/indexed_array_analysis.h"
|
#include "tensorflow/compiler/xla/service/indexed_array_analysis.h"
|
||||||
|
|
||||||
#include "absl/algorithm/container.h"
|
#include "absl/algorithm/container.h"
|
||||||
|
#include "absl/container/flat_hash_map.h"
|
||||||
#include "absl/container/inlined_vector.h"
|
#include "absl/container/inlined_vector.h"
|
||||||
#include "absl/strings/str_cat.h"
|
#include "absl/strings/str_cat.h"
|
||||||
#include "absl/strings/str_join.h"
|
#include "absl/strings/str_join.h"
|
||||||
@ -95,7 +96,7 @@ Status IndexedArrayAnalysis::TraverseAndPopulateCache(
|
|||||||
absl::InlinedVector<const HloInstruction*, 4> stack;
|
absl::InlinedVector<const HloInstruction*, 4> stack;
|
||||||
|
|
||||||
enum DfsState { kDiscovered, kVisited };
|
enum DfsState { kDiscovered, kVisited };
|
||||||
gtl::FlatMap<const HloInstruction*, DfsState> dfs_state_map;
|
absl::flat_hash_map<const HloInstruction*, DfsState> dfs_state_map;
|
||||||
|
|
||||||
stack.push_back(root);
|
stack.push_back(root);
|
||||||
InsertOrDie(&dfs_state_map, root, kDiscovered);
|
InsertOrDie(&dfs_state_map, root, kDiscovered);
|
||||||
|
@ -18,10 +18,10 @@ limitations under the License.
|
|||||||
|
|
||||||
#include <type_traits>
|
#include <type_traits>
|
||||||
|
|
||||||
|
#include "absl/container/flat_hash_map.h"
|
||||||
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
|
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
|
||||||
#include "tensorflow/compiler/xla/service/hlo_module.h"
|
#include "tensorflow/compiler/xla/service/hlo_module.h"
|
||||||
#include "tensorflow/compiler/xla/service/hlo_pass_interface.h"
|
#include "tensorflow/compiler/xla/service/hlo_pass_interface.h"
|
||||||
#include "tensorflow/core/lib/gtl/flatmap.h"
|
|
||||||
#include "tensorflow/core/util/ptr_util.h"
|
#include "tensorflow/core/util/ptr_util.h"
|
||||||
|
|
||||||
namespace xla {
|
namespace xla {
|
||||||
@ -360,7 +360,7 @@ class IndexedArrayAnalysis {
|
|||||||
|
|
||||||
std::vector<std::unique_ptr<Array>> owned_tensors_;
|
std::vector<std::unique_ptr<Array>> owned_tensors_;
|
||||||
std::vector<Literal> owned_literals_;
|
std::vector<Literal> owned_literals_;
|
||||||
tensorflow::gtl::FlatMap<const HloInstruction*, Array*> cache_;
|
absl::flat_hash_map<const HloInstruction*, Array*> cache_;
|
||||||
};
|
};
|
||||||
|
|
||||||
// A pass that prints all non-trivial results returned by IndexedArrayAnalysis.
|
// A pass that prints all non-trivial results returned by IndexedArrayAnalysis.
|
||||||
|
@ -22,11 +22,11 @@ limitations under the License.
|
|||||||
#include <vector>
|
#include <vector>
|
||||||
|
|
||||||
#include "absl/algorithm/container.h"
|
#include "absl/algorithm/container.h"
|
||||||
|
#include "absl/container/flat_hash_map.h"
|
||||||
#include "absl/memory/memory.h"
|
#include "absl/memory/memory.h"
|
||||||
#include "tensorflow/compiler/xla/map_util.h"
|
#include "tensorflow/compiler/xla/map_util.h"
|
||||||
#include "tensorflow/compiler/xla/service/hlo_opcode.h"
|
#include "tensorflow/compiler/xla/service/hlo_opcode.h"
|
||||||
#include "tensorflow/core/lib/core/errors.h"
|
#include "tensorflow/core/lib/core/errors.h"
|
||||||
#include "tensorflow/core/lib/gtl/flatmap.h"
|
|
||||||
#include "tensorflow/core/platform/logging.h"
|
#include "tensorflow/core/platform/logging.h"
|
||||||
|
|
||||||
namespace xla {
|
namespace xla {
|
||||||
@ -189,7 +189,7 @@ bool InstructionFusion::EffectivelyAtMostUnary(HloInstruction* hlo) {
|
|||||||
bool InstructionFusion::CanFuseOnAllPaths(
|
bool InstructionFusion::CanFuseOnAllPaths(
|
||||||
HloInstruction* producer, HloInstruction* consumer,
|
HloInstruction* producer, HloInstruction* consumer,
|
||||||
const HloInstructionSet& do_not_fuse,
|
const HloInstructionSet& do_not_fuse,
|
||||||
tensorflow::gtl::FlatMap<std::pair<HloInstruction*, HloInstruction*>, bool>*
|
absl::flat_hash_map<std::pair<HloInstruction*, HloInstruction*>, bool>*
|
||||||
result_cache) {
|
result_cache) {
|
||||||
if (consumer == producer) {
|
if (consumer == producer) {
|
||||||
return true;
|
return true;
|
||||||
@ -241,7 +241,7 @@ InstructionFusion::ComputeGloballyUnfusible(
|
|||||||
// fusing operations that require duplication later depending on
|
// fusing operations that require duplication later depending on
|
||||||
// is_expensive_().
|
// is_expensive_().
|
||||||
HloInstructionSet do_not_duplicate;
|
HloInstructionSet do_not_duplicate;
|
||||||
tensorflow::gtl::FlatMap<std::pair<HloInstruction*, HloInstruction*>, bool>
|
absl::flat_hash_map<std::pair<HloInstruction*, HloInstruction*>, bool>
|
||||||
can_fuse_on_all_paths_result_cache;
|
can_fuse_on_all_paths_result_cache;
|
||||||
for (HloInstruction* consumer : post_order) {
|
for (HloInstruction* consumer : post_order) {
|
||||||
for (HloInstruction* producer : consumer->operands()) {
|
for (HloInstruction* producer : consumer->operands()) {
|
||||||
@ -430,7 +430,7 @@ class ReversePostOrderFusionQueue : public FusionQueue {
|
|||||||
|
|
||||||
private:
|
private:
|
||||||
std::vector<HloInstruction*> post_order_;
|
std::vector<HloInstruction*> post_order_;
|
||||||
tensorflow::gtl::FlatMap<HloInstruction*, int> post_order_index_;
|
absl::flat_hash_map<HloInstruction*, int> post_order_index_;
|
||||||
};
|
};
|
||||||
|
|
||||||
} // namespace
|
} // namespace
|
||||||
|
@ -1,3 +1,4 @@
|
|||||||
|
#include "absl/container/flat_hash_map.h"
|
||||||
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
Licensed under the Apache License, Version 2.0 (the "License");
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
@ -158,8 +159,8 @@ class InstructionFusion : public HloModulePass {
|
|||||||
bool CanFuseOnAllPaths(
|
bool CanFuseOnAllPaths(
|
||||||
HloInstruction* producer, HloInstruction* consumer,
|
HloInstruction* producer, HloInstruction* consumer,
|
||||||
const HloInstructionSet& do_not_fuse,
|
const HloInstructionSet& do_not_fuse,
|
||||||
tensorflow::gtl::FlatMap<std::pair<HloInstruction*, HloInstruction*>,
|
absl::flat_hash_map<std::pair<HloInstruction*, HloInstruction*>, bool>*
|
||||||
bool>* result_cache);
|
result_cache);
|
||||||
|
|
||||||
// Computes the set of nodes that we do not want to fuse into any of their
|
// Computes the set of nodes that we do not want to fuse into any of their
|
||||||
// consumers based on a global analysis of the HLO graph.
|
// consumers based on a global analysis of the HLO graph.
|
||||||
|
@ -25,6 +25,7 @@ limitations under the License.
|
|||||||
#include <utility>
|
#include <utility>
|
||||||
#include <vector>
|
#include <vector>
|
||||||
|
|
||||||
|
#include "absl/container/flat_hash_map.h"
|
||||||
#include "tensorflow/compiler/xla/service/computation_layout.h"
|
#include "tensorflow/compiler/xla/service/computation_layout.h"
|
||||||
#include "tensorflow/compiler/xla/service/hlo_computation.h"
|
#include "tensorflow/compiler/xla/service/hlo_computation.h"
|
||||||
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
|
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
|
||||||
@ -38,7 +39,6 @@ limitations under the License.
|
|||||||
#include "tensorflow/compiler/xla/types.h"
|
#include "tensorflow/compiler/xla/types.h"
|
||||||
#include "tensorflow/compiler/xla/xla_data.pb.h"
|
#include "tensorflow/compiler/xla/xla_data.pb.h"
|
||||||
#include "tensorflow/core/lib/core/status.h"
|
#include "tensorflow/core/lib/core/status.h"
|
||||||
#include "tensorflow/core/lib/gtl/flatmap.h"
|
|
||||||
#include "tensorflow/core/lib/gtl/flatset.h"
|
#include "tensorflow/core/lib/gtl/flatset.h"
|
||||||
#include "tensorflow/core/platform/types.h"
|
#include "tensorflow/core/platform/types.h"
|
||||||
|
|
||||||
@ -228,8 +228,8 @@ class LayoutConstraints {
|
|||||||
// Array-shaped buffers which have not yet been constrained.
|
// Array-shaped buffers which have not yet been constrained.
|
||||||
std::set<LogicalBuffer::Id> unconstrained_buffer_ids_;
|
std::set<LogicalBuffer::Id> unconstrained_buffer_ids_;
|
||||||
|
|
||||||
mutable tensorflow::gtl::FlatMap<const HloInstruction*,
|
mutable absl::flat_hash_map<const HloInstruction*,
|
||||||
std::unique_ptr<PointsToSet::BufferSet>>
|
std::unique_ptr<PointsToSet::BufferSet>>
|
||||||
buffer_sets_cache_;
|
buffer_sets_cache_;
|
||||||
|
|
||||||
HloComputation* computation_;
|
HloComputation* computation_;
|
||||||
|
@ -38,6 +38,7 @@ cc_library(
|
|||||||
"//tensorflow/compiler/xla/service:hlo",
|
"//tensorflow/compiler/xla/service:hlo",
|
||||||
"//tensorflow/compiler/xla/service:logical_buffer",
|
"//tensorflow/compiler/xla/service:logical_buffer",
|
||||||
"//tensorflow/core:lib",
|
"//tensorflow/core:lib",
|
||||||
|
"@com_google_absl//absl/container:flat_hash_map",
|
||||||
"@com_google_absl//absl/strings",
|
"@com_google_absl//absl/strings",
|
||||||
"@llvm//:core",
|
"@llvm//:core",
|
||||||
],
|
],
|
||||||
|
@ -16,13 +16,13 @@ limitations under the License.
|
|||||||
#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_LLVM_IR_ALIAS_ANALYSIS_H_
|
#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_LLVM_IR_ALIAS_ANALYSIS_H_
|
||||||
#define TENSORFLOW_COMPILER_XLA_SERVICE_LLVM_IR_ALIAS_ANALYSIS_H_
|
#define TENSORFLOW_COMPILER_XLA_SERVICE_LLVM_IR_ALIAS_ANALYSIS_H_
|
||||||
|
|
||||||
|
#include "absl/container/flat_hash_map.h"
|
||||||
#include "absl/strings/str_cat.h"
|
#include "absl/strings/str_cat.h"
|
||||||
#include "llvm/IR/Module.h"
|
#include "llvm/IR/Module.h"
|
||||||
#include "tensorflow/compiler/xla/service/buffer_assignment.h"
|
#include "tensorflow/compiler/xla/service/buffer_assignment.h"
|
||||||
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
|
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
|
||||||
#include "tensorflow/compiler/xla/service/llvm_ir/ir_array.h"
|
#include "tensorflow/compiler/xla/service/llvm_ir/ir_array.h"
|
||||||
#include "tensorflow/compiler/xla/types.h"
|
#include "tensorflow/compiler/xla/types.h"
|
||||||
#include "tensorflow/core/lib/gtl/flatmap.h"
|
|
||||||
#include "tensorflow/core/lib/gtl/flatset.h"
|
#include "tensorflow/core/lib/gtl/flatset.h"
|
||||||
|
|
||||||
namespace xla {
|
namespace xla {
|
||||||
@ -77,14 +77,14 @@ class AliasAnalysis {
|
|||||||
// A map from a buffer slice to metadata corresponding to its alias.scope
|
// A map from a buffer slice to metadata corresponding to its alias.scope
|
||||||
// metadata. The index kParameterAliasSet is used to hold aliasing
|
// metadata. The index kParameterAliasSet is used to hold aliasing
|
||||||
// information for parameters.
|
// information for parameters.
|
||||||
tensorflow::gtl::FlatMap<BufferAllocation::Slice, llvm::MDNode*,
|
absl::flat_hash_map<BufferAllocation::Slice, llvm::MDNode*,
|
||||||
BufferAllocation::Slice::Hasher>
|
BufferAllocation::Slice::Hasher>
|
||||||
alias_scope_metadata_;
|
alias_scope_metadata_;
|
||||||
|
|
||||||
// A map from a buffer slice to metadata corresponding to its noalias
|
// A map from a buffer slice to metadata corresponding to its noalias
|
||||||
// metadata.
|
// metadata.
|
||||||
tensorflow::gtl::FlatMap<BufferAllocation::Slice, llvm::MDNode*,
|
absl::flat_hash_map<BufferAllocation::Slice, llvm::MDNode*,
|
||||||
BufferAllocation::Slice::Hasher>
|
BufferAllocation::Slice::Hasher>
|
||||||
noalias_metadata_;
|
noalias_metadata_;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
@ -18,7 +18,7 @@ limitations under the License.
|
|||||||
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
|
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
|
||||||
#include "tensorflow/compiler/xla/service/hlo_opcode.h"
|
#include "tensorflow/compiler/xla/service/hlo_opcode.h"
|
||||||
#include "tensorflow/compiler/xla/shape_util.h"
|
#include "tensorflow/compiler/xla/shape_util.h"
|
||||||
#include "tensorflow/core/lib/gtl/flatmap.h"
|
#include "tensorflow/core/lib/gtl/flatset.h"
|
||||||
#include "tensorflow/core/platform/types.h"
|
#include "tensorflow/core/platform/types.h"
|
||||||
|
|
||||||
namespace xla {
|
namespace xla {
|
||||||
|
@ -19,6 +19,7 @@ limitations under the License.
|
|||||||
#include <queue>
|
#include <queue>
|
||||||
#include <vector>
|
#include <vector>
|
||||||
|
|
||||||
|
#include "absl/container/flat_hash_map.h"
|
||||||
#include "absl/strings/string_view.h"
|
#include "absl/strings/string_view.h"
|
||||||
#include "tensorflow/compiler/xla/service/hlo_module.h"
|
#include "tensorflow/compiler/xla/service/hlo_module.h"
|
||||||
#include "tensorflow/compiler/xla/service/hlo_pass_interface.h"
|
#include "tensorflow/compiler/xla/service/hlo_pass_interface.h"
|
||||||
@ -126,7 +127,7 @@ class MultiOutputFusion : public HloModulePass {
|
|||||||
std::vector<FusionCandidate> candidates_;
|
std::vector<FusionCandidate> candidates_;
|
||||||
|
|
||||||
// A map that maps an instruction to the index_.
|
// A map that maps an instruction to the index_.
|
||||||
tensorflow::gtl::FlatMap<HloInstruction*, int> candidates_index_;
|
absl::flat_hash_map<HloInstruction*, int> candidates_index_;
|
||||||
|
|
||||||
// The reachability map of current computation.
|
// The reachability map of current computation.
|
||||||
std::unique_ptr<HloReachabilityMap> reachability_;
|
std::unique_ptr<HloReachabilityMap> reachability_;
|
||||||
|
@ -18,9 +18,9 @@ limitations under the License.
|
|||||||
|
|
||||||
#include <string>
|
#include <string>
|
||||||
|
|
||||||
|
#include "absl/container/flat_hash_map.h"
|
||||||
#include "absl/strings/string_view.h"
|
#include "absl/strings/string_view.h"
|
||||||
#include "tensorflow/compiler/xla/types.h"
|
#include "tensorflow/compiler/xla/types.h"
|
||||||
#include "tensorflow/core/lib/gtl/flatmap.h"
|
|
||||||
#include "tensorflow/core/lib/gtl/flatset.h"
|
#include "tensorflow/core/lib/gtl/flatset.h"
|
||||||
#include "tensorflow/core/platform/macros.h"
|
#include "tensorflow/core/platform/macros.h"
|
||||||
|
|
||||||
@ -78,7 +78,7 @@ class NameUniquer {
|
|||||||
|
|
||||||
// Map from name prefix to the generator data structure which tracks used
|
// Map from name prefix to the generator data structure which tracks used
|
||||||
// identifiers and generates new ones.
|
// identifiers and generates new ones.
|
||||||
tensorflow::gtl::FlatMap<string, SequentialIdGenerator> generated_names_;
|
absl::flat_hash_map<string, SequentialIdGenerator> generated_names_;
|
||||||
|
|
||||||
TF_DISALLOW_COPY_AND_ASSIGN(NameUniquer);
|
TF_DISALLOW_COPY_AND_ASSIGN(NameUniquer);
|
||||||
};
|
};
|
||||||
|
@ -22,7 +22,6 @@ limitations under the License.
|
|||||||
#include "tensorflow/compiler/xla/service/hlo_module.h"
|
#include "tensorflow/compiler/xla/service/hlo_module.h"
|
||||||
#include "tensorflow/compiler/xla/service/hlo_pass_interface.h"
|
#include "tensorflow/compiler/xla/service/hlo_pass_interface.h"
|
||||||
#include "tensorflow/compiler/xla/service/hlo_pass_pipeline.h"
|
#include "tensorflow/compiler/xla/service/hlo_pass_pipeline.h"
|
||||||
#include "tensorflow/core/lib/gtl/flatmap.h"
|
|
||||||
|
|
||||||
namespace xla {
|
namespace xla {
|
||||||
|
|
||||||
|
@ -36,7 +36,6 @@ limitations under the License.
|
|||||||
#include "tensorflow/compiler/xla/xla_data.pb.h"
|
#include "tensorflow/compiler/xla/xla_data.pb.h"
|
||||||
#include "tensorflow/core/lib/core/status.h"
|
#include "tensorflow/core/lib/core/status.h"
|
||||||
#include "tensorflow/core/lib/gtl/compactptrset.h"
|
#include "tensorflow/core/lib/gtl/compactptrset.h"
|
||||||
#include "tensorflow/core/lib/gtl/flatmap.h"
|
|
||||||
#include "tensorflow/core/lib/gtl/flatset.h"
|
#include "tensorflow/core/lib/gtl/flatset.h"
|
||||||
#include "tensorflow/core/platform/macros.h"
|
#include "tensorflow/core/platform/macros.h"
|
||||||
#include "tensorflow/core/platform/types.h"
|
#include "tensorflow/core/platform/types.h"
|
||||||
|
@ -18,7 +18,6 @@ limitations under the License.
|
|||||||
#include "absl/container/inlined_vector.h"
|
#include "absl/container/inlined_vector.h"
|
||||||
#include "tensorflow/compiler/xla/service/while_util.h"
|
#include "tensorflow/compiler/xla/service/while_util.h"
|
||||||
#include "tensorflow/compiler/xla/util.h"
|
#include "tensorflow/compiler/xla/util.h"
|
||||||
#include "tensorflow/core/lib/gtl/flatmap.h"
|
|
||||||
|
|
||||||
namespace xla {
|
namespace xla {
|
||||||
|
|
||||||
|
@ -15,17 +15,17 @@ limitations under the License.
|
|||||||
|
|
||||||
#include "tensorflow/compiler/xla/service/while_loop_invariant_code_motion.h"
|
#include "tensorflow/compiler/xla/service/while_loop_invariant_code_motion.h"
|
||||||
#include "absl/algorithm/container.h"
|
#include "absl/algorithm/container.h"
|
||||||
|
#include "absl/container/flat_hash_map.h"
|
||||||
#include "absl/container/inlined_vector.h"
|
#include "absl/container/inlined_vector.h"
|
||||||
#include "tensorflow/compiler/xla/service/tuple_util.h"
|
#include "tensorflow/compiler/xla/service/tuple_util.h"
|
||||||
#include "tensorflow/compiler/xla/service/while_util.h"
|
#include "tensorflow/compiler/xla/service/while_util.h"
|
||||||
#include "tensorflow/compiler/xla/util.h"
|
#include "tensorflow/compiler/xla/util.h"
|
||||||
#include "tensorflow/core/lib/gtl/flatmap.h"
|
|
||||||
#include "tensorflow/core/lib/gtl/flatset.h"
|
#include "tensorflow/core/lib/gtl/flatset.h"
|
||||||
|
|
||||||
namespace xla {
|
namespace xla {
|
||||||
|
|
||||||
|
using absl::flat_hash_map;
|
||||||
using absl::InlinedVector;
|
using absl::InlinedVector;
|
||||||
using tensorflow::gtl::FlatMap;
|
|
||||||
using tensorflow::gtl::FlatSet;
|
using tensorflow::gtl::FlatSet;
|
||||||
|
|
||||||
// Copies `to_hoist` to the computation containing `while_instr`, hoisting its
|
// Copies `to_hoist` to the computation containing `while_instr`, hoisting its
|
||||||
@ -34,7 +34,7 @@ using tensorflow::gtl::FlatSet;
|
|||||||
// function hoists the operands in `unhoisted_invariant_instructions` and moves
|
// function hoists the operands in `unhoisted_invariant_instructions` and moves
|
||||||
// them into `hoisted_instructions`.
|
// them into `hoisted_instructions`.
|
||||||
static void CreateLoopInvariantCopy(
|
static void CreateLoopInvariantCopy(
|
||||||
FlatMap<HloInstruction*, HloInstruction*>* hoisted_instructions,
|
flat_hash_map<HloInstruction*, HloInstruction*>* hoisted_instructions,
|
||||||
FlatSet<HloInstruction*>* unhoisted_invariant_instructions,
|
FlatSet<HloInstruction*>* unhoisted_invariant_instructions,
|
||||||
HloInstruction* while_instr, HloInstruction* to_hoist) {
|
HloInstruction* while_instr, HloInstruction* to_hoist) {
|
||||||
HloComputation* parent_of_while = while_instr->parent();
|
HloComputation* parent_of_while = while_instr->parent();
|
||||||
@ -147,7 +147,7 @@ WhileLoopInvariantCodeMotion::TryHoistingInvariantInstructionsFromWhileBody(
|
|||||||
|
|
||||||
// Maps instructions in the while body to instructions hoisted outside the
|
// Maps instructions in the while body to instructions hoisted outside the
|
||||||
// while that compute the same value.
|
// while that compute the same value.
|
||||||
FlatMap<HloInstruction*, HloInstruction*> hoisted_instructions;
|
flat_hash_map<HloInstruction*, HloInstruction*> hoisted_instructions;
|
||||||
|
|
||||||
// Contains instructions that can be legally hoisted, but were deemed to be
|
// Contains instructions that can be legally hoisted, but were deemed to be
|
||||||
// unprofitable to be hoisted alone by NotWorthHoistingIndividually. When we
|
// unprofitable to be hoisted alone by NotWorthHoistingIndividually. When we
|
||||||
|
@ -14,12 +14,12 @@ limitations under the License.
|
|||||||
==============================================================================*/
|
==============================================================================*/
|
||||||
|
|
||||||
#include "tensorflow/compiler/xla/service/while_loop_simplifier.h"
|
#include "tensorflow/compiler/xla/service/while_loop_simplifier.h"
|
||||||
|
#include "absl/container/flat_hash_map.h"
|
||||||
#include "absl/strings/str_cat.h"
|
#include "absl/strings/str_cat.h"
|
||||||
#include "absl/strings/str_join.h"
|
#include "absl/strings/str_join.h"
|
||||||
#include "absl/types/optional.h"
|
#include "absl/types/optional.h"
|
||||||
#include "tensorflow/compiler/xla/service/call_inliner.h"
|
#include "tensorflow/compiler/xla/service/call_inliner.h"
|
||||||
#include "tensorflow/compiler/xla/service/while_loop_analysis.h"
|
#include "tensorflow/compiler/xla/service/while_loop_analysis.h"
|
||||||
#include "tensorflow/core/lib/gtl/flatmap.h"
|
|
||||||
|
|
||||||
namespace xla {
|
namespace xla {
|
||||||
|
|
||||||
@ -181,7 +181,7 @@ static StatusOr<bool> TryRemoveDeadWhileParams(HloInstruction* while_op) {
|
|||||||
used_tuple_indices.end());
|
used_tuple_indices.end());
|
||||||
std::sort(new_to_old_tuple_idx.begin(), new_to_old_tuple_idx.end());
|
std::sort(new_to_old_tuple_idx.begin(), new_to_old_tuple_idx.end());
|
||||||
|
|
||||||
tensorflow::gtl::FlatMap<int64, int64> old_to_new_tuple_idx;
|
absl::flat_hash_map<int64, int64> old_to_new_tuple_idx;
|
||||||
for (int64 new_idx = 0; new_idx < new_to_old_tuple_idx.size(); ++new_idx) {
|
for (int64 new_idx = 0; new_idx < new_to_old_tuple_idx.size(); ++new_idx) {
|
||||||
int64 old_idx = new_to_old_tuple_idx[new_idx];
|
int64 old_idx = new_to_old_tuple_idx[new_idx];
|
||||||
old_to_new_tuple_idx[old_idx] = new_idx;
|
old_to_new_tuple_idx[old_idx] = new_idx;
|
||||||
@ -405,7 +405,7 @@ static StatusOr<bool> TryPropagateConstant(HloInstruction* while_op) {
|
|||||||
// build a map from the tuple element index to the constant value. Limit this
|
// build a map from the tuple element index to the constant value. Limit this
|
||||||
// to scalar constant values because propagating array constants can regress
|
// to scalar constant values because propagating array constants can regress
|
||||||
// performance by forcing us to copy constants.
|
// performance by forcing us to copy constants.
|
||||||
tensorflow::gtl::FlatMap<int, const HloInstruction*> index_to_constant;
|
absl::flat_hash_map<int, const HloInstruction*> index_to_constant;
|
||||||
for (int i = 0; i < root_operands.size(); i++) {
|
for (int i = 0; i < root_operands.size(); i++) {
|
||||||
HloInstruction* instr = root_operands[i];
|
HloInstruction* instr = root_operands[i];
|
||||||
if (instr->opcode() == HloOpcode::kGetTupleElement &&
|
if (instr->opcode() == HloOpcode::kGetTupleElement &&
|
||||||
|
@ -422,6 +422,7 @@ xla_test(
|
|||||||
"//tensorflow/core:regexp_internal",
|
"//tensorflow/core:regexp_internal",
|
||||||
"//tensorflow/core:test",
|
"//tensorflow/core:test",
|
||||||
"@com_google_absl//absl/algorithm:container",
|
"@com_google_absl//absl/algorithm:container",
|
||||||
|
"@com_google_absl//absl/container:flat_hash_map",
|
||||||
"@com_google_absl//absl/strings",
|
"@com_google_absl//absl/strings",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
@ -17,6 +17,7 @@ limitations under the License.
|
|||||||
#include <vector>
|
#include <vector>
|
||||||
|
|
||||||
#include "absl/algorithm/container.h"
|
#include "absl/algorithm/container.h"
|
||||||
|
#include "absl/container/flat_hash_map.h"
|
||||||
#include "absl/strings/match.h"
|
#include "absl/strings/match.h"
|
||||||
#include "absl/strings/str_cat.h"
|
#include "absl/strings/str_cat.h"
|
||||||
#include "absl/strings/str_split.h"
|
#include "absl/strings/str_split.h"
|
||||||
@ -32,7 +33,6 @@ limitations under the License.
|
|||||||
#include "tensorflow/compiler/xla/tests/test_macros.h"
|
#include "tensorflow/compiler/xla/tests/test_macros.h"
|
||||||
#include "tensorflow/compiler/xla/tests/test_utils.h"
|
#include "tensorflow/compiler/xla/tests/test_utils.h"
|
||||||
#include "tensorflow/core/lib/core/status_test_util.h"
|
#include "tensorflow/core/lib/core/status_test_util.h"
|
||||||
#include "tensorflow/core/lib/gtl/flatmap.h"
|
|
||||||
#include "tensorflow/core/platform/regexp.h"
|
#include "tensorflow/core/platform/regexp.h"
|
||||||
#include "tensorflow/core/platform/test.h"
|
#include "tensorflow/core/platform/test.h"
|
||||||
#include "tensorflow/core/platform/types.h"
|
#include "tensorflow/core/platform/types.h"
|
||||||
@ -83,7 +83,7 @@ struct ParsedProfileOutputLine {
|
|||||||
|
|
||||||
Status ParseOneProfileOutputLine(
|
Status ParseOneProfileOutputLine(
|
||||||
const string& line, bool expect_hlo,
|
const string& line, bool expect_hlo,
|
||||||
gtl::FlatMap<string, ParsedProfileOutputLine>* parsed_results,
|
absl::flat_hash_map<string, ParsedProfileOutputLine>* parsed_results,
|
||||||
absl::Span<const absl::string_view> opcodes_to_ignore = {}) {
|
absl::Span<const absl::string_view> opcodes_to_ignore = {}) {
|
||||||
string separator = "[^:]*:: +";
|
string separator = "[^:]*:: +";
|
||||||
string match_percentage = R"(\d+\.\d*% +\d+Σ)";
|
string match_percentage = R"(\d+\.\d*% +\d+Σ)";
|
||||||
@ -208,7 +208,7 @@ XLA_TEST_F(HloProfileTest, ProfileSingleComputation) {
|
|||||||
std::vector<string> profile_output_lines =
|
std::vector<string> profile_output_lines =
|
||||||
absl::StrSplit(profile_output, '\n');
|
absl::StrSplit(profile_output, '\n');
|
||||||
|
|
||||||
gtl::FlatMap<string, ParsedProfileOutputLine> parsed_profile_lines;
|
absl::flat_hash_map<string, ParsedProfileOutputLine> parsed_profile_lines;
|
||||||
|
|
||||||
TF_ASSERT_OK(ParseOneProfileOutputLine(
|
TF_ASSERT_OK(ParseOneProfileOutputLine(
|
||||||
profile_output_lines[1], /*expect_hlo=*/false, &parsed_profile_lines));
|
profile_output_lines[1], /*expect_hlo=*/false, &parsed_profile_lines));
|
||||||
@ -314,7 +314,7 @@ XLA_TEST_F(HloProfileTest, ProfileWhileComputation) {
|
|||||||
|
|
||||||
ASSERT_NE(while_body_profile_end, profile_output_lines.end());
|
ASSERT_NE(while_body_profile_end, profile_output_lines.end());
|
||||||
|
|
||||||
gtl::FlatMap<string, ParsedProfileOutputLine> parsed_profile_lines;
|
absl::flat_hash_map<string, ParsedProfileOutputLine> parsed_profile_lines;
|
||||||
|
|
||||||
for (auto while_body_profile_i = while_body_profile_start + 1;
|
for (auto while_body_profile_i = while_body_profile_start + 1;
|
||||||
while_body_profile_i != while_body_profile_end; while_body_profile_i++) {
|
while_body_profile_i != while_body_profile_end; while_body_profile_i++) {
|
||||||
|
Loading…
Reference in New Issue
Block a user