Better error message when unable to find a kernel due to attr mismatch.

Also stop building a prioritized list every time a device is selected.

PiperOrigin-RevId: 209614803
This commit is contained in:
Akshay Modi 2018-08-21 10:11:56 -07:00 committed by TensorFlower Gardener
parent fe2826658b
commit 99107ec949
3 changed files with 18 additions and 10 deletions

View File

@ -15,6 +15,7 @@ limitations under the License.
#include "tensorflow/core/common_runtime/eager/context.h"
#include "tensorflow/core/common_runtime/device_set.h"
#include "tensorflow/core/common_runtime/process_util.h"
#include "tensorflow/core/framework/resource_mgr.h"
#include "tensorflow/core/lib/core/blocking_counter.h"
@ -78,6 +79,12 @@ void EagerContext::InitDeviceMapAndAsync() {
}
}
}
DeviceSet ds;
for (Device* d : devices_) {
ds.AddDevice(d);
}
prioritized_device_type_list_ = ds.PrioritizedDeviceTypeList();
}
bool EagerContext::Async() const {

View File

@ -93,6 +93,9 @@ class EagerContext {
// TODO(apassos) make this return a constant reference
std::vector<Device*>* devices() { return &devices_; }
const std::vector<DeviceType>& prioritized_device_type_list() {
return prioritized_device_type_list_;
}
// Clears the kernel caches.
void ClearCaches();
@ -210,6 +213,7 @@ class EagerContext {
// Devices owned by device_manager
std::vector<Device*> devices_;
std::vector<DeviceType> prioritized_device_type_list_;
// All devices are not owned.
gtl::FlatMap<string, Device*, StringPieceHasher> devices_map_;
Rendezvous* rendezvous_;

View File

@ -192,17 +192,14 @@ Status ValidateInputTypeAndPlacement(EagerContext* ctx, Device* op_device,
}
Status SelectDevice(const NodeDef& ndef, EagerContext* ctx, Device** device) {
DeviceSet ds;
for (Device* d : *ctx->devices()) {
ds.AddDevice(d);
}
DeviceTypeVector final_devices;
auto status = SupportedDeviceTypesForNode(ds.PrioritizedDeviceTypeList(),
ndef, &final_devices);
if (!status.ok()) return status;
TF_RETURN_IF_ERROR(SupportedDeviceTypesForNode(
ctx->prioritized_device_type_list(), ndef, &final_devices));
if (final_devices.empty()) {
return errors::Internal("Could not find valid device for node ",
ndef.DebugString());
return errors::Internal(
"Could not find valid device for node.\nNode: ", SummarizeNodeDef(ndef),
"\nAll kernels registered for op ", ndef.op(), " :\n",
KernelsRegisteredForOp(ndef.op()));
}
for (Device* d : *ctx->devices()) {
if (d->device_type() == final_devices[0].type_string()) {
@ -211,7 +208,7 @@ Status SelectDevice(const NodeDef& ndef, EagerContext* ctx, Device** device) {
}
}
return errors::Unknown("Could not find a device for node ",
ndef.DebugString());
SummarizeNodeDef(ndef));
}
Status GetOutputDTypes(EagerOperation* op, DataTypeVector* output_dtypes) {