From 99107ec94994bf9f6748fb69fe2fe20d946c6c96 Mon Sep 17 00:00:00 2001 From: Akshay Modi Date: Tue, 21 Aug 2018 10:11:56 -0700 Subject: [PATCH] Better error message when unable to find a kernel due to attr mismatch. Also stop building a prioritized list every time a device is selected. PiperOrigin-RevId: 209614803 --- tensorflow/core/common_runtime/eager/context.cc | 7 +++++++ tensorflow/core/common_runtime/eager/context.h | 4 ++++ tensorflow/core/common_runtime/eager/execute.cc | 17 +++++++---------- 3 files changed, 18 insertions(+), 10 deletions(-) diff --git a/tensorflow/core/common_runtime/eager/context.cc b/tensorflow/core/common_runtime/eager/context.cc index 5bdd547c7f3..b859b06fa0e 100644 --- a/tensorflow/core/common_runtime/eager/context.cc +++ b/tensorflow/core/common_runtime/eager/context.cc @@ -15,6 +15,7 @@ limitations under the License. #include "tensorflow/core/common_runtime/eager/context.h" +#include "tensorflow/core/common_runtime/device_set.h" #include "tensorflow/core/common_runtime/process_util.h" #include "tensorflow/core/framework/resource_mgr.h" #include "tensorflow/core/lib/core/blocking_counter.h" @@ -78,6 +79,12 @@ void EagerContext::InitDeviceMapAndAsync() { } } } + + DeviceSet ds; + for (Device* d : devices_) { + ds.AddDevice(d); + } + prioritized_device_type_list_ = ds.PrioritizedDeviceTypeList(); } bool EagerContext::Async() const { diff --git a/tensorflow/core/common_runtime/eager/context.h b/tensorflow/core/common_runtime/eager/context.h index 9835b195113..3c95ac590d1 100644 --- a/tensorflow/core/common_runtime/eager/context.h +++ b/tensorflow/core/common_runtime/eager/context.h @@ -93,6 +93,9 @@ class EagerContext { // TODO(apassos) make this return a constant reference std::vector* devices() { return &devices_; } + const std::vector& prioritized_device_type_list() { + return prioritized_device_type_list_; + } // Clears the kernel caches. void ClearCaches(); @@ -210,6 +213,7 @@ class EagerContext { // Devices owned by device_manager std::vector devices_; + std::vector prioritized_device_type_list_; // All devices are not owned. gtl::FlatMap devices_map_; Rendezvous* rendezvous_; diff --git a/tensorflow/core/common_runtime/eager/execute.cc b/tensorflow/core/common_runtime/eager/execute.cc index 46065f399c5..5b3a64ba980 100644 --- a/tensorflow/core/common_runtime/eager/execute.cc +++ b/tensorflow/core/common_runtime/eager/execute.cc @@ -192,17 +192,14 @@ Status ValidateInputTypeAndPlacement(EagerContext* ctx, Device* op_device, } Status SelectDevice(const NodeDef& ndef, EagerContext* ctx, Device** device) { - DeviceSet ds; - for (Device* d : *ctx->devices()) { - ds.AddDevice(d); - } DeviceTypeVector final_devices; - auto status = SupportedDeviceTypesForNode(ds.PrioritizedDeviceTypeList(), - ndef, &final_devices); - if (!status.ok()) return status; + TF_RETURN_IF_ERROR(SupportedDeviceTypesForNode( + ctx->prioritized_device_type_list(), ndef, &final_devices)); if (final_devices.empty()) { - return errors::Internal("Could not find valid device for node ", - ndef.DebugString()); + return errors::Internal( + "Could not find valid device for node.\nNode: ", SummarizeNodeDef(ndef), + "\nAll kernels registered for op ", ndef.op(), " :\n", + KernelsRegisteredForOp(ndef.op())); } for (Device* d : *ctx->devices()) { if (d->device_type() == final_devices[0].type_string()) { @@ -211,7 +208,7 @@ Status SelectDevice(const NodeDef& ndef, EagerContext* ctx, Device** device) { } } return errors::Unknown("Could not find a device for node ", - ndef.DebugString()); + SummarizeNodeDef(ndef)); } Status GetOutputDTypes(EagerOperation* op, DataTypeVector* output_dtypes) {