Move some device placement logic to EagerContext.
This will make it easier to incorporate this logic into places like pywrap_tensor.cc next, which needs to use the logic towards fixing b/139690309. I took the opportunity do perform some slight reshuffling of the original logic to make it more readable. PiperOrigin-RevId: 292023452 Change-Id: I2af49f738bf38b776c20fd6edbd525d2429c831f
This commit is contained in:
parent
7a4123bda5
commit
53889e9671
tensorflow/core/common_runtime/eager
@ -30,6 +30,7 @@ limitations under the License.
|
|||||||
|
|
||||||
#include "tensorflow/core/common_runtime/collective_executor_mgr.h"
|
#include "tensorflow/core/common_runtime/collective_executor_mgr.h"
|
||||||
#include "tensorflow/core/common_runtime/collective_param_resolver_local.h"
|
#include "tensorflow/core/common_runtime/collective_param_resolver_local.h"
|
||||||
|
#include "tensorflow/core/common_runtime/colocation_graph.h"
|
||||||
#include "tensorflow/core/common_runtime/device_resolver_local.h"
|
#include "tensorflow/core/common_runtime/device_resolver_local.h"
|
||||||
#include "tensorflow/core/common_runtime/device_set.h"
|
#include "tensorflow/core/common_runtime/device_set.h"
|
||||||
#include "tensorflow/core/common_runtime/eager/process_function_library_runtime.h"
|
#include "tensorflow/core/common_runtime/eager/process_function_library_runtime.h"
|
||||||
@ -154,6 +155,75 @@ void EagerContext::InitPrioritizedDeviceTypeList() {
|
|||||||
prioritized_device_type_list_ = ds.PrioritizedDeviceTypeList();
|
prioritized_device_type_list_ = ds.PrioritizedDeviceTypeList();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
namespace {
|
||||||
|
// Using absl::StrJoin with lambda does not work in tf-lite builds.
|
||||||
|
// TODO(b/148160441): Replace with absl::StrJoin once DeviceBase has operator<<.
|
||||||
|
std::vector<string> DevicesToString(const std::vector<Device*>& devices) {
|
||||||
|
std::vector<string> v;
|
||||||
|
v.reserve(devices.size());
|
||||||
|
for (Device* d : devices) {
|
||||||
|
v.push_back(d->name());
|
||||||
|
}
|
||||||
|
return v;
|
||||||
|
}
|
||||||
|
} // namespace
|
||||||
|
|
||||||
|
Status EagerContext::SelectDevice(const DeviceNameUtils::ParsedName& preferred,
|
||||||
|
const PrioritizedDeviceTypeVector& supported,
|
||||||
|
Device** device) const {
|
||||||
|
std::vector<Device*> selected;
|
||||||
|
const DeviceSet& pflr_devices = *pflr()->device_set();
|
||||||
|
|
||||||
|
// If there are no preferred devices, select the first registered device from
|
||||||
|
// the supported device list.
|
||||||
|
if (!DeviceNameUtils::HasSomeDetails(preferred)) {
|
||||||
|
// TODO(b/148213212): Allow setting default device in eager context.
|
||||||
|
selected = ColocationGraph::FilterSupportedDevices(
|
||||||
|
pflr_devices.devices(), supported, /*default_local_device=*/nullptr);
|
||||||
|
if (selected.empty()) {
|
||||||
|
return errors::InvalidArgument(
|
||||||
|
"No supported device found in available devices [",
|
||||||
|
absl::StrJoin(DevicesToString(pflr_devices.devices()), ", "), "].");
|
||||||
|
}
|
||||||
|
*device = selected[0];
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
|
||||||
|
// If the caller specified a preferred device, select the first matching
|
||||||
|
// registered device from the supported device list. If nothing matches and
|
||||||
|
// soft placement is enabled, pick a suitable device from the available ones.
|
||||||
|
pflr_devices.FindMatchingDevices(preferred, &selected);
|
||||||
|
|
||||||
|
if (!selected.empty()) {
|
||||||
|
selected = ColocationGraph::FilterSupportedDevices(
|
||||||
|
selected, supported, /*default_local_device=*/nullptr);
|
||||||
|
}
|
||||||
|
|
||||||
|
if (selected.empty() && AllowSoftPlacement()) {
|
||||||
|
DeviceNameUtils::ParsedName soft_device_name = preferred;
|
||||||
|
soft_device_name.type.clear();
|
||||||
|
soft_device_name.has_type = false;
|
||||||
|
soft_device_name.has_id = false;
|
||||||
|
// TODO(b/148213746): Soft placement logic picks up another task if the
|
||||||
|
// requested does not exist.
|
||||||
|
pflr_devices.FindMatchingDevices(soft_device_name, &selected);
|
||||||
|
if (!selected.empty()) {
|
||||||
|
selected = ColocationGraph::FilterSupportedDevices(
|
||||||
|
selected, supported, /*default_local_device=*/nullptr);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if (selected.empty()) {
|
||||||
|
return errors::InvalidArgument(
|
||||||
|
"Could not satisfy device specification '", preferred,
|
||||||
|
"'. All available devices [",
|
||||||
|
absl::StrJoin(DevicesToString(pflr_devices.devices()), ", "), "].");
|
||||||
|
}
|
||||||
|
|
||||||
|
*device = selected[0];
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
|
||||||
void EagerContext::ResetClusterFLR(
|
void EagerContext::ResetClusterFLR(
|
||||||
DistributedFunctionLibraryRuntime* cluster_flr) {
|
DistributedFunctionLibraryRuntime* cluster_flr) {
|
||||||
cluster_flr_.Reset(cluster_flr, lazy_copy_function_remote_inputs_);
|
cluster_flr_.Reset(cluster_flr, lazy_copy_function_remote_inputs_);
|
||||||
|
@ -156,6 +156,23 @@ class EagerContext : public core::RefCounted {
|
|||||||
// Returns the device placement policy for the current thread.
|
// Returns the device placement policy for the current thread.
|
||||||
ContextDevicePlacementPolicy GetDevicePlacementPolicy() const;
|
ContextDevicePlacementPolicy GetDevicePlacementPolicy() const;
|
||||||
|
|
||||||
|
// Select an appropriate device for an operation.
|
||||||
|
//
|
||||||
|
// Given the preferred device for the operation, and the list of devices the
|
||||||
|
// operation supports, finds the best suitable device for the operation in
|
||||||
|
// this context.
|
||||||
|
//
|
||||||
|
// The preferred device is specified as a `ParsedName` containing the elements
|
||||||
|
// (details) that the resulting device should match. If there are no such
|
||||||
|
// devices, and the context currently allows soft device placement, a suitable
|
||||||
|
// device not matching `preferred` will be chosen.
|
||||||
|
//
|
||||||
|
// The chosen device is stored in the `device` argument. The argument is not
|
||||||
|
// modified unless this method returns `Status::OK()`.
|
||||||
|
Status SelectDevice(const DeviceNameUtils::ParsedName& preferred,
|
||||||
|
const PrioritizedDeviceTypeVector& supported,
|
||||||
|
Device** device) const;
|
||||||
|
|
||||||
// Sets the implicit copy policy for the current thread.
|
// Sets the implicit copy policy for the current thread.
|
||||||
void SetThreadLocalMirroringPolicy(ContextMirroringPolicy);
|
void SetThreadLocalMirroringPolicy(ContextMirroringPolicy);
|
||||||
|
|
||||||
|
@ -75,16 +75,6 @@ namespace tensorflow {
|
|||||||
|
|
||||||
namespace {
|
namespace {
|
||||||
|
|
||||||
// Using absl::StrJoin with lambda does not work in tf-lite builds.
|
|
||||||
std::vector<string> DevicesToString(const std::vector<Device*> devices) {
|
|
||||||
std::vector<string> v;
|
|
||||||
v.reserve(devices.size());
|
|
||||||
for (Device* d : devices) {
|
|
||||||
v.push_back(d->name());
|
|
||||||
}
|
|
||||||
return v;
|
|
||||||
}
|
|
||||||
|
|
||||||
const string& DeviceNameOrUnspecified(Device* device) {
|
const string& DeviceNameOrUnspecified(Device* device) {
|
||||||
static string* unspecified_string = new string("<unspecified>");
|
static string* unspecified_string = new string("<unspecified>");
|
||||||
return (device == nullptr) ? *unspecified_string : device->name();
|
return (device == nullptr) ? *unspecified_string : device->name();
|
||||||
@ -208,72 +198,6 @@ Status ValidateInputTypeAndPlacement(
|
|||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
|
||||||
Status SelectDevice(EagerOperation* op, const NodeDef& ndef,
|
|
||||||
const EagerContext& ctx, Device** device) {
|
|
||||||
std::vector<Device*> final_devices;
|
|
||||||
PrioritizedDeviceTypeVector supported_devs;
|
|
||||||
TF_RETURN_IF_ERROR(SupportedDeviceTypesForNode(
|
|
||||||
ctx.prioritized_device_type_list(), ndef, &supported_devs,
|
|
||||||
&ctx.HostCPU()->parsed_name()));
|
|
||||||
if (supported_devs.empty()) {
|
|
||||||
return errors::NotFound("Could not find valid device for node.\nNode:",
|
|
||||||
FormatNodeDefForError(ndef),
|
|
||||||
"\nAll kernels registered for op ", ndef.op(),
|
|
||||||
" :\n", KernelsRegisteredForOp(ndef.op()));
|
|
||||||
}
|
|
||||||
|
|
||||||
if (DeviceNameUtils::HasSomeDetails(op->GetDeviceParsedName())) {
|
|
||||||
ctx.pflr()->device_set()->FindMatchingDevices(op->GetDeviceParsedName(),
|
|
||||||
&final_devices);
|
|
||||||
|
|
||||||
if (!final_devices.empty()) {
|
|
||||||
final_devices = ColocationGraph::FilterSupportedDevices(
|
|
||||||
final_devices, supported_devs, /*default_local_device=*/nullptr);
|
|
||||||
}
|
|
||||||
|
|
||||||
if (final_devices.empty() && ctx.AllowSoftPlacement()) {
|
|
||||||
DeviceNameUtils::ParsedName soft_device_name = op->GetDeviceParsedName();
|
|
||||||
soft_device_name.type.clear();
|
|
||||||
soft_device_name.has_type = false;
|
|
||||||
soft_device_name.has_id = false;
|
|
||||||
// TODO(fishx): Soft placement logic picks up another task if the
|
|
||||||
// requested does not exist.
|
|
||||||
ctx.pflr()->device_set()->FindMatchingDevices(soft_device_name,
|
|
||||||
&final_devices);
|
|
||||||
if (!final_devices.empty()) {
|
|
||||||
final_devices = ColocationGraph::FilterSupportedDevices(
|
|
||||||
final_devices, supported_devs, /*default_local_device=*/nullptr);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if (final_devices.empty()) {
|
|
||||||
return errors::InvalidArgument(
|
|
||||||
"Could not satisfy device specification '", op->GetDeviceParsedName(),
|
|
||||||
"'. All available devices [",
|
|
||||||
absl::StrJoin(DevicesToString(ctx.pflr()->device_set()->devices()),
|
|
||||||
", "),
|
|
||||||
"]. Eager operation: ", op->DebugString());
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
// TODO(fishx): Allow setting default device in eager context.
|
|
||||||
final_devices = ColocationGraph::FilterSupportedDevices(
|
|
||||||
ctx.pflr()->device_set()->devices(), supported_devs,
|
|
||||||
/*default_local_device=*/nullptr);
|
|
||||||
if (final_devices.empty()) {
|
|
||||||
return errors::InvalidArgument(
|
|
||||||
"No OpKernel registered to suppport this eager operation:",
|
|
||||||
op->DebugString());
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
DVLOG(1) << "Placer place op [" << op->Name()
|
|
||||||
<< "] on device: " << final_devices[0]->name();
|
|
||||||
DVLOG(4) << "Available kernels for " << op->Name() << "are "
|
|
||||||
<< KernelsRegisteredForOp(op->Name());
|
|
||||||
op->SetDevice(final_devices[0]);
|
|
||||||
*device = final_devices[0];
|
|
||||||
return Status::OK();
|
|
||||||
}
|
|
||||||
|
|
||||||
Status GetOutputDTypes(EagerOperation* op, DataTypeVector* output_dtypes) {
|
Status GetOutputDTypes(EagerOperation* op, DataTypeVector* output_dtypes) {
|
||||||
const auto& node_def = op->MutableAttrs()->BuildNodeDef();
|
const auto& node_def = op->MutableAttrs()->BuildNodeDef();
|
||||||
const OpDef* op_def = nullptr;
|
const OpDef* op_def = nullptr;
|
||||||
@ -524,7 +448,24 @@ Status EagerLocalExecute(EagerOperation* op, TensorHandle** retvals,
|
|||||||
|
|
||||||
const NodeDef& ndef = op->MutableAttrs()->BuildNodeDef();
|
const NodeDef& ndef = op->MutableAttrs()->BuildNodeDef();
|
||||||
if (device == nullptr) {
|
if (device == nullptr) {
|
||||||
TF_RETURN_IF_ERROR(SelectDevice(op, ndef, ctx, &device));
|
PrioritizedDeviceTypeVector supported_devs;
|
||||||
|
TF_RETURN_IF_ERROR(SupportedDeviceTypesForNode(
|
||||||
|
ctx.prioritized_device_type_list(), ndef, &supported_devs,
|
||||||
|
&ctx.HostCPU()->parsed_name()));
|
||||||
|
if (supported_devs.empty()) {
|
||||||
|
return errors::NotFound("Could not find valid device for node.\nNode:",
|
||||||
|
FormatNodeDefForError(ndef),
|
||||||
|
"\nAll kernels registered for op ", ndef.op(),
|
||||||
|
" :\n", KernelsRegisteredForOp(ndef.op()));
|
||||||
|
}
|
||||||
|
TF_RETURN_IF_ERROR(
|
||||||
|
ctx.SelectDevice(op->GetDeviceParsedName(), supported_devs, &device));
|
||||||
|
|
||||||
|
DVLOG(1) << "Placer place op [" << op->Name()
|
||||||
|
<< "] on device: " << device->name();
|
||||||
|
DVLOG(4) << "Available kernels for " << op->Name() << "are "
|
||||||
|
<< KernelsRegisteredForOp(op->Name());
|
||||||
|
op->SetDevice(device);
|
||||||
}
|
}
|
||||||
if (ctx.LogDevicePlacement() || VLOG_IS_ON(1)) {
|
if (ctx.LogDevicePlacement() || VLOG_IS_ON(1)) {
|
||||||
string msg = strings::StrCat("Executing op ", ndef.op(), " in device ",
|
string msg = strings::StrCat("Executing op ", ndef.op(), " in device ",
|
||||||
|
Loading…
Reference in New Issue
Block a user