parent
f80b1ca748
commit
d7b3e49283
@ -20,7 +20,6 @@ limitations under the License.
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
|
||||
#include "absl/strings/str_join.h"
|
||||
#include "tensorflow/core/common_runtime/device.h"
|
||||
#include "tensorflow/core/framework/attr_value_util.h"
|
||||
#include "tensorflow/core/framework/device_attributes.pb.h"
|
||||
@ -34,7 +33,6 @@ limitations under the License.
|
||||
#include "tensorflow/core/lib/strings/str_util.h"
|
||||
#include "tensorflow/core/lib/strings/strcat.h"
|
||||
#include "tensorflow/core/util/device_name_utils.h"
|
||||
#include "tensorflow/core/util/dump_graph.h"
|
||||
#include "tensorflow/core/util/port.h"
|
||||
|
||||
namespace tensorflow {
|
||||
@ -98,27 +96,6 @@ std::vector<Device*> FilterSupportedDevices(
|
||||
return filtered_devices;
|
||||
}
|
||||
|
||||
// Using absl::StrJoin with lambda does not work in tf-lite builds.
|
||||
std::vector<string> DevicesToString(const std::vector<Device*> devices) {
|
||||
std::vector<string> v;
|
||||
v.reserve(devices.size());
|
||||
for (Device* d : devices) {
|
||||
v.push_back(d->name());
|
||||
}
|
||||
return v;
|
||||
}
|
||||
|
||||
// Using absl::StrJoin with lambda does not work in tf-lite builds.
|
||||
std::vector<string> DeviceTypeAndPriorityToString(
|
||||
const PrioritizedDeviceTypeVector& devices) {
|
||||
std::vector<string> v;
|
||||
v.reserve(devices.size());
|
||||
for (const std::pair<DeviceType, int32>& device_and_type : devices) {
|
||||
v.push_back(DeviceTypeString(device_and_type.first));
|
||||
}
|
||||
return v;
|
||||
}
|
||||
|
||||
// Returns true if the node has no inputs and produces outputs
|
||||
// that are consumed by a single node.
|
||||
//
|
||||
@ -130,16 +107,6 @@ bool IsGeneratorNode(const Node* node) {
|
||||
!IsRefType(node->output_type(0));
|
||||
}
|
||||
|
||||
// While Placer can override requested device on ops processing
|
||||
// resources, i.e. node that take (and potentially return) a resource,
|
||||
// it must not override requested device on ops generating a resource,
|
||||
// e.g. VarHandleOp, _Arg. Such ops are currently no-input, single resource/ref
|
||||
// output nodes.
|
||||
bool IsResourceGeneratorNode(const Node& node) {
|
||||
return node.num_inputs() == 0 && node.num_outputs() == 1 &&
|
||||
(IsRefType(node.output_type(0)) || node.output_type(0) == DT_RESOURCE);
|
||||
}
|
||||
|
||||
bool IsExemptFromResourceInputColocation(const Node* node) {
|
||||
// Note: Partitioned function calls, which place and partition their
|
||||
// function bodies, are exempt from this check: they forward resource and
|
||||
@ -169,7 +136,7 @@ bool ArePrioritiesSame(const PrioritizedDeviceTypeVector& a_types,
|
||||
return true;
|
||||
}
|
||||
|
||||
// Represents a node in the disjoint node forest and the
|
||||
// Represents a node in the disjoint node set forest, and the
|
||||
// accumulated constraints on the device used by that node.
|
||||
class Member {
|
||||
public:
|
||||
@ -188,76 +155,18 @@ class Member {
|
||||
&supported_device_types_);
|
||||
}
|
||||
|
||||
const DeviceNameUtils::ParsedName& requested_device_name() const {
|
||||
return requested_device_name_;
|
||||
const DeviceNameUtils::ParsedName& device_name() const {
|
||||
return device_name_;
|
||||
}
|
||||
|
||||
Status SetAssignedDeviceName(const string& device_name) {
|
||||
if (DeviceNameUtils::HasSomeDetails(requested_device_name_)) {
|
||||
return errors::Internal(
|
||||
"Setting assigned device name when there is a requested device set "
|
||||
"is unsupported");
|
||||
}
|
||||
if (!DeviceNameUtils::ParseFullName(device_name, &assigned_device_name_)) {
|
||||
Status SetDeviceName(const string& device_name) {
|
||||
if (!DeviceNameUtils::ParseFullName(device_name, &device_name_)) {
|
||||
return errors::Internal("Malformed assigned device '", device_name, "'");
|
||||
}
|
||||
// Set requested device to assigned_device to maintain the invariant that
|
||||
// requested is a specialization of assigned.
|
||||
requested_device_name_ = assigned_device_name_;
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status SetRequestedDeviceName(const Node& node) {
|
||||
if (!DeviceNameUtils::ParseFullName(node.requested_device(),
|
||||
&requested_device_name_)) {
|
||||
return errors::InvalidArgument("Malformed device specification '",
|
||||
node.requested_device(),
|
||||
"' in node: ", node.DebugString());
|
||||
}
|
||||
if (DeviceNameUtils::HasSomeDetails(assigned_device_name_)) {
|
||||
return errors::Internal(
|
||||
"Setting requested device name when there is an assigned device set "
|
||||
"is unsupported");
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status EnsureCompatibilityAcrossResourceEdge(
|
||||
const Node& src, const Member& src_root,
|
||||
const Node& dst, /*dst_root is this*/
|
||||
bool log_device_placement) {
|
||||
if (!DeviceNameUtils::AreCompatibleDevNames(src_root.assigned_device_name_,
|
||||
assigned_device_name_)) {
|
||||
return errors::InvalidArgument(
|
||||
"Cannot place the graph because a reference or resource edge "
|
||||
"connects colocation groups with incompatible assigned devices: ",
|
||||
DeviceNameUtils::ParsedNameToString(src_root.assigned_device_name_),
|
||||
" vs ", DeviceNameUtils::ParsedNameToString(assigned_device_name_));
|
||||
}
|
||||
|
||||
if (DeviceNameUtils::AreCompatibleDevNames(src_root.requested_device_name_,
|
||||
requested_device_name_)) {
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
// If we are here, assigned devices are compatible but requested ones are
|
||||
// not. We will be overriding the requested device for destination node, but
|
||||
// need to preserve the invariant that it will be a specialization of
|
||||
// the assigned device.
|
||||
if (log_device_placement) {
|
||||
LOG(INFO) << "Ignoring device specification "
|
||||
<< DeviceNameUtils::ParsedNameToString(requested_device_name_)
|
||||
<< " for node '" << dst.name()
|
||||
<< "' because the input edge from '" << src.name()
|
||||
<< "' is a reference connection and already has a device "
|
||||
"field set to "
|
||||
<< DeviceNameUtils::ParsedNameToString(
|
||||
src_root.requested_device_name_);
|
||||
}
|
||||
requested_device_name_ = src_root.requested_device_name_;
|
||||
DeviceNameUtils::EnsureSpecification(&requested_device_name_,
|
||||
assigned_device_name_);
|
||||
return Status::OK();
|
||||
void SetDeviceName(const DeviceNameUtils::ParsedName& device_name) {
|
||||
device_name_ = device_name;
|
||||
}
|
||||
|
||||
const PrioritizedDeviceTypeVector& supported_device_types() const {
|
||||
@ -319,13 +228,13 @@ class Member {
|
||||
}
|
||||
|
||||
Status MergeDeviceNames(const Member& other, bool allow_soft_placement) {
|
||||
// Assuming the "requested is a specialization of assigned" invariant holds
|
||||
// for this and `other`, it will hold after these two merges.
|
||||
TF_RETURN_IF_ERROR(DeviceNameUtils::MergeDevNames(
|
||||
&requested_device_name_, other.requested_device_name_,
|
||||
allow_soft_placement));
|
||||
return DeviceNameUtils::MergeDevNames(&assigned_device_name_,
|
||||
other.assigned_device_name_,
|
||||
return DeviceNameUtils::MergeDevNames(&device_name_, other.device_name_,
|
||||
allow_soft_placement);
|
||||
}
|
||||
Status MergeDeviceNames(const string& dev_name, bool allow_soft_placement) {
|
||||
DeviceNameUtils::ParsedName parsed;
|
||||
DeviceNameUtils::ParseFullName(dev_name, &parsed);
|
||||
return DeviceNameUtils::MergeDevNames(&device_name_, parsed,
|
||||
allow_soft_placement);
|
||||
}
|
||||
|
||||
@ -415,29 +324,16 @@ class Member {
|
||||
if (node.assigned_device_name_index() == assigned_device_name_index_) {
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
DeviceNameUtils::ParsedName parsed;
|
||||
DeviceNameUtils::ParseFullName(node.assigned_device_name(), &parsed);
|
||||
Status s = DeviceNameUtils::MergeDevNames(&assigned_device_name_, parsed,
|
||||
allow_soft_placement);
|
||||
Status s =
|
||||
MergeDeviceNames(node.assigned_device_name(), allow_soft_placement);
|
||||
if (!s.ok()) {
|
||||
return errors::Internal(
|
||||
"Constraining by assigned device should not cause an error. Original "
|
||||
"root's assigned device name: ",
|
||||
DeviceNameUtils::ParsedNameToString(assigned_device_name_),
|
||||
" node's assigned device name \"", node.assigned_device_name(),
|
||||
"root device name: ",
|
||||
DeviceNameUtils::ParsedNameToString(device_name_),
|
||||
" assigned device name \"", node.assigned_device_name(),
|
||||
". Error: ", s.error_message());
|
||||
}
|
||||
s = DeviceNameUtils::MergeDevNames(&requested_device_name_, parsed,
|
||||
allow_soft_placement);
|
||||
if (!s.ok()) {
|
||||
return errors::Internal(
|
||||
"Constraining by assigned device should not cause an error. Original "
|
||||
"root's requested device name: \"",
|
||||
DeviceNameUtils::ParsedNameToString(requested_device_name_),
|
||||
"\", node's assigned device name \"", node.assigned_device_name(),
|
||||
"\". Error: ", s.error_message());
|
||||
}
|
||||
|
||||
assigned_device_name_index_ = node.assigned_device_name_index();
|
||||
// Clear cached possible_devices, if any.
|
||||
@ -450,20 +346,6 @@ class Member {
|
||||
}
|
||||
const std::vector<Device*>& possible_devices() { return possible_devices_; }
|
||||
|
||||
string DebugString() {
|
||||
return absl::StrCat(
|
||||
"Member(assigned_device_name_index_=", assigned_device_name_index_,
|
||||
" requested_device_name_=",
|
||||
DeviceNameUtils::ParsedNameToString(requested_device_name_),
|
||||
" assigned_device_name_=",
|
||||
DeviceNameUtils::ParsedNameToString(assigned_device_name_),
|
||||
" supported_device_types_=[",
|
||||
absl::StrJoin(DeviceTypeAndPriorityToString(supported_device_types_),
|
||||
", "),
|
||||
"] possible_devices_=[",
|
||||
absl::StrJoin(DevicesToString(possible_devices_), ", "), "]");
|
||||
}
|
||||
|
||||
private:
|
||||
// The id of the node that is the parent of this one, or its own
|
||||
// id if it is a root. parent <= 0 indicates that this member is invalid.
|
||||
@ -474,37 +356,19 @@ class Member {
|
||||
// sets.
|
||||
int rank_ = 0;
|
||||
|
||||
// Once colocation groups have been formed, the Placer starts actually
|
||||
// choosing devices. All nodes in a group must be assigned to the same
|
||||
// device. Once we assigned the first device to some node in this group,
|
||||
// we set assigned_device_name_index to this device name's index in the
|
||||
// graph.
|
||||
// The `*_device_name_` fields will contain the parsed name of this device
|
||||
// and `possible_devices`, if computed, will contain just this device.
|
||||
// Once colocation groups have been formed and we assigned at least
|
||||
// one node in this group to a device, assigned_device_name_index will
|
||||
// contain this device name's index in the graph. The `device_name` will
|
||||
// contain the parsed name of this device and `possible_devices`, if
|
||||
// computed, will contain just this device.
|
||||
// `assigned_device_name_index` is an optimization to avoid parsing and
|
||||
// comparing device names. The value of -1 signals that a single device
|
||||
// has not been chosen yet.
|
||||
int assigned_device_name_index_ = -1;
|
||||
|
||||
// The merged form of the device requested for this node, with those of all of
|
||||
// its children. requested_device_name_ is always kept a specialization (i.e.
|
||||
// DeviceNameUtils::IsSpecialization) of assigned_device_name_. When no device
|
||||
// is requested, this field is set to assigned_device_name_. As a
|
||||
// specialization of assigned_device_name_, requested_device_name_ represents
|
||||
// the most specific form of all assigned and requested devices of this node
|
||||
// and its children, if this node is a root. requested_device_name_ is used
|
||||
// to finally select devices for nodes. We can override requested devices due
|
||||
// to resource colocation constraints but not assigned devices (unless soft
|
||||
// placement is on).
|
||||
DeviceNameUtils::ParsedName requested_device_name_;
|
||||
|
||||
// The merged form of the device assigned for this node, with
|
||||
// The merged form of the device requested for this node, with
|
||||
// those of all of its children.
|
||||
// This field is used to raise errors due to unsatisfiable constraints.
|
||||
// Can be a partial specification.
|
||||
// INVARIANT: requested_device_name_ is always a
|
||||
// DeviceNameUtils::IsSpecialization of assigned_device_name_.
|
||||
DeviceNameUtils::ParsedName assigned_device_name_;
|
||||
DeviceNameUtils::ParsedName device_name_;
|
||||
|
||||
// The intersection of all device types supported by this node,
|
||||
// and those of all of its children, in priority order
|
||||
@ -515,7 +379,7 @@ class Member {
|
||||
// and all of its children have been assigned, or nullptr if this
|
||||
// has not yet been computed.
|
||||
std::vector<Device*> possible_devices_;
|
||||
}; // namespace
|
||||
};
|
||||
|
||||
// This class maintains the connected components of a colocation
|
||||
// constraint graph, and uses this information to assign a satisfying
|
||||
@ -629,9 +493,34 @@ class ColocationGraph {
|
||||
int dst_root_id = FindRoot(dst->id());
|
||||
auto& src_root = members_[src_root_id];
|
||||
auto& dst_root = members_[dst_root_id];
|
||||
// If both the source node and this node have partially
|
||||
// specified a device, then 'dst's device should be
|
||||
// cleared: the reference edge forces 'node' to be on the
|
||||
// same device as the source node.
|
||||
const auto& source_parsed_name = src_root.device_name();
|
||||
const auto& dest_parsed_name = dst_root.device_name();
|
||||
if (DeviceNameUtils::HasSomeDetails(source_parsed_name) &&
|
||||
DeviceNameUtils::HasSomeDetails(dest_parsed_name)) {
|
||||
// Ignore a specified device for 'dst' if the two names were
|
||||
// incompatible.
|
||||
if (!DeviceNameUtils::AreCompatibleDevNames(source_parsed_name,
|
||||
dest_parsed_name)) {
|
||||
TF_RETURN_IF_ERROR(VerifyResourceAndRefInputsCanBeColocated(
|
||||
dst, src, source_parsed_name));
|
||||
if (log_device_placement_) {
|
||||
LOG(INFO) << "Ignoring device specification "
|
||||
<< DeviceNameUtils::ParsedNameToString(dest_parsed_name)
|
||||
<< " for node '" << dst->name()
|
||||
<< "' because the input edge from '" << src->name()
|
||||
<< "' is a reference connection and already has a device "
|
||||
"field set to "
|
||||
<< DeviceNameUtils::ParsedNameToString(source_parsed_name);
|
||||
}
|
||||
|
||||
TF_RETURN_IF_ERROR(dst_root.EnsureCompatibilityAcrossResourceEdge(
|
||||
*src, src_root, *dst, log_device_placement_));
|
||||
// Make 'dst' colocated with the source
|
||||
dst_root.SetDeviceName(source_parsed_name);
|
||||
}
|
||||
}
|
||||
Status status = ColocateNodes(*src, src_root_id, *dst, dst_root_id);
|
||||
if (!status.ok()) {
|
||||
return AttachDef(
|
||||
@ -792,13 +681,12 @@ class ColocationGraph {
|
||||
// "devices" will contain the set of feasible placements for the
|
||||
// colocated node set containing 'node'.
|
||||
std::vector<Device*> devices;
|
||||
if (DeviceNameUtils::HasSomeDetails(
|
||||
members_[node_root].requested_device_name())) {
|
||||
if (DeviceNameUtils::HasSomeDetails(members_[node_root].device_name())) {
|
||||
// The root node has a (possibly partial) device
|
||||
// specification, so enumerate the physical devices that
|
||||
// conform to it.
|
||||
device_set_->FindMatchingDevices(
|
||||
members_[node_root].requested_device_name(), &devices);
|
||||
device_set_->FindMatchingDevices(members_[node_root].device_name(),
|
||||
&devices);
|
||||
|
||||
if (!devices.empty()) {
|
||||
// Filter devices into those that are compatible with the root
|
||||
@ -813,7 +701,7 @@ class ColocationGraph {
|
||||
// The soft_device_name is the same as the node's device name
|
||||
// without specifying the device type or ID.
|
||||
DeviceNameUtils::ParsedName soft_device_name =
|
||||
members_[node_root].requested_device_name();
|
||||
members_[node_root].device_name();
|
||||
soft_device_name.type.clear();
|
||||
soft_device_name.has_type = false;
|
||||
soft_device_name.has_id = false;
|
||||
@ -834,8 +722,7 @@ class ColocationGraph {
|
||||
DeviceNameUtils::ParsedName specified_device_name;
|
||||
if (DeviceNameUtils::ParseFullName(node->requested_device(),
|
||||
&specified_device_name) &&
|
||||
specified_device_name ==
|
||||
members_[node_root].requested_device_name()) {
|
||||
specified_device_name == members_[node_root].device_name()) {
|
||||
// The specified device and merged set device match, and
|
||||
// will appear in the GraphDef (for debugging), so just
|
||||
// print the specified device.
|
||||
@ -887,7 +774,7 @@ class ColocationGraph {
|
||||
" was colocated with a group of nodes that ",
|
||||
"required incompatible device '",
|
||||
DeviceNameUtils::ParsedNameToString(
|
||||
members_[node_root].requested_device_name()),
|
||||
members_[node_root].device_name()),
|
||||
"'", debug_info);
|
||||
}
|
||||
}
|
||||
@ -916,7 +803,10 @@ class ColocationGraph {
|
||||
}
|
||||
|
||||
Status InitializeMembers() {
|
||||
for (Node* node : graph_->op_nodes()) {
|
||||
for (Node* node : graph_->nodes()) {
|
||||
if (!node->IsOp()) {
|
||||
continue;
|
||||
}
|
||||
Status status = InitializeMember(*node, &members_[node->id()]);
|
||||
if (!status.ok()) {
|
||||
return AttachDef(status, *node);
|
||||
@ -925,22 +815,6 @@ class ColocationGraph {
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
string DebugString() {
|
||||
std::unordered_set<int> roots;
|
||||
std::vector<string> root_strings;
|
||||
for (const Node* node : graph_->nodes()) {
|
||||
if (!node->IsOp()) {
|
||||
continue;
|
||||
}
|
||||
int node_root = FindRoot(node->id());
|
||||
if (roots.count(node_root) == 0) {
|
||||
root_strings.push_back(DebugInfo(node_root));
|
||||
roots.insert(node_root);
|
||||
}
|
||||
}
|
||||
return absl::StrJoin(root_strings, "\n");
|
||||
}
|
||||
|
||||
// Returns debugging info for the node referred to by 'node_root'.
|
||||
string DebugInfo(const int node_root) {
|
||||
string text(
|
||||
@ -986,52 +860,46 @@ class ColocationGraph {
|
||||
}
|
||||
strings::StrAppend(&text, "\n");
|
||||
|
||||
if (num_nodes_found <= 0) {
|
||||
if (num_nodes_found <= 1) {
|
||||
text.clear();
|
||||
}
|
||||
return text;
|
||||
}
|
||||
|
||||
Status InitializeMemberWithAssignedDevice(const string& assigned_device_name,
|
||||
const string& node_type,
|
||||
bool must_be_full_name,
|
||||
Member* member) {
|
||||
// This node has already been assigned to a device, so we
|
||||
// respect this placement, after sanity-checking it.
|
||||
// NOTE: Since any assignment must have been performed by
|
||||
// the TensorFlow runtime, we consider errors in this branch to
|
||||
// be INTERNAL.
|
||||
TF_RETURN_IF_ERROR(member->SetAssignedDeviceName(assigned_device_name));
|
||||
if (!must_be_full_name) {
|
||||
return Status::OK();
|
||||
}
|
||||
// Since assigned device must be a full specification, do extra checks.
|
||||
const Device* assigned_device =
|
||||
device_set_->FindDeviceByName(assigned_device_name);
|
||||
if (assigned_device == nullptr) {
|
||||
return errors::Internal("Assigned device '", assigned_device_name,
|
||||
"' does not match any device");
|
||||
}
|
||||
|
||||
for (const auto& d : member->supported_device_types()) {
|
||||
if (DeviceType(assigned_device->attributes().device_type()) == d.first) {
|
||||
return Status::OK();
|
||||
}
|
||||
}
|
||||
|
||||
return errors::Internal("Assigned device '", assigned_device_name,
|
||||
"' does not have registered OpKernel support "
|
||||
"for ",
|
||||
node_type);
|
||||
}
|
||||
|
||||
Status InitializeMember(const Node& node, Member* member) {
|
||||
TF_RETURN_IF_ERROR(
|
||||
member->SetParentAndSupportedDevices(node, device_types_));
|
||||
|
||||
if (node.has_assigned_device_name()) {
|
||||
TF_RETURN_IF_ERROR(InitializeMemberWithAssignedDevice(
|
||||
node.assigned_device_name(), node.type_string(), true, member));
|
||||
// This node has already been assigned to a device, so we
|
||||
// respect this placement, after sanity-checking it. The
|
||||
// device_name and supported_device_types for this node reflect
|
||||
// the assigned device, so any nodes colocated with this node
|
||||
// will be assigned to the same device (assuming this is
|
||||
// possible).
|
||||
// NOTE: Since any assignment must have been performed by
|
||||
// the TensorFlow runtime, we consider errors in this branch to
|
||||
// be INTERNAL.
|
||||
const string& assigned_device_name = node.assigned_device_name();
|
||||
TF_RETURN_IF_ERROR(member->SetDeviceName(assigned_device_name));
|
||||
const Device* assigned_device =
|
||||
device_set_->FindDeviceByName(assigned_device_name);
|
||||
if (assigned_device == nullptr) {
|
||||
return errors::Internal("Assigned device '", assigned_device_name,
|
||||
"' does not match any device");
|
||||
}
|
||||
|
||||
for (const auto& d : member->supported_device_types()) {
|
||||
if (DeviceType(assigned_device->attributes().device_type()) ==
|
||||
d.first) {
|
||||
return Status::OK();
|
||||
}
|
||||
}
|
||||
|
||||
return errors::Internal("Assigned device '", assigned_device_name,
|
||||
"' does not have registered OpKernel support "
|
||||
"for ",
|
||||
node.type_string());
|
||||
} else {
|
||||
// This node has not yet been assigned to a device, so we
|
||||
// calculate any constraints due to the set of registered
|
||||
@ -1065,18 +933,15 @@ class ColocationGraph {
|
||||
// If the NodeDef contains a device, then we interpret it as a
|
||||
// (partial) device specification.
|
||||
if (!node.requested_device().empty()) {
|
||||
if (IsResourceGeneratorNode(node)) {
|
||||
// Treat requested device on resource generating nodes as assigned
|
||||
// device so that we don't override it.
|
||||
TF_RETURN_IF_ERROR(InitializeMemberWithAssignedDevice(
|
||||
node.requested_device(), node.type_string(), false, member));
|
||||
} else {
|
||||
// The user has specified a device in the NodeDef, try to find a
|
||||
// valid device matching their specification in the set of
|
||||
// devices.
|
||||
// NOTE: The full name may specify a device that is not in
|
||||
// n.supported_device_types(), but we check that in AssignDevice().
|
||||
TF_RETURN_IF_ERROR(member->SetRequestedDeviceName(node));
|
||||
// The user has specified a device in the NodeDef, try to find a
|
||||
// valid device matching their specification in the set of
|
||||
// devices.
|
||||
// NOTE: The full name may specify a device that is not in
|
||||
// n.supported_device_types(), but we check that in AssignDevice().
|
||||
if (!member->SetDeviceName(node.requested_device()).ok()) {
|
||||
return errors::InvalidArgument("Malformed device specification '",
|
||||
node.requested_device(),
|
||||
"' in node: ", node.DebugString());
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -1167,6 +1032,40 @@ class ColocationGraph {
|
||||
// given id is connected.
|
||||
int FindRoot(int node_id) { return Member::FindRoot(&members_, node_id); }
|
||||
|
||||
// Ensures that the devices of 'dst's resource and reference match the device
|
||||
// specified for 'src', which is an input of 'dst' with a partially or fully
|
||||
// specified device.
|
||||
Status VerifyResourceAndRefInputsCanBeColocated(
|
||||
const Node* dst, const Node* src,
|
||||
const DeviceNameUtils::ParsedName& src_parsed_name) {
|
||||
std::vector<const Edge*> edges;
|
||||
TF_RETURN_IF_ERROR(dst->input_edges(&edges));
|
||||
for (const Edge* edge : edges) {
|
||||
DataType input_type = dst->input_type(edge->dst_input());
|
||||
if (input_type == DT_RESOURCE || IsRefType(input_type)) {
|
||||
const Node* input_node = edge->src();
|
||||
if (input_node == src) {
|
||||
continue;
|
||||
}
|
||||
const auto& input_root = members_[FindRoot(input_node->id())];
|
||||
const auto& input_parsed_name = input_root.device_name();
|
||||
if (DeviceNameUtils::HasSomeDetails(input_parsed_name) &&
|
||||
!DeviceNameUtils::AreCompatibleDevNames(input_parsed_name,
|
||||
src_parsed_name)) {
|
||||
return AttachDef(
|
||||
errors::InvalidArgument(
|
||||
"Could not colocate node with its "
|
||||
"resource and reference inputs; devices ",
|
||||
DeviceNameUtils::ParsedNameToString(input_parsed_name),
|
||||
" and ", DeviceNameUtils::ParsedNameToString(src_parsed_name),
|
||||
" are not compatible."),
|
||||
*dst);
|
||||
}
|
||||
}
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
const Graph* const graph_; // Not owned.
|
||||
std::vector<Member> members_;
|
||||
const DeviceSet* device_set_; // Not owned.
|
||||
@ -1220,15 +1119,6 @@ Status Placer::Run() {
|
||||
return errors::FailedPrecondition("No devices are registered");
|
||||
}
|
||||
|
||||
if (VLOG_IS_ON(3)) {
|
||||
DumpGraphToFile("placer_input", *graph_, nullptr, "/tmp");
|
||||
for (const Node* node : graph_->op_nodes()) {
|
||||
VLOG(3) << " " << node->name() << ": requested: '"
|
||||
<< node->requested_device() << "' assigned: '"
|
||||
<< node->assigned_device_name() << "'";
|
||||
}
|
||||
}
|
||||
|
||||
ColocationGraph colocation_graph(
|
||||
graph_, devices_, default_device_,
|
||||
options_ == nullptr || options_->config.allow_soft_placement(),
|
||||
@ -1236,15 +1126,14 @@ Status Placer::Run() {
|
||||
|
||||
TF_RETURN_IF_ERROR(colocation_graph.Initialize());
|
||||
|
||||
// For each node, assign a device based on the constraints in the disjoint
|
||||
// node set.
|
||||
// For each node, assign a device based on the constraints in the
|
||||
// disjoint node set.
|
||||
std::vector<Node*> second_pass;
|
||||
for (Node* node : graph_->op_nodes()) {
|
||||
// The graph may have come pre-populated by the framework with assigned
|
||||
// devices (e.g., for stateful placements), so the placer should not try to
|
||||
// place nodes that are already placed.
|
||||
if (node->has_assigned_device_name()) {
|
||||
TF_RETURN_IF_ERROR(colocation_graph.LimitToAssignedDevice(*node));
|
||||
LogDeviceAssignment(node, log_device_placement_);
|
||||
continue;
|
||||
}
|
||||
@ -1345,9 +1234,6 @@ Status Placer::Run() {
|
||||
log_device_placement_));
|
||||
}
|
||||
|
||||
if (VLOG_IS_ON(3)) {
|
||||
DumpGraphToFile("placer_output", *graph_, nullptr, "/tmp");
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
|
@ -17,7 +17,6 @@ limitations under the License.
|
||||
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <unordered_set>
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
|
||||
@ -25,15 +24,11 @@ limitations under the License.
|
||||
#include "tensorflow/core/common_runtime/device_factory.h"
|
||||
#include "tensorflow/core/common_runtime/device_set.h"
|
||||
#include "tensorflow/core/framework/device_attributes.pb.h"
|
||||
#include "tensorflow/core/framework/function.h"
|
||||
#include "tensorflow/core/framework/function_testlib.h"
|
||||
#include "tensorflow/core/framework/kernel_def_builder.h"
|
||||
#include "tensorflow/core/framework/op.h"
|
||||
#include "tensorflow/core/framework/op_def_builder.h"
|
||||
#include "tensorflow/core/framework/op_kernel.h"
|
||||
#include "tensorflow/core/framework/types.pb.h"
|
||||
#include "tensorflow/core/graph/graph.h"
|
||||
#include "tensorflow/core/graph/graph_constructor.h"
|
||||
#include "tensorflow/core/graph/graph_def_builder.h"
|
||||
#include "tensorflow/core/graph/graph_def_builder_util.h"
|
||||
#include "tensorflow/core/lib/core/error_codes.pb.h"
|
||||
@ -45,16 +40,6 @@ limitations under the License.
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
using ::tensorflow::test::function::GDef;
|
||||
using ::tensorflow::test::function::NDef;
|
||||
using FDH = ::tensorflow::FunctionDefHelper;
|
||||
|
||||
constexpr char kCPU[] = "/device:fakecpu:0";
|
||||
constexpr char kGPU[] = "/device:fakegpu:0";
|
||||
|
||||
constexpr char kFullCPU[] = "/job:a/replica:0/task:0/device:fakecpu:0";
|
||||
constexpr char kFullGPU[] = "/job:a/replica:0/task:0/device:fakegpu:0";
|
||||
|
||||
namespace {
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
@ -225,16 +210,6 @@ class PlacerTest : public ::testing::Test {
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status BuildGraph(const GraphDef& graph_def, Graph* out_graph) {
|
||||
GraphConstructorOptions opts;
|
||||
TF_RETURN_IF_ERROR(ConvertGraphDefToGraph(opts, graph_def, out_graph));
|
||||
nodes_by_name_.clear();
|
||||
for (Node* node : out_graph->nodes()) {
|
||||
nodes_by_name_[node->name()] = node->id();
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
// Invokes the Placer on "graph". If no DeviceSet is specified, the
|
||||
// placement will use the default DeviceSet (of 10 CPU and 10 GPU devices).
|
||||
//
|
||||
@ -891,7 +866,7 @@ TEST_F(PlacerTest, TestResourceHandle) {
|
||||
}
|
||||
|
||||
TEST_F(PlacerTest, TestResourceHandlesOnDifferentDevicesFails) {
|
||||
auto handle_test = [this](bool allow_soft_placement, bool set_assigned) {
|
||||
auto handle_test = [this](bool allow_soft_placement) {
|
||||
Graph g(OpRegistry::Global());
|
||||
{ // Scope for temporary variables used to construct g.
|
||||
GraphDefBuilder b(GraphDefBuilder::kFailImmediately);
|
||||
@ -903,41 +878,27 @@ TEST_F(PlacerTest, TestResourceHandlesOnDifferentDevicesFails) {
|
||||
b.opts().WithName("two_handles_in"));
|
||||
TF_EXPECT_OK(BuildGraph(b, &g));
|
||||
|
||||
if (set_assigned) {
|
||||
GetNodeByName(g, "var_cpu")
|
||||
->set_assigned_device_name(
|
||||
"/job:a/replica:0/task:0/device:fakecpu:0");
|
||||
GetNodeByName(g, "var_gpu")
|
||||
->set_assigned_device_name(
|
||||
"/job:a/replica:0/task:0/device:fakegpu:0");
|
||||
} else {
|
||||
GetNodeByName(g, "var_cpu")
|
||||
->set_requested_device("/job:a/replica:0/task:0/device:fakecpu:0");
|
||||
GetNodeByName(g, "var_gpu")
|
||||
->set_requested_device("/job:a/replica:0/task:0/device:fakegpu:0");
|
||||
}
|
||||
GetNodeByName(g, "var_cpu")
|
||||
->set_assigned_device_name(
|
||||
"/job:a/replica:0/task:0/device:fakecpu:0");
|
||||
GetNodeByName(g, "var_gpu")
|
||||
->set_assigned_device_name(
|
||||
"/job:a/replica:0/task:0/device:fakegpu:0");
|
||||
}
|
||||
|
||||
SessionOptions options;
|
||||
options.config.set_allow_soft_placement(allow_soft_placement);
|
||||
options.config.set_log_device_placement(true);
|
||||
Status s = Place(&g, &options);
|
||||
EXPECT_EQ(error::INVALID_ARGUMENT, s.code()) << s.ToString();
|
||||
EXPECT_EQ(error::INVALID_ARGUMENT, s.code());
|
||||
EXPECT_TRUE(str_util::StrContains(
|
||||
s.error_message(),
|
||||
"Cannot place the graph because a reference or resource edge "
|
||||
"connects "
|
||||
"colocation groups with incompatible assigned devices: "
|
||||
"/job:a/replica:0/task:0/device:fakegpu:0 vs "
|
||||
"/job:a/replica:0/task:0/device:fakecpu:0"));
|
||||
|
||||
"Could not colocate node with its resource and reference inputs"));
|
||||
return Status::OK();
|
||||
};
|
||||
|
||||
TF_EXPECT_OK(handle_test(false, false));
|
||||
TF_EXPECT_OK(handle_test(false, true));
|
||||
TF_EXPECT_OK(handle_test(true, false));
|
||||
TF_EXPECT_OK(handle_test(true, true));
|
||||
TF_EXPECT_OK(handle_test(false));
|
||||
TF_EXPECT_OK(handle_test(true));
|
||||
}
|
||||
|
||||
// Test that an assignment of an operator to the wrong device
|
||||
@ -1656,127 +1617,5 @@ TEST_F(PlacerTest, TestGeneratorNodeDoesntFollowNonColocatedConsumers) {
|
||||
EXPECT_DEVICE_TYPE(g, "in", "FakeGPU");
|
||||
}
|
||||
|
||||
REGISTER_KERNEL_BUILDER(Name("_Arg").Device("FakeCPU"), DummyOp);
|
||||
REGISTER_KERNEL_BUILDER(Name("_Arg").Device("FakeGPU"), DummyOp);
|
||||
REGISTER_KERNEL_BUILDER(Name("_Retval").Device("FakeCPU"), DummyOp);
|
||||
REGISTER_KERNEL_BUILDER(Name("_Retval").Device("FakeGPU"), DummyOp);
|
||||
REGISTER_KERNEL_BUILDER(Name("Identity").Device("FakeCPU"), DummyOp);
|
||||
REGISTER_KERNEL_BUILDER(Name("Identity").Device("FakeGPU"), DummyOp);
|
||||
REGISTER_KERNEL_BUILDER(Name("Const").Device("FakeCPU"), DummyOp);
|
||||
REGISTER_KERNEL_BUILDER(Name("Const").Device("FakeGPU"), DummyOp);
|
||||
REGISTER_KERNEL_BUILDER(Name("Mul").Device("FakeCPU"), DummyOp);
|
||||
REGISTER_KERNEL_BUILDER(Name("Mul").Device("FakeGPU"), DummyOp);
|
||||
REGISTER_KERNEL_BUILDER(Name("Add").Device("FakeCPU"), DummyOp);
|
||||
REGISTER_KERNEL_BUILDER(Name("Add").Device("FakeGPU"), DummyOp);
|
||||
|
||||
TEST_F(PlacerTest, RequestedDeviceOnResourceGeneratorIsTreatedAsAssigned) {
|
||||
/*
|
||||
* a:RES:GPU b:RES:CPU
|
||||
* | |
|
||||
* | |
|
||||
* v v
|
||||
* id1 id2
|
||||
* @loc:id2
|
||||
*/
|
||||
FunctionDef func = test::function::ResourceOutput();
|
||||
GraphDef graph = GDef(
|
||||
{
|
||||
NDef("a", "_Arg", {}, {{"T", DT_RESOURCE}}, kGPU),
|
||||
NDef("b", "_Arg", {}, {{"T", DT_RESOURCE}}, kCPU),
|
||||
NDef("id1", "Identity", {"a"},
|
||||
{{"T", DT_RESOURCE},
|
||||
{"_class", gtl::ArraySlice<string>({"loc:@id2"})}}),
|
||||
NDef("id2", "Identity", {"b"}, {{"T", DT_RESOURCE}}),
|
||||
},
|
||||
// FunctionLib
|
||||
{func});
|
||||
|
||||
Graph g(OpRegistry::Global());
|
||||
TF_ASSERT_OK(BuildGraph(graph, &g));
|
||||
Status s = Place(&g);
|
||||
EXPECT_EQ(error::INVALID_ARGUMENT, s.code());
|
||||
EXPECT_TRUE(str_util::StrContains(
|
||||
s.error_message(),
|
||||
"Cannot place the graph because a reference or resource edge connects "
|
||||
"colocation groups with incompatible assigned devices:"));
|
||||
}
|
||||
|
||||
TEST_F(PlacerTest, RequestedDeviceCanBeOverridden) {
|
||||
/*
|
||||
* a:RES b:RES
|
||||
* | |
|
||||
* id_a:GPU id_b:CPU
|
||||
* | |
|
||||
* v v
|
||||
* id1 id2
|
||||
* @loc:id2
|
||||
*/
|
||||
FunctionDef func = test::function::ResourceOutput();
|
||||
GraphDef graph = GDef(
|
||||
{
|
||||
NDef("a", "_Arg", {}, {{"T", DT_RESOURCE}}),
|
||||
NDef("b", "_Arg", {}, {{"T", DT_RESOURCE}}),
|
||||
NDef("id_a", "Identity", {"a"}, {{"T", DT_RESOURCE}}, kGPU),
|
||||
NDef("id_b", "Identity", {"b"}, {{"T", DT_RESOURCE}}, kCPU),
|
||||
NDef("id1", "Identity", {"id_a"},
|
||||
{{"T", DT_RESOURCE},
|
||||
{"_class", gtl::ArraySlice<string>({"loc:@id2"})}}),
|
||||
NDef("id2", "Identity", {"id_b"}, {{"T", DT_RESOURCE}}),
|
||||
},
|
||||
// FunctionLib
|
||||
{func});
|
||||
|
||||
Graph g(OpRegistry::Global());
|
||||
TF_ASSERT_OK(BuildGraph(graph, &g));
|
||||
TF_ASSERT_OK(Place(&g));
|
||||
|
||||
// All should be colocated
|
||||
EXPECT_COLOCATED(g, "a", "b");
|
||||
EXPECT_COLOCATED(g, "id_a", "id_b");
|
||||
EXPECT_COLOCATED(g, "id1", "id2");
|
||||
EXPECT_COLOCATED(g, "a", "id_a");
|
||||
EXPECT_COLOCATED(g, "a", "id1");
|
||||
}
|
||||
|
||||
TEST_F(PlacerTest, AssignedDevicesAreNotOverriddenDueToResourcesAndColocation) {
|
||||
/*
|
||||
* a:RES b:RES
|
||||
* | |
|
||||
* id_a:GPU id_b:CPU
|
||||
* | |
|
||||
* v v
|
||||
* id1 id2
|
||||
* @loc:id2
|
||||
*/
|
||||
FunctionDef func = test::function::ResourceOutput();
|
||||
GraphDef graph = GDef(
|
||||
{
|
||||
NDef("a", "_Arg", {}, {{"T", DT_RESOURCE}}),
|
||||
NDef("b", "_Arg", {}, {{"T", DT_RESOURCE}}),
|
||||
NDef("id_a", "Identity", {"a"}, {{"T", DT_RESOURCE}}),
|
||||
NDef("id_b", "Identity", {"b"}, {{"T", DT_RESOURCE}}),
|
||||
NDef("id1", "Identity", {"id_a"},
|
||||
{{"T", DT_RESOURCE},
|
||||
{"_class", gtl::ArraySlice<string>({"loc:@id2"})}}),
|
||||
NDef("id2", "Identity", {"id_b"}, {{"T", DT_RESOURCE}}),
|
||||
},
|
||||
// FunctionLib
|
||||
{func});
|
||||
|
||||
Graph g(OpRegistry::Global());
|
||||
TF_ASSERT_OK(BuildGraph(graph, &g));
|
||||
std::unordered_map<string, Node*> nodes = g.BuildNodeNameIndex();
|
||||
GetNodeByName(g, "id_a")->set_assigned_device_name(kFullGPU);
|
||||
GetNodeByName(g, "id_b")->set_assigned_device_name(kFullCPU);
|
||||
Status s = Place(&g);
|
||||
EXPECT_EQ(error::INVALID_ARGUMENT, s.code());
|
||||
EXPECT_TRUE(str_util::StrContains(
|
||||
s.error_message(),
|
||||
"Cannot place the graph because a reference or resource edge connects "
|
||||
"colocation groups with incompatible assigned devices: "
|
||||
"/job:a/replica:0/task:0/device:fakecpu:0 vs "
|
||||
"/job:a/replica:0/task:0/device:fakegpu:0"));
|
||||
}
|
||||
|
||||
} // namespace
|
||||
} // namespace tensorflow
|
||||
|
@ -289,30 +289,6 @@ bool DeviceNameUtils::IsSpecification(const ParsedName& less_specific,
|
||||
return true;
|
||||
}
|
||||
|
||||
void DeviceNameUtils::EnsureSpecification(ParsedName* more_specific,
|
||||
const ParsedName& less_specific) {
|
||||
if (less_specific.has_job) {
|
||||
more_specific->has_job = true;
|
||||
more_specific->job = less_specific.job;
|
||||
}
|
||||
if (less_specific.has_replica) {
|
||||
more_specific->has_replica = true;
|
||||
more_specific->replica = less_specific.replica;
|
||||
}
|
||||
if (less_specific.has_task) {
|
||||
more_specific->has_task = true;
|
||||
more_specific->task = less_specific.task;
|
||||
}
|
||||
if (less_specific.has_type) {
|
||||
more_specific->has_type = true;
|
||||
more_specific->type = less_specific.type;
|
||||
}
|
||||
if (less_specific.has_id) {
|
||||
more_specific->has_id = true;
|
||||
more_specific->id = less_specific.id;
|
||||
}
|
||||
}
|
||||
|
||||
/* static */
|
||||
bool DeviceNameUtils::IsCompleteSpecification(const ParsedName& pattern,
|
||||
const ParsedName& name) {
|
||||
|
@ -110,11 +110,6 @@ class DeviceNameUtils {
|
||||
static bool IsSpecification(const ParsedName& less_specific,
|
||||
const ParsedName& more_specific);
|
||||
|
||||
// Makes minimal changes to more_specific so that it becomes a
|
||||
// specification of less_specific.
|
||||
static void EnsureSpecification(ParsedName* more_specific,
|
||||
const ParsedName& less_specific);
|
||||
|
||||
// Like IsSpecification, but the second argument "name" must have a
|
||||
// non-wildcard value for all of its components.
|
||||
static bool IsCompleteSpecification(const ParsedName& pattern,
|
||||
|
@ -454,8 +454,7 @@ class MapTest(test_base.DatasetTestBase, parameterized.TestCase):
|
||||
dataset = dataset_ops.Dataset.from_tensors(0).repeat(10).map(func)
|
||||
expected_error = (
|
||||
errors.InvalidArgumentError,
|
||||
"Cannot place the graph because a reference or resource edge "
|
||||
"connects colocation groups with incompatible assigned devices")
|
||||
"Could not colocate node with its resource and reference inputs")
|
||||
self.assertDatasetProduces(
|
||||
dataset, expected_error=expected_error, requires_initialization=True)
|
||||
|
||||
|
@ -876,9 +876,8 @@ class FunctionTest(test.TestCase, parameterized.TestCase):
|
||||
return None
|
||||
|
||||
with self.assertRaisesRegexp(
|
||||
errors.InvalidArgumentError,
|
||||
'Cannot place the graph because a reference or resource edge connects '
|
||||
'colocation groups with incompatible assigned devices'):
|
||||
errors.InvalidArgumentError, 'Could not colocate node with its '
|
||||
'resource and reference inputs.*'):
|
||||
if not context.executing_eagerly():
|
||||
self.evaluate(variables.global_variables_initializer())
|
||||
self.evaluate(resource_apply_adam())
|
||||
|
Loading…
Reference in New Issue
Block a user