Move some device placement logic to EagerContext.

This will make it easier to incorporate this logic into places like
pywrap_tensor.cc next, which needs to use the logic towards fixing
b/139690309.

I took the opportunity do perform some slight reshuffling of the
original logic to make it more readable.

PiperOrigin-RevId: 292023452
Change-Id: I2af49f738bf38b776c20fd6edbd525d2429c831f
This commit is contained in:
A. Unique TensorFlower 2020-01-28 15:15:04 -08:00 committed by TensorFlower Gardener
parent 7a4123bda5
commit 53889e9671
3 changed files with 105 additions and 77 deletions

View File

@ -30,6 +30,7 @@ limitations under the License.
#include "tensorflow/core/common_runtime/collective_executor_mgr.h"
#include "tensorflow/core/common_runtime/collective_param_resolver_local.h"
#include "tensorflow/core/common_runtime/colocation_graph.h"
#include "tensorflow/core/common_runtime/device_resolver_local.h"
#include "tensorflow/core/common_runtime/device_set.h"
#include "tensorflow/core/common_runtime/eager/process_function_library_runtime.h"
@ -154,6 +155,75 @@ void EagerContext::InitPrioritizedDeviceTypeList() {
prioritized_device_type_list_ = ds.PrioritizedDeviceTypeList();
}
namespace {
// Using absl::StrJoin with lambda does not work in tf-lite builds.
// TODO(b/148160441): Replace with absl::StrJoin once DeviceBase has operator<<.
std::vector<string> DevicesToString(const std::vector<Device*>& devices) {
std::vector<string> v;
v.reserve(devices.size());
for (Device* d : devices) {
v.push_back(d->name());
}
return v;
}
} // namespace
Status EagerContext::SelectDevice(const DeviceNameUtils::ParsedName& preferred,
const PrioritizedDeviceTypeVector& supported,
Device** device) const {
std::vector<Device*> selected;
const DeviceSet& pflr_devices = *pflr()->device_set();
// If there are no preferred devices, select the first registered device from
// the supported device list.
if (!DeviceNameUtils::HasSomeDetails(preferred)) {
// TODO(b/148213212): Allow setting default device in eager context.
selected = ColocationGraph::FilterSupportedDevices(
pflr_devices.devices(), supported, /*default_local_device=*/nullptr);
if (selected.empty()) {
return errors::InvalidArgument(
"No supported device found in available devices [",
absl::StrJoin(DevicesToString(pflr_devices.devices()), ", "), "].");
}
*device = selected[0];
return Status::OK();
}
// If the caller specified a preferred device, select the first matching
// registered device from the supported device list. If nothing matches and
// soft placement is enabled, pick a suitable device from the available ones.
pflr_devices.FindMatchingDevices(preferred, &selected);
if (!selected.empty()) {
selected = ColocationGraph::FilterSupportedDevices(
selected, supported, /*default_local_device=*/nullptr);
}
if (selected.empty() && AllowSoftPlacement()) {
DeviceNameUtils::ParsedName soft_device_name = preferred;
soft_device_name.type.clear();
soft_device_name.has_type = false;
soft_device_name.has_id = false;
// TODO(b/148213746): Soft placement logic picks up another task if the
// requested does not exist.
pflr_devices.FindMatchingDevices(soft_device_name, &selected);
if (!selected.empty()) {
selected = ColocationGraph::FilterSupportedDevices(
selected, supported, /*default_local_device=*/nullptr);
}
}
if (selected.empty()) {
return errors::InvalidArgument(
"Could not satisfy device specification '", preferred,
"'. All available devices [",
absl::StrJoin(DevicesToString(pflr_devices.devices()), ", "), "].");
}
*device = selected[0];
return Status::OK();
}
void EagerContext::ResetClusterFLR(
DistributedFunctionLibraryRuntime* cluster_flr) {
cluster_flr_.Reset(cluster_flr, lazy_copy_function_remote_inputs_);

View File

@ -156,6 +156,23 @@ class EagerContext : public core::RefCounted {
// Returns the device placement policy for the current thread.
ContextDevicePlacementPolicy GetDevicePlacementPolicy() const;
// Select an appropriate device for an operation.
//
// Given the preferred device for the operation, and the list of devices the
// operation supports, finds the best suitable device for the operation in
// this context.
//
// The preferred device is specified as a `ParsedName` containing the elements
// (details) that the resulting device should match. If there are no such
// devices, and the context currently allows soft device placement, a suitable
// device not matching `preferred` will be chosen.
//
// The chosen device is stored in the `device` argument. The argument is not
// modified unless this method returns `Status::OK()`.
Status SelectDevice(const DeviceNameUtils::ParsedName& preferred,
const PrioritizedDeviceTypeVector& supported,
Device** device) const;
// Sets the implicit copy policy for the current thread.
void SetThreadLocalMirroringPolicy(ContextMirroringPolicy);

View File

@ -75,16 +75,6 @@ namespace tensorflow {
namespace {
// Using absl::StrJoin with lambda does not work in tf-lite builds.
std::vector<string> DevicesToString(const std::vector<Device*> devices) {
std::vector<string> v;
v.reserve(devices.size());
for (Device* d : devices) {
v.push_back(d->name());
}
return v;
}
const string& DeviceNameOrUnspecified(Device* device) {
static string* unspecified_string = new string("<unspecified>");
return (device == nullptr) ? *unspecified_string : device->name();
@ -208,72 +198,6 @@ Status ValidateInputTypeAndPlacement(
return Status::OK();
}
Status SelectDevice(EagerOperation* op, const NodeDef& ndef,
const EagerContext& ctx, Device** device) {
std::vector<Device*> final_devices;
PrioritizedDeviceTypeVector supported_devs;
TF_RETURN_IF_ERROR(SupportedDeviceTypesForNode(
ctx.prioritized_device_type_list(), ndef, &supported_devs,
&ctx.HostCPU()->parsed_name()));
if (supported_devs.empty()) {
return errors::NotFound("Could not find valid device for node.\nNode:",
FormatNodeDefForError(ndef),
"\nAll kernels registered for op ", ndef.op(),
" :\n", KernelsRegisteredForOp(ndef.op()));
}
if (DeviceNameUtils::HasSomeDetails(op->GetDeviceParsedName())) {
ctx.pflr()->device_set()->FindMatchingDevices(op->GetDeviceParsedName(),
&final_devices);
if (!final_devices.empty()) {
final_devices = ColocationGraph::FilterSupportedDevices(
final_devices, supported_devs, /*default_local_device=*/nullptr);
}
if (final_devices.empty() && ctx.AllowSoftPlacement()) {
DeviceNameUtils::ParsedName soft_device_name = op->GetDeviceParsedName();
soft_device_name.type.clear();
soft_device_name.has_type = false;
soft_device_name.has_id = false;
// TODO(fishx): Soft placement logic picks up another task if the
// requested does not exist.
ctx.pflr()->device_set()->FindMatchingDevices(soft_device_name,
&final_devices);
if (!final_devices.empty()) {
final_devices = ColocationGraph::FilterSupportedDevices(
final_devices, supported_devs, /*default_local_device=*/nullptr);
}
}
if (final_devices.empty()) {
return errors::InvalidArgument(
"Could not satisfy device specification '", op->GetDeviceParsedName(),
"'. All available devices [",
absl::StrJoin(DevicesToString(ctx.pflr()->device_set()->devices()),
", "),
"]. Eager operation: ", op->DebugString());
}
} else {
// TODO(fishx): Allow setting default device in eager context.
final_devices = ColocationGraph::FilterSupportedDevices(
ctx.pflr()->device_set()->devices(), supported_devs,
/*default_local_device=*/nullptr);
if (final_devices.empty()) {
return errors::InvalidArgument(
"No OpKernel registered to suppport this eager operation:",
op->DebugString());
}
}
DVLOG(1) << "Placer place op [" << op->Name()
<< "] on device: " << final_devices[0]->name();
DVLOG(4) << "Available kernels for " << op->Name() << "are "
<< KernelsRegisteredForOp(op->Name());
op->SetDevice(final_devices[0]);
*device = final_devices[0];
return Status::OK();
}
Status GetOutputDTypes(EagerOperation* op, DataTypeVector* output_dtypes) {
const auto& node_def = op->MutableAttrs()->BuildNodeDef();
const OpDef* op_def = nullptr;
@ -524,7 +448,24 @@ Status EagerLocalExecute(EagerOperation* op, TensorHandle** retvals,
const NodeDef& ndef = op->MutableAttrs()->BuildNodeDef();
if (device == nullptr) {
TF_RETURN_IF_ERROR(SelectDevice(op, ndef, ctx, &device));
PrioritizedDeviceTypeVector supported_devs;
TF_RETURN_IF_ERROR(SupportedDeviceTypesForNode(
ctx.prioritized_device_type_list(), ndef, &supported_devs,
&ctx.HostCPU()->parsed_name()));
if (supported_devs.empty()) {
return errors::NotFound("Could not find valid device for node.\nNode:",
FormatNodeDefForError(ndef),
"\nAll kernels registered for op ", ndef.op(),
" :\n", KernelsRegisteredForOp(ndef.op()));
}
TF_RETURN_IF_ERROR(
ctx.SelectDevice(op->GetDeviceParsedName(), supported_devs, &device));
DVLOG(1) << "Placer place op [" << op->Name()
<< "] on device: " << device->name();
DVLOG(4) << "Available kernels for " << op->Name() << "are "
<< KernelsRegisteredForOp(op->Name());
op->SetDevice(device);
}
if (ctx.LogDevicePlacement() || VLOG_IS_ON(1)) {
string msg = strings::StrCat("Executing op ", ndef.op(), " in device ",