236 lines
7.8 KiB
C++
236 lines
7.8 KiB
C++
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
|
|
|
Licensed under the Apache License, Version 2.0 (the "License");
|
|
you may not use this file except in compliance with the License.
|
|
You may obtain a copy of the License at
|
|
|
|
http://www.apache.org/licenses/LICENSE-2.0
|
|
|
|
Unless required by applicable law or agreed to in writing, software
|
|
distributed under the License is distributed on an "AS IS" BASIS,
|
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
See the License for the specific language governing permissions and
|
|
limitations under the License.
|
|
==============================================================================*/
|
|
|
|
#include "tensorflow/compiler/jit/device_util.h"
|
|
|
|
#include "absl/algorithm/container.h"
|
|
#include "absl/container/flat_hash_set.h"
|
|
#include "tensorflow/compiler/xla/status_macros.h"
|
|
|
|
namespace tensorflow {
|
|
namespace jit {
|
|
using xla::StatusOr;
|
|
|
|
void DeviceSet::Insert(DeviceId device_id) {
|
|
int word_index = device_id.id() / kWordSize;
|
|
int bit_index = device_id.id() % kWordSize;
|
|
|
|
if (word_index >= storage_.size()) {
|
|
storage_.resize(word_index + 1, 0);
|
|
}
|
|
|
|
storage_[word_index] |= (1ull << bit_index);
|
|
}
|
|
|
|
void DeviceSet::UnionWith(const DeviceSet& other) {
|
|
if (other.storage_.size() > storage_.size()) {
|
|
storage_.resize(other.storage_.size(), 0);
|
|
}
|
|
|
|
for (int i = 0; i < other.storage_.size(); i++) {
|
|
storage_[i] |= other.storage_[i];
|
|
}
|
|
}
|
|
|
|
bool DeviceSet::IsEmpty() const {
|
|
return absl::c_all_of(storage_, [&](uint64 val) { return val == 0; });
|
|
}
|
|
|
|
xla::StatusOr<DeviceId> DeviceInfoCache::GetIdFor(absl::string_view name) {
|
|
TF_RET_CHECK(!name.empty());
|
|
|
|
auto it = name_to_id_.find(name);
|
|
if (it != name_to_id_.end()) {
|
|
return it->second;
|
|
}
|
|
|
|
int new_id = names_.size();
|
|
names_.push_back(string(name));
|
|
id_to_device_type_.push_back(absl::make_unique<DeviceType>(""));
|
|
DeviceType* device_type = id_to_device_type_.back().get();
|
|
TF_RETURN_IF_ERROR(DeviceNameToDeviceType(names_.back(), device_type));
|
|
|
|
is_cpu_.push_back(device_type->type_string() == DEVICE_CPU);
|
|
is_gpu_.push_back(device_type->type_string() == DEVICE_GPU);
|
|
|
|
name_to_id_.emplace(string(name), DeviceId(new_id));
|
|
|
|
const XlaOpRegistry::DeviceRegistration* compilation_device;
|
|
if (!XlaOpRegistry::GetCompilationDevice(device_type->type(),
|
|
&compilation_device)) {
|
|
compilation_device = nullptr;
|
|
}
|
|
id_to_compilation_device_.push_back(compilation_device);
|
|
|
|
return DeviceId(new_id);
|
|
}
|
|
|
|
string DeviceInfoCache::DebugString(const DeviceSet& device_set) const {
|
|
std::vector<string> names;
|
|
device_set.ForEach([&](DeviceId device_id) {
|
|
names.push_back(string(GetNameFor(device_id)));
|
|
return true;
|
|
});
|
|
|
|
return absl::StrCat("[", absl::StrJoin(names, ","), "]");
|
|
}
|
|
} // namespace jit
|
|
|
|
Status DeviceNameToDeviceType(const string& device, DeviceType* device_type) {
|
|
DeviceNameUtils::ParsedName parsed;
|
|
if (!DeviceNameUtils::ParseFullName(device, &parsed)) {
|
|
return errors::Internal("Malformed assigned device '", device, "'");
|
|
}
|
|
*device_type = DeviceType(parsed.type);
|
|
return Status::OK();
|
|
}
|
|
|
|
xla::StatusOr<absl::optional<jit::DeviceId>> PickDeviceForXlaImpl(
|
|
const jit::DeviceInfoCache& device_info_cache,
|
|
const jit::DeviceSet& devices, bool allow_mixing_unknown_and_cpu,
|
|
bool failure_to_pick_is_error) {
|
|
#define FAILED_TO_PICK_DEVICE(failing_status) \
|
|
do { \
|
|
if (failure_to_pick_is_error) { \
|
|
return failing_status; \
|
|
} else { \
|
|
return {absl::nullopt}; \
|
|
} \
|
|
} while (false)
|
|
|
|
absl::optional<jit::DeviceId> maybe_gpu_device;
|
|
absl::optional<jit::DeviceId> maybe_cpu_device;
|
|
absl::optional<jit::DeviceId> maybe_unknown_device;
|
|
|
|
bool multiple_cpu_devices = false;
|
|
bool multiple_gpu_devices = false;
|
|
bool multiple_unknown_devices = false;
|
|
|
|
// Returns 'true' if d0 and d1 are conflicting devices. If they are
|
|
// compatible, update d1 with a more specific one.
|
|
// TODO(sanjoy): Cache DeviceNameUtils::ParsedName inside device_info_cache.
|
|
const auto is_multiple_devices =
|
|
[&](const jit::DeviceId& d0, absl::optional<jit::DeviceId>* d1) -> bool {
|
|
const absl::string_view name0 = device_info_cache.GetNameFor(d0);
|
|
const absl::string_view name1 = device_info_cache.GetNameFor(d1->value());
|
|
|
|
DeviceNameUtils::ParsedName parsed0, parsed1;
|
|
if (!DeviceNameUtils::ParseFullName(name0, &parsed0) ||
|
|
!DeviceNameUtils::ParseFullName(name1, &parsed1) ||
|
|
!DeviceNameUtils::AreCompatibleDevNames(parsed0, parsed1)) {
|
|
return true;
|
|
}
|
|
|
|
if (DeviceNameUtils::IsSpecification(parsed0, parsed1)) {
|
|
return false;
|
|
}
|
|
|
|
if (DeviceNameUtils::IsSpecification(parsed1, parsed0)) {
|
|
*d1 = d0;
|
|
return false;
|
|
}
|
|
|
|
return true;
|
|
};
|
|
|
|
devices.ForEach([&](jit::DeviceId device) {
|
|
if (device_info_cache.IsGpu(device)) {
|
|
if (maybe_gpu_device) {
|
|
multiple_gpu_devices = is_multiple_devices(device, &maybe_gpu_device);
|
|
if (multiple_gpu_devices) return false;
|
|
} else {
|
|
maybe_gpu_device = device;
|
|
}
|
|
} else if (device_info_cache.IsCpu(device)) {
|
|
if (maybe_cpu_device) {
|
|
multiple_cpu_devices = is_multiple_devices(device, &maybe_cpu_device);
|
|
if (multiple_cpu_devices) return false;
|
|
} else {
|
|
maybe_cpu_device = device;
|
|
}
|
|
} else {
|
|
if (maybe_unknown_device) {
|
|
multiple_unknown_devices = true;
|
|
return false;
|
|
}
|
|
maybe_unknown_device = device;
|
|
}
|
|
|
|
return true;
|
|
});
|
|
|
|
if (multiple_cpu_devices) {
|
|
FAILED_TO_PICK_DEVICE(errors::Internal(
|
|
"Multiple CPU devices ", device_info_cache.DebugString(devices)));
|
|
}
|
|
|
|
if (multiple_gpu_devices) {
|
|
FAILED_TO_PICK_DEVICE(errors::Internal(
|
|
"Multiple GPU devices ", device_info_cache.DebugString(devices)));
|
|
}
|
|
|
|
if (multiple_unknown_devices) {
|
|
FAILED_TO_PICK_DEVICE(errors::Internal(
|
|
"Multiple unknown devices ", device_info_cache.DebugString(devices)));
|
|
}
|
|
|
|
if (maybe_unknown_device && maybe_gpu_device) {
|
|
FAILED_TO_PICK_DEVICE(errors::Internal(
|
|
"Found both unknown and GPU devices: ",
|
|
device_info_cache.GetNameFor(*maybe_unknown_device), ", ",
|
|
device_info_cache.GetNameFor(*maybe_gpu_device)));
|
|
}
|
|
|
|
if (!allow_mixing_unknown_and_cpu) {
|
|
if (maybe_unknown_device && maybe_cpu_device) {
|
|
FAILED_TO_PICK_DEVICE(errors::Internal(
|
|
"Found both unknown and CPU devices: ",
|
|
device_info_cache.GetNameFor(*maybe_unknown_device), ", ",
|
|
device_info_cache.GetNameFor(*maybe_cpu_device)));
|
|
}
|
|
}
|
|
|
|
if (maybe_gpu_device) {
|
|
return {*maybe_gpu_device};
|
|
} else if (maybe_unknown_device) {
|
|
return {*maybe_unknown_device};
|
|
} else if (maybe_cpu_device) {
|
|
return {*maybe_cpu_device};
|
|
}
|
|
|
|
FAILED_TO_PICK_DEVICE(errors::Internal("Empty device set!"));
|
|
|
|
#undef FAILED_TO_PICK_DEVICE
|
|
}
|
|
|
|
xla::StatusOr<jit::DeviceId> PickDeviceForXla(
|
|
const jit::DeviceInfoCache& device_info_cache,
|
|
const jit::DeviceSet& devices, bool allow_mixing_unknown_and_cpu) {
|
|
TF_ASSIGN_OR_RETURN(absl::optional<jit::DeviceId> device_id,
|
|
PickDeviceForXlaImpl(device_info_cache, devices,
|
|
allow_mixing_unknown_and_cpu,
|
|
/*failure_to_pick_is_error=*/true));
|
|
return *device_id;
|
|
}
|
|
|
|
xla::StatusOr<absl::optional<jit::DeviceId>> MaybePickDeviceForXla(
|
|
const jit::DeviceInfoCache& device_info_cache,
|
|
const jit::DeviceSet& devices, bool allow_mixing_unknown_and_cpu) {
|
|
return PickDeviceForXlaImpl(device_info_cache, devices,
|
|
allow_mixing_unknown_and_cpu,
|
|
/*failure_to_pick_is_error=*/false);
|
|
}
|
|
} // namespace tensorflow
|