Move some device placement logic to EagerContext.
This will make it easier to incorporate this logic into places like pywrap_tensor.cc next, which needs to use the logic towards fixing b/139690309. I took the opportunity do perform some slight reshuffling of the original logic to make it more readable. PiperOrigin-RevId: 292023452 Change-Id: I2af49f738bf38b776c20fd6edbd525d2429c831f
This commit is contained in:
parent
7a4123bda5
commit
53889e9671
@ -30,6 +30,7 @@ limitations under the License.
|
||||
|
||||
#include "tensorflow/core/common_runtime/collective_executor_mgr.h"
|
||||
#include "tensorflow/core/common_runtime/collective_param_resolver_local.h"
|
||||
#include "tensorflow/core/common_runtime/colocation_graph.h"
|
||||
#include "tensorflow/core/common_runtime/device_resolver_local.h"
|
||||
#include "tensorflow/core/common_runtime/device_set.h"
|
||||
#include "tensorflow/core/common_runtime/eager/process_function_library_runtime.h"
|
||||
@ -154,6 +155,75 @@ void EagerContext::InitPrioritizedDeviceTypeList() {
|
||||
prioritized_device_type_list_ = ds.PrioritizedDeviceTypeList();
|
||||
}
|
||||
|
||||
namespace {
|
||||
// Using absl::StrJoin with lambda does not work in tf-lite builds.
|
||||
// TODO(b/148160441): Replace with absl::StrJoin once DeviceBase has operator<<.
|
||||
std::vector<string> DevicesToString(const std::vector<Device*>& devices) {
|
||||
std::vector<string> v;
|
||||
v.reserve(devices.size());
|
||||
for (Device* d : devices) {
|
||||
v.push_back(d->name());
|
||||
}
|
||||
return v;
|
||||
}
|
||||
} // namespace
|
||||
|
||||
Status EagerContext::SelectDevice(const DeviceNameUtils::ParsedName& preferred,
|
||||
const PrioritizedDeviceTypeVector& supported,
|
||||
Device** device) const {
|
||||
std::vector<Device*> selected;
|
||||
const DeviceSet& pflr_devices = *pflr()->device_set();
|
||||
|
||||
// If there are no preferred devices, select the first registered device from
|
||||
// the supported device list.
|
||||
if (!DeviceNameUtils::HasSomeDetails(preferred)) {
|
||||
// TODO(b/148213212): Allow setting default device in eager context.
|
||||
selected = ColocationGraph::FilterSupportedDevices(
|
||||
pflr_devices.devices(), supported, /*default_local_device=*/nullptr);
|
||||
if (selected.empty()) {
|
||||
return errors::InvalidArgument(
|
||||
"No supported device found in available devices [",
|
||||
absl::StrJoin(DevicesToString(pflr_devices.devices()), ", "), "].");
|
||||
}
|
||||
*device = selected[0];
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
// If the caller specified a preferred device, select the first matching
|
||||
// registered device from the supported device list. If nothing matches and
|
||||
// soft placement is enabled, pick a suitable device from the available ones.
|
||||
pflr_devices.FindMatchingDevices(preferred, &selected);
|
||||
|
||||
if (!selected.empty()) {
|
||||
selected = ColocationGraph::FilterSupportedDevices(
|
||||
selected, supported, /*default_local_device=*/nullptr);
|
||||
}
|
||||
|
||||
if (selected.empty() && AllowSoftPlacement()) {
|
||||
DeviceNameUtils::ParsedName soft_device_name = preferred;
|
||||
soft_device_name.type.clear();
|
||||
soft_device_name.has_type = false;
|
||||
soft_device_name.has_id = false;
|
||||
// TODO(b/148213746): Soft placement logic picks up another task if the
|
||||
// requested does not exist.
|
||||
pflr_devices.FindMatchingDevices(soft_device_name, &selected);
|
||||
if (!selected.empty()) {
|
||||
selected = ColocationGraph::FilterSupportedDevices(
|
||||
selected, supported, /*default_local_device=*/nullptr);
|
||||
}
|
||||
}
|
||||
|
||||
if (selected.empty()) {
|
||||
return errors::InvalidArgument(
|
||||
"Could not satisfy device specification '", preferred,
|
||||
"'. All available devices [",
|
||||
absl::StrJoin(DevicesToString(pflr_devices.devices()), ", "), "].");
|
||||
}
|
||||
|
||||
*device = selected[0];
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
void EagerContext::ResetClusterFLR(
|
||||
DistributedFunctionLibraryRuntime* cluster_flr) {
|
||||
cluster_flr_.Reset(cluster_flr, lazy_copy_function_remote_inputs_);
|
||||
|
@ -156,6 +156,23 @@ class EagerContext : public core::RefCounted {
|
||||
// Returns the device placement policy for the current thread.
|
||||
ContextDevicePlacementPolicy GetDevicePlacementPolicy() const;
|
||||
|
||||
// Select an appropriate device for an operation.
|
||||
//
|
||||
// Given the preferred device for the operation, and the list of devices the
|
||||
// operation supports, finds the best suitable device for the operation in
|
||||
// this context.
|
||||
//
|
||||
// The preferred device is specified as a `ParsedName` containing the elements
|
||||
// (details) that the resulting device should match. If there are no such
|
||||
// devices, and the context currently allows soft device placement, a suitable
|
||||
// device not matching `preferred` will be chosen.
|
||||
//
|
||||
// The chosen device is stored in the `device` argument. The argument is not
|
||||
// modified unless this method returns `Status::OK()`.
|
||||
Status SelectDevice(const DeviceNameUtils::ParsedName& preferred,
|
||||
const PrioritizedDeviceTypeVector& supported,
|
||||
Device** device) const;
|
||||
|
||||
// Sets the implicit copy policy for the current thread.
|
||||
void SetThreadLocalMirroringPolicy(ContextMirroringPolicy);
|
||||
|
||||
|
@ -75,16 +75,6 @@ namespace tensorflow {
|
||||
|
||||
namespace {
|
||||
|
||||
// Using absl::StrJoin with lambda does not work in tf-lite builds.
|
||||
std::vector<string> DevicesToString(const std::vector<Device*> devices) {
|
||||
std::vector<string> v;
|
||||
v.reserve(devices.size());
|
||||
for (Device* d : devices) {
|
||||
v.push_back(d->name());
|
||||
}
|
||||
return v;
|
||||
}
|
||||
|
||||
const string& DeviceNameOrUnspecified(Device* device) {
|
||||
static string* unspecified_string = new string("<unspecified>");
|
||||
return (device == nullptr) ? *unspecified_string : device->name();
|
||||
@ -208,72 +198,6 @@ Status ValidateInputTypeAndPlacement(
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status SelectDevice(EagerOperation* op, const NodeDef& ndef,
|
||||
const EagerContext& ctx, Device** device) {
|
||||
std::vector<Device*> final_devices;
|
||||
PrioritizedDeviceTypeVector supported_devs;
|
||||
TF_RETURN_IF_ERROR(SupportedDeviceTypesForNode(
|
||||
ctx.prioritized_device_type_list(), ndef, &supported_devs,
|
||||
&ctx.HostCPU()->parsed_name()));
|
||||
if (supported_devs.empty()) {
|
||||
return errors::NotFound("Could not find valid device for node.\nNode:",
|
||||
FormatNodeDefForError(ndef),
|
||||
"\nAll kernels registered for op ", ndef.op(),
|
||||
" :\n", KernelsRegisteredForOp(ndef.op()));
|
||||
}
|
||||
|
||||
if (DeviceNameUtils::HasSomeDetails(op->GetDeviceParsedName())) {
|
||||
ctx.pflr()->device_set()->FindMatchingDevices(op->GetDeviceParsedName(),
|
||||
&final_devices);
|
||||
|
||||
if (!final_devices.empty()) {
|
||||
final_devices = ColocationGraph::FilterSupportedDevices(
|
||||
final_devices, supported_devs, /*default_local_device=*/nullptr);
|
||||
}
|
||||
|
||||
if (final_devices.empty() && ctx.AllowSoftPlacement()) {
|
||||
DeviceNameUtils::ParsedName soft_device_name = op->GetDeviceParsedName();
|
||||
soft_device_name.type.clear();
|
||||
soft_device_name.has_type = false;
|
||||
soft_device_name.has_id = false;
|
||||
// TODO(fishx): Soft placement logic picks up another task if the
|
||||
// requested does not exist.
|
||||
ctx.pflr()->device_set()->FindMatchingDevices(soft_device_name,
|
||||
&final_devices);
|
||||
if (!final_devices.empty()) {
|
||||
final_devices = ColocationGraph::FilterSupportedDevices(
|
||||
final_devices, supported_devs, /*default_local_device=*/nullptr);
|
||||
}
|
||||
}
|
||||
if (final_devices.empty()) {
|
||||
return errors::InvalidArgument(
|
||||
"Could not satisfy device specification '", op->GetDeviceParsedName(),
|
||||
"'. All available devices [",
|
||||
absl::StrJoin(DevicesToString(ctx.pflr()->device_set()->devices()),
|
||||
", "),
|
||||
"]. Eager operation: ", op->DebugString());
|
||||
}
|
||||
} else {
|
||||
// TODO(fishx): Allow setting default device in eager context.
|
||||
final_devices = ColocationGraph::FilterSupportedDevices(
|
||||
ctx.pflr()->device_set()->devices(), supported_devs,
|
||||
/*default_local_device=*/nullptr);
|
||||
if (final_devices.empty()) {
|
||||
return errors::InvalidArgument(
|
||||
"No OpKernel registered to suppport this eager operation:",
|
||||
op->DebugString());
|
||||
}
|
||||
}
|
||||
|
||||
DVLOG(1) << "Placer place op [" << op->Name()
|
||||
<< "] on device: " << final_devices[0]->name();
|
||||
DVLOG(4) << "Available kernels for " << op->Name() << "are "
|
||||
<< KernelsRegisteredForOp(op->Name());
|
||||
op->SetDevice(final_devices[0]);
|
||||
*device = final_devices[0];
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status GetOutputDTypes(EagerOperation* op, DataTypeVector* output_dtypes) {
|
||||
const auto& node_def = op->MutableAttrs()->BuildNodeDef();
|
||||
const OpDef* op_def = nullptr;
|
||||
@ -524,7 +448,24 @@ Status EagerLocalExecute(EagerOperation* op, TensorHandle** retvals,
|
||||
|
||||
const NodeDef& ndef = op->MutableAttrs()->BuildNodeDef();
|
||||
if (device == nullptr) {
|
||||
TF_RETURN_IF_ERROR(SelectDevice(op, ndef, ctx, &device));
|
||||
PrioritizedDeviceTypeVector supported_devs;
|
||||
TF_RETURN_IF_ERROR(SupportedDeviceTypesForNode(
|
||||
ctx.prioritized_device_type_list(), ndef, &supported_devs,
|
||||
&ctx.HostCPU()->parsed_name()));
|
||||
if (supported_devs.empty()) {
|
||||
return errors::NotFound("Could not find valid device for node.\nNode:",
|
||||
FormatNodeDefForError(ndef),
|
||||
"\nAll kernels registered for op ", ndef.op(),
|
||||
" :\n", KernelsRegisteredForOp(ndef.op()));
|
||||
}
|
||||
TF_RETURN_IF_ERROR(
|
||||
ctx.SelectDevice(op->GetDeviceParsedName(), supported_devs, &device));
|
||||
|
||||
DVLOG(1) << "Placer place op [" << op->Name()
|
||||
<< "] on device: " << device->name();
|
||||
DVLOG(4) << "Available kernels for " << op->Name() << "are "
|
||||
<< KernelsRegisteredForOp(op->Name());
|
||||
op->SetDevice(device);
|
||||
}
|
||||
if (ctx.LogDevicePlacement() || VLOG_IS_ON(1)) {
|
||||
string msg = strings::StrCat("Executing op ", ndef.op(), " in device ",
|
||||
|
Loading…
Reference in New Issue
Block a user