Automated rollback of commit 8b633af6ff

PiperOrigin-RevId: 233657173
This commit is contained in:
Igor Ganichev 2019-02-12 12:30:40 -08:00 committed by TensorFlower Gardener
parent f80b1ca748
commit d7b3e49283
6 changed files with 153 additions and 459 deletions

View File

@ -20,7 +20,6 @@ limitations under the License.
#include <utility>
#include <vector>
#include "absl/strings/str_join.h"
#include "tensorflow/core/common_runtime/device.h"
#include "tensorflow/core/framework/attr_value_util.h"
#include "tensorflow/core/framework/device_attributes.pb.h"
@ -34,7 +33,6 @@ limitations under the License.
#include "tensorflow/core/lib/strings/str_util.h"
#include "tensorflow/core/lib/strings/strcat.h"
#include "tensorflow/core/util/device_name_utils.h"
#include "tensorflow/core/util/dump_graph.h"
#include "tensorflow/core/util/port.h"
namespace tensorflow {
@ -98,27 +96,6 @@ std::vector<Device*> FilterSupportedDevices(
return filtered_devices;
}
// Using absl::StrJoin with lambda does not work in tf-lite builds.
std::vector<string> DevicesToString(const std::vector<Device*> devices) {
std::vector<string> v;
v.reserve(devices.size());
for (Device* d : devices) {
v.push_back(d->name());
}
return v;
}
// Using absl::StrJoin with lambda does not work in tf-lite builds.
std::vector<string> DeviceTypeAndPriorityToString(
const PrioritizedDeviceTypeVector& devices) {
std::vector<string> v;
v.reserve(devices.size());
for (const std::pair<DeviceType, int32>& device_and_type : devices) {
v.push_back(DeviceTypeString(device_and_type.first));
}
return v;
}
// Returns true if the node has no inputs and produces outputs
// that are consumed by a single node.
//
@ -130,16 +107,6 @@ bool IsGeneratorNode(const Node* node) {
!IsRefType(node->output_type(0));
}
// While Placer can override requested device on ops processing
// resources, i.e. node that take (and potentially return) a resource,
// it must not override requested device on ops generating a resource,
// e.g. VarHandleOp, _Arg. Such ops are currently no-input, single resource/ref
// output nodes.
bool IsResourceGeneratorNode(const Node& node) {
return node.num_inputs() == 0 && node.num_outputs() == 1 &&
(IsRefType(node.output_type(0)) || node.output_type(0) == DT_RESOURCE);
}
bool IsExemptFromResourceInputColocation(const Node* node) {
// Note: Partitioned function calls, which place and partition their
// function bodies, are exempt from this check: they forward resource and
@ -169,7 +136,7 @@ bool ArePrioritiesSame(const PrioritizedDeviceTypeVector& a_types,
return true;
}
// Represents a node in the disjoint node forest and the
// Represents a node in the disjoint node set forest, and the
// accumulated constraints on the device used by that node.
class Member {
public:
@ -188,76 +155,18 @@ class Member {
&supported_device_types_);
}
const DeviceNameUtils::ParsedName& requested_device_name() const {
return requested_device_name_;
const DeviceNameUtils::ParsedName& device_name() const {
return device_name_;
}
Status SetAssignedDeviceName(const string& device_name) {
if (DeviceNameUtils::HasSomeDetails(requested_device_name_)) {
return errors::Internal(
"Setting assigned device name when there is a requested device set "
"is unsupported");
}
if (!DeviceNameUtils::ParseFullName(device_name, &assigned_device_name_)) {
Status SetDeviceName(const string& device_name) {
if (!DeviceNameUtils::ParseFullName(device_name, &device_name_)) {
return errors::Internal("Malformed assigned device '", device_name, "'");
}
// Set requested device to assigned_device to maintain the invariant that
// requested is a specialization of assigned.
requested_device_name_ = assigned_device_name_;
return Status::OK();
}
Status SetRequestedDeviceName(const Node& node) {
if (!DeviceNameUtils::ParseFullName(node.requested_device(),
&requested_device_name_)) {
return errors::InvalidArgument("Malformed device specification '",
node.requested_device(),
"' in node: ", node.DebugString());
}
if (DeviceNameUtils::HasSomeDetails(assigned_device_name_)) {
return errors::Internal(
"Setting requested device name when there is an assigned device set "
"is unsupported");
}
return Status::OK();
}
Status EnsureCompatibilityAcrossResourceEdge(
const Node& src, const Member& src_root,
const Node& dst, /*dst_root is this*/
bool log_device_placement) {
if (!DeviceNameUtils::AreCompatibleDevNames(src_root.assigned_device_name_,
assigned_device_name_)) {
return errors::InvalidArgument(
"Cannot place the graph because a reference or resource edge "
"connects colocation groups with incompatible assigned devices: ",
DeviceNameUtils::ParsedNameToString(src_root.assigned_device_name_),
" vs ", DeviceNameUtils::ParsedNameToString(assigned_device_name_));
}
if (DeviceNameUtils::AreCompatibleDevNames(src_root.requested_device_name_,
requested_device_name_)) {
return Status::OK();
}
// If we are here, assigned devices are compatible but requested ones are
// not. We will be overriding the requested device for destination node, but
// need to preserve the invariant that it will be a specialization of
// the assigned device.
if (log_device_placement) {
LOG(INFO) << "Ignoring device specification "
<< DeviceNameUtils::ParsedNameToString(requested_device_name_)
<< " for node '" << dst.name()
<< "' because the input edge from '" << src.name()
<< "' is a reference connection and already has a device "
"field set to "
<< DeviceNameUtils::ParsedNameToString(
src_root.requested_device_name_);
}
requested_device_name_ = src_root.requested_device_name_;
DeviceNameUtils::EnsureSpecification(&requested_device_name_,
assigned_device_name_);
return Status::OK();
void SetDeviceName(const DeviceNameUtils::ParsedName& device_name) {
device_name_ = device_name;
}
const PrioritizedDeviceTypeVector& supported_device_types() const {
@ -319,13 +228,13 @@ class Member {
}
Status MergeDeviceNames(const Member& other, bool allow_soft_placement) {
// Assuming the "requested is a specialization of assigned" invariant holds
// for this and `other`, it will hold after these two merges.
TF_RETURN_IF_ERROR(DeviceNameUtils::MergeDevNames(
&requested_device_name_, other.requested_device_name_,
allow_soft_placement));
return DeviceNameUtils::MergeDevNames(&assigned_device_name_,
other.assigned_device_name_,
return DeviceNameUtils::MergeDevNames(&device_name_, other.device_name_,
allow_soft_placement);
}
Status MergeDeviceNames(const string& dev_name, bool allow_soft_placement) {
DeviceNameUtils::ParsedName parsed;
DeviceNameUtils::ParseFullName(dev_name, &parsed);
return DeviceNameUtils::MergeDevNames(&device_name_, parsed,
allow_soft_placement);
}
@ -415,29 +324,16 @@ class Member {
if (node.assigned_device_name_index() == assigned_device_name_index_) {
return Status::OK();
}
DeviceNameUtils::ParsedName parsed;
DeviceNameUtils::ParseFullName(node.assigned_device_name(), &parsed);
Status s = DeviceNameUtils::MergeDevNames(&assigned_device_name_, parsed,
allow_soft_placement);
Status s =
MergeDeviceNames(node.assigned_device_name(), allow_soft_placement);
if (!s.ok()) {
return errors::Internal(
"Constraining by assigned device should not cause an error. Original "
"root's assigned device name: ",
DeviceNameUtils::ParsedNameToString(assigned_device_name_),
" node's assigned device name \"", node.assigned_device_name(),
"root device name: ",
DeviceNameUtils::ParsedNameToString(device_name_),
" assigned device name \"", node.assigned_device_name(),
". Error: ", s.error_message());
}
s = DeviceNameUtils::MergeDevNames(&requested_device_name_, parsed,
allow_soft_placement);
if (!s.ok()) {
return errors::Internal(
"Constraining by assigned device should not cause an error. Original "
"root's requested device name: \"",
DeviceNameUtils::ParsedNameToString(requested_device_name_),
"\", node's assigned device name \"", node.assigned_device_name(),
"\". Error: ", s.error_message());
}
assigned_device_name_index_ = node.assigned_device_name_index();
// Clear cached possible_devices, if any.
@ -450,20 +346,6 @@ class Member {
}
const std::vector<Device*>& possible_devices() { return possible_devices_; }
string DebugString() {
return absl::StrCat(
"Member(assigned_device_name_index_=", assigned_device_name_index_,
" requested_device_name_=",
DeviceNameUtils::ParsedNameToString(requested_device_name_),
" assigned_device_name_=",
DeviceNameUtils::ParsedNameToString(assigned_device_name_),
" supported_device_types_=[",
absl::StrJoin(DeviceTypeAndPriorityToString(supported_device_types_),
", "),
"] possible_devices_=[",
absl::StrJoin(DevicesToString(possible_devices_), ", "), "]");
}
private:
// The id of the node that is the parent of this one, or its own
// id if it is a root. parent <= 0 indicates that this member is invalid.
@ -474,37 +356,19 @@ class Member {
// sets.
int rank_ = 0;
// Once colocation groups have been formed, the Placer starts actually
// choosing devices. All nodes in a group must be assigned to the same
// device. Once we assigned the first device to some node in this group,
// we set assigned_device_name_index to this device name's index in the
// graph.
// The `*_device_name_` fields will contain the parsed name of this device
// and `possible_devices`, if computed, will contain just this device.
// Once colocation groups have been formed and we assigned at least
// one node in this group to a device, assigned_device_name_index will
// contain this device name's index in the graph. The `device_name` will
// contain the parsed name of this device and `possible_devices`, if
// computed, will contain just this device.
// `assigned_device_name_index` is an optimization to avoid parsing and
// comparing device names. The value of -1 signals that a single device
// has not been chosen yet.
int assigned_device_name_index_ = -1;
// The merged form of the device requested for this node, with those of all of
// its children. requested_device_name_ is always kept a specialization (i.e.
// DeviceNameUtils::IsSpecialization) of assigned_device_name_. When no device
// is requested, this field is set to assigned_device_name_. As a
// specialization of assigned_device_name_, requested_device_name_ represents
// the most specific form of all assigned and requested devices of this node
// and its children, if this node is a root. requested_device_name_ is used
// to finally select devices for nodes. We can override requested devices due
// to resource colocation constraints but not assigned devices (unless soft
// placement is on).
DeviceNameUtils::ParsedName requested_device_name_;
// The merged form of the device assigned for this node, with
// The merged form of the device requested for this node, with
// those of all of its children.
// This field is used to raise errors due to unsatisfiable constraints.
// Can be a partial specification.
// INVARIANT: requested_device_name_ is always a
// DeviceNameUtils::IsSpecialization of assigned_device_name_.
DeviceNameUtils::ParsedName assigned_device_name_;
DeviceNameUtils::ParsedName device_name_;
// The intersection of all device types supported by this node,
// and those of all of its children, in priority order
@ -515,7 +379,7 @@ class Member {
// and all of its children have been assigned, or nullptr if this
// has not yet been computed.
std::vector<Device*> possible_devices_;
}; // namespace
};
// This class maintains the connected components of a colocation
// constraint graph, and uses this information to assign a satisfying
@ -629,9 +493,34 @@ class ColocationGraph {
int dst_root_id = FindRoot(dst->id());
auto& src_root = members_[src_root_id];
auto& dst_root = members_[dst_root_id];
// If both the source node and this node have partially
// specified a device, then 'dst's device should be
// cleared: the reference edge forces 'node' to be on the
// same device as the source node.
const auto& source_parsed_name = src_root.device_name();
const auto& dest_parsed_name = dst_root.device_name();
if (DeviceNameUtils::HasSomeDetails(source_parsed_name) &&
DeviceNameUtils::HasSomeDetails(dest_parsed_name)) {
// Ignore a specified device for 'dst' if the two names were
// incompatible.
if (!DeviceNameUtils::AreCompatibleDevNames(source_parsed_name,
dest_parsed_name)) {
TF_RETURN_IF_ERROR(VerifyResourceAndRefInputsCanBeColocated(
dst, src, source_parsed_name));
if (log_device_placement_) {
LOG(INFO) << "Ignoring device specification "
<< DeviceNameUtils::ParsedNameToString(dest_parsed_name)
<< " for node '" << dst->name()
<< "' because the input edge from '" << src->name()
<< "' is a reference connection and already has a device "
"field set to "
<< DeviceNameUtils::ParsedNameToString(source_parsed_name);
}
TF_RETURN_IF_ERROR(dst_root.EnsureCompatibilityAcrossResourceEdge(
*src, src_root, *dst, log_device_placement_));
// Make 'dst' colocated with the source
dst_root.SetDeviceName(source_parsed_name);
}
}
Status status = ColocateNodes(*src, src_root_id, *dst, dst_root_id);
if (!status.ok()) {
return AttachDef(
@ -792,13 +681,12 @@ class ColocationGraph {
// "devices" will contain the set of feasible placements for the
// colocated node set containing 'node'.
std::vector<Device*> devices;
if (DeviceNameUtils::HasSomeDetails(
members_[node_root].requested_device_name())) {
if (DeviceNameUtils::HasSomeDetails(members_[node_root].device_name())) {
// The root node has a (possibly partial) device
// specification, so enumerate the physical devices that
// conform to it.
device_set_->FindMatchingDevices(
members_[node_root].requested_device_name(), &devices);
device_set_->FindMatchingDevices(members_[node_root].device_name(),
&devices);
if (!devices.empty()) {
// Filter devices into those that are compatible with the root
@ -813,7 +701,7 @@ class ColocationGraph {
// The soft_device_name is the same as the node's device name
// without specifying the device type or ID.
DeviceNameUtils::ParsedName soft_device_name =
members_[node_root].requested_device_name();
members_[node_root].device_name();
soft_device_name.type.clear();
soft_device_name.has_type = false;
soft_device_name.has_id = false;
@ -834,8 +722,7 @@ class ColocationGraph {
DeviceNameUtils::ParsedName specified_device_name;
if (DeviceNameUtils::ParseFullName(node->requested_device(),
&specified_device_name) &&
specified_device_name ==
members_[node_root].requested_device_name()) {
specified_device_name == members_[node_root].device_name()) {
// The specified device and merged set device match, and
// will appear in the GraphDef (for debugging), so just
// print the specified device.
@ -887,7 +774,7 @@ class ColocationGraph {
" was colocated with a group of nodes that ",
"required incompatible device '",
DeviceNameUtils::ParsedNameToString(
members_[node_root].requested_device_name()),
members_[node_root].device_name()),
"'", debug_info);
}
}
@ -916,7 +803,10 @@ class ColocationGraph {
}
Status InitializeMembers() {
for (Node* node : graph_->op_nodes()) {
for (Node* node : graph_->nodes()) {
if (!node->IsOp()) {
continue;
}
Status status = InitializeMember(*node, &members_[node->id()]);
if (!status.ok()) {
return AttachDef(status, *node);
@ -925,22 +815,6 @@ class ColocationGraph {
return Status::OK();
}
string DebugString() {
std::unordered_set<int> roots;
std::vector<string> root_strings;
for (const Node* node : graph_->nodes()) {
if (!node->IsOp()) {
continue;
}
int node_root = FindRoot(node->id());
if (roots.count(node_root) == 0) {
root_strings.push_back(DebugInfo(node_root));
roots.insert(node_root);
}
}
return absl::StrJoin(root_strings, "\n");
}
// Returns debugging info for the node referred to by 'node_root'.
string DebugInfo(const int node_root) {
string text(
@ -986,52 +860,46 @@ class ColocationGraph {
}
strings::StrAppend(&text, "\n");
if (num_nodes_found <= 0) {
if (num_nodes_found <= 1) {
text.clear();
}
return text;
}
Status InitializeMemberWithAssignedDevice(const string& assigned_device_name,
const string& node_type,
bool must_be_full_name,
Member* member) {
// This node has already been assigned to a device, so we
// respect this placement, after sanity-checking it.
// NOTE: Since any assignment must have been performed by
// the TensorFlow runtime, we consider errors in this branch to
// be INTERNAL.
TF_RETURN_IF_ERROR(member->SetAssignedDeviceName(assigned_device_name));
if (!must_be_full_name) {
return Status::OK();
}
// Since assigned device must be a full specification, do extra checks.
const Device* assigned_device =
device_set_->FindDeviceByName(assigned_device_name);
if (assigned_device == nullptr) {
return errors::Internal("Assigned device '", assigned_device_name,
"' does not match any device");
}
for (const auto& d : member->supported_device_types()) {
if (DeviceType(assigned_device->attributes().device_type()) == d.first) {
return Status::OK();
}
}
return errors::Internal("Assigned device '", assigned_device_name,
"' does not have registered OpKernel support "
"for ",
node_type);
}
Status InitializeMember(const Node& node, Member* member) {
TF_RETURN_IF_ERROR(
member->SetParentAndSupportedDevices(node, device_types_));
if (node.has_assigned_device_name()) {
TF_RETURN_IF_ERROR(InitializeMemberWithAssignedDevice(
node.assigned_device_name(), node.type_string(), true, member));
// This node has already been assigned to a device, so we
// respect this placement, after sanity-checking it. The
// device_name and supported_device_types for this node reflect
// the assigned device, so any nodes colocated with this node
// will be assigned to the same device (assuming this is
// possible).
// NOTE: Since any assignment must have been performed by
// the TensorFlow runtime, we consider errors in this branch to
// be INTERNAL.
const string& assigned_device_name = node.assigned_device_name();
TF_RETURN_IF_ERROR(member->SetDeviceName(assigned_device_name));
const Device* assigned_device =
device_set_->FindDeviceByName(assigned_device_name);
if (assigned_device == nullptr) {
return errors::Internal("Assigned device '", assigned_device_name,
"' does not match any device");
}
for (const auto& d : member->supported_device_types()) {
if (DeviceType(assigned_device->attributes().device_type()) ==
d.first) {
return Status::OK();
}
}
return errors::Internal("Assigned device '", assigned_device_name,
"' does not have registered OpKernel support "
"for ",
node.type_string());
} else {
// This node has not yet been assigned to a device, so we
// calculate any constraints due to the set of registered
@ -1065,18 +933,15 @@ class ColocationGraph {
// If the NodeDef contains a device, then we interpret it as a
// (partial) device specification.
if (!node.requested_device().empty()) {
if (IsResourceGeneratorNode(node)) {
// Treat requested device on resource generating nodes as assigned
// device so that we don't override it.
TF_RETURN_IF_ERROR(InitializeMemberWithAssignedDevice(
node.requested_device(), node.type_string(), false, member));
} else {
// The user has specified a device in the NodeDef, try to find a
// valid device matching their specification in the set of
// devices.
// NOTE: The full name may specify a device that is not in
// n.supported_device_types(), but we check that in AssignDevice().
TF_RETURN_IF_ERROR(member->SetRequestedDeviceName(node));
// The user has specified a device in the NodeDef, try to find a
// valid device matching their specification in the set of
// devices.
// NOTE: The full name may specify a device that is not in
// n.supported_device_types(), but we check that in AssignDevice().
if (!member->SetDeviceName(node.requested_device()).ok()) {
return errors::InvalidArgument("Malformed device specification '",
node.requested_device(),
"' in node: ", node.DebugString());
}
}
}
@ -1167,6 +1032,40 @@ class ColocationGraph {
// given id is connected.
int FindRoot(int node_id) { return Member::FindRoot(&members_, node_id); }
// Ensures that the devices of 'dst's resource and reference match the device
// specified for 'src', which is an input of 'dst' with a partially or fully
// specified device.
Status VerifyResourceAndRefInputsCanBeColocated(
const Node* dst, const Node* src,
const DeviceNameUtils::ParsedName& src_parsed_name) {
std::vector<const Edge*> edges;
TF_RETURN_IF_ERROR(dst->input_edges(&edges));
for (const Edge* edge : edges) {
DataType input_type = dst->input_type(edge->dst_input());
if (input_type == DT_RESOURCE || IsRefType(input_type)) {
const Node* input_node = edge->src();
if (input_node == src) {
continue;
}
const auto& input_root = members_[FindRoot(input_node->id())];
const auto& input_parsed_name = input_root.device_name();
if (DeviceNameUtils::HasSomeDetails(input_parsed_name) &&
!DeviceNameUtils::AreCompatibleDevNames(input_parsed_name,
src_parsed_name)) {
return AttachDef(
errors::InvalidArgument(
"Could not colocate node with its "
"resource and reference inputs; devices ",
DeviceNameUtils::ParsedNameToString(input_parsed_name),
" and ", DeviceNameUtils::ParsedNameToString(src_parsed_name),
" are not compatible."),
*dst);
}
}
}
return Status::OK();
}
const Graph* const graph_; // Not owned.
std::vector<Member> members_;
const DeviceSet* device_set_; // Not owned.
@ -1220,15 +1119,6 @@ Status Placer::Run() {
return errors::FailedPrecondition("No devices are registered");
}
if (VLOG_IS_ON(3)) {
DumpGraphToFile("placer_input", *graph_, nullptr, "/tmp");
for (const Node* node : graph_->op_nodes()) {
VLOG(3) << " " << node->name() << ": requested: '"
<< node->requested_device() << "' assigned: '"
<< node->assigned_device_name() << "'";
}
}
ColocationGraph colocation_graph(
graph_, devices_, default_device_,
options_ == nullptr || options_->config.allow_soft_placement(),
@ -1236,15 +1126,14 @@ Status Placer::Run() {
TF_RETURN_IF_ERROR(colocation_graph.Initialize());
// For each node, assign a device based on the constraints in the disjoint
// node set.
// For each node, assign a device based on the constraints in the
// disjoint node set.
std::vector<Node*> second_pass;
for (Node* node : graph_->op_nodes()) {
// The graph may have come pre-populated by the framework with assigned
// devices (e.g., for stateful placements), so the placer should not try to
// place nodes that are already placed.
if (node->has_assigned_device_name()) {
TF_RETURN_IF_ERROR(colocation_graph.LimitToAssignedDevice(*node));
LogDeviceAssignment(node, log_device_placement_);
continue;
}
@ -1345,9 +1234,6 @@ Status Placer::Run() {
log_device_placement_));
}
if (VLOG_IS_ON(3)) {
DumpGraphToFile("placer_output", *graph_, nullptr, "/tmp");
}
return Status::OK();
}

View File

@ -17,7 +17,6 @@ limitations under the License.
#include <memory>
#include <string>
#include <unordered_set>
#include <utility>
#include <vector>
@ -25,15 +24,11 @@ limitations under the License.
#include "tensorflow/core/common_runtime/device_factory.h"
#include "tensorflow/core/common_runtime/device_set.h"
#include "tensorflow/core/framework/device_attributes.pb.h"
#include "tensorflow/core/framework/function.h"
#include "tensorflow/core/framework/function_testlib.h"
#include "tensorflow/core/framework/kernel_def_builder.h"
#include "tensorflow/core/framework/op.h"
#include "tensorflow/core/framework/op_def_builder.h"
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/types.pb.h"
#include "tensorflow/core/graph/graph.h"
#include "tensorflow/core/graph/graph_constructor.h"
#include "tensorflow/core/graph/graph_def_builder.h"
#include "tensorflow/core/graph/graph_def_builder_util.h"
#include "tensorflow/core/lib/core/error_codes.pb.h"
@ -45,16 +40,6 @@ limitations under the License.
namespace tensorflow {
using ::tensorflow::test::function::GDef;
using ::tensorflow::test::function::NDef;
using FDH = ::tensorflow::FunctionDefHelper;
constexpr char kCPU[] = "/device:fakecpu:0";
constexpr char kGPU[] = "/device:fakegpu:0";
constexpr char kFullCPU[] = "/job:a/replica:0/task:0/device:fakecpu:0";
constexpr char kFullGPU[] = "/job:a/replica:0/task:0/device:fakegpu:0";
namespace {
////////////////////////////////////////////////////////////////////////////////
@ -225,16 +210,6 @@ class PlacerTest : public ::testing::Test {
return Status::OK();
}
Status BuildGraph(const GraphDef& graph_def, Graph* out_graph) {
GraphConstructorOptions opts;
TF_RETURN_IF_ERROR(ConvertGraphDefToGraph(opts, graph_def, out_graph));
nodes_by_name_.clear();
for (Node* node : out_graph->nodes()) {
nodes_by_name_[node->name()] = node->id();
}
return Status::OK();
}
// Invokes the Placer on "graph". If no DeviceSet is specified, the
// placement will use the default DeviceSet (of 10 CPU and 10 GPU devices).
//
@ -891,7 +866,7 @@ TEST_F(PlacerTest, TestResourceHandle) {
}
TEST_F(PlacerTest, TestResourceHandlesOnDifferentDevicesFails) {
auto handle_test = [this](bool allow_soft_placement, bool set_assigned) {
auto handle_test = [this](bool allow_soft_placement) {
Graph g(OpRegistry::Global());
{ // Scope for temporary variables used to construct g.
GraphDefBuilder b(GraphDefBuilder::kFailImmediately);
@ -903,41 +878,27 @@ TEST_F(PlacerTest, TestResourceHandlesOnDifferentDevicesFails) {
b.opts().WithName("two_handles_in"));
TF_EXPECT_OK(BuildGraph(b, &g));
if (set_assigned) {
GetNodeByName(g, "var_cpu")
->set_assigned_device_name(
"/job:a/replica:0/task:0/device:fakecpu:0");
GetNodeByName(g, "var_gpu")
->set_assigned_device_name(
"/job:a/replica:0/task:0/device:fakegpu:0");
} else {
GetNodeByName(g, "var_cpu")
->set_requested_device("/job:a/replica:0/task:0/device:fakecpu:0");
GetNodeByName(g, "var_gpu")
->set_requested_device("/job:a/replica:0/task:0/device:fakegpu:0");
}
GetNodeByName(g, "var_cpu")
->set_assigned_device_name(
"/job:a/replica:0/task:0/device:fakecpu:0");
GetNodeByName(g, "var_gpu")
->set_assigned_device_name(
"/job:a/replica:0/task:0/device:fakegpu:0");
}
SessionOptions options;
options.config.set_allow_soft_placement(allow_soft_placement);
options.config.set_log_device_placement(true);
Status s = Place(&g, &options);
EXPECT_EQ(error::INVALID_ARGUMENT, s.code()) << s.ToString();
EXPECT_EQ(error::INVALID_ARGUMENT, s.code());
EXPECT_TRUE(str_util::StrContains(
s.error_message(),
"Cannot place the graph because a reference or resource edge "
"connects "
"colocation groups with incompatible assigned devices: "
"/job:a/replica:0/task:0/device:fakegpu:0 vs "
"/job:a/replica:0/task:0/device:fakecpu:0"));
"Could not colocate node with its resource and reference inputs"));
return Status::OK();
};
TF_EXPECT_OK(handle_test(false, false));
TF_EXPECT_OK(handle_test(false, true));
TF_EXPECT_OK(handle_test(true, false));
TF_EXPECT_OK(handle_test(true, true));
TF_EXPECT_OK(handle_test(false));
TF_EXPECT_OK(handle_test(true));
}
// Test that an assignment of an operator to the wrong device
@ -1656,127 +1617,5 @@ TEST_F(PlacerTest, TestGeneratorNodeDoesntFollowNonColocatedConsumers) {
EXPECT_DEVICE_TYPE(g, "in", "FakeGPU");
}
REGISTER_KERNEL_BUILDER(Name("_Arg").Device("FakeCPU"), DummyOp);
REGISTER_KERNEL_BUILDER(Name("_Arg").Device("FakeGPU"), DummyOp);
REGISTER_KERNEL_BUILDER(Name("_Retval").Device("FakeCPU"), DummyOp);
REGISTER_KERNEL_BUILDER(Name("_Retval").Device("FakeGPU"), DummyOp);
REGISTER_KERNEL_BUILDER(Name("Identity").Device("FakeCPU"), DummyOp);
REGISTER_KERNEL_BUILDER(Name("Identity").Device("FakeGPU"), DummyOp);
REGISTER_KERNEL_BUILDER(Name("Const").Device("FakeCPU"), DummyOp);
REGISTER_KERNEL_BUILDER(Name("Const").Device("FakeGPU"), DummyOp);
REGISTER_KERNEL_BUILDER(Name("Mul").Device("FakeCPU"), DummyOp);
REGISTER_KERNEL_BUILDER(Name("Mul").Device("FakeGPU"), DummyOp);
REGISTER_KERNEL_BUILDER(Name("Add").Device("FakeCPU"), DummyOp);
REGISTER_KERNEL_BUILDER(Name("Add").Device("FakeGPU"), DummyOp);
TEST_F(PlacerTest, RequestedDeviceOnResourceGeneratorIsTreatedAsAssigned) {
/*
* a:RES:GPU b:RES:CPU
* | |
* | |
* v v
* id1 id2
* @loc:id2
*/
FunctionDef func = test::function::ResourceOutput();
GraphDef graph = GDef(
{
NDef("a", "_Arg", {}, {{"T", DT_RESOURCE}}, kGPU),
NDef("b", "_Arg", {}, {{"T", DT_RESOURCE}}, kCPU),
NDef("id1", "Identity", {"a"},
{{"T", DT_RESOURCE},
{"_class", gtl::ArraySlice<string>({"loc:@id2"})}}),
NDef("id2", "Identity", {"b"}, {{"T", DT_RESOURCE}}),
},
// FunctionLib
{func});
Graph g(OpRegistry::Global());
TF_ASSERT_OK(BuildGraph(graph, &g));
Status s = Place(&g);
EXPECT_EQ(error::INVALID_ARGUMENT, s.code());
EXPECT_TRUE(str_util::StrContains(
s.error_message(),
"Cannot place the graph because a reference or resource edge connects "
"colocation groups with incompatible assigned devices:"));
}
TEST_F(PlacerTest, RequestedDeviceCanBeOverridden) {
/*
* a:RES b:RES
* | |
* id_a:GPU id_b:CPU
* | |
* v v
* id1 id2
* @loc:id2
*/
FunctionDef func = test::function::ResourceOutput();
GraphDef graph = GDef(
{
NDef("a", "_Arg", {}, {{"T", DT_RESOURCE}}),
NDef("b", "_Arg", {}, {{"T", DT_RESOURCE}}),
NDef("id_a", "Identity", {"a"}, {{"T", DT_RESOURCE}}, kGPU),
NDef("id_b", "Identity", {"b"}, {{"T", DT_RESOURCE}}, kCPU),
NDef("id1", "Identity", {"id_a"},
{{"T", DT_RESOURCE},
{"_class", gtl::ArraySlice<string>({"loc:@id2"})}}),
NDef("id2", "Identity", {"id_b"}, {{"T", DT_RESOURCE}}),
},
// FunctionLib
{func});
Graph g(OpRegistry::Global());
TF_ASSERT_OK(BuildGraph(graph, &g));
TF_ASSERT_OK(Place(&g));
// All should be colocated
EXPECT_COLOCATED(g, "a", "b");
EXPECT_COLOCATED(g, "id_a", "id_b");
EXPECT_COLOCATED(g, "id1", "id2");
EXPECT_COLOCATED(g, "a", "id_a");
EXPECT_COLOCATED(g, "a", "id1");
}
TEST_F(PlacerTest, AssignedDevicesAreNotOverriddenDueToResourcesAndColocation) {
/*
* a:RES b:RES
* | |
* id_a:GPU id_b:CPU
* | |
* v v
* id1 id2
* @loc:id2
*/
FunctionDef func = test::function::ResourceOutput();
GraphDef graph = GDef(
{
NDef("a", "_Arg", {}, {{"T", DT_RESOURCE}}),
NDef("b", "_Arg", {}, {{"T", DT_RESOURCE}}),
NDef("id_a", "Identity", {"a"}, {{"T", DT_RESOURCE}}),
NDef("id_b", "Identity", {"b"}, {{"T", DT_RESOURCE}}),
NDef("id1", "Identity", {"id_a"},
{{"T", DT_RESOURCE},
{"_class", gtl::ArraySlice<string>({"loc:@id2"})}}),
NDef("id2", "Identity", {"id_b"}, {{"T", DT_RESOURCE}}),
},
// FunctionLib
{func});
Graph g(OpRegistry::Global());
TF_ASSERT_OK(BuildGraph(graph, &g));
std::unordered_map<string, Node*> nodes = g.BuildNodeNameIndex();
GetNodeByName(g, "id_a")->set_assigned_device_name(kFullGPU);
GetNodeByName(g, "id_b")->set_assigned_device_name(kFullCPU);
Status s = Place(&g);
EXPECT_EQ(error::INVALID_ARGUMENT, s.code());
EXPECT_TRUE(str_util::StrContains(
s.error_message(),
"Cannot place the graph because a reference or resource edge connects "
"colocation groups with incompatible assigned devices: "
"/job:a/replica:0/task:0/device:fakecpu:0 vs "
"/job:a/replica:0/task:0/device:fakegpu:0"));
}
} // namespace
} // namespace tensorflow

View File

@ -289,30 +289,6 @@ bool DeviceNameUtils::IsSpecification(const ParsedName& less_specific,
return true;
}
void DeviceNameUtils::EnsureSpecification(ParsedName* more_specific,
const ParsedName& less_specific) {
if (less_specific.has_job) {
more_specific->has_job = true;
more_specific->job = less_specific.job;
}
if (less_specific.has_replica) {
more_specific->has_replica = true;
more_specific->replica = less_specific.replica;
}
if (less_specific.has_task) {
more_specific->has_task = true;
more_specific->task = less_specific.task;
}
if (less_specific.has_type) {
more_specific->has_type = true;
more_specific->type = less_specific.type;
}
if (less_specific.has_id) {
more_specific->has_id = true;
more_specific->id = less_specific.id;
}
}
/* static */
bool DeviceNameUtils::IsCompleteSpecification(const ParsedName& pattern,
const ParsedName& name) {

View File

@ -110,11 +110,6 @@ class DeviceNameUtils {
static bool IsSpecification(const ParsedName& less_specific,
const ParsedName& more_specific);
// Makes minimal changes to more_specific so that it becomes a
// specification of less_specific.
static void EnsureSpecification(ParsedName* more_specific,
const ParsedName& less_specific);
// Like IsSpecification, but the second argument "name" must have a
// non-wildcard value for all of its components.
static bool IsCompleteSpecification(const ParsedName& pattern,

View File

@ -454,8 +454,7 @@ class MapTest(test_base.DatasetTestBase, parameterized.TestCase):
dataset = dataset_ops.Dataset.from_tensors(0).repeat(10).map(func)
expected_error = (
errors.InvalidArgumentError,
"Cannot place the graph because a reference or resource edge "
"connects colocation groups with incompatible assigned devices")
"Could not colocate node with its resource and reference inputs")
self.assertDatasetProduces(
dataset, expected_error=expected_error, requires_initialization=True)

View File

@ -876,9 +876,8 @@ class FunctionTest(test.TestCase, parameterized.TestCase):
return None
with self.assertRaisesRegexp(
errors.InvalidArgumentError,
'Cannot place the graph because a reference or resource edge connects '
'colocation groups with incompatible assigned devices'):
errors.InvalidArgumentError, 'Could not colocate node with its '
'resource and reference inputs.*'):
if not context.executing_eagerly():
self.evaluate(variables.global_variables_initializer())
self.evaluate(resource_apply_adam())