Add extra supported device types if the following conditions are satisfied:
1) No kernel is defined for the given op (e.g. PyFunc on worker process) 2) A device is requested for this node which specifies job/replica/task 3) A local device is provided which specifies job/replica/task 4) The local device does not have the same (job, replica, task) as the requested device The goal is to address the issue where the graph includes op (e.g. PyFunc) whose kernel is known to a remote process but not to the current process. PiperOrigin-RevId: 263880099
This commit is contained in:
parent
9a26222268
commit
31385959d9
@ -121,15 +121,19 @@ bool ArePrioritiesSame(const PrioritizedDeviceTypeVector& a_types,
|
||||
} // namespace
|
||||
|
||||
Status Member::SetParentAndSupportedDevices(
|
||||
const Node& node, const std::vector<DeviceType>& types) {
|
||||
const Node& node, const std::vector<DeviceType>& types,
|
||||
const Device* default_local_device) {
|
||||
int id = node.id();
|
||||
if (id < 0) {
|
||||
return errors::Internal("Placer should not be creating a Member for node: ",
|
||||
node.DebugString());
|
||||
}
|
||||
parent_ = id;
|
||||
const DeviceNameUtils::ParsedName* name =
|
||||
default_local_device == nullptr ? nullptr
|
||||
: &default_local_device->parsed_name();
|
||||
return SupportedDeviceTypesForNode(types, node.def(),
|
||||
&supported_device_types_);
|
||||
&supported_device_types_, name);
|
||||
}
|
||||
|
||||
Status Member::SetAssignedDeviceName(const string& device_name) {
|
||||
@ -533,18 +537,18 @@ DeviceNameUtils::ParsedName Member::GetPreferredSoftDeviceName() const {
|
||||
ColocationGraph::ColocationGraph(const Graph* graph, const FunctionStack& stack,
|
||||
const FunctionLibraryDefinition* flib_def,
|
||||
const DeviceSet* device_set,
|
||||
const Device* default_device,
|
||||
const Device* default_local_device,
|
||||
bool allow_soft_placement,
|
||||
bool log_device_placement)
|
||||
: graph_(*graph),
|
||||
stack_(stack),
|
||||
flib_def_(*flib_def),
|
||||
inspecting_placer_(stack, flib_def, device_set, default_device,
|
||||
inspecting_placer_(stack, flib_def, device_set, default_local_device,
|
||||
allow_soft_placement, log_device_placement),
|
||||
inspection_required_checker_(graph, flib_def),
|
||||
device_set_(*device_set),
|
||||
device_types_(device_set->PrioritizedDeviceTypeList()),
|
||||
default_device_(default_device),
|
||||
default_local_device_(default_local_device),
|
||||
allow_soft_placement_(allow_soft_placement),
|
||||
log_device_placement_(log_device_placement) {
|
||||
members_.resize(graph_.num_node_ids());
|
||||
@ -930,7 +934,7 @@ void ColocationGraph::GetSoftDeviceCandidates(
|
||||
if (!possible_devices->empty()) {
|
||||
*possible_devices = FilterSupportedDevices(
|
||||
*possible_devices, root_member.supported_device_types(),
|
||||
default_device_);
|
||||
default_local_device_);
|
||||
}
|
||||
|
||||
if (!possible_devices->empty()) {
|
||||
@ -953,7 +957,7 @@ void ColocationGraph::GetSoftDeviceCandidates(
|
||||
if (!possible_devices->empty()) {
|
||||
*possible_devices = FilterSupportedDevices(
|
||||
*possible_devices, root_member.supported_device_types(),
|
||||
default_device_);
|
||||
default_local_device_);
|
||||
}
|
||||
|
||||
if (!possible_devices->empty()) {
|
||||
@ -1007,7 +1011,7 @@ Status ColocationGraph::GetDevicesForNode(
|
||||
// Filter devices into those that are compatible with the root
|
||||
// node (and its children).
|
||||
devices = FilterSupportedDevices(
|
||||
devices, root_member.supported_device_types(), default_device_);
|
||||
devices, root_member.supported_device_types(), default_local_device_);
|
||||
}
|
||||
|
||||
// Perform soft placement if allow_soft_placement_ is set.
|
||||
@ -1094,7 +1098,7 @@ Status ColocationGraph::GetDevicesForNode(
|
||||
}
|
||||
devices = FilterSupportedDevices(device_set_.devices(),
|
||||
root_member.supported_device_types(),
|
||||
default_device_);
|
||||
default_local_device_);
|
||||
|
||||
if (devices.empty()) {
|
||||
return errors::InvalidArgument(
|
||||
@ -1163,7 +1167,12 @@ string ColocationGraph::DebugInfo(const int node_root) const {
|
||||
colocation_nodes.push_back(node);
|
||||
|
||||
PrioritizedDeviceTypeVector supported_types;
|
||||
SupportedDeviceTypesForNode(device_types_, node->def(), &supported_types)
|
||||
const DeviceNameUtils::ParsedName* name =
|
||||
default_local_device_ == nullptr
|
||||
? nullptr
|
||||
: &default_local_device_->parsed_name();
|
||||
SupportedDeviceTypesForNode(device_types_, node->def(), &supported_types,
|
||||
name)
|
||||
.IgnoreError();
|
||||
string devices_registered;
|
||||
for (const auto& device_type : supported_types) {
|
||||
@ -1239,7 +1248,8 @@ Status ColocationGraph::InitializeMemberWithAssignedDevice(
|
||||
}
|
||||
|
||||
Status ColocationGraph::InitializeMember(const Node& node, Member* member) {
|
||||
TF_RETURN_IF_ERROR(member->SetParentAndSupportedDevices(node, device_types_));
|
||||
TF_RETURN_IF_ERROR(member->SetParentAndSupportedDevices(
|
||||
node, device_types_, default_local_device_));
|
||||
|
||||
if (node.has_assigned_device_name()) {
|
||||
TF_RETURN_IF_ERROR(InitializeMemberWithAssignedDevice(
|
||||
@ -1291,19 +1301,19 @@ Status ColocationGraph::InitializeMember(const Node& node, Member* member) {
|
||||
/*static*/ std::vector<Device*> ColocationGraph::FilterSupportedDevices(
|
||||
const std::vector<Device*>& devices,
|
||||
const PrioritizedDeviceTypeVector& supported_device_types,
|
||||
const Device* default_device) {
|
||||
const Device* default_local_device) {
|
||||
Device* filtered_default_device = nullptr;
|
||||
std::vector<std::pair<Device*, int32>> prioritized_filtered_devices;
|
||||
for (const auto& supported_device_type : supported_device_types) {
|
||||
for (Device* device : devices) {
|
||||
if (DeviceType(device->attributes().device_type()) ==
|
||||
supported_device_type.first) {
|
||||
if (default_device &&
|
||||
(device == default_device ||
|
||||
if (default_local_device &&
|
||||
(device == default_local_device ||
|
||||
// TODO(nareshmodi, fishx): At times the device pointer in the
|
||||
// device set is different to the one passed in as the default
|
||||
// device. Figure out why this might be.
|
||||
device->name() == default_device->name())) {
|
||||
device->name() == default_local_device->name())) {
|
||||
filtered_default_device = device;
|
||||
} else {
|
||||
prioritized_filtered_devices.emplace_back(
|
||||
|
@ -38,7 +38,8 @@ class Member {
|
||||
Member() = default;
|
||||
|
||||
Status SetParentAndSupportedDevices(const Node& node,
|
||||
const std::vector<DeviceType>& types);
|
||||
const std::vector<DeviceType>& types,
|
||||
const Device* default_local_device);
|
||||
|
||||
const DeviceNameUtils::ParsedName& requested_device_name() const {
|
||||
return requested_device_name_;
|
||||
@ -203,12 +204,13 @@ class Member {
|
||||
class ColocationGraph {
|
||||
public:
|
||||
// graph, flib_def, and device_set must not be null and must outlive
|
||||
// this ColocationGraph. default_device can be null. If not, must outlive
|
||||
// this.
|
||||
// this ColocationGraph. default_local_device can be null. If not, must
|
||||
// outlive this.
|
||||
ColocationGraph(const Graph* graph, const FunctionStack& stack,
|
||||
const FunctionLibraryDefinition* flib_def,
|
||||
const DeviceSet* device_set, const Device* default_device,
|
||||
bool allow_soft_placement, bool log_device_placement);
|
||||
const DeviceSet* device_set,
|
||||
const Device* default_local_device, bool allow_soft_placement,
|
||||
bool log_device_placement);
|
||||
|
||||
Status Initialize();
|
||||
|
||||
@ -254,7 +256,7 @@ class ColocationGraph {
|
||||
static std::vector<Device*> FilterSupportedDevices(
|
||||
const std::vector<Device*>& devices,
|
||||
const PrioritizedDeviceTypeVector& supported_device_types,
|
||||
const Device* default_device);
|
||||
const Device* default_local_device);
|
||||
|
||||
private:
|
||||
// Adds each node of the Graph to this ColocationGraph as a singleton.
|
||||
@ -355,7 +357,7 @@ class ColocationGraph {
|
||||
PlacerInspectionRequiredOpChecker inspection_required_checker_;
|
||||
const DeviceSet& device_set_;
|
||||
const std::vector<DeviceType> device_types_;
|
||||
const Device* default_device_;
|
||||
const Device* default_local_device_;
|
||||
const bool allow_soft_placement_;
|
||||
const bool log_device_placement_;
|
||||
|
||||
|
@ -228,7 +228,8 @@ Status SelectDevice(EagerOperation* op, const NodeDef& ndef, EagerContext* ctx,
|
||||
std::vector<Device*> final_devices;
|
||||
PrioritizedDeviceTypeVector supported_devs;
|
||||
TF_RETURN_IF_ERROR(SupportedDeviceTypesForNode(
|
||||
ctx->prioritized_device_type_list(), ndef, &supported_devs));
|
||||
ctx->prioritized_device_type_list(), ndef, &supported_devs,
|
||||
&ctx->HostCPU()->parsed_name()));
|
||||
if (supported_devs.empty()) {
|
||||
return errors::NotFound("Could not find valid device for node.\nNode:",
|
||||
FormatNodeDefForError(ndef),
|
||||
|
@ -610,7 +610,7 @@ Status GraphExecutionState::InitBaseGraph(std::unique_ptr<Graph>&& new_graph) {
|
||||
OptimizationPassRegistry::PRE_PLACEMENT, optimization_options));
|
||||
|
||||
Placer placer(new_graph.get(), "", flib_def_.get(), device_set_,
|
||||
/* default_device= */ nullptr,
|
||||
/* default_local_device= */ nullptr,
|
||||
session_options_ == nullptr ||
|
||||
session_options_->config.allow_soft_placement(),
|
||||
session_options_ != nullptr &&
|
||||
|
@ -73,20 +73,20 @@ Status AssignAndLog(int assigned_device, Node* node,
|
||||
|
||||
Placer::Placer(Graph* graph, const string& function_name,
|
||||
const FunctionLibraryDefinition* flib_def,
|
||||
const DeviceSet* devices, const Device* default_device,
|
||||
const DeviceSet* devices, const Device* default_local_device,
|
||||
bool allow_soft_placement, bool log_device_placement)
|
||||
: graph_(graph),
|
||||
function_name_(function_name),
|
||||
flib_def_(flib_def),
|
||||
devices_(devices),
|
||||
default_device_(default_device),
|
||||
default_local_device_(default_local_device),
|
||||
allow_soft_placement_(allow_soft_placement),
|
||||
log_device_placement_(log_device_placement) {}
|
||||
|
||||
Placer::Placer(Graph* graph, const string& function_name,
|
||||
const DeviceSet* devices, const Device* default_device)
|
||||
: Placer(graph, function_name, &graph->flib_def(), devices, default_device,
|
||||
true, false) {}
|
||||
const DeviceSet* devices, const Device* default_local_device)
|
||||
: Placer(graph, function_name, &graph->flib_def(), devices,
|
||||
default_local_device, true, false) {}
|
||||
|
||||
Placer::Placer(Graph* graph, const string& function_name,
|
||||
const DeviceSet* devices)
|
||||
@ -113,7 +113,7 @@ Status Placer::Run() {
|
||||
|
||||
FunctionStack stack(function_name_);
|
||||
ColocationGraph colocation_graph(graph_, stack, flib_def_, devices_,
|
||||
default_device_, allow_soft_placement_,
|
||||
default_local_device_, allow_soft_placement_,
|
||||
log_device_placement_);
|
||||
|
||||
TF_RETURN_IF_ERROR(colocation_graph.Initialize());
|
||||
|
@ -61,19 +61,20 @@ class Placer {
|
||||
// represented by "graph". If "graph" is not representing a function body,
|
||||
// "function_name" should be empty.
|
||||
//
|
||||
// If non-null, default_device is used where possible as a placement for nodes
|
||||
// which do not have a device specified, ahead of other devices which would
|
||||
// otherwise be higher priority.
|
||||
// If non-null, default_local_device is used where possible as a placement for
|
||||
// nodes which do not have a device specified, ahead of other devices which
|
||||
// would otherwise be higher priority. default_local_device should be on the
|
||||
// local host so that its FLR is directly accessible by the current process.
|
||||
//
|
||||
// The "graph", "devices", and "default_device" pointer arguments are borrowed
|
||||
// by this Placer, and must outlive it.
|
||||
// The "graph", "devices", and "default_local_device" pointer arguments are
|
||||
// borrowed by this Placer, and must outlive it.
|
||||
Placer(Graph* graph, const string& function_name,
|
||||
const FunctionLibraryDefinition* flib_def, const DeviceSet* devices,
|
||||
const Device* default_device, bool allow_soft_placement,
|
||||
const Device* default_local_device, bool allow_soft_placement,
|
||||
bool log_device_placement);
|
||||
|
||||
Placer(Graph* graph, const string& function_name, const DeviceSet* devices,
|
||||
const Device* default_device);
|
||||
const Device* default_local_device);
|
||||
|
||||
Placer(Graph* graph, const string& function_name, const DeviceSet* devices);
|
||||
|
||||
@ -96,7 +97,7 @@ class Placer {
|
||||
const string function_name_;
|
||||
const FunctionLibraryDefinition* const flib_def_; // Not owned.
|
||||
const DeviceSet* const devices_; // Not owned.
|
||||
const Device* default_device_; // Not owned.
|
||||
const Device* default_local_device_; // Not owned.
|
||||
const bool allow_soft_placement_;
|
||||
const bool log_device_placement_;
|
||||
|
||||
|
@ -241,9 +241,9 @@ class PlacerTest : public ::testing::Test {
|
||||
// placement will use the default DeviceSet (of 10 CPU and 10 GPU devices).
|
||||
//
|
||||
// REQUIRES: "*graph" was produced by the most recent call to BuildGraph.
|
||||
Status Place(Graph* graph, DeviceSet* devices, bool allow_soft_placement,
|
||||
bool log_device_placement) {
|
||||
Placer placer(graph, "", &graph->flib_def(), devices, nullptr,
|
||||
Status Place(Graph* graph, DeviceSet* devices, Device* default_local_device,
|
||||
bool allow_soft_placement, bool log_device_placement) {
|
||||
Placer placer(graph, "", &graph->flib_def(), devices, default_local_device,
|
||||
allow_soft_placement, log_device_placement);
|
||||
return placer.Run();
|
||||
}
|
||||
@ -286,15 +286,18 @@ class PlacerTest : public ::testing::Test {
|
||||
}
|
||||
|
||||
Status Place(Graph* graph, DeviceSet* devices) {
|
||||
return Place(graph, devices, true, false);
|
||||
return Place(graph, devices, nullptr, true, false);
|
||||
}
|
||||
|
||||
Status Place(Graph* graph, bool allow_soft_placement,
|
||||
bool log_device_placement) {
|
||||
return Place(graph, &devices_, allow_soft_placement, log_device_placement);
|
||||
return Place(graph, &devices_, nullptr, allow_soft_placement,
|
||||
log_device_placement);
|
||||
}
|
||||
|
||||
Status Place(Graph* graph) { return Place(graph, &devices_, true, false); }
|
||||
Status Place(Graph* graph) {
|
||||
return Place(graph, &devices_, nullptr, true, false);
|
||||
}
|
||||
|
||||
Status CallOptPassesAndPlace(Graph* graph, bool allow_soft_placement,
|
||||
bool log_device_placement) {
|
||||
@ -1430,8 +1433,8 @@ TEST_F(PlacerTest, TestUnknownAssignedDevice) {
|
||||
}
|
||||
|
||||
// Test that placement fails when an op with no registered kernels is
|
||||
// requested.
|
||||
TEST_F(PlacerTest, TestNoKernelsRegistered) {
|
||||
// requested and no device is requested for the node
|
||||
TEST_F(PlacerTest, TestNoKernelsRegisteredWithNoRequstedDevice) {
|
||||
Graph g(OpRegistry::Global());
|
||||
{ // Scope for temporary variables used to construct g.
|
||||
GraphDefBuilder b(GraphDefBuilder::kFailImmediately);
|
||||
@ -1447,6 +1450,58 @@ TEST_F(PlacerTest, TestNoKernelsRegistered) {
|
||||
EXPECT_TRUE(absl::StrContains(s.error_message(), "<no registered kernels>"));
|
||||
}
|
||||
|
||||
// Test that placement fails when an op does not have registered kernel
|
||||
// and the requested device has the same (job, replica, task) as the placer's
|
||||
// local device
|
||||
TEST_F(PlacerTest, TestNoKernelsRegisteredWithRequestedDeviceLocal) {
|
||||
const string cpu_device = "/job:b/replica:0/task:0/device:FakeCPU:0";
|
||||
const string gpu_device = "/job:b/replica:0/task:0/device:FakeGPU:0";
|
||||
|
||||
Graph g(OpRegistry::Global());
|
||||
{ // Scope for temporary variables used to construct g.
|
||||
GraphDefBuilder b(GraphDefBuilder::kFailImmediately);
|
||||
ops::SourceOp("VariableNoKernels", b.opts().WithName("var"));
|
||||
TF_EXPECT_OK(BuildGraph(b, &g));
|
||||
}
|
||||
GetNodeByName(g, "var")->set_requested_device(gpu_device);
|
||||
|
||||
DeviceSet devices;
|
||||
std::unique_ptr<Device> gpu(FakeDevice::MakeGPU(gpu_device));
|
||||
devices.AddDevice(gpu.get());
|
||||
std::unique_ptr<Device> cpu(FakeDevice::MakeCPU(cpu_device));
|
||||
devices.AddDevice(cpu.get());
|
||||
Status s = Place(&g, &devices, cpu.get(), false, false);
|
||||
EXPECT_EQ(error::INVALID_ARGUMENT, s.code());
|
||||
EXPECT_TRUE(absl::StrContains(s.error_message(),
|
||||
"No OpKernel was registered to support Op "
|
||||
"'VariableNoKernels' used by {{node var}}"));
|
||||
EXPECT_TRUE(absl::StrContains(s.error_message(), "<no registered kernels>"));
|
||||
}
|
||||
|
||||
// Test that placement succeeds when an op does not have registered kernel
|
||||
// and the requested device has different (job, replica, task) than the placer's
|
||||
// local device
|
||||
TEST_F(PlacerTest, TestNoKernelsRegisteredWithRequestedDeviceRemote) {
|
||||
const string local_device = "/job:b/replica:0/task:0/device:FakeCPU:0";
|
||||
const string remote_device = "/job:b/replica:0/task:1/device:FakeGPU:0";
|
||||
|
||||
Graph g(OpRegistry::Global());
|
||||
{ // Scope for temporary variables used to construct g.
|
||||
GraphDefBuilder b(GraphDefBuilder::kFailImmediately);
|
||||
ops::SourceOp("VariableNoKernels", b.opts().WithName("var"));
|
||||
TF_EXPECT_OK(BuildGraph(b, &g));
|
||||
}
|
||||
GetNodeByName(g, "var")->set_requested_device(remote_device);
|
||||
|
||||
DeviceSet heterogeneous;
|
||||
std::unique_ptr<Device> gpu(FakeDevice::MakeGPU(remote_device));
|
||||
heterogeneous.AddDevice(gpu.get());
|
||||
std::unique_ptr<Device> cpu(FakeDevice::MakeCPU(local_device));
|
||||
heterogeneous.AddDevice(cpu.get());
|
||||
TF_EXPECT_OK(Place(&g, &heterogeneous, cpu.get(), false, false));
|
||||
EXPECT_DEVICE_CONTAINS(g, "var", remote_device);
|
||||
}
|
||||
|
||||
// Test that placement fails when a kernel is registered but no known
|
||||
// device supports it.
|
||||
TEST_F(PlacerTest, TestNoDevicesRegistered) {
|
||||
|
@ -1359,7 +1359,8 @@ Status FindKernelDef(const DeviceType& device_type, const NodeDef& node_def,
|
||||
|
||||
Status SupportedDeviceTypesForNode(
|
||||
const std::vector<DeviceType>& prioritized_types, const NodeDef& def,
|
||||
PrioritizedDeviceTypeVector* prioritized_device_types) {
|
||||
PrioritizedDeviceTypeVector* prioritized_device_types,
|
||||
const DeviceNameUtils::ParsedName* local_device_name) {
|
||||
// TODO(zhifengc): Changes the callers (SimplePlacer and
|
||||
// DynamicPlacer) to consider the possibility that 'def' is call to
|
||||
// a user-defined function and only calls this
|
||||
@ -1367,16 +1368,44 @@ Status SupportedDeviceTypesForNode(
|
||||
const OpRegistrationData* op_reg_data;
|
||||
const Status s = OpRegistry::Global()->LookUp(def.op(), &op_reg_data);
|
||||
if (s.ok()) {
|
||||
bool exists_attr_mismatch = false;
|
||||
for (const DeviceType& device_type : prioritized_types) {
|
||||
const KernelRegistration* reg = nullptr;
|
||||
bool was_attr_mismatch;
|
||||
bool was_attr_mismatch = false;
|
||||
TF_RETURN_IF_ERROR(
|
||||
FindKernelRegistration(device_type, def, ®, &was_attr_mismatch));
|
||||
exists_attr_mismatch = exists_attr_mismatch || was_attr_mismatch;
|
||||
if (reg != nullptr) {
|
||||
int32 priority = reg->def.priority();
|
||||
prioritized_device_types->emplace_back(device_type, priority);
|
||||
}
|
||||
}
|
||||
// Add extra supported device types if the following conditions are
|
||||
// satisfied:
|
||||
// 1) No kernel is defined for the given op (e.g. PyFunc on worker process)
|
||||
// 2) A device is requested for this node which specifies job/replica/task
|
||||
// 3) A local device is provided which specifies job/replica/task
|
||||
// 4) The local device does not have the same (job, replica, task) as the
|
||||
// requested device
|
||||
//
|
||||
// The goal is to address the issue where a graph includes op (e.g. PyFunc)
|
||||
// whose kernel is known to a remote process but not to the current process.
|
||||
if (prioritized_device_types->empty() && !exists_attr_mismatch &&
|
||||
local_device_name != nullptr) {
|
||||
DeviceNameUtils::ParsedName requested_device_name;
|
||||
DeviceNameUtils::ParseFullName(def.device(), &requested_device_name);
|
||||
if (!DeviceNameUtils::IsSameAddressSpace(*local_device_name,
|
||||
requested_device_name)) {
|
||||
if (requested_device_name.has_type) {
|
||||
prioritized_device_types->push_back(
|
||||
std::make_pair(DeviceType(requested_device_name.type), 0));
|
||||
} else {
|
||||
for (const DeviceType& device_type : prioritized_types) {
|
||||
prioritized_device_types->push_back(std::make_pair(device_type, 0));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
std::sort(prioritized_device_types->begin(),
|
||||
prioritized_device_types->end(),
|
||||
[](const std::pair<DeviceType, int32>& a,
|
||||
|
@ -1411,7 +1411,8 @@ Status CreateOpKernel(DeviceType device_type, DeviceBase* device,
|
||||
// * def has all attrs specified (e.g. using AddDefaultsToNodeDef()).
|
||||
Status SupportedDeviceTypesForNode(
|
||||
const std::vector<DeviceType>& prioritized_types, const NodeDef& def,
|
||||
PrioritizedDeviceTypeVector* device_types);
|
||||
PrioritizedDeviceTypeVector* device_types,
|
||||
const DeviceNameUtils::ParsedName* local_device_name = nullptr);
|
||||
|
||||
// Returns a message with a description of the kernels registered for op
|
||||
// `op_name`.
|
||||
|
@ -37,6 +37,7 @@ limitations under the License.
|
||||
#include "tensorflow/core/platform/test.h"
|
||||
#include "tensorflow/core/platform/test_benchmark.h"
|
||||
#include "tensorflow/core/public/version.h"
|
||||
#include "tensorflow/core/util/device_name_utils.h"
|
||||
|
||||
class DummyKernel : public tensorflow::OpKernel {
|
||||
public:
|
||||
@ -107,6 +108,8 @@ REGISTER_KERNEL_BUILDER(Name("Test4").Device(DEVICE_GPU), DummyKernel);
|
||||
// Kernels with different priorities.
|
||||
REGISTER_OP("Test5").Input("a: T").Input("b: T").Attr("T: type");
|
||||
|
||||
REGISTER_OP("OpWithoutKernel").Input("a: T").Input("b: T").Attr("T: type");
|
||||
|
||||
class TestOp5Cpu : public tensorflow::OpKernel {
|
||||
public:
|
||||
explicit TestOp5Cpu(OpKernelConstruction* context) : OpKernel(context) {}
|
||||
@ -134,11 +137,13 @@ class OpKernelTest : public ::testing::Test {
|
||||
OpKernelTest() : device_(Env::Default()) {}
|
||||
|
||||
protected:
|
||||
NodeDef CreateNodeDef(const string& op_type, const DataTypeVector& inputs) {
|
||||
NodeDef CreateNodeDef(const string& op_type, const DataTypeVector& inputs,
|
||||
const string& device = "") {
|
||||
NodeDefBuilder builder(op_type + "-op", op_type);
|
||||
for (DataType dt : inputs) {
|
||||
builder.Input(FakeInput(dt));
|
||||
}
|
||||
builder.Device(device);
|
||||
NodeDef node_def;
|
||||
TF_CHECK_OK(builder.Finalize(&node_def));
|
||||
return node_def;
|
||||
@ -214,6 +219,38 @@ TEST_F(OpKernelTest, CpuTypeRegistered) {
|
||||
EXPECT_EQ(DeviceType(DEVICE_CPU), devs[0].first);
|
||||
}
|
||||
|
||||
TEST_F(OpKernelTest, KernelNotRegistered) {
|
||||
const string& local_device = "/job:localhost/replica:0/task:0/device:CPU:0";
|
||||
const string& remote_device = "/job:worker/replica:0/task:0/device";
|
||||
{
|
||||
// Try a node def of an op which does not have kernel. And the requested
|
||||
// device in NodeDef is on a different address space than the local device.
|
||||
NodeDef ndef =
|
||||
CreateNodeDef("OpWithoutKernel", {DT_STRING, DT_STRING}, remote_device);
|
||||
PrioritizedDeviceTypeVector devs;
|
||||
DeviceNameUtils::ParsedName local_device_name;
|
||||
DeviceNameUtils::ParseFullName(local_device, &local_device_name);
|
||||
TF_ASSERT_OK(SupportedDeviceTypesForNode(DeviceTypes(), ndef, &devs,
|
||||
&local_device_name));
|
||||
EXPECT_EQ(2, devs.size());
|
||||
EXPECT_EQ(DeviceType(DEVICE_GPU), devs[0].first);
|
||||
EXPECT_EQ(DeviceType(DEVICE_CPU), devs[1].first);
|
||||
}
|
||||
|
||||
{
|
||||
// Try a node def of an op which does not have kernel. And the requested
|
||||
// device in NodeDef is on the same address space as the local device.
|
||||
NodeDef ndef =
|
||||
CreateNodeDef("OpWithoutKernel", {DT_STRING, DT_STRING}, local_device);
|
||||
PrioritizedDeviceTypeVector devs;
|
||||
DeviceNameUtils::ParsedName local_device_name;
|
||||
DeviceNameUtils::ParseFullName(local_device, &local_device_name);
|
||||
TF_ASSERT_OK(SupportedDeviceTypesForNode(DeviceTypes(), ndef, &devs,
|
||||
&local_device_name));
|
||||
EXPECT_EQ(0, devs.size());
|
||||
}
|
||||
}
|
||||
|
||||
TEST_F(OpKernelTest, CpuAndGpuTypeRegistered) {
|
||||
{
|
||||
// Try a node def of an op that is registered for a specific type
|
||||
|
Loading…
Reference in New Issue
Block a user