Use integers instead of strings to represent devices in the auto-clustering passes
This, in turn, lets us represent sets of devices as bitmaps, which is a big win. PiperOrigin-RevId: 246939469
This commit is contained in:
parent
233e0ddbe8
commit
b912370109
@ -593,6 +593,7 @@ cc_library(
|
|||||||
"//tensorflow/compiler/xla:status_macros",
|
"//tensorflow/compiler/xla:status_macros",
|
||||||
"//tensorflow/compiler/xla:statusor",
|
"//tensorflow/compiler/xla:statusor",
|
||||||
"//tensorflow/core:framework",
|
"//tensorflow/core:framework",
|
||||||
|
"@com_google_absl//absl/algorithm:container",
|
||||||
"@com_google_absl//absl/container:flat_hash_map",
|
"@com_google_absl//absl/container:flat_hash_map",
|
||||||
"@com_google_absl//absl/container:flat_hash_set",
|
"@com_google_absl//absl/container:flat_hash_set",
|
||||||
"@com_google_absl//absl/strings",
|
"@com_google_absl//absl/strings",
|
||||||
|
@ -233,14 +233,10 @@ void RemoveAllIncomingControlEdges(Graph* g, Node* n) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Returns true (into `result`) if a node placed on `device` must be compiled.
|
// Returns true (into `result`) if a node placed on `device` must be compiled.
|
||||||
Status DeviceRequiresCompilation(const string& device, bool* result) {
|
Status DeviceRequiresCompilation(const jit::DeviceInfoCache& device_info_cache,
|
||||||
DeviceType device_type("");
|
jit::DeviceId device, bool* result) {
|
||||||
TF_RETURN_IF_ERROR(DeviceNameToDeviceType(device, &device_type));
|
const XlaOpRegistry::DeviceRegistration* registration =
|
||||||
const XlaOpRegistry::DeviceRegistration* registration = nullptr;
|
device_info_cache.GetCompilationDevice(device);
|
||||||
if (!XlaOpRegistry::GetCompilationDevice(device_type.type(), ®istration)) {
|
|
||||||
return errors::Internal("Could not find compilation device ",
|
|
||||||
device_type.type());
|
|
||||||
}
|
|
||||||
*result = registration->autoclustering_policy ==
|
*result = registration->autoclustering_policy ==
|
||||||
XlaOpRegistry::AutoclusteringPolicy::kAlways;
|
XlaOpRegistry::AutoclusteringPolicy::kAlways;
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
@ -293,17 +289,20 @@ Status ReplaceFunctionCallWithPartionedCall(
|
|||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
|
||||||
Status InferDeviceForCluster(Node* n, const string& function_name,
|
xla::StatusOr<jit::DeviceId> InferDeviceForCluster(
|
||||||
const FunctionLibraryDefinition& flib_def,
|
jit::DeviceInfoCache* device_info_cache, Node* n,
|
||||||
string* result) {
|
const string& function_name, const FunctionLibraryDefinition& flib_def) {
|
||||||
const FunctionDef* func_def = flib_def.Find(function_name);
|
const FunctionDef* func_def = flib_def.Find(function_name);
|
||||||
TF_RET_CHECK(func_def) << "Could not find " << function_name;
|
TF_RET_CHECK(func_def) << "Could not find " << function_name;
|
||||||
|
|
||||||
std::set<string> device_names;
|
jit::DeviceSet device_set;
|
||||||
|
|
||||||
for (const NodeDef& ndef : func_def->node_def()) {
|
for (const NodeDef& ndef : func_def->node_def()) {
|
||||||
VLOG(3) << ndef.DebugString();
|
VLOG(3) << ndef.DebugString();
|
||||||
if (!ndef.device().empty()) {
|
if (!ndef.device().empty()) {
|
||||||
device_names.insert(ndef.device());
|
TF_ASSIGN_OR_RETURN(jit::DeviceId device_id,
|
||||||
|
device_info_cache->GetIdFor(ndef.device()));
|
||||||
|
device_set.Insert(device_id);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -311,41 +310,47 @@ Status InferDeviceForCluster(Node* n, const string& function_name,
|
|||||||
// TODO(sanjoy): We need this because EncapsulateSubgraphsPass drops device
|
// TODO(sanjoy): We need this because EncapsulateSubgraphsPass drops device
|
||||||
// assignment when constant folding. We should fix EncapsulateSubgraphsPass
|
// assignment when constant folding. We should fix EncapsulateSubgraphsPass
|
||||||
// instead.
|
// instead.
|
||||||
device_names.insert(n->assigned_device_name());
|
TF_ASSIGN_OR_RETURN(jit::DeviceId device_id,
|
||||||
|
device_info_cache->GetIdFor(n->assigned_device_name()));
|
||||||
|
device_set.Insert(device_id);
|
||||||
}
|
}
|
||||||
|
|
||||||
std::vector<string> device_names_vector;
|
TF_ASSIGN_OR_RETURN(jit::DeviceId result,
|
||||||
absl::c_copy(device_names, std::back_inserter(device_names_vector));
|
PickDeviceForXla(*device_info_cache, device_set,
|
||||||
|
/*allow_mixing_unknown_and_cpu=*/true));
|
||||||
Status s = PickDeviceForXla(device_names_vector, true, result);
|
VLOG(2) << "For " << function_name << " PickDeviceForXla("
|
||||||
if (s.ok()) {
|
<< device_info_cache->DebugString(device_set) << ") -> "
|
||||||
VLOG(2) << "For " << function_name << " PickDeviceForXla("
|
<< device_info_cache->GetNameFor(result);
|
||||||
<< absl::StrJoin(device_names_vector, ", ") << ") -> " << *result;
|
return result;
|
||||||
}
|
|
||||||
return s;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
Status ReplaceNodeWithXlaCompileAndXlaRun(
|
Status ReplaceNodeWithXlaCompileAndXlaRun(
|
||||||
|
jit::DeviceInfoCache* device_info_cache,
|
||||||
const GraphOptimizationPassOptions& options,
|
const GraphOptimizationPassOptions& options,
|
||||||
const FunctionLibraryDefinition& flib_def, bool lazy_compilation_enabled,
|
const FunctionLibraryDefinition& flib_def, bool lazy_compilation_enabled,
|
||||||
bool insert_print_nodes, Graph* g, Node* n) {
|
bool insert_print_nodes, Graph* g, Node* n) {
|
||||||
XlaClusterInfo cluster_info;
|
XlaClusterInfo cluster_info;
|
||||||
TF_RETURN_IF_ERROR(GetXlaClusterInfo(n, &cluster_info));
|
TF_RETURN_IF_ERROR(GetXlaClusterInfo(n, &cluster_info));
|
||||||
|
|
||||||
string device;
|
TF_ASSIGN_OR_RETURN(
|
||||||
TF_RETURN_IF_ERROR(InferDeviceForCluster(n, cluster_info.function.name(),
|
jit::DeviceId device,
|
||||||
flib_def, &device));
|
InferDeviceForCluster(device_info_cache, n, cluster_info.function.name(),
|
||||||
|
flib_def));
|
||||||
|
|
||||||
bool requires_compilation;
|
bool requires_compilation;
|
||||||
TF_RETURN_IF_ERROR(DeviceRequiresCompilation(device, &requires_compilation));
|
TF_RETURN_IF_ERROR(DeviceRequiresCompilation(*device_info_cache, device,
|
||||||
|
&requires_compilation));
|
||||||
if (!lazy_compilation_enabled) {
|
if (!lazy_compilation_enabled) {
|
||||||
requires_compilation = true;
|
requires_compilation = true;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
string device_name_str = string(device_info_cache->GetNameFor(device));
|
||||||
|
|
||||||
Status status;
|
Status status;
|
||||||
Scope root = NewInternalScope(g, &status, /*refiner=*/nullptr)
|
Scope root = NewInternalScope(g, &status, /*refiner=*/nullptr)
|
||||||
.NewSubScope(n->name())
|
.NewSubScope(n->name())
|
||||||
.WithDevice(n->requested_device())
|
.WithDevice(n->requested_device())
|
||||||
.WithAssignedDevice(device);
|
.WithAssignedDevice(device_name_str);
|
||||||
|
|
||||||
ops::_XlaCompile xla_compile(root.WithOpName("xla_compile"),
|
ops::_XlaCompile xla_compile(root.WithOpName("xla_compile"),
|
||||||
/*constants=*/cluster_info.constant_inputs,
|
/*constants=*/cluster_info.constant_inputs,
|
||||||
@ -441,10 +446,12 @@ Status BuildXlaOpsPass::Run(const GraphOptimizationPassOptions& options) {
|
|||||||
bool insert_print_nodes =
|
bool insert_print_nodes =
|
||||||
GetBuildXlaOpsPassFlags()->tf_xla_print_cluster_outputs;
|
GetBuildXlaOpsPassFlags()->tf_xla_print_cluster_outputs;
|
||||||
|
|
||||||
|
jit::DeviceInfoCache device_info_cache;
|
||||||
|
|
||||||
for (Node* n : xla_compiled_kernels) {
|
for (Node* n : xla_compiled_kernels) {
|
||||||
TF_RETURN_IF_ERROR(ReplaceNodeWithXlaCompileAndXlaRun(
|
TF_RETURN_IF_ERROR(ReplaceNodeWithXlaCompileAndXlaRun(
|
||||||
options, *options.flib_def, lazy_compilation_enabled,
|
&device_info_cache, options, *options.flib_def,
|
||||||
insert_print_nodes, graph, n));
|
lazy_compilation_enabled, insert_print_nodes, graph, n));
|
||||||
}
|
}
|
||||||
|
|
||||||
if (VLOG_IS_ON(1)) {
|
if (VLOG_IS_ON(1)) {
|
||||||
|
@ -15,6 +15,7 @@ limitations under the License.
|
|||||||
|
|
||||||
#include "tensorflow/compiler/jit/device_util.h"
|
#include "tensorflow/compiler/jit/device_util.h"
|
||||||
|
|
||||||
|
#include "absl/algorithm/container.h"
|
||||||
#include "absl/container/flat_hash_set.h"
|
#include "absl/container/flat_hash_set.h"
|
||||||
#include "tensorflow/compiler/xla/status_macros.h"
|
#include "tensorflow/compiler/xla/status_macros.h"
|
||||||
|
|
||||||
@ -22,43 +23,68 @@ namespace tensorflow {
|
|||||||
namespace jit {
|
namespace jit {
|
||||||
using xla::StatusOr;
|
using xla::StatusOr;
|
||||||
|
|
||||||
StatusOr<const XlaOpRegistry::DeviceRegistration*>
|
void DeviceSet::Insert(DeviceId device_id) {
|
||||||
DeviceInfoCache::GetCompilationDevice(absl::string_view device_name) {
|
int word_index = device_id.id() / kWordSize;
|
||||||
auto it = device_to_device_registration_.find(device_name);
|
int bit_index = device_id.id() % kWordSize;
|
||||||
if (it != device_to_device_registration_.end()) {
|
|
||||||
|
if (word_index >= storage_.size()) {
|
||||||
|
storage_.resize(word_index + 1, 0);
|
||||||
|
}
|
||||||
|
|
||||||
|
storage_[word_index] |= (1ull << bit_index);
|
||||||
|
}
|
||||||
|
|
||||||
|
void DeviceSet::UnionWith(const DeviceSet& other) {
|
||||||
|
if (other.storage_.size() > storage_.size()) {
|
||||||
|
storage_.resize(other.storage_.size(), 0);
|
||||||
|
}
|
||||||
|
|
||||||
|
for (int i = 0; i < other.storage_.size(); i++) {
|
||||||
|
storage_[i] |= other.storage_[i];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
bool DeviceSet::IsEmpty() const {
|
||||||
|
return absl::c_all_of(storage_, [&](uint64 val) { return val == 0; });
|
||||||
|
}
|
||||||
|
|
||||||
|
xla::StatusOr<DeviceId> DeviceInfoCache::GetIdFor(absl::string_view name) {
|
||||||
|
TF_RET_CHECK(!name.empty());
|
||||||
|
|
||||||
|
auto it = name_to_id_.find(name);
|
||||||
|
if (it != name_to_id_.end()) {
|
||||||
return it->second;
|
return it->second;
|
||||||
}
|
}
|
||||||
|
|
||||||
string device_name_str = string(device_name);
|
int new_id = names_.size();
|
||||||
TF_ASSIGN_OR_RETURN(const DeviceType& device_type,
|
names_.push_back(string(name));
|
||||||
GetDeviceTypeFor(device_name_str));
|
id_to_device_type_.push_back(absl::make_unique<DeviceType>(""));
|
||||||
const XlaOpRegistry::DeviceRegistration* registration;
|
DeviceType* device_type = id_to_device_type_.back().get();
|
||||||
if (!XlaOpRegistry::GetCompilationDevice(device_type.type(), ®istration)) {
|
TF_RETURN_IF_ERROR(DeviceNameToDeviceType(names_.back(), device_type));
|
||||||
registration = nullptr;
|
|
||||||
|
is_cpu_.push_back(device_type->type_string() == DEVICE_CPU);
|
||||||
|
is_gpu_.push_back(device_type->type_string() == DEVICE_GPU);
|
||||||
|
|
||||||
|
name_to_id_.emplace(string(name), DeviceId(new_id));
|
||||||
|
|
||||||
|
const XlaOpRegistry::DeviceRegistration* compilation_device;
|
||||||
|
if (!XlaOpRegistry::GetCompilationDevice(device_type->type(),
|
||||||
|
&compilation_device)) {
|
||||||
|
compilation_device = nullptr;
|
||||||
}
|
}
|
||||||
|
id_to_compilation_device_.push_back(compilation_device);
|
||||||
|
|
||||||
device_to_device_registration_.insert(
|
return DeviceId(new_id);
|
||||||
{std::move(device_name_str), registration});
|
|
||||||
|
|
||||||
return registration;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
StatusOr<std::reference_wrapper<const DeviceType>>
|
string DeviceInfoCache::DebugString(const DeviceSet& device_set) const {
|
||||||
DeviceInfoCache::GetDeviceTypeFor(absl::string_view device_name) {
|
std::vector<string> names;
|
||||||
auto it = device_to_device_type_.find(device_name);
|
device_set.ForEach([&](DeviceId device_id) {
|
||||||
if (it != device_to_device_type_.end()) {
|
names.push_back(string(GetNameFor(device_id)));
|
||||||
return std::cref(*it->second);
|
return false;
|
||||||
}
|
});
|
||||||
|
|
||||||
string device_name_str = string(device_name);
|
return absl::StrCat("[", absl::StrJoin(names, ","), "]");
|
||||||
auto device_type = absl::make_unique<DeviceType>("");
|
|
||||||
TF_RETURN_IF_ERROR(
|
|
||||||
DeviceNameToDeviceType(device_name_str, device_type.get()));
|
|
||||||
|
|
||||||
it = device_to_device_type_
|
|
||||||
.insert({std::move(device_name_str), std::move(device_type)})
|
|
||||||
.first;
|
|
||||||
return std::cref(*it->second);
|
|
||||||
}
|
}
|
||||||
} // namespace jit
|
} // namespace jit
|
||||||
|
|
||||||
@ -71,10 +97,11 @@ Status DeviceNameToDeviceType(const string& device, DeviceType* device_type) {
|
|||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
|
||||||
Status PickDeviceForXlaImpl(absl::Span<const string> device_names,
|
Status PickDeviceForXlaImpl(const jit::DeviceInfoCache& device_info_cache,
|
||||||
|
const jit::DeviceSet& devices,
|
||||||
bool allow_mixing_unknown_and_cpu,
|
bool allow_mixing_unknown_and_cpu,
|
||||||
bool* out_can_pick_device,
|
bool* out_can_pick_device,
|
||||||
string* out_device_picked) {
|
absl::optional<jit::DeviceId>* out_device_picked) {
|
||||||
if (out_can_pick_device) {
|
if (out_can_pick_device) {
|
||||||
*out_can_pick_device = true;
|
*out_can_pick_device = true;
|
||||||
}
|
}
|
||||||
@ -89,65 +116,79 @@ Status PickDeviceForXlaImpl(absl::Span<const string> device_names,
|
|||||||
} \
|
} \
|
||||||
} while (false)
|
} while (false)
|
||||||
|
|
||||||
TF_RET_CHECK(!device_names.empty()) << "No devices to choose from";
|
TF_RET_CHECK(!devices.IsEmpty()) << "No devices to choose from";
|
||||||
DCHECK_NE(out_can_pick_device == nullptr, out_device_picked == nullptr);
|
DCHECK_NE(out_can_pick_device == nullptr, out_device_picked == nullptr);
|
||||||
|
|
||||||
absl::flat_hash_set<absl::string_view> device_names_set;
|
absl::optional<jit::DeviceId> maybe_gpu_device;
|
||||||
for (absl::string_view device_name : device_names) {
|
absl::optional<jit::DeviceId> maybe_cpu_device;
|
||||||
TF_RET_CHECK(!device_name.empty());
|
absl::optional<jit::DeviceId> maybe_unknown_device;
|
||||||
device_names_set.insert(device_name);
|
|
||||||
}
|
|
||||||
|
|
||||||
absl::optional<absl::string_view> maybe_gpu_device;
|
bool multiple_cpu_devices = false;
|
||||||
absl::optional<absl::string_view> maybe_cpu_device;
|
bool multiple_gpu_devices = false;
|
||||||
absl::optional<absl::string_view> maybe_unknown_device;
|
bool multiple_unknown_devices = false;
|
||||||
|
|
||||||
for (absl::string_view device_name : device_names_set) {
|
devices.ForEach([&](jit::DeviceId device) {
|
||||||
DeviceNameUtils::ParsedName parsed_name;
|
if (device_info_cache.IsGpu(device)) {
|
||||||
TF_RET_CHECK(DeviceNameUtils::ParseFullName(device_name, &parsed_name))
|
|
||||||
<< device_name;
|
|
||||||
if (parsed_name.type == "GPU") {
|
|
||||||
if (maybe_gpu_device) {
|
if (maybe_gpu_device) {
|
||||||
FAILED_TO_PICK_DEVICE(errors::Internal(
|
multiple_gpu_devices = true;
|
||||||
"Multiple GPU devices ", absl::StrJoin(device_names, ", ")));
|
return false;
|
||||||
}
|
}
|
||||||
maybe_gpu_device = device_name;
|
maybe_gpu_device = device;
|
||||||
} else if (parsed_name.type == "CPU") {
|
} else if (device_info_cache.IsCpu(device)) {
|
||||||
if (maybe_cpu_device) {
|
if (maybe_cpu_device) {
|
||||||
FAILED_TO_PICK_DEVICE(errors::Internal(
|
multiple_cpu_devices = true;
|
||||||
"Multiple CPU devices ", absl::StrJoin(device_names, ", ")));
|
return false;
|
||||||
}
|
}
|
||||||
maybe_cpu_device = device_name;
|
maybe_cpu_device = device;
|
||||||
} else {
|
} else {
|
||||||
if (maybe_unknown_device) {
|
if (maybe_unknown_device) {
|
||||||
FAILED_TO_PICK_DEVICE(errors::Internal(
|
multiple_unknown_devices = true;
|
||||||
"Multiple unknown devices ", absl::StrJoin(device_names, ", ")));
|
return false;
|
||||||
}
|
}
|
||||||
maybe_unknown_device = device_name;
|
maybe_unknown_device = device;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
return true;
|
||||||
|
});
|
||||||
|
|
||||||
|
if (multiple_cpu_devices) {
|
||||||
|
FAILED_TO_PICK_DEVICE(errors::Internal(
|
||||||
|
"Multiple CPU devices ", device_info_cache.DebugString(devices)));
|
||||||
|
}
|
||||||
|
|
||||||
|
if (multiple_gpu_devices) {
|
||||||
|
FAILED_TO_PICK_DEVICE(errors::Internal(
|
||||||
|
"Multiple GPU devices ", device_info_cache.DebugString(devices)));
|
||||||
|
}
|
||||||
|
|
||||||
|
if (multiple_unknown_devices) {
|
||||||
|
FAILED_TO_PICK_DEVICE(errors::Internal(
|
||||||
|
"Multiple unknown devices ", device_info_cache.DebugString(devices)));
|
||||||
}
|
}
|
||||||
|
|
||||||
if (maybe_unknown_device && maybe_gpu_device) {
|
if (maybe_unknown_device && maybe_gpu_device) {
|
||||||
FAILED_TO_PICK_DEVICE(errors::Internal(
|
FAILED_TO_PICK_DEVICE(errors::Internal(
|
||||||
"Found both unknown and GPU devices: ", *maybe_unknown_device, ", ",
|
"Found both unknown and GPU devices: ",
|
||||||
*maybe_gpu_device));
|
device_info_cache.GetNameFor(*maybe_unknown_device), ", ",
|
||||||
|
device_info_cache.GetNameFor(*maybe_gpu_device)));
|
||||||
}
|
}
|
||||||
|
|
||||||
if (!allow_mixing_unknown_and_cpu) {
|
if (!allow_mixing_unknown_and_cpu) {
|
||||||
if (maybe_unknown_device && maybe_cpu_device) {
|
if (maybe_unknown_device && maybe_cpu_device) {
|
||||||
FAILED_TO_PICK_DEVICE(errors::Internal(
|
FAILED_TO_PICK_DEVICE(errors::Internal(
|
||||||
"Found both unknown and CPU devices: ", *maybe_unknown_device, ", ",
|
"Found both unknown and CPU devices: ",
|
||||||
*maybe_cpu_device));
|
device_info_cache.GetNameFor(*maybe_unknown_device), ", ",
|
||||||
|
device_info_cache.GetNameFor(*maybe_cpu_device)));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if (out_device_picked) {
|
if (out_device_picked) {
|
||||||
if (maybe_gpu_device) {
|
if (maybe_gpu_device) {
|
||||||
*out_device_picked = string(*maybe_gpu_device);
|
*out_device_picked = *maybe_gpu_device;
|
||||||
} else if (maybe_unknown_device) {
|
} else if (maybe_unknown_device) {
|
||||||
*out_device_picked = string(*maybe_unknown_device);
|
*out_device_picked = *maybe_unknown_device;
|
||||||
} else {
|
} else {
|
||||||
*out_device_picked = string(*maybe_cpu_device);
|
*out_device_picked = *maybe_cpu_device;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -156,19 +197,24 @@ Status PickDeviceForXlaImpl(absl::Span<const string> device_names,
|
|||||||
#undef FAILED_TO_PICK_DEVICE
|
#undef FAILED_TO_PICK_DEVICE
|
||||||
}
|
}
|
||||||
|
|
||||||
Status PickDeviceForXla(absl::Span<const string> device_names,
|
xla::StatusOr<jit::DeviceId> PickDeviceForXla(
|
||||||
bool allow_mixing_unknown_and_cpu,
|
const jit::DeviceInfoCache& device_info_cache,
|
||||||
string* out_device_picked) {
|
const jit::DeviceSet& devices, bool allow_mixing_unknown_and_cpu) {
|
||||||
return PickDeviceForXlaImpl(device_names, allow_mixing_unknown_and_cpu,
|
absl::optional<jit::DeviceId> device;
|
||||||
/*out_can_pick_device=*/nullptr,
|
TF_RETURN_IF_ERROR(PickDeviceForXlaImpl(
|
||||||
out_device_picked);
|
device_info_cache, devices, allow_mixing_unknown_and_cpu,
|
||||||
|
/*out_can_pick_device=*/nullptr, &device));
|
||||||
|
return *device;
|
||||||
}
|
}
|
||||||
|
|
||||||
Status CanPickDeviceForXla(absl::Span<const string> device_names,
|
xla::StatusOr<bool> CanPickDeviceForXla(
|
||||||
bool allow_mixing_unknown_and_cpu,
|
const jit::DeviceInfoCache& device_info_cache,
|
||||||
bool* out_can_pick_device) {
|
const jit::DeviceSet& devices, bool allow_mixing_unknown_and_cpu) {
|
||||||
return PickDeviceForXlaImpl(device_names, allow_mixing_unknown_and_cpu,
|
bool can_pick_device;
|
||||||
out_can_pick_device,
|
TF_RETURN_IF_ERROR(PickDeviceForXlaImpl(device_info_cache, devices,
|
||||||
/*out_device_picked=*/nullptr);
|
allow_mixing_unknown_and_cpu,
|
||||||
|
&can_pick_device,
|
||||||
|
/*out_device_picked=*/nullptr));
|
||||||
|
return can_pick_device;
|
||||||
}
|
}
|
||||||
} // namespace tensorflow
|
} // namespace tensorflow
|
||||||
|
@ -23,24 +23,119 @@ limitations under the License.
|
|||||||
#include "absl/strings/string_view.h"
|
#include "absl/strings/string_view.h"
|
||||||
#include "absl/types/span.h"
|
#include "absl/types/span.h"
|
||||||
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
|
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
|
||||||
|
#include "tensorflow/compiler/xla/status_macros.h"
|
||||||
#include "tensorflow/compiler/xla/statusor.h"
|
#include "tensorflow/compiler/xla/statusor.h"
|
||||||
#include "tensorflow/core/framework/types.h"
|
#include "tensorflow/core/framework/types.h"
|
||||||
|
|
||||||
namespace tensorflow {
|
namespace tensorflow {
|
||||||
namespace jit {
|
namespace jit {
|
||||||
|
// Instances of DeviceId represent TensorFlow devices as integers.
|
||||||
|
//
|
||||||
|
// This helps avoid having to manipulate device names as strings when
|
||||||
|
// auto-clustering.
|
||||||
|
class DeviceId {
|
||||||
|
public:
|
||||||
|
DeviceId(DeviceId&&) = default;
|
||||||
|
DeviceId(const DeviceId&) = default;
|
||||||
|
DeviceId& operator=(const DeviceId&) = default;
|
||||||
|
|
||||||
|
bool operator==(const DeviceId& other) const { return id() == other.id(); }
|
||||||
|
bool operator!=(const DeviceId& other) const { return !(*this == other); }
|
||||||
|
|
||||||
|
private:
|
||||||
|
int id_;
|
||||||
|
|
||||||
|
explicit DeviceId(int id) : id_(id) {}
|
||||||
|
|
||||||
|
int id() const { return id_; }
|
||||||
|
|
||||||
|
friend class DeviceInfoCache;
|
||||||
|
friend class DeviceSet;
|
||||||
|
};
|
||||||
|
|
||||||
|
// A set of DeviceIds, represented as a bitmap.
|
||||||
|
class DeviceSet {
|
||||||
|
public:
|
||||||
|
void Insert(DeviceId device_id);
|
||||||
|
void UnionWith(const DeviceSet& other);
|
||||||
|
bool IsEmpty() const;
|
||||||
|
|
||||||
|
// Calls `func` on each DeviceId in the set. Stops iterating early if `func`
|
||||||
|
// return false.
|
||||||
|
//
|
||||||
|
// TODO(sanjoy): Change this to take a typed std::function if that's
|
||||||
|
// performance neutral.
|
||||||
|
template <typename FnTy>
|
||||||
|
void ForEach(FnTy func) const {
|
||||||
|
// This is really a poor man's iterator, we should consider writing a proper
|
||||||
|
// iterator if this ends up being used widely.
|
||||||
|
for (int word_index = 0; word_index < storage_.size(); word_index++) {
|
||||||
|
uint64 word = storage_[word_index];
|
||||||
|
for (int bit_index = 0; bit_index < kWordSize; bit_index++) {
|
||||||
|
if (word & (1ull << bit_index)) {
|
||||||
|
if (!func(DeviceId(word_index * kWordSize + bit_index))) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
private:
|
||||||
|
absl::InlinedVector<uint64, 1> storage_;
|
||||||
|
|
||||||
|
const int kWordSize = 64;
|
||||||
|
};
|
||||||
|
|
||||||
// Caches some miscellaneous information about TF devices. Thread compatible.
|
// Caches some miscellaneous information about TF devices. Thread compatible.
|
||||||
class DeviceInfoCache {
|
class DeviceInfoCache {
|
||||||
public:
|
public:
|
||||||
xla::StatusOr<const XlaOpRegistry::DeviceRegistration*> GetCompilationDevice(
|
bool IsGpu(DeviceId device) const { return is_gpu_[device.id()]; }
|
||||||
absl::string_view device_name);
|
bool IsCpu(DeviceId device) const { return is_cpu_[device.id()]; }
|
||||||
xla::StatusOr<std::reference_wrapper<const DeviceType>> GetDeviceTypeFor(
|
|
||||||
absl::string_view device_name);
|
absl::string_view GetNameFor(DeviceId device) const {
|
||||||
|
return names_[device.id()];
|
||||||
|
}
|
||||||
|
|
||||||
|
xla::StatusOr<DeviceId> GetIdFor(absl::string_view name);
|
||||||
|
|
||||||
|
using DeviceRegistration = const XlaOpRegistry::DeviceRegistration;
|
||||||
|
|
||||||
|
DeviceRegistration* GetCompilationDevice(DeviceId device) const {
|
||||||
|
return id_to_compilation_device_[device.id()];
|
||||||
|
}
|
||||||
|
|
||||||
|
xla::StatusOr<DeviceRegistration*> GetCompilationDevice(
|
||||||
|
absl::string_view name) {
|
||||||
|
TF_ASSIGN_OR_RETURN(DeviceId device_id, GetIdFor(name));
|
||||||
|
return GetCompilationDevice(device_id);
|
||||||
|
}
|
||||||
|
|
||||||
|
const DeviceType& GetDeviceTypeFor(DeviceId device) const {
|
||||||
|
return *id_to_device_type_[device.id()];
|
||||||
|
}
|
||||||
|
|
||||||
|
using DeviceTypeConstRef = std::reference_wrapper<const DeviceType>;
|
||||||
|
|
||||||
|
xla::StatusOr<DeviceTypeConstRef> GetDeviceTypeFor(
|
||||||
|
absl::string_view device_name) {
|
||||||
|
TF_ASSIGN_OR_RETURN(DeviceId device_id, GetIdFor(device_name));
|
||||||
|
return std::cref(*id_to_device_type_[device_id.id()]);
|
||||||
|
}
|
||||||
|
|
||||||
|
string DebugString(const DeviceSet& device_set) const;
|
||||||
|
|
||||||
private:
|
private:
|
||||||
absl::flat_hash_map<string, const XlaOpRegistry::DeviceRegistration*>
|
absl::flat_hash_map<string, DeviceId> name_to_id_;
|
||||||
device_to_device_registration_;
|
|
||||||
absl::flat_hash_map<string, std::unique_ptr<DeviceType>>
|
// These fields are populated for a device in GetIdFor, *before* we give out a
|
||||||
device_to_device_type_;
|
// DeviceId.
|
||||||
|
std::vector<const XlaOpRegistry::DeviceRegistration*>
|
||||||
|
id_to_compilation_device_;
|
||||||
|
std::vector<std::unique_ptr<DeviceType>> id_to_device_type_;
|
||||||
|
std::vector<string> names_;
|
||||||
|
std::vector<bool> is_cpu_;
|
||||||
|
std::vector<bool> is_gpu_;
|
||||||
};
|
};
|
||||||
|
|
||||||
} // namespace jit
|
} // namespace jit
|
||||||
@ -49,7 +144,7 @@ class DeviceInfoCache {
|
|||||||
Status DeviceNameToDeviceType(const string& device, DeviceType* device_type);
|
Status DeviceNameToDeviceType(const string& device, DeviceType* device_type);
|
||||||
|
|
||||||
// Picks the device for which XLA should compile a cluster that contains
|
// Picks the device for which XLA should compile a cluster that contains
|
||||||
// operations placed in devices in `device_names`. For instance a cluster that
|
// operations placed in devices in `devices`. For instance a cluster that
|
||||||
// contains operations solely placed on the CPU will be compiled into a CPU
|
// contains operations solely placed on the CPU will be compiled into a CPU
|
||||||
// executable by XLA, whereas a cluster that contains operations placed on the
|
// executable by XLA, whereas a cluster that contains operations placed on the
|
||||||
// CPU and also operations placed on the GPU will be compiled into a GPU
|
// CPU and also operations placed on the GPU will be compiled into a GPU
|
||||||
@ -82,16 +177,15 @@ Status DeviceNameToDeviceType(const string& device, DeviceType* device_type);
|
|||||||
// case it is the responsibility of the optimization pass that injected the
|
// case it is the responsibility of the optimization pass that injected the
|
||||||
// CPU nodes into the cluster to ensure that these nodes can be compiled by
|
// CPU nodes into the cluster to ensure that these nodes can be compiled by
|
||||||
// the unknown XLA backend.
|
// the unknown XLA backend.
|
||||||
Status PickDeviceForXla(absl::Span<const string> device_names,
|
xla::StatusOr<jit::DeviceId> PickDeviceForXla(
|
||||||
bool allow_mixing_unknown_and_cpu,
|
const jit::DeviceInfoCache& device_info_cache,
|
||||||
string* out_device_picked);
|
const jit::DeviceSet& devices, bool allow_mixing_unknown_and_cpu);
|
||||||
|
|
||||||
// This is like `PickDeviceForXla` except that it returns false (instead of a
|
// This is like `PickDeviceForXla` except that it returns false (instead of a
|
||||||
// non-OK Status) in `out_can_pick_device` if no unambiguous choice of device
|
// non-OK Status) if no unambiguous choice of device exists.
|
||||||
// exists.
|
xla::StatusOr<bool> CanPickDeviceForXla(
|
||||||
Status CanPickDeviceForXla(absl::Span<const string> device_names,
|
const jit::DeviceInfoCache& device_info_cache,
|
||||||
bool allow_mixing_unknown_and_cpu,
|
const jit::DeviceSet& devices, bool allow_mixing_unknown_and_cpu);
|
||||||
bool* out_can_pick_device);
|
|
||||||
} // namespace tensorflow
|
} // namespace tensorflow
|
||||||
|
|
||||||
#endif // TENSORFLOW_COMPILER_JIT_DEVICE_INFO_CACHE_H_
|
#endif // TENSORFLOW_COMPILER_JIT_DEVICE_INFO_CACHE_H_
|
||||||
|
@ -22,12 +22,20 @@ namespace tensorflow {
|
|||||||
namespace {
|
namespace {
|
||||||
|
|
||||||
Status PickDeviceHelper(bool allow_mixing_unknown_and_cpu,
|
Status PickDeviceHelper(bool allow_mixing_unknown_and_cpu,
|
||||||
absl::Span<const absl::string_view> inputs,
|
absl::Span<const absl::string_view> device_names,
|
||||||
string* result) {
|
string* result) {
|
||||||
std::vector<string> inputs_string;
|
jit::DeviceInfoCache cache;
|
||||||
absl::c_transform(inputs, std::back_inserter(inputs_string),
|
jit::DeviceSet device_set;
|
||||||
[](absl::string_view sv) { return string(sv); });
|
for (absl::string_view name : device_names) {
|
||||||
return PickDeviceForXla(inputs_string, allow_mixing_unknown_and_cpu, result);
|
TF_ASSIGN_OR_RETURN(jit::DeviceId device_id, cache.GetIdFor(name));
|
||||||
|
device_set.Insert(device_id);
|
||||||
|
}
|
||||||
|
|
||||||
|
TF_ASSIGN_OR_RETURN(
|
||||||
|
jit::DeviceId result_id,
|
||||||
|
PickDeviceForXla(cache, device_set, allow_mixing_unknown_and_cpu));
|
||||||
|
*result = string(cache.GetNameFor(result_id));
|
||||||
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
|
||||||
void CheckPickDeviceResult(absl::string_view expected_result,
|
void CheckPickDeviceResult(absl::string_view expected_result,
|
||||||
@ -87,5 +95,38 @@ TEST(PickDeviceForXla, MultipleDevicesOfSameType) {
|
|||||||
CheckPickDeviceHasError(false, {kCPU0, kCPU1, kGPU0});
|
CheckPickDeviceHasError(false, {kCPU0, kCPU1, kGPU0});
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void SimpleRoundTripTestForDeviceSet(int num_devices) {
|
||||||
|
jit::DeviceSet device_set;
|
||||||
|
jit::DeviceInfoCache device_info_cache;
|
||||||
|
|
||||||
|
std::vector<string> expected_devices, actual_devices;
|
||||||
|
|
||||||
|
for (int i = 0; i < num_devices; i++) {
|
||||||
|
string device_name =
|
||||||
|
absl::StrCat("/job:localhost/replica:0/task:0/device:XPU:", i);
|
||||||
|
TF_ASSERT_OK_AND_ASSIGN(jit::DeviceId device_id,
|
||||||
|
device_info_cache.GetIdFor(device_name));
|
||||||
|
device_set.Insert(device_id);
|
||||||
|
expected_devices.push_back(device_name);
|
||||||
|
}
|
||||||
|
|
||||||
|
device_set.ForEach([&](jit::DeviceId device_id) {
|
||||||
|
actual_devices.push_back(string(device_info_cache.GetNameFor(device_id)));
|
||||||
|
return true;
|
||||||
|
});
|
||||||
|
|
||||||
|
EXPECT_EQ(expected_devices, actual_devices);
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(DeviceSetTest, SimpleRoundTrip_One) { SimpleRoundTripTestForDeviceSet(1); }
|
||||||
|
|
||||||
|
TEST(DeviceSetTest, SimpleRoundTrip_Small) {
|
||||||
|
SimpleRoundTripTestForDeviceSet(8);
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(DeviceSetTest, SimpleRoundTrip_Large) {
|
||||||
|
SimpleRoundTripTestForDeviceSet(800);
|
||||||
|
}
|
||||||
|
|
||||||
} // namespace
|
} // namespace
|
||||||
} // namespace tensorflow
|
} // namespace tensorflow
|
||||||
|
@ -57,6 +57,8 @@ namespace tensorflow {
|
|||||||
|
|
||||||
namespace {
|
namespace {
|
||||||
using DeadnessPredicate = DeadnessAnalysis::DeadnessPredicate;
|
using DeadnessPredicate = DeadnessAnalysis::DeadnessPredicate;
|
||||||
|
using jit::DeviceId;
|
||||||
|
using jit::DeviceSet;
|
||||||
using xla::StatusOr;
|
using xla::StatusOr;
|
||||||
|
|
||||||
// The clusters we create here are eventually lowered into an
|
// The clusters we create here are eventually lowered into an
|
||||||
@ -117,8 +119,8 @@ class MarkForCompilationPassImpl {
|
|||||||
public:
|
public:
|
||||||
// Constructs a trivial cluster representing a single TF node.
|
// Constructs a trivial cluster representing a single TF node.
|
||||||
Cluster(int tf_graph_node_id, int effective_cluster_size,
|
Cluster(int tf_graph_node_id, int effective_cluster_size,
|
||||||
bool has_functional_control_flow,
|
bool has_functional_control_flow, DeviceSet devices,
|
||||||
absl::flat_hash_set<string> devices, string resource_op_device,
|
absl::optional<DeviceId> resource_op_device,
|
||||||
absl::optional<int> resource_var_operation_node_id,
|
absl::optional<int> resource_var_operation_node_id,
|
||||||
absl::optional<DeadnessPredicate> deadness_predicate,
|
absl::optional<DeadnessPredicate> deadness_predicate,
|
||||||
bool is_xla_compile_attr_true, absl::optional<string> xla_scope)
|
bool is_xla_compile_attr_true, absl::optional<string> xla_scope)
|
||||||
@ -126,7 +128,7 @@ class MarkForCompilationPassImpl {
|
|||||||
effective_cluster_size_(effective_cluster_size),
|
effective_cluster_size_(effective_cluster_size),
|
||||||
has_functional_control_flow_(has_functional_control_flow),
|
has_functional_control_flow_(has_functional_control_flow),
|
||||||
devices_(std::move(devices)),
|
devices_(std::move(devices)),
|
||||||
resource_op_device_(std::move(resource_op_device)),
|
resource_op_device_(resource_op_device),
|
||||||
deadness_predicate_(deadness_predicate),
|
deadness_predicate_(deadness_predicate),
|
||||||
is_xla_compile_attr_true_(is_xla_compile_attr_true),
|
is_xla_compile_attr_true_(is_xla_compile_attr_true),
|
||||||
xla_scope_(std::move(xla_scope)) {
|
xla_scope_(std::move(xla_scope)) {
|
||||||
@ -162,12 +164,14 @@ class MarkForCompilationPassImpl {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// The set of devices nodes in the cluster are placed on.
|
// The set of devices nodes in the cluster are placed on.
|
||||||
const absl::flat_hash_set<string>& devices() const { return devices_; }
|
const DeviceSet& devices() const { return devices_; }
|
||||||
|
|
||||||
// If the cluster has a resource operation then the device the resource
|
// If the cluster has a resource operation then the device the resource
|
||||||
// operation is placed on. A cluster may have resource ops placed only on a
|
// operation is placed on. A cluster may have resource ops placed only on a
|
||||||
// single device.
|
// single device.
|
||||||
const string& resource_op_device() const { return resource_op_device_; }
|
const absl::optional<DeviceId>& resource_op_device() const {
|
||||||
|
return resource_op_device_;
|
||||||
|
}
|
||||||
|
|
||||||
// If not nullopt the a predicate that is true iff the cluster is alive.
|
// If not nullopt the a predicate that is true iff the cluster is alive.
|
||||||
// Otherwise the user has (unsafely) disabled deadness analysis. If this is
|
// Otherwise the user has (unsafely) disabled deadness analysis. If this is
|
||||||
@ -208,8 +212,8 @@ class MarkForCompilationPassImpl {
|
|||||||
int cycles_graph_node_id_;
|
int cycles_graph_node_id_;
|
||||||
int effective_cluster_size_;
|
int effective_cluster_size_;
|
||||||
bool has_functional_control_flow_;
|
bool has_functional_control_flow_;
|
||||||
absl::flat_hash_set<string> devices_;
|
DeviceSet devices_;
|
||||||
string resource_op_device_;
|
absl::optional<DeviceId> resource_op_device_;
|
||||||
absl::optional<DeadnessPredicate> deadness_predicate_;
|
absl::optional<DeadnessPredicate> deadness_predicate_;
|
||||||
bool is_xla_compile_attr_true_;
|
bool is_xla_compile_attr_true_;
|
||||||
absl::optional<string> xla_scope_;
|
absl::optional<string> xla_scope_;
|
||||||
@ -279,17 +283,17 @@ class MarkForCompilationPassImpl {
|
|||||||
|
|
||||||
Cluster* MakeNewCluster(int cycles_graph_node_id, int effective_cluster_size,
|
Cluster* MakeNewCluster(int cycles_graph_node_id, int effective_cluster_size,
|
||||||
bool has_functional_control_flow,
|
bool has_functional_control_flow,
|
||||||
absl::flat_hash_set<string> devices,
|
const DeviceSet& device_set,
|
||||||
string resource_op_device,
|
absl::optional<DeviceId> resource_op_device,
|
||||||
absl::optional<int> resource_var_operation_node_id,
|
absl::optional<int> resource_var_operation_node_id,
|
||||||
absl::optional<DeadnessPredicate> deadness_predicate,
|
absl::optional<DeadnessPredicate> deadness_predicate,
|
||||||
bool is_xla_compile_attr_true,
|
bool is_xla_compile_attr_true,
|
||||||
absl::optional<string> xla_scope) {
|
absl::optional<string> xla_scope) {
|
||||||
cluster_storage_.push_back(absl::make_unique<Cluster>(
|
cluster_storage_.push_back(absl::make_unique<Cluster>(
|
||||||
cycles_graph_node_id, effective_cluster_size,
|
cycles_graph_node_id, effective_cluster_size,
|
||||||
has_functional_control_flow, std::move(devices),
|
has_functional_control_flow, device_set, resource_op_device,
|
||||||
std::move(resource_op_device), resource_var_operation_node_id,
|
resource_var_operation_node_id, deadness_predicate,
|
||||||
deadness_predicate, is_xla_compile_attr_true, xla_scope));
|
is_xla_compile_attr_true, xla_scope));
|
||||||
return cluster_storage_.back().get();
|
return cluster_storage_.back().get();
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -486,13 +490,15 @@ void MarkForCompilationPassImpl::Cluster::Merge(Cluster* other) {
|
|||||||
effective_cluster_size_ += other->effective_cluster_size_;
|
effective_cluster_size_ += other->effective_cluster_size_;
|
||||||
has_functional_control_flow_ |= other->has_functional_control_flow_;
|
has_functional_control_flow_ |= other->has_functional_control_flow_;
|
||||||
|
|
||||||
for (string other_device : other->devices_) {
|
devices_.UnionWith(other->devices_);
|
||||||
devices_.insert(other_device);
|
|
||||||
}
|
|
||||||
other->devices_.clear();
|
|
||||||
|
|
||||||
if (resource_op_device_.empty()) {
|
DCHECK(!(resource_op_device_.has_value() &&
|
||||||
resource_op_device_ = std::move(other->resource_op_device_);
|
other->resource_op_device_.has_value()) ||
|
||||||
|
*resource_op_device_ == *other->resource_op_device_)
|
||||||
|
<< "AreDevicesCompatible should have returned false otherwise!";
|
||||||
|
|
||||||
|
if (!resource_op_device_.has_value()) {
|
||||||
|
resource_op_device_ = other->resource_op_device_;
|
||||||
}
|
}
|
||||||
|
|
||||||
is_xla_compile_attr_true_ |= other->is_xla_compile_attr_true_;
|
is_xla_compile_attr_true_ |= other->is_xla_compile_attr_true_;
|
||||||
@ -779,12 +785,14 @@ Status MarkForCompilationPassImpl::BuildInitialClusterSet() {
|
|||||||
deadness_analysis_->GetPredicateFor(node, Graph::kControlSlot));
|
deadness_analysis_->GetPredicateFor(node, Graph::kControlSlot));
|
||||||
}
|
}
|
||||||
|
|
||||||
const string& device = !node->assigned_device_name().empty()
|
const string& device_name_str = !node->assigned_device_name().empty()
|
||||||
? node->assigned_device_name()
|
? node->assigned_device_name()
|
||||||
: node->requested_device();
|
: node->requested_device();
|
||||||
|
TF_ASSIGN_OR_RETURN(DeviceId device,
|
||||||
|
device_info_cache_.GetIdFor(device_name_str));
|
||||||
|
|
||||||
bool is_resource_op = HasResourceInputOrOutput(*node);
|
bool is_resource_op = HasResourceInputOrOutput(*node);
|
||||||
string resource_op_device;
|
absl::optional<DeviceId> resource_op_device;
|
||||||
if (is_resource_op) {
|
if (is_resource_op) {
|
||||||
resource_op_device = device;
|
resource_op_device = device;
|
||||||
}
|
}
|
||||||
@ -805,15 +813,14 @@ Status MarkForCompilationPassImpl::BuildInitialClusterSet() {
|
|||||||
is_xla_compile_attr_true |= xla_compile_attr;
|
is_xla_compile_attr_true |= xla_compile_attr;
|
||||||
}
|
}
|
||||||
|
|
||||||
absl::flat_hash_set<string> devices;
|
DeviceSet devices;
|
||||||
devices.insert(device);
|
devices.Insert(device);
|
||||||
|
|
||||||
Cluster* new_cluster = MakeNewCluster(
|
Cluster* new_cluster = MakeNewCluster(
|
||||||
/*cycles_graph_node_id=*/node->id(),
|
/*cycles_graph_node_id=*/node->id(),
|
||||||
/*effective_cluster_size=*/effective_cluster_size,
|
/*effective_cluster_size=*/effective_cluster_size,
|
||||||
/*has_functional_control_flow=*/has_functional_control_flow,
|
/*has_functional_control_flow=*/has_functional_control_flow, devices,
|
||||||
std::move(devices), std::move(resource_op_device),
|
resource_op_device, resource_var_operation_node_id, deadness_predicate,
|
||||||
resource_var_operation_node_id, deadness_predicate,
|
|
||||||
/*is_xla_compile_attr_true=*/is_xla_compile_attr_true,
|
/*is_xla_compile_attr_true=*/is_xla_compile_attr_true,
|
||||||
GetXlaScope(node));
|
GetXlaScope(node));
|
||||||
|
|
||||||
@ -1255,27 +1262,22 @@ void MarkForCompilationPassImpl::VLogClusteringSummary() {
|
|||||||
|
|
||||||
StatusOr<bool> MarkForCompilationPassImpl::AreDevicesCompatible(
|
StatusOr<bool> MarkForCompilationPassImpl::AreDevicesCompatible(
|
||||||
const Cluster& cluster_a, const Cluster& cluster_b) {
|
const Cluster& cluster_a, const Cluster& cluster_b) {
|
||||||
std::vector<string> devices;
|
DeviceSet devices = cluster_a.devices();
|
||||||
absl::c_remove_copy(cluster_a.devices(), std::back_inserter(devices), "");
|
devices.UnionWith(cluster_b.devices());
|
||||||
absl::c_remove_copy(cluster_b.devices(), std::back_inserter(devices), "");
|
|
||||||
absl::c_sort(devices);
|
|
||||||
|
|
||||||
if (devices.empty()) {
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
|
|
||||||
// First check if we will even be able to pick a device for the larger
|
// First check if we will even be able to pick a device for the larger
|
||||||
// combined cluster.
|
// combined cluster.
|
||||||
bool can_pick_device;
|
TF_ASSIGN_OR_RETURN(
|
||||||
TF_RETURN_IF_ERROR(CanPickDeviceForXla(
|
bool can_pick_device,
|
||||||
devices, /*allow_mixing_unknown_and_cpu=*/false, &can_pick_device));
|
CanPickDeviceForXla(device_info_cache_, devices,
|
||||||
|
/*allow_mixing_unknown_and_cpu=*/false));
|
||||||
if (!can_pick_device) {
|
if (!can_pick_device) {
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
|
||||||
string chosen_device;
|
TF_ASSIGN_OR_RETURN(DeviceId chosen_device,
|
||||||
TF_RETURN_IF_ERROR(PickDeviceForXla(
|
PickDeviceForXla(device_info_cache_, devices,
|
||||||
devices, /*allow_mixing_unknown_and_cpu=*/false, &chosen_device));
|
/*allow_mixing_unknown_and_cpu=*/false));
|
||||||
|
|
||||||
// If we are able to pick a device `chosen_device` for the larger cluster, the
|
// If we are able to pick a device `chosen_device` for the larger cluster, the
|
||||||
// resource operations in `cluster_a` and `cluster_b` must be placed on the
|
// resource operations in `cluster_a` and `cluster_b` must be placed on the
|
||||||
@ -1283,9 +1285,11 @@ StatusOr<bool> MarkForCompilationPassImpl::AreDevicesCompatible(
|
|||||||
// _XlaRun kernels are going to run on and therefore try to access the
|
// _XlaRun kernels are going to run on and therefore try to access the
|
||||||
// resource variables from `chosen_device`, which will be an error if the
|
// resource variables from `chosen_device`, which will be an error if the
|
||||||
// resource variables are placed on some other device.
|
// resource variables are placed on some other device.
|
||||||
auto resource_op_device_ok = [&](const string& resource_op_device) {
|
auto resource_op_device_ok =
|
||||||
return resource_op_device.empty() || resource_op_device == chosen_device;
|
[&](absl::optional<DeviceId> resource_op_device) {
|
||||||
};
|
return !resource_op_device.has_value() ||
|
||||||
|
*resource_op_device == chosen_device;
|
||||||
|
};
|
||||||
|
|
||||||
return resource_op_device_ok(cluster_a.resource_op_device()) &&
|
return resource_op_device_ok(cluster_a.resource_op_device()) &&
|
||||||
resource_op_device_ok(cluster_b.resource_op_device());
|
resource_op_device_ok(cluster_b.resource_op_device());
|
||||||
@ -1294,22 +1298,18 @@ StatusOr<bool> MarkForCompilationPassImpl::AreDevicesCompatible(
|
|||||||
// Returns `true` iff we should compile `cluster`.
|
// Returns `true` iff we should compile `cluster`.
|
||||||
StatusOr<bool> MarkForCompilationPassImpl::ShouldCompileClusterImpl(
|
StatusOr<bool> MarkForCompilationPassImpl::ShouldCompileClusterImpl(
|
||||||
const Cluster& cluster) {
|
const Cluster& cluster) {
|
||||||
std::vector<string> devices;
|
TF_ASSIGN_OR_RETURN(DeviceId chosen_device,
|
||||||
absl::c_remove_copy(cluster.devices(), std::back_inserter(devices), "");
|
PickDeviceForXla(device_info_cache_, cluster.devices(),
|
||||||
absl::c_sort(devices);
|
/*allow_mixing_unknown_and_cpu=*/false));
|
||||||
|
|
||||||
string chosen_device;
|
const DeviceType& device_type =
|
||||||
TF_RETURN_IF_ERROR(PickDeviceForXla(
|
device_info_cache_.GetDeviceTypeFor(chosen_device);
|
||||||
devices, /*allow_mixing_unknown_and_cpu=*/false, &chosen_device));
|
const XlaOpRegistry::DeviceRegistration* registration =
|
||||||
|
device_info_cache_.GetCompilationDevice(chosen_device);
|
||||||
TF_ASSIGN_OR_RETURN(const DeviceType& device_type,
|
|
||||||
device_info_cache_.GetDeviceTypeFor(chosen_device));
|
|
||||||
TF_ASSIGN_OR_RETURN(const XlaOpRegistry::DeviceRegistration* registration,
|
|
||||||
device_info_cache_.GetCompilationDevice(chosen_device));
|
|
||||||
TF_RET_CHECK(registration)
|
TF_RET_CHECK(registration)
|
||||||
<< "chosen device = " << chosen_device
|
<< "chosen device = " << device_info_cache_.GetNameFor(chosen_device)
|
||||||
<< "; device type = " << device_type.type() << "; devices ("
|
<< "; device type = " << device_type.type() << "; devices ("
|
||||||
<< devices.size() << ") = " << absl::StrJoin(devices, ", ");
|
<< device_info_cache_.DebugString(cluster.devices());
|
||||||
|
|
||||||
bool should_compile =
|
bool should_compile =
|
||||||
cluster.is_xla_compile_attr_true() ||
|
cluster.is_xla_compile_attr_true() ||
|
||||||
@ -1343,7 +1343,8 @@ StatusOr<bool> MarkForCompilationPassImpl::ShouldCompileClusterImpl(
|
|||||||
}
|
}
|
||||||
|
|
||||||
VLOG(3) << (should_compile ? "Compiling" : "Not compiling")
|
VLOG(3) << (should_compile ? "Compiling" : "Not compiling")
|
||||||
<< " cluster with device " << chosen_device;
|
<< " cluster with device "
|
||||||
|
<< device_info_cache_.GetNameFor(chosen_device);
|
||||||
|
|
||||||
return should_compile;
|
return should_compile;
|
||||||
}
|
}
|
||||||
|
Loading…
Reference in New Issue
Block a user