From 53889e9671945c98323044e5e9badc0ada82b13a Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" <gardener@tensorflow.org> Date: Tue, 28 Jan 2020 15:15:04 -0800 Subject: [PATCH] Move some device placement logic to EagerContext. This will make it easier to incorporate this logic into places like pywrap_tensor.cc next, which needs to use the logic towards fixing b/139690309. I took the opportunity do perform some slight reshuffling of the original logic to make it more readable. PiperOrigin-RevId: 292023452 Change-Id: I2af49f738bf38b776c20fd6edbd525d2429c831f --- .../core/common_runtime/eager/context.cc | 70 ++++++++++++++ .../core/common_runtime/eager/context.h | 17 ++++ .../core/common_runtime/eager/execute.cc | 95 ++++--------------- 3 files changed, 105 insertions(+), 77 deletions(-) diff --git a/tensorflow/core/common_runtime/eager/context.cc b/tensorflow/core/common_runtime/eager/context.cc index d80c949286a..301b75dfa68 100644 --- a/tensorflow/core/common_runtime/eager/context.cc +++ b/tensorflow/core/common_runtime/eager/context.cc @@ -30,6 +30,7 @@ limitations under the License. #include "tensorflow/core/common_runtime/collective_executor_mgr.h" #include "tensorflow/core/common_runtime/collective_param_resolver_local.h" +#include "tensorflow/core/common_runtime/colocation_graph.h" #include "tensorflow/core/common_runtime/device_resolver_local.h" #include "tensorflow/core/common_runtime/device_set.h" #include "tensorflow/core/common_runtime/eager/process_function_library_runtime.h" @@ -154,6 +155,75 @@ void EagerContext::InitPrioritizedDeviceTypeList() { prioritized_device_type_list_ = ds.PrioritizedDeviceTypeList(); } +namespace { +// Using absl::StrJoin with lambda does not work in tf-lite builds. +// TODO(b/148160441): Replace with absl::StrJoin once DeviceBase has operator<<. +std::vector<string> DevicesToString(const std::vector<Device*>& devices) { + std::vector<string> v; + v.reserve(devices.size()); + for (Device* d : devices) { + v.push_back(d->name()); + } + return v; +} +} // namespace + +Status EagerContext::SelectDevice(const DeviceNameUtils::ParsedName& preferred, + const PrioritizedDeviceTypeVector& supported, + Device** device) const { + std::vector<Device*> selected; + const DeviceSet& pflr_devices = *pflr()->device_set(); + + // If there are no preferred devices, select the first registered device from + // the supported device list. + if (!DeviceNameUtils::HasSomeDetails(preferred)) { + // TODO(b/148213212): Allow setting default device in eager context. + selected = ColocationGraph::FilterSupportedDevices( + pflr_devices.devices(), supported, /*default_local_device=*/nullptr); + if (selected.empty()) { + return errors::InvalidArgument( + "No supported device found in available devices [", + absl::StrJoin(DevicesToString(pflr_devices.devices()), ", "), "]."); + } + *device = selected[0]; + return Status::OK(); + } + + // If the caller specified a preferred device, select the first matching + // registered device from the supported device list. If nothing matches and + // soft placement is enabled, pick a suitable device from the available ones. + pflr_devices.FindMatchingDevices(preferred, &selected); + + if (!selected.empty()) { + selected = ColocationGraph::FilterSupportedDevices( + selected, supported, /*default_local_device=*/nullptr); + } + + if (selected.empty() && AllowSoftPlacement()) { + DeviceNameUtils::ParsedName soft_device_name = preferred; + soft_device_name.type.clear(); + soft_device_name.has_type = false; + soft_device_name.has_id = false; + // TODO(b/148213746): Soft placement logic picks up another task if the + // requested does not exist. + pflr_devices.FindMatchingDevices(soft_device_name, &selected); + if (!selected.empty()) { + selected = ColocationGraph::FilterSupportedDevices( + selected, supported, /*default_local_device=*/nullptr); + } + } + + if (selected.empty()) { + return errors::InvalidArgument( + "Could not satisfy device specification '", preferred, + "'. All available devices [", + absl::StrJoin(DevicesToString(pflr_devices.devices()), ", "), "]."); + } + + *device = selected[0]; + return Status::OK(); +} + void EagerContext::ResetClusterFLR( DistributedFunctionLibraryRuntime* cluster_flr) { cluster_flr_.Reset(cluster_flr, lazy_copy_function_remote_inputs_); diff --git a/tensorflow/core/common_runtime/eager/context.h b/tensorflow/core/common_runtime/eager/context.h index f3fd7cf628f..de573410442 100644 --- a/tensorflow/core/common_runtime/eager/context.h +++ b/tensorflow/core/common_runtime/eager/context.h @@ -156,6 +156,23 @@ class EagerContext : public core::RefCounted { // Returns the device placement policy for the current thread. ContextDevicePlacementPolicy GetDevicePlacementPolicy() const; + // Select an appropriate device for an operation. + // + // Given the preferred device for the operation, and the list of devices the + // operation supports, finds the best suitable device for the operation in + // this context. + // + // The preferred device is specified as a `ParsedName` containing the elements + // (details) that the resulting device should match. If there are no such + // devices, and the context currently allows soft device placement, a suitable + // device not matching `preferred` will be chosen. + // + // The chosen device is stored in the `device` argument. The argument is not + // modified unless this method returns `Status::OK()`. + Status SelectDevice(const DeviceNameUtils::ParsedName& preferred, + const PrioritizedDeviceTypeVector& supported, + Device** device) const; + // Sets the implicit copy policy for the current thread. void SetThreadLocalMirroringPolicy(ContextMirroringPolicy); diff --git a/tensorflow/core/common_runtime/eager/execute.cc b/tensorflow/core/common_runtime/eager/execute.cc index 7f4594662de..c81945f7ef0 100644 --- a/tensorflow/core/common_runtime/eager/execute.cc +++ b/tensorflow/core/common_runtime/eager/execute.cc @@ -75,16 +75,6 @@ namespace tensorflow { namespace { -// Using absl::StrJoin with lambda does not work in tf-lite builds. -std::vector<string> DevicesToString(const std::vector<Device*> devices) { - std::vector<string> v; - v.reserve(devices.size()); - for (Device* d : devices) { - v.push_back(d->name()); - } - return v; -} - const string& DeviceNameOrUnspecified(Device* device) { static string* unspecified_string = new string("<unspecified>"); return (device == nullptr) ? *unspecified_string : device->name(); @@ -208,72 +198,6 @@ Status ValidateInputTypeAndPlacement( return Status::OK(); } -Status SelectDevice(EagerOperation* op, const NodeDef& ndef, - const EagerContext& ctx, Device** device) { - std::vector<Device*> final_devices; - PrioritizedDeviceTypeVector supported_devs; - TF_RETURN_IF_ERROR(SupportedDeviceTypesForNode( - ctx.prioritized_device_type_list(), ndef, &supported_devs, - &ctx.HostCPU()->parsed_name())); - if (supported_devs.empty()) { - return errors::NotFound("Could not find valid device for node.\nNode:", - FormatNodeDefForError(ndef), - "\nAll kernels registered for op ", ndef.op(), - " :\n", KernelsRegisteredForOp(ndef.op())); - } - - if (DeviceNameUtils::HasSomeDetails(op->GetDeviceParsedName())) { - ctx.pflr()->device_set()->FindMatchingDevices(op->GetDeviceParsedName(), - &final_devices); - - if (!final_devices.empty()) { - final_devices = ColocationGraph::FilterSupportedDevices( - final_devices, supported_devs, /*default_local_device=*/nullptr); - } - - if (final_devices.empty() && ctx.AllowSoftPlacement()) { - DeviceNameUtils::ParsedName soft_device_name = op->GetDeviceParsedName(); - soft_device_name.type.clear(); - soft_device_name.has_type = false; - soft_device_name.has_id = false; - // TODO(fishx): Soft placement logic picks up another task if the - // requested does not exist. - ctx.pflr()->device_set()->FindMatchingDevices(soft_device_name, - &final_devices); - if (!final_devices.empty()) { - final_devices = ColocationGraph::FilterSupportedDevices( - final_devices, supported_devs, /*default_local_device=*/nullptr); - } - } - if (final_devices.empty()) { - return errors::InvalidArgument( - "Could not satisfy device specification '", op->GetDeviceParsedName(), - "'. All available devices [", - absl::StrJoin(DevicesToString(ctx.pflr()->device_set()->devices()), - ", "), - "]. Eager operation: ", op->DebugString()); - } - } else { - // TODO(fishx): Allow setting default device in eager context. - final_devices = ColocationGraph::FilterSupportedDevices( - ctx.pflr()->device_set()->devices(), supported_devs, - /*default_local_device=*/nullptr); - if (final_devices.empty()) { - return errors::InvalidArgument( - "No OpKernel registered to suppport this eager operation:", - op->DebugString()); - } - } - - DVLOG(1) << "Placer place op [" << op->Name() - << "] on device: " << final_devices[0]->name(); - DVLOG(4) << "Available kernels for " << op->Name() << "are " - << KernelsRegisteredForOp(op->Name()); - op->SetDevice(final_devices[0]); - *device = final_devices[0]; - return Status::OK(); -} - Status GetOutputDTypes(EagerOperation* op, DataTypeVector* output_dtypes) { const auto& node_def = op->MutableAttrs()->BuildNodeDef(); const OpDef* op_def = nullptr; @@ -524,7 +448,24 @@ Status EagerLocalExecute(EagerOperation* op, TensorHandle** retvals, const NodeDef& ndef = op->MutableAttrs()->BuildNodeDef(); if (device == nullptr) { - TF_RETURN_IF_ERROR(SelectDevice(op, ndef, ctx, &device)); + PrioritizedDeviceTypeVector supported_devs; + TF_RETURN_IF_ERROR(SupportedDeviceTypesForNode( + ctx.prioritized_device_type_list(), ndef, &supported_devs, + &ctx.HostCPU()->parsed_name())); + if (supported_devs.empty()) { + return errors::NotFound("Could not find valid device for node.\nNode:", + FormatNodeDefForError(ndef), + "\nAll kernels registered for op ", ndef.op(), + " :\n", KernelsRegisteredForOp(ndef.op())); + } + TF_RETURN_IF_ERROR( + ctx.SelectDevice(op->GetDeviceParsedName(), supported_devs, &device)); + + DVLOG(1) << "Placer place op [" << op->Name() + << "] on device: " << device->name(); + DVLOG(4) << "Available kernels for " << op->Name() << "are " + << KernelsRegisteredForOp(op->Name()); + op->SetDevice(device); } if (ctx.LogDevicePlacement() || VLOG_IS_ON(1)) { string msg = strings::StrCat("Executing op ", ndef.op(), " in device ",