Use integers instead of strings to represent devices in the auto-clustering passes
This, in turn, lets us represent sets of devices as bitmaps, which is a big win. PiperOrigin-RevId: 246939469
This commit is contained in:
parent
233e0ddbe8
commit
b912370109
@ -593,6 +593,7 @@ cc_library(
|
||||
"//tensorflow/compiler/xla:status_macros",
|
||||
"//tensorflow/compiler/xla:statusor",
|
||||
"//tensorflow/core:framework",
|
||||
"@com_google_absl//absl/algorithm:container",
|
||||
"@com_google_absl//absl/container:flat_hash_map",
|
||||
"@com_google_absl//absl/container:flat_hash_set",
|
||||
"@com_google_absl//absl/strings",
|
||||
|
@ -233,14 +233,10 @@ void RemoveAllIncomingControlEdges(Graph* g, Node* n) {
|
||||
}
|
||||
|
||||
// Returns true (into `result`) if a node placed on `device` must be compiled.
|
||||
Status DeviceRequiresCompilation(const string& device, bool* result) {
|
||||
DeviceType device_type("");
|
||||
TF_RETURN_IF_ERROR(DeviceNameToDeviceType(device, &device_type));
|
||||
const XlaOpRegistry::DeviceRegistration* registration = nullptr;
|
||||
if (!XlaOpRegistry::GetCompilationDevice(device_type.type(), ®istration)) {
|
||||
return errors::Internal("Could not find compilation device ",
|
||||
device_type.type());
|
||||
}
|
||||
Status DeviceRequiresCompilation(const jit::DeviceInfoCache& device_info_cache,
|
||||
jit::DeviceId device, bool* result) {
|
||||
const XlaOpRegistry::DeviceRegistration* registration =
|
||||
device_info_cache.GetCompilationDevice(device);
|
||||
*result = registration->autoclustering_policy ==
|
||||
XlaOpRegistry::AutoclusteringPolicy::kAlways;
|
||||
return Status::OK();
|
||||
@ -293,17 +289,20 @@ Status ReplaceFunctionCallWithPartionedCall(
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status InferDeviceForCluster(Node* n, const string& function_name,
|
||||
const FunctionLibraryDefinition& flib_def,
|
||||
string* result) {
|
||||
xla::StatusOr<jit::DeviceId> InferDeviceForCluster(
|
||||
jit::DeviceInfoCache* device_info_cache, Node* n,
|
||||
const string& function_name, const FunctionLibraryDefinition& flib_def) {
|
||||
const FunctionDef* func_def = flib_def.Find(function_name);
|
||||
TF_RET_CHECK(func_def) << "Could not find " << function_name;
|
||||
|
||||
std::set<string> device_names;
|
||||
jit::DeviceSet device_set;
|
||||
|
||||
for (const NodeDef& ndef : func_def->node_def()) {
|
||||
VLOG(3) << ndef.DebugString();
|
||||
if (!ndef.device().empty()) {
|
||||
device_names.insert(ndef.device());
|
||||
TF_ASSIGN_OR_RETURN(jit::DeviceId device_id,
|
||||
device_info_cache->GetIdFor(ndef.device()));
|
||||
device_set.Insert(device_id);
|
||||
}
|
||||
}
|
||||
|
||||
@ -311,41 +310,47 @@ Status InferDeviceForCluster(Node* n, const string& function_name,
|
||||
// TODO(sanjoy): We need this because EncapsulateSubgraphsPass drops device
|
||||
// assignment when constant folding. We should fix EncapsulateSubgraphsPass
|
||||
// instead.
|
||||
device_names.insert(n->assigned_device_name());
|
||||
TF_ASSIGN_OR_RETURN(jit::DeviceId device_id,
|
||||
device_info_cache->GetIdFor(n->assigned_device_name()));
|
||||
device_set.Insert(device_id);
|
||||
}
|
||||
|
||||
std::vector<string> device_names_vector;
|
||||
absl::c_copy(device_names, std::back_inserter(device_names_vector));
|
||||
|
||||
Status s = PickDeviceForXla(device_names_vector, true, result);
|
||||
if (s.ok()) {
|
||||
TF_ASSIGN_OR_RETURN(jit::DeviceId result,
|
||||
PickDeviceForXla(*device_info_cache, device_set,
|
||||
/*allow_mixing_unknown_and_cpu=*/true));
|
||||
VLOG(2) << "For " << function_name << " PickDeviceForXla("
|
||||
<< absl::StrJoin(device_names_vector, ", ") << ") -> " << *result;
|
||||
}
|
||||
return s;
|
||||
<< device_info_cache->DebugString(device_set) << ") -> "
|
||||
<< device_info_cache->GetNameFor(result);
|
||||
return result;
|
||||
}
|
||||
|
||||
Status ReplaceNodeWithXlaCompileAndXlaRun(
|
||||
jit::DeviceInfoCache* device_info_cache,
|
||||
const GraphOptimizationPassOptions& options,
|
||||
const FunctionLibraryDefinition& flib_def, bool lazy_compilation_enabled,
|
||||
bool insert_print_nodes, Graph* g, Node* n) {
|
||||
XlaClusterInfo cluster_info;
|
||||
TF_RETURN_IF_ERROR(GetXlaClusterInfo(n, &cluster_info));
|
||||
|
||||
string device;
|
||||
TF_RETURN_IF_ERROR(InferDeviceForCluster(n, cluster_info.function.name(),
|
||||
flib_def, &device));
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
jit::DeviceId device,
|
||||
InferDeviceForCluster(device_info_cache, n, cluster_info.function.name(),
|
||||
flib_def));
|
||||
|
||||
bool requires_compilation;
|
||||
TF_RETURN_IF_ERROR(DeviceRequiresCompilation(device, &requires_compilation));
|
||||
TF_RETURN_IF_ERROR(DeviceRequiresCompilation(*device_info_cache, device,
|
||||
&requires_compilation));
|
||||
if (!lazy_compilation_enabled) {
|
||||
requires_compilation = true;
|
||||
}
|
||||
|
||||
string device_name_str = string(device_info_cache->GetNameFor(device));
|
||||
|
||||
Status status;
|
||||
Scope root = NewInternalScope(g, &status, /*refiner=*/nullptr)
|
||||
.NewSubScope(n->name())
|
||||
.WithDevice(n->requested_device())
|
||||
.WithAssignedDevice(device);
|
||||
.WithAssignedDevice(device_name_str);
|
||||
|
||||
ops::_XlaCompile xla_compile(root.WithOpName("xla_compile"),
|
||||
/*constants=*/cluster_info.constant_inputs,
|
||||
@ -441,10 +446,12 @@ Status BuildXlaOpsPass::Run(const GraphOptimizationPassOptions& options) {
|
||||
bool insert_print_nodes =
|
||||
GetBuildXlaOpsPassFlags()->tf_xla_print_cluster_outputs;
|
||||
|
||||
jit::DeviceInfoCache device_info_cache;
|
||||
|
||||
for (Node* n : xla_compiled_kernels) {
|
||||
TF_RETURN_IF_ERROR(ReplaceNodeWithXlaCompileAndXlaRun(
|
||||
options, *options.flib_def, lazy_compilation_enabled,
|
||||
insert_print_nodes, graph, n));
|
||||
&device_info_cache, options, *options.flib_def,
|
||||
lazy_compilation_enabled, insert_print_nodes, graph, n));
|
||||
}
|
||||
|
||||
if (VLOG_IS_ON(1)) {
|
||||
|
@ -15,6 +15,7 @@ limitations under the License.
|
||||
|
||||
#include "tensorflow/compiler/jit/device_util.h"
|
||||
|
||||
#include "absl/algorithm/container.h"
|
||||
#include "absl/container/flat_hash_set.h"
|
||||
#include "tensorflow/compiler/xla/status_macros.h"
|
||||
|
||||
@ -22,43 +23,68 @@ namespace tensorflow {
|
||||
namespace jit {
|
||||
using xla::StatusOr;
|
||||
|
||||
StatusOr<const XlaOpRegistry::DeviceRegistration*>
|
||||
DeviceInfoCache::GetCompilationDevice(absl::string_view device_name) {
|
||||
auto it = device_to_device_registration_.find(device_name);
|
||||
if (it != device_to_device_registration_.end()) {
|
||||
void DeviceSet::Insert(DeviceId device_id) {
|
||||
int word_index = device_id.id() / kWordSize;
|
||||
int bit_index = device_id.id() % kWordSize;
|
||||
|
||||
if (word_index >= storage_.size()) {
|
||||
storage_.resize(word_index + 1, 0);
|
||||
}
|
||||
|
||||
storage_[word_index] |= (1ull << bit_index);
|
||||
}
|
||||
|
||||
void DeviceSet::UnionWith(const DeviceSet& other) {
|
||||
if (other.storage_.size() > storage_.size()) {
|
||||
storage_.resize(other.storage_.size(), 0);
|
||||
}
|
||||
|
||||
for (int i = 0; i < other.storage_.size(); i++) {
|
||||
storage_[i] |= other.storage_[i];
|
||||
}
|
||||
}
|
||||
|
||||
bool DeviceSet::IsEmpty() const {
|
||||
return absl::c_all_of(storage_, [&](uint64 val) { return val == 0; });
|
||||
}
|
||||
|
||||
xla::StatusOr<DeviceId> DeviceInfoCache::GetIdFor(absl::string_view name) {
|
||||
TF_RET_CHECK(!name.empty());
|
||||
|
||||
auto it = name_to_id_.find(name);
|
||||
if (it != name_to_id_.end()) {
|
||||
return it->second;
|
||||
}
|
||||
|
||||
string device_name_str = string(device_name);
|
||||
TF_ASSIGN_OR_RETURN(const DeviceType& device_type,
|
||||
GetDeviceTypeFor(device_name_str));
|
||||
const XlaOpRegistry::DeviceRegistration* registration;
|
||||
if (!XlaOpRegistry::GetCompilationDevice(device_type.type(), ®istration)) {
|
||||
registration = nullptr;
|
||||
int new_id = names_.size();
|
||||
names_.push_back(string(name));
|
||||
id_to_device_type_.push_back(absl::make_unique<DeviceType>(""));
|
||||
DeviceType* device_type = id_to_device_type_.back().get();
|
||||
TF_RETURN_IF_ERROR(DeviceNameToDeviceType(names_.back(), device_type));
|
||||
|
||||
is_cpu_.push_back(device_type->type_string() == DEVICE_CPU);
|
||||
is_gpu_.push_back(device_type->type_string() == DEVICE_GPU);
|
||||
|
||||
name_to_id_.emplace(string(name), DeviceId(new_id));
|
||||
|
||||
const XlaOpRegistry::DeviceRegistration* compilation_device;
|
||||
if (!XlaOpRegistry::GetCompilationDevice(device_type->type(),
|
||||
&compilation_device)) {
|
||||
compilation_device = nullptr;
|
||||
}
|
||||
id_to_compilation_device_.push_back(compilation_device);
|
||||
|
||||
device_to_device_registration_.insert(
|
||||
{std::move(device_name_str), registration});
|
||||
|
||||
return registration;
|
||||
return DeviceId(new_id);
|
||||
}
|
||||
|
||||
StatusOr<std::reference_wrapper<const DeviceType>>
|
||||
DeviceInfoCache::GetDeviceTypeFor(absl::string_view device_name) {
|
||||
auto it = device_to_device_type_.find(device_name);
|
||||
if (it != device_to_device_type_.end()) {
|
||||
return std::cref(*it->second);
|
||||
}
|
||||
string DeviceInfoCache::DebugString(const DeviceSet& device_set) const {
|
||||
std::vector<string> names;
|
||||
device_set.ForEach([&](DeviceId device_id) {
|
||||
names.push_back(string(GetNameFor(device_id)));
|
||||
return false;
|
||||
});
|
||||
|
||||
string device_name_str = string(device_name);
|
||||
auto device_type = absl::make_unique<DeviceType>("");
|
||||
TF_RETURN_IF_ERROR(
|
||||
DeviceNameToDeviceType(device_name_str, device_type.get()));
|
||||
|
||||
it = device_to_device_type_
|
||||
.insert({std::move(device_name_str), std::move(device_type)})
|
||||
.first;
|
||||
return std::cref(*it->second);
|
||||
return absl::StrCat("[", absl::StrJoin(names, ","), "]");
|
||||
}
|
||||
} // namespace jit
|
||||
|
||||
@ -71,10 +97,11 @@ Status DeviceNameToDeviceType(const string& device, DeviceType* device_type) {
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status PickDeviceForXlaImpl(absl::Span<const string> device_names,
|
||||
Status PickDeviceForXlaImpl(const jit::DeviceInfoCache& device_info_cache,
|
||||
const jit::DeviceSet& devices,
|
||||
bool allow_mixing_unknown_and_cpu,
|
||||
bool* out_can_pick_device,
|
||||
string* out_device_picked) {
|
||||
absl::optional<jit::DeviceId>* out_device_picked) {
|
||||
if (out_can_pick_device) {
|
||||
*out_can_pick_device = true;
|
||||
}
|
||||
@ -89,65 +116,79 @@ Status PickDeviceForXlaImpl(absl::Span<const string> device_names,
|
||||
} \
|
||||
} while (false)
|
||||
|
||||
TF_RET_CHECK(!device_names.empty()) << "No devices to choose from";
|
||||
TF_RET_CHECK(!devices.IsEmpty()) << "No devices to choose from";
|
||||
DCHECK_NE(out_can_pick_device == nullptr, out_device_picked == nullptr);
|
||||
|
||||
absl::flat_hash_set<absl::string_view> device_names_set;
|
||||
for (absl::string_view device_name : device_names) {
|
||||
TF_RET_CHECK(!device_name.empty());
|
||||
device_names_set.insert(device_name);
|
||||
}
|
||||
absl::optional<jit::DeviceId> maybe_gpu_device;
|
||||
absl::optional<jit::DeviceId> maybe_cpu_device;
|
||||
absl::optional<jit::DeviceId> maybe_unknown_device;
|
||||
|
||||
absl::optional<absl::string_view> maybe_gpu_device;
|
||||
absl::optional<absl::string_view> maybe_cpu_device;
|
||||
absl::optional<absl::string_view> maybe_unknown_device;
|
||||
bool multiple_cpu_devices = false;
|
||||
bool multiple_gpu_devices = false;
|
||||
bool multiple_unknown_devices = false;
|
||||
|
||||
for (absl::string_view device_name : device_names_set) {
|
||||
DeviceNameUtils::ParsedName parsed_name;
|
||||
TF_RET_CHECK(DeviceNameUtils::ParseFullName(device_name, &parsed_name))
|
||||
<< device_name;
|
||||
if (parsed_name.type == "GPU") {
|
||||
devices.ForEach([&](jit::DeviceId device) {
|
||||
if (device_info_cache.IsGpu(device)) {
|
||||
if (maybe_gpu_device) {
|
||||
FAILED_TO_PICK_DEVICE(errors::Internal(
|
||||
"Multiple GPU devices ", absl::StrJoin(device_names, ", ")));
|
||||
multiple_gpu_devices = true;
|
||||
return false;
|
||||
}
|
||||
maybe_gpu_device = device_name;
|
||||
} else if (parsed_name.type == "CPU") {
|
||||
maybe_gpu_device = device;
|
||||
} else if (device_info_cache.IsCpu(device)) {
|
||||
if (maybe_cpu_device) {
|
||||
FAILED_TO_PICK_DEVICE(errors::Internal(
|
||||
"Multiple CPU devices ", absl::StrJoin(device_names, ", ")));
|
||||
multiple_cpu_devices = true;
|
||||
return false;
|
||||
}
|
||||
maybe_cpu_device = device_name;
|
||||
maybe_cpu_device = device;
|
||||
} else {
|
||||
if (maybe_unknown_device) {
|
||||
multiple_unknown_devices = true;
|
||||
return false;
|
||||
}
|
||||
maybe_unknown_device = device;
|
||||
}
|
||||
|
||||
return true;
|
||||
});
|
||||
|
||||
if (multiple_cpu_devices) {
|
||||
FAILED_TO_PICK_DEVICE(errors::Internal(
|
||||
"Multiple unknown devices ", absl::StrJoin(device_names, ", ")));
|
||||
"Multiple CPU devices ", device_info_cache.DebugString(devices)));
|
||||
}
|
||||
maybe_unknown_device = device_name;
|
||||
|
||||
if (multiple_gpu_devices) {
|
||||
FAILED_TO_PICK_DEVICE(errors::Internal(
|
||||
"Multiple GPU devices ", device_info_cache.DebugString(devices)));
|
||||
}
|
||||
|
||||
if (multiple_unknown_devices) {
|
||||
FAILED_TO_PICK_DEVICE(errors::Internal(
|
||||
"Multiple unknown devices ", device_info_cache.DebugString(devices)));
|
||||
}
|
||||
|
||||
if (maybe_unknown_device && maybe_gpu_device) {
|
||||
FAILED_TO_PICK_DEVICE(errors::Internal(
|
||||
"Found both unknown and GPU devices: ", *maybe_unknown_device, ", ",
|
||||
*maybe_gpu_device));
|
||||
"Found both unknown and GPU devices: ",
|
||||
device_info_cache.GetNameFor(*maybe_unknown_device), ", ",
|
||||
device_info_cache.GetNameFor(*maybe_gpu_device)));
|
||||
}
|
||||
|
||||
if (!allow_mixing_unknown_and_cpu) {
|
||||
if (maybe_unknown_device && maybe_cpu_device) {
|
||||
FAILED_TO_PICK_DEVICE(errors::Internal(
|
||||
"Found both unknown and CPU devices: ", *maybe_unknown_device, ", ",
|
||||
*maybe_cpu_device));
|
||||
"Found both unknown and CPU devices: ",
|
||||
device_info_cache.GetNameFor(*maybe_unknown_device), ", ",
|
||||
device_info_cache.GetNameFor(*maybe_cpu_device)));
|
||||
}
|
||||
}
|
||||
|
||||
if (out_device_picked) {
|
||||
if (maybe_gpu_device) {
|
||||
*out_device_picked = string(*maybe_gpu_device);
|
||||
*out_device_picked = *maybe_gpu_device;
|
||||
} else if (maybe_unknown_device) {
|
||||
*out_device_picked = string(*maybe_unknown_device);
|
||||
*out_device_picked = *maybe_unknown_device;
|
||||
} else {
|
||||
*out_device_picked = string(*maybe_cpu_device);
|
||||
*out_device_picked = *maybe_cpu_device;
|
||||
}
|
||||
}
|
||||
|
||||
@ -156,19 +197,24 @@ Status PickDeviceForXlaImpl(absl::Span<const string> device_names,
|
||||
#undef FAILED_TO_PICK_DEVICE
|
||||
}
|
||||
|
||||
Status PickDeviceForXla(absl::Span<const string> device_names,
|
||||
bool allow_mixing_unknown_and_cpu,
|
||||
string* out_device_picked) {
|
||||
return PickDeviceForXlaImpl(device_names, allow_mixing_unknown_and_cpu,
|
||||
/*out_can_pick_device=*/nullptr,
|
||||
out_device_picked);
|
||||
xla::StatusOr<jit::DeviceId> PickDeviceForXla(
|
||||
const jit::DeviceInfoCache& device_info_cache,
|
||||
const jit::DeviceSet& devices, bool allow_mixing_unknown_and_cpu) {
|
||||
absl::optional<jit::DeviceId> device;
|
||||
TF_RETURN_IF_ERROR(PickDeviceForXlaImpl(
|
||||
device_info_cache, devices, allow_mixing_unknown_and_cpu,
|
||||
/*out_can_pick_device=*/nullptr, &device));
|
||||
return *device;
|
||||
}
|
||||
|
||||
Status CanPickDeviceForXla(absl::Span<const string> device_names,
|
||||
bool allow_mixing_unknown_and_cpu,
|
||||
bool* out_can_pick_device) {
|
||||
return PickDeviceForXlaImpl(device_names, allow_mixing_unknown_and_cpu,
|
||||
out_can_pick_device,
|
||||
/*out_device_picked=*/nullptr);
|
||||
xla::StatusOr<bool> CanPickDeviceForXla(
|
||||
const jit::DeviceInfoCache& device_info_cache,
|
||||
const jit::DeviceSet& devices, bool allow_mixing_unknown_and_cpu) {
|
||||
bool can_pick_device;
|
||||
TF_RETURN_IF_ERROR(PickDeviceForXlaImpl(device_info_cache, devices,
|
||||
allow_mixing_unknown_and_cpu,
|
||||
&can_pick_device,
|
||||
/*out_device_picked=*/nullptr));
|
||||
return can_pick_device;
|
||||
}
|
||||
} // namespace tensorflow
|
||||
|
@ -23,24 +23,119 @@ limitations under the License.
|
||||
#include "absl/strings/string_view.h"
|
||||
#include "absl/types/span.h"
|
||||
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
|
||||
#include "tensorflow/compiler/xla/status_macros.h"
|
||||
#include "tensorflow/compiler/xla/statusor.h"
|
||||
#include "tensorflow/core/framework/types.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace jit {
|
||||
// Instances of DeviceId represent TensorFlow devices as integers.
|
||||
//
|
||||
// This helps avoid having to manipulate device names as strings when
|
||||
// auto-clustering.
|
||||
class DeviceId {
|
||||
public:
|
||||
DeviceId(DeviceId&&) = default;
|
||||
DeviceId(const DeviceId&) = default;
|
||||
DeviceId& operator=(const DeviceId&) = default;
|
||||
|
||||
bool operator==(const DeviceId& other) const { return id() == other.id(); }
|
||||
bool operator!=(const DeviceId& other) const { return !(*this == other); }
|
||||
|
||||
private:
|
||||
int id_;
|
||||
|
||||
explicit DeviceId(int id) : id_(id) {}
|
||||
|
||||
int id() const { return id_; }
|
||||
|
||||
friend class DeviceInfoCache;
|
||||
friend class DeviceSet;
|
||||
};
|
||||
|
||||
// A set of DeviceIds, represented as a bitmap.
|
||||
class DeviceSet {
|
||||
public:
|
||||
void Insert(DeviceId device_id);
|
||||
void UnionWith(const DeviceSet& other);
|
||||
bool IsEmpty() const;
|
||||
|
||||
// Calls `func` on each DeviceId in the set. Stops iterating early if `func`
|
||||
// return false.
|
||||
//
|
||||
// TODO(sanjoy): Change this to take a typed std::function if that's
|
||||
// performance neutral.
|
||||
template <typename FnTy>
|
||||
void ForEach(FnTy func) const {
|
||||
// This is really a poor man's iterator, we should consider writing a proper
|
||||
// iterator if this ends up being used widely.
|
||||
for (int word_index = 0; word_index < storage_.size(); word_index++) {
|
||||
uint64 word = storage_[word_index];
|
||||
for (int bit_index = 0; bit_index < kWordSize; bit_index++) {
|
||||
if (word & (1ull << bit_index)) {
|
||||
if (!func(DeviceId(word_index * kWordSize + bit_index))) {
|
||||
return;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
private:
|
||||
absl::InlinedVector<uint64, 1> storage_;
|
||||
|
||||
const int kWordSize = 64;
|
||||
};
|
||||
|
||||
// Caches some miscellaneous information about TF devices. Thread compatible.
|
||||
class DeviceInfoCache {
|
||||
public:
|
||||
xla::StatusOr<const XlaOpRegistry::DeviceRegistration*> GetCompilationDevice(
|
||||
absl::string_view device_name);
|
||||
xla::StatusOr<std::reference_wrapper<const DeviceType>> GetDeviceTypeFor(
|
||||
absl::string_view device_name);
|
||||
bool IsGpu(DeviceId device) const { return is_gpu_[device.id()]; }
|
||||
bool IsCpu(DeviceId device) const { return is_cpu_[device.id()]; }
|
||||
|
||||
absl::string_view GetNameFor(DeviceId device) const {
|
||||
return names_[device.id()];
|
||||
}
|
||||
|
||||
xla::StatusOr<DeviceId> GetIdFor(absl::string_view name);
|
||||
|
||||
using DeviceRegistration = const XlaOpRegistry::DeviceRegistration;
|
||||
|
||||
DeviceRegistration* GetCompilationDevice(DeviceId device) const {
|
||||
return id_to_compilation_device_[device.id()];
|
||||
}
|
||||
|
||||
xla::StatusOr<DeviceRegistration*> GetCompilationDevice(
|
||||
absl::string_view name) {
|
||||
TF_ASSIGN_OR_RETURN(DeviceId device_id, GetIdFor(name));
|
||||
return GetCompilationDevice(device_id);
|
||||
}
|
||||
|
||||
const DeviceType& GetDeviceTypeFor(DeviceId device) const {
|
||||
return *id_to_device_type_[device.id()];
|
||||
}
|
||||
|
||||
using DeviceTypeConstRef = std::reference_wrapper<const DeviceType>;
|
||||
|
||||
xla::StatusOr<DeviceTypeConstRef> GetDeviceTypeFor(
|
||||
absl::string_view device_name) {
|
||||
TF_ASSIGN_OR_RETURN(DeviceId device_id, GetIdFor(device_name));
|
||||
return std::cref(*id_to_device_type_[device_id.id()]);
|
||||
}
|
||||
|
||||
string DebugString(const DeviceSet& device_set) const;
|
||||
|
||||
private:
|
||||
absl::flat_hash_map<string, const XlaOpRegistry::DeviceRegistration*>
|
||||
device_to_device_registration_;
|
||||
absl::flat_hash_map<string, std::unique_ptr<DeviceType>>
|
||||
device_to_device_type_;
|
||||
absl::flat_hash_map<string, DeviceId> name_to_id_;
|
||||
|
||||
// These fields are populated for a device in GetIdFor, *before* we give out a
|
||||
// DeviceId.
|
||||
std::vector<const XlaOpRegistry::DeviceRegistration*>
|
||||
id_to_compilation_device_;
|
||||
std::vector<std::unique_ptr<DeviceType>> id_to_device_type_;
|
||||
std::vector<string> names_;
|
||||
std::vector<bool> is_cpu_;
|
||||
std::vector<bool> is_gpu_;
|
||||
};
|
||||
|
||||
} // namespace jit
|
||||
@ -49,7 +144,7 @@ class DeviceInfoCache {
|
||||
Status DeviceNameToDeviceType(const string& device, DeviceType* device_type);
|
||||
|
||||
// Picks the device for which XLA should compile a cluster that contains
|
||||
// operations placed in devices in `device_names`. For instance a cluster that
|
||||
// operations placed in devices in `devices`. For instance a cluster that
|
||||
// contains operations solely placed on the CPU will be compiled into a CPU
|
||||
// executable by XLA, whereas a cluster that contains operations placed on the
|
||||
// CPU and also operations placed on the GPU will be compiled into a GPU
|
||||
@ -82,16 +177,15 @@ Status DeviceNameToDeviceType(const string& device, DeviceType* device_type);
|
||||
// case it is the responsibility of the optimization pass that injected the
|
||||
// CPU nodes into the cluster to ensure that these nodes can be compiled by
|
||||
// the unknown XLA backend.
|
||||
Status PickDeviceForXla(absl::Span<const string> device_names,
|
||||
bool allow_mixing_unknown_and_cpu,
|
||||
string* out_device_picked);
|
||||
xla::StatusOr<jit::DeviceId> PickDeviceForXla(
|
||||
const jit::DeviceInfoCache& device_info_cache,
|
||||
const jit::DeviceSet& devices, bool allow_mixing_unknown_and_cpu);
|
||||
|
||||
// This is like `PickDeviceForXla` except that it returns false (instead of a
|
||||
// non-OK Status) in `out_can_pick_device` if no unambiguous choice of device
|
||||
// exists.
|
||||
Status CanPickDeviceForXla(absl::Span<const string> device_names,
|
||||
bool allow_mixing_unknown_and_cpu,
|
||||
bool* out_can_pick_device);
|
||||
// non-OK Status) if no unambiguous choice of device exists.
|
||||
xla::StatusOr<bool> CanPickDeviceForXla(
|
||||
const jit::DeviceInfoCache& device_info_cache,
|
||||
const jit::DeviceSet& devices, bool allow_mixing_unknown_and_cpu);
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // TENSORFLOW_COMPILER_JIT_DEVICE_INFO_CACHE_H_
|
||||
|
@ -22,12 +22,20 @@ namespace tensorflow {
|
||||
namespace {
|
||||
|
||||
Status PickDeviceHelper(bool allow_mixing_unknown_and_cpu,
|
||||
absl::Span<const absl::string_view> inputs,
|
||||
absl::Span<const absl::string_view> device_names,
|
||||
string* result) {
|
||||
std::vector<string> inputs_string;
|
||||
absl::c_transform(inputs, std::back_inserter(inputs_string),
|
||||
[](absl::string_view sv) { return string(sv); });
|
||||
return PickDeviceForXla(inputs_string, allow_mixing_unknown_and_cpu, result);
|
||||
jit::DeviceInfoCache cache;
|
||||
jit::DeviceSet device_set;
|
||||
for (absl::string_view name : device_names) {
|
||||
TF_ASSIGN_OR_RETURN(jit::DeviceId device_id, cache.GetIdFor(name));
|
||||
device_set.Insert(device_id);
|
||||
}
|
||||
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
jit::DeviceId result_id,
|
||||
PickDeviceForXla(cache, device_set, allow_mixing_unknown_and_cpu));
|
||||
*result = string(cache.GetNameFor(result_id));
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
void CheckPickDeviceResult(absl::string_view expected_result,
|
||||
@ -87,5 +95,38 @@ TEST(PickDeviceForXla, MultipleDevicesOfSameType) {
|
||||
CheckPickDeviceHasError(false, {kCPU0, kCPU1, kGPU0});
|
||||
}
|
||||
|
||||
void SimpleRoundTripTestForDeviceSet(int num_devices) {
|
||||
jit::DeviceSet device_set;
|
||||
jit::DeviceInfoCache device_info_cache;
|
||||
|
||||
std::vector<string> expected_devices, actual_devices;
|
||||
|
||||
for (int i = 0; i < num_devices; i++) {
|
||||
string device_name =
|
||||
absl::StrCat("/job:localhost/replica:0/task:0/device:XPU:", i);
|
||||
TF_ASSERT_OK_AND_ASSIGN(jit::DeviceId device_id,
|
||||
device_info_cache.GetIdFor(device_name));
|
||||
device_set.Insert(device_id);
|
||||
expected_devices.push_back(device_name);
|
||||
}
|
||||
|
||||
device_set.ForEach([&](jit::DeviceId device_id) {
|
||||
actual_devices.push_back(string(device_info_cache.GetNameFor(device_id)));
|
||||
return true;
|
||||
});
|
||||
|
||||
EXPECT_EQ(expected_devices, actual_devices);
|
||||
}
|
||||
|
||||
TEST(DeviceSetTest, SimpleRoundTrip_One) { SimpleRoundTripTestForDeviceSet(1); }
|
||||
|
||||
TEST(DeviceSetTest, SimpleRoundTrip_Small) {
|
||||
SimpleRoundTripTestForDeviceSet(8);
|
||||
}
|
||||
|
||||
TEST(DeviceSetTest, SimpleRoundTrip_Large) {
|
||||
SimpleRoundTripTestForDeviceSet(800);
|
||||
}
|
||||
|
||||
} // namespace
|
||||
} // namespace tensorflow
|
||||
|
@ -57,6 +57,8 @@ namespace tensorflow {
|
||||
|
||||
namespace {
|
||||
using DeadnessPredicate = DeadnessAnalysis::DeadnessPredicate;
|
||||
using jit::DeviceId;
|
||||
using jit::DeviceSet;
|
||||
using xla::StatusOr;
|
||||
|
||||
// The clusters we create here are eventually lowered into an
|
||||
@ -117,8 +119,8 @@ class MarkForCompilationPassImpl {
|
||||
public:
|
||||
// Constructs a trivial cluster representing a single TF node.
|
||||
Cluster(int tf_graph_node_id, int effective_cluster_size,
|
||||
bool has_functional_control_flow,
|
||||
absl::flat_hash_set<string> devices, string resource_op_device,
|
||||
bool has_functional_control_flow, DeviceSet devices,
|
||||
absl::optional<DeviceId> resource_op_device,
|
||||
absl::optional<int> resource_var_operation_node_id,
|
||||
absl::optional<DeadnessPredicate> deadness_predicate,
|
||||
bool is_xla_compile_attr_true, absl::optional<string> xla_scope)
|
||||
@ -126,7 +128,7 @@ class MarkForCompilationPassImpl {
|
||||
effective_cluster_size_(effective_cluster_size),
|
||||
has_functional_control_flow_(has_functional_control_flow),
|
||||
devices_(std::move(devices)),
|
||||
resource_op_device_(std::move(resource_op_device)),
|
||||
resource_op_device_(resource_op_device),
|
||||
deadness_predicate_(deadness_predicate),
|
||||
is_xla_compile_attr_true_(is_xla_compile_attr_true),
|
||||
xla_scope_(std::move(xla_scope)) {
|
||||
@ -162,12 +164,14 @@ class MarkForCompilationPassImpl {
|
||||
}
|
||||
|
||||
// The set of devices nodes in the cluster are placed on.
|
||||
const absl::flat_hash_set<string>& devices() const { return devices_; }
|
||||
const DeviceSet& devices() const { return devices_; }
|
||||
|
||||
// If the cluster has a resource operation then the device the resource
|
||||
// operation is placed on. A cluster may have resource ops placed only on a
|
||||
// single device.
|
||||
const string& resource_op_device() const { return resource_op_device_; }
|
||||
const absl::optional<DeviceId>& resource_op_device() const {
|
||||
return resource_op_device_;
|
||||
}
|
||||
|
||||
// If not nullopt the a predicate that is true iff the cluster is alive.
|
||||
// Otherwise the user has (unsafely) disabled deadness analysis. If this is
|
||||
@ -208,8 +212,8 @@ class MarkForCompilationPassImpl {
|
||||
int cycles_graph_node_id_;
|
||||
int effective_cluster_size_;
|
||||
bool has_functional_control_flow_;
|
||||
absl::flat_hash_set<string> devices_;
|
||||
string resource_op_device_;
|
||||
DeviceSet devices_;
|
||||
absl::optional<DeviceId> resource_op_device_;
|
||||
absl::optional<DeadnessPredicate> deadness_predicate_;
|
||||
bool is_xla_compile_attr_true_;
|
||||
absl::optional<string> xla_scope_;
|
||||
@ -279,17 +283,17 @@ class MarkForCompilationPassImpl {
|
||||
|
||||
Cluster* MakeNewCluster(int cycles_graph_node_id, int effective_cluster_size,
|
||||
bool has_functional_control_flow,
|
||||
absl::flat_hash_set<string> devices,
|
||||
string resource_op_device,
|
||||
const DeviceSet& device_set,
|
||||
absl::optional<DeviceId> resource_op_device,
|
||||
absl::optional<int> resource_var_operation_node_id,
|
||||
absl::optional<DeadnessPredicate> deadness_predicate,
|
||||
bool is_xla_compile_attr_true,
|
||||
absl::optional<string> xla_scope) {
|
||||
cluster_storage_.push_back(absl::make_unique<Cluster>(
|
||||
cycles_graph_node_id, effective_cluster_size,
|
||||
has_functional_control_flow, std::move(devices),
|
||||
std::move(resource_op_device), resource_var_operation_node_id,
|
||||
deadness_predicate, is_xla_compile_attr_true, xla_scope));
|
||||
has_functional_control_flow, device_set, resource_op_device,
|
||||
resource_var_operation_node_id, deadness_predicate,
|
||||
is_xla_compile_attr_true, xla_scope));
|
||||
return cluster_storage_.back().get();
|
||||
}
|
||||
|
||||
@ -486,13 +490,15 @@ void MarkForCompilationPassImpl::Cluster::Merge(Cluster* other) {
|
||||
effective_cluster_size_ += other->effective_cluster_size_;
|
||||
has_functional_control_flow_ |= other->has_functional_control_flow_;
|
||||
|
||||
for (string other_device : other->devices_) {
|
||||
devices_.insert(other_device);
|
||||
}
|
||||
other->devices_.clear();
|
||||
devices_.UnionWith(other->devices_);
|
||||
|
||||
if (resource_op_device_.empty()) {
|
||||
resource_op_device_ = std::move(other->resource_op_device_);
|
||||
DCHECK(!(resource_op_device_.has_value() &&
|
||||
other->resource_op_device_.has_value()) ||
|
||||
*resource_op_device_ == *other->resource_op_device_)
|
||||
<< "AreDevicesCompatible should have returned false otherwise!";
|
||||
|
||||
if (!resource_op_device_.has_value()) {
|
||||
resource_op_device_ = other->resource_op_device_;
|
||||
}
|
||||
|
||||
is_xla_compile_attr_true_ |= other->is_xla_compile_attr_true_;
|
||||
@ -779,12 +785,14 @@ Status MarkForCompilationPassImpl::BuildInitialClusterSet() {
|
||||
deadness_analysis_->GetPredicateFor(node, Graph::kControlSlot));
|
||||
}
|
||||
|
||||
const string& device = !node->assigned_device_name().empty()
|
||||
const string& device_name_str = !node->assigned_device_name().empty()
|
||||
? node->assigned_device_name()
|
||||
: node->requested_device();
|
||||
TF_ASSIGN_OR_RETURN(DeviceId device,
|
||||
device_info_cache_.GetIdFor(device_name_str));
|
||||
|
||||
bool is_resource_op = HasResourceInputOrOutput(*node);
|
||||
string resource_op_device;
|
||||
absl::optional<DeviceId> resource_op_device;
|
||||
if (is_resource_op) {
|
||||
resource_op_device = device;
|
||||
}
|
||||
@ -805,15 +813,14 @@ Status MarkForCompilationPassImpl::BuildInitialClusterSet() {
|
||||
is_xla_compile_attr_true |= xla_compile_attr;
|
||||
}
|
||||
|
||||
absl::flat_hash_set<string> devices;
|
||||
devices.insert(device);
|
||||
DeviceSet devices;
|
||||
devices.Insert(device);
|
||||
|
||||
Cluster* new_cluster = MakeNewCluster(
|
||||
/*cycles_graph_node_id=*/node->id(),
|
||||
/*effective_cluster_size=*/effective_cluster_size,
|
||||
/*has_functional_control_flow=*/has_functional_control_flow,
|
||||
std::move(devices), std::move(resource_op_device),
|
||||
resource_var_operation_node_id, deadness_predicate,
|
||||
/*has_functional_control_flow=*/has_functional_control_flow, devices,
|
||||
resource_op_device, resource_var_operation_node_id, deadness_predicate,
|
||||
/*is_xla_compile_attr_true=*/is_xla_compile_attr_true,
|
||||
GetXlaScope(node));
|
||||
|
||||
@ -1255,27 +1262,22 @@ void MarkForCompilationPassImpl::VLogClusteringSummary() {
|
||||
|
||||
StatusOr<bool> MarkForCompilationPassImpl::AreDevicesCompatible(
|
||||
const Cluster& cluster_a, const Cluster& cluster_b) {
|
||||
std::vector<string> devices;
|
||||
absl::c_remove_copy(cluster_a.devices(), std::back_inserter(devices), "");
|
||||
absl::c_remove_copy(cluster_b.devices(), std::back_inserter(devices), "");
|
||||
absl::c_sort(devices);
|
||||
|
||||
if (devices.empty()) {
|
||||
return false;
|
||||
}
|
||||
DeviceSet devices = cluster_a.devices();
|
||||
devices.UnionWith(cluster_b.devices());
|
||||
|
||||
// First check if we will even be able to pick a device for the larger
|
||||
// combined cluster.
|
||||
bool can_pick_device;
|
||||
TF_RETURN_IF_ERROR(CanPickDeviceForXla(
|
||||
devices, /*allow_mixing_unknown_and_cpu=*/false, &can_pick_device));
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
bool can_pick_device,
|
||||
CanPickDeviceForXla(device_info_cache_, devices,
|
||||
/*allow_mixing_unknown_and_cpu=*/false));
|
||||
if (!can_pick_device) {
|
||||
return false;
|
||||
}
|
||||
|
||||
string chosen_device;
|
||||
TF_RETURN_IF_ERROR(PickDeviceForXla(
|
||||
devices, /*allow_mixing_unknown_and_cpu=*/false, &chosen_device));
|
||||
TF_ASSIGN_OR_RETURN(DeviceId chosen_device,
|
||||
PickDeviceForXla(device_info_cache_, devices,
|
||||
/*allow_mixing_unknown_and_cpu=*/false));
|
||||
|
||||
// If we are able to pick a device `chosen_device` for the larger cluster, the
|
||||
// resource operations in `cluster_a` and `cluster_b` must be placed on the
|
||||
@ -1283,8 +1285,10 @@ StatusOr<bool> MarkForCompilationPassImpl::AreDevicesCompatible(
|
||||
// _XlaRun kernels are going to run on and therefore try to access the
|
||||
// resource variables from `chosen_device`, which will be an error if the
|
||||
// resource variables are placed on some other device.
|
||||
auto resource_op_device_ok = [&](const string& resource_op_device) {
|
||||
return resource_op_device.empty() || resource_op_device == chosen_device;
|
||||
auto resource_op_device_ok =
|
||||
[&](absl::optional<DeviceId> resource_op_device) {
|
||||
return !resource_op_device.has_value() ||
|
||||
*resource_op_device == chosen_device;
|
||||
};
|
||||
|
||||
return resource_op_device_ok(cluster_a.resource_op_device()) &&
|
||||
@ -1294,22 +1298,18 @@ StatusOr<bool> MarkForCompilationPassImpl::AreDevicesCompatible(
|
||||
// Returns `true` iff we should compile `cluster`.
|
||||
StatusOr<bool> MarkForCompilationPassImpl::ShouldCompileClusterImpl(
|
||||
const Cluster& cluster) {
|
||||
std::vector<string> devices;
|
||||
absl::c_remove_copy(cluster.devices(), std::back_inserter(devices), "");
|
||||
absl::c_sort(devices);
|
||||
TF_ASSIGN_OR_RETURN(DeviceId chosen_device,
|
||||
PickDeviceForXla(device_info_cache_, cluster.devices(),
|
||||
/*allow_mixing_unknown_and_cpu=*/false));
|
||||
|
||||
string chosen_device;
|
||||
TF_RETURN_IF_ERROR(PickDeviceForXla(
|
||||
devices, /*allow_mixing_unknown_and_cpu=*/false, &chosen_device));
|
||||
|
||||
TF_ASSIGN_OR_RETURN(const DeviceType& device_type,
|
||||
device_info_cache_.GetDeviceTypeFor(chosen_device));
|
||||
TF_ASSIGN_OR_RETURN(const XlaOpRegistry::DeviceRegistration* registration,
|
||||
device_info_cache_.GetCompilationDevice(chosen_device));
|
||||
const DeviceType& device_type =
|
||||
device_info_cache_.GetDeviceTypeFor(chosen_device);
|
||||
const XlaOpRegistry::DeviceRegistration* registration =
|
||||
device_info_cache_.GetCompilationDevice(chosen_device);
|
||||
TF_RET_CHECK(registration)
|
||||
<< "chosen device = " << chosen_device
|
||||
<< "chosen device = " << device_info_cache_.GetNameFor(chosen_device)
|
||||
<< "; device type = " << device_type.type() << "; devices ("
|
||||
<< devices.size() << ") = " << absl::StrJoin(devices, ", ");
|
||||
<< device_info_cache_.DebugString(cluster.devices());
|
||||
|
||||
bool should_compile =
|
||||
cluster.is_xla_compile_attr_true() ||
|
||||
@ -1343,7 +1343,8 @@ StatusOr<bool> MarkForCompilationPassImpl::ShouldCompileClusterImpl(
|
||||
}
|
||||
|
||||
VLOG(3) << (should_compile ? "Compiling" : "Not compiling")
|
||||
<< " cluster with device " << chosen_device;
|
||||
<< " cluster with device "
|
||||
<< device_info_cache_.GetNameFor(chosen_device);
|
||||
|
||||
return should_compile;
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user