diff --git a/tensorflow/core/common_runtime/eager/execute.cc b/tensorflow/core/common_runtime/eager/execute.cc index 3cdda3ed753..a708033c650 100644 --- a/tensorflow/core/common_runtime/eager/execute.cc +++ b/tensorflow/core/common_runtime/eager/execute.cc @@ -192,7 +192,7 @@ Status ValidateInputTypeAndPlacement(EagerContext* ctx, Device* op_device, } Status SelectDevice(const NodeDef& ndef, EagerContext* ctx, Device** device) { - DeviceTypeVector final_devices; + PrioritizedDeviceTypeVector final_devices; TF_RETURN_IF_ERROR(SupportedDeviceTypesForNode( ctx->prioritized_device_type_list(), ndef, &final_devices)); if (final_devices.empty()) { @@ -202,7 +202,7 @@ Status SelectDevice(const NodeDef& ndef, EagerContext* ctx, Device** device) { " :\n", KernelsRegisteredForOp(ndef.op())); } for (Device* d : *ctx->devices()) { - if (d->device_type() == final_devices[0].type_string()) { + if (d->device_type() == final_devices[0].first.type_string()) { *device = d; return Status::OK(); } diff --git a/tensorflow/core/common_runtime/placer.cc b/tensorflow/core/common_runtime/placer.cc index f8d933b45e0..01e4072f603 100644 --- a/tensorflow/core/common_runtime/placer.cc +++ b/tensorflow/core/common_runtime/placer.cc @@ -47,42 +47,51 @@ const StringPiece kColocationGroupPrefixStringPiece(kColocationGroupPrefix); // returned list is sorted by preferred type (higher numeric type is preferred). std::vector FilterSupportedDevices( const std::vector& devices, - const DeviceTypeVector& supported_device_types, + const PrioritizedDeviceTypeVector& supported_device_types, const Device* default_device) { Device* filtered_default_device = nullptr; - std::vector filtered_devices; - for (const DeviceType& d : supported_device_types) { + std::vector> prioritized_filtered_devices; + for (const auto& supported_device_type : supported_device_types) { for (Device* device : devices) { - if (DeviceType(device->attributes().device_type()) == d) { + if (DeviceType(device->attributes().device_type()) == + supported_device_type.first) { if (device == default_device) { filtered_default_device = device; } else { - filtered_devices.emplace_back(device); + prioritized_filtered_devices.emplace_back( + device, supported_device_type.second); } } } } - auto device_sort = [](const Device* a, const Device* b) { - auto a_priority = DeviceSet::DeviceTypeOrder(DeviceType(a->device_type())); - auto b_priority = DeviceSet::DeviceTypeOrder(DeviceType(b->device_type())); + auto device_sort = [](const std::pair& a, + const std::pair& b) { + if (a.second != b.second) { + return a.second > b.second; + } + + auto a_priority = + DeviceSet::DeviceTypeOrder(DeviceType(a.first->device_type())); + auto b_priority = + DeviceSet::DeviceTypeOrder(DeviceType(b.first->device_type())); // First sort by prioritized device type (higher is preferred) and // then by device name (lexicographically). if (a_priority != b_priority) { return a_priority > b_priority; } - return StringPiece(a->name()) < StringPiece(b->name()); + return StringPiece(a.first->name()) < StringPiece(b.first->name()); }; - std::vector::iterator sort_start; + std::sort(prioritized_filtered_devices.begin(), + prioritized_filtered_devices.end(), device_sort); + + std::vector filtered_devices; if (filtered_default_device != nullptr) { - // Put the default device first outside of the normal ordering. filtered_devices.emplace_back(filtered_default_device); - std::iter_swap(filtered_devices.begin(), std::prev(filtered_devices.end())); - sort_start = std::next(filtered_devices.begin()); - } else { - sort_start = filtered_devices.begin(); } - std::sort(sort_start, filtered_devices.end(), device_sort); + for (const auto& prioritized_filtered_device : prioritized_filtered_devices) { + filtered_devices.push_back(prioritized_filtered_device.first); + } return filtered_devices; } @@ -472,7 +481,7 @@ class ColocationGraph { // The intersection of all device types supported by this node, // and those of all of its children, in priority order // of the preferred device. - DeviceTypeVector supported_device_types; + PrioritizedDeviceTypeVector supported_device_types; // The merged form of the device requested for this node, with // those of all of its children. @@ -511,8 +520,8 @@ class ColocationGraph { const string& op_type = node->type_string(); string devices_registered; for (const auto& device_type : members_[id].supported_device_types) { - strings::StrAppend(&devices_registered, DeviceTypeString(device_type), - " "); + strings::StrAppend(&devices_registered, + DeviceTypeString(device_type.first), " "); } type_to_devices[op_type] = std::move(devices_registered); @@ -565,8 +574,9 @@ class ColocationGraph { "' does not match any device"); } - for (const DeviceType& d : member->supported_device_types) { - if (DeviceType(assigned_device->attributes().device_type()) == d) { + for (const auto& d : member->supported_device_types) { + if (DeviceType(assigned_device->attributes().device_type()) == + d.first) { return Status::OK(); } } @@ -623,24 +633,102 @@ class ColocationGraph { return Status::OK(); } + static bool HasPriorities(const PrioritizedDeviceTypeVector& device_types) { + for (const auto& prioritized_device_type : device_types) { + if (prioritized_device_type.second != 0) return true; + } + return false; + } + + static bool ArePrioritiesSame(const PrioritizedDeviceTypeVector& a_types, + const PrioritizedDeviceTypeVector& b_types) { + if (a_types.size() != b_types.size()) { + return false; + } + for (int i = 0; i < a_types.size(); ++i) { + if (a_types[i].first != b_types[i].first) { + return false; + } + } + return true; + } + // Updates target to contain the intersection of the device types in // "target" and "other". - static void MergeSupportedDevices(DeviceTypeVector* target, - const DeviceTypeVector& other) { - DeviceTypeVector temp = *target; + static void MergeSupportedDevices(PrioritizedDeviceTypeVector* target, + const PrioritizedDeviceTypeVector& other) { + PrioritizedDeviceTypeVector temp = *target; target->clear(); - // Iterate in priority order. - for (const DeviceType& device_type : temp) { + // Generate intersection with priorities. + PrioritizedDeviceTypeVector target_intersection; + PrioritizedDeviceTypeVector other_intersection; + for (const auto& prioritized_device_type : temp) { bool found = false; - for (const DeviceType& other_device_type : other) { - if (device_type == other_device_type) { + for (const auto& other_prioritized_device_type : other) { + if (prioritized_device_type.first == + other_prioritized_device_type.first) { found = true; + other_intersection.push_back(other_prioritized_device_type); break; } } if (found) { - target->push_back(device_type); + target_intersection.push_back(prioritized_device_type); + } + } + + // Sort the devices by priority order. + auto device_sort = [](const std::pair& a, + const std::pair& b) { + // First look at set priorities. + if (a.second != b.second) { + return a.second > b.second; + } + // Then fallback to default priorities. + auto a_priority = DeviceSet::DeviceTypeOrder(a.first); + auto b_priority = DeviceSet::DeviceTypeOrder(b.first); + if (a_priority != b_priority) { + return a_priority > b_priority; + } + // Finally just look at the Device type strings. + return a.first.type_string() < b.first.type_string(); + }; + + std::sort(target_intersection.begin(), target_intersection.end(), + device_sort); + std::sort(other_intersection.begin(), other_intersection.end(), + device_sort); + + bool is_target_prioritized = HasPriorities(target_intersection); + bool is_other_prioritized = HasPriorities(other_intersection); + // If neither are prioritized then we just return the original i.e. target + // prioritization. + if (!is_target_prioritized && !is_other_prioritized) { + *target = target_intersection; + } + // If only one is prioritized, then we respect priorities of that in the + // intersection. + if (is_target_prioritized && !is_other_prioritized) { + *target = target_intersection; + } + if (!is_target_prioritized && is_other_prioritized) { + *target = other_intersection; + } + // If both have priorities and agree then we go with that. If the + // prioritization order is different, then we just fallback to the default + // i.e. what the DeviceTypeOrder suggests. In that case, we also set the + // merged priorities to 0, so that downstream merges work correctly as well. + if (is_target_prioritized && is_other_prioritized) { + bool priorities_agree = + ArePrioritiesSame(target_intersection, other_intersection); + if (priorities_agree) { + *target = target_intersection; + } else { + for (const auto& prioritized_device : target_intersection) { + target->push_back(std::make_pair(prioritized_device.first, 0)); + } + std::sort(target->begin(), target->end(), device_sort); } } } diff --git a/tensorflow/core/common_runtime/placer_test.cc b/tensorflow/core/common_runtime/placer_test.cc index 69f1611c1dd..009f905f108 100644 --- a/tensorflow/core/common_runtime/placer_test.cc +++ b/tensorflow/core/common_runtime/placer_test.cc @@ -164,6 +164,13 @@ REGISTER_KERNEL_BUILDER(Name("TestDeviceEnforce").Device("FakeGPU"), DummyOp); REGISTER_KERNEL_BUILDER(Name("Shape").Device("FakeCPU"), DummyOp); REGISTER_KERNEL_BUILDER(Name("Shape").Device("FakeGPU"), DummyOp); +// Op that has kernels with device priorities specified. +REGISTER_OP("TestDatasetOp").Input("a: float").Output("b: float"); +REGISTER_KERNEL_BUILDER(Name("TestDatasetOp").Device("FakeCPU").Priority(2), + DummyOp); +REGISTER_KERNEL_BUILDER(Name("TestDatasetOp").Device("FakeGPU").Priority(1), + DummyOp); + //////////////////////////////////////////////////////////////////////////////// // // A PlacerTest method has three phases: @@ -285,6 +292,251 @@ TEST_F(PlacerTest, TestNoConstraints) { EXPECT_DEVICE_TYPE(g, "n2", "FakeGPU"); } +// Test that a graph with no constraints but using kernels that have a specified +// device priority will successfully assign nodes to the device with higher +// priority +TEST_F(PlacerTest, TestNoConstraintsWithPrioritizedKernels) { + Graph g(OpRegistry::Global()); + { // Scope for temporary variables used to construct g. + GraphDefBuilder b(GraphDefBuilder::kFailImmediately); + Node* input = ops::SourceOp("TestInput", b.opts().WithName("in")); + ops::UnaryOp("TestDatasetOp", ops::NodeOut(input, 0), + b.opts().WithName("n1")); + ops::UnaryOp("TestDatasetOp", ops::NodeOut(input, 1), + b.opts().WithName("n2")); + TF_EXPECT_OK(BuildGraph(b, &g)); + } + + TF_EXPECT_OK(Place(&g)); + EXPECT_DEVICE_TYPE(g, "in", "FakeCPU"); + EXPECT_DEVICE_TYPE(g, "n1", "FakeCPU"); + EXPECT_DEVICE_TYPE(g, "n2", "FakeCPU"); +} + +TEST_F(PlacerTest, TestGPUInputIntoPrioritizedKernel) { + Graph g(OpRegistry::Global()); + { + // Scope for temp variables used to construct g. + GraphDefBuilder b(GraphDefBuilder::kFailImmediately); + Node* input = ops::SourceOp("TestGPUOutput", b.opts().WithName("in")); + ops::UnaryOp("TestDatasetOp", ops::NodeOut(input, 0), + b.opts().WithName("n1")); + TF_EXPECT_OK(BuildGraph(b, &g)); + } + + TF_EXPECT_OK(Place(&g)); + EXPECT_DEVICE_TYPE(g, "in", "FakeGPU"); + EXPECT_DEVICE_TYPE(g, "n1", "FakeCPU"); +} + +// Tests that a GPU kernel colocated with prioritized kernel respects it. +TEST_F(PlacerTest, TestGPUInputColocatedWithPrioritizedKernel) { + Graph g(OpRegistry::Global()); + { + // Scope for temp variables used to construct g. + GraphDefBuilder b(GraphDefBuilder::kFailImmediately); + Node* input = ops::SourceOp("TestGPUOutput", b.opts().WithName("in")); + // We colocate n1 with in. + ops::UnaryOp("TestDatasetOp", ops::NodeOut(input, 0), + b.opts().WithName("n1").WithAttr("_class", {"loc:@in"})); + // We don't colocate n2 with in. + ops::UnaryOp("TestDatasetOp", ops::NodeOut(input, 0), + b.opts().WithName("n2")); + TF_EXPECT_OK(BuildGraph(b, &g)); + } + + TF_EXPECT_OK(Place(&g)); + EXPECT_DEVICE_TYPE(g, "in", "FakeGPU"); + EXPECT_DEVICE_TYPE(g, "n1", "FakeGPU"); + EXPECT_DEVICE_TYPE(g, "n2", "FakeCPU"); +} + +REGISTER_OP("CreateDatasetCPU").Output("o: resource"); +REGISTER_KERNEL_BUILDER(Name("CreateDatasetCPU").Device("FakeCPU"), DummyOp); + +REGISTER_OP("CreateDatasetSP").Output("o: resource"); +REGISTER_KERNEL_BUILDER(Name("CreateDatasetSP").Device("FakeCPU").Priority(2), + DummyOp); +REGISTER_KERNEL_BUILDER(Name("CreateDatasetSP").Device("FakeGPU").Priority(1), + DummyOp); + +REGISTER_OP("CreateDatasetRP").Output("o: resource"); +REGISTER_KERNEL_BUILDER(Name("CreateDatasetRP").Device("FakeCPU").Priority(1), + DummyOp); +REGISTER_KERNEL_BUILDER(Name("CreateDatasetRP").Device("FakeGPU").Priority(2), + DummyOp); + +REGISTER_OP("CreateDatasetNP").Output("o: resource"); +REGISTER_KERNEL_BUILDER(Name("CreateDatasetNP").Device("FakeCPU"), DummyOp); +REGISTER_KERNEL_BUILDER(Name("CreateDatasetNP").Device("FakeGPU"), DummyOp); + +REGISTER_OP("IteratorNP").Input("i: resource").Output("o: float"); +REGISTER_KERNEL_BUILDER(Name("IteratorNP").Device("FakeCPU"), DummyOp); +REGISTER_KERNEL_BUILDER(Name("IteratorNP").Device("FakeGPU"), DummyOp); + +REGISTER_OP("IteratorSP").Input("i: resource").Output("o: float"); +REGISTER_KERNEL_BUILDER(Name("IteratorSP").Device("FakeCPU").Priority(2), + DummyOp); +REGISTER_KERNEL_BUILDER(Name("IteratorSP").Device("FakeGPU").Priority(1), + DummyOp); + +REGISTER_OP("IteratorRP").Input("i: resource").Output("o: float"); +REGISTER_KERNEL_BUILDER(Name("IteratorRP").Device("FakeCPU").Priority(1), + DummyOp); +REGISTER_KERNEL_BUILDER(Name("IteratorRP").Device("FakeGPU").Priority(2), + DummyOp); + +REGISTER_OP("IteratorGPU").Input("i: resource").Output("o: float"); +REGISTER_KERNEL_BUILDER(Name("IteratorGPU").Device("FakeGPU"), DummyOp); + +// Test reference edges with one node having prioritized kernels and the other +// has no preference. We should respect priority here. +TEST_F(PlacerTest, TestDSWithPriority) { + Graph g(OpRegistry::Global()); + { + GraphDefBuilder b(GraphDefBuilder::kFailImmediately); + Node* ds = ops::SourceOp("CreateDatasetSP", b.opts().WithName("ds")); + ops::UnaryOp("IteratorNP", ops::NodeOut(ds, 0), b.opts().WithName("it")); + TF_EXPECT_OK(BuildGraph(b, &g)); + } + TF_EXPECT_OK(Place(&g)); + EXPECT_DEVICE_TYPE(g, "ds", "FakeCPU"); + EXPECT_DEVICE_TYPE(g, "it", "FakeCPU"); +} + +// Test reference edges with one node having kernels with regular priority and +// the other has no preference. We should respect priority here. +TEST_F(PlacerTest, TestDSWithGPUPriority) { + Graph g(OpRegistry::Global()); + { + GraphDefBuilder b(GraphDefBuilder::kFailImmediately); + Node* ds = ops::SourceOp("CreateDatasetRP", b.opts().WithName("ds")); + ops::UnaryOp("IteratorNP", ops::NodeOut(ds, 0), b.opts().WithName("it")); + TF_EXPECT_OK(BuildGraph(b, &g)); + } + TF_EXPECT_OK(Place(&g)); + EXPECT_DEVICE_TYPE(g, "ds", "FakeGPU"); + EXPECT_DEVICE_TYPE(g, "it", "FakeGPU"); +} + +// Test reference edges with one node having prioritized kernels and the other +// has no preference. We should respect priority here. +TEST_F(PlacerTest, TestITWithPriority) { + Graph g(OpRegistry::Global()); + { + GraphDefBuilder b(GraphDefBuilder::kFailImmediately); + Node* ds = ops::SourceOp("CreateDatasetNP", b.opts().WithName("ds")); + ops::UnaryOp("IteratorSP", ops::NodeOut(ds, 0), b.opts().WithName("it")); + TF_EXPECT_OK(BuildGraph(b, &g)); + } + TF_EXPECT_OK(Place(&g)); + EXPECT_DEVICE_TYPE(g, "ds", "FakeCPU"); + EXPECT_DEVICE_TYPE(g, "it", "FakeCPU"); +} + +// Test reference edges with one node having kernels with regular priority and +// the other has no preference. We should respect priority here. +TEST_F(PlacerTest, TestITWithGPUPriority) { + Graph g(OpRegistry::Global()); + { + GraphDefBuilder b(GraphDefBuilder::kFailImmediately); + Node* ds = ops::SourceOp("CreateDatasetNP", b.opts().WithName("ds")); + ops::UnaryOp("IteratorRP", ops::NodeOut(ds, 0), b.opts().WithName("it")); + TF_EXPECT_OK(BuildGraph(b, &g)); + } + TF_EXPECT_OK(Place(&g)); + EXPECT_DEVICE_TYPE(g, "ds", "FakeGPU"); + EXPECT_DEVICE_TYPE(g, "it", "FakeGPU"); +} + +// Test reference edges with one node having prioritized kernels and other node +// can only be placed on GPU. We should respect the constraint then. +TEST_F(PlacerTest, TestITGPU) { + Graph g(OpRegistry::Global()); + { + GraphDefBuilder b(GraphDefBuilder::kFailImmediately); + Node* ds = ops::SourceOp("CreateDatasetSP", b.opts().WithName("ds")); + ops::UnaryOp("IteratorGPU", ops::NodeOut(ds, 0), b.opts().WithName("it")); + TF_EXPECT_OK(BuildGraph(b, &g)); + } + TF_EXPECT_OK(Place(&g)); + EXPECT_DEVICE_TYPE(g, "ds", "FakeGPU"); + EXPECT_DEVICE_TYPE(g, "it", "FakeGPU"); +} + +// Test reference edges with one node having prioritized kernels and other node +// can only be placed on CPU. We should respect the constraint then. +TEST_F(PlacerTest, TestSimpleIteratorOnlyGPU) { + Graph g(OpRegistry::Global()); + { + GraphDefBuilder b(GraphDefBuilder::kFailImmediately); + Node* ds = ops::SourceOp("CreateDatasetCPU", b.opts().WithName("ds")); + ops::UnaryOp("IteratorRP", ops::NodeOut(ds, 0), b.opts().WithName("it")); + TF_EXPECT_OK(BuildGraph(b, &g)); + } + TF_EXPECT_OK(Place(&g)); + EXPECT_DEVICE_TYPE(g, "ds", "FakeCPU"); + EXPECT_DEVICE_TYPE(g, "it", "FakeCPU"); +} + +// Test constraints with agreeing priorities. +TEST_F(PlacerTest, TestAgreeingPriorities) { + Graph g(OpRegistry::Global()); + { + GraphDefBuilder b(GraphDefBuilder::kFailImmediately); + Node* ds = ops::SourceOp("CreateDatasetSP", b.opts().WithName("ds")); + ops::UnaryOp("IteratorSP", ops::NodeOut(ds, 0), b.opts().WithName("it")); + TF_EXPECT_OK(BuildGraph(b, &g)); + } + TF_EXPECT_OK(Place(&g)); + EXPECT_DEVICE_TYPE(g, "ds", "FakeCPU"); + EXPECT_DEVICE_TYPE(g, "it", "FakeCPU"); +} + +// Test constraints with agreeing regular priorities. +TEST_F(PlacerTest, TestAgreeingRegularPriorities) { + Graph g(OpRegistry::Global()); + { + GraphDefBuilder b(GraphDefBuilder::kFailImmediately); + Node* ds = ops::SourceOp("CreateDatasetRP", b.opts().WithName("ds")); + ops::UnaryOp("IteratorRP", ops::NodeOut(ds, 0), b.opts().WithName("it")); + TF_EXPECT_OK(BuildGraph(b, &g)); + } + TF_EXPECT_OK(Place(&g)); + EXPECT_DEVICE_TYPE(g, "ds", "FakeGPU"); + EXPECT_DEVICE_TYPE(g, "it", "FakeGPU"); +} + +// Test constraints with different priorities. In this case, we should bail +// and just revert to default. +TEST_F(PlacerTest, TestConflictingPriorities) { + Graph g(OpRegistry::Global()); + { + GraphDefBuilder b(GraphDefBuilder::kFailImmediately); + Node* ds = ops::SourceOp("CreateDatasetSP", b.opts().WithName("ds")); + ops::UnaryOp("IteratorRP", ops::NodeOut(ds, 0), b.opts().WithName("it")); + TF_EXPECT_OK(BuildGraph(b, &g)); + } + TF_EXPECT_OK(Place(&g)); + EXPECT_DEVICE_TYPE(g, "ds", "FakeGPU"); + EXPECT_DEVICE_TYPE(g, "it", "FakeGPU"); +} + +// Test constraints with different priorities. In this case, we should bail +// and just revert to default. +TEST_F(PlacerTest, TestConflictingPrioritiesReversed) { + Graph g(OpRegistry::Global()); + { + GraphDefBuilder b(GraphDefBuilder::kFailImmediately); + Node* ds = ops::SourceOp("CreateDatasetRP", b.opts().WithName("ds")); + ops::UnaryOp("IteratorSP", ops::NodeOut(ds, 0), b.opts().WithName("it")); + TF_EXPECT_OK(BuildGraph(b, &g)); + } + TF_EXPECT_OK(Place(&g)); + EXPECT_DEVICE_TYPE(g, "ds", "FakeGPU"); + EXPECT_DEVICE_TYPE(g, "it", "FakeGPU"); +} + // Test that a graph with device type and reference constraints on // some of the ops will successfully assign nodes to the constrained // device, and colocate nodes with reference connections. diff --git a/tensorflow/core/framework/kernel_def.proto b/tensorflow/core/framework/kernel_def.proto index e16c2ae73bd..358621dc0f5 100644 --- a/tensorflow/core/framework/kernel_def.proto +++ b/tensorflow/core/framework/kernel_def.proto @@ -33,6 +33,11 @@ message KernelDef { // won't be used unless the user specifies a "_kernel" attr with // value matching this. string label = 5; + + // Prioritization of kernel amongst different devices. By default we assume + // priority is 0. The higher the priority the better. By default (i.e. if + // this is not set), we prefer GPU kernels over CPU. + int32 priority = 6; } // A collection of KernelDefs diff --git a/tensorflow/core/framework/kernel_def_builder.cc b/tensorflow/core/framework/kernel_def_builder.cc index eb86f18ff06..fcacc3bebba 100644 --- a/tensorflow/core/framework/kernel_def_builder.cc +++ b/tensorflow/core/framework/kernel_def_builder.cc @@ -66,6 +66,11 @@ KernelDefBuilder& KernelDefBuilder::Label(const char* label) { return *this; } +KernelDefBuilder& KernelDefBuilder::Priority(int32 priority) { + kernel_def_->set_priority(priority); + return *this; +} + const KernelDef* KernelDefBuilder::Build() { KernelDef* r = kernel_def_; kernel_def_ = nullptr; diff --git a/tensorflow/core/framework/kernel_def_builder.h b/tensorflow/core/framework/kernel_def_builder.h index 32dd21f94e0..d74453cf606 100644 --- a/tensorflow/core/framework/kernel_def_builder.h +++ b/tensorflow/core/framework/kernel_def_builder.h @@ -64,6 +64,9 @@ class KernelDefBuilder { // "_kernel" attr. May only be specified once. Returns *this. KernelDefBuilder& Label(const char* label); + // Specify a priority number for this kernel. + KernelDefBuilder& Priority(int32 priority); + // Returns a pointer to a KernelDef with fields set based on the // above calls to this instance. // Caller takes ownership of the result. diff --git a/tensorflow/core/framework/op_kernel.cc b/tensorflow/core/framework/op_kernel.cc index fe71196979a..e2a177569d6 100644 --- a/tensorflow/core/framework/op_kernel.cc +++ b/tensorflow/core/framework/op_kernel.cc @@ -1091,7 +1091,7 @@ Status FindKernelDef(const DeviceType& device_type, const NodeDef& node_def, Status SupportedDeviceTypesForNode( const std::vector& prioritized_types, const NodeDef& def, - DeviceTypeVector* device_types) { + PrioritizedDeviceTypeVector* prioritized_device_types) { // TODO(zhifengc): Changes the callers (SimplePlacer and // DynamicPlacer) to consider the possibility that 'def' is call to // a user-defined function and only calls this @@ -1104,12 +1104,21 @@ Status SupportedDeviceTypesForNode( bool was_attr_mismatch; TF_RETURN_IF_ERROR( FindKernelRegistration(device_type, def, ®, &was_attr_mismatch)); - if (reg != nullptr) device_types->push_back(device_type); + if (reg != nullptr) { + int32 priority = reg->def.priority(); + prioritized_device_types->emplace_back(device_type, priority); + } } + std::sort(prioritized_device_types->begin(), + prioritized_device_types->end(), + [](const std::pair& a, + const std::pair& b) { + return a.second > b.second; + }); } else { // Assumes that all device types support this node. for (const DeviceType& device_type : prioritized_types) { - device_types->push_back(device_type); + prioritized_device_types->push_back(std::make_pair(device_type, 0)); } } return Status::OK(); diff --git a/tensorflow/core/framework/op_kernel.h b/tensorflow/core/framework/op_kernel.h index 37a437136bd..9f4c57e880a 100644 --- a/tensorflow/core/framework/op_kernel.h +++ b/tensorflow/core/framework/op_kernel.h @@ -1237,7 +1237,7 @@ Status CreateOpKernel(DeviceType device_type, DeviceBase* device, // * def has all attrs specified (e.g. using AddDefaultsToNodeDef()). Status SupportedDeviceTypesForNode( const std::vector& prioritized_types, const NodeDef& def, - DeviceTypeVector* device_types); + PrioritizedDeviceTypeVector* device_types); // Returns a message with a description of the kernels registered for op // `op_name`. diff --git a/tensorflow/core/framework/op_kernel_test.cc b/tensorflow/core/framework/op_kernel_test.cc index 83dda6579b7..d8001cd0710 100644 --- a/tensorflow/core/framework/op_kernel_test.cc +++ b/tensorflow/core/framework/op_kernel_test.cc @@ -102,6 +102,27 @@ REGISTER_OP("Test4").Input("i: float").Output("o: float"); REGISTER_KERNEL_BUILDER(Name("Test4").Device(DEVICE_CPU), DummyKernel); REGISTER_KERNEL_BUILDER(Name("Test4").Device(DEVICE_GPU), DummyKernel); +// Kernels with different priorities. +REGISTER_OP("Test5").Input("a: T").Input("b: T").Attr("T: type"); + +class TestOp5Cpu : public tensorflow::OpKernel { + public: + explicit TestOp5Cpu(OpKernelConstruction* context) : OpKernel(context) {} + void Compute(OpKernelContext* context) override {} +}; + +REGISTER_KERNEL_BUILDER(Name("Test5").Device(DEVICE_CPU).Priority(2), + TestOp5Cpu); + +class TestOp5Gpu : public tensorflow::OpKernel { + public: + explicit TestOp5Gpu(OpKernelConstruction* context) : OpKernel(context) {} + void Compute(OpKernelContext* context) override {} +}; + +REGISTER_KERNEL_BUILDER(Name("Test5").Device(DEVICE_GPU).Priority(1), + TestOp5Gpu); + static std::vector DeviceTypes() { return {DeviceType(DEVICE_GPU), DeviceType(DEVICE_CPU)}; } @@ -185,10 +206,10 @@ TEST_F(OpKernelTest, SuccessBothCpuAndGpu) { TEST_F(OpKernelTest, CpuTypeRegistered) { NodeDef ndef = CreateNodeDef("Test1", {DT_FLOAT, DT_INT32}); - DeviceTypeVector devs; + PrioritizedDeviceTypeVector devs; TF_ASSERT_OK(SupportedDeviceTypesForNode(DeviceTypes(), ndef, &devs)); EXPECT_EQ(1, devs.size()); - EXPECT_EQ(DeviceType(DEVICE_CPU), devs[0]); + EXPECT_EQ(DeviceType(DEVICE_CPU), devs[0].first); } TEST_F(OpKernelTest, CpuAndGpuTypeRegistered) { @@ -196,24 +217,24 @@ TEST_F(OpKernelTest, CpuAndGpuTypeRegistered) { // Try a node def of an op that is registered for a specific type // only on CPU. NodeDef ndef = CreateNodeDef("Test3", {DT_INT8, DT_INT8}); - DeviceTypeVector devs; + PrioritizedDeviceTypeVector devs; TF_ASSERT_OK(SupportedDeviceTypesForNode(DeviceTypes(), ndef, &devs)); EXPECT_EQ(1, devs.size()); - EXPECT_EQ(DeviceType(DEVICE_CPU), devs[0]); + EXPECT_EQ(DeviceType(DEVICE_CPU), devs[0].first); } { // Try a node def of an op that is registered for a specific type // only on GPU. NodeDef ndef = CreateNodeDef("Test3", {DT_FLOAT, DT_FLOAT}); - DeviceTypeVector devs; + PrioritizedDeviceTypeVector devs; TF_ASSERT_OK(SupportedDeviceTypesForNode(DeviceTypes(), ndef, &devs)); EXPECT_EQ(1, devs.size()); - EXPECT_EQ(DeviceType(DEVICE_GPU), devs[0]); + EXPECT_EQ(DeviceType(DEVICE_GPU), devs[0].first); } { // Try a node def of an op that is only registered for other types. NodeDef ndef = CreateNodeDef("Test3", {DT_STRING, DT_STRING}); - DeviceTypeVector devs; + PrioritizedDeviceTypeVector devs; TF_ASSERT_OK(SupportedDeviceTypesForNode(DeviceTypes(), ndef, &devs)); EXPECT_EQ(0, devs.size()); } @@ -221,11 +242,23 @@ TEST_F(OpKernelTest, CpuAndGpuTypeRegistered) { { // Try a node def of an op that is registered for both. NodeDef ndef = CreateNodeDef("Test4", {DT_FLOAT}); - DeviceTypeVector devs; + PrioritizedDeviceTypeVector devs; TF_ASSERT_OK(SupportedDeviceTypesForNode(DeviceTypes(), ndef, &devs)); EXPECT_EQ(2, devs.size()); - EXPECT_EQ(DeviceType(DEVICE_GPU), devs[0]); - EXPECT_EQ(DeviceType(DEVICE_CPU), devs[1]); + EXPECT_EQ(DeviceType(DEVICE_GPU), devs[0].first); + EXPECT_EQ(DeviceType(DEVICE_CPU), devs[1].first); + } + + { + // Try a node def of an op where kernels have priorities. + NodeDef ndef = CreateNodeDef("Test5", {DT_STRING, DT_STRING}); + PrioritizedDeviceTypeVector devs; + TF_ASSERT_OK(SupportedDeviceTypesForNode(DeviceTypes(), ndef, &devs)); + EXPECT_EQ(2, devs.size()); + EXPECT_EQ(DeviceType(DEVICE_CPU), devs[0].first); + EXPECT_EQ(2, devs[0].second); + EXPECT_EQ(DeviceType(DEVICE_GPU), devs[1].first); + EXPECT_EQ(1, devs[1].second); } } @@ -412,11 +445,11 @@ class OpKernelBuilderTest : public ::testing::Test { } // Test SupportedDeviceTypesForNode() - DeviceTypeVector devices; + PrioritizedDeviceTypeVector devices; TF_EXPECT_OK(SupportedDeviceTypesForNode(DeviceTypes(), def, &devices)); bool found = false; - for (const DeviceType& dt : devices) { - if (dt == device_type) { + for (const auto& dt : devices) { + if (dt.first == device_type) { found = true; } } @@ -445,11 +478,11 @@ class OpKernelBuilderTest : public ::testing::Test { EXPECT_EQ(code, status.code()); // Test SupportedDeviceTypesForNode(). - DeviceTypeVector devices; + PrioritizedDeviceTypeVector devices; if (errors::IsNotFound(status)) { TF_EXPECT_OK(SupportedDeviceTypesForNode(DeviceTypes(), def, &devices)); - for (const DeviceType& dt : devices) { - EXPECT_NE(dt, device_type); + for (const auto& dt : devices) { + EXPECT_NE(dt.first, device_type); } } else { Status status2 = @@ -562,7 +595,7 @@ REGISTER_KERNEL_BUILDER(Name("DuplicateKernel").Device(DEVICE_CPU), TEST_F(OpKernelBuilderTest, DuplicateKernel) { const NodeDef ndef = CreateNodeDef("DuplicateKernel", {}); - DeviceTypeVector devs; + PrioritizedDeviceTypeVector devs; Status status = SupportedDeviceTypesForNode(DeviceTypes(), ndef, &devs); ASSERT_FALSE(status.ok()); EXPECT_TRUE(str_util::StrContains( @@ -582,7 +615,7 @@ REGISTER_KERNEL_BUILDER( TEST_F(OpKernelBuilderTest, DuplicateKernelForT) { const NodeDef ndef = CreateNodeDef("DuplicateKernelForT", {"T|type|DT_FLOAT"}); - DeviceTypeVector devs; + PrioritizedDeviceTypeVector devs; Status status = SupportedDeviceTypesForNode(DeviceTypes(), ndef, &devs); ASSERT_FALSE(status.ok()); EXPECT_TRUE(str_util::StrContains( @@ -603,7 +636,7 @@ REGISTER_KERNEL_BUILDER(Name("BadConstraint") TEST_F(OpKernelBuilderTest, BadConstraint) { const NodeDef ndef = CreateNodeDef("BadConstraint", {}); - DeviceTypeVector devs; + PrioritizedDeviceTypeVector devs; Status status = SupportedDeviceTypesForNode(DeviceTypes(), ndef, &devs); ASSERT_FALSE(status.ok()); EXPECT_TRUE( diff --git a/tensorflow/core/framework/types.h b/tensorflow/core/framework/types.h index a05dea19ec4..c0df1933421 100644 --- a/tensorflow/core/framework/types.h +++ b/tensorflow/core/framework/types.h @@ -104,6 +104,8 @@ typedef gtl::InlinedVector DataTypeVector; typedef gtl::ArraySlice DataTypeSlice; typedef gtl::InlinedVector DeviceTypeVector; +typedef gtl::InlinedVector, 4> + PrioritizedDeviceTypeVector; // Convert the enums to strings for errors: string DataTypeString(DataType dtype);