From 8cbd3301ec04145e7b90118ae5e0c44c75fd5c06 Mon Sep 17 00:00:00 2001 From: Rohan Jain Date: Tue, 13 Nov 2018 07:34:20 -0800 Subject: [PATCH] Adding a priority option to kernels in order to be able to override Device selection prioritization. For the regular case, we always prefer to place on the GPU but for tf.data pipelines when no device is specified, we'd like to place it on the CPU actually since that is what the user usually intends. As we extend parts of the tf.data pipeline to run on GPU's / TPU's etc., this will become more and more necessary. PiperOrigin-RevId: 221264064 --- .../core/common_runtime/eager/execute.cc | 4 +- tensorflow/core/common_runtime/placer.cc | 146 ++++++++-- tensorflow/core/common_runtime/placer_test.cc | 252 ++++++++++++++++++ tensorflow/core/framework/kernel_def.proto | 5 + .../core/framework/kernel_def_builder.cc | 5 + .../core/framework/kernel_def_builder.h | 3 + tensorflow/core/framework/op_kernel.cc | 15 +- tensorflow/core/framework/op_kernel.h | 2 +- tensorflow/core/framework/op_kernel_test.cc | 71 +++-- tensorflow/core/framework/types.h | 2 + 10 files changed, 451 insertions(+), 54 deletions(-) diff --git a/tensorflow/core/common_runtime/eager/execute.cc b/tensorflow/core/common_runtime/eager/execute.cc index 3cdda3ed753..a708033c650 100644 --- a/tensorflow/core/common_runtime/eager/execute.cc +++ b/tensorflow/core/common_runtime/eager/execute.cc @@ -192,7 +192,7 @@ Status ValidateInputTypeAndPlacement(EagerContext* ctx, Device* op_device, } Status SelectDevice(const NodeDef& ndef, EagerContext* ctx, Device** device) { - DeviceTypeVector final_devices; + PrioritizedDeviceTypeVector final_devices; TF_RETURN_IF_ERROR(SupportedDeviceTypesForNode( ctx->prioritized_device_type_list(), ndef, &final_devices)); if (final_devices.empty()) { @@ -202,7 +202,7 @@ Status SelectDevice(const NodeDef& ndef, EagerContext* ctx, Device** device) { " :\n", KernelsRegisteredForOp(ndef.op())); } for (Device* d : *ctx->devices()) { - if (d->device_type() == final_devices[0].type_string()) { + if (d->device_type() == final_devices[0].first.type_string()) { *device = d; return Status::OK(); } diff --git a/tensorflow/core/common_runtime/placer.cc b/tensorflow/core/common_runtime/placer.cc index f8d933b45e0..01e4072f603 100644 --- a/tensorflow/core/common_runtime/placer.cc +++ b/tensorflow/core/common_runtime/placer.cc @@ -47,42 +47,51 @@ const StringPiece kColocationGroupPrefixStringPiece(kColocationGroupPrefix); // returned list is sorted by preferred type (higher numeric type is preferred). std::vector FilterSupportedDevices( const std::vector& devices, - const DeviceTypeVector& supported_device_types, + const PrioritizedDeviceTypeVector& supported_device_types, const Device* default_device) { Device* filtered_default_device = nullptr; - std::vector filtered_devices; - for (const DeviceType& d : supported_device_types) { + std::vector> prioritized_filtered_devices; + for (const auto& supported_device_type : supported_device_types) { for (Device* device : devices) { - if (DeviceType(device->attributes().device_type()) == d) { + if (DeviceType(device->attributes().device_type()) == + supported_device_type.first) { if (device == default_device) { filtered_default_device = device; } else { - filtered_devices.emplace_back(device); + prioritized_filtered_devices.emplace_back( + device, supported_device_type.second); } } } } - auto device_sort = [](const Device* a, const Device* b) { - auto a_priority = DeviceSet::DeviceTypeOrder(DeviceType(a->device_type())); - auto b_priority = DeviceSet::DeviceTypeOrder(DeviceType(b->device_type())); + auto device_sort = [](const std::pair& a, + const std::pair& b) { + if (a.second != b.second) { + return a.second > b.second; + } + + auto a_priority = + DeviceSet::DeviceTypeOrder(DeviceType(a.first->device_type())); + auto b_priority = + DeviceSet::DeviceTypeOrder(DeviceType(b.first->device_type())); // First sort by prioritized device type (higher is preferred) and // then by device name (lexicographically). if (a_priority != b_priority) { return a_priority > b_priority; } - return StringPiece(a->name()) < StringPiece(b->name()); + return StringPiece(a.first->name()) < StringPiece(b.first->name()); }; - std::vector::iterator sort_start; + std::sort(prioritized_filtered_devices.begin(), + prioritized_filtered_devices.end(), device_sort); + + std::vector filtered_devices; if (filtered_default_device != nullptr) { - // Put the default device first outside of the normal ordering. filtered_devices.emplace_back(filtered_default_device); - std::iter_swap(filtered_devices.begin(), std::prev(filtered_devices.end())); - sort_start = std::next(filtered_devices.begin()); - } else { - sort_start = filtered_devices.begin(); } - std::sort(sort_start, filtered_devices.end(), device_sort); + for (const auto& prioritized_filtered_device : prioritized_filtered_devices) { + filtered_devices.push_back(prioritized_filtered_device.first); + } return filtered_devices; } @@ -472,7 +481,7 @@ class ColocationGraph { // The intersection of all device types supported by this node, // and those of all of its children, in priority order // of the preferred device. - DeviceTypeVector supported_device_types; + PrioritizedDeviceTypeVector supported_device_types; // The merged form of the device requested for this node, with // those of all of its children. @@ -511,8 +520,8 @@ class ColocationGraph { const string& op_type = node->type_string(); string devices_registered; for (const auto& device_type : members_[id].supported_device_types) { - strings::StrAppend(&devices_registered, DeviceTypeString(device_type), - " "); + strings::StrAppend(&devices_registered, + DeviceTypeString(device_type.first), " "); } type_to_devices[op_type] = std::move(devices_registered); @@ -565,8 +574,9 @@ class ColocationGraph { "' does not match any device"); } - for (const DeviceType& d : member->supported_device_types) { - if (DeviceType(assigned_device->attributes().device_type()) == d) { + for (const auto& d : member->supported_device_types) { + if (DeviceType(assigned_device->attributes().device_type()) == + d.first) { return Status::OK(); } } @@ -623,24 +633,102 @@ class ColocationGraph { return Status::OK(); } + static bool HasPriorities(const PrioritizedDeviceTypeVector& device_types) { + for (const auto& prioritized_device_type : device_types) { + if (prioritized_device_type.second != 0) return true; + } + return false; + } + + static bool ArePrioritiesSame(const PrioritizedDeviceTypeVector& a_types, + const PrioritizedDeviceTypeVector& b_types) { + if (a_types.size() != b_types.size()) { + return false; + } + for (int i = 0; i < a_types.size(); ++i) { + if (a_types[i].first != b_types[i].first) { + return false; + } + } + return true; + } + // Updates target to contain the intersection of the device types in // "target" and "other". - static void MergeSupportedDevices(DeviceTypeVector* target, - const DeviceTypeVector& other) { - DeviceTypeVector temp = *target; + static void MergeSupportedDevices(PrioritizedDeviceTypeVector* target, + const PrioritizedDeviceTypeVector& other) { + PrioritizedDeviceTypeVector temp = *target; target->clear(); - // Iterate in priority order. - for (const DeviceType& device_type : temp) { + // Generate intersection with priorities. + PrioritizedDeviceTypeVector target_intersection; + PrioritizedDeviceTypeVector other_intersection; + for (const auto& prioritized_device_type : temp) { bool found = false; - for (const DeviceType& other_device_type : other) { - if (device_type == other_device_type) { + for (const auto& other_prioritized_device_type : other) { + if (prioritized_device_type.first == + other_prioritized_device_type.first) { found = true; + other_intersection.push_back(other_prioritized_device_type); break; } } if (found) { - target->push_back(device_type); + target_intersection.push_back(prioritized_device_type); + } + } + + // Sort the devices by priority order. + auto device_sort = [](const std::pair& a, + const std::pair& b) { + // First look at set priorities. + if (a.second != b.second) { + return a.second > b.second; + } + // Then fallback to default priorities. + auto a_priority = DeviceSet::DeviceTypeOrder(a.first); + auto b_priority = DeviceSet::DeviceTypeOrder(b.first); + if (a_priority != b_priority) { + return a_priority > b_priority; + } + // Finally just look at the Device type strings. + return a.first.type_string() < b.first.type_string(); + }; + + std::sort(target_intersection.begin(), target_intersection.end(), + device_sort); + std::sort(other_intersection.begin(), other_intersection.end(), + device_sort); + + bool is_target_prioritized = HasPriorities(target_intersection); + bool is_other_prioritized = HasPriorities(other_intersection); + // If neither are prioritized then we just return the original i.e. target + // prioritization. + if (!is_target_prioritized && !is_other_prioritized) { + *target = target_intersection; + } + // If only one is prioritized, then we respect priorities of that in the + // intersection. + if (is_target_prioritized && !is_other_prioritized) { + *target = target_intersection; + } + if (!is_target_prioritized && is_other_prioritized) { + *target = other_intersection; + } + // If both have priorities and agree then we go with that. If the + // prioritization order is different, then we just fallback to the default + // i.e. what the DeviceTypeOrder suggests. In that case, we also set the + // merged priorities to 0, so that downstream merges work correctly as well. + if (is_target_prioritized && is_other_prioritized) { + bool priorities_agree = + ArePrioritiesSame(target_intersection, other_intersection); + if (priorities_agree) { + *target = target_intersection; + } else { + for (const auto& prioritized_device : target_intersection) { + target->push_back(std::make_pair(prioritized_device.first, 0)); + } + std::sort(target->begin(), target->end(), device_sort); } } } diff --git a/tensorflow/core/common_runtime/placer_test.cc b/tensorflow/core/common_runtime/placer_test.cc index 69f1611c1dd..009f905f108 100644 --- a/tensorflow/core/common_runtime/placer_test.cc +++ b/tensorflow/core/common_runtime/placer_test.cc @@ -164,6 +164,13 @@ REGISTER_KERNEL_BUILDER(Name("TestDeviceEnforce").Device("FakeGPU"), DummyOp); REGISTER_KERNEL_BUILDER(Name("Shape").Device("FakeCPU"), DummyOp); REGISTER_KERNEL_BUILDER(Name("Shape").Device("FakeGPU"), DummyOp); +// Op that has kernels with device priorities specified. +REGISTER_OP("TestDatasetOp").Input("a: float").Output("b: float"); +REGISTER_KERNEL_BUILDER(Name("TestDatasetOp").Device("FakeCPU").Priority(2), + DummyOp); +REGISTER_KERNEL_BUILDER(Name("TestDatasetOp").Device("FakeGPU").Priority(1), + DummyOp); + //////////////////////////////////////////////////////////////////////////////// // // A PlacerTest method has three phases: @@ -285,6 +292,251 @@ TEST_F(PlacerTest, TestNoConstraints) { EXPECT_DEVICE_TYPE(g, "n2", "FakeGPU"); } +// Test that a graph with no constraints but using kernels that have a specified +// device priority will successfully assign nodes to the device with higher +// priority +TEST_F(PlacerTest, TestNoConstraintsWithPrioritizedKernels) { + Graph g(OpRegistry::Global()); + { // Scope for temporary variables used to construct g. + GraphDefBuilder b(GraphDefBuilder::kFailImmediately); + Node* input = ops::SourceOp("TestInput", b.opts().WithName("in")); + ops::UnaryOp("TestDatasetOp", ops::NodeOut(input, 0), + b.opts().WithName("n1")); + ops::UnaryOp("TestDatasetOp", ops::NodeOut(input, 1), + b.opts().WithName("n2")); + TF_EXPECT_OK(BuildGraph(b, &g)); + } + + TF_EXPECT_OK(Place(&g)); + EXPECT_DEVICE_TYPE(g, "in", "FakeCPU"); + EXPECT_DEVICE_TYPE(g, "n1", "FakeCPU"); + EXPECT_DEVICE_TYPE(g, "n2", "FakeCPU"); +} + +TEST_F(PlacerTest, TestGPUInputIntoPrioritizedKernel) { + Graph g(OpRegistry::Global()); + { + // Scope for temp variables used to construct g. + GraphDefBuilder b(GraphDefBuilder::kFailImmediately); + Node* input = ops::SourceOp("TestGPUOutput", b.opts().WithName("in")); + ops::UnaryOp("TestDatasetOp", ops::NodeOut(input, 0), + b.opts().WithName("n1")); + TF_EXPECT_OK(BuildGraph(b, &g)); + } + + TF_EXPECT_OK(Place(&g)); + EXPECT_DEVICE_TYPE(g, "in", "FakeGPU"); + EXPECT_DEVICE_TYPE(g, "n1", "FakeCPU"); +} + +// Tests that a GPU kernel colocated with prioritized kernel respects it. +TEST_F(PlacerTest, TestGPUInputColocatedWithPrioritizedKernel) { + Graph g(OpRegistry::Global()); + { + // Scope for temp variables used to construct g. + GraphDefBuilder b(GraphDefBuilder::kFailImmediately); + Node* input = ops::SourceOp("TestGPUOutput", b.opts().WithName("in")); + // We colocate n1 with in. + ops::UnaryOp("TestDatasetOp", ops::NodeOut(input, 0), + b.opts().WithName("n1").WithAttr("_class", {"loc:@in"})); + // We don't colocate n2 with in. + ops::UnaryOp("TestDatasetOp", ops::NodeOut(input, 0), + b.opts().WithName("n2")); + TF_EXPECT_OK(BuildGraph(b, &g)); + } + + TF_EXPECT_OK(Place(&g)); + EXPECT_DEVICE_TYPE(g, "in", "FakeGPU"); + EXPECT_DEVICE_TYPE(g, "n1", "FakeGPU"); + EXPECT_DEVICE_TYPE(g, "n2", "FakeCPU"); +} + +REGISTER_OP("CreateDatasetCPU").Output("o: resource"); +REGISTER_KERNEL_BUILDER(Name("CreateDatasetCPU").Device("FakeCPU"), DummyOp); + +REGISTER_OP("CreateDatasetSP").Output("o: resource"); +REGISTER_KERNEL_BUILDER(Name("CreateDatasetSP").Device("FakeCPU").Priority(2), + DummyOp); +REGISTER_KERNEL_BUILDER(Name("CreateDatasetSP").Device("FakeGPU").Priority(1), + DummyOp); + +REGISTER_OP("CreateDatasetRP").Output("o: resource"); +REGISTER_KERNEL_BUILDER(Name("CreateDatasetRP").Device("FakeCPU").Priority(1), + DummyOp); +REGISTER_KERNEL_BUILDER(Name("CreateDatasetRP").Device("FakeGPU").Priority(2), + DummyOp); + +REGISTER_OP("CreateDatasetNP").Output("o: resource"); +REGISTER_KERNEL_BUILDER(Name("CreateDatasetNP").Device("FakeCPU"), DummyOp); +REGISTER_KERNEL_BUILDER(Name("CreateDatasetNP").Device("FakeGPU"), DummyOp); + +REGISTER_OP("IteratorNP").Input("i: resource").Output("o: float"); +REGISTER_KERNEL_BUILDER(Name("IteratorNP").Device("FakeCPU"), DummyOp); +REGISTER_KERNEL_BUILDER(Name("IteratorNP").Device("FakeGPU"), DummyOp); + +REGISTER_OP("IteratorSP").Input("i: resource").Output("o: float"); +REGISTER_KERNEL_BUILDER(Name("IteratorSP").Device("FakeCPU").Priority(2), + DummyOp); +REGISTER_KERNEL_BUILDER(Name("IteratorSP").Device("FakeGPU").Priority(1), + DummyOp); + +REGISTER_OP("IteratorRP").Input("i: resource").Output("o: float"); +REGISTER_KERNEL_BUILDER(Name("IteratorRP").Device("FakeCPU").Priority(1), + DummyOp); +REGISTER_KERNEL_BUILDER(Name("IteratorRP").Device("FakeGPU").Priority(2), + DummyOp); + +REGISTER_OP("IteratorGPU").Input("i: resource").Output("o: float"); +REGISTER_KERNEL_BUILDER(Name("IteratorGPU").Device("FakeGPU"), DummyOp); + +// Test reference edges with one node having prioritized kernels and the other +// has no preference. We should respect priority here. +TEST_F(PlacerTest, TestDSWithPriority) { + Graph g(OpRegistry::Global()); + { + GraphDefBuilder b(GraphDefBuilder::kFailImmediately); + Node* ds = ops::SourceOp("CreateDatasetSP", b.opts().WithName("ds")); + ops::UnaryOp("IteratorNP", ops::NodeOut(ds, 0), b.opts().WithName("it")); + TF_EXPECT_OK(BuildGraph(b, &g)); + } + TF_EXPECT_OK(Place(&g)); + EXPECT_DEVICE_TYPE(g, "ds", "FakeCPU"); + EXPECT_DEVICE_TYPE(g, "it", "FakeCPU"); +} + +// Test reference edges with one node having kernels with regular priority and +// the other has no preference. We should respect priority here. +TEST_F(PlacerTest, TestDSWithGPUPriority) { + Graph g(OpRegistry::Global()); + { + GraphDefBuilder b(GraphDefBuilder::kFailImmediately); + Node* ds = ops::SourceOp("CreateDatasetRP", b.opts().WithName("ds")); + ops::UnaryOp("IteratorNP", ops::NodeOut(ds, 0), b.opts().WithName("it")); + TF_EXPECT_OK(BuildGraph(b, &g)); + } + TF_EXPECT_OK(Place(&g)); + EXPECT_DEVICE_TYPE(g, "ds", "FakeGPU"); + EXPECT_DEVICE_TYPE(g, "it", "FakeGPU"); +} + +// Test reference edges with one node having prioritized kernels and the other +// has no preference. We should respect priority here. +TEST_F(PlacerTest, TestITWithPriority) { + Graph g(OpRegistry::Global()); + { + GraphDefBuilder b(GraphDefBuilder::kFailImmediately); + Node* ds = ops::SourceOp("CreateDatasetNP", b.opts().WithName("ds")); + ops::UnaryOp("IteratorSP", ops::NodeOut(ds, 0), b.opts().WithName("it")); + TF_EXPECT_OK(BuildGraph(b, &g)); + } + TF_EXPECT_OK(Place(&g)); + EXPECT_DEVICE_TYPE(g, "ds", "FakeCPU"); + EXPECT_DEVICE_TYPE(g, "it", "FakeCPU"); +} + +// Test reference edges with one node having kernels with regular priority and +// the other has no preference. We should respect priority here. +TEST_F(PlacerTest, TestITWithGPUPriority) { + Graph g(OpRegistry::Global()); + { + GraphDefBuilder b(GraphDefBuilder::kFailImmediately); + Node* ds = ops::SourceOp("CreateDatasetNP", b.opts().WithName("ds")); + ops::UnaryOp("IteratorRP", ops::NodeOut(ds, 0), b.opts().WithName("it")); + TF_EXPECT_OK(BuildGraph(b, &g)); + } + TF_EXPECT_OK(Place(&g)); + EXPECT_DEVICE_TYPE(g, "ds", "FakeGPU"); + EXPECT_DEVICE_TYPE(g, "it", "FakeGPU"); +} + +// Test reference edges with one node having prioritized kernels and other node +// can only be placed on GPU. We should respect the constraint then. +TEST_F(PlacerTest, TestITGPU) { + Graph g(OpRegistry::Global()); + { + GraphDefBuilder b(GraphDefBuilder::kFailImmediately); + Node* ds = ops::SourceOp("CreateDatasetSP", b.opts().WithName("ds")); + ops::UnaryOp("IteratorGPU", ops::NodeOut(ds, 0), b.opts().WithName("it")); + TF_EXPECT_OK(BuildGraph(b, &g)); + } + TF_EXPECT_OK(Place(&g)); + EXPECT_DEVICE_TYPE(g, "ds", "FakeGPU"); + EXPECT_DEVICE_TYPE(g, "it", "FakeGPU"); +} + +// Test reference edges with one node having prioritized kernels and other node +// can only be placed on CPU. We should respect the constraint then. +TEST_F(PlacerTest, TestSimpleIteratorOnlyGPU) { + Graph g(OpRegistry::Global()); + { + GraphDefBuilder b(GraphDefBuilder::kFailImmediately); + Node* ds = ops::SourceOp("CreateDatasetCPU", b.opts().WithName("ds")); + ops::UnaryOp("IteratorRP", ops::NodeOut(ds, 0), b.opts().WithName("it")); + TF_EXPECT_OK(BuildGraph(b, &g)); + } + TF_EXPECT_OK(Place(&g)); + EXPECT_DEVICE_TYPE(g, "ds", "FakeCPU"); + EXPECT_DEVICE_TYPE(g, "it", "FakeCPU"); +} + +// Test constraints with agreeing priorities. +TEST_F(PlacerTest, TestAgreeingPriorities) { + Graph g(OpRegistry::Global()); + { + GraphDefBuilder b(GraphDefBuilder::kFailImmediately); + Node* ds = ops::SourceOp("CreateDatasetSP", b.opts().WithName("ds")); + ops::UnaryOp("IteratorSP", ops::NodeOut(ds, 0), b.opts().WithName("it")); + TF_EXPECT_OK(BuildGraph(b, &g)); + } + TF_EXPECT_OK(Place(&g)); + EXPECT_DEVICE_TYPE(g, "ds", "FakeCPU"); + EXPECT_DEVICE_TYPE(g, "it", "FakeCPU"); +} + +// Test constraints with agreeing regular priorities. +TEST_F(PlacerTest, TestAgreeingRegularPriorities) { + Graph g(OpRegistry::Global()); + { + GraphDefBuilder b(GraphDefBuilder::kFailImmediately); + Node* ds = ops::SourceOp("CreateDatasetRP", b.opts().WithName("ds")); + ops::UnaryOp("IteratorRP", ops::NodeOut(ds, 0), b.opts().WithName("it")); + TF_EXPECT_OK(BuildGraph(b, &g)); + } + TF_EXPECT_OK(Place(&g)); + EXPECT_DEVICE_TYPE(g, "ds", "FakeGPU"); + EXPECT_DEVICE_TYPE(g, "it", "FakeGPU"); +} + +// Test constraints with different priorities. In this case, we should bail +// and just revert to default. +TEST_F(PlacerTest, TestConflictingPriorities) { + Graph g(OpRegistry::Global()); + { + GraphDefBuilder b(GraphDefBuilder::kFailImmediately); + Node* ds = ops::SourceOp("CreateDatasetSP", b.opts().WithName("ds")); + ops::UnaryOp("IteratorRP", ops::NodeOut(ds, 0), b.opts().WithName("it")); + TF_EXPECT_OK(BuildGraph(b, &g)); + } + TF_EXPECT_OK(Place(&g)); + EXPECT_DEVICE_TYPE(g, "ds", "FakeGPU"); + EXPECT_DEVICE_TYPE(g, "it", "FakeGPU"); +} + +// Test constraints with different priorities. In this case, we should bail +// and just revert to default. +TEST_F(PlacerTest, TestConflictingPrioritiesReversed) { + Graph g(OpRegistry::Global()); + { + GraphDefBuilder b(GraphDefBuilder::kFailImmediately); + Node* ds = ops::SourceOp("CreateDatasetRP", b.opts().WithName("ds")); + ops::UnaryOp("IteratorSP", ops::NodeOut(ds, 0), b.opts().WithName("it")); + TF_EXPECT_OK(BuildGraph(b, &g)); + } + TF_EXPECT_OK(Place(&g)); + EXPECT_DEVICE_TYPE(g, "ds", "FakeGPU"); + EXPECT_DEVICE_TYPE(g, "it", "FakeGPU"); +} + // Test that a graph with device type and reference constraints on // some of the ops will successfully assign nodes to the constrained // device, and colocate nodes with reference connections. diff --git a/tensorflow/core/framework/kernel_def.proto b/tensorflow/core/framework/kernel_def.proto index e16c2ae73bd..358621dc0f5 100644 --- a/tensorflow/core/framework/kernel_def.proto +++ b/tensorflow/core/framework/kernel_def.proto @@ -33,6 +33,11 @@ message KernelDef { // won't be used unless the user specifies a "_kernel" attr with // value matching this. string label = 5; + + // Prioritization of kernel amongst different devices. By default we assume + // priority is 0. The higher the priority the better. By default (i.e. if + // this is not set), we prefer GPU kernels over CPU. + int32 priority = 6; } // A collection of KernelDefs diff --git a/tensorflow/core/framework/kernel_def_builder.cc b/tensorflow/core/framework/kernel_def_builder.cc index eb86f18ff06..fcacc3bebba 100644 --- a/tensorflow/core/framework/kernel_def_builder.cc +++ b/tensorflow/core/framework/kernel_def_builder.cc @@ -66,6 +66,11 @@ KernelDefBuilder& KernelDefBuilder::Label(const char* label) { return *this; } +KernelDefBuilder& KernelDefBuilder::Priority(int32 priority) { + kernel_def_->set_priority(priority); + return *this; +} + const KernelDef* KernelDefBuilder::Build() { KernelDef* r = kernel_def_; kernel_def_ = nullptr; diff --git a/tensorflow/core/framework/kernel_def_builder.h b/tensorflow/core/framework/kernel_def_builder.h index 32dd21f94e0..d74453cf606 100644 --- a/tensorflow/core/framework/kernel_def_builder.h +++ b/tensorflow/core/framework/kernel_def_builder.h @@ -64,6 +64,9 @@ class KernelDefBuilder { // "_kernel" attr. May only be specified once. Returns *this. KernelDefBuilder& Label(const char* label); + // Specify a priority number for this kernel. + KernelDefBuilder& Priority(int32 priority); + // Returns a pointer to a KernelDef with fields set based on the // above calls to this instance. // Caller takes ownership of the result. diff --git a/tensorflow/core/framework/op_kernel.cc b/tensorflow/core/framework/op_kernel.cc index fe71196979a..e2a177569d6 100644 --- a/tensorflow/core/framework/op_kernel.cc +++ b/tensorflow/core/framework/op_kernel.cc @@ -1091,7 +1091,7 @@ Status FindKernelDef(const DeviceType& device_type, const NodeDef& node_def, Status SupportedDeviceTypesForNode( const std::vector& prioritized_types, const NodeDef& def, - DeviceTypeVector* device_types) { + PrioritizedDeviceTypeVector* prioritized_device_types) { // TODO(zhifengc): Changes the callers (SimplePlacer and // DynamicPlacer) to consider the possibility that 'def' is call to // a user-defined function and only calls this @@ -1104,12 +1104,21 @@ Status SupportedDeviceTypesForNode( bool was_attr_mismatch; TF_RETURN_IF_ERROR( FindKernelRegistration(device_type, def, ®, &was_attr_mismatch)); - if (reg != nullptr) device_types->push_back(device_type); + if (reg != nullptr) { + int32 priority = reg->def.priority(); + prioritized_device_types->emplace_back(device_type, priority); + } } + std::sort(prioritized_device_types->begin(), + prioritized_device_types->end(), + [](const std::pair& a, + const std::pair& b) { + return a.second > b.second; + }); } else { // Assumes that all device types support this node. for (const DeviceType& device_type : prioritized_types) { - device_types->push_back(device_type); + prioritized_device_types->push_back(std::make_pair(device_type, 0)); } } return Status::OK(); diff --git a/tensorflow/core/framework/op_kernel.h b/tensorflow/core/framework/op_kernel.h index 37a437136bd..9f4c57e880a 100644 --- a/tensorflow/core/framework/op_kernel.h +++ b/tensorflow/core/framework/op_kernel.h @@ -1237,7 +1237,7 @@ Status CreateOpKernel(DeviceType device_type, DeviceBase* device, // * def has all attrs specified (e.g. using AddDefaultsToNodeDef()). Status SupportedDeviceTypesForNode( const std::vector& prioritized_types, const NodeDef& def, - DeviceTypeVector* device_types); + PrioritizedDeviceTypeVector* device_types); // Returns a message with a description of the kernels registered for op // `op_name`. diff --git a/tensorflow/core/framework/op_kernel_test.cc b/tensorflow/core/framework/op_kernel_test.cc index 83dda6579b7..d8001cd0710 100644 --- a/tensorflow/core/framework/op_kernel_test.cc +++ b/tensorflow/core/framework/op_kernel_test.cc @@ -102,6 +102,27 @@ REGISTER_OP("Test4").Input("i: float").Output("o: float"); REGISTER_KERNEL_BUILDER(Name("Test4").Device(DEVICE_CPU), DummyKernel); REGISTER_KERNEL_BUILDER(Name("Test4").Device(DEVICE_GPU), DummyKernel); +// Kernels with different priorities. +REGISTER_OP("Test5").Input("a: T").Input("b: T").Attr("T: type"); + +class TestOp5Cpu : public tensorflow::OpKernel { + public: + explicit TestOp5Cpu(OpKernelConstruction* context) : OpKernel(context) {} + void Compute(OpKernelContext* context) override {} +}; + +REGISTER_KERNEL_BUILDER(Name("Test5").Device(DEVICE_CPU).Priority(2), + TestOp5Cpu); + +class TestOp5Gpu : public tensorflow::OpKernel { + public: + explicit TestOp5Gpu(OpKernelConstruction* context) : OpKernel(context) {} + void Compute(OpKernelContext* context) override {} +}; + +REGISTER_KERNEL_BUILDER(Name("Test5").Device(DEVICE_GPU).Priority(1), + TestOp5Gpu); + static std::vector DeviceTypes() { return {DeviceType(DEVICE_GPU), DeviceType(DEVICE_CPU)}; } @@ -185,10 +206,10 @@ TEST_F(OpKernelTest, SuccessBothCpuAndGpu) { TEST_F(OpKernelTest, CpuTypeRegistered) { NodeDef ndef = CreateNodeDef("Test1", {DT_FLOAT, DT_INT32}); - DeviceTypeVector devs; + PrioritizedDeviceTypeVector devs; TF_ASSERT_OK(SupportedDeviceTypesForNode(DeviceTypes(), ndef, &devs)); EXPECT_EQ(1, devs.size()); - EXPECT_EQ(DeviceType(DEVICE_CPU), devs[0]); + EXPECT_EQ(DeviceType(DEVICE_CPU), devs[0].first); } TEST_F(OpKernelTest, CpuAndGpuTypeRegistered) { @@ -196,24 +217,24 @@ TEST_F(OpKernelTest, CpuAndGpuTypeRegistered) { // Try a node def of an op that is registered for a specific type // only on CPU. NodeDef ndef = CreateNodeDef("Test3", {DT_INT8, DT_INT8}); - DeviceTypeVector devs; + PrioritizedDeviceTypeVector devs; TF_ASSERT_OK(SupportedDeviceTypesForNode(DeviceTypes(), ndef, &devs)); EXPECT_EQ(1, devs.size()); - EXPECT_EQ(DeviceType(DEVICE_CPU), devs[0]); + EXPECT_EQ(DeviceType(DEVICE_CPU), devs[0].first); } { // Try a node def of an op that is registered for a specific type // only on GPU. NodeDef ndef = CreateNodeDef("Test3", {DT_FLOAT, DT_FLOAT}); - DeviceTypeVector devs; + PrioritizedDeviceTypeVector devs; TF_ASSERT_OK(SupportedDeviceTypesForNode(DeviceTypes(), ndef, &devs)); EXPECT_EQ(1, devs.size()); - EXPECT_EQ(DeviceType(DEVICE_GPU), devs[0]); + EXPECT_EQ(DeviceType(DEVICE_GPU), devs[0].first); } { // Try a node def of an op that is only registered for other types. NodeDef ndef = CreateNodeDef("Test3", {DT_STRING, DT_STRING}); - DeviceTypeVector devs; + PrioritizedDeviceTypeVector devs; TF_ASSERT_OK(SupportedDeviceTypesForNode(DeviceTypes(), ndef, &devs)); EXPECT_EQ(0, devs.size()); } @@ -221,11 +242,23 @@ TEST_F(OpKernelTest, CpuAndGpuTypeRegistered) { { // Try a node def of an op that is registered for both. NodeDef ndef = CreateNodeDef("Test4", {DT_FLOAT}); - DeviceTypeVector devs; + PrioritizedDeviceTypeVector devs; TF_ASSERT_OK(SupportedDeviceTypesForNode(DeviceTypes(), ndef, &devs)); EXPECT_EQ(2, devs.size()); - EXPECT_EQ(DeviceType(DEVICE_GPU), devs[0]); - EXPECT_EQ(DeviceType(DEVICE_CPU), devs[1]); + EXPECT_EQ(DeviceType(DEVICE_GPU), devs[0].first); + EXPECT_EQ(DeviceType(DEVICE_CPU), devs[1].first); + } + + { + // Try a node def of an op where kernels have priorities. + NodeDef ndef = CreateNodeDef("Test5", {DT_STRING, DT_STRING}); + PrioritizedDeviceTypeVector devs; + TF_ASSERT_OK(SupportedDeviceTypesForNode(DeviceTypes(), ndef, &devs)); + EXPECT_EQ(2, devs.size()); + EXPECT_EQ(DeviceType(DEVICE_CPU), devs[0].first); + EXPECT_EQ(2, devs[0].second); + EXPECT_EQ(DeviceType(DEVICE_GPU), devs[1].first); + EXPECT_EQ(1, devs[1].second); } } @@ -412,11 +445,11 @@ class OpKernelBuilderTest : public ::testing::Test { } // Test SupportedDeviceTypesForNode() - DeviceTypeVector devices; + PrioritizedDeviceTypeVector devices; TF_EXPECT_OK(SupportedDeviceTypesForNode(DeviceTypes(), def, &devices)); bool found = false; - for (const DeviceType& dt : devices) { - if (dt == device_type) { + for (const auto& dt : devices) { + if (dt.first == device_type) { found = true; } } @@ -445,11 +478,11 @@ class OpKernelBuilderTest : public ::testing::Test { EXPECT_EQ(code, status.code()); // Test SupportedDeviceTypesForNode(). - DeviceTypeVector devices; + PrioritizedDeviceTypeVector devices; if (errors::IsNotFound(status)) { TF_EXPECT_OK(SupportedDeviceTypesForNode(DeviceTypes(), def, &devices)); - for (const DeviceType& dt : devices) { - EXPECT_NE(dt, device_type); + for (const auto& dt : devices) { + EXPECT_NE(dt.first, device_type); } } else { Status status2 = @@ -562,7 +595,7 @@ REGISTER_KERNEL_BUILDER(Name("DuplicateKernel").Device(DEVICE_CPU), TEST_F(OpKernelBuilderTest, DuplicateKernel) { const NodeDef ndef = CreateNodeDef("DuplicateKernel", {}); - DeviceTypeVector devs; + PrioritizedDeviceTypeVector devs; Status status = SupportedDeviceTypesForNode(DeviceTypes(), ndef, &devs); ASSERT_FALSE(status.ok()); EXPECT_TRUE(str_util::StrContains( @@ -582,7 +615,7 @@ REGISTER_KERNEL_BUILDER( TEST_F(OpKernelBuilderTest, DuplicateKernelForT) { const NodeDef ndef = CreateNodeDef("DuplicateKernelForT", {"T|type|DT_FLOAT"}); - DeviceTypeVector devs; + PrioritizedDeviceTypeVector devs; Status status = SupportedDeviceTypesForNode(DeviceTypes(), ndef, &devs); ASSERT_FALSE(status.ok()); EXPECT_TRUE(str_util::StrContains( @@ -603,7 +636,7 @@ REGISTER_KERNEL_BUILDER(Name("BadConstraint") TEST_F(OpKernelBuilderTest, BadConstraint) { const NodeDef ndef = CreateNodeDef("BadConstraint", {}); - DeviceTypeVector devs; + PrioritizedDeviceTypeVector devs; Status status = SupportedDeviceTypesForNode(DeviceTypes(), ndef, &devs); ASSERT_FALSE(status.ok()); EXPECT_TRUE( diff --git a/tensorflow/core/framework/types.h b/tensorflow/core/framework/types.h index a05dea19ec4..c0df1933421 100644 --- a/tensorflow/core/framework/types.h +++ b/tensorflow/core/framework/types.h @@ -104,6 +104,8 @@ typedef gtl::InlinedVector DataTypeVector; typedef gtl::ArraySlice DataTypeSlice; typedef gtl::InlinedVector DeviceTypeVector; +typedef gtl::InlinedVector, 4> + PrioritizedDeviceTypeVector; // Convert the enums to strings for errors: string DataTypeString(DataType dtype);