diff --git a/tensorflow/compiler/jit/BUILD b/tensorflow/compiler/jit/BUILD index fedaeacad68..ef91c85ec36 100644 --- a/tensorflow/compiler/jit/BUILD +++ b/tensorflow/compiler/jit/BUILD @@ -593,6 +593,7 @@ cc_library( "//tensorflow/compiler/xla:status_macros", "//tensorflow/compiler/xla:statusor", "//tensorflow/core:framework", + "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/strings", diff --git a/tensorflow/compiler/jit/build_xla_ops_pass.cc b/tensorflow/compiler/jit/build_xla_ops_pass.cc index d479a483607..47b3c6611f3 100644 --- a/tensorflow/compiler/jit/build_xla_ops_pass.cc +++ b/tensorflow/compiler/jit/build_xla_ops_pass.cc @@ -233,14 +233,10 @@ void RemoveAllIncomingControlEdges(Graph* g, Node* n) { } // Returns true (into `result`) if a node placed on `device` must be compiled. -Status DeviceRequiresCompilation(const string& device, bool* result) { - DeviceType device_type(""); - TF_RETURN_IF_ERROR(DeviceNameToDeviceType(device, &device_type)); - const XlaOpRegistry::DeviceRegistration* registration = nullptr; - if (!XlaOpRegistry::GetCompilationDevice(device_type.type(), ®istration)) { - return errors::Internal("Could not find compilation device ", - device_type.type()); - } +Status DeviceRequiresCompilation(const jit::DeviceInfoCache& device_info_cache, + jit::DeviceId device, bool* result) { + const XlaOpRegistry::DeviceRegistration* registration = + device_info_cache.GetCompilationDevice(device); *result = registration->autoclustering_policy == XlaOpRegistry::AutoclusteringPolicy::kAlways; return Status::OK(); @@ -293,17 +289,20 @@ Status ReplaceFunctionCallWithPartionedCall( return Status::OK(); } -Status InferDeviceForCluster(Node* n, const string& function_name, - const FunctionLibraryDefinition& flib_def, - string* result) { +xla::StatusOr InferDeviceForCluster( + jit::DeviceInfoCache* device_info_cache, Node* n, + const string& function_name, const FunctionLibraryDefinition& flib_def) { const FunctionDef* func_def = flib_def.Find(function_name); TF_RET_CHECK(func_def) << "Could not find " << function_name; - std::set device_names; + jit::DeviceSet device_set; + for (const NodeDef& ndef : func_def->node_def()) { VLOG(3) << ndef.DebugString(); if (!ndef.device().empty()) { - device_names.insert(ndef.device()); + TF_ASSIGN_OR_RETURN(jit::DeviceId device_id, + device_info_cache->GetIdFor(ndef.device())); + device_set.Insert(device_id); } } @@ -311,41 +310,47 @@ Status InferDeviceForCluster(Node* n, const string& function_name, // TODO(sanjoy): We need this because EncapsulateSubgraphsPass drops device // assignment when constant folding. We should fix EncapsulateSubgraphsPass // instead. - device_names.insert(n->assigned_device_name()); + TF_ASSIGN_OR_RETURN(jit::DeviceId device_id, + device_info_cache->GetIdFor(n->assigned_device_name())); + device_set.Insert(device_id); } - std::vector device_names_vector; - absl::c_copy(device_names, std::back_inserter(device_names_vector)); - - Status s = PickDeviceForXla(device_names_vector, true, result); - if (s.ok()) { - VLOG(2) << "For " << function_name << " PickDeviceForXla(" - << absl::StrJoin(device_names_vector, ", ") << ") -> " << *result; - } - return s; + TF_ASSIGN_OR_RETURN(jit::DeviceId result, + PickDeviceForXla(*device_info_cache, device_set, + /*allow_mixing_unknown_and_cpu=*/true)); + VLOG(2) << "For " << function_name << " PickDeviceForXla(" + << device_info_cache->DebugString(device_set) << ") -> " + << device_info_cache->GetNameFor(result); + return result; } Status ReplaceNodeWithXlaCompileAndXlaRun( + jit::DeviceInfoCache* device_info_cache, const GraphOptimizationPassOptions& options, const FunctionLibraryDefinition& flib_def, bool lazy_compilation_enabled, bool insert_print_nodes, Graph* g, Node* n) { XlaClusterInfo cluster_info; TF_RETURN_IF_ERROR(GetXlaClusterInfo(n, &cluster_info)); - string device; - TF_RETURN_IF_ERROR(InferDeviceForCluster(n, cluster_info.function.name(), - flib_def, &device)); + TF_ASSIGN_OR_RETURN( + jit::DeviceId device, + InferDeviceForCluster(device_info_cache, n, cluster_info.function.name(), + flib_def)); + bool requires_compilation; - TF_RETURN_IF_ERROR(DeviceRequiresCompilation(device, &requires_compilation)); + TF_RETURN_IF_ERROR(DeviceRequiresCompilation(*device_info_cache, device, + &requires_compilation)); if (!lazy_compilation_enabled) { requires_compilation = true; } + string device_name_str = string(device_info_cache->GetNameFor(device)); + Status status; Scope root = NewInternalScope(g, &status, /*refiner=*/nullptr) .NewSubScope(n->name()) .WithDevice(n->requested_device()) - .WithAssignedDevice(device); + .WithAssignedDevice(device_name_str); ops::_XlaCompile xla_compile(root.WithOpName("xla_compile"), /*constants=*/cluster_info.constant_inputs, @@ -441,10 +446,12 @@ Status BuildXlaOpsPass::Run(const GraphOptimizationPassOptions& options) { bool insert_print_nodes = GetBuildXlaOpsPassFlags()->tf_xla_print_cluster_outputs; + jit::DeviceInfoCache device_info_cache; + for (Node* n : xla_compiled_kernels) { TF_RETURN_IF_ERROR(ReplaceNodeWithXlaCompileAndXlaRun( - options, *options.flib_def, lazy_compilation_enabled, - insert_print_nodes, graph, n)); + &device_info_cache, options, *options.flib_def, + lazy_compilation_enabled, insert_print_nodes, graph, n)); } if (VLOG_IS_ON(1)) { diff --git a/tensorflow/compiler/jit/device_util.cc b/tensorflow/compiler/jit/device_util.cc index 2f0b4bb11d0..b2dee129239 100644 --- a/tensorflow/compiler/jit/device_util.cc +++ b/tensorflow/compiler/jit/device_util.cc @@ -15,6 +15,7 @@ limitations under the License. #include "tensorflow/compiler/jit/device_util.h" +#include "absl/algorithm/container.h" #include "absl/container/flat_hash_set.h" #include "tensorflow/compiler/xla/status_macros.h" @@ -22,43 +23,68 @@ namespace tensorflow { namespace jit { using xla::StatusOr; -StatusOr -DeviceInfoCache::GetCompilationDevice(absl::string_view device_name) { - auto it = device_to_device_registration_.find(device_name); - if (it != device_to_device_registration_.end()) { +void DeviceSet::Insert(DeviceId device_id) { + int word_index = device_id.id() / kWordSize; + int bit_index = device_id.id() % kWordSize; + + if (word_index >= storage_.size()) { + storage_.resize(word_index + 1, 0); + } + + storage_[word_index] |= (1ull << bit_index); +} + +void DeviceSet::UnionWith(const DeviceSet& other) { + if (other.storage_.size() > storage_.size()) { + storage_.resize(other.storage_.size(), 0); + } + + for (int i = 0; i < other.storage_.size(); i++) { + storage_[i] |= other.storage_[i]; + } +} + +bool DeviceSet::IsEmpty() const { + return absl::c_all_of(storage_, [&](uint64 val) { return val == 0; }); +} + +xla::StatusOr DeviceInfoCache::GetIdFor(absl::string_view name) { + TF_RET_CHECK(!name.empty()); + + auto it = name_to_id_.find(name); + if (it != name_to_id_.end()) { return it->second; } - string device_name_str = string(device_name); - TF_ASSIGN_OR_RETURN(const DeviceType& device_type, - GetDeviceTypeFor(device_name_str)); - const XlaOpRegistry::DeviceRegistration* registration; - if (!XlaOpRegistry::GetCompilationDevice(device_type.type(), ®istration)) { - registration = nullptr; + int new_id = names_.size(); + names_.push_back(string(name)); + id_to_device_type_.push_back(absl::make_unique("")); + DeviceType* device_type = id_to_device_type_.back().get(); + TF_RETURN_IF_ERROR(DeviceNameToDeviceType(names_.back(), device_type)); + + is_cpu_.push_back(device_type->type_string() == DEVICE_CPU); + is_gpu_.push_back(device_type->type_string() == DEVICE_GPU); + + name_to_id_.emplace(string(name), DeviceId(new_id)); + + const XlaOpRegistry::DeviceRegistration* compilation_device; + if (!XlaOpRegistry::GetCompilationDevice(device_type->type(), + &compilation_device)) { + compilation_device = nullptr; } + id_to_compilation_device_.push_back(compilation_device); - device_to_device_registration_.insert( - {std::move(device_name_str), registration}); - - return registration; + return DeviceId(new_id); } -StatusOr> -DeviceInfoCache::GetDeviceTypeFor(absl::string_view device_name) { - auto it = device_to_device_type_.find(device_name); - if (it != device_to_device_type_.end()) { - return std::cref(*it->second); - } +string DeviceInfoCache::DebugString(const DeviceSet& device_set) const { + std::vector names; + device_set.ForEach([&](DeviceId device_id) { + names.push_back(string(GetNameFor(device_id))); + return false; + }); - string device_name_str = string(device_name); - auto device_type = absl::make_unique(""); - TF_RETURN_IF_ERROR( - DeviceNameToDeviceType(device_name_str, device_type.get())); - - it = device_to_device_type_ - .insert({std::move(device_name_str), std::move(device_type)}) - .first; - return std::cref(*it->second); + return absl::StrCat("[", absl::StrJoin(names, ","), "]"); } } // namespace jit @@ -71,10 +97,11 @@ Status DeviceNameToDeviceType(const string& device, DeviceType* device_type) { return Status::OK(); } -Status PickDeviceForXlaImpl(absl::Span device_names, +Status PickDeviceForXlaImpl(const jit::DeviceInfoCache& device_info_cache, + const jit::DeviceSet& devices, bool allow_mixing_unknown_and_cpu, bool* out_can_pick_device, - string* out_device_picked) { + absl::optional* out_device_picked) { if (out_can_pick_device) { *out_can_pick_device = true; } @@ -89,65 +116,79 @@ Status PickDeviceForXlaImpl(absl::Span device_names, } \ } while (false) - TF_RET_CHECK(!device_names.empty()) << "No devices to choose from"; + TF_RET_CHECK(!devices.IsEmpty()) << "No devices to choose from"; DCHECK_NE(out_can_pick_device == nullptr, out_device_picked == nullptr); - absl::flat_hash_set device_names_set; - for (absl::string_view device_name : device_names) { - TF_RET_CHECK(!device_name.empty()); - device_names_set.insert(device_name); - } + absl::optional maybe_gpu_device; + absl::optional maybe_cpu_device; + absl::optional maybe_unknown_device; - absl::optional maybe_gpu_device; - absl::optional maybe_cpu_device; - absl::optional maybe_unknown_device; + bool multiple_cpu_devices = false; + bool multiple_gpu_devices = false; + bool multiple_unknown_devices = false; - for (absl::string_view device_name : device_names_set) { - DeviceNameUtils::ParsedName parsed_name; - TF_RET_CHECK(DeviceNameUtils::ParseFullName(device_name, &parsed_name)) - << device_name; - if (parsed_name.type == "GPU") { + devices.ForEach([&](jit::DeviceId device) { + if (device_info_cache.IsGpu(device)) { if (maybe_gpu_device) { - FAILED_TO_PICK_DEVICE(errors::Internal( - "Multiple GPU devices ", absl::StrJoin(device_names, ", "))); + multiple_gpu_devices = true; + return false; } - maybe_gpu_device = device_name; - } else if (parsed_name.type == "CPU") { + maybe_gpu_device = device; + } else if (device_info_cache.IsCpu(device)) { if (maybe_cpu_device) { - FAILED_TO_PICK_DEVICE(errors::Internal( - "Multiple CPU devices ", absl::StrJoin(device_names, ", "))); + multiple_cpu_devices = true; + return false; } - maybe_cpu_device = device_name; + maybe_cpu_device = device; } else { if (maybe_unknown_device) { - FAILED_TO_PICK_DEVICE(errors::Internal( - "Multiple unknown devices ", absl::StrJoin(device_names, ", "))); + multiple_unknown_devices = true; + return false; } - maybe_unknown_device = device_name; + maybe_unknown_device = device; } + + return true; + }); + + if (multiple_cpu_devices) { + FAILED_TO_PICK_DEVICE(errors::Internal( + "Multiple CPU devices ", device_info_cache.DebugString(devices))); + } + + if (multiple_gpu_devices) { + FAILED_TO_PICK_DEVICE(errors::Internal( + "Multiple GPU devices ", device_info_cache.DebugString(devices))); + } + + if (multiple_unknown_devices) { + FAILED_TO_PICK_DEVICE(errors::Internal( + "Multiple unknown devices ", device_info_cache.DebugString(devices))); } if (maybe_unknown_device && maybe_gpu_device) { FAILED_TO_PICK_DEVICE(errors::Internal( - "Found both unknown and GPU devices: ", *maybe_unknown_device, ", ", - *maybe_gpu_device)); + "Found both unknown and GPU devices: ", + device_info_cache.GetNameFor(*maybe_unknown_device), ", ", + device_info_cache.GetNameFor(*maybe_gpu_device))); } if (!allow_mixing_unknown_and_cpu) { if (maybe_unknown_device && maybe_cpu_device) { FAILED_TO_PICK_DEVICE(errors::Internal( - "Found both unknown and CPU devices: ", *maybe_unknown_device, ", ", - *maybe_cpu_device)); + "Found both unknown and CPU devices: ", + device_info_cache.GetNameFor(*maybe_unknown_device), ", ", + device_info_cache.GetNameFor(*maybe_cpu_device))); } } if (out_device_picked) { if (maybe_gpu_device) { - *out_device_picked = string(*maybe_gpu_device); + *out_device_picked = *maybe_gpu_device; } else if (maybe_unknown_device) { - *out_device_picked = string(*maybe_unknown_device); + *out_device_picked = *maybe_unknown_device; } else { - *out_device_picked = string(*maybe_cpu_device); + *out_device_picked = *maybe_cpu_device; } } @@ -156,19 +197,24 @@ Status PickDeviceForXlaImpl(absl::Span device_names, #undef FAILED_TO_PICK_DEVICE } -Status PickDeviceForXla(absl::Span device_names, - bool allow_mixing_unknown_and_cpu, - string* out_device_picked) { - return PickDeviceForXlaImpl(device_names, allow_mixing_unknown_and_cpu, - /*out_can_pick_device=*/nullptr, - out_device_picked); +xla::StatusOr PickDeviceForXla( + const jit::DeviceInfoCache& device_info_cache, + const jit::DeviceSet& devices, bool allow_mixing_unknown_and_cpu) { + absl::optional device; + TF_RETURN_IF_ERROR(PickDeviceForXlaImpl( + device_info_cache, devices, allow_mixing_unknown_and_cpu, + /*out_can_pick_device=*/nullptr, &device)); + return *device; } -Status CanPickDeviceForXla(absl::Span device_names, - bool allow_mixing_unknown_and_cpu, - bool* out_can_pick_device) { - return PickDeviceForXlaImpl(device_names, allow_mixing_unknown_and_cpu, - out_can_pick_device, - /*out_device_picked=*/nullptr); +xla::StatusOr CanPickDeviceForXla( + const jit::DeviceInfoCache& device_info_cache, + const jit::DeviceSet& devices, bool allow_mixing_unknown_and_cpu) { + bool can_pick_device; + TF_RETURN_IF_ERROR(PickDeviceForXlaImpl(device_info_cache, devices, + allow_mixing_unknown_and_cpu, + &can_pick_device, + /*out_device_picked=*/nullptr)); + return can_pick_device; } } // namespace tensorflow diff --git a/tensorflow/compiler/jit/device_util.h b/tensorflow/compiler/jit/device_util.h index e7c07b5ec4e..f3c6dec687e 100644 --- a/tensorflow/compiler/jit/device_util.h +++ b/tensorflow/compiler/jit/device_util.h @@ -23,24 +23,119 @@ limitations under the License. #include "absl/strings/string_view.h" #include "absl/types/span.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" +#include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/core/framework/types.h" namespace tensorflow { namespace jit { +// Instances of DeviceId represent TensorFlow devices as integers. +// +// This helps avoid having to manipulate device names as strings when +// auto-clustering. +class DeviceId { + public: + DeviceId(DeviceId&&) = default; + DeviceId(const DeviceId&) = default; + DeviceId& operator=(const DeviceId&) = default; + + bool operator==(const DeviceId& other) const { return id() == other.id(); } + bool operator!=(const DeviceId& other) const { return !(*this == other); } + + private: + int id_; + + explicit DeviceId(int id) : id_(id) {} + + int id() const { return id_; } + + friend class DeviceInfoCache; + friend class DeviceSet; +}; + +// A set of DeviceIds, represented as a bitmap. +class DeviceSet { + public: + void Insert(DeviceId device_id); + void UnionWith(const DeviceSet& other); + bool IsEmpty() const; + + // Calls `func` on each DeviceId in the set. Stops iterating early if `func` + // return false. + // + // TODO(sanjoy): Change this to take a typed std::function if that's + // performance neutral. + template + void ForEach(FnTy func) const { + // This is really a poor man's iterator, we should consider writing a proper + // iterator if this ends up being used widely. + for (int word_index = 0; word_index < storage_.size(); word_index++) { + uint64 word = storage_[word_index]; + for (int bit_index = 0; bit_index < kWordSize; bit_index++) { + if (word & (1ull << bit_index)) { + if (!func(DeviceId(word_index * kWordSize + bit_index))) { + return; + } + } + } + } + } + + private: + absl::InlinedVector storage_; + + const int kWordSize = 64; +}; + // Caches some miscellaneous information about TF devices. Thread compatible. class DeviceInfoCache { public: - xla::StatusOr GetCompilationDevice( - absl::string_view device_name); - xla::StatusOr> GetDeviceTypeFor( - absl::string_view device_name); + bool IsGpu(DeviceId device) const { return is_gpu_[device.id()]; } + bool IsCpu(DeviceId device) const { return is_cpu_[device.id()]; } + + absl::string_view GetNameFor(DeviceId device) const { + return names_[device.id()]; + } + + xla::StatusOr GetIdFor(absl::string_view name); + + using DeviceRegistration = const XlaOpRegistry::DeviceRegistration; + + DeviceRegistration* GetCompilationDevice(DeviceId device) const { + return id_to_compilation_device_[device.id()]; + } + + xla::StatusOr GetCompilationDevice( + absl::string_view name) { + TF_ASSIGN_OR_RETURN(DeviceId device_id, GetIdFor(name)); + return GetCompilationDevice(device_id); + } + + const DeviceType& GetDeviceTypeFor(DeviceId device) const { + return *id_to_device_type_[device.id()]; + } + + using DeviceTypeConstRef = std::reference_wrapper; + + xla::StatusOr GetDeviceTypeFor( + absl::string_view device_name) { + TF_ASSIGN_OR_RETURN(DeviceId device_id, GetIdFor(device_name)); + return std::cref(*id_to_device_type_[device_id.id()]); + } + + string DebugString(const DeviceSet& device_set) const; private: - absl::flat_hash_map - device_to_device_registration_; - absl::flat_hash_map> - device_to_device_type_; + absl::flat_hash_map name_to_id_; + + // These fields are populated for a device in GetIdFor, *before* we give out a + // DeviceId. + std::vector + id_to_compilation_device_; + std::vector> id_to_device_type_; + std::vector names_; + std::vector is_cpu_; + std::vector is_gpu_; }; } // namespace jit @@ -49,7 +144,7 @@ class DeviceInfoCache { Status DeviceNameToDeviceType(const string& device, DeviceType* device_type); // Picks the device for which XLA should compile a cluster that contains -// operations placed in devices in `device_names`. For instance a cluster that +// operations placed in devices in `devices`. For instance a cluster that // contains operations solely placed on the CPU will be compiled into a CPU // executable by XLA, whereas a cluster that contains operations placed on the // CPU and also operations placed on the GPU will be compiled into a GPU @@ -82,16 +177,15 @@ Status DeviceNameToDeviceType(const string& device, DeviceType* device_type); // case it is the responsibility of the optimization pass that injected the // CPU nodes into the cluster to ensure that these nodes can be compiled by // the unknown XLA backend. -Status PickDeviceForXla(absl::Span device_names, - bool allow_mixing_unknown_and_cpu, - string* out_device_picked); +xla::StatusOr PickDeviceForXla( + const jit::DeviceInfoCache& device_info_cache, + const jit::DeviceSet& devices, bool allow_mixing_unknown_and_cpu); // This is like `PickDeviceForXla` except that it returns false (instead of a -// non-OK Status) in `out_can_pick_device` if no unambiguous choice of device -// exists. -Status CanPickDeviceForXla(absl::Span device_names, - bool allow_mixing_unknown_and_cpu, - bool* out_can_pick_device); +// non-OK Status) if no unambiguous choice of device exists. +xla::StatusOr CanPickDeviceForXla( + const jit::DeviceInfoCache& device_info_cache, + const jit::DeviceSet& devices, bool allow_mixing_unknown_and_cpu); } // namespace tensorflow #endif // TENSORFLOW_COMPILER_JIT_DEVICE_INFO_CACHE_H_ diff --git a/tensorflow/compiler/jit/device_util_test.cc b/tensorflow/compiler/jit/device_util_test.cc index fede5644c00..9396c49d52e 100644 --- a/tensorflow/compiler/jit/device_util_test.cc +++ b/tensorflow/compiler/jit/device_util_test.cc @@ -22,12 +22,20 @@ namespace tensorflow { namespace { Status PickDeviceHelper(bool allow_mixing_unknown_and_cpu, - absl::Span inputs, + absl::Span device_names, string* result) { - std::vector inputs_string; - absl::c_transform(inputs, std::back_inserter(inputs_string), - [](absl::string_view sv) { return string(sv); }); - return PickDeviceForXla(inputs_string, allow_mixing_unknown_and_cpu, result); + jit::DeviceInfoCache cache; + jit::DeviceSet device_set; + for (absl::string_view name : device_names) { + TF_ASSIGN_OR_RETURN(jit::DeviceId device_id, cache.GetIdFor(name)); + device_set.Insert(device_id); + } + + TF_ASSIGN_OR_RETURN( + jit::DeviceId result_id, + PickDeviceForXla(cache, device_set, allow_mixing_unknown_and_cpu)); + *result = string(cache.GetNameFor(result_id)); + return Status::OK(); } void CheckPickDeviceResult(absl::string_view expected_result, @@ -87,5 +95,38 @@ TEST(PickDeviceForXla, MultipleDevicesOfSameType) { CheckPickDeviceHasError(false, {kCPU0, kCPU1, kGPU0}); } +void SimpleRoundTripTestForDeviceSet(int num_devices) { + jit::DeviceSet device_set; + jit::DeviceInfoCache device_info_cache; + + std::vector expected_devices, actual_devices; + + for (int i = 0; i < num_devices; i++) { + string device_name = + absl::StrCat("/job:localhost/replica:0/task:0/device:XPU:", i); + TF_ASSERT_OK_AND_ASSIGN(jit::DeviceId device_id, + device_info_cache.GetIdFor(device_name)); + device_set.Insert(device_id); + expected_devices.push_back(device_name); + } + + device_set.ForEach([&](jit::DeviceId device_id) { + actual_devices.push_back(string(device_info_cache.GetNameFor(device_id))); + return true; + }); + + EXPECT_EQ(expected_devices, actual_devices); +} + +TEST(DeviceSetTest, SimpleRoundTrip_One) { SimpleRoundTripTestForDeviceSet(1); } + +TEST(DeviceSetTest, SimpleRoundTrip_Small) { + SimpleRoundTripTestForDeviceSet(8); +} + +TEST(DeviceSetTest, SimpleRoundTrip_Large) { + SimpleRoundTripTestForDeviceSet(800); +} + } // namespace } // namespace tensorflow diff --git a/tensorflow/compiler/jit/mark_for_compilation_pass.cc b/tensorflow/compiler/jit/mark_for_compilation_pass.cc index 7d794b0edc2..a9713f8ea3c 100644 --- a/tensorflow/compiler/jit/mark_for_compilation_pass.cc +++ b/tensorflow/compiler/jit/mark_for_compilation_pass.cc @@ -57,6 +57,8 @@ namespace tensorflow { namespace { using DeadnessPredicate = DeadnessAnalysis::DeadnessPredicate; +using jit::DeviceId; +using jit::DeviceSet; using xla::StatusOr; // The clusters we create here are eventually lowered into an @@ -117,8 +119,8 @@ class MarkForCompilationPassImpl { public: // Constructs a trivial cluster representing a single TF node. Cluster(int tf_graph_node_id, int effective_cluster_size, - bool has_functional_control_flow, - absl::flat_hash_set devices, string resource_op_device, + bool has_functional_control_flow, DeviceSet devices, + absl::optional resource_op_device, absl::optional resource_var_operation_node_id, absl::optional deadness_predicate, bool is_xla_compile_attr_true, absl::optional xla_scope) @@ -126,7 +128,7 @@ class MarkForCompilationPassImpl { effective_cluster_size_(effective_cluster_size), has_functional_control_flow_(has_functional_control_flow), devices_(std::move(devices)), - resource_op_device_(std::move(resource_op_device)), + resource_op_device_(resource_op_device), deadness_predicate_(deadness_predicate), is_xla_compile_attr_true_(is_xla_compile_attr_true), xla_scope_(std::move(xla_scope)) { @@ -162,12 +164,14 @@ class MarkForCompilationPassImpl { } // The set of devices nodes in the cluster are placed on. - const absl::flat_hash_set& devices() const { return devices_; } + const DeviceSet& devices() const { return devices_; } // If the cluster has a resource operation then the device the resource // operation is placed on. A cluster may have resource ops placed only on a // single device. - const string& resource_op_device() const { return resource_op_device_; } + const absl::optional& resource_op_device() const { + return resource_op_device_; + } // If not nullopt the a predicate that is true iff the cluster is alive. // Otherwise the user has (unsafely) disabled deadness analysis. If this is @@ -208,8 +212,8 @@ class MarkForCompilationPassImpl { int cycles_graph_node_id_; int effective_cluster_size_; bool has_functional_control_flow_; - absl::flat_hash_set devices_; - string resource_op_device_; + DeviceSet devices_; + absl::optional resource_op_device_; absl::optional deadness_predicate_; bool is_xla_compile_attr_true_; absl::optional xla_scope_; @@ -279,17 +283,17 @@ class MarkForCompilationPassImpl { Cluster* MakeNewCluster(int cycles_graph_node_id, int effective_cluster_size, bool has_functional_control_flow, - absl::flat_hash_set devices, - string resource_op_device, + const DeviceSet& device_set, + absl::optional resource_op_device, absl::optional resource_var_operation_node_id, absl::optional deadness_predicate, bool is_xla_compile_attr_true, absl::optional xla_scope) { cluster_storage_.push_back(absl::make_unique( cycles_graph_node_id, effective_cluster_size, - has_functional_control_flow, std::move(devices), - std::move(resource_op_device), resource_var_operation_node_id, - deadness_predicate, is_xla_compile_attr_true, xla_scope)); + has_functional_control_flow, device_set, resource_op_device, + resource_var_operation_node_id, deadness_predicate, + is_xla_compile_attr_true, xla_scope)); return cluster_storage_.back().get(); } @@ -486,13 +490,15 @@ void MarkForCompilationPassImpl::Cluster::Merge(Cluster* other) { effective_cluster_size_ += other->effective_cluster_size_; has_functional_control_flow_ |= other->has_functional_control_flow_; - for (string other_device : other->devices_) { - devices_.insert(other_device); - } - other->devices_.clear(); + devices_.UnionWith(other->devices_); - if (resource_op_device_.empty()) { - resource_op_device_ = std::move(other->resource_op_device_); + DCHECK(!(resource_op_device_.has_value() && + other->resource_op_device_.has_value()) || + *resource_op_device_ == *other->resource_op_device_) + << "AreDevicesCompatible should have returned false otherwise!"; + + if (!resource_op_device_.has_value()) { + resource_op_device_ = other->resource_op_device_; } is_xla_compile_attr_true_ |= other->is_xla_compile_attr_true_; @@ -779,12 +785,14 @@ Status MarkForCompilationPassImpl::BuildInitialClusterSet() { deadness_analysis_->GetPredicateFor(node, Graph::kControlSlot)); } - const string& device = !node->assigned_device_name().empty() - ? node->assigned_device_name() - : node->requested_device(); + const string& device_name_str = !node->assigned_device_name().empty() + ? node->assigned_device_name() + : node->requested_device(); + TF_ASSIGN_OR_RETURN(DeviceId device, + device_info_cache_.GetIdFor(device_name_str)); bool is_resource_op = HasResourceInputOrOutput(*node); - string resource_op_device; + absl::optional resource_op_device; if (is_resource_op) { resource_op_device = device; } @@ -805,15 +813,14 @@ Status MarkForCompilationPassImpl::BuildInitialClusterSet() { is_xla_compile_attr_true |= xla_compile_attr; } - absl::flat_hash_set devices; - devices.insert(device); + DeviceSet devices; + devices.Insert(device); Cluster* new_cluster = MakeNewCluster( /*cycles_graph_node_id=*/node->id(), /*effective_cluster_size=*/effective_cluster_size, - /*has_functional_control_flow=*/has_functional_control_flow, - std::move(devices), std::move(resource_op_device), - resource_var_operation_node_id, deadness_predicate, + /*has_functional_control_flow=*/has_functional_control_flow, devices, + resource_op_device, resource_var_operation_node_id, deadness_predicate, /*is_xla_compile_attr_true=*/is_xla_compile_attr_true, GetXlaScope(node)); @@ -1255,27 +1262,22 @@ void MarkForCompilationPassImpl::VLogClusteringSummary() { StatusOr MarkForCompilationPassImpl::AreDevicesCompatible( const Cluster& cluster_a, const Cluster& cluster_b) { - std::vector devices; - absl::c_remove_copy(cluster_a.devices(), std::back_inserter(devices), ""); - absl::c_remove_copy(cluster_b.devices(), std::back_inserter(devices), ""); - absl::c_sort(devices); - - if (devices.empty()) { - return false; - } + DeviceSet devices = cluster_a.devices(); + devices.UnionWith(cluster_b.devices()); // First check if we will even be able to pick a device for the larger // combined cluster. - bool can_pick_device; - TF_RETURN_IF_ERROR(CanPickDeviceForXla( - devices, /*allow_mixing_unknown_and_cpu=*/false, &can_pick_device)); + TF_ASSIGN_OR_RETURN( + bool can_pick_device, + CanPickDeviceForXla(device_info_cache_, devices, + /*allow_mixing_unknown_and_cpu=*/false)); if (!can_pick_device) { return false; } - string chosen_device; - TF_RETURN_IF_ERROR(PickDeviceForXla( - devices, /*allow_mixing_unknown_and_cpu=*/false, &chosen_device)); + TF_ASSIGN_OR_RETURN(DeviceId chosen_device, + PickDeviceForXla(device_info_cache_, devices, + /*allow_mixing_unknown_and_cpu=*/false)); // If we are able to pick a device `chosen_device` for the larger cluster, the // resource operations in `cluster_a` and `cluster_b` must be placed on the @@ -1283,9 +1285,11 @@ StatusOr MarkForCompilationPassImpl::AreDevicesCompatible( // _XlaRun kernels are going to run on and therefore try to access the // resource variables from `chosen_device`, which will be an error if the // resource variables are placed on some other device. - auto resource_op_device_ok = [&](const string& resource_op_device) { - return resource_op_device.empty() || resource_op_device == chosen_device; - }; + auto resource_op_device_ok = + [&](absl::optional resource_op_device) { + return !resource_op_device.has_value() || + *resource_op_device == chosen_device; + }; return resource_op_device_ok(cluster_a.resource_op_device()) && resource_op_device_ok(cluster_b.resource_op_device()); @@ -1294,22 +1298,18 @@ StatusOr MarkForCompilationPassImpl::AreDevicesCompatible( // Returns `true` iff we should compile `cluster`. StatusOr MarkForCompilationPassImpl::ShouldCompileClusterImpl( const Cluster& cluster) { - std::vector devices; - absl::c_remove_copy(cluster.devices(), std::back_inserter(devices), ""); - absl::c_sort(devices); + TF_ASSIGN_OR_RETURN(DeviceId chosen_device, + PickDeviceForXla(device_info_cache_, cluster.devices(), + /*allow_mixing_unknown_and_cpu=*/false)); - string chosen_device; - TF_RETURN_IF_ERROR(PickDeviceForXla( - devices, /*allow_mixing_unknown_and_cpu=*/false, &chosen_device)); - - TF_ASSIGN_OR_RETURN(const DeviceType& device_type, - device_info_cache_.GetDeviceTypeFor(chosen_device)); - TF_ASSIGN_OR_RETURN(const XlaOpRegistry::DeviceRegistration* registration, - device_info_cache_.GetCompilationDevice(chosen_device)); + const DeviceType& device_type = + device_info_cache_.GetDeviceTypeFor(chosen_device); + const XlaOpRegistry::DeviceRegistration* registration = + device_info_cache_.GetCompilationDevice(chosen_device); TF_RET_CHECK(registration) - << "chosen device = " << chosen_device + << "chosen device = " << device_info_cache_.GetNameFor(chosen_device) << "; device type = " << device_type.type() << "; devices (" - << devices.size() << ") = " << absl::StrJoin(devices, ", "); + << device_info_cache_.DebugString(cluster.devices()); bool should_compile = cluster.is_xla_compile_attr_true() || @@ -1343,7 +1343,8 @@ StatusOr MarkForCompilationPassImpl::ShouldCompileClusterImpl( } VLOG(3) << (should_compile ? "Compiling" : "Not compiling") - << " cluster with device " << chosen_device; + << " cluster with device " + << device_info_cache_.GetNameFor(chosen_device); return should_compile; }