Use integers instead of strings to represent devices in the auto-clustering passes

This, in turn, lets us represent sets of devices as bitmaps, which is a big win.

PiperOrigin-RevId: 246939469
This commit is contained in:
Sanjoy Das 2019-05-06 18:46:45 -07:00 committed by TensorFlower Gardener
parent 233e0ddbe8
commit b912370109
6 changed files with 374 additions and 184 deletions

View File

@ -593,6 +593,7 @@ cc_library(
"//tensorflow/compiler/xla:status_macros", "//tensorflow/compiler/xla:status_macros",
"//tensorflow/compiler/xla:statusor", "//tensorflow/compiler/xla:statusor",
"//tensorflow/core:framework", "//tensorflow/core:framework",
"@com_google_absl//absl/algorithm:container",
"@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/container:flat_hash_map",
"@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/container:flat_hash_set",
"@com_google_absl//absl/strings", "@com_google_absl//absl/strings",

View File

@ -233,14 +233,10 @@ void RemoveAllIncomingControlEdges(Graph* g, Node* n) {
} }
// Returns true (into `result`) if a node placed on `device` must be compiled. // Returns true (into `result`) if a node placed on `device` must be compiled.
Status DeviceRequiresCompilation(const string& device, bool* result) { Status DeviceRequiresCompilation(const jit::DeviceInfoCache& device_info_cache,
DeviceType device_type(""); jit::DeviceId device, bool* result) {
TF_RETURN_IF_ERROR(DeviceNameToDeviceType(device, &device_type)); const XlaOpRegistry::DeviceRegistration* registration =
const XlaOpRegistry::DeviceRegistration* registration = nullptr; device_info_cache.GetCompilationDevice(device);
if (!XlaOpRegistry::GetCompilationDevice(device_type.type(), &registration)) {
return errors::Internal("Could not find compilation device ",
device_type.type());
}
*result = registration->autoclustering_policy == *result = registration->autoclustering_policy ==
XlaOpRegistry::AutoclusteringPolicy::kAlways; XlaOpRegistry::AutoclusteringPolicy::kAlways;
return Status::OK(); return Status::OK();
@ -293,17 +289,20 @@ Status ReplaceFunctionCallWithPartionedCall(
return Status::OK(); return Status::OK();
} }
Status InferDeviceForCluster(Node* n, const string& function_name, xla::StatusOr<jit::DeviceId> InferDeviceForCluster(
const FunctionLibraryDefinition& flib_def, jit::DeviceInfoCache* device_info_cache, Node* n,
string* result) { const string& function_name, const FunctionLibraryDefinition& flib_def) {
const FunctionDef* func_def = flib_def.Find(function_name); const FunctionDef* func_def = flib_def.Find(function_name);
TF_RET_CHECK(func_def) << "Could not find " << function_name; TF_RET_CHECK(func_def) << "Could not find " << function_name;
std::set<string> device_names; jit::DeviceSet device_set;
for (const NodeDef& ndef : func_def->node_def()) { for (const NodeDef& ndef : func_def->node_def()) {
VLOG(3) << ndef.DebugString(); VLOG(3) << ndef.DebugString();
if (!ndef.device().empty()) { if (!ndef.device().empty()) {
device_names.insert(ndef.device()); TF_ASSIGN_OR_RETURN(jit::DeviceId device_id,
device_info_cache->GetIdFor(ndef.device()));
device_set.Insert(device_id);
} }
} }
@ -311,41 +310,47 @@ Status InferDeviceForCluster(Node* n, const string& function_name,
// TODO(sanjoy): We need this because EncapsulateSubgraphsPass drops device // TODO(sanjoy): We need this because EncapsulateSubgraphsPass drops device
// assignment when constant folding. We should fix EncapsulateSubgraphsPass // assignment when constant folding. We should fix EncapsulateSubgraphsPass
// instead. // instead.
device_names.insert(n->assigned_device_name()); TF_ASSIGN_OR_RETURN(jit::DeviceId device_id,
device_info_cache->GetIdFor(n->assigned_device_name()));
device_set.Insert(device_id);
} }
std::vector<string> device_names_vector; TF_ASSIGN_OR_RETURN(jit::DeviceId result,
absl::c_copy(device_names, std::back_inserter(device_names_vector)); PickDeviceForXla(*device_info_cache, device_set,
/*allow_mixing_unknown_and_cpu=*/true));
Status s = PickDeviceForXla(device_names_vector, true, result); VLOG(2) << "For " << function_name << " PickDeviceForXla("
if (s.ok()) { << device_info_cache->DebugString(device_set) << ") -> "
VLOG(2) << "For " << function_name << " PickDeviceForXla(" << device_info_cache->GetNameFor(result);
<< absl::StrJoin(device_names_vector, ", ") << ") -> " << *result; return result;
}
return s;
} }
Status ReplaceNodeWithXlaCompileAndXlaRun( Status ReplaceNodeWithXlaCompileAndXlaRun(
jit::DeviceInfoCache* device_info_cache,
const GraphOptimizationPassOptions& options, const GraphOptimizationPassOptions& options,
const FunctionLibraryDefinition& flib_def, bool lazy_compilation_enabled, const FunctionLibraryDefinition& flib_def, bool lazy_compilation_enabled,
bool insert_print_nodes, Graph* g, Node* n) { bool insert_print_nodes, Graph* g, Node* n) {
XlaClusterInfo cluster_info; XlaClusterInfo cluster_info;
TF_RETURN_IF_ERROR(GetXlaClusterInfo(n, &cluster_info)); TF_RETURN_IF_ERROR(GetXlaClusterInfo(n, &cluster_info));
string device; TF_ASSIGN_OR_RETURN(
TF_RETURN_IF_ERROR(InferDeviceForCluster(n, cluster_info.function.name(), jit::DeviceId device,
flib_def, &device)); InferDeviceForCluster(device_info_cache, n, cluster_info.function.name(),
flib_def));
bool requires_compilation; bool requires_compilation;
TF_RETURN_IF_ERROR(DeviceRequiresCompilation(device, &requires_compilation)); TF_RETURN_IF_ERROR(DeviceRequiresCompilation(*device_info_cache, device,
&requires_compilation));
if (!lazy_compilation_enabled) { if (!lazy_compilation_enabled) {
requires_compilation = true; requires_compilation = true;
} }
string device_name_str = string(device_info_cache->GetNameFor(device));
Status status; Status status;
Scope root = NewInternalScope(g, &status, /*refiner=*/nullptr) Scope root = NewInternalScope(g, &status, /*refiner=*/nullptr)
.NewSubScope(n->name()) .NewSubScope(n->name())
.WithDevice(n->requested_device()) .WithDevice(n->requested_device())
.WithAssignedDevice(device); .WithAssignedDevice(device_name_str);
ops::_XlaCompile xla_compile(root.WithOpName("xla_compile"), ops::_XlaCompile xla_compile(root.WithOpName("xla_compile"),
/*constants=*/cluster_info.constant_inputs, /*constants=*/cluster_info.constant_inputs,
@ -441,10 +446,12 @@ Status BuildXlaOpsPass::Run(const GraphOptimizationPassOptions& options) {
bool insert_print_nodes = bool insert_print_nodes =
GetBuildXlaOpsPassFlags()->tf_xla_print_cluster_outputs; GetBuildXlaOpsPassFlags()->tf_xla_print_cluster_outputs;
jit::DeviceInfoCache device_info_cache;
for (Node* n : xla_compiled_kernels) { for (Node* n : xla_compiled_kernels) {
TF_RETURN_IF_ERROR(ReplaceNodeWithXlaCompileAndXlaRun( TF_RETURN_IF_ERROR(ReplaceNodeWithXlaCompileAndXlaRun(
options, *options.flib_def, lazy_compilation_enabled, &device_info_cache, options, *options.flib_def,
insert_print_nodes, graph, n)); lazy_compilation_enabled, insert_print_nodes, graph, n));
} }
if (VLOG_IS_ON(1)) { if (VLOG_IS_ON(1)) {

View File

@ -15,6 +15,7 @@ limitations under the License.
#include "tensorflow/compiler/jit/device_util.h" #include "tensorflow/compiler/jit/device_util.h"
#include "absl/algorithm/container.h"
#include "absl/container/flat_hash_set.h" #include "absl/container/flat_hash_set.h"
#include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/compiler/xla/status_macros.h"
@ -22,43 +23,68 @@ namespace tensorflow {
namespace jit { namespace jit {
using xla::StatusOr; using xla::StatusOr;
StatusOr<const XlaOpRegistry::DeviceRegistration*> void DeviceSet::Insert(DeviceId device_id) {
DeviceInfoCache::GetCompilationDevice(absl::string_view device_name) { int word_index = device_id.id() / kWordSize;
auto it = device_to_device_registration_.find(device_name); int bit_index = device_id.id() % kWordSize;
if (it != device_to_device_registration_.end()) {
if (word_index >= storage_.size()) {
storage_.resize(word_index + 1, 0);
}
storage_[word_index] |= (1ull << bit_index);
}
void DeviceSet::UnionWith(const DeviceSet& other) {
if (other.storage_.size() > storage_.size()) {
storage_.resize(other.storage_.size(), 0);
}
for (int i = 0; i < other.storage_.size(); i++) {
storage_[i] |= other.storage_[i];
}
}
bool DeviceSet::IsEmpty() const {
return absl::c_all_of(storage_, [&](uint64 val) { return val == 0; });
}
xla::StatusOr<DeviceId> DeviceInfoCache::GetIdFor(absl::string_view name) {
TF_RET_CHECK(!name.empty());
auto it = name_to_id_.find(name);
if (it != name_to_id_.end()) {
return it->second; return it->second;
} }
string device_name_str = string(device_name); int new_id = names_.size();
TF_ASSIGN_OR_RETURN(const DeviceType& device_type, names_.push_back(string(name));
GetDeviceTypeFor(device_name_str)); id_to_device_type_.push_back(absl::make_unique<DeviceType>(""));
const XlaOpRegistry::DeviceRegistration* registration; DeviceType* device_type = id_to_device_type_.back().get();
if (!XlaOpRegistry::GetCompilationDevice(device_type.type(), &registration)) { TF_RETURN_IF_ERROR(DeviceNameToDeviceType(names_.back(), device_type));
registration = nullptr;
is_cpu_.push_back(device_type->type_string() == DEVICE_CPU);
is_gpu_.push_back(device_type->type_string() == DEVICE_GPU);
name_to_id_.emplace(string(name), DeviceId(new_id));
const XlaOpRegistry::DeviceRegistration* compilation_device;
if (!XlaOpRegistry::GetCompilationDevice(device_type->type(),
&compilation_device)) {
compilation_device = nullptr;
} }
id_to_compilation_device_.push_back(compilation_device);
device_to_device_registration_.insert( return DeviceId(new_id);
{std::move(device_name_str), registration});
return registration;
} }
StatusOr<std::reference_wrapper<const DeviceType>> string DeviceInfoCache::DebugString(const DeviceSet& device_set) const {
DeviceInfoCache::GetDeviceTypeFor(absl::string_view device_name) { std::vector<string> names;
auto it = device_to_device_type_.find(device_name); device_set.ForEach([&](DeviceId device_id) {
if (it != device_to_device_type_.end()) { names.push_back(string(GetNameFor(device_id)));
return std::cref(*it->second); return false;
} });
string device_name_str = string(device_name); return absl::StrCat("[", absl::StrJoin(names, ","), "]");
auto device_type = absl::make_unique<DeviceType>("");
TF_RETURN_IF_ERROR(
DeviceNameToDeviceType(device_name_str, device_type.get()));
it = device_to_device_type_
.insert({std::move(device_name_str), std::move(device_type)})
.first;
return std::cref(*it->second);
} }
} // namespace jit } // namespace jit
@ -71,10 +97,11 @@ Status DeviceNameToDeviceType(const string& device, DeviceType* device_type) {
return Status::OK(); return Status::OK();
} }
Status PickDeviceForXlaImpl(absl::Span<const string> device_names, Status PickDeviceForXlaImpl(const jit::DeviceInfoCache& device_info_cache,
const jit::DeviceSet& devices,
bool allow_mixing_unknown_and_cpu, bool allow_mixing_unknown_and_cpu,
bool* out_can_pick_device, bool* out_can_pick_device,
string* out_device_picked) { absl::optional<jit::DeviceId>* out_device_picked) {
if (out_can_pick_device) { if (out_can_pick_device) {
*out_can_pick_device = true; *out_can_pick_device = true;
} }
@ -89,65 +116,79 @@ Status PickDeviceForXlaImpl(absl::Span<const string> device_names,
} \ } \
} while (false) } while (false)
TF_RET_CHECK(!device_names.empty()) << "No devices to choose from"; TF_RET_CHECK(!devices.IsEmpty()) << "No devices to choose from";
DCHECK_NE(out_can_pick_device == nullptr, out_device_picked == nullptr); DCHECK_NE(out_can_pick_device == nullptr, out_device_picked == nullptr);
absl::flat_hash_set<absl::string_view> device_names_set; absl::optional<jit::DeviceId> maybe_gpu_device;
for (absl::string_view device_name : device_names) { absl::optional<jit::DeviceId> maybe_cpu_device;
TF_RET_CHECK(!device_name.empty()); absl::optional<jit::DeviceId> maybe_unknown_device;
device_names_set.insert(device_name);
}
absl::optional<absl::string_view> maybe_gpu_device; bool multiple_cpu_devices = false;
absl::optional<absl::string_view> maybe_cpu_device; bool multiple_gpu_devices = false;
absl::optional<absl::string_view> maybe_unknown_device; bool multiple_unknown_devices = false;
for (absl::string_view device_name : device_names_set) { devices.ForEach([&](jit::DeviceId device) {
DeviceNameUtils::ParsedName parsed_name; if (device_info_cache.IsGpu(device)) {
TF_RET_CHECK(DeviceNameUtils::ParseFullName(device_name, &parsed_name))
<< device_name;
if (parsed_name.type == "GPU") {
if (maybe_gpu_device) { if (maybe_gpu_device) {
FAILED_TO_PICK_DEVICE(errors::Internal( multiple_gpu_devices = true;
"Multiple GPU devices ", absl::StrJoin(device_names, ", "))); return false;
} }
maybe_gpu_device = device_name; maybe_gpu_device = device;
} else if (parsed_name.type == "CPU") { } else if (device_info_cache.IsCpu(device)) {
if (maybe_cpu_device) { if (maybe_cpu_device) {
FAILED_TO_PICK_DEVICE(errors::Internal( multiple_cpu_devices = true;
"Multiple CPU devices ", absl::StrJoin(device_names, ", "))); return false;
} }
maybe_cpu_device = device_name; maybe_cpu_device = device;
} else { } else {
if (maybe_unknown_device) { if (maybe_unknown_device) {
FAILED_TO_PICK_DEVICE(errors::Internal( multiple_unknown_devices = true;
"Multiple unknown devices ", absl::StrJoin(device_names, ", "))); return false;
} }
maybe_unknown_device = device_name; maybe_unknown_device = device;
} }
return true;
});
if (multiple_cpu_devices) {
FAILED_TO_PICK_DEVICE(errors::Internal(
"Multiple CPU devices ", device_info_cache.DebugString(devices)));
}
if (multiple_gpu_devices) {
FAILED_TO_PICK_DEVICE(errors::Internal(
"Multiple GPU devices ", device_info_cache.DebugString(devices)));
}
if (multiple_unknown_devices) {
FAILED_TO_PICK_DEVICE(errors::Internal(
"Multiple unknown devices ", device_info_cache.DebugString(devices)));
} }
if (maybe_unknown_device && maybe_gpu_device) { if (maybe_unknown_device && maybe_gpu_device) {
FAILED_TO_PICK_DEVICE(errors::Internal( FAILED_TO_PICK_DEVICE(errors::Internal(
"Found both unknown and GPU devices: ", *maybe_unknown_device, ", ", "Found both unknown and GPU devices: ",
*maybe_gpu_device)); device_info_cache.GetNameFor(*maybe_unknown_device), ", ",
device_info_cache.GetNameFor(*maybe_gpu_device)));
} }
if (!allow_mixing_unknown_and_cpu) { if (!allow_mixing_unknown_and_cpu) {
if (maybe_unknown_device && maybe_cpu_device) { if (maybe_unknown_device && maybe_cpu_device) {
FAILED_TO_PICK_DEVICE(errors::Internal( FAILED_TO_PICK_DEVICE(errors::Internal(
"Found both unknown and CPU devices: ", *maybe_unknown_device, ", ", "Found both unknown and CPU devices: ",
*maybe_cpu_device)); device_info_cache.GetNameFor(*maybe_unknown_device), ", ",
device_info_cache.GetNameFor(*maybe_cpu_device)));
} }
} }
if (out_device_picked) { if (out_device_picked) {
if (maybe_gpu_device) { if (maybe_gpu_device) {
*out_device_picked = string(*maybe_gpu_device); *out_device_picked = *maybe_gpu_device;
} else if (maybe_unknown_device) { } else if (maybe_unknown_device) {
*out_device_picked = string(*maybe_unknown_device); *out_device_picked = *maybe_unknown_device;
} else { } else {
*out_device_picked = string(*maybe_cpu_device); *out_device_picked = *maybe_cpu_device;
} }
} }
@ -156,19 +197,24 @@ Status PickDeviceForXlaImpl(absl::Span<const string> device_names,
#undef FAILED_TO_PICK_DEVICE #undef FAILED_TO_PICK_DEVICE
} }
Status PickDeviceForXla(absl::Span<const string> device_names, xla::StatusOr<jit::DeviceId> PickDeviceForXla(
bool allow_mixing_unknown_and_cpu, const jit::DeviceInfoCache& device_info_cache,
string* out_device_picked) { const jit::DeviceSet& devices, bool allow_mixing_unknown_and_cpu) {
return PickDeviceForXlaImpl(device_names, allow_mixing_unknown_and_cpu, absl::optional<jit::DeviceId> device;
/*out_can_pick_device=*/nullptr, TF_RETURN_IF_ERROR(PickDeviceForXlaImpl(
out_device_picked); device_info_cache, devices, allow_mixing_unknown_and_cpu,
/*out_can_pick_device=*/nullptr, &device));
return *device;
} }
Status CanPickDeviceForXla(absl::Span<const string> device_names, xla::StatusOr<bool> CanPickDeviceForXla(
bool allow_mixing_unknown_and_cpu, const jit::DeviceInfoCache& device_info_cache,
bool* out_can_pick_device) { const jit::DeviceSet& devices, bool allow_mixing_unknown_and_cpu) {
return PickDeviceForXlaImpl(device_names, allow_mixing_unknown_and_cpu, bool can_pick_device;
out_can_pick_device, TF_RETURN_IF_ERROR(PickDeviceForXlaImpl(device_info_cache, devices,
/*out_device_picked=*/nullptr); allow_mixing_unknown_and_cpu,
&can_pick_device,
/*out_device_picked=*/nullptr));
return can_pick_device;
} }
} // namespace tensorflow } // namespace tensorflow

View File

@ -23,24 +23,119 @@ limitations under the License.
#include "absl/strings/string_view.h" #include "absl/strings/string_view.h"
#include "absl/types/span.h" #include "absl/types/span.h"
#include "tensorflow/compiler/tf2xla/xla_op_registry.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h"
#include "tensorflow/compiler/xla/status_macros.h"
#include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/compiler/xla/statusor.h"
#include "tensorflow/core/framework/types.h" #include "tensorflow/core/framework/types.h"
namespace tensorflow { namespace tensorflow {
namespace jit { namespace jit {
// Instances of DeviceId represent TensorFlow devices as integers.
//
// This helps avoid having to manipulate device names as strings when
// auto-clustering.
class DeviceId {
public:
DeviceId(DeviceId&&) = default;
DeviceId(const DeviceId&) = default;
DeviceId& operator=(const DeviceId&) = default;
bool operator==(const DeviceId& other) const { return id() == other.id(); }
bool operator!=(const DeviceId& other) const { return !(*this == other); }
private:
int id_;
explicit DeviceId(int id) : id_(id) {}
int id() const { return id_; }
friend class DeviceInfoCache;
friend class DeviceSet;
};
// A set of DeviceIds, represented as a bitmap.
class DeviceSet {
public:
void Insert(DeviceId device_id);
void UnionWith(const DeviceSet& other);
bool IsEmpty() const;
// Calls `func` on each DeviceId in the set. Stops iterating early if `func`
// return false.
//
// TODO(sanjoy): Change this to take a typed std::function if that's
// performance neutral.
template <typename FnTy>
void ForEach(FnTy func) const {
// This is really a poor man's iterator, we should consider writing a proper
// iterator if this ends up being used widely.
for (int word_index = 0; word_index < storage_.size(); word_index++) {
uint64 word = storage_[word_index];
for (int bit_index = 0; bit_index < kWordSize; bit_index++) {
if (word & (1ull << bit_index)) {
if (!func(DeviceId(word_index * kWordSize + bit_index))) {
return;
}
}
}
}
}
private:
absl::InlinedVector<uint64, 1> storage_;
const int kWordSize = 64;
};
// Caches some miscellaneous information about TF devices. Thread compatible. // Caches some miscellaneous information about TF devices. Thread compatible.
class DeviceInfoCache { class DeviceInfoCache {
public: public:
xla::StatusOr<const XlaOpRegistry::DeviceRegistration*> GetCompilationDevice( bool IsGpu(DeviceId device) const { return is_gpu_[device.id()]; }
absl::string_view device_name); bool IsCpu(DeviceId device) const { return is_cpu_[device.id()]; }
xla::StatusOr<std::reference_wrapper<const DeviceType>> GetDeviceTypeFor(
absl::string_view device_name); absl::string_view GetNameFor(DeviceId device) const {
return names_[device.id()];
}
xla::StatusOr<DeviceId> GetIdFor(absl::string_view name);
using DeviceRegistration = const XlaOpRegistry::DeviceRegistration;
DeviceRegistration* GetCompilationDevice(DeviceId device) const {
return id_to_compilation_device_[device.id()];
}
xla::StatusOr<DeviceRegistration*> GetCompilationDevice(
absl::string_view name) {
TF_ASSIGN_OR_RETURN(DeviceId device_id, GetIdFor(name));
return GetCompilationDevice(device_id);
}
const DeviceType& GetDeviceTypeFor(DeviceId device) const {
return *id_to_device_type_[device.id()];
}
using DeviceTypeConstRef = std::reference_wrapper<const DeviceType>;
xla::StatusOr<DeviceTypeConstRef> GetDeviceTypeFor(
absl::string_view device_name) {
TF_ASSIGN_OR_RETURN(DeviceId device_id, GetIdFor(device_name));
return std::cref(*id_to_device_type_[device_id.id()]);
}
string DebugString(const DeviceSet& device_set) const;
private: private:
absl::flat_hash_map<string, const XlaOpRegistry::DeviceRegistration*> absl::flat_hash_map<string, DeviceId> name_to_id_;
device_to_device_registration_;
absl::flat_hash_map<string, std::unique_ptr<DeviceType>> // These fields are populated for a device in GetIdFor, *before* we give out a
device_to_device_type_; // DeviceId.
std::vector<const XlaOpRegistry::DeviceRegistration*>
id_to_compilation_device_;
std::vector<std::unique_ptr<DeviceType>> id_to_device_type_;
std::vector<string> names_;
std::vector<bool> is_cpu_;
std::vector<bool> is_gpu_;
}; };
} // namespace jit } // namespace jit
@ -49,7 +144,7 @@ class DeviceInfoCache {
Status DeviceNameToDeviceType(const string& device, DeviceType* device_type); Status DeviceNameToDeviceType(const string& device, DeviceType* device_type);
// Picks the device for which XLA should compile a cluster that contains // Picks the device for which XLA should compile a cluster that contains
// operations placed in devices in `device_names`. For instance a cluster that // operations placed in devices in `devices`. For instance a cluster that
// contains operations solely placed on the CPU will be compiled into a CPU // contains operations solely placed on the CPU will be compiled into a CPU
// executable by XLA, whereas a cluster that contains operations placed on the // executable by XLA, whereas a cluster that contains operations placed on the
// CPU and also operations placed on the GPU will be compiled into a GPU // CPU and also operations placed on the GPU will be compiled into a GPU
@ -82,16 +177,15 @@ Status DeviceNameToDeviceType(const string& device, DeviceType* device_type);
// case it is the responsibility of the optimization pass that injected the // case it is the responsibility of the optimization pass that injected the
// CPU nodes into the cluster to ensure that these nodes can be compiled by // CPU nodes into the cluster to ensure that these nodes can be compiled by
// the unknown XLA backend. // the unknown XLA backend.
Status PickDeviceForXla(absl::Span<const string> device_names, xla::StatusOr<jit::DeviceId> PickDeviceForXla(
bool allow_mixing_unknown_and_cpu, const jit::DeviceInfoCache& device_info_cache,
string* out_device_picked); const jit::DeviceSet& devices, bool allow_mixing_unknown_and_cpu);
// This is like `PickDeviceForXla` except that it returns false (instead of a // This is like `PickDeviceForXla` except that it returns false (instead of a
// non-OK Status) in `out_can_pick_device` if no unambiguous choice of device // non-OK Status) if no unambiguous choice of device exists.
// exists. xla::StatusOr<bool> CanPickDeviceForXla(
Status CanPickDeviceForXla(absl::Span<const string> device_names, const jit::DeviceInfoCache& device_info_cache,
bool allow_mixing_unknown_and_cpu, const jit::DeviceSet& devices, bool allow_mixing_unknown_and_cpu);
bool* out_can_pick_device);
} // namespace tensorflow } // namespace tensorflow
#endif // TENSORFLOW_COMPILER_JIT_DEVICE_INFO_CACHE_H_ #endif // TENSORFLOW_COMPILER_JIT_DEVICE_INFO_CACHE_H_

View File

@ -22,12 +22,20 @@ namespace tensorflow {
namespace { namespace {
Status PickDeviceHelper(bool allow_mixing_unknown_and_cpu, Status PickDeviceHelper(bool allow_mixing_unknown_and_cpu,
absl::Span<const absl::string_view> inputs, absl::Span<const absl::string_view> device_names,
string* result) { string* result) {
std::vector<string> inputs_string; jit::DeviceInfoCache cache;
absl::c_transform(inputs, std::back_inserter(inputs_string), jit::DeviceSet device_set;
[](absl::string_view sv) { return string(sv); }); for (absl::string_view name : device_names) {
return PickDeviceForXla(inputs_string, allow_mixing_unknown_and_cpu, result); TF_ASSIGN_OR_RETURN(jit::DeviceId device_id, cache.GetIdFor(name));
device_set.Insert(device_id);
}
TF_ASSIGN_OR_RETURN(
jit::DeviceId result_id,
PickDeviceForXla(cache, device_set, allow_mixing_unknown_and_cpu));
*result = string(cache.GetNameFor(result_id));
return Status::OK();
} }
void CheckPickDeviceResult(absl::string_view expected_result, void CheckPickDeviceResult(absl::string_view expected_result,
@ -87,5 +95,38 @@ TEST(PickDeviceForXla, MultipleDevicesOfSameType) {
CheckPickDeviceHasError(false, {kCPU0, kCPU1, kGPU0}); CheckPickDeviceHasError(false, {kCPU0, kCPU1, kGPU0});
} }
void SimpleRoundTripTestForDeviceSet(int num_devices) {
jit::DeviceSet device_set;
jit::DeviceInfoCache device_info_cache;
std::vector<string> expected_devices, actual_devices;
for (int i = 0; i < num_devices; i++) {
string device_name =
absl::StrCat("/job:localhost/replica:0/task:0/device:XPU:", i);
TF_ASSERT_OK_AND_ASSIGN(jit::DeviceId device_id,
device_info_cache.GetIdFor(device_name));
device_set.Insert(device_id);
expected_devices.push_back(device_name);
}
device_set.ForEach([&](jit::DeviceId device_id) {
actual_devices.push_back(string(device_info_cache.GetNameFor(device_id)));
return true;
});
EXPECT_EQ(expected_devices, actual_devices);
}
TEST(DeviceSetTest, SimpleRoundTrip_One) { SimpleRoundTripTestForDeviceSet(1); }
TEST(DeviceSetTest, SimpleRoundTrip_Small) {
SimpleRoundTripTestForDeviceSet(8);
}
TEST(DeviceSetTest, SimpleRoundTrip_Large) {
SimpleRoundTripTestForDeviceSet(800);
}
} // namespace } // namespace
} // namespace tensorflow } // namespace tensorflow

View File

@ -57,6 +57,8 @@ namespace tensorflow {
namespace { namespace {
using DeadnessPredicate = DeadnessAnalysis::DeadnessPredicate; using DeadnessPredicate = DeadnessAnalysis::DeadnessPredicate;
using jit::DeviceId;
using jit::DeviceSet;
using xla::StatusOr; using xla::StatusOr;
// The clusters we create here are eventually lowered into an // The clusters we create here are eventually lowered into an
@ -117,8 +119,8 @@ class MarkForCompilationPassImpl {
public: public:
// Constructs a trivial cluster representing a single TF node. // Constructs a trivial cluster representing a single TF node.
Cluster(int tf_graph_node_id, int effective_cluster_size, Cluster(int tf_graph_node_id, int effective_cluster_size,
bool has_functional_control_flow, bool has_functional_control_flow, DeviceSet devices,
absl::flat_hash_set<string> devices, string resource_op_device, absl::optional<DeviceId> resource_op_device,
absl::optional<int> resource_var_operation_node_id, absl::optional<int> resource_var_operation_node_id,
absl::optional<DeadnessPredicate> deadness_predicate, absl::optional<DeadnessPredicate> deadness_predicate,
bool is_xla_compile_attr_true, absl::optional<string> xla_scope) bool is_xla_compile_attr_true, absl::optional<string> xla_scope)
@ -126,7 +128,7 @@ class MarkForCompilationPassImpl {
effective_cluster_size_(effective_cluster_size), effective_cluster_size_(effective_cluster_size),
has_functional_control_flow_(has_functional_control_flow), has_functional_control_flow_(has_functional_control_flow),
devices_(std::move(devices)), devices_(std::move(devices)),
resource_op_device_(std::move(resource_op_device)), resource_op_device_(resource_op_device),
deadness_predicate_(deadness_predicate), deadness_predicate_(deadness_predicate),
is_xla_compile_attr_true_(is_xla_compile_attr_true), is_xla_compile_attr_true_(is_xla_compile_attr_true),
xla_scope_(std::move(xla_scope)) { xla_scope_(std::move(xla_scope)) {
@ -162,12 +164,14 @@ class MarkForCompilationPassImpl {
} }
// The set of devices nodes in the cluster are placed on. // The set of devices nodes in the cluster are placed on.
const absl::flat_hash_set<string>& devices() const { return devices_; } const DeviceSet& devices() const { return devices_; }
// If the cluster has a resource operation then the device the resource // If the cluster has a resource operation then the device the resource
// operation is placed on. A cluster may have resource ops placed only on a // operation is placed on. A cluster may have resource ops placed only on a
// single device. // single device.
const string& resource_op_device() const { return resource_op_device_; } const absl::optional<DeviceId>& resource_op_device() const {
return resource_op_device_;
}
// If not nullopt the a predicate that is true iff the cluster is alive. // If not nullopt the a predicate that is true iff the cluster is alive.
// Otherwise the user has (unsafely) disabled deadness analysis. If this is // Otherwise the user has (unsafely) disabled deadness analysis. If this is
@ -208,8 +212,8 @@ class MarkForCompilationPassImpl {
int cycles_graph_node_id_; int cycles_graph_node_id_;
int effective_cluster_size_; int effective_cluster_size_;
bool has_functional_control_flow_; bool has_functional_control_flow_;
absl::flat_hash_set<string> devices_; DeviceSet devices_;
string resource_op_device_; absl::optional<DeviceId> resource_op_device_;
absl::optional<DeadnessPredicate> deadness_predicate_; absl::optional<DeadnessPredicate> deadness_predicate_;
bool is_xla_compile_attr_true_; bool is_xla_compile_attr_true_;
absl::optional<string> xla_scope_; absl::optional<string> xla_scope_;
@ -279,17 +283,17 @@ class MarkForCompilationPassImpl {
Cluster* MakeNewCluster(int cycles_graph_node_id, int effective_cluster_size, Cluster* MakeNewCluster(int cycles_graph_node_id, int effective_cluster_size,
bool has_functional_control_flow, bool has_functional_control_flow,
absl::flat_hash_set<string> devices, const DeviceSet& device_set,
string resource_op_device, absl::optional<DeviceId> resource_op_device,
absl::optional<int> resource_var_operation_node_id, absl::optional<int> resource_var_operation_node_id,
absl::optional<DeadnessPredicate> deadness_predicate, absl::optional<DeadnessPredicate> deadness_predicate,
bool is_xla_compile_attr_true, bool is_xla_compile_attr_true,
absl::optional<string> xla_scope) { absl::optional<string> xla_scope) {
cluster_storage_.push_back(absl::make_unique<Cluster>( cluster_storage_.push_back(absl::make_unique<Cluster>(
cycles_graph_node_id, effective_cluster_size, cycles_graph_node_id, effective_cluster_size,
has_functional_control_flow, std::move(devices), has_functional_control_flow, device_set, resource_op_device,
std::move(resource_op_device), resource_var_operation_node_id, resource_var_operation_node_id, deadness_predicate,
deadness_predicate, is_xla_compile_attr_true, xla_scope)); is_xla_compile_attr_true, xla_scope));
return cluster_storage_.back().get(); return cluster_storage_.back().get();
} }
@ -486,13 +490,15 @@ void MarkForCompilationPassImpl::Cluster::Merge(Cluster* other) {
effective_cluster_size_ += other->effective_cluster_size_; effective_cluster_size_ += other->effective_cluster_size_;
has_functional_control_flow_ |= other->has_functional_control_flow_; has_functional_control_flow_ |= other->has_functional_control_flow_;
for (string other_device : other->devices_) { devices_.UnionWith(other->devices_);
devices_.insert(other_device);
}
other->devices_.clear();
if (resource_op_device_.empty()) { DCHECK(!(resource_op_device_.has_value() &&
resource_op_device_ = std::move(other->resource_op_device_); other->resource_op_device_.has_value()) ||
*resource_op_device_ == *other->resource_op_device_)
<< "AreDevicesCompatible should have returned false otherwise!";
if (!resource_op_device_.has_value()) {
resource_op_device_ = other->resource_op_device_;
} }
is_xla_compile_attr_true_ |= other->is_xla_compile_attr_true_; is_xla_compile_attr_true_ |= other->is_xla_compile_attr_true_;
@ -779,12 +785,14 @@ Status MarkForCompilationPassImpl::BuildInitialClusterSet() {
deadness_analysis_->GetPredicateFor(node, Graph::kControlSlot)); deadness_analysis_->GetPredicateFor(node, Graph::kControlSlot));
} }
const string& device = !node->assigned_device_name().empty() const string& device_name_str = !node->assigned_device_name().empty()
? node->assigned_device_name() ? node->assigned_device_name()
: node->requested_device(); : node->requested_device();
TF_ASSIGN_OR_RETURN(DeviceId device,
device_info_cache_.GetIdFor(device_name_str));
bool is_resource_op = HasResourceInputOrOutput(*node); bool is_resource_op = HasResourceInputOrOutput(*node);
string resource_op_device; absl::optional<DeviceId> resource_op_device;
if (is_resource_op) { if (is_resource_op) {
resource_op_device = device; resource_op_device = device;
} }
@ -805,15 +813,14 @@ Status MarkForCompilationPassImpl::BuildInitialClusterSet() {
is_xla_compile_attr_true |= xla_compile_attr; is_xla_compile_attr_true |= xla_compile_attr;
} }
absl::flat_hash_set<string> devices; DeviceSet devices;
devices.insert(device); devices.Insert(device);
Cluster* new_cluster = MakeNewCluster( Cluster* new_cluster = MakeNewCluster(
/*cycles_graph_node_id=*/node->id(), /*cycles_graph_node_id=*/node->id(),
/*effective_cluster_size=*/effective_cluster_size, /*effective_cluster_size=*/effective_cluster_size,
/*has_functional_control_flow=*/has_functional_control_flow, /*has_functional_control_flow=*/has_functional_control_flow, devices,
std::move(devices), std::move(resource_op_device), resource_op_device, resource_var_operation_node_id, deadness_predicate,
resource_var_operation_node_id, deadness_predicate,
/*is_xla_compile_attr_true=*/is_xla_compile_attr_true, /*is_xla_compile_attr_true=*/is_xla_compile_attr_true,
GetXlaScope(node)); GetXlaScope(node));
@ -1255,27 +1262,22 @@ void MarkForCompilationPassImpl::VLogClusteringSummary() {
StatusOr<bool> MarkForCompilationPassImpl::AreDevicesCompatible( StatusOr<bool> MarkForCompilationPassImpl::AreDevicesCompatible(
const Cluster& cluster_a, const Cluster& cluster_b) { const Cluster& cluster_a, const Cluster& cluster_b) {
std::vector<string> devices; DeviceSet devices = cluster_a.devices();
absl::c_remove_copy(cluster_a.devices(), std::back_inserter(devices), ""); devices.UnionWith(cluster_b.devices());
absl::c_remove_copy(cluster_b.devices(), std::back_inserter(devices), "");
absl::c_sort(devices);
if (devices.empty()) {
return false;
}
// First check if we will even be able to pick a device for the larger // First check if we will even be able to pick a device for the larger
// combined cluster. // combined cluster.
bool can_pick_device; TF_ASSIGN_OR_RETURN(
TF_RETURN_IF_ERROR(CanPickDeviceForXla( bool can_pick_device,
devices, /*allow_mixing_unknown_and_cpu=*/false, &can_pick_device)); CanPickDeviceForXla(device_info_cache_, devices,
/*allow_mixing_unknown_and_cpu=*/false));
if (!can_pick_device) { if (!can_pick_device) {
return false; return false;
} }
string chosen_device; TF_ASSIGN_OR_RETURN(DeviceId chosen_device,
TF_RETURN_IF_ERROR(PickDeviceForXla( PickDeviceForXla(device_info_cache_, devices,
devices, /*allow_mixing_unknown_and_cpu=*/false, &chosen_device)); /*allow_mixing_unknown_and_cpu=*/false));
// If we are able to pick a device `chosen_device` for the larger cluster, the // If we are able to pick a device `chosen_device` for the larger cluster, the
// resource operations in `cluster_a` and `cluster_b` must be placed on the // resource operations in `cluster_a` and `cluster_b` must be placed on the
@ -1283,9 +1285,11 @@ StatusOr<bool> MarkForCompilationPassImpl::AreDevicesCompatible(
// _XlaRun kernels are going to run on and therefore try to access the // _XlaRun kernels are going to run on and therefore try to access the
// resource variables from `chosen_device`, which will be an error if the // resource variables from `chosen_device`, which will be an error if the
// resource variables are placed on some other device. // resource variables are placed on some other device.
auto resource_op_device_ok = [&](const string& resource_op_device) { auto resource_op_device_ok =
return resource_op_device.empty() || resource_op_device == chosen_device; [&](absl::optional<DeviceId> resource_op_device) {
}; return !resource_op_device.has_value() ||
*resource_op_device == chosen_device;
};
return resource_op_device_ok(cluster_a.resource_op_device()) && return resource_op_device_ok(cluster_a.resource_op_device()) &&
resource_op_device_ok(cluster_b.resource_op_device()); resource_op_device_ok(cluster_b.resource_op_device());
@ -1294,22 +1298,18 @@ StatusOr<bool> MarkForCompilationPassImpl::AreDevicesCompatible(
// Returns `true` iff we should compile `cluster`. // Returns `true` iff we should compile `cluster`.
StatusOr<bool> MarkForCompilationPassImpl::ShouldCompileClusterImpl( StatusOr<bool> MarkForCompilationPassImpl::ShouldCompileClusterImpl(
const Cluster& cluster) { const Cluster& cluster) {
std::vector<string> devices; TF_ASSIGN_OR_RETURN(DeviceId chosen_device,
absl::c_remove_copy(cluster.devices(), std::back_inserter(devices), ""); PickDeviceForXla(device_info_cache_, cluster.devices(),
absl::c_sort(devices); /*allow_mixing_unknown_and_cpu=*/false));
string chosen_device; const DeviceType& device_type =
TF_RETURN_IF_ERROR(PickDeviceForXla( device_info_cache_.GetDeviceTypeFor(chosen_device);
devices, /*allow_mixing_unknown_and_cpu=*/false, &chosen_device)); const XlaOpRegistry::DeviceRegistration* registration =
device_info_cache_.GetCompilationDevice(chosen_device);
TF_ASSIGN_OR_RETURN(const DeviceType& device_type,
device_info_cache_.GetDeviceTypeFor(chosen_device));
TF_ASSIGN_OR_RETURN(const XlaOpRegistry::DeviceRegistration* registration,
device_info_cache_.GetCompilationDevice(chosen_device));
TF_RET_CHECK(registration) TF_RET_CHECK(registration)
<< "chosen device = " << chosen_device << "chosen device = " << device_info_cache_.GetNameFor(chosen_device)
<< "; device type = " << device_type.type() << "; devices (" << "; device type = " << device_type.type() << "; devices ("
<< devices.size() << ") = " << absl::StrJoin(devices, ", "); << device_info_cache_.DebugString(cluster.devices());
bool should_compile = bool should_compile =
cluster.is_xla_compile_attr_true() || cluster.is_xla_compile_attr_true() ||
@ -1343,7 +1343,8 @@ StatusOr<bool> MarkForCompilationPassImpl::ShouldCompileClusterImpl(
} }
VLOG(3) << (should_compile ? "Compiling" : "Not compiling") VLOG(3) << (should_compile ? "Compiling" : "Not compiling")
<< " cluster with device " << chosen_device; << " cluster with device "
<< device_info_cache_.GetNameFor(chosen_device);
return should_compile; return should_compile;
} }