Better error message when unable to find a kernel due to attr mismatch.
Also stop building a prioritized list every time a device is selected. PiperOrigin-RevId: 209614803
This commit is contained in:
parent
fe2826658b
commit
99107ec949
@ -15,6 +15,7 @@ limitations under the License.
|
||||
|
||||
#include "tensorflow/core/common_runtime/eager/context.h"
|
||||
|
||||
#include "tensorflow/core/common_runtime/device_set.h"
|
||||
#include "tensorflow/core/common_runtime/process_util.h"
|
||||
#include "tensorflow/core/framework/resource_mgr.h"
|
||||
#include "tensorflow/core/lib/core/blocking_counter.h"
|
||||
@ -78,6 +79,12 @@ void EagerContext::InitDeviceMapAndAsync() {
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
DeviceSet ds;
|
||||
for (Device* d : devices_) {
|
||||
ds.AddDevice(d);
|
||||
}
|
||||
prioritized_device_type_list_ = ds.PrioritizedDeviceTypeList();
|
||||
}
|
||||
|
||||
bool EagerContext::Async() const {
|
||||
|
@ -93,6 +93,9 @@ class EagerContext {
|
||||
|
||||
// TODO(apassos) make this return a constant reference
|
||||
std::vector<Device*>* devices() { return &devices_; }
|
||||
const std::vector<DeviceType>& prioritized_device_type_list() {
|
||||
return prioritized_device_type_list_;
|
||||
}
|
||||
|
||||
// Clears the kernel caches.
|
||||
void ClearCaches();
|
||||
@ -210,6 +213,7 @@ class EagerContext {
|
||||
|
||||
// Devices owned by device_manager
|
||||
std::vector<Device*> devices_;
|
||||
std::vector<DeviceType> prioritized_device_type_list_;
|
||||
// All devices are not owned.
|
||||
gtl::FlatMap<string, Device*, StringPieceHasher> devices_map_;
|
||||
Rendezvous* rendezvous_;
|
||||
|
@ -192,17 +192,14 @@ Status ValidateInputTypeAndPlacement(EagerContext* ctx, Device* op_device,
|
||||
}
|
||||
|
||||
Status SelectDevice(const NodeDef& ndef, EagerContext* ctx, Device** device) {
|
||||
DeviceSet ds;
|
||||
for (Device* d : *ctx->devices()) {
|
||||
ds.AddDevice(d);
|
||||
}
|
||||
DeviceTypeVector final_devices;
|
||||
auto status = SupportedDeviceTypesForNode(ds.PrioritizedDeviceTypeList(),
|
||||
ndef, &final_devices);
|
||||
if (!status.ok()) return status;
|
||||
TF_RETURN_IF_ERROR(SupportedDeviceTypesForNode(
|
||||
ctx->prioritized_device_type_list(), ndef, &final_devices));
|
||||
if (final_devices.empty()) {
|
||||
return errors::Internal("Could not find valid device for node ",
|
||||
ndef.DebugString());
|
||||
return errors::Internal(
|
||||
"Could not find valid device for node.\nNode: ", SummarizeNodeDef(ndef),
|
||||
"\nAll kernels registered for op ", ndef.op(), " :\n",
|
||||
KernelsRegisteredForOp(ndef.op()));
|
||||
}
|
||||
for (Device* d : *ctx->devices()) {
|
||||
if (d->device_type() == final_devices[0].type_string()) {
|
||||
@ -211,7 +208,7 @@ Status SelectDevice(const NodeDef& ndef, EagerContext* ctx, Device** device) {
|
||||
}
|
||||
}
|
||||
return errors::Unknown("Could not find a device for node ",
|
||||
ndef.DebugString());
|
||||
SummarizeNodeDef(ndef));
|
||||
}
|
||||
|
||||
Status GetOutputDTypes(EagerOperation* op, DataTypeVector* output_dtypes) {
|
||||
|
Loading…
x
Reference in New Issue
Block a user