Better error message when unable to find a kernel due to attr mismatch.

Also stop building a prioritized list every time a device is selected.

PiperOrigin-RevId: 209614803
This commit is contained in:
Akshay Modi 2018-08-21 10:11:56 -07:00 committed by TensorFlower Gardener
parent fe2826658b
commit 99107ec949
3 changed files with 18 additions and 10 deletions

View File

@ -15,6 +15,7 @@ limitations under the License.
#include "tensorflow/core/common_runtime/eager/context.h" #include "tensorflow/core/common_runtime/eager/context.h"
#include "tensorflow/core/common_runtime/device_set.h"
#include "tensorflow/core/common_runtime/process_util.h" #include "tensorflow/core/common_runtime/process_util.h"
#include "tensorflow/core/framework/resource_mgr.h" #include "tensorflow/core/framework/resource_mgr.h"
#include "tensorflow/core/lib/core/blocking_counter.h" #include "tensorflow/core/lib/core/blocking_counter.h"
@ -78,6 +79,12 @@ void EagerContext::InitDeviceMapAndAsync() {
} }
} }
} }
DeviceSet ds;
for (Device* d : devices_) {
ds.AddDevice(d);
}
prioritized_device_type_list_ = ds.PrioritizedDeviceTypeList();
} }
bool EagerContext::Async() const { bool EagerContext::Async() const {

View File

@ -93,6 +93,9 @@ class EagerContext {
// TODO(apassos) make this return a constant reference // TODO(apassos) make this return a constant reference
std::vector<Device*>* devices() { return &devices_; } std::vector<Device*>* devices() { return &devices_; }
const std::vector<DeviceType>& prioritized_device_type_list() {
return prioritized_device_type_list_;
}
// Clears the kernel caches. // Clears the kernel caches.
void ClearCaches(); void ClearCaches();
@ -210,6 +213,7 @@ class EagerContext {
// Devices owned by device_manager // Devices owned by device_manager
std::vector<Device*> devices_; std::vector<Device*> devices_;
std::vector<DeviceType> prioritized_device_type_list_;
// All devices are not owned. // All devices are not owned.
gtl::FlatMap<string, Device*, StringPieceHasher> devices_map_; gtl::FlatMap<string, Device*, StringPieceHasher> devices_map_;
Rendezvous* rendezvous_; Rendezvous* rendezvous_;

View File

@ -192,17 +192,14 @@ Status ValidateInputTypeAndPlacement(EagerContext* ctx, Device* op_device,
} }
Status SelectDevice(const NodeDef& ndef, EagerContext* ctx, Device** device) { Status SelectDevice(const NodeDef& ndef, EagerContext* ctx, Device** device) {
DeviceSet ds;
for (Device* d : *ctx->devices()) {
ds.AddDevice(d);
}
DeviceTypeVector final_devices; DeviceTypeVector final_devices;
auto status = SupportedDeviceTypesForNode(ds.PrioritizedDeviceTypeList(), TF_RETURN_IF_ERROR(SupportedDeviceTypesForNode(
ndef, &final_devices); ctx->prioritized_device_type_list(), ndef, &final_devices));
if (!status.ok()) return status;
if (final_devices.empty()) { if (final_devices.empty()) {
return errors::Internal("Could not find valid device for node ", return errors::Internal(
ndef.DebugString()); "Could not find valid device for node.\nNode: ", SummarizeNodeDef(ndef),
"\nAll kernels registered for op ", ndef.op(), " :\n",
KernelsRegisteredForOp(ndef.op()));
} }
for (Device* d : *ctx->devices()) { for (Device* d : *ctx->devices()) {
if (d->device_type() == final_devices[0].type_string()) { if (d->device_type() == final_devices[0].type_string()) {
@ -211,7 +208,7 @@ Status SelectDevice(const NodeDef& ndef, EagerContext* ctx, Device** device) {
} }
} }
return errors::Unknown("Could not find a device for node ", return errors::Unknown("Could not find a device for node ",
ndef.DebugString()); SummarizeNodeDef(ndef));
} }
Status GetOutputDTypes(EagerOperation* op, DataTypeVector* output_dtypes) { Status GetOutputDTypes(EagerOperation* op, DataTypeVector* output_dtypes) {