Introduce CompositeDevice which represents a set of physical devices.
Support variables assigned to a CompositeDevice in ColocationGraph. PiperOrigin-RevId: 307683468 Change-Id: I3d7f2b6e747cc8659e343744190f173cedb2d010
This commit is contained in:
parent
2a6eb169c8
commit
852f93a20a
@ -53,6 +53,7 @@ package(
|
|||||||
tf_cuda_library(
|
tf_cuda_library(
|
||||||
name = "core_cpu",
|
name = "core_cpu",
|
||||||
hdrs = [
|
hdrs = [
|
||||||
|
"composite_device.h",
|
||||||
"device.h",
|
"device.h",
|
||||||
"device_factory.h",
|
"device_factory.h",
|
||||||
"function.h",
|
"function.h",
|
||||||
@ -145,6 +146,7 @@ filegroup(
|
|||||||
filegroup(
|
filegroup(
|
||||||
name = "core_cpu_base_headers",
|
name = "core_cpu_base_headers",
|
||||||
srcs = [
|
srcs = [
|
||||||
|
"composite_device.h",
|
||||||
"device.h",
|
"device.h",
|
||||||
"device_factory.h",
|
"device_factory.h",
|
||||||
"device_mgr.h",
|
"device_mgr.h",
|
||||||
@ -274,6 +276,7 @@ tf_cuda_library(
|
|||||||
"collective_rma_local.cc",
|
"collective_rma_local.cc",
|
||||||
"collective_util.cc",
|
"collective_util.cc",
|
||||||
"colocation_graph.cc",
|
"colocation_graph.cc",
|
||||||
|
"composite_device.cc",
|
||||||
"constant_folding.cc",
|
"constant_folding.cc",
|
||||||
"copy_tensor.cc",
|
"copy_tensor.cc",
|
||||||
"costmodel_manager.cc",
|
"costmodel_manager.cc",
|
||||||
@ -529,6 +532,22 @@ tf_cc_test(
|
|||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
tf_cc_test(
|
||||||
|
name = "composite_device_test",
|
||||||
|
size = "small",
|
||||||
|
srcs = [
|
||||||
|
"composite_device_test.cc",
|
||||||
|
],
|
||||||
|
linkstatic = tf_kernel_tests_linkstatic(),
|
||||||
|
deps = [
|
||||||
|
":core_cpu",
|
||||||
|
"//tensorflow/core:protos_all_cc",
|
||||||
|
"//tensorflow/core:test",
|
||||||
|
"//tensorflow/core:test_main",
|
||||||
|
"//tensorflow/core:testlib",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
tf_cc_tests(
|
tf_cc_tests(
|
||||||
name = "core_higher_level_tests",
|
name = "core_higher_level_tests",
|
||||||
size = "small",
|
size = "small",
|
||||||
|
@ -26,6 +26,7 @@ limitations under the License.
|
|||||||
#include "absl/container/flat_hash_set.h"
|
#include "absl/container/flat_hash_set.h"
|
||||||
#include "absl/strings/str_join.h"
|
#include "absl/strings/str_join.h"
|
||||||
#include "absl/types/optional.h"
|
#include "absl/types/optional.h"
|
||||||
|
#include "tensorflow/core/common_runtime/composite_device.h"
|
||||||
#include "tensorflow/core/common_runtime/device.h"
|
#include "tensorflow/core/common_runtime/device.h"
|
||||||
#include "tensorflow/core/common_runtime/device_set.h"
|
#include "tensorflow/core/common_runtime/device_set.h"
|
||||||
#include "tensorflow/core/common_runtime/function.h"
|
#include "tensorflow/core/common_runtime/function.h"
|
||||||
@ -136,6 +137,10 @@ bool IsXlaDevice(absl::string_view device_type) {
|
|||||||
device_type == "TPU");
|
device_type == "TPU");
|
||||||
}
|
}
|
||||||
|
|
||||||
|
bool IsCompositeDevice(absl::string_view device_type) {
|
||||||
|
return device_type == kCompositeDeviceType;
|
||||||
|
}
|
||||||
|
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
Status Member::SetParentAndSupportedDevices(
|
Status Member::SetParentAndSupportedDevices(
|
||||||
@ -220,6 +225,26 @@ Status Member::FillPossibleDevices(PossibleDevices* possible_device) const {
|
|||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
bool Member::IsEdgeFromCompositeDeviceToPhysicalDevice(
|
||||||
|
const Member& src_root) const {
|
||||||
|
auto compatible_edge_from_composite_device_to_physical_device =
|
||||||
|
[](const DeviceNameUtils::ParsedName& src_device,
|
||||||
|
const DeviceNameUtils::ParsedName& dst_device) -> bool {
|
||||||
|
return src_device.has_type && dst_device.has_type &&
|
||||||
|
IsCompositeDevice(src_device.type) &&
|
||||||
|
!IsCompositeDevice(dst_device.type);
|
||||||
|
};
|
||||||
|
if (compatible_edge_from_composite_device_to_physical_device(
|
||||||
|
src_root.assigned_device_name_, assigned_device_name_) ||
|
||||||
|
compatible_edge_from_composite_device_to_physical_device(
|
||||||
|
src_root.resource_device_name_, resource_device_name_) ||
|
||||||
|
compatible_edge_from_composite_device_to_physical_device(
|
||||||
|
src_root.requested_device_name_, requested_device_name_)) {
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
Status Member::EnsureCompatibilityAcrossResourceEdge(
|
Status Member::EnsureCompatibilityAcrossResourceEdge(
|
||||||
const Node& src, const Member& src_root,
|
const Node& src, const Member& src_root,
|
||||||
const Node& dst, /*dst_root is this*/
|
const Node& dst, /*dst_root is this*/
|
||||||
@ -484,7 +509,10 @@ Status Member::AssignDevice(const Node& node) {
|
|||||||
void Member::MaybeExcludeXlaDevices() {
|
void Member::MaybeExcludeXlaDevices() {
|
||||||
for (const auto& parsed_name :
|
for (const auto& parsed_name :
|
||||||
{requested_device_name_, assigned_device_name_, resource_device_name_}) {
|
{requested_device_name_, assigned_device_name_, resource_device_name_}) {
|
||||||
if (parsed_name.has_type && IsXlaDevice(parsed_name.type)) {
|
// Don't exculde XLA devices from supported devices if member is explicitly
|
||||||
|
// assigned to a CompositeDevice.
|
||||||
|
if (parsed_name.has_type && (IsXlaDevice(parsed_name.type) ||
|
||||||
|
IsCompositeDevice(parsed_name.type))) {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -664,6 +692,12 @@ Status ColocationGraph::ColocateResourceOrRefEdge(const Node* src,
|
|||||||
auto& src_root = members_[src_root_id];
|
auto& src_root = members_[src_root_id];
|
||||||
auto& dst_root = members_[dst_root_id];
|
auto& dst_root = members_[dst_root_id];
|
||||||
|
|
||||||
|
if (dst_root.IsEdgeFromCompositeDeviceToPhysicalDevice(src_root)) {
|
||||||
|
// If the src root is assigned to a composite device and the dst root is
|
||||||
|
// assigned to a physical device, don't colocate the dst root with the src
|
||||||
|
// root.
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
TF_RETURN_IF_ERROR(dst_root.EnsureCompatibilityAcrossResourceEdge(
|
TF_RETURN_IF_ERROR(dst_root.EnsureCompatibilityAcrossResourceEdge(
|
||||||
*src, src_root, *dst, log_device_placement_));
|
*src, src_root, *dst, log_device_placement_));
|
||||||
Status status = ColocateNodes(*src, src_root_id, *dst, dst_root_id);
|
Status status = ColocateNodes(*src, src_root_id, *dst, dst_root_id);
|
||||||
@ -890,6 +924,15 @@ Status GetGroupNodes(const IOColocationGroups& groups, const Node& node,
|
|||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Returns whether the device_type in `device_attributes` is supported.
|
||||||
|
bool IsSupportedDeviceType(const DeviceAttributes& device_attributes,
|
||||||
|
const DeviceType& supported_type) {
|
||||||
|
if (DeviceType(device_attributes.device_type()) == supported_type) {
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
return IsCompositeDevice(device_attributes.device_type());
|
||||||
|
}
|
||||||
|
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
Status ColocationGraph::ApplyIOColocationGroups(
|
Status ColocationGraph::ApplyIOColocationGroups(
|
||||||
@ -1364,7 +1407,7 @@ Status ColocationGraph::InitializeMemberWithAssignedDevice(
|
|||||||
}
|
}
|
||||||
|
|
||||||
for (const auto& d : member->supported_device_types()) {
|
for (const auto& d : member->supported_device_types()) {
|
||||||
if (DeviceType(assigned_device->attributes().device_type()) == d.first) {
|
if (IsSupportedDeviceType(assigned_device->attributes(), d.first)) {
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -1434,8 +1477,8 @@ Status ColocationGraph::InitializeMember(const Node& node, Member* member) {
|
|||||||
PrioritizedDeviceVector prioritized_filtered_devices;
|
PrioritizedDeviceVector prioritized_filtered_devices;
|
||||||
for (const auto& supported_device_type : supported_device_types) {
|
for (const auto& supported_device_type : supported_device_types) {
|
||||||
for (Device* device : devices) {
|
for (Device* device : devices) {
|
||||||
if (DeviceType(device->attributes().device_type()) ==
|
if (IsSupportedDeviceType(device->attributes(),
|
||||||
supported_device_type.first) {
|
supported_device_type.first)) {
|
||||||
if (default_local_device &&
|
if (default_local_device &&
|
||||||
(device == default_local_device ||
|
(device == default_local_device ||
|
||||||
// TODO(nareshmodi, fishx): At times the device pointer in the
|
// TODO(nareshmodi, fishx): At times the device pointer in the
|
||||||
|
@ -51,6 +51,10 @@ class Member {
|
|||||||
|
|
||||||
Status FillPossibleDevices(PossibleDevices* possible_device) const;
|
Status FillPossibleDevices(PossibleDevices* possible_device) const;
|
||||||
|
|
||||||
|
// Returns whether `src_root` is assigned to a CompositeDevice and `this` is
|
||||||
|
// assigned to a physical device.
|
||||||
|
bool IsEdgeFromCompositeDeviceToPhysicalDevice(const Member& src_root) const;
|
||||||
|
|
||||||
Status EnsureCompatibilityAcrossResourceEdge(
|
Status EnsureCompatibilityAcrossResourceEdge(
|
||||||
const Node& src, const Member& src_root,
|
const Node& src, const Member& src_root,
|
||||||
const Node& dst, /*dst_root is this*/
|
const Node& dst, /*dst_root is this*/
|
||||||
|
69
tensorflow/core/common_runtime/composite_device.cc
Normal file
69
tensorflow/core/common_runtime/composite_device.cc
Normal file
@ -0,0 +1,69 @@
|
|||||||
|
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
#include "tensorflow/core/common_runtime/composite_device.h"
|
||||||
|
|
||||||
|
#include "absl/strings/str_join.h"
|
||||||
|
#include "tensorflow/core/util/device_name_utils.h"
|
||||||
|
|
||||||
|
namespace tensorflow {
|
||||||
|
|
||||||
|
const char* const kCompositeDeviceType = "COMPOSITE";
|
||||||
|
|
||||||
|
std::unique_ptr<CompositeDevice> CompositeDevice::MakeDevice(
|
||||||
|
const std::vector<string>& underlying_devices, const int unique_device_id,
|
||||||
|
Status* status) {
|
||||||
|
if (underlying_devices.empty()) {
|
||||||
|
status->Update(
|
||||||
|
errors::InvalidArgument("underlying_devices should not be empty."));
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
DeviceNameUtils::ParsedName parsed_name;
|
||||||
|
if (!DeviceNameUtils::ParseFullName(underlying_devices.at(0), &parsed_name)) {
|
||||||
|
status->Update(tensorflow::errors::InvalidArgument(
|
||||||
|
"Cannot parse device name ", underlying_devices.at(0),
|
||||||
|
" when creating CompositeDevice."));
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
const string& underlying_type = parsed_name.type;
|
||||||
|
for (int i = 1; i < underlying_devices.size(); ++i) {
|
||||||
|
DeviceNameUtils::ParsedName name;
|
||||||
|
if (!DeviceNameUtils::ParseFullName(underlying_devices.at(i), &name)) {
|
||||||
|
status->Update(tensorflow::errors::InvalidArgument(
|
||||||
|
"Cannot parse device name ", underlying_devices.at(i),
|
||||||
|
" when creating CompositeDevice."));
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
if (name.type != underlying_type) {
|
||||||
|
status->Update(tensorflow::errors::InvalidArgument(
|
||||||
|
"Expect device type ", parsed_name.type, "; but got type ", name.type,
|
||||||
|
" from device: ", underlying_devices.at(i),
|
||||||
|
" when creating CompositeDevice."));
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
DeviceAttributes device_attributes;
|
||||||
|
parsed_name.type = kCompositeDeviceType;
|
||||||
|
device_attributes.set_device_type(parsed_name.type);
|
||||||
|
parsed_name.id = unique_device_id;
|
||||||
|
const string composite_name =
|
||||||
|
DeviceNameUtils::ParsedNameToString(parsed_name);
|
||||||
|
device_attributes.set_name(composite_name);
|
||||||
|
|
||||||
|
return absl::WrapUnique(
|
||||||
|
new CompositeDevice(device_attributes, underlying_devices));
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace tensorflow
|
63
tensorflow/core/common_runtime/composite_device.h
Normal file
63
tensorflow/core/common_runtime/composite_device.h
Normal file
@ -0,0 +1,63 @@
|
|||||||
|
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
#ifndef TENSORFLOW_CORE_COMMON_RUNTIME_COMPOSITE_DEVICE_H_
|
||||||
|
#define TENSORFLOW_CORE_COMMON_RUNTIME_COMPOSITE_DEVICE_H_
|
||||||
|
|
||||||
|
#include "absl/strings/string_view.h"
|
||||||
|
#include "tensorflow/core/common_runtime/device.h"
|
||||||
|
#include "tensorflow/core/framework/allocator.h"
|
||||||
|
#include "tensorflow/core/framework/device_attributes.pb.h"
|
||||||
|
#include "tensorflow/core/lib/core/errors.h"
|
||||||
|
#include "tensorflow/core/lib/core/status.h"
|
||||||
|
|
||||||
|
namespace tensorflow {
|
||||||
|
|
||||||
|
extern const char* const kCompositeDeviceType;
|
||||||
|
|
||||||
|
// A virtual device which represents a set of devices. We don't execute any
|
||||||
|
// op on this virtial device.
|
||||||
|
class CompositeDevice : public Device {
|
||||||
|
public:
|
||||||
|
Status Sync() override {
|
||||||
|
return errors::Internal(
|
||||||
|
"Sync() should never been invoked on CompositeDevice.");
|
||||||
|
}
|
||||||
|
|
||||||
|
Allocator* GetAllocator(AllocatorAttributes) override { return nullptr; }
|
||||||
|
|
||||||
|
const std::vector<string>* underlying_devices() const {
|
||||||
|
return &underlying_devices_;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Helper for creating a CompositeDevice.
|
||||||
|
static std::unique_ptr<CompositeDevice> MakeDevice(
|
||||||
|
const std::vector<string>& underlying_devices, const int unique_device_id,
|
||||||
|
Status* status);
|
||||||
|
|
||||||
|
private:
|
||||||
|
CompositeDevice(const DeviceAttributes& device_attributes,
|
||||||
|
const std::vector<string>& underlying_devices)
|
||||||
|
: Device(/*env=*/nullptr, device_attributes),
|
||||||
|
underlying_devices_(underlying_devices) {}
|
||||||
|
|
||||||
|
const std::vector<string> underlying_devices_;
|
||||||
|
|
||||||
|
TF_DISALLOW_COPY_AND_ASSIGN(CompositeDevice);
|
||||||
|
};
|
||||||
|
|
||||||
|
} // namespace tensorflow
|
||||||
|
|
||||||
|
#endif // TENSORFLOW_CORE_COMMON_RUNTIME_COMPOSITE_DEVICE_H_
|
65
tensorflow/core/common_runtime/composite_device_test.cc
Normal file
65
tensorflow/core/common_runtime/composite_device_test.cc
Normal file
@ -0,0 +1,65 @@
|
|||||||
|
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
#include "tensorflow/core/common_runtime/composite_device.h"
|
||||||
|
|
||||||
|
#include "tensorflow/core/lib/core/status_test_util.h"
|
||||||
|
|
||||||
|
namespace tensorflow {
|
||||||
|
|
||||||
|
TEST(CompositeDeviceTest, Basic) {
|
||||||
|
std::vector<string> underlying_devices;
|
||||||
|
{
|
||||||
|
Status status;
|
||||||
|
std::unique_ptr<CompositeDevice> composite_device =
|
||||||
|
CompositeDevice::MakeDevice(underlying_devices, /*unique_device_id=*/0,
|
||||||
|
&status);
|
||||||
|
EXPECT_EQ(composite_device, nullptr);
|
||||||
|
EXPECT_EQ(error::INVALID_ARGUMENT, status.code());
|
||||||
|
EXPECT_TRUE(absl::StrContains(status.error_message(),
|
||||||
|
"underlying_devices should not be empty"))
|
||||||
|
<< status.ToString();
|
||||||
|
}
|
||||||
|
|
||||||
|
{
|
||||||
|
Status status;
|
||||||
|
underlying_devices.push_back(
|
||||||
|
"/job:localhost/replica:0/task:0/device:CPU:0");
|
||||||
|
underlying_devices.push_back(
|
||||||
|
"/job:localhost/replica:0/task:0/device:CPU:1");
|
||||||
|
std::unique_ptr<CompositeDevice> composite_device =
|
||||||
|
CompositeDevice::MakeDevice(underlying_devices, /*unique_device_id=*/0,
|
||||||
|
&status);
|
||||||
|
TF_ASSERT_OK(status);
|
||||||
|
EXPECT_EQ(composite_device->device_type(), kCompositeDeviceType);
|
||||||
|
EXPECT_EQ(underlying_devices, *composite_device->underlying_devices());
|
||||||
|
}
|
||||||
|
|
||||||
|
{
|
||||||
|
Status status;
|
||||||
|
underlying_devices.push_back(
|
||||||
|
"/job:localhost/replica:0/task:0/device:GPU:0");
|
||||||
|
std::unique_ptr<CompositeDevice> composite_device =
|
||||||
|
CompositeDevice::MakeDevice(underlying_devices, /*unique_device_id=*/1,
|
||||||
|
&status);
|
||||||
|
EXPECT_EQ(composite_device, nullptr);
|
||||||
|
EXPECT_EQ(error::INVALID_ARGUMENT, status.code());
|
||||||
|
EXPECT_TRUE(absl::StrContains(status.error_message(),
|
||||||
|
"Expect device type CPU; but got type GPU"))
|
||||||
|
<< status.ToString();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace tensorflow
|
@ -96,7 +96,7 @@ class FakeDevice : public Device {
|
|||||||
const string& device_type) {
|
const string& device_type) {
|
||||||
DeviceAttributes device_attributes;
|
DeviceAttributes device_attributes;
|
||||||
device_attributes.set_name(name);
|
device_attributes.set_name(name);
|
||||||
device_attributes.set_device_type(DeviceType(device_type).type());
|
device_attributes.set_device_type(device_type);
|
||||||
return std::unique_ptr<Device>(new FakeDevice(device_attributes));
|
return std::unique_ptr<Device>(new FakeDevice(device_attributes));
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -233,6 +233,9 @@ class PlacerTest : public ::testing::Test {
|
|||||||
local_devices_.emplace_back(FakeDevice::MakeDevice(
|
local_devices_.emplace_back(FakeDevice::MakeDevice(
|
||||||
"/job:a/replica:0/task:0/device:XLA_CPU:0", "XLA_CPU"));
|
"/job:a/replica:0/task:0/device:XLA_CPU:0", "XLA_CPU"));
|
||||||
devices_.AddDevice(local_devices_.back().get());
|
devices_.AddDevice(local_devices_.back().get());
|
||||||
|
local_devices_.emplace_back(FakeDevice::MakeDevice(
|
||||||
|
"/job:a/replica:0/task:0/device:COMPOSITE:0", "COMPOSITE"));
|
||||||
|
devices_.AddDevice(local_devices_.back().get());
|
||||||
}
|
}
|
||||||
|
|
||||||
// Builds the given graph, and (if successful) indexes the node
|
// Builds the given graph, and (if successful) indexes the node
|
||||||
@ -1175,6 +1178,40 @@ TEST_F(PlacerTest, TestReferenceConnectionNoSourceDevice) {
|
|||||||
EXPECT_DEVICE_TYPE(g, "assign", "FakeCPU");
|
EXPECT_DEVICE_TYPE(g, "assign", "FakeCPU");
|
||||||
}
|
}
|
||||||
|
|
||||||
|
TEST_F(PlacerTest, TestResourceHandleOnCompositeDevice) {
|
||||||
|
auto build_graph = [this](Graph* g) -> Status {
|
||||||
|
GraphDefBuilder b(GraphDefBuilder::kFailImmediately);
|
||||||
|
Node* input = ops::SourceOp("TestInput", b.opts().WithName("in"));
|
||||||
|
// Build ten variable-and-assignment pairs.
|
||||||
|
Node* var = ops::SourceOp("HandleVariableCPU", b.opts().WithName("var"));
|
||||||
|
ops::BinaryOp("TestHandleAssign", var, input, b.opts().WithName("assign"));
|
||||||
|
TF_RETURN_IF_ERROR(BuildGraph(b, g));
|
||||||
|
// `var` is assigned to COMPOSITE.
|
||||||
|
GetNodeByName(*g, "var")->set_assigned_device_name(
|
||||||
|
"/job:a/replica:0/task:0/device:COMPOSITE:0");
|
||||||
|
return Status::OK();
|
||||||
|
};
|
||||||
|
|
||||||
|
{
|
||||||
|
// `assign` is not assigned to any device.
|
||||||
|
Graph g(OpRegistry::Global());
|
||||||
|
TF_ASSERT_OK(build_graph(&g));
|
||||||
|
TF_ASSERT_OK(Place(&g));
|
||||||
|
EXPECT_DEVICE_TYPE(g, "var", "COMPOSITE");
|
||||||
|
EXPECT_DEVICE_TYPE(g, "assign", "COMPOSITE");
|
||||||
|
}
|
||||||
|
{
|
||||||
|
// `assign` is assigned to FakeCPU.
|
||||||
|
Graph g(OpRegistry::Global());
|
||||||
|
TF_ASSERT_OK(build_graph(&g));
|
||||||
|
GetNodeByName(g, "assign")
|
||||||
|
->set_assigned_device_name("/job:a/replica:0/task:0/device:FakeCPU:0");
|
||||||
|
TF_ASSERT_OK(Place(&g));
|
||||||
|
EXPECT_DEVICE_TYPE(g, "var", "COMPOSITE");
|
||||||
|
EXPECT_DEVICE_TYPE(g, "assign", "FakeCPU");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
TEST_F(PlacerTest, TestColocationGroup) {
|
TEST_F(PlacerTest, TestColocationGroup) {
|
||||||
Graph g(OpRegistry::Global());
|
Graph g(OpRegistry::Global());
|
||||||
{ // Scope for temporary variables used to construct g.
|
{ // Scope for temporary variables used to construct g.
|
||||||
@ -1282,6 +1319,9 @@ TEST_F(PlacerTest, TestColocationGroupWithReferenceConnections) {
|
|||||||
Node* input = ops::SourceOp("TestInput", b.opts().WithName("in"));
|
Node* input = ops::SourceOp("TestInput", b.opts().WithName("in"));
|
||||||
Node* var1 = ops::SourceOp("VariableCPU", b.opts().WithName("var1"));
|
Node* var1 = ops::SourceOp("VariableCPU", b.opts().WithName("var1"));
|
||||||
Node* var2 = ops::SourceOp("VariableCPU", b.opts().WithName("var2"));
|
Node* var2 = ops::SourceOp("VariableCPU", b.opts().WithName("var2"));
|
||||||
|
Node* var3 = ops::SourceOp(
|
||||||
|
"VariableCPU",
|
||||||
|
b.opts().WithName("var3").WithDevice("/device:COMPOSITE:0"));
|
||||||
|
|
||||||
// Two assigns (reference connections) with two different
|
// Two assigns (reference connections) with two different
|
||||||
// colocation groups. Because their colocation groups all map to the
|
// colocation groups. Because their colocation groups all map to the
|
||||||
@ -1292,14 +1332,20 @@ TEST_F(PlacerTest, TestColocationGroupWithReferenceConnections) {
|
|||||||
ops::BinaryOp(
|
ops::BinaryOp(
|
||||||
"TestAssign", var2, input,
|
"TestAssign", var2, input,
|
||||||
b.opts().WithName("assign2").WithAttr("_class", {"loc:@var2"}));
|
b.opts().WithName("assign2").WithAttr("_class", {"loc:@var2"}));
|
||||||
|
ops::BinaryOp(
|
||||||
|
"TestAssign", var3, input,
|
||||||
|
b.opts().WithName("assign3").WithAttr("_class", {"loc:@var3"}));
|
||||||
TF_EXPECT_OK(BuildGraph(b, &g));
|
TF_EXPECT_OK(BuildGraph(b, &g));
|
||||||
}
|
}
|
||||||
|
|
||||||
TF_EXPECT_OK(Place(&g));
|
TF_EXPECT_OK(Place(&g));
|
||||||
|
EXPECT_DEVICE_TYPE(g, "in", "FakeCPU");
|
||||||
EXPECT_COLOCATED(g, "in", "var1");
|
EXPECT_COLOCATED(g, "in", "var1");
|
||||||
EXPECT_COLOCATED(g, "in", "var2");
|
EXPECT_COLOCATED(g, "in", "var2");
|
||||||
EXPECT_COLOCATED(g, "var1", "assign2");
|
EXPECT_COLOCATED(g, "var1", "assign2");
|
||||||
EXPECT_COLOCATED(g, "var2", "assign1");
|
EXPECT_COLOCATED(g, "var2", "assign1");
|
||||||
|
EXPECT_DEVICE_TYPE(g, "var3", "COMPOSITE");
|
||||||
|
EXPECT_COLOCATED(g, "var3", "assign3");
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_P(SoftPlacementPlacerTest,
|
TEST_P(SoftPlacementPlacerTest,
|
||||||
|
Loading…
x
Reference in New Issue
Block a user