Better error message when unable to find a kernel due to attr mismatch.
Also stop building a prioritized list every time a device is selected. PiperOrigin-RevId: 209614803
This commit is contained in:
parent
fe2826658b
commit
99107ec949
@ -15,6 +15,7 @@ limitations under the License.
|
|||||||
|
|
||||||
#include "tensorflow/core/common_runtime/eager/context.h"
|
#include "tensorflow/core/common_runtime/eager/context.h"
|
||||||
|
|
||||||
|
#include "tensorflow/core/common_runtime/device_set.h"
|
||||||
#include "tensorflow/core/common_runtime/process_util.h"
|
#include "tensorflow/core/common_runtime/process_util.h"
|
||||||
#include "tensorflow/core/framework/resource_mgr.h"
|
#include "tensorflow/core/framework/resource_mgr.h"
|
||||||
#include "tensorflow/core/lib/core/blocking_counter.h"
|
#include "tensorflow/core/lib/core/blocking_counter.h"
|
||||||
@ -78,6 +79,12 @@ void EagerContext::InitDeviceMapAndAsync() {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
DeviceSet ds;
|
||||||
|
for (Device* d : devices_) {
|
||||||
|
ds.AddDevice(d);
|
||||||
|
}
|
||||||
|
prioritized_device_type_list_ = ds.PrioritizedDeviceTypeList();
|
||||||
}
|
}
|
||||||
|
|
||||||
bool EagerContext::Async() const {
|
bool EagerContext::Async() const {
|
||||||
|
@ -93,6 +93,9 @@ class EagerContext {
|
|||||||
|
|
||||||
// TODO(apassos) make this return a constant reference
|
// TODO(apassos) make this return a constant reference
|
||||||
std::vector<Device*>* devices() { return &devices_; }
|
std::vector<Device*>* devices() { return &devices_; }
|
||||||
|
const std::vector<DeviceType>& prioritized_device_type_list() {
|
||||||
|
return prioritized_device_type_list_;
|
||||||
|
}
|
||||||
|
|
||||||
// Clears the kernel caches.
|
// Clears the kernel caches.
|
||||||
void ClearCaches();
|
void ClearCaches();
|
||||||
@ -210,6 +213,7 @@ class EagerContext {
|
|||||||
|
|
||||||
// Devices owned by device_manager
|
// Devices owned by device_manager
|
||||||
std::vector<Device*> devices_;
|
std::vector<Device*> devices_;
|
||||||
|
std::vector<DeviceType> prioritized_device_type_list_;
|
||||||
// All devices are not owned.
|
// All devices are not owned.
|
||||||
gtl::FlatMap<string, Device*, StringPieceHasher> devices_map_;
|
gtl::FlatMap<string, Device*, StringPieceHasher> devices_map_;
|
||||||
Rendezvous* rendezvous_;
|
Rendezvous* rendezvous_;
|
||||||
|
@ -192,17 +192,14 @@ Status ValidateInputTypeAndPlacement(EagerContext* ctx, Device* op_device,
|
|||||||
}
|
}
|
||||||
|
|
||||||
Status SelectDevice(const NodeDef& ndef, EagerContext* ctx, Device** device) {
|
Status SelectDevice(const NodeDef& ndef, EagerContext* ctx, Device** device) {
|
||||||
DeviceSet ds;
|
|
||||||
for (Device* d : *ctx->devices()) {
|
|
||||||
ds.AddDevice(d);
|
|
||||||
}
|
|
||||||
DeviceTypeVector final_devices;
|
DeviceTypeVector final_devices;
|
||||||
auto status = SupportedDeviceTypesForNode(ds.PrioritizedDeviceTypeList(),
|
TF_RETURN_IF_ERROR(SupportedDeviceTypesForNode(
|
||||||
ndef, &final_devices);
|
ctx->prioritized_device_type_list(), ndef, &final_devices));
|
||||||
if (!status.ok()) return status;
|
|
||||||
if (final_devices.empty()) {
|
if (final_devices.empty()) {
|
||||||
return errors::Internal("Could not find valid device for node ",
|
return errors::Internal(
|
||||||
ndef.DebugString());
|
"Could not find valid device for node.\nNode: ", SummarizeNodeDef(ndef),
|
||||||
|
"\nAll kernels registered for op ", ndef.op(), " :\n",
|
||||||
|
KernelsRegisteredForOp(ndef.op()));
|
||||||
}
|
}
|
||||||
for (Device* d : *ctx->devices()) {
|
for (Device* d : *ctx->devices()) {
|
||||||
if (d->device_type() == final_devices[0].type_string()) {
|
if (d->device_type() == final_devices[0].type_string()) {
|
||||||
@ -211,7 +208,7 @@ Status SelectDevice(const NodeDef& ndef, EagerContext* ctx, Device** device) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
return errors::Unknown("Could not find a device for node ",
|
return errors::Unknown("Could not find a device for node ",
|
||||||
ndef.DebugString());
|
SummarizeNodeDef(ndef));
|
||||||
}
|
}
|
||||||
|
|
||||||
Status GetOutputDTypes(EagerOperation* op, DataTypeVector* output_dtypes) {
|
Status GetOutputDTypes(EagerOperation* op, DataTypeVector* output_dtypes) {
|
||||||
|
Loading…
x
Reference in New Issue
Block a user