From 53889e9671945c98323044e5e9badc0ada82b13a Mon Sep 17 00:00:00 2001
From: "A. Unique TensorFlower" <gardener@tensorflow.org>
Date: Tue, 28 Jan 2020 15:15:04 -0800
Subject: [PATCH] Move some device placement logic to EagerContext.

This will make it easier to incorporate this logic into places like
pywrap_tensor.cc next, which needs to use the logic towards fixing
b/139690309.

I took the opportunity do perform some slight reshuffling of the
original logic to make it more readable.

PiperOrigin-RevId: 292023452
Change-Id: I2af49f738bf38b776c20fd6edbd525d2429c831f
---
 .../core/common_runtime/eager/context.cc      | 70 ++++++++++++++
 .../core/common_runtime/eager/context.h       | 17 ++++
 .../core/common_runtime/eager/execute.cc      | 95 ++++---------------
 3 files changed, 105 insertions(+), 77 deletions(-)

diff --git a/tensorflow/core/common_runtime/eager/context.cc b/tensorflow/core/common_runtime/eager/context.cc
index d80c949286a..301b75dfa68 100644
--- a/tensorflow/core/common_runtime/eager/context.cc
+++ b/tensorflow/core/common_runtime/eager/context.cc
@@ -30,6 +30,7 @@ limitations under the License.
 
 #include "tensorflow/core/common_runtime/collective_executor_mgr.h"
 #include "tensorflow/core/common_runtime/collective_param_resolver_local.h"
+#include "tensorflow/core/common_runtime/colocation_graph.h"
 #include "tensorflow/core/common_runtime/device_resolver_local.h"
 #include "tensorflow/core/common_runtime/device_set.h"
 #include "tensorflow/core/common_runtime/eager/process_function_library_runtime.h"
@@ -154,6 +155,75 @@ void EagerContext::InitPrioritizedDeviceTypeList() {
   prioritized_device_type_list_ = ds.PrioritizedDeviceTypeList();
 }
 
+namespace {
+// Using absl::StrJoin with lambda does not work in tf-lite builds.
+// TODO(b/148160441): Replace with absl::StrJoin once DeviceBase has operator<<.
+std::vector<string> DevicesToString(const std::vector<Device*>& devices) {
+  std::vector<string> v;
+  v.reserve(devices.size());
+  for (Device* d : devices) {
+    v.push_back(d->name());
+  }
+  return v;
+}
+}  // namespace
+
+Status EagerContext::SelectDevice(const DeviceNameUtils::ParsedName& preferred,
+                                  const PrioritizedDeviceTypeVector& supported,
+                                  Device** device) const {
+  std::vector<Device*> selected;
+  const DeviceSet& pflr_devices = *pflr()->device_set();
+
+  // If there are no preferred devices, select the first registered device from
+  // the supported device list.
+  if (!DeviceNameUtils::HasSomeDetails(preferred)) {
+    // TODO(b/148213212): Allow setting default device in eager context.
+    selected = ColocationGraph::FilterSupportedDevices(
+        pflr_devices.devices(), supported, /*default_local_device=*/nullptr);
+    if (selected.empty()) {
+      return errors::InvalidArgument(
+          "No supported device found in available devices [",
+          absl::StrJoin(DevicesToString(pflr_devices.devices()), ", "), "].");
+    }
+    *device = selected[0];
+    return Status::OK();
+  }
+
+  // If the caller specified a preferred device, select the first matching
+  // registered device from the supported device list. If nothing matches and
+  // soft placement is enabled, pick a suitable device from the available ones.
+  pflr_devices.FindMatchingDevices(preferred, &selected);
+
+  if (!selected.empty()) {
+    selected = ColocationGraph::FilterSupportedDevices(
+        selected, supported, /*default_local_device=*/nullptr);
+  }
+
+  if (selected.empty() && AllowSoftPlacement()) {
+    DeviceNameUtils::ParsedName soft_device_name = preferred;
+    soft_device_name.type.clear();
+    soft_device_name.has_type = false;
+    soft_device_name.has_id = false;
+    // TODO(b/148213746): Soft placement logic picks up another task if the
+    // requested does not exist.
+    pflr_devices.FindMatchingDevices(soft_device_name, &selected);
+    if (!selected.empty()) {
+      selected = ColocationGraph::FilterSupportedDevices(
+          selected, supported, /*default_local_device=*/nullptr);
+    }
+  }
+
+  if (selected.empty()) {
+    return errors::InvalidArgument(
+        "Could not satisfy device specification '", preferred,
+        "'. All available devices [",
+        absl::StrJoin(DevicesToString(pflr_devices.devices()), ", "), "].");
+  }
+
+  *device = selected[0];
+  return Status::OK();
+}
+
 void EagerContext::ResetClusterFLR(
     DistributedFunctionLibraryRuntime* cluster_flr) {
   cluster_flr_.Reset(cluster_flr, lazy_copy_function_remote_inputs_);
diff --git a/tensorflow/core/common_runtime/eager/context.h b/tensorflow/core/common_runtime/eager/context.h
index f3fd7cf628f..de573410442 100644
--- a/tensorflow/core/common_runtime/eager/context.h
+++ b/tensorflow/core/common_runtime/eager/context.h
@@ -156,6 +156,23 @@ class EagerContext : public core::RefCounted {
   // Returns the device placement policy for the current thread.
   ContextDevicePlacementPolicy GetDevicePlacementPolicy() const;
 
+  // Select an appropriate device for an operation.
+  //
+  // Given the preferred device for the operation, and the list of devices the
+  // operation supports, finds the best suitable device for the operation in
+  // this context.
+  //
+  // The preferred device is specified as a `ParsedName` containing the elements
+  // (details) that the resulting device should match. If there are no such
+  // devices, and the context currently allows soft device placement, a suitable
+  // device not matching `preferred` will be chosen.
+  //
+  // The chosen device is stored in the `device` argument. The argument is not
+  // modified unless this method returns `Status::OK()`.
+  Status SelectDevice(const DeviceNameUtils::ParsedName& preferred,
+                      const PrioritizedDeviceTypeVector& supported,
+                      Device** device) const;
+
   // Sets the implicit copy policy for the current thread.
   void SetThreadLocalMirroringPolicy(ContextMirroringPolicy);
 
diff --git a/tensorflow/core/common_runtime/eager/execute.cc b/tensorflow/core/common_runtime/eager/execute.cc
index 7f4594662de..c81945f7ef0 100644
--- a/tensorflow/core/common_runtime/eager/execute.cc
+++ b/tensorflow/core/common_runtime/eager/execute.cc
@@ -75,16 +75,6 @@ namespace tensorflow {
 
 namespace {
 
-// Using absl::StrJoin with lambda does not work in tf-lite builds.
-std::vector<string> DevicesToString(const std::vector<Device*> devices) {
-  std::vector<string> v;
-  v.reserve(devices.size());
-  for (Device* d : devices) {
-    v.push_back(d->name());
-  }
-  return v;
-}
-
 const string& DeviceNameOrUnspecified(Device* device) {
   static string* unspecified_string = new string("<unspecified>");
   return (device == nullptr) ? *unspecified_string : device->name();
@@ -208,72 +198,6 @@ Status ValidateInputTypeAndPlacement(
   return Status::OK();
 }
 
-Status SelectDevice(EagerOperation* op, const NodeDef& ndef,
-                    const EagerContext& ctx, Device** device) {
-  std::vector<Device*> final_devices;
-  PrioritizedDeviceTypeVector supported_devs;
-  TF_RETURN_IF_ERROR(SupportedDeviceTypesForNode(
-      ctx.prioritized_device_type_list(), ndef, &supported_devs,
-      &ctx.HostCPU()->parsed_name()));
-  if (supported_devs.empty()) {
-    return errors::NotFound("Could not find valid device for node.\nNode:",
-                            FormatNodeDefForError(ndef),
-                            "\nAll kernels registered for op ", ndef.op(),
-                            " :\n", KernelsRegisteredForOp(ndef.op()));
-  }
-
-  if (DeviceNameUtils::HasSomeDetails(op->GetDeviceParsedName())) {
-    ctx.pflr()->device_set()->FindMatchingDevices(op->GetDeviceParsedName(),
-                                                  &final_devices);
-
-    if (!final_devices.empty()) {
-      final_devices = ColocationGraph::FilterSupportedDevices(
-          final_devices, supported_devs, /*default_local_device=*/nullptr);
-    }
-
-    if (final_devices.empty() && ctx.AllowSoftPlacement()) {
-      DeviceNameUtils::ParsedName soft_device_name = op->GetDeviceParsedName();
-      soft_device_name.type.clear();
-      soft_device_name.has_type = false;
-      soft_device_name.has_id = false;
-      // TODO(fishx): Soft placement logic picks up another task if the
-      // requested does not exist.
-      ctx.pflr()->device_set()->FindMatchingDevices(soft_device_name,
-                                                    &final_devices);
-      if (!final_devices.empty()) {
-        final_devices = ColocationGraph::FilterSupportedDevices(
-            final_devices, supported_devs, /*default_local_device=*/nullptr);
-      }
-    }
-    if (final_devices.empty()) {
-      return errors::InvalidArgument(
-          "Could not satisfy device specification '", op->GetDeviceParsedName(),
-          "'. All available devices [",
-          absl::StrJoin(DevicesToString(ctx.pflr()->device_set()->devices()),
-                        ", "),
-          "]. Eager operation: ", op->DebugString());
-    }
-  } else {
-    // TODO(fishx): Allow setting default device in eager context.
-    final_devices = ColocationGraph::FilterSupportedDevices(
-        ctx.pflr()->device_set()->devices(), supported_devs,
-        /*default_local_device=*/nullptr);
-    if (final_devices.empty()) {
-      return errors::InvalidArgument(
-          "No OpKernel registered to suppport this eager operation:",
-          op->DebugString());
-    }
-  }
-
-  DVLOG(1) << "Placer place op [" << op->Name()
-           << "] on device: " << final_devices[0]->name();
-  DVLOG(4) << "Available kernels for " << op->Name() << "are "
-           << KernelsRegisteredForOp(op->Name());
-  op->SetDevice(final_devices[0]);
-  *device = final_devices[0];
-  return Status::OK();
-}
-
 Status GetOutputDTypes(EagerOperation* op, DataTypeVector* output_dtypes) {
   const auto& node_def = op->MutableAttrs()->BuildNodeDef();
   const OpDef* op_def = nullptr;
@@ -524,7 +448,24 @@ Status EagerLocalExecute(EagerOperation* op, TensorHandle** retvals,
 
     const NodeDef& ndef = op->MutableAttrs()->BuildNodeDef();
     if (device == nullptr) {
-      TF_RETURN_IF_ERROR(SelectDevice(op, ndef, ctx, &device));
+      PrioritizedDeviceTypeVector supported_devs;
+      TF_RETURN_IF_ERROR(SupportedDeviceTypesForNode(
+          ctx.prioritized_device_type_list(), ndef, &supported_devs,
+          &ctx.HostCPU()->parsed_name()));
+      if (supported_devs.empty()) {
+        return errors::NotFound("Could not find valid device for node.\nNode:",
+                                FormatNodeDefForError(ndef),
+                                "\nAll kernels registered for op ", ndef.op(),
+                                " :\n", KernelsRegisteredForOp(ndef.op()));
+      }
+      TF_RETURN_IF_ERROR(
+          ctx.SelectDevice(op->GetDeviceParsedName(), supported_devs, &device));
+
+      DVLOG(1) << "Placer place op [" << op->Name()
+               << "] on device: " << device->name();
+      DVLOG(4) << "Available kernels for " << op->Name() << "are "
+               << KernelsRegisteredForOp(op->Name());
+      op->SetDevice(device);
     }
     if (ctx.LogDevicePlacement() || VLOG_IS_ON(1)) {
       string msg = strings::StrCat("Executing op ", ndef.op(), " in device ",