Adding a priority option to kernels in order to be able to override Device selection prioritization. For the regular case, we always prefer to place on the GPU but for tf.data pipelines when no device is specified, we'd like to place it on the CPU actually since that is what the user usually intends.

As we extend parts of the tf.data pipeline to run on GPU's / TPU's etc., this will become more and more necessary.

PiperOrigin-RevId: 221264064
This commit is contained in:
Rohan Jain 2018-11-13 07:34:20 -08:00 committed by TensorFlower Gardener
parent b9247ca08f
commit 8cbd3301ec
10 changed files with 451 additions and 54 deletions

View File

@ -192,7 +192,7 @@ Status ValidateInputTypeAndPlacement(EagerContext* ctx, Device* op_device,
} }
Status SelectDevice(const NodeDef& ndef, EagerContext* ctx, Device** device) { Status SelectDevice(const NodeDef& ndef, EagerContext* ctx, Device** device) {
DeviceTypeVector final_devices; PrioritizedDeviceTypeVector final_devices;
TF_RETURN_IF_ERROR(SupportedDeviceTypesForNode( TF_RETURN_IF_ERROR(SupportedDeviceTypesForNode(
ctx->prioritized_device_type_list(), ndef, &final_devices)); ctx->prioritized_device_type_list(), ndef, &final_devices));
if (final_devices.empty()) { if (final_devices.empty()) {
@ -202,7 +202,7 @@ Status SelectDevice(const NodeDef& ndef, EagerContext* ctx, Device** device) {
" :\n", KernelsRegisteredForOp(ndef.op())); " :\n", KernelsRegisteredForOp(ndef.op()));
} }
for (Device* d : *ctx->devices()) { for (Device* d : *ctx->devices()) {
if (d->device_type() == final_devices[0].type_string()) { if (d->device_type() == final_devices[0].first.type_string()) {
*device = d; *device = d;
return Status::OK(); return Status::OK();
} }

View File

@ -47,42 +47,51 @@ const StringPiece kColocationGroupPrefixStringPiece(kColocationGroupPrefix);
// returned list is sorted by preferred type (higher numeric type is preferred). // returned list is sorted by preferred type (higher numeric type is preferred).
std::vector<Device*> FilterSupportedDevices( std::vector<Device*> FilterSupportedDevices(
const std::vector<Device*>& devices, const std::vector<Device*>& devices,
const DeviceTypeVector& supported_device_types, const PrioritizedDeviceTypeVector& supported_device_types,
const Device* default_device) { const Device* default_device) {
Device* filtered_default_device = nullptr; Device* filtered_default_device = nullptr;
std::vector<Device*> filtered_devices; std::vector<std::pair<Device*, int32>> prioritized_filtered_devices;
for (const DeviceType& d : supported_device_types) { for (const auto& supported_device_type : supported_device_types) {
for (Device* device : devices) { for (Device* device : devices) {
if (DeviceType(device->attributes().device_type()) == d) { if (DeviceType(device->attributes().device_type()) ==
supported_device_type.first) {
if (device == default_device) { if (device == default_device) {
filtered_default_device = device; filtered_default_device = device;
} else { } else {
filtered_devices.emplace_back(device); prioritized_filtered_devices.emplace_back(
device, supported_device_type.second);
} }
} }
} }
} }
auto device_sort = [](const Device* a, const Device* b) { auto device_sort = [](const std::pair<Device*, int32>& a,
auto a_priority = DeviceSet::DeviceTypeOrder(DeviceType(a->device_type())); const std::pair<Device*, int32>& b) {
auto b_priority = DeviceSet::DeviceTypeOrder(DeviceType(b->device_type())); if (a.second != b.second) {
return a.second > b.second;
}
auto a_priority =
DeviceSet::DeviceTypeOrder(DeviceType(a.first->device_type()));
auto b_priority =
DeviceSet::DeviceTypeOrder(DeviceType(b.first->device_type()));
// First sort by prioritized device type (higher is preferred) and // First sort by prioritized device type (higher is preferred) and
// then by device name (lexicographically). // then by device name (lexicographically).
if (a_priority != b_priority) { if (a_priority != b_priority) {
return a_priority > b_priority; return a_priority > b_priority;
} }
return StringPiece(a->name()) < StringPiece(b->name()); return StringPiece(a.first->name()) < StringPiece(b.first->name());
}; };
std::vector<Device*>::iterator sort_start; std::sort(prioritized_filtered_devices.begin(),
prioritized_filtered_devices.end(), device_sort);
std::vector<Device*> filtered_devices;
if (filtered_default_device != nullptr) { if (filtered_default_device != nullptr) {
// Put the default device first outside of the normal ordering.
filtered_devices.emplace_back(filtered_default_device); filtered_devices.emplace_back(filtered_default_device);
std::iter_swap(filtered_devices.begin(), std::prev(filtered_devices.end()));
sort_start = std::next(filtered_devices.begin());
} else {
sort_start = filtered_devices.begin();
} }
std::sort(sort_start, filtered_devices.end(), device_sort); for (const auto& prioritized_filtered_device : prioritized_filtered_devices) {
filtered_devices.push_back(prioritized_filtered_device.first);
}
return filtered_devices; return filtered_devices;
} }
@ -472,7 +481,7 @@ class ColocationGraph {
// The intersection of all device types supported by this node, // The intersection of all device types supported by this node,
// and those of all of its children, in priority order // and those of all of its children, in priority order
// of the preferred device. // of the preferred device.
DeviceTypeVector supported_device_types; PrioritizedDeviceTypeVector supported_device_types;
// The merged form of the device requested for this node, with // The merged form of the device requested for this node, with
// those of all of its children. // those of all of its children.
@ -511,8 +520,8 @@ class ColocationGraph {
const string& op_type = node->type_string(); const string& op_type = node->type_string();
string devices_registered; string devices_registered;
for (const auto& device_type : members_[id].supported_device_types) { for (const auto& device_type : members_[id].supported_device_types) {
strings::StrAppend(&devices_registered, DeviceTypeString(device_type), strings::StrAppend(&devices_registered,
" "); DeviceTypeString(device_type.first), " ");
} }
type_to_devices[op_type] = std::move(devices_registered); type_to_devices[op_type] = std::move(devices_registered);
@ -565,8 +574,9 @@ class ColocationGraph {
"' does not match any device"); "' does not match any device");
} }
for (const DeviceType& d : member->supported_device_types) { for (const auto& d : member->supported_device_types) {
if (DeviceType(assigned_device->attributes().device_type()) == d) { if (DeviceType(assigned_device->attributes().device_type()) ==
d.first) {
return Status::OK(); return Status::OK();
} }
} }
@ -623,24 +633,102 @@ class ColocationGraph {
return Status::OK(); return Status::OK();
} }
static bool HasPriorities(const PrioritizedDeviceTypeVector& device_types) {
for (const auto& prioritized_device_type : device_types) {
if (prioritized_device_type.second != 0) return true;
}
return false;
}
static bool ArePrioritiesSame(const PrioritizedDeviceTypeVector& a_types,
const PrioritizedDeviceTypeVector& b_types) {
if (a_types.size() != b_types.size()) {
return false;
}
for (int i = 0; i < a_types.size(); ++i) {
if (a_types[i].first != b_types[i].first) {
return false;
}
}
return true;
}
// Updates target to contain the intersection of the device types in // Updates target to contain the intersection of the device types in
// "target" and "other". // "target" and "other".
static void MergeSupportedDevices(DeviceTypeVector* target, static void MergeSupportedDevices(PrioritizedDeviceTypeVector* target,
const DeviceTypeVector& other) { const PrioritizedDeviceTypeVector& other) {
DeviceTypeVector temp = *target; PrioritizedDeviceTypeVector temp = *target;
target->clear(); target->clear();
// Iterate in priority order. // Generate intersection with priorities.
for (const DeviceType& device_type : temp) { PrioritizedDeviceTypeVector target_intersection;
PrioritizedDeviceTypeVector other_intersection;
for (const auto& prioritized_device_type : temp) {
bool found = false; bool found = false;
for (const DeviceType& other_device_type : other) { for (const auto& other_prioritized_device_type : other) {
if (device_type == other_device_type) { if (prioritized_device_type.first ==
other_prioritized_device_type.first) {
found = true; found = true;
other_intersection.push_back(other_prioritized_device_type);
break; break;
} }
} }
if (found) { if (found) {
target->push_back(device_type); target_intersection.push_back(prioritized_device_type);
}
}
// Sort the devices by priority order.
auto device_sort = [](const std::pair<DeviceType, int32>& a,
const std::pair<DeviceType, int32>& b) {
// First look at set priorities.
if (a.second != b.second) {
return a.second > b.second;
}
// Then fallback to default priorities.
auto a_priority = DeviceSet::DeviceTypeOrder(a.first);
auto b_priority = DeviceSet::DeviceTypeOrder(b.first);
if (a_priority != b_priority) {
return a_priority > b_priority;
}
// Finally just look at the Device type strings.
return a.first.type_string() < b.first.type_string();
};
std::sort(target_intersection.begin(), target_intersection.end(),
device_sort);
std::sort(other_intersection.begin(), other_intersection.end(),
device_sort);
bool is_target_prioritized = HasPriorities(target_intersection);
bool is_other_prioritized = HasPriorities(other_intersection);
// If neither are prioritized then we just return the original i.e. target
// prioritization.
if (!is_target_prioritized && !is_other_prioritized) {
*target = target_intersection;
}
// If only one is prioritized, then we respect priorities of that in the
// intersection.
if (is_target_prioritized && !is_other_prioritized) {
*target = target_intersection;
}
if (!is_target_prioritized && is_other_prioritized) {
*target = other_intersection;
}
// If both have priorities and agree then we go with that. If the
// prioritization order is different, then we just fallback to the default
// i.e. what the DeviceTypeOrder suggests. In that case, we also set the
// merged priorities to 0, so that downstream merges work correctly as well.
if (is_target_prioritized && is_other_prioritized) {
bool priorities_agree =
ArePrioritiesSame(target_intersection, other_intersection);
if (priorities_agree) {
*target = target_intersection;
} else {
for (const auto& prioritized_device : target_intersection) {
target->push_back(std::make_pair(prioritized_device.first, 0));
}
std::sort(target->begin(), target->end(), device_sort);
} }
} }
} }

View File

@ -164,6 +164,13 @@ REGISTER_KERNEL_BUILDER(Name("TestDeviceEnforce").Device("FakeGPU"), DummyOp);
REGISTER_KERNEL_BUILDER(Name("Shape").Device("FakeCPU"), DummyOp); REGISTER_KERNEL_BUILDER(Name("Shape").Device("FakeCPU"), DummyOp);
REGISTER_KERNEL_BUILDER(Name("Shape").Device("FakeGPU"), DummyOp); REGISTER_KERNEL_BUILDER(Name("Shape").Device("FakeGPU"), DummyOp);
// Op that has kernels with device priorities specified.
REGISTER_OP("TestDatasetOp").Input("a: float").Output("b: float");
REGISTER_KERNEL_BUILDER(Name("TestDatasetOp").Device("FakeCPU").Priority(2),
DummyOp);
REGISTER_KERNEL_BUILDER(Name("TestDatasetOp").Device("FakeGPU").Priority(1),
DummyOp);
//////////////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////////////
// //
// A PlacerTest method has three phases: // A PlacerTest method has three phases:
@ -285,6 +292,251 @@ TEST_F(PlacerTest, TestNoConstraints) {
EXPECT_DEVICE_TYPE(g, "n2", "FakeGPU"); EXPECT_DEVICE_TYPE(g, "n2", "FakeGPU");
} }
// Test that a graph with no constraints but using kernels that have a specified
// device priority will successfully assign nodes to the device with higher
// priority
TEST_F(PlacerTest, TestNoConstraintsWithPrioritizedKernels) {
Graph g(OpRegistry::Global());
{ // Scope for temporary variables used to construct g.
GraphDefBuilder b(GraphDefBuilder::kFailImmediately);
Node* input = ops::SourceOp("TestInput", b.opts().WithName("in"));
ops::UnaryOp("TestDatasetOp", ops::NodeOut(input, 0),
b.opts().WithName("n1"));
ops::UnaryOp("TestDatasetOp", ops::NodeOut(input, 1),
b.opts().WithName("n2"));
TF_EXPECT_OK(BuildGraph(b, &g));
}
TF_EXPECT_OK(Place(&g));
EXPECT_DEVICE_TYPE(g, "in", "FakeCPU");
EXPECT_DEVICE_TYPE(g, "n1", "FakeCPU");
EXPECT_DEVICE_TYPE(g, "n2", "FakeCPU");
}
TEST_F(PlacerTest, TestGPUInputIntoPrioritizedKernel) {
Graph g(OpRegistry::Global());
{
// Scope for temp variables used to construct g.
GraphDefBuilder b(GraphDefBuilder::kFailImmediately);
Node* input = ops::SourceOp("TestGPUOutput", b.opts().WithName("in"));
ops::UnaryOp("TestDatasetOp", ops::NodeOut(input, 0),
b.opts().WithName("n1"));
TF_EXPECT_OK(BuildGraph(b, &g));
}
TF_EXPECT_OK(Place(&g));
EXPECT_DEVICE_TYPE(g, "in", "FakeGPU");
EXPECT_DEVICE_TYPE(g, "n1", "FakeCPU");
}
// Tests that a GPU kernel colocated with prioritized kernel respects it.
TEST_F(PlacerTest, TestGPUInputColocatedWithPrioritizedKernel) {
Graph g(OpRegistry::Global());
{
// Scope for temp variables used to construct g.
GraphDefBuilder b(GraphDefBuilder::kFailImmediately);
Node* input = ops::SourceOp("TestGPUOutput", b.opts().WithName("in"));
// We colocate n1 with in.
ops::UnaryOp("TestDatasetOp", ops::NodeOut(input, 0),
b.opts().WithName("n1").WithAttr("_class", {"loc:@in"}));
// We don't colocate n2 with in.
ops::UnaryOp("TestDatasetOp", ops::NodeOut(input, 0),
b.opts().WithName("n2"));
TF_EXPECT_OK(BuildGraph(b, &g));
}
TF_EXPECT_OK(Place(&g));
EXPECT_DEVICE_TYPE(g, "in", "FakeGPU");
EXPECT_DEVICE_TYPE(g, "n1", "FakeGPU");
EXPECT_DEVICE_TYPE(g, "n2", "FakeCPU");
}
REGISTER_OP("CreateDatasetCPU").Output("o: resource");
REGISTER_KERNEL_BUILDER(Name("CreateDatasetCPU").Device("FakeCPU"), DummyOp);
REGISTER_OP("CreateDatasetSP").Output("o: resource");
REGISTER_KERNEL_BUILDER(Name("CreateDatasetSP").Device("FakeCPU").Priority(2),
DummyOp);
REGISTER_KERNEL_BUILDER(Name("CreateDatasetSP").Device("FakeGPU").Priority(1),
DummyOp);
REGISTER_OP("CreateDatasetRP").Output("o: resource");
REGISTER_KERNEL_BUILDER(Name("CreateDatasetRP").Device("FakeCPU").Priority(1),
DummyOp);
REGISTER_KERNEL_BUILDER(Name("CreateDatasetRP").Device("FakeGPU").Priority(2),
DummyOp);
REGISTER_OP("CreateDatasetNP").Output("o: resource");
REGISTER_KERNEL_BUILDER(Name("CreateDatasetNP").Device("FakeCPU"), DummyOp);
REGISTER_KERNEL_BUILDER(Name("CreateDatasetNP").Device("FakeGPU"), DummyOp);
REGISTER_OP("IteratorNP").Input("i: resource").Output("o: float");
REGISTER_KERNEL_BUILDER(Name("IteratorNP").Device("FakeCPU"), DummyOp);
REGISTER_KERNEL_BUILDER(Name("IteratorNP").Device("FakeGPU"), DummyOp);
REGISTER_OP("IteratorSP").Input("i: resource").Output("o: float");
REGISTER_KERNEL_BUILDER(Name("IteratorSP").Device("FakeCPU").Priority(2),
DummyOp);
REGISTER_KERNEL_BUILDER(Name("IteratorSP").Device("FakeGPU").Priority(1),
DummyOp);
REGISTER_OP("IteratorRP").Input("i: resource").Output("o: float");
REGISTER_KERNEL_BUILDER(Name("IteratorRP").Device("FakeCPU").Priority(1),
DummyOp);
REGISTER_KERNEL_BUILDER(Name("IteratorRP").Device("FakeGPU").Priority(2),
DummyOp);
REGISTER_OP("IteratorGPU").Input("i: resource").Output("o: float");
REGISTER_KERNEL_BUILDER(Name("IteratorGPU").Device("FakeGPU"), DummyOp);
// Test reference edges with one node having prioritized kernels and the other
// has no preference. We should respect priority here.
TEST_F(PlacerTest, TestDSWithPriority) {
Graph g(OpRegistry::Global());
{
GraphDefBuilder b(GraphDefBuilder::kFailImmediately);
Node* ds = ops::SourceOp("CreateDatasetSP", b.opts().WithName("ds"));
ops::UnaryOp("IteratorNP", ops::NodeOut(ds, 0), b.opts().WithName("it"));
TF_EXPECT_OK(BuildGraph(b, &g));
}
TF_EXPECT_OK(Place(&g));
EXPECT_DEVICE_TYPE(g, "ds", "FakeCPU");
EXPECT_DEVICE_TYPE(g, "it", "FakeCPU");
}
// Test reference edges with one node having kernels with regular priority and
// the other has no preference. We should respect priority here.
TEST_F(PlacerTest, TestDSWithGPUPriority) {
Graph g(OpRegistry::Global());
{
GraphDefBuilder b(GraphDefBuilder::kFailImmediately);
Node* ds = ops::SourceOp("CreateDatasetRP", b.opts().WithName("ds"));
ops::UnaryOp("IteratorNP", ops::NodeOut(ds, 0), b.opts().WithName("it"));
TF_EXPECT_OK(BuildGraph(b, &g));
}
TF_EXPECT_OK(Place(&g));
EXPECT_DEVICE_TYPE(g, "ds", "FakeGPU");
EXPECT_DEVICE_TYPE(g, "it", "FakeGPU");
}
// Test reference edges with one node having prioritized kernels and the other
// has no preference. We should respect priority here.
TEST_F(PlacerTest, TestITWithPriority) {
Graph g(OpRegistry::Global());
{
GraphDefBuilder b(GraphDefBuilder::kFailImmediately);
Node* ds = ops::SourceOp("CreateDatasetNP", b.opts().WithName("ds"));
ops::UnaryOp("IteratorSP", ops::NodeOut(ds, 0), b.opts().WithName("it"));
TF_EXPECT_OK(BuildGraph(b, &g));
}
TF_EXPECT_OK(Place(&g));
EXPECT_DEVICE_TYPE(g, "ds", "FakeCPU");
EXPECT_DEVICE_TYPE(g, "it", "FakeCPU");
}
// Test reference edges with one node having kernels with regular priority and
// the other has no preference. We should respect priority here.
TEST_F(PlacerTest, TestITWithGPUPriority) {
Graph g(OpRegistry::Global());
{
GraphDefBuilder b(GraphDefBuilder::kFailImmediately);
Node* ds = ops::SourceOp("CreateDatasetNP", b.opts().WithName("ds"));
ops::UnaryOp("IteratorRP", ops::NodeOut(ds, 0), b.opts().WithName("it"));
TF_EXPECT_OK(BuildGraph(b, &g));
}
TF_EXPECT_OK(Place(&g));
EXPECT_DEVICE_TYPE(g, "ds", "FakeGPU");
EXPECT_DEVICE_TYPE(g, "it", "FakeGPU");
}
// Test reference edges with one node having prioritized kernels and other node
// can only be placed on GPU. We should respect the constraint then.
TEST_F(PlacerTest, TestITGPU) {
Graph g(OpRegistry::Global());
{
GraphDefBuilder b(GraphDefBuilder::kFailImmediately);
Node* ds = ops::SourceOp("CreateDatasetSP", b.opts().WithName("ds"));
ops::UnaryOp("IteratorGPU", ops::NodeOut(ds, 0), b.opts().WithName("it"));
TF_EXPECT_OK(BuildGraph(b, &g));
}
TF_EXPECT_OK(Place(&g));
EXPECT_DEVICE_TYPE(g, "ds", "FakeGPU");
EXPECT_DEVICE_TYPE(g, "it", "FakeGPU");
}
// Test reference edges with one node having prioritized kernels and other node
// can only be placed on CPU. We should respect the constraint then.
TEST_F(PlacerTest, TestSimpleIteratorOnlyGPU) {
Graph g(OpRegistry::Global());
{
GraphDefBuilder b(GraphDefBuilder::kFailImmediately);
Node* ds = ops::SourceOp("CreateDatasetCPU", b.opts().WithName("ds"));
ops::UnaryOp("IteratorRP", ops::NodeOut(ds, 0), b.opts().WithName("it"));
TF_EXPECT_OK(BuildGraph(b, &g));
}
TF_EXPECT_OK(Place(&g));
EXPECT_DEVICE_TYPE(g, "ds", "FakeCPU");
EXPECT_DEVICE_TYPE(g, "it", "FakeCPU");
}
// Test constraints with agreeing priorities.
TEST_F(PlacerTest, TestAgreeingPriorities) {
Graph g(OpRegistry::Global());
{
GraphDefBuilder b(GraphDefBuilder::kFailImmediately);
Node* ds = ops::SourceOp("CreateDatasetSP", b.opts().WithName("ds"));
ops::UnaryOp("IteratorSP", ops::NodeOut(ds, 0), b.opts().WithName("it"));
TF_EXPECT_OK(BuildGraph(b, &g));
}
TF_EXPECT_OK(Place(&g));
EXPECT_DEVICE_TYPE(g, "ds", "FakeCPU");
EXPECT_DEVICE_TYPE(g, "it", "FakeCPU");
}
// Test constraints with agreeing regular priorities.
TEST_F(PlacerTest, TestAgreeingRegularPriorities) {
Graph g(OpRegistry::Global());
{
GraphDefBuilder b(GraphDefBuilder::kFailImmediately);
Node* ds = ops::SourceOp("CreateDatasetRP", b.opts().WithName("ds"));
ops::UnaryOp("IteratorRP", ops::NodeOut(ds, 0), b.opts().WithName("it"));
TF_EXPECT_OK(BuildGraph(b, &g));
}
TF_EXPECT_OK(Place(&g));
EXPECT_DEVICE_TYPE(g, "ds", "FakeGPU");
EXPECT_DEVICE_TYPE(g, "it", "FakeGPU");
}
// Test constraints with different priorities. In this case, we should bail
// and just revert to default.
TEST_F(PlacerTest, TestConflictingPriorities) {
Graph g(OpRegistry::Global());
{
GraphDefBuilder b(GraphDefBuilder::kFailImmediately);
Node* ds = ops::SourceOp("CreateDatasetSP", b.opts().WithName("ds"));
ops::UnaryOp("IteratorRP", ops::NodeOut(ds, 0), b.opts().WithName("it"));
TF_EXPECT_OK(BuildGraph(b, &g));
}
TF_EXPECT_OK(Place(&g));
EXPECT_DEVICE_TYPE(g, "ds", "FakeGPU");
EXPECT_DEVICE_TYPE(g, "it", "FakeGPU");
}
// Test constraints with different priorities. In this case, we should bail
// and just revert to default.
TEST_F(PlacerTest, TestConflictingPrioritiesReversed) {
Graph g(OpRegistry::Global());
{
GraphDefBuilder b(GraphDefBuilder::kFailImmediately);
Node* ds = ops::SourceOp("CreateDatasetRP", b.opts().WithName("ds"));
ops::UnaryOp("IteratorSP", ops::NodeOut(ds, 0), b.opts().WithName("it"));
TF_EXPECT_OK(BuildGraph(b, &g));
}
TF_EXPECT_OK(Place(&g));
EXPECT_DEVICE_TYPE(g, "ds", "FakeGPU");
EXPECT_DEVICE_TYPE(g, "it", "FakeGPU");
}
// Test that a graph with device type and reference constraints on // Test that a graph with device type and reference constraints on
// some of the ops will successfully assign nodes to the constrained // some of the ops will successfully assign nodes to the constrained
// device, and colocate nodes with reference connections. // device, and colocate nodes with reference connections.

View File

@ -33,6 +33,11 @@ message KernelDef {
// won't be used unless the user specifies a "_kernel" attr with // won't be used unless the user specifies a "_kernel" attr with
// value matching this. // value matching this.
string label = 5; string label = 5;
// Prioritization of kernel amongst different devices. By default we assume
// priority is 0. The higher the priority the better. By default (i.e. if
// this is not set), we prefer GPU kernels over CPU.
int32 priority = 6;
} }
// A collection of KernelDefs // A collection of KernelDefs

View File

@ -66,6 +66,11 @@ KernelDefBuilder& KernelDefBuilder::Label(const char* label) {
return *this; return *this;
} }
KernelDefBuilder& KernelDefBuilder::Priority(int32 priority) {
kernel_def_->set_priority(priority);
return *this;
}
const KernelDef* KernelDefBuilder::Build() { const KernelDef* KernelDefBuilder::Build() {
KernelDef* r = kernel_def_; KernelDef* r = kernel_def_;
kernel_def_ = nullptr; kernel_def_ = nullptr;

View File

@ -64,6 +64,9 @@ class KernelDefBuilder {
// "_kernel" attr. May only be specified once. Returns *this. // "_kernel" attr. May only be specified once. Returns *this.
KernelDefBuilder& Label(const char* label); KernelDefBuilder& Label(const char* label);
// Specify a priority number for this kernel.
KernelDefBuilder& Priority(int32 priority);
// Returns a pointer to a KernelDef with fields set based on the // Returns a pointer to a KernelDef with fields set based on the
// above calls to this instance. // above calls to this instance.
// Caller takes ownership of the result. // Caller takes ownership of the result.

View File

@ -1091,7 +1091,7 @@ Status FindKernelDef(const DeviceType& device_type, const NodeDef& node_def,
Status SupportedDeviceTypesForNode( Status SupportedDeviceTypesForNode(
const std::vector<DeviceType>& prioritized_types, const NodeDef& def, const std::vector<DeviceType>& prioritized_types, const NodeDef& def,
DeviceTypeVector* device_types) { PrioritizedDeviceTypeVector* prioritized_device_types) {
// TODO(zhifengc): Changes the callers (SimplePlacer and // TODO(zhifengc): Changes the callers (SimplePlacer and
// DynamicPlacer) to consider the possibility that 'def' is call to // DynamicPlacer) to consider the possibility that 'def' is call to
// a user-defined function and only calls this // a user-defined function and only calls this
@ -1104,12 +1104,21 @@ Status SupportedDeviceTypesForNode(
bool was_attr_mismatch; bool was_attr_mismatch;
TF_RETURN_IF_ERROR( TF_RETURN_IF_ERROR(
FindKernelRegistration(device_type, def, &reg, &was_attr_mismatch)); FindKernelRegistration(device_type, def, &reg, &was_attr_mismatch));
if (reg != nullptr) device_types->push_back(device_type); if (reg != nullptr) {
int32 priority = reg->def.priority();
prioritized_device_types->emplace_back(device_type, priority);
}
} }
std::sort(prioritized_device_types->begin(),
prioritized_device_types->end(),
[](const std::pair<DeviceType, int32>& a,
const std::pair<DeviceType, int32>& b) {
return a.second > b.second;
});
} else { } else {
// Assumes that all device types support this node. // Assumes that all device types support this node.
for (const DeviceType& device_type : prioritized_types) { for (const DeviceType& device_type : prioritized_types) {
device_types->push_back(device_type); prioritized_device_types->push_back(std::make_pair(device_type, 0));
} }
} }
return Status::OK(); return Status::OK();

View File

@ -1237,7 +1237,7 @@ Status CreateOpKernel(DeviceType device_type, DeviceBase* device,
// * def has all attrs specified (e.g. using AddDefaultsToNodeDef()). // * def has all attrs specified (e.g. using AddDefaultsToNodeDef()).
Status SupportedDeviceTypesForNode( Status SupportedDeviceTypesForNode(
const std::vector<DeviceType>& prioritized_types, const NodeDef& def, const std::vector<DeviceType>& prioritized_types, const NodeDef& def,
DeviceTypeVector* device_types); PrioritizedDeviceTypeVector* device_types);
// Returns a message with a description of the kernels registered for op // Returns a message with a description of the kernels registered for op
// `op_name`. // `op_name`.

View File

@ -102,6 +102,27 @@ REGISTER_OP("Test4").Input("i: float").Output("o: float");
REGISTER_KERNEL_BUILDER(Name("Test4").Device(DEVICE_CPU), DummyKernel); REGISTER_KERNEL_BUILDER(Name("Test4").Device(DEVICE_CPU), DummyKernel);
REGISTER_KERNEL_BUILDER(Name("Test4").Device(DEVICE_GPU), DummyKernel); REGISTER_KERNEL_BUILDER(Name("Test4").Device(DEVICE_GPU), DummyKernel);
// Kernels with different priorities.
REGISTER_OP("Test5").Input("a: T").Input("b: T").Attr("T: type");
class TestOp5Cpu : public tensorflow::OpKernel {
public:
explicit TestOp5Cpu(OpKernelConstruction* context) : OpKernel(context) {}
void Compute(OpKernelContext* context) override {}
};
REGISTER_KERNEL_BUILDER(Name("Test5").Device(DEVICE_CPU).Priority(2),
TestOp5Cpu);
class TestOp5Gpu : public tensorflow::OpKernel {
public:
explicit TestOp5Gpu(OpKernelConstruction* context) : OpKernel(context) {}
void Compute(OpKernelContext* context) override {}
};
REGISTER_KERNEL_BUILDER(Name("Test5").Device(DEVICE_GPU).Priority(1),
TestOp5Gpu);
static std::vector<DeviceType> DeviceTypes() { static std::vector<DeviceType> DeviceTypes() {
return {DeviceType(DEVICE_GPU), DeviceType(DEVICE_CPU)}; return {DeviceType(DEVICE_GPU), DeviceType(DEVICE_CPU)};
} }
@ -185,10 +206,10 @@ TEST_F(OpKernelTest, SuccessBothCpuAndGpu) {
TEST_F(OpKernelTest, CpuTypeRegistered) { TEST_F(OpKernelTest, CpuTypeRegistered) {
NodeDef ndef = CreateNodeDef("Test1", {DT_FLOAT, DT_INT32}); NodeDef ndef = CreateNodeDef("Test1", {DT_FLOAT, DT_INT32});
DeviceTypeVector devs; PrioritizedDeviceTypeVector devs;
TF_ASSERT_OK(SupportedDeviceTypesForNode(DeviceTypes(), ndef, &devs)); TF_ASSERT_OK(SupportedDeviceTypesForNode(DeviceTypes(), ndef, &devs));
EXPECT_EQ(1, devs.size()); EXPECT_EQ(1, devs.size());
EXPECT_EQ(DeviceType(DEVICE_CPU), devs[0]); EXPECT_EQ(DeviceType(DEVICE_CPU), devs[0].first);
} }
TEST_F(OpKernelTest, CpuAndGpuTypeRegistered) { TEST_F(OpKernelTest, CpuAndGpuTypeRegistered) {
@ -196,24 +217,24 @@ TEST_F(OpKernelTest, CpuAndGpuTypeRegistered) {
// Try a node def of an op that is registered for a specific type // Try a node def of an op that is registered for a specific type
// only on CPU. // only on CPU.
NodeDef ndef = CreateNodeDef("Test3", {DT_INT8, DT_INT8}); NodeDef ndef = CreateNodeDef("Test3", {DT_INT8, DT_INT8});
DeviceTypeVector devs; PrioritizedDeviceTypeVector devs;
TF_ASSERT_OK(SupportedDeviceTypesForNode(DeviceTypes(), ndef, &devs)); TF_ASSERT_OK(SupportedDeviceTypesForNode(DeviceTypes(), ndef, &devs));
EXPECT_EQ(1, devs.size()); EXPECT_EQ(1, devs.size());
EXPECT_EQ(DeviceType(DEVICE_CPU), devs[0]); EXPECT_EQ(DeviceType(DEVICE_CPU), devs[0].first);
} }
{ {
// Try a node def of an op that is registered for a specific type // Try a node def of an op that is registered for a specific type
// only on GPU. // only on GPU.
NodeDef ndef = CreateNodeDef("Test3", {DT_FLOAT, DT_FLOAT}); NodeDef ndef = CreateNodeDef("Test3", {DT_FLOAT, DT_FLOAT});
DeviceTypeVector devs; PrioritizedDeviceTypeVector devs;
TF_ASSERT_OK(SupportedDeviceTypesForNode(DeviceTypes(), ndef, &devs)); TF_ASSERT_OK(SupportedDeviceTypesForNode(DeviceTypes(), ndef, &devs));
EXPECT_EQ(1, devs.size()); EXPECT_EQ(1, devs.size());
EXPECT_EQ(DeviceType(DEVICE_GPU), devs[0]); EXPECT_EQ(DeviceType(DEVICE_GPU), devs[0].first);
} }
{ {
// Try a node def of an op that is only registered for other types. // Try a node def of an op that is only registered for other types.
NodeDef ndef = CreateNodeDef("Test3", {DT_STRING, DT_STRING}); NodeDef ndef = CreateNodeDef("Test3", {DT_STRING, DT_STRING});
DeviceTypeVector devs; PrioritizedDeviceTypeVector devs;
TF_ASSERT_OK(SupportedDeviceTypesForNode(DeviceTypes(), ndef, &devs)); TF_ASSERT_OK(SupportedDeviceTypesForNode(DeviceTypes(), ndef, &devs));
EXPECT_EQ(0, devs.size()); EXPECT_EQ(0, devs.size());
} }
@ -221,11 +242,23 @@ TEST_F(OpKernelTest, CpuAndGpuTypeRegistered) {
{ {
// Try a node def of an op that is registered for both. // Try a node def of an op that is registered for both.
NodeDef ndef = CreateNodeDef("Test4", {DT_FLOAT}); NodeDef ndef = CreateNodeDef("Test4", {DT_FLOAT});
DeviceTypeVector devs; PrioritizedDeviceTypeVector devs;
TF_ASSERT_OK(SupportedDeviceTypesForNode(DeviceTypes(), ndef, &devs)); TF_ASSERT_OK(SupportedDeviceTypesForNode(DeviceTypes(), ndef, &devs));
EXPECT_EQ(2, devs.size()); EXPECT_EQ(2, devs.size());
EXPECT_EQ(DeviceType(DEVICE_GPU), devs[0]); EXPECT_EQ(DeviceType(DEVICE_GPU), devs[0].first);
EXPECT_EQ(DeviceType(DEVICE_CPU), devs[1]); EXPECT_EQ(DeviceType(DEVICE_CPU), devs[1].first);
}
{
// Try a node def of an op where kernels have priorities.
NodeDef ndef = CreateNodeDef("Test5", {DT_STRING, DT_STRING});
PrioritizedDeviceTypeVector devs;
TF_ASSERT_OK(SupportedDeviceTypesForNode(DeviceTypes(), ndef, &devs));
EXPECT_EQ(2, devs.size());
EXPECT_EQ(DeviceType(DEVICE_CPU), devs[0].first);
EXPECT_EQ(2, devs[0].second);
EXPECT_EQ(DeviceType(DEVICE_GPU), devs[1].first);
EXPECT_EQ(1, devs[1].second);
} }
} }
@ -412,11 +445,11 @@ class OpKernelBuilderTest : public ::testing::Test {
} }
// Test SupportedDeviceTypesForNode() // Test SupportedDeviceTypesForNode()
DeviceTypeVector devices; PrioritizedDeviceTypeVector devices;
TF_EXPECT_OK(SupportedDeviceTypesForNode(DeviceTypes(), def, &devices)); TF_EXPECT_OK(SupportedDeviceTypesForNode(DeviceTypes(), def, &devices));
bool found = false; bool found = false;
for (const DeviceType& dt : devices) { for (const auto& dt : devices) {
if (dt == device_type) { if (dt.first == device_type) {
found = true; found = true;
} }
} }
@ -445,11 +478,11 @@ class OpKernelBuilderTest : public ::testing::Test {
EXPECT_EQ(code, status.code()); EXPECT_EQ(code, status.code());
// Test SupportedDeviceTypesForNode(). // Test SupportedDeviceTypesForNode().
DeviceTypeVector devices; PrioritizedDeviceTypeVector devices;
if (errors::IsNotFound(status)) { if (errors::IsNotFound(status)) {
TF_EXPECT_OK(SupportedDeviceTypesForNode(DeviceTypes(), def, &devices)); TF_EXPECT_OK(SupportedDeviceTypesForNode(DeviceTypes(), def, &devices));
for (const DeviceType& dt : devices) { for (const auto& dt : devices) {
EXPECT_NE(dt, device_type); EXPECT_NE(dt.first, device_type);
} }
} else { } else {
Status status2 = Status status2 =
@ -562,7 +595,7 @@ REGISTER_KERNEL_BUILDER(Name("DuplicateKernel").Device(DEVICE_CPU),
TEST_F(OpKernelBuilderTest, DuplicateKernel) { TEST_F(OpKernelBuilderTest, DuplicateKernel) {
const NodeDef ndef = CreateNodeDef("DuplicateKernel", {}); const NodeDef ndef = CreateNodeDef("DuplicateKernel", {});
DeviceTypeVector devs; PrioritizedDeviceTypeVector devs;
Status status = SupportedDeviceTypesForNode(DeviceTypes(), ndef, &devs); Status status = SupportedDeviceTypesForNode(DeviceTypes(), ndef, &devs);
ASSERT_FALSE(status.ok()); ASSERT_FALSE(status.ok());
EXPECT_TRUE(str_util::StrContains( EXPECT_TRUE(str_util::StrContains(
@ -582,7 +615,7 @@ REGISTER_KERNEL_BUILDER(
TEST_F(OpKernelBuilderTest, DuplicateKernelForT) { TEST_F(OpKernelBuilderTest, DuplicateKernelForT) {
const NodeDef ndef = const NodeDef ndef =
CreateNodeDef("DuplicateKernelForT", {"T|type|DT_FLOAT"}); CreateNodeDef("DuplicateKernelForT", {"T|type|DT_FLOAT"});
DeviceTypeVector devs; PrioritizedDeviceTypeVector devs;
Status status = SupportedDeviceTypesForNode(DeviceTypes(), ndef, &devs); Status status = SupportedDeviceTypesForNode(DeviceTypes(), ndef, &devs);
ASSERT_FALSE(status.ok()); ASSERT_FALSE(status.ok());
EXPECT_TRUE(str_util::StrContains( EXPECT_TRUE(str_util::StrContains(
@ -603,7 +636,7 @@ REGISTER_KERNEL_BUILDER(Name("BadConstraint")
TEST_F(OpKernelBuilderTest, BadConstraint) { TEST_F(OpKernelBuilderTest, BadConstraint) {
const NodeDef ndef = CreateNodeDef("BadConstraint", {}); const NodeDef ndef = CreateNodeDef("BadConstraint", {});
DeviceTypeVector devs; PrioritizedDeviceTypeVector devs;
Status status = SupportedDeviceTypesForNode(DeviceTypes(), ndef, &devs); Status status = SupportedDeviceTypesForNode(DeviceTypes(), ndef, &devs);
ASSERT_FALSE(status.ok()); ASSERT_FALSE(status.ok());
EXPECT_TRUE( EXPECT_TRUE(

View File

@ -104,6 +104,8 @@ typedef gtl::InlinedVector<DataType, 4> DataTypeVector;
typedef gtl::ArraySlice<DataType> DataTypeSlice; typedef gtl::ArraySlice<DataType> DataTypeSlice;
typedef gtl::InlinedVector<DeviceType, 4> DeviceTypeVector; typedef gtl::InlinedVector<DeviceType, 4> DeviceTypeVector;
typedef gtl::InlinedVector<std::pair<DeviceType, int32>, 4>
PrioritizedDeviceTypeVector;
// Convert the enums to strings for errors: // Convert the enums to strings for errors:
string DataTypeString(DataType dtype); string DataTypeString(DataType dtype);