Adding a priority option to kernels in order to be able to override Device selection prioritization. For the regular case, we always prefer to place on the GPU but for tf.data pipelines when no device is specified, we'd like to place it on the CPU actually since that is what the user usually intends.
As we extend parts of the tf.data pipeline to run on GPU's / TPU's etc., this will become more and more necessary. PiperOrigin-RevId: 221264064
This commit is contained in:
parent
b9247ca08f
commit
8cbd3301ec
@ -192,7 +192,7 @@ Status ValidateInputTypeAndPlacement(EagerContext* ctx, Device* op_device,
|
||||
}
|
||||
|
||||
Status SelectDevice(const NodeDef& ndef, EagerContext* ctx, Device** device) {
|
||||
DeviceTypeVector final_devices;
|
||||
PrioritizedDeviceTypeVector final_devices;
|
||||
TF_RETURN_IF_ERROR(SupportedDeviceTypesForNode(
|
||||
ctx->prioritized_device_type_list(), ndef, &final_devices));
|
||||
if (final_devices.empty()) {
|
||||
@ -202,7 +202,7 @@ Status SelectDevice(const NodeDef& ndef, EagerContext* ctx, Device** device) {
|
||||
" :\n", KernelsRegisteredForOp(ndef.op()));
|
||||
}
|
||||
for (Device* d : *ctx->devices()) {
|
||||
if (d->device_type() == final_devices[0].type_string()) {
|
||||
if (d->device_type() == final_devices[0].first.type_string()) {
|
||||
*device = d;
|
||||
return Status::OK();
|
||||
}
|
||||
|
@ -47,42 +47,51 @@ const StringPiece kColocationGroupPrefixStringPiece(kColocationGroupPrefix);
|
||||
// returned list is sorted by preferred type (higher numeric type is preferred).
|
||||
std::vector<Device*> FilterSupportedDevices(
|
||||
const std::vector<Device*>& devices,
|
||||
const DeviceTypeVector& supported_device_types,
|
||||
const PrioritizedDeviceTypeVector& supported_device_types,
|
||||
const Device* default_device) {
|
||||
Device* filtered_default_device = nullptr;
|
||||
std::vector<Device*> filtered_devices;
|
||||
for (const DeviceType& d : supported_device_types) {
|
||||
std::vector<std::pair<Device*, int32>> prioritized_filtered_devices;
|
||||
for (const auto& supported_device_type : supported_device_types) {
|
||||
for (Device* device : devices) {
|
||||
if (DeviceType(device->attributes().device_type()) == d) {
|
||||
if (DeviceType(device->attributes().device_type()) ==
|
||||
supported_device_type.first) {
|
||||
if (device == default_device) {
|
||||
filtered_default_device = device;
|
||||
} else {
|
||||
filtered_devices.emplace_back(device);
|
||||
prioritized_filtered_devices.emplace_back(
|
||||
device, supported_device_type.second);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
auto device_sort = [](const Device* a, const Device* b) {
|
||||
auto a_priority = DeviceSet::DeviceTypeOrder(DeviceType(a->device_type()));
|
||||
auto b_priority = DeviceSet::DeviceTypeOrder(DeviceType(b->device_type()));
|
||||
auto device_sort = [](const std::pair<Device*, int32>& a,
|
||||
const std::pair<Device*, int32>& b) {
|
||||
if (a.second != b.second) {
|
||||
return a.second > b.second;
|
||||
}
|
||||
|
||||
auto a_priority =
|
||||
DeviceSet::DeviceTypeOrder(DeviceType(a.first->device_type()));
|
||||
auto b_priority =
|
||||
DeviceSet::DeviceTypeOrder(DeviceType(b.first->device_type()));
|
||||
// First sort by prioritized device type (higher is preferred) and
|
||||
// then by device name (lexicographically).
|
||||
if (a_priority != b_priority) {
|
||||
return a_priority > b_priority;
|
||||
}
|
||||
return StringPiece(a->name()) < StringPiece(b->name());
|
||||
return StringPiece(a.first->name()) < StringPiece(b.first->name());
|
||||
};
|
||||
std::vector<Device*>::iterator sort_start;
|
||||
std::sort(prioritized_filtered_devices.begin(),
|
||||
prioritized_filtered_devices.end(), device_sort);
|
||||
|
||||
std::vector<Device*> filtered_devices;
|
||||
if (filtered_default_device != nullptr) {
|
||||
// Put the default device first outside of the normal ordering.
|
||||
filtered_devices.emplace_back(filtered_default_device);
|
||||
std::iter_swap(filtered_devices.begin(), std::prev(filtered_devices.end()));
|
||||
sort_start = std::next(filtered_devices.begin());
|
||||
} else {
|
||||
sort_start = filtered_devices.begin();
|
||||
}
|
||||
std::sort(sort_start, filtered_devices.end(), device_sort);
|
||||
for (const auto& prioritized_filtered_device : prioritized_filtered_devices) {
|
||||
filtered_devices.push_back(prioritized_filtered_device.first);
|
||||
}
|
||||
return filtered_devices;
|
||||
}
|
||||
|
||||
@ -472,7 +481,7 @@ class ColocationGraph {
|
||||
// The intersection of all device types supported by this node,
|
||||
// and those of all of its children, in priority order
|
||||
// of the preferred device.
|
||||
DeviceTypeVector supported_device_types;
|
||||
PrioritizedDeviceTypeVector supported_device_types;
|
||||
|
||||
// The merged form of the device requested for this node, with
|
||||
// those of all of its children.
|
||||
@ -511,8 +520,8 @@ class ColocationGraph {
|
||||
const string& op_type = node->type_string();
|
||||
string devices_registered;
|
||||
for (const auto& device_type : members_[id].supported_device_types) {
|
||||
strings::StrAppend(&devices_registered, DeviceTypeString(device_type),
|
||||
" ");
|
||||
strings::StrAppend(&devices_registered,
|
||||
DeviceTypeString(device_type.first), " ");
|
||||
}
|
||||
|
||||
type_to_devices[op_type] = std::move(devices_registered);
|
||||
@ -565,8 +574,9 @@ class ColocationGraph {
|
||||
"' does not match any device");
|
||||
}
|
||||
|
||||
for (const DeviceType& d : member->supported_device_types) {
|
||||
if (DeviceType(assigned_device->attributes().device_type()) == d) {
|
||||
for (const auto& d : member->supported_device_types) {
|
||||
if (DeviceType(assigned_device->attributes().device_type()) ==
|
||||
d.first) {
|
||||
return Status::OK();
|
||||
}
|
||||
}
|
||||
@ -623,24 +633,102 @@ class ColocationGraph {
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
static bool HasPriorities(const PrioritizedDeviceTypeVector& device_types) {
|
||||
for (const auto& prioritized_device_type : device_types) {
|
||||
if (prioritized_device_type.second != 0) return true;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
static bool ArePrioritiesSame(const PrioritizedDeviceTypeVector& a_types,
|
||||
const PrioritizedDeviceTypeVector& b_types) {
|
||||
if (a_types.size() != b_types.size()) {
|
||||
return false;
|
||||
}
|
||||
for (int i = 0; i < a_types.size(); ++i) {
|
||||
if (a_types[i].first != b_types[i].first) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
// Updates target to contain the intersection of the device types in
|
||||
// "target" and "other".
|
||||
static void MergeSupportedDevices(DeviceTypeVector* target,
|
||||
const DeviceTypeVector& other) {
|
||||
DeviceTypeVector temp = *target;
|
||||
static void MergeSupportedDevices(PrioritizedDeviceTypeVector* target,
|
||||
const PrioritizedDeviceTypeVector& other) {
|
||||
PrioritizedDeviceTypeVector temp = *target;
|
||||
target->clear();
|
||||
|
||||
// Iterate in priority order.
|
||||
for (const DeviceType& device_type : temp) {
|
||||
// Generate intersection with priorities.
|
||||
PrioritizedDeviceTypeVector target_intersection;
|
||||
PrioritizedDeviceTypeVector other_intersection;
|
||||
for (const auto& prioritized_device_type : temp) {
|
||||
bool found = false;
|
||||
for (const DeviceType& other_device_type : other) {
|
||||
if (device_type == other_device_type) {
|
||||
for (const auto& other_prioritized_device_type : other) {
|
||||
if (prioritized_device_type.first ==
|
||||
other_prioritized_device_type.first) {
|
||||
found = true;
|
||||
other_intersection.push_back(other_prioritized_device_type);
|
||||
break;
|
||||
}
|
||||
}
|
||||
if (found) {
|
||||
target->push_back(device_type);
|
||||
target_intersection.push_back(prioritized_device_type);
|
||||
}
|
||||
}
|
||||
|
||||
// Sort the devices by priority order.
|
||||
auto device_sort = [](const std::pair<DeviceType, int32>& a,
|
||||
const std::pair<DeviceType, int32>& b) {
|
||||
// First look at set priorities.
|
||||
if (a.second != b.second) {
|
||||
return a.second > b.second;
|
||||
}
|
||||
// Then fallback to default priorities.
|
||||
auto a_priority = DeviceSet::DeviceTypeOrder(a.first);
|
||||
auto b_priority = DeviceSet::DeviceTypeOrder(b.first);
|
||||
if (a_priority != b_priority) {
|
||||
return a_priority > b_priority;
|
||||
}
|
||||
// Finally just look at the Device type strings.
|
||||
return a.first.type_string() < b.first.type_string();
|
||||
};
|
||||
|
||||
std::sort(target_intersection.begin(), target_intersection.end(),
|
||||
device_sort);
|
||||
std::sort(other_intersection.begin(), other_intersection.end(),
|
||||
device_sort);
|
||||
|
||||
bool is_target_prioritized = HasPriorities(target_intersection);
|
||||
bool is_other_prioritized = HasPriorities(other_intersection);
|
||||
// If neither are prioritized then we just return the original i.e. target
|
||||
// prioritization.
|
||||
if (!is_target_prioritized && !is_other_prioritized) {
|
||||
*target = target_intersection;
|
||||
}
|
||||
// If only one is prioritized, then we respect priorities of that in the
|
||||
// intersection.
|
||||
if (is_target_prioritized && !is_other_prioritized) {
|
||||
*target = target_intersection;
|
||||
}
|
||||
if (!is_target_prioritized && is_other_prioritized) {
|
||||
*target = other_intersection;
|
||||
}
|
||||
// If both have priorities and agree then we go with that. If the
|
||||
// prioritization order is different, then we just fallback to the default
|
||||
// i.e. what the DeviceTypeOrder suggests. In that case, we also set the
|
||||
// merged priorities to 0, so that downstream merges work correctly as well.
|
||||
if (is_target_prioritized && is_other_prioritized) {
|
||||
bool priorities_agree =
|
||||
ArePrioritiesSame(target_intersection, other_intersection);
|
||||
if (priorities_agree) {
|
||||
*target = target_intersection;
|
||||
} else {
|
||||
for (const auto& prioritized_device : target_intersection) {
|
||||
target->push_back(std::make_pair(prioritized_device.first, 0));
|
||||
}
|
||||
std::sort(target->begin(), target->end(), device_sort);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -164,6 +164,13 @@ REGISTER_KERNEL_BUILDER(Name("TestDeviceEnforce").Device("FakeGPU"), DummyOp);
|
||||
REGISTER_KERNEL_BUILDER(Name("Shape").Device("FakeCPU"), DummyOp);
|
||||
REGISTER_KERNEL_BUILDER(Name("Shape").Device("FakeGPU"), DummyOp);
|
||||
|
||||
// Op that has kernels with device priorities specified.
|
||||
REGISTER_OP("TestDatasetOp").Input("a: float").Output("b: float");
|
||||
REGISTER_KERNEL_BUILDER(Name("TestDatasetOp").Device("FakeCPU").Priority(2),
|
||||
DummyOp);
|
||||
REGISTER_KERNEL_BUILDER(Name("TestDatasetOp").Device("FakeGPU").Priority(1),
|
||||
DummyOp);
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
//
|
||||
// A PlacerTest method has three phases:
|
||||
@ -285,6 +292,251 @@ TEST_F(PlacerTest, TestNoConstraints) {
|
||||
EXPECT_DEVICE_TYPE(g, "n2", "FakeGPU");
|
||||
}
|
||||
|
||||
// Test that a graph with no constraints but using kernels that have a specified
|
||||
// device priority will successfully assign nodes to the device with higher
|
||||
// priority
|
||||
TEST_F(PlacerTest, TestNoConstraintsWithPrioritizedKernels) {
|
||||
Graph g(OpRegistry::Global());
|
||||
{ // Scope for temporary variables used to construct g.
|
||||
GraphDefBuilder b(GraphDefBuilder::kFailImmediately);
|
||||
Node* input = ops::SourceOp("TestInput", b.opts().WithName("in"));
|
||||
ops::UnaryOp("TestDatasetOp", ops::NodeOut(input, 0),
|
||||
b.opts().WithName("n1"));
|
||||
ops::UnaryOp("TestDatasetOp", ops::NodeOut(input, 1),
|
||||
b.opts().WithName("n2"));
|
||||
TF_EXPECT_OK(BuildGraph(b, &g));
|
||||
}
|
||||
|
||||
TF_EXPECT_OK(Place(&g));
|
||||
EXPECT_DEVICE_TYPE(g, "in", "FakeCPU");
|
||||
EXPECT_DEVICE_TYPE(g, "n1", "FakeCPU");
|
||||
EXPECT_DEVICE_TYPE(g, "n2", "FakeCPU");
|
||||
}
|
||||
|
||||
TEST_F(PlacerTest, TestGPUInputIntoPrioritizedKernel) {
|
||||
Graph g(OpRegistry::Global());
|
||||
{
|
||||
// Scope for temp variables used to construct g.
|
||||
GraphDefBuilder b(GraphDefBuilder::kFailImmediately);
|
||||
Node* input = ops::SourceOp("TestGPUOutput", b.opts().WithName("in"));
|
||||
ops::UnaryOp("TestDatasetOp", ops::NodeOut(input, 0),
|
||||
b.opts().WithName("n1"));
|
||||
TF_EXPECT_OK(BuildGraph(b, &g));
|
||||
}
|
||||
|
||||
TF_EXPECT_OK(Place(&g));
|
||||
EXPECT_DEVICE_TYPE(g, "in", "FakeGPU");
|
||||
EXPECT_DEVICE_TYPE(g, "n1", "FakeCPU");
|
||||
}
|
||||
|
||||
// Tests that a GPU kernel colocated with prioritized kernel respects it.
|
||||
TEST_F(PlacerTest, TestGPUInputColocatedWithPrioritizedKernel) {
|
||||
Graph g(OpRegistry::Global());
|
||||
{
|
||||
// Scope for temp variables used to construct g.
|
||||
GraphDefBuilder b(GraphDefBuilder::kFailImmediately);
|
||||
Node* input = ops::SourceOp("TestGPUOutput", b.opts().WithName("in"));
|
||||
// We colocate n1 with in.
|
||||
ops::UnaryOp("TestDatasetOp", ops::NodeOut(input, 0),
|
||||
b.opts().WithName("n1").WithAttr("_class", {"loc:@in"}));
|
||||
// We don't colocate n2 with in.
|
||||
ops::UnaryOp("TestDatasetOp", ops::NodeOut(input, 0),
|
||||
b.opts().WithName("n2"));
|
||||
TF_EXPECT_OK(BuildGraph(b, &g));
|
||||
}
|
||||
|
||||
TF_EXPECT_OK(Place(&g));
|
||||
EXPECT_DEVICE_TYPE(g, "in", "FakeGPU");
|
||||
EXPECT_DEVICE_TYPE(g, "n1", "FakeGPU");
|
||||
EXPECT_DEVICE_TYPE(g, "n2", "FakeCPU");
|
||||
}
|
||||
|
||||
REGISTER_OP("CreateDatasetCPU").Output("o: resource");
|
||||
REGISTER_KERNEL_BUILDER(Name("CreateDatasetCPU").Device("FakeCPU"), DummyOp);
|
||||
|
||||
REGISTER_OP("CreateDatasetSP").Output("o: resource");
|
||||
REGISTER_KERNEL_BUILDER(Name("CreateDatasetSP").Device("FakeCPU").Priority(2),
|
||||
DummyOp);
|
||||
REGISTER_KERNEL_BUILDER(Name("CreateDatasetSP").Device("FakeGPU").Priority(1),
|
||||
DummyOp);
|
||||
|
||||
REGISTER_OP("CreateDatasetRP").Output("o: resource");
|
||||
REGISTER_KERNEL_BUILDER(Name("CreateDatasetRP").Device("FakeCPU").Priority(1),
|
||||
DummyOp);
|
||||
REGISTER_KERNEL_BUILDER(Name("CreateDatasetRP").Device("FakeGPU").Priority(2),
|
||||
DummyOp);
|
||||
|
||||
REGISTER_OP("CreateDatasetNP").Output("o: resource");
|
||||
REGISTER_KERNEL_BUILDER(Name("CreateDatasetNP").Device("FakeCPU"), DummyOp);
|
||||
REGISTER_KERNEL_BUILDER(Name("CreateDatasetNP").Device("FakeGPU"), DummyOp);
|
||||
|
||||
REGISTER_OP("IteratorNP").Input("i: resource").Output("o: float");
|
||||
REGISTER_KERNEL_BUILDER(Name("IteratorNP").Device("FakeCPU"), DummyOp);
|
||||
REGISTER_KERNEL_BUILDER(Name("IteratorNP").Device("FakeGPU"), DummyOp);
|
||||
|
||||
REGISTER_OP("IteratorSP").Input("i: resource").Output("o: float");
|
||||
REGISTER_KERNEL_BUILDER(Name("IteratorSP").Device("FakeCPU").Priority(2),
|
||||
DummyOp);
|
||||
REGISTER_KERNEL_BUILDER(Name("IteratorSP").Device("FakeGPU").Priority(1),
|
||||
DummyOp);
|
||||
|
||||
REGISTER_OP("IteratorRP").Input("i: resource").Output("o: float");
|
||||
REGISTER_KERNEL_BUILDER(Name("IteratorRP").Device("FakeCPU").Priority(1),
|
||||
DummyOp);
|
||||
REGISTER_KERNEL_BUILDER(Name("IteratorRP").Device("FakeGPU").Priority(2),
|
||||
DummyOp);
|
||||
|
||||
REGISTER_OP("IteratorGPU").Input("i: resource").Output("o: float");
|
||||
REGISTER_KERNEL_BUILDER(Name("IteratorGPU").Device("FakeGPU"), DummyOp);
|
||||
|
||||
// Test reference edges with one node having prioritized kernels and the other
|
||||
// has no preference. We should respect priority here.
|
||||
TEST_F(PlacerTest, TestDSWithPriority) {
|
||||
Graph g(OpRegistry::Global());
|
||||
{
|
||||
GraphDefBuilder b(GraphDefBuilder::kFailImmediately);
|
||||
Node* ds = ops::SourceOp("CreateDatasetSP", b.opts().WithName("ds"));
|
||||
ops::UnaryOp("IteratorNP", ops::NodeOut(ds, 0), b.opts().WithName("it"));
|
||||
TF_EXPECT_OK(BuildGraph(b, &g));
|
||||
}
|
||||
TF_EXPECT_OK(Place(&g));
|
||||
EXPECT_DEVICE_TYPE(g, "ds", "FakeCPU");
|
||||
EXPECT_DEVICE_TYPE(g, "it", "FakeCPU");
|
||||
}
|
||||
|
||||
// Test reference edges with one node having kernels with regular priority and
|
||||
// the other has no preference. We should respect priority here.
|
||||
TEST_F(PlacerTest, TestDSWithGPUPriority) {
|
||||
Graph g(OpRegistry::Global());
|
||||
{
|
||||
GraphDefBuilder b(GraphDefBuilder::kFailImmediately);
|
||||
Node* ds = ops::SourceOp("CreateDatasetRP", b.opts().WithName("ds"));
|
||||
ops::UnaryOp("IteratorNP", ops::NodeOut(ds, 0), b.opts().WithName("it"));
|
||||
TF_EXPECT_OK(BuildGraph(b, &g));
|
||||
}
|
||||
TF_EXPECT_OK(Place(&g));
|
||||
EXPECT_DEVICE_TYPE(g, "ds", "FakeGPU");
|
||||
EXPECT_DEVICE_TYPE(g, "it", "FakeGPU");
|
||||
}
|
||||
|
||||
// Test reference edges with one node having prioritized kernels and the other
|
||||
// has no preference. We should respect priority here.
|
||||
TEST_F(PlacerTest, TestITWithPriority) {
|
||||
Graph g(OpRegistry::Global());
|
||||
{
|
||||
GraphDefBuilder b(GraphDefBuilder::kFailImmediately);
|
||||
Node* ds = ops::SourceOp("CreateDatasetNP", b.opts().WithName("ds"));
|
||||
ops::UnaryOp("IteratorSP", ops::NodeOut(ds, 0), b.opts().WithName("it"));
|
||||
TF_EXPECT_OK(BuildGraph(b, &g));
|
||||
}
|
||||
TF_EXPECT_OK(Place(&g));
|
||||
EXPECT_DEVICE_TYPE(g, "ds", "FakeCPU");
|
||||
EXPECT_DEVICE_TYPE(g, "it", "FakeCPU");
|
||||
}
|
||||
|
||||
// Test reference edges with one node having kernels with regular priority and
|
||||
// the other has no preference. We should respect priority here.
|
||||
TEST_F(PlacerTest, TestITWithGPUPriority) {
|
||||
Graph g(OpRegistry::Global());
|
||||
{
|
||||
GraphDefBuilder b(GraphDefBuilder::kFailImmediately);
|
||||
Node* ds = ops::SourceOp("CreateDatasetNP", b.opts().WithName("ds"));
|
||||
ops::UnaryOp("IteratorRP", ops::NodeOut(ds, 0), b.opts().WithName("it"));
|
||||
TF_EXPECT_OK(BuildGraph(b, &g));
|
||||
}
|
||||
TF_EXPECT_OK(Place(&g));
|
||||
EXPECT_DEVICE_TYPE(g, "ds", "FakeGPU");
|
||||
EXPECT_DEVICE_TYPE(g, "it", "FakeGPU");
|
||||
}
|
||||
|
||||
// Test reference edges with one node having prioritized kernels and other node
|
||||
// can only be placed on GPU. We should respect the constraint then.
|
||||
TEST_F(PlacerTest, TestITGPU) {
|
||||
Graph g(OpRegistry::Global());
|
||||
{
|
||||
GraphDefBuilder b(GraphDefBuilder::kFailImmediately);
|
||||
Node* ds = ops::SourceOp("CreateDatasetSP", b.opts().WithName("ds"));
|
||||
ops::UnaryOp("IteratorGPU", ops::NodeOut(ds, 0), b.opts().WithName("it"));
|
||||
TF_EXPECT_OK(BuildGraph(b, &g));
|
||||
}
|
||||
TF_EXPECT_OK(Place(&g));
|
||||
EXPECT_DEVICE_TYPE(g, "ds", "FakeGPU");
|
||||
EXPECT_DEVICE_TYPE(g, "it", "FakeGPU");
|
||||
}
|
||||
|
||||
// Test reference edges with one node having prioritized kernels and other node
|
||||
// can only be placed on CPU. We should respect the constraint then.
|
||||
TEST_F(PlacerTest, TestSimpleIteratorOnlyGPU) {
|
||||
Graph g(OpRegistry::Global());
|
||||
{
|
||||
GraphDefBuilder b(GraphDefBuilder::kFailImmediately);
|
||||
Node* ds = ops::SourceOp("CreateDatasetCPU", b.opts().WithName("ds"));
|
||||
ops::UnaryOp("IteratorRP", ops::NodeOut(ds, 0), b.opts().WithName("it"));
|
||||
TF_EXPECT_OK(BuildGraph(b, &g));
|
||||
}
|
||||
TF_EXPECT_OK(Place(&g));
|
||||
EXPECT_DEVICE_TYPE(g, "ds", "FakeCPU");
|
||||
EXPECT_DEVICE_TYPE(g, "it", "FakeCPU");
|
||||
}
|
||||
|
||||
// Test constraints with agreeing priorities.
|
||||
TEST_F(PlacerTest, TestAgreeingPriorities) {
|
||||
Graph g(OpRegistry::Global());
|
||||
{
|
||||
GraphDefBuilder b(GraphDefBuilder::kFailImmediately);
|
||||
Node* ds = ops::SourceOp("CreateDatasetSP", b.opts().WithName("ds"));
|
||||
ops::UnaryOp("IteratorSP", ops::NodeOut(ds, 0), b.opts().WithName("it"));
|
||||
TF_EXPECT_OK(BuildGraph(b, &g));
|
||||
}
|
||||
TF_EXPECT_OK(Place(&g));
|
||||
EXPECT_DEVICE_TYPE(g, "ds", "FakeCPU");
|
||||
EXPECT_DEVICE_TYPE(g, "it", "FakeCPU");
|
||||
}
|
||||
|
||||
// Test constraints with agreeing regular priorities.
|
||||
TEST_F(PlacerTest, TestAgreeingRegularPriorities) {
|
||||
Graph g(OpRegistry::Global());
|
||||
{
|
||||
GraphDefBuilder b(GraphDefBuilder::kFailImmediately);
|
||||
Node* ds = ops::SourceOp("CreateDatasetRP", b.opts().WithName("ds"));
|
||||
ops::UnaryOp("IteratorRP", ops::NodeOut(ds, 0), b.opts().WithName("it"));
|
||||
TF_EXPECT_OK(BuildGraph(b, &g));
|
||||
}
|
||||
TF_EXPECT_OK(Place(&g));
|
||||
EXPECT_DEVICE_TYPE(g, "ds", "FakeGPU");
|
||||
EXPECT_DEVICE_TYPE(g, "it", "FakeGPU");
|
||||
}
|
||||
|
||||
// Test constraints with different priorities. In this case, we should bail
|
||||
// and just revert to default.
|
||||
TEST_F(PlacerTest, TestConflictingPriorities) {
|
||||
Graph g(OpRegistry::Global());
|
||||
{
|
||||
GraphDefBuilder b(GraphDefBuilder::kFailImmediately);
|
||||
Node* ds = ops::SourceOp("CreateDatasetSP", b.opts().WithName("ds"));
|
||||
ops::UnaryOp("IteratorRP", ops::NodeOut(ds, 0), b.opts().WithName("it"));
|
||||
TF_EXPECT_OK(BuildGraph(b, &g));
|
||||
}
|
||||
TF_EXPECT_OK(Place(&g));
|
||||
EXPECT_DEVICE_TYPE(g, "ds", "FakeGPU");
|
||||
EXPECT_DEVICE_TYPE(g, "it", "FakeGPU");
|
||||
}
|
||||
|
||||
// Test constraints with different priorities. In this case, we should bail
|
||||
// and just revert to default.
|
||||
TEST_F(PlacerTest, TestConflictingPrioritiesReversed) {
|
||||
Graph g(OpRegistry::Global());
|
||||
{
|
||||
GraphDefBuilder b(GraphDefBuilder::kFailImmediately);
|
||||
Node* ds = ops::SourceOp("CreateDatasetRP", b.opts().WithName("ds"));
|
||||
ops::UnaryOp("IteratorSP", ops::NodeOut(ds, 0), b.opts().WithName("it"));
|
||||
TF_EXPECT_OK(BuildGraph(b, &g));
|
||||
}
|
||||
TF_EXPECT_OK(Place(&g));
|
||||
EXPECT_DEVICE_TYPE(g, "ds", "FakeGPU");
|
||||
EXPECT_DEVICE_TYPE(g, "it", "FakeGPU");
|
||||
}
|
||||
|
||||
// Test that a graph with device type and reference constraints on
|
||||
// some of the ops will successfully assign nodes to the constrained
|
||||
// device, and colocate nodes with reference connections.
|
||||
|
@ -33,6 +33,11 @@ message KernelDef {
|
||||
// won't be used unless the user specifies a "_kernel" attr with
|
||||
// value matching this.
|
||||
string label = 5;
|
||||
|
||||
// Prioritization of kernel amongst different devices. By default we assume
|
||||
// priority is 0. The higher the priority the better. By default (i.e. if
|
||||
// this is not set), we prefer GPU kernels over CPU.
|
||||
int32 priority = 6;
|
||||
}
|
||||
|
||||
// A collection of KernelDefs
|
||||
|
@ -66,6 +66,11 @@ KernelDefBuilder& KernelDefBuilder::Label(const char* label) {
|
||||
return *this;
|
||||
}
|
||||
|
||||
KernelDefBuilder& KernelDefBuilder::Priority(int32 priority) {
|
||||
kernel_def_->set_priority(priority);
|
||||
return *this;
|
||||
}
|
||||
|
||||
const KernelDef* KernelDefBuilder::Build() {
|
||||
KernelDef* r = kernel_def_;
|
||||
kernel_def_ = nullptr;
|
||||
|
@ -64,6 +64,9 @@ class KernelDefBuilder {
|
||||
// "_kernel" attr. May only be specified once. Returns *this.
|
||||
KernelDefBuilder& Label(const char* label);
|
||||
|
||||
// Specify a priority number for this kernel.
|
||||
KernelDefBuilder& Priority(int32 priority);
|
||||
|
||||
// Returns a pointer to a KernelDef with fields set based on the
|
||||
// above calls to this instance.
|
||||
// Caller takes ownership of the result.
|
||||
|
@ -1091,7 +1091,7 @@ Status FindKernelDef(const DeviceType& device_type, const NodeDef& node_def,
|
||||
|
||||
Status SupportedDeviceTypesForNode(
|
||||
const std::vector<DeviceType>& prioritized_types, const NodeDef& def,
|
||||
DeviceTypeVector* device_types) {
|
||||
PrioritizedDeviceTypeVector* prioritized_device_types) {
|
||||
// TODO(zhifengc): Changes the callers (SimplePlacer and
|
||||
// DynamicPlacer) to consider the possibility that 'def' is call to
|
||||
// a user-defined function and only calls this
|
||||
@ -1104,12 +1104,21 @@ Status SupportedDeviceTypesForNode(
|
||||
bool was_attr_mismatch;
|
||||
TF_RETURN_IF_ERROR(
|
||||
FindKernelRegistration(device_type, def, ®, &was_attr_mismatch));
|
||||
if (reg != nullptr) device_types->push_back(device_type);
|
||||
if (reg != nullptr) {
|
||||
int32 priority = reg->def.priority();
|
||||
prioritized_device_types->emplace_back(device_type, priority);
|
||||
}
|
||||
}
|
||||
std::sort(prioritized_device_types->begin(),
|
||||
prioritized_device_types->end(),
|
||||
[](const std::pair<DeviceType, int32>& a,
|
||||
const std::pair<DeviceType, int32>& b) {
|
||||
return a.second > b.second;
|
||||
});
|
||||
} else {
|
||||
// Assumes that all device types support this node.
|
||||
for (const DeviceType& device_type : prioritized_types) {
|
||||
device_types->push_back(device_type);
|
||||
prioritized_device_types->push_back(std::make_pair(device_type, 0));
|
||||
}
|
||||
}
|
||||
return Status::OK();
|
||||
|
@ -1237,7 +1237,7 @@ Status CreateOpKernel(DeviceType device_type, DeviceBase* device,
|
||||
// * def has all attrs specified (e.g. using AddDefaultsToNodeDef()).
|
||||
Status SupportedDeviceTypesForNode(
|
||||
const std::vector<DeviceType>& prioritized_types, const NodeDef& def,
|
||||
DeviceTypeVector* device_types);
|
||||
PrioritizedDeviceTypeVector* device_types);
|
||||
|
||||
// Returns a message with a description of the kernels registered for op
|
||||
// `op_name`.
|
||||
|
@ -102,6 +102,27 @@ REGISTER_OP("Test4").Input("i: float").Output("o: float");
|
||||
REGISTER_KERNEL_BUILDER(Name("Test4").Device(DEVICE_CPU), DummyKernel);
|
||||
REGISTER_KERNEL_BUILDER(Name("Test4").Device(DEVICE_GPU), DummyKernel);
|
||||
|
||||
// Kernels with different priorities.
|
||||
REGISTER_OP("Test5").Input("a: T").Input("b: T").Attr("T: type");
|
||||
|
||||
class TestOp5Cpu : public tensorflow::OpKernel {
|
||||
public:
|
||||
explicit TestOp5Cpu(OpKernelConstruction* context) : OpKernel(context) {}
|
||||
void Compute(OpKernelContext* context) override {}
|
||||
};
|
||||
|
||||
REGISTER_KERNEL_BUILDER(Name("Test5").Device(DEVICE_CPU).Priority(2),
|
||||
TestOp5Cpu);
|
||||
|
||||
class TestOp5Gpu : public tensorflow::OpKernel {
|
||||
public:
|
||||
explicit TestOp5Gpu(OpKernelConstruction* context) : OpKernel(context) {}
|
||||
void Compute(OpKernelContext* context) override {}
|
||||
};
|
||||
|
||||
REGISTER_KERNEL_BUILDER(Name("Test5").Device(DEVICE_GPU).Priority(1),
|
||||
TestOp5Gpu);
|
||||
|
||||
static std::vector<DeviceType> DeviceTypes() {
|
||||
return {DeviceType(DEVICE_GPU), DeviceType(DEVICE_CPU)};
|
||||
}
|
||||
@ -185,10 +206,10 @@ TEST_F(OpKernelTest, SuccessBothCpuAndGpu) {
|
||||
|
||||
TEST_F(OpKernelTest, CpuTypeRegistered) {
|
||||
NodeDef ndef = CreateNodeDef("Test1", {DT_FLOAT, DT_INT32});
|
||||
DeviceTypeVector devs;
|
||||
PrioritizedDeviceTypeVector devs;
|
||||
TF_ASSERT_OK(SupportedDeviceTypesForNode(DeviceTypes(), ndef, &devs));
|
||||
EXPECT_EQ(1, devs.size());
|
||||
EXPECT_EQ(DeviceType(DEVICE_CPU), devs[0]);
|
||||
EXPECT_EQ(DeviceType(DEVICE_CPU), devs[0].first);
|
||||
}
|
||||
|
||||
TEST_F(OpKernelTest, CpuAndGpuTypeRegistered) {
|
||||
@ -196,24 +217,24 @@ TEST_F(OpKernelTest, CpuAndGpuTypeRegistered) {
|
||||
// Try a node def of an op that is registered for a specific type
|
||||
// only on CPU.
|
||||
NodeDef ndef = CreateNodeDef("Test3", {DT_INT8, DT_INT8});
|
||||
DeviceTypeVector devs;
|
||||
PrioritizedDeviceTypeVector devs;
|
||||
TF_ASSERT_OK(SupportedDeviceTypesForNode(DeviceTypes(), ndef, &devs));
|
||||
EXPECT_EQ(1, devs.size());
|
||||
EXPECT_EQ(DeviceType(DEVICE_CPU), devs[0]);
|
||||
EXPECT_EQ(DeviceType(DEVICE_CPU), devs[0].first);
|
||||
}
|
||||
{
|
||||
// Try a node def of an op that is registered for a specific type
|
||||
// only on GPU.
|
||||
NodeDef ndef = CreateNodeDef("Test3", {DT_FLOAT, DT_FLOAT});
|
||||
DeviceTypeVector devs;
|
||||
PrioritizedDeviceTypeVector devs;
|
||||
TF_ASSERT_OK(SupportedDeviceTypesForNode(DeviceTypes(), ndef, &devs));
|
||||
EXPECT_EQ(1, devs.size());
|
||||
EXPECT_EQ(DeviceType(DEVICE_GPU), devs[0]);
|
||||
EXPECT_EQ(DeviceType(DEVICE_GPU), devs[0].first);
|
||||
}
|
||||
{
|
||||
// Try a node def of an op that is only registered for other types.
|
||||
NodeDef ndef = CreateNodeDef("Test3", {DT_STRING, DT_STRING});
|
||||
DeviceTypeVector devs;
|
||||
PrioritizedDeviceTypeVector devs;
|
||||
TF_ASSERT_OK(SupportedDeviceTypesForNode(DeviceTypes(), ndef, &devs));
|
||||
EXPECT_EQ(0, devs.size());
|
||||
}
|
||||
@ -221,11 +242,23 @@ TEST_F(OpKernelTest, CpuAndGpuTypeRegistered) {
|
||||
{
|
||||
// Try a node def of an op that is registered for both.
|
||||
NodeDef ndef = CreateNodeDef("Test4", {DT_FLOAT});
|
||||
DeviceTypeVector devs;
|
||||
PrioritizedDeviceTypeVector devs;
|
||||
TF_ASSERT_OK(SupportedDeviceTypesForNode(DeviceTypes(), ndef, &devs));
|
||||
EXPECT_EQ(2, devs.size());
|
||||
EXPECT_EQ(DeviceType(DEVICE_GPU), devs[0]);
|
||||
EXPECT_EQ(DeviceType(DEVICE_CPU), devs[1]);
|
||||
EXPECT_EQ(DeviceType(DEVICE_GPU), devs[0].first);
|
||||
EXPECT_EQ(DeviceType(DEVICE_CPU), devs[1].first);
|
||||
}
|
||||
|
||||
{
|
||||
// Try a node def of an op where kernels have priorities.
|
||||
NodeDef ndef = CreateNodeDef("Test5", {DT_STRING, DT_STRING});
|
||||
PrioritizedDeviceTypeVector devs;
|
||||
TF_ASSERT_OK(SupportedDeviceTypesForNode(DeviceTypes(), ndef, &devs));
|
||||
EXPECT_EQ(2, devs.size());
|
||||
EXPECT_EQ(DeviceType(DEVICE_CPU), devs[0].first);
|
||||
EXPECT_EQ(2, devs[0].second);
|
||||
EXPECT_EQ(DeviceType(DEVICE_GPU), devs[1].first);
|
||||
EXPECT_EQ(1, devs[1].second);
|
||||
}
|
||||
}
|
||||
|
||||
@ -412,11 +445,11 @@ class OpKernelBuilderTest : public ::testing::Test {
|
||||
}
|
||||
|
||||
// Test SupportedDeviceTypesForNode()
|
||||
DeviceTypeVector devices;
|
||||
PrioritizedDeviceTypeVector devices;
|
||||
TF_EXPECT_OK(SupportedDeviceTypesForNode(DeviceTypes(), def, &devices));
|
||||
bool found = false;
|
||||
for (const DeviceType& dt : devices) {
|
||||
if (dt == device_type) {
|
||||
for (const auto& dt : devices) {
|
||||
if (dt.first == device_type) {
|
||||
found = true;
|
||||
}
|
||||
}
|
||||
@ -445,11 +478,11 @@ class OpKernelBuilderTest : public ::testing::Test {
|
||||
EXPECT_EQ(code, status.code());
|
||||
|
||||
// Test SupportedDeviceTypesForNode().
|
||||
DeviceTypeVector devices;
|
||||
PrioritizedDeviceTypeVector devices;
|
||||
if (errors::IsNotFound(status)) {
|
||||
TF_EXPECT_OK(SupportedDeviceTypesForNode(DeviceTypes(), def, &devices));
|
||||
for (const DeviceType& dt : devices) {
|
||||
EXPECT_NE(dt, device_type);
|
||||
for (const auto& dt : devices) {
|
||||
EXPECT_NE(dt.first, device_type);
|
||||
}
|
||||
} else {
|
||||
Status status2 =
|
||||
@ -562,7 +595,7 @@ REGISTER_KERNEL_BUILDER(Name("DuplicateKernel").Device(DEVICE_CPU),
|
||||
|
||||
TEST_F(OpKernelBuilderTest, DuplicateKernel) {
|
||||
const NodeDef ndef = CreateNodeDef("DuplicateKernel", {});
|
||||
DeviceTypeVector devs;
|
||||
PrioritizedDeviceTypeVector devs;
|
||||
Status status = SupportedDeviceTypesForNode(DeviceTypes(), ndef, &devs);
|
||||
ASSERT_FALSE(status.ok());
|
||||
EXPECT_TRUE(str_util::StrContains(
|
||||
@ -582,7 +615,7 @@ REGISTER_KERNEL_BUILDER(
|
||||
TEST_F(OpKernelBuilderTest, DuplicateKernelForT) {
|
||||
const NodeDef ndef =
|
||||
CreateNodeDef("DuplicateKernelForT", {"T|type|DT_FLOAT"});
|
||||
DeviceTypeVector devs;
|
||||
PrioritizedDeviceTypeVector devs;
|
||||
Status status = SupportedDeviceTypesForNode(DeviceTypes(), ndef, &devs);
|
||||
ASSERT_FALSE(status.ok());
|
||||
EXPECT_TRUE(str_util::StrContains(
|
||||
@ -603,7 +636,7 @@ REGISTER_KERNEL_BUILDER(Name("BadConstraint")
|
||||
|
||||
TEST_F(OpKernelBuilderTest, BadConstraint) {
|
||||
const NodeDef ndef = CreateNodeDef("BadConstraint", {});
|
||||
DeviceTypeVector devs;
|
||||
PrioritizedDeviceTypeVector devs;
|
||||
Status status = SupportedDeviceTypesForNode(DeviceTypes(), ndef, &devs);
|
||||
ASSERT_FALSE(status.ok());
|
||||
EXPECT_TRUE(
|
||||
|
@ -104,6 +104,8 @@ typedef gtl::InlinedVector<DataType, 4> DataTypeVector;
|
||||
typedef gtl::ArraySlice<DataType> DataTypeSlice;
|
||||
|
||||
typedef gtl::InlinedVector<DeviceType, 4> DeviceTypeVector;
|
||||
typedef gtl::InlinedVector<std::pair<DeviceType, int32>, 4>
|
||||
PrioritizedDeviceTypeVector;
|
||||
|
||||
// Convert the enums to strings for errors:
|
||||
string DataTypeString(DataType dtype);
|
||||
|
Loading…
Reference in New Issue
Block a user