Use integers instead of strings to represent devices in the auto-clustering passes

This, in turn, lets us represent sets of devices as bitmaps, which is a big win.

PiperOrigin-RevId: 246939469
This commit is contained in:
Sanjoy Das 2019-05-06 18:46:45 -07:00 committed by TensorFlower Gardener
parent 233e0ddbe8
commit b912370109
6 changed files with 374 additions and 184 deletions

View File

@ -593,6 +593,7 @@ cc_library(
"//tensorflow/compiler/xla:status_macros",
"//tensorflow/compiler/xla:statusor",
"//tensorflow/core:framework",
"@com_google_absl//absl/algorithm:container",
"@com_google_absl//absl/container:flat_hash_map",
"@com_google_absl//absl/container:flat_hash_set",
"@com_google_absl//absl/strings",

View File

@ -233,14 +233,10 @@ void RemoveAllIncomingControlEdges(Graph* g, Node* n) {
}
// Returns true (into `result`) if a node placed on `device` must be compiled.
Status DeviceRequiresCompilation(const string& device, bool* result) {
DeviceType device_type("");
TF_RETURN_IF_ERROR(DeviceNameToDeviceType(device, &device_type));
const XlaOpRegistry::DeviceRegistration* registration = nullptr;
if (!XlaOpRegistry::GetCompilationDevice(device_type.type(), &registration)) {
return errors::Internal("Could not find compilation device ",
device_type.type());
}
Status DeviceRequiresCompilation(const jit::DeviceInfoCache& device_info_cache,
jit::DeviceId device, bool* result) {
const XlaOpRegistry::DeviceRegistration* registration =
device_info_cache.GetCompilationDevice(device);
*result = registration->autoclustering_policy ==
XlaOpRegistry::AutoclusteringPolicy::kAlways;
return Status::OK();
@ -293,17 +289,20 @@ Status ReplaceFunctionCallWithPartionedCall(
return Status::OK();
}
Status InferDeviceForCluster(Node* n, const string& function_name,
const FunctionLibraryDefinition& flib_def,
string* result) {
xla::StatusOr<jit::DeviceId> InferDeviceForCluster(
jit::DeviceInfoCache* device_info_cache, Node* n,
const string& function_name, const FunctionLibraryDefinition& flib_def) {
const FunctionDef* func_def = flib_def.Find(function_name);
TF_RET_CHECK(func_def) << "Could not find " << function_name;
std::set<string> device_names;
jit::DeviceSet device_set;
for (const NodeDef& ndef : func_def->node_def()) {
VLOG(3) << ndef.DebugString();
if (!ndef.device().empty()) {
device_names.insert(ndef.device());
TF_ASSIGN_OR_RETURN(jit::DeviceId device_id,
device_info_cache->GetIdFor(ndef.device()));
device_set.Insert(device_id);
}
}
@ -311,41 +310,47 @@ Status InferDeviceForCluster(Node* n, const string& function_name,
// TODO(sanjoy): We need this because EncapsulateSubgraphsPass drops device
// assignment when constant folding. We should fix EncapsulateSubgraphsPass
// instead.
device_names.insert(n->assigned_device_name());
TF_ASSIGN_OR_RETURN(jit::DeviceId device_id,
device_info_cache->GetIdFor(n->assigned_device_name()));
device_set.Insert(device_id);
}
std::vector<string> device_names_vector;
absl::c_copy(device_names, std::back_inserter(device_names_vector));
Status s = PickDeviceForXla(device_names_vector, true, result);
if (s.ok()) {
TF_ASSIGN_OR_RETURN(jit::DeviceId result,
PickDeviceForXla(*device_info_cache, device_set,
/*allow_mixing_unknown_and_cpu=*/true));
VLOG(2) << "For " << function_name << " PickDeviceForXla("
<< absl::StrJoin(device_names_vector, ", ") << ") -> " << *result;
}
return s;
<< device_info_cache->DebugString(device_set) << ") -> "
<< device_info_cache->GetNameFor(result);
return result;
}
Status ReplaceNodeWithXlaCompileAndXlaRun(
jit::DeviceInfoCache* device_info_cache,
const GraphOptimizationPassOptions& options,
const FunctionLibraryDefinition& flib_def, bool lazy_compilation_enabled,
bool insert_print_nodes, Graph* g, Node* n) {
XlaClusterInfo cluster_info;
TF_RETURN_IF_ERROR(GetXlaClusterInfo(n, &cluster_info));
string device;
TF_RETURN_IF_ERROR(InferDeviceForCluster(n, cluster_info.function.name(),
flib_def, &device));
TF_ASSIGN_OR_RETURN(
jit::DeviceId device,
InferDeviceForCluster(device_info_cache, n, cluster_info.function.name(),
flib_def));
bool requires_compilation;
TF_RETURN_IF_ERROR(DeviceRequiresCompilation(device, &requires_compilation));
TF_RETURN_IF_ERROR(DeviceRequiresCompilation(*device_info_cache, device,
&requires_compilation));
if (!lazy_compilation_enabled) {
requires_compilation = true;
}
string device_name_str = string(device_info_cache->GetNameFor(device));
Status status;
Scope root = NewInternalScope(g, &status, /*refiner=*/nullptr)
.NewSubScope(n->name())
.WithDevice(n->requested_device())
.WithAssignedDevice(device);
.WithAssignedDevice(device_name_str);
ops::_XlaCompile xla_compile(root.WithOpName("xla_compile"),
/*constants=*/cluster_info.constant_inputs,
@ -441,10 +446,12 @@ Status BuildXlaOpsPass::Run(const GraphOptimizationPassOptions& options) {
bool insert_print_nodes =
GetBuildXlaOpsPassFlags()->tf_xla_print_cluster_outputs;
jit::DeviceInfoCache device_info_cache;
for (Node* n : xla_compiled_kernels) {
TF_RETURN_IF_ERROR(ReplaceNodeWithXlaCompileAndXlaRun(
options, *options.flib_def, lazy_compilation_enabled,
insert_print_nodes, graph, n));
&device_info_cache, options, *options.flib_def,
lazy_compilation_enabled, insert_print_nodes, graph, n));
}
if (VLOG_IS_ON(1)) {

View File

@ -15,6 +15,7 @@ limitations under the License.
#include "tensorflow/compiler/jit/device_util.h"
#include "absl/algorithm/container.h"
#include "absl/container/flat_hash_set.h"
#include "tensorflow/compiler/xla/status_macros.h"
@ -22,43 +23,68 @@ namespace tensorflow {
namespace jit {
using xla::StatusOr;
StatusOr<const XlaOpRegistry::DeviceRegistration*>
DeviceInfoCache::GetCompilationDevice(absl::string_view device_name) {
auto it = device_to_device_registration_.find(device_name);
if (it != device_to_device_registration_.end()) {
void DeviceSet::Insert(DeviceId device_id) {
int word_index = device_id.id() / kWordSize;
int bit_index = device_id.id() % kWordSize;
if (word_index >= storage_.size()) {
storage_.resize(word_index + 1, 0);
}
storage_[word_index] |= (1ull << bit_index);
}
void DeviceSet::UnionWith(const DeviceSet& other) {
if (other.storage_.size() > storage_.size()) {
storage_.resize(other.storage_.size(), 0);
}
for (int i = 0; i < other.storage_.size(); i++) {
storage_[i] |= other.storage_[i];
}
}
bool DeviceSet::IsEmpty() const {
return absl::c_all_of(storage_, [&](uint64 val) { return val == 0; });
}
xla::StatusOr<DeviceId> DeviceInfoCache::GetIdFor(absl::string_view name) {
TF_RET_CHECK(!name.empty());
auto it = name_to_id_.find(name);
if (it != name_to_id_.end()) {
return it->second;
}
string device_name_str = string(device_name);
TF_ASSIGN_OR_RETURN(const DeviceType& device_type,
GetDeviceTypeFor(device_name_str));
const XlaOpRegistry::DeviceRegistration* registration;
if (!XlaOpRegistry::GetCompilationDevice(device_type.type(), &registration)) {
registration = nullptr;
int new_id = names_.size();
names_.push_back(string(name));
id_to_device_type_.push_back(absl::make_unique<DeviceType>(""));
DeviceType* device_type = id_to_device_type_.back().get();
TF_RETURN_IF_ERROR(DeviceNameToDeviceType(names_.back(), device_type));
is_cpu_.push_back(device_type->type_string() == DEVICE_CPU);
is_gpu_.push_back(device_type->type_string() == DEVICE_GPU);
name_to_id_.emplace(string(name), DeviceId(new_id));
const XlaOpRegistry::DeviceRegistration* compilation_device;
if (!XlaOpRegistry::GetCompilationDevice(device_type->type(),
&compilation_device)) {
compilation_device = nullptr;
}
id_to_compilation_device_.push_back(compilation_device);
return DeviceId(new_id);
}
device_to_device_registration_.insert(
{std::move(device_name_str), registration});
string DeviceInfoCache::DebugString(const DeviceSet& device_set) const {
std::vector<string> names;
device_set.ForEach([&](DeviceId device_id) {
names.push_back(string(GetNameFor(device_id)));
return false;
});
return registration;
}
StatusOr<std::reference_wrapper<const DeviceType>>
DeviceInfoCache::GetDeviceTypeFor(absl::string_view device_name) {
auto it = device_to_device_type_.find(device_name);
if (it != device_to_device_type_.end()) {
return std::cref(*it->second);
}
string device_name_str = string(device_name);
auto device_type = absl::make_unique<DeviceType>("");
TF_RETURN_IF_ERROR(
DeviceNameToDeviceType(device_name_str, device_type.get()));
it = device_to_device_type_
.insert({std::move(device_name_str), std::move(device_type)})
.first;
return std::cref(*it->second);
return absl::StrCat("[", absl::StrJoin(names, ","), "]");
}
} // namespace jit
@ -71,10 +97,11 @@ Status DeviceNameToDeviceType(const string& device, DeviceType* device_type) {
return Status::OK();
}
Status PickDeviceForXlaImpl(absl::Span<const string> device_names,
Status PickDeviceForXlaImpl(const jit::DeviceInfoCache& device_info_cache,
const jit::DeviceSet& devices,
bool allow_mixing_unknown_and_cpu,
bool* out_can_pick_device,
string* out_device_picked) {
absl::optional<jit::DeviceId>* out_device_picked) {
if (out_can_pick_device) {
*out_can_pick_device = true;
}
@ -89,65 +116,79 @@ Status PickDeviceForXlaImpl(absl::Span<const string> device_names,
} \
} while (false)
TF_RET_CHECK(!device_names.empty()) << "No devices to choose from";
TF_RET_CHECK(!devices.IsEmpty()) << "No devices to choose from";
DCHECK_NE(out_can_pick_device == nullptr, out_device_picked == nullptr);
absl::flat_hash_set<absl::string_view> device_names_set;
for (absl::string_view device_name : device_names) {
TF_RET_CHECK(!device_name.empty());
device_names_set.insert(device_name);
}
absl::optional<jit::DeviceId> maybe_gpu_device;
absl::optional<jit::DeviceId> maybe_cpu_device;
absl::optional<jit::DeviceId> maybe_unknown_device;
absl::optional<absl::string_view> maybe_gpu_device;
absl::optional<absl::string_view> maybe_cpu_device;
absl::optional<absl::string_view> maybe_unknown_device;
bool multiple_cpu_devices = false;
bool multiple_gpu_devices = false;
bool multiple_unknown_devices = false;
for (absl::string_view device_name : device_names_set) {
DeviceNameUtils::ParsedName parsed_name;
TF_RET_CHECK(DeviceNameUtils::ParseFullName(device_name, &parsed_name))
<< device_name;
if (parsed_name.type == "GPU") {
devices.ForEach([&](jit::DeviceId device) {
if (device_info_cache.IsGpu(device)) {
if (maybe_gpu_device) {
FAILED_TO_PICK_DEVICE(errors::Internal(
"Multiple GPU devices ", absl::StrJoin(device_names, ", ")));
multiple_gpu_devices = true;
return false;
}
maybe_gpu_device = device_name;
} else if (parsed_name.type == "CPU") {
maybe_gpu_device = device;
} else if (device_info_cache.IsCpu(device)) {
if (maybe_cpu_device) {
FAILED_TO_PICK_DEVICE(errors::Internal(
"Multiple CPU devices ", absl::StrJoin(device_names, ", ")));
multiple_cpu_devices = true;
return false;
}
maybe_cpu_device = device_name;
maybe_cpu_device = device;
} else {
if (maybe_unknown_device) {
multiple_unknown_devices = true;
return false;
}
maybe_unknown_device = device;
}
return true;
});
if (multiple_cpu_devices) {
FAILED_TO_PICK_DEVICE(errors::Internal(
"Multiple unknown devices ", absl::StrJoin(device_names, ", ")));
"Multiple CPU devices ", device_info_cache.DebugString(devices)));
}
maybe_unknown_device = device_name;
if (multiple_gpu_devices) {
FAILED_TO_PICK_DEVICE(errors::Internal(
"Multiple GPU devices ", device_info_cache.DebugString(devices)));
}
if (multiple_unknown_devices) {
FAILED_TO_PICK_DEVICE(errors::Internal(
"Multiple unknown devices ", device_info_cache.DebugString(devices)));
}
if (maybe_unknown_device && maybe_gpu_device) {
FAILED_TO_PICK_DEVICE(errors::Internal(
"Found both unknown and GPU devices: ", *maybe_unknown_device, ", ",
*maybe_gpu_device));
"Found both unknown and GPU devices: ",
device_info_cache.GetNameFor(*maybe_unknown_device), ", ",
device_info_cache.GetNameFor(*maybe_gpu_device)));
}
if (!allow_mixing_unknown_and_cpu) {
if (maybe_unknown_device && maybe_cpu_device) {
FAILED_TO_PICK_DEVICE(errors::Internal(
"Found both unknown and CPU devices: ", *maybe_unknown_device, ", ",
*maybe_cpu_device));
"Found both unknown and CPU devices: ",
device_info_cache.GetNameFor(*maybe_unknown_device), ", ",
device_info_cache.GetNameFor(*maybe_cpu_device)));
}
}
if (out_device_picked) {
if (maybe_gpu_device) {
*out_device_picked = string(*maybe_gpu_device);
*out_device_picked = *maybe_gpu_device;
} else if (maybe_unknown_device) {
*out_device_picked = string(*maybe_unknown_device);
*out_device_picked = *maybe_unknown_device;
} else {
*out_device_picked = string(*maybe_cpu_device);
*out_device_picked = *maybe_cpu_device;
}
}
@ -156,19 +197,24 @@ Status PickDeviceForXlaImpl(absl::Span<const string> device_names,
#undef FAILED_TO_PICK_DEVICE
}
Status PickDeviceForXla(absl::Span<const string> device_names,
bool allow_mixing_unknown_and_cpu,
string* out_device_picked) {
return PickDeviceForXlaImpl(device_names, allow_mixing_unknown_and_cpu,
/*out_can_pick_device=*/nullptr,
out_device_picked);
xla::StatusOr<jit::DeviceId> PickDeviceForXla(
const jit::DeviceInfoCache& device_info_cache,
const jit::DeviceSet& devices, bool allow_mixing_unknown_and_cpu) {
absl::optional<jit::DeviceId> device;
TF_RETURN_IF_ERROR(PickDeviceForXlaImpl(
device_info_cache, devices, allow_mixing_unknown_and_cpu,
/*out_can_pick_device=*/nullptr, &device));
return *device;
}
Status CanPickDeviceForXla(absl::Span<const string> device_names,
bool allow_mixing_unknown_and_cpu,
bool* out_can_pick_device) {
return PickDeviceForXlaImpl(device_names, allow_mixing_unknown_and_cpu,
out_can_pick_device,
/*out_device_picked=*/nullptr);
xla::StatusOr<bool> CanPickDeviceForXla(
const jit::DeviceInfoCache& device_info_cache,
const jit::DeviceSet& devices, bool allow_mixing_unknown_and_cpu) {
bool can_pick_device;
TF_RETURN_IF_ERROR(PickDeviceForXlaImpl(device_info_cache, devices,
allow_mixing_unknown_and_cpu,
&can_pick_device,
/*out_device_picked=*/nullptr));
return can_pick_device;
}
} // namespace tensorflow

View File

@ -23,24 +23,119 @@ limitations under the License.
#include "absl/strings/string_view.h"
#include "absl/types/span.h"
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
#include "tensorflow/compiler/xla/status_macros.h"
#include "tensorflow/compiler/xla/statusor.h"
#include "tensorflow/core/framework/types.h"
namespace tensorflow {
namespace jit {
// Instances of DeviceId represent TensorFlow devices as integers.
//
// This helps avoid having to manipulate device names as strings when
// auto-clustering.
class DeviceId {
public:
DeviceId(DeviceId&&) = default;
DeviceId(const DeviceId&) = default;
DeviceId& operator=(const DeviceId&) = default;
bool operator==(const DeviceId& other) const { return id() == other.id(); }
bool operator!=(const DeviceId& other) const { return !(*this == other); }
private:
int id_;
explicit DeviceId(int id) : id_(id) {}
int id() const { return id_; }
friend class DeviceInfoCache;
friend class DeviceSet;
};
// A set of DeviceIds, represented as a bitmap.
class DeviceSet {
public:
void Insert(DeviceId device_id);
void UnionWith(const DeviceSet& other);
bool IsEmpty() const;
// Calls `func` on each DeviceId in the set. Stops iterating early if `func`
// return false.
//
// TODO(sanjoy): Change this to take a typed std::function if that's
// performance neutral.
template <typename FnTy>
void ForEach(FnTy func) const {
// This is really a poor man's iterator, we should consider writing a proper
// iterator if this ends up being used widely.
for (int word_index = 0; word_index < storage_.size(); word_index++) {
uint64 word = storage_[word_index];
for (int bit_index = 0; bit_index < kWordSize; bit_index++) {
if (word & (1ull << bit_index)) {
if (!func(DeviceId(word_index * kWordSize + bit_index))) {
return;
}
}
}
}
}
private:
absl::InlinedVector<uint64, 1> storage_;
const int kWordSize = 64;
};
// Caches some miscellaneous information about TF devices. Thread compatible.
class DeviceInfoCache {
public:
xla::StatusOr<const XlaOpRegistry::DeviceRegistration*> GetCompilationDevice(
absl::string_view device_name);
xla::StatusOr<std::reference_wrapper<const DeviceType>> GetDeviceTypeFor(
absl::string_view device_name);
bool IsGpu(DeviceId device) const { return is_gpu_[device.id()]; }
bool IsCpu(DeviceId device) const { return is_cpu_[device.id()]; }
absl::string_view GetNameFor(DeviceId device) const {
return names_[device.id()];
}
xla::StatusOr<DeviceId> GetIdFor(absl::string_view name);
using DeviceRegistration = const XlaOpRegistry::DeviceRegistration;
DeviceRegistration* GetCompilationDevice(DeviceId device) const {
return id_to_compilation_device_[device.id()];
}
xla::StatusOr<DeviceRegistration*> GetCompilationDevice(
absl::string_view name) {
TF_ASSIGN_OR_RETURN(DeviceId device_id, GetIdFor(name));
return GetCompilationDevice(device_id);
}
const DeviceType& GetDeviceTypeFor(DeviceId device) const {
return *id_to_device_type_[device.id()];
}
using DeviceTypeConstRef = std::reference_wrapper<const DeviceType>;
xla::StatusOr<DeviceTypeConstRef> GetDeviceTypeFor(
absl::string_view device_name) {
TF_ASSIGN_OR_RETURN(DeviceId device_id, GetIdFor(device_name));
return std::cref(*id_to_device_type_[device_id.id()]);
}
string DebugString(const DeviceSet& device_set) const;
private:
absl::flat_hash_map<string, const XlaOpRegistry::DeviceRegistration*>
device_to_device_registration_;
absl::flat_hash_map<string, std::unique_ptr<DeviceType>>
device_to_device_type_;
absl::flat_hash_map<string, DeviceId> name_to_id_;
// These fields are populated for a device in GetIdFor, *before* we give out a
// DeviceId.
std::vector<const XlaOpRegistry::DeviceRegistration*>
id_to_compilation_device_;
std::vector<std::unique_ptr<DeviceType>> id_to_device_type_;
std::vector<string> names_;
std::vector<bool> is_cpu_;
std::vector<bool> is_gpu_;
};
} // namespace jit
@ -49,7 +144,7 @@ class DeviceInfoCache {
Status DeviceNameToDeviceType(const string& device, DeviceType* device_type);
// Picks the device for which XLA should compile a cluster that contains
// operations placed in devices in `device_names`. For instance a cluster that
// operations placed in devices in `devices`. For instance a cluster that
// contains operations solely placed on the CPU will be compiled into a CPU
// executable by XLA, whereas a cluster that contains operations placed on the
// CPU and also operations placed on the GPU will be compiled into a GPU
@ -82,16 +177,15 @@ Status DeviceNameToDeviceType(const string& device, DeviceType* device_type);
// case it is the responsibility of the optimization pass that injected the
// CPU nodes into the cluster to ensure that these nodes can be compiled by
// the unknown XLA backend.
Status PickDeviceForXla(absl::Span<const string> device_names,
bool allow_mixing_unknown_and_cpu,
string* out_device_picked);
xla::StatusOr<jit::DeviceId> PickDeviceForXla(
const jit::DeviceInfoCache& device_info_cache,
const jit::DeviceSet& devices, bool allow_mixing_unknown_and_cpu);
// This is like `PickDeviceForXla` except that it returns false (instead of a
// non-OK Status) in `out_can_pick_device` if no unambiguous choice of device
// exists.
Status CanPickDeviceForXla(absl::Span<const string> device_names,
bool allow_mixing_unknown_and_cpu,
bool* out_can_pick_device);
// non-OK Status) if no unambiguous choice of device exists.
xla::StatusOr<bool> CanPickDeviceForXla(
const jit::DeviceInfoCache& device_info_cache,
const jit::DeviceSet& devices, bool allow_mixing_unknown_and_cpu);
} // namespace tensorflow
#endif // TENSORFLOW_COMPILER_JIT_DEVICE_INFO_CACHE_H_

View File

@ -22,12 +22,20 @@ namespace tensorflow {
namespace {
Status PickDeviceHelper(bool allow_mixing_unknown_and_cpu,
absl::Span<const absl::string_view> inputs,
absl::Span<const absl::string_view> device_names,
string* result) {
std::vector<string> inputs_string;
absl::c_transform(inputs, std::back_inserter(inputs_string),
[](absl::string_view sv) { return string(sv); });
return PickDeviceForXla(inputs_string, allow_mixing_unknown_and_cpu, result);
jit::DeviceInfoCache cache;
jit::DeviceSet device_set;
for (absl::string_view name : device_names) {
TF_ASSIGN_OR_RETURN(jit::DeviceId device_id, cache.GetIdFor(name));
device_set.Insert(device_id);
}
TF_ASSIGN_OR_RETURN(
jit::DeviceId result_id,
PickDeviceForXla(cache, device_set, allow_mixing_unknown_and_cpu));
*result = string(cache.GetNameFor(result_id));
return Status::OK();
}
void CheckPickDeviceResult(absl::string_view expected_result,
@ -87,5 +95,38 @@ TEST(PickDeviceForXla, MultipleDevicesOfSameType) {
CheckPickDeviceHasError(false, {kCPU0, kCPU1, kGPU0});
}
void SimpleRoundTripTestForDeviceSet(int num_devices) {
jit::DeviceSet device_set;
jit::DeviceInfoCache device_info_cache;
std::vector<string> expected_devices, actual_devices;
for (int i = 0; i < num_devices; i++) {
string device_name =
absl::StrCat("/job:localhost/replica:0/task:0/device:XPU:", i);
TF_ASSERT_OK_AND_ASSIGN(jit::DeviceId device_id,
device_info_cache.GetIdFor(device_name));
device_set.Insert(device_id);
expected_devices.push_back(device_name);
}
device_set.ForEach([&](jit::DeviceId device_id) {
actual_devices.push_back(string(device_info_cache.GetNameFor(device_id)));
return true;
});
EXPECT_EQ(expected_devices, actual_devices);
}
TEST(DeviceSetTest, SimpleRoundTrip_One) { SimpleRoundTripTestForDeviceSet(1); }
TEST(DeviceSetTest, SimpleRoundTrip_Small) {
SimpleRoundTripTestForDeviceSet(8);
}
TEST(DeviceSetTest, SimpleRoundTrip_Large) {
SimpleRoundTripTestForDeviceSet(800);
}
} // namespace
} // namespace tensorflow

View File

@ -57,6 +57,8 @@ namespace tensorflow {
namespace {
using DeadnessPredicate = DeadnessAnalysis::DeadnessPredicate;
using jit::DeviceId;
using jit::DeviceSet;
using xla::StatusOr;
// The clusters we create here are eventually lowered into an
@ -117,8 +119,8 @@ class MarkForCompilationPassImpl {
public:
// Constructs a trivial cluster representing a single TF node.
Cluster(int tf_graph_node_id, int effective_cluster_size,
bool has_functional_control_flow,
absl::flat_hash_set<string> devices, string resource_op_device,
bool has_functional_control_flow, DeviceSet devices,
absl::optional<DeviceId> resource_op_device,
absl::optional<int> resource_var_operation_node_id,
absl::optional<DeadnessPredicate> deadness_predicate,
bool is_xla_compile_attr_true, absl::optional<string> xla_scope)
@ -126,7 +128,7 @@ class MarkForCompilationPassImpl {
effective_cluster_size_(effective_cluster_size),
has_functional_control_flow_(has_functional_control_flow),
devices_(std::move(devices)),
resource_op_device_(std::move(resource_op_device)),
resource_op_device_(resource_op_device),
deadness_predicate_(deadness_predicate),
is_xla_compile_attr_true_(is_xla_compile_attr_true),
xla_scope_(std::move(xla_scope)) {
@ -162,12 +164,14 @@ class MarkForCompilationPassImpl {
}
// The set of devices nodes in the cluster are placed on.
const absl::flat_hash_set<string>& devices() const { return devices_; }
const DeviceSet& devices() const { return devices_; }
// If the cluster has a resource operation then the device the resource
// operation is placed on. A cluster may have resource ops placed only on a
// single device.
const string& resource_op_device() const { return resource_op_device_; }
const absl::optional<DeviceId>& resource_op_device() const {
return resource_op_device_;
}
// If not nullopt the a predicate that is true iff the cluster is alive.
// Otherwise the user has (unsafely) disabled deadness analysis. If this is
@ -208,8 +212,8 @@ class MarkForCompilationPassImpl {
int cycles_graph_node_id_;
int effective_cluster_size_;
bool has_functional_control_flow_;
absl::flat_hash_set<string> devices_;
string resource_op_device_;
DeviceSet devices_;
absl::optional<DeviceId> resource_op_device_;
absl::optional<DeadnessPredicate> deadness_predicate_;
bool is_xla_compile_attr_true_;
absl::optional<string> xla_scope_;
@ -279,17 +283,17 @@ class MarkForCompilationPassImpl {
Cluster* MakeNewCluster(int cycles_graph_node_id, int effective_cluster_size,
bool has_functional_control_flow,
absl::flat_hash_set<string> devices,
string resource_op_device,
const DeviceSet& device_set,
absl::optional<DeviceId> resource_op_device,
absl::optional<int> resource_var_operation_node_id,
absl::optional<DeadnessPredicate> deadness_predicate,
bool is_xla_compile_attr_true,
absl::optional<string> xla_scope) {
cluster_storage_.push_back(absl::make_unique<Cluster>(
cycles_graph_node_id, effective_cluster_size,
has_functional_control_flow, std::move(devices),
std::move(resource_op_device), resource_var_operation_node_id,
deadness_predicate, is_xla_compile_attr_true, xla_scope));
has_functional_control_flow, device_set, resource_op_device,
resource_var_operation_node_id, deadness_predicate,
is_xla_compile_attr_true, xla_scope));
return cluster_storage_.back().get();
}
@ -486,13 +490,15 @@ void MarkForCompilationPassImpl::Cluster::Merge(Cluster* other) {
effective_cluster_size_ += other->effective_cluster_size_;
has_functional_control_flow_ |= other->has_functional_control_flow_;
for (string other_device : other->devices_) {
devices_.insert(other_device);
}
other->devices_.clear();
devices_.UnionWith(other->devices_);
if (resource_op_device_.empty()) {
resource_op_device_ = std::move(other->resource_op_device_);
DCHECK(!(resource_op_device_.has_value() &&
other->resource_op_device_.has_value()) ||
*resource_op_device_ == *other->resource_op_device_)
<< "AreDevicesCompatible should have returned false otherwise!";
if (!resource_op_device_.has_value()) {
resource_op_device_ = other->resource_op_device_;
}
is_xla_compile_attr_true_ |= other->is_xla_compile_attr_true_;
@ -779,12 +785,14 @@ Status MarkForCompilationPassImpl::BuildInitialClusterSet() {
deadness_analysis_->GetPredicateFor(node, Graph::kControlSlot));
}
const string& device = !node->assigned_device_name().empty()
const string& device_name_str = !node->assigned_device_name().empty()
? node->assigned_device_name()
: node->requested_device();
TF_ASSIGN_OR_RETURN(DeviceId device,
device_info_cache_.GetIdFor(device_name_str));
bool is_resource_op = HasResourceInputOrOutput(*node);
string resource_op_device;
absl::optional<DeviceId> resource_op_device;
if (is_resource_op) {
resource_op_device = device;
}
@ -805,15 +813,14 @@ Status MarkForCompilationPassImpl::BuildInitialClusterSet() {
is_xla_compile_attr_true |= xla_compile_attr;
}
absl::flat_hash_set<string> devices;
devices.insert(device);
DeviceSet devices;
devices.Insert(device);
Cluster* new_cluster = MakeNewCluster(
/*cycles_graph_node_id=*/node->id(),
/*effective_cluster_size=*/effective_cluster_size,
/*has_functional_control_flow=*/has_functional_control_flow,
std::move(devices), std::move(resource_op_device),
resource_var_operation_node_id, deadness_predicate,
/*has_functional_control_flow=*/has_functional_control_flow, devices,
resource_op_device, resource_var_operation_node_id, deadness_predicate,
/*is_xla_compile_attr_true=*/is_xla_compile_attr_true,
GetXlaScope(node));
@ -1255,27 +1262,22 @@ void MarkForCompilationPassImpl::VLogClusteringSummary() {
StatusOr<bool> MarkForCompilationPassImpl::AreDevicesCompatible(
const Cluster& cluster_a, const Cluster& cluster_b) {
std::vector<string> devices;
absl::c_remove_copy(cluster_a.devices(), std::back_inserter(devices), "");
absl::c_remove_copy(cluster_b.devices(), std::back_inserter(devices), "");
absl::c_sort(devices);
if (devices.empty()) {
return false;
}
DeviceSet devices = cluster_a.devices();
devices.UnionWith(cluster_b.devices());
// First check if we will even be able to pick a device for the larger
// combined cluster.
bool can_pick_device;
TF_RETURN_IF_ERROR(CanPickDeviceForXla(
devices, /*allow_mixing_unknown_and_cpu=*/false, &can_pick_device));
TF_ASSIGN_OR_RETURN(
bool can_pick_device,
CanPickDeviceForXla(device_info_cache_, devices,
/*allow_mixing_unknown_and_cpu=*/false));
if (!can_pick_device) {
return false;
}
string chosen_device;
TF_RETURN_IF_ERROR(PickDeviceForXla(
devices, /*allow_mixing_unknown_and_cpu=*/false, &chosen_device));
TF_ASSIGN_OR_RETURN(DeviceId chosen_device,
PickDeviceForXla(device_info_cache_, devices,
/*allow_mixing_unknown_and_cpu=*/false));
// If we are able to pick a device `chosen_device` for the larger cluster, the
// resource operations in `cluster_a` and `cluster_b` must be placed on the
@ -1283,8 +1285,10 @@ StatusOr<bool> MarkForCompilationPassImpl::AreDevicesCompatible(
// _XlaRun kernels are going to run on and therefore try to access the
// resource variables from `chosen_device`, which will be an error if the
// resource variables are placed on some other device.
auto resource_op_device_ok = [&](const string& resource_op_device) {
return resource_op_device.empty() || resource_op_device == chosen_device;
auto resource_op_device_ok =
[&](absl::optional<DeviceId> resource_op_device) {
return !resource_op_device.has_value() ||
*resource_op_device == chosen_device;
};
return resource_op_device_ok(cluster_a.resource_op_device()) &&
@ -1294,22 +1298,18 @@ StatusOr<bool> MarkForCompilationPassImpl::AreDevicesCompatible(
// Returns `true` iff we should compile `cluster`.
StatusOr<bool> MarkForCompilationPassImpl::ShouldCompileClusterImpl(
const Cluster& cluster) {
std::vector<string> devices;
absl::c_remove_copy(cluster.devices(), std::back_inserter(devices), "");
absl::c_sort(devices);
TF_ASSIGN_OR_RETURN(DeviceId chosen_device,
PickDeviceForXla(device_info_cache_, cluster.devices(),
/*allow_mixing_unknown_and_cpu=*/false));
string chosen_device;
TF_RETURN_IF_ERROR(PickDeviceForXla(
devices, /*allow_mixing_unknown_and_cpu=*/false, &chosen_device));
TF_ASSIGN_OR_RETURN(const DeviceType& device_type,
device_info_cache_.GetDeviceTypeFor(chosen_device));
TF_ASSIGN_OR_RETURN(const XlaOpRegistry::DeviceRegistration* registration,
device_info_cache_.GetCompilationDevice(chosen_device));
const DeviceType& device_type =
device_info_cache_.GetDeviceTypeFor(chosen_device);
const XlaOpRegistry::DeviceRegistration* registration =
device_info_cache_.GetCompilationDevice(chosen_device);
TF_RET_CHECK(registration)
<< "chosen device = " << chosen_device
<< "chosen device = " << device_info_cache_.GetNameFor(chosen_device)
<< "; device type = " << device_type.type() << "; devices ("
<< devices.size() << ") = " << absl::StrJoin(devices, ", ");
<< device_info_cache_.DebugString(cluster.devices());
bool should_compile =
cluster.is_xla_compile_attr_true() ||
@ -1343,7 +1343,8 @@ StatusOr<bool> MarkForCompilationPassImpl::ShouldCompileClusterImpl(
}
VLOG(3) << (should_compile ? "Compiling" : "Not compiling")
<< " cluster with device " << chosen_device;
<< " cluster with device "
<< device_info_cache_.GetNameFor(chosen_device);
return should_compile;
}