Use the same CompositeDevice name on remote workers as the one on a client.
PiperOrigin-RevId: 317702206 Change-Id: I7068efb25eb930252f89a167108ed59c69c2078f
This commit is contained in:
parent
8cf9784629
commit
39d080e8b9
@ -25,6 +25,16 @@ const char* const kCompositeDeviceType = "COMPOSITE";
|
|||||||
std::unique_ptr<CompositeDevice> CompositeDevice::MakeDevice(
|
std::unique_ptr<CompositeDevice> CompositeDevice::MakeDevice(
|
||||||
const std::vector<string>& underlying_devices, const int unique_device_id,
|
const std::vector<string>& underlying_devices, const int unique_device_id,
|
||||||
const DeviceNameUtils::ParsedName& host_name, Status* status) {
|
const DeviceNameUtils::ParsedName& host_name, Status* status) {
|
||||||
|
DeviceNameUtils::ParsedName parsed_name = host_name;
|
||||||
|
parsed_name.type = kCompositeDeviceType;
|
||||||
|
parsed_name.id = unique_device_id;
|
||||||
|
const string device_name = DeviceNameUtils::ParsedNameToString(parsed_name);
|
||||||
|
return CompositeDevice::MakeDevice(underlying_devices, device_name, status);
|
||||||
|
}
|
||||||
|
|
||||||
|
std::unique_ptr<CompositeDevice> CompositeDevice::MakeDevice(
|
||||||
|
const std::vector<string>& underlying_devices, const string& device_name,
|
||||||
|
Status* status) {
|
||||||
if (underlying_devices.empty()) {
|
if (underlying_devices.empty()) {
|
||||||
status->Update(
|
status->Update(
|
||||||
errors::InvalidArgument("underlying_devices should not be empty."));
|
errors::InvalidArgument("underlying_devices should not be empty."));
|
||||||
@ -63,13 +73,8 @@ std::unique_ptr<CompositeDevice> CompositeDevice::MakeDevice(
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
DeviceNameUtils::ParsedName parsed_composite_name = host_name;
|
|
||||||
DeviceAttributes device_attributes;
|
DeviceAttributes device_attributes;
|
||||||
parsed_composite_name.type = kCompositeDeviceType;
|
device_attributes.set_name(device_name);
|
||||||
parsed_composite_name.id = unique_device_id;
|
|
||||||
const string composite_name =
|
|
||||||
DeviceNameUtils::ParsedNameToString(parsed_composite_name);
|
|
||||||
device_attributes.set_name(composite_name);
|
|
||||||
device_attributes.set_device_type(kCompositeDeviceType);
|
device_attributes.set_device_type(kCompositeDeviceType);
|
||||||
|
|
||||||
return absl::WrapUnique(
|
return absl::WrapUnique(
|
||||||
|
@ -48,6 +48,11 @@ class CompositeDevice : public Device {
|
|||||||
const std::vector<string>& underlying_devices, const int unique_device_id,
|
const std::vector<string>& underlying_devices, const int unique_device_id,
|
||||||
const DeviceNameUtils::ParsedName& host_name, Status* status);
|
const DeviceNameUtils::ParsedName& host_name, Status* status);
|
||||||
|
|
||||||
|
// Helper for creating a CompositeDevice with the given device name.
|
||||||
|
static std::unique_ptr<CompositeDevice> MakeDevice(
|
||||||
|
const std::vector<string>& underlying_devices, const string& device_name,
|
||||||
|
Status* status);
|
||||||
|
|
||||||
private:
|
private:
|
||||||
CompositeDevice(const DeviceAttributes& device_attributes,
|
CompositeDevice(const DeviceAttributes& device_attributes,
|
||||||
const std::vector<string>& underlying_devices)
|
const std::vector<string>& underlying_devices)
|
||||||
|
@ -80,4 +80,20 @@ TEST(CompositeDeviceTest, Basic) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
TEST(CompositeDeviceTest, DeviceName) {
|
||||||
|
const string composite_device_name =
|
||||||
|
"/job:localhost/replica:0/task:0/device:CPU:10";
|
||||||
|
std::vector<string> underlying_devices;
|
||||||
|
underlying_devices.push_back("/job:worker/replica:0/task:0/device:CPU:0");
|
||||||
|
underlying_devices.push_back("/job:worker/replica:0/task:0/device:CPU:1");
|
||||||
|
Status status;
|
||||||
|
std::unique_ptr<CompositeDevice> composite_device =
|
||||||
|
CompositeDevice::MakeDevice(underlying_devices, composite_device_name,
|
||||||
|
&status);
|
||||||
|
TF_ASSERT_OK(status);
|
||||||
|
EXPECT_EQ(composite_device->name(), composite_device_name);
|
||||||
|
EXPECT_EQ(composite_device->device_type(), kCompositeDeviceType);
|
||||||
|
EXPECT_EQ(underlying_devices, *composite_device->underlying_devices());
|
||||||
|
}
|
||||||
|
|
||||||
} // namespace tensorflow
|
} // namespace tensorflow
|
||||||
|
@ -893,7 +893,7 @@ Status EagerContext::FindDeviceFromName(const char* device_name,
|
|||||||
}
|
}
|
||||||
|
|
||||||
Status EagerContext::FindCompositeDeviceFromName(
|
Status EagerContext::FindCompositeDeviceFromName(
|
||||||
const char* device_name, CompositeDevice** device) const {
|
StringPiece device_name, CompositeDevice** device) const {
|
||||||
tf_shared_lock l(composite_devices_mu_);
|
tf_shared_lock l(composite_devices_mu_);
|
||||||
for (const auto& d : composite_devices_) {
|
for (const auto& d : composite_devices_) {
|
||||||
if (d.second->name() == device_name) {
|
if (d.second->name() == device_name) {
|
||||||
@ -939,8 +939,13 @@ Status EagerContext::RegisterCustomDevice(
|
|||||||
}
|
}
|
||||||
|
|
||||||
Status EagerContext::FindOrCreateCompositeDevice(
|
Status EagerContext::FindOrCreateCompositeDevice(
|
||||||
const std::vector<string>& underlying_devices,
|
const std::vector<string>& underlying_devices, const string& device_name,
|
||||||
CompositeDevice** composite_device) {
|
CompositeDevice** composite_device) {
|
||||||
|
if (!device_name.empty() &&
|
||||||
|
FindCompositeDeviceFromName(device_name, composite_device).ok()) {
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
|
||||||
const uint64 hash_key = Fingerprint64(absl::StrJoin(underlying_devices, ","));
|
const uint64 hash_key = Fingerprint64(absl::StrJoin(underlying_devices, ","));
|
||||||
|
|
||||||
mutex_lock l(composite_devices_mu_);
|
mutex_lock l(composite_devices_mu_);
|
||||||
@ -951,11 +956,16 @@ Status EagerContext::FindOrCreateCompositeDevice(
|
|||||||
}
|
}
|
||||||
|
|
||||||
Status s;
|
Status s;
|
||||||
|
std::unique_ptr<CompositeDevice> device;
|
||||||
|
if (device_name.empty()) {
|
||||||
// Create a CompositeDevice on the same task as the host CPU, in order to
|
// Create a CompositeDevice on the same task as the host CPU, in order to
|
||||||
// trigger packed TensorHandle copy from a client to a remote worker.
|
// trigger packed TensorHandle copy from a client to a remote worker.
|
||||||
auto device =
|
device = CompositeDevice::MakeDevice(underlying_devices,
|
||||||
CompositeDevice::MakeDevice(underlying_devices, composite_devices_.size(),
|
composite_devices_.size(),
|
||||||
HostCPU()->parsed_name(), &s);
|
HostCPU()->parsed_name(), &s);
|
||||||
|
} else {
|
||||||
|
device = CompositeDevice::MakeDevice(underlying_devices, device_name, &s);
|
||||||
|
}
|
||||||
TF_RETURN_IF_ERROR(s);
|
TF_RETURN_IF_ERROR(s);
|
||||||
*composite_device = device.get();
|
*composite_device = device.get();
|
||||||
pflr_->AddCompositeDevice(*composite_device);
|
pflr_->AddCompositeDevice(*composite_device);
|
||||||
|
@ -486,7 +486,7 @@ class EagerContext : public ImmediateExecutionContext, public core::RefCounted {
|
|||||||
|
|
||||||
Status FindDeviceFromName(const char* device_name, Device** device) const;
|
Status FindDeviceFromName(const char* device_name, Device** device) const;
|
||||||
|
|
||||||
Status FindCompositeDeviceFromName(const char* device_name,
|
Status FindCompositeDeviceFromName(StringPiece device_name,
|
||||||
CompositeDevice** device) const;
|
CompositeDevice** device) const;
|
||||||
|
|
||||||
Status FindCustomDeviceFromName(const string& device_name,
|
Status FindCustomDeviceFromName(const string& device_name,
|
||||||
@ -495,9 +495,10 @@ class EagerContext : public ImmediateExecutionContext, public core::RefCounted {
|
|||||||
Status RegisterCustomDevice(const string& name,
|
Status RegisterCustomDevice(const string& name,
|
||||||
std::unique_ptr<CustomDevice> device);
|
std::unique_ptr<CustomDevice> device);
|
||||||
|
|
||||||
// Find or create a composite device with the given `underlying_devices`.
|
// Find or create a composite device with the given `underlying_devices` and
|
||||||
|
// `device_name` (if not empty).
|
||||||
Status FindOrCreateCompositeDevice(
|
Status FindOrCreateCompositeDevice(
|
||||||
const std::vector<string>& underlying_devices,
|
const std::vector<string>& underlying_devices, const string& device_name,
|
||||||
CompositeDevice** composite_device);
|
CompositeDevice** composite_device);
|
||||||
|
|
||||||
bool OnSameTask(const Device* first, const Device* second) const;
|
bool OnSameTask(const Device* first, const Device* second) const;
|
||||||
|
@ -177,6 +177,7 @@ TEST_F(EagerContextTest, CompositeDevice) {
|
|||||||
"/job:worker/replica:0/task:0/device:CPU:1"};
|
"/job:worker/replica:0/task:0/device:CPU:1"};
|
||||||
CompositeDevice* composite_device_0 = nullptr;
|
CompositeDevice* composite_device_0 = nullptr;
|
||||||
TF_ASSERT_OK(context()->FindOrCreateCompositeDevice(underlying_devices,
|
TF_ASSERT_OK(context()->FindOrCreateCompositeDevice(underlying_devices,
|
||||||
|
/*device_name=*/"",
|
||||||
&composite_device_0));
|
&composite_device_0));
|
||||||
EXPECT_EQ(composite_device_0->name(),
|
EXPECT_EQ(composite_device_0->name(),
|
||||||
"/job:localhost/replica:0/task:0/device:COMPOSITE:0");
|
"/job:localhost/replica:0/task:0/device:COMPOSITE:0");
|
||||||
@ -186,11 +187,13 @@ TEST_F(EagerContextTest, CompositeDevice) {
|
|||||||
EXPECT_EQ(device, composite_device_0);
|
EXPECT_EQ(device, composite_device_0);
|
||||||
CompositeDevice* composite_device_1 = nullptr;
|
CompositeDevice* composite_device_1 = nullptr;
|
||||||
TF_ASSERT_OK(context()->FindOrCreateCompositeDevice(underlying_devices,
|
TF_ASSERT_OK(context()->FindOrCreateCompositeDevice(underlying_devices,
|
||||||
|
/*device_name=*/"",
|
||||||
&composite_device_1));
|
&composite_device_1));
|
||||||
EXPECT_EQ(composite_device_1, composite_device_0);
|
EXPECT_EQ(composite_device_1, composite_device_0);
|
||||||
underlying_devices.push_back("/job:worker/replica:0/task:0/device:CPU:2");
|
underlying_devices.push_back("/job:worker/replica:0/task:0/device:CPU:2");
|
||||||
CompositeDevice* composite_device_2 = nullptr;
|
CompositeDevice* composite_device_2 = nullptr;
|
||||||
TF_ASSERT_OK(context()->FindOrCreateCompositeDevice(underlying_devices,
|
TF_ASSERT_OK(context()->FindOrCreateCompositeDevice(underlying_devices,
|
||||||
|
/*device_name=*/"",
|
||||||
&composite_device_2));
|
&composite_device_2));
|
||||||
EXPECT_EQ(composite_device_2->name(),
|
EXPECT_EQ(composite_device_2->name(),
|
||||||
"/job:localhost/replica:0/task:0/device:COMPOSITE:1");
|
"/job:localhost/replica:0/task:0/device:COMPOSITE:1");
|
||||||
@ -202,5 +205,33 @@ TEST_F(EagerContextTest, CompositeDevice) {
|
|||||||
"/job:localhost/replica:0/task:0/device:COMPOSITE:2", &device)));
|
"/job:localhost/replica:0/task:0/device:COMPOSITE:2", &device)));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
TEST_F(EagerContextTest, CompositeDeviceWithGivenName) {
|
||||||
|
InitContext(SessionOptions(), DEVICE_PLACEMENT_EXPLICIT);
|
||||||
|
const std::vector<string> underlying_devices_0 = {
|
||||||
|
"/job:worker/replica:0/task:0/device:CPU:0",
|
||||||
|
"/job:worker/replica:0/task:0/device:CPU:1"};
|
||||||
|
const string composite_device_name =
|
||||||
|
"/job:worker1/replica:0/task:0/device:COMPOSITE:5";
|
||||||
|
// Create a CompositeDevice with the given name.
|
||||||
|
CompositeDevice* composite_device_0 = nullptr;
|
||||||
|
TF_ASSERT_OK(context()->FindOrCreateCompositeDevice(
|
||||||
|
underlying_devices_0, composite_device_name, &composite_device_0));
|
||||||
|
EXPECT_EQ(composite_device_0->name(), composite_device_name);
|
||||||
|
|
||||||
|
CompositeDevice* device = nullptr;
|
||||||
|
TF_EXPECT_OK(
|
||||||
|
context()->FindCompositeDeviceFromName(composite_device_name, &device));
|
||||||
|
EXPECT_EQ(device, composite_device_0);
|
||||||
|
|
||||||
|
std::vector<string> underlying_devices_1 = {
|
||||||
|
"/job:worker/replica:0/task:0/device:CPU:1",
|
||||||
|
"/job:worker/replica:0/task:0/device:CPU:2"};
|
||||||
|
// Find a CompositeDevice with the given name.
|
||||||
|
CompositeDevice* composite_device_1 = nullptr;
|
||||||
|
TF_ASSERT_OK(context()->FindOrCreateCompositeDevice(
|
||||||
|
underlying_devices_1, composite_device_0->name(), &composite_device_1));
|
||||||
|
EXPECT_EQ(composite_device_1, composite_device_0);
|
||||||
|
}
|
||||||
|
|
||||||
} // namespace
|
} // namespace
|
||||||
} // namespace tensorflow
|
} // namespace tensorflow
|
||||||
|
@ -415,7 +415,7 @@ Status GetOrCreateKernelAndDevice(
|
|||||||
TF_RETURN_IF_ERROR(GetDeviceForInput(ctx, input, &input_device));
|
TF_RETURN_IF_ERROR(GetDeviceForInput(ctx, input, &input_device));
|
||||||
input_dev_ptrs.push_back(input_device);
|
input_dev_ptrs.push_back(input_device);
|
||||||
CompositeDevice* composite_device = nullptr;
|
CompositeDevice* composite_device = nullptr;
|
||||||
if (ctx.FindCompositeDeviceFromName(input_device->name().c_str(),
|
if (ctx.FindCompositeDeviceFromName(input_device->name(),
|
||||||
&composite_device)
|
&composite_device)
|
||||||
.ok()) {
|
.ok()) {
|
||||||
composite_devices[input_device->name()] =
|
composite_devices[input_device->name()] =
|
||||||
|
@ -320,6 +320,7 @@ TensorHandle::TensorHandle(Device* d, Device* op_device,
|
|||||||
Status TensorHandle::CreatePackedHandle(std::vector<TensorHandle*>&& handles,
|
Status TensorHandle::CreatePackedHandle(std::vector<TensorHandle*>&& handles,
|
||||||
const tensorflow::DataType dtype,
|
const tensorflow::DataType dtype,
|
||||||
const tensorflow::TensorShape& shape,
|
const tensorflow::TensorShape& shape,
|
||||||
|
const string& device_name,
|
||||||
EagerContext* ctx,
|
EagerContext* ctx,
|
||||||
TensorHandle** packed_handle) {
|
TensorHandle** packed_handle) {
|
||||||
if (handles.empty()) {
|
if (handles.empty()) {
|
||||||
@ -343,8 +344,8 @@ Status TensorHandle::CreatePackedHandle(std::vector<TensorHandle*>&& handles,
|
|||||||
}
|
}
|
||||||
|
|
||||||
CompositeDevice* composite_device = nullptr;
|
CompositeDevice* composite_device = nullptr;
|
||||||
TF_RETURN_IF_ERROR(
|
TF_RETURN_IF_ERROR(ctx->FindOrCreateCompositeDevice(devices, device_name,
|
||||||
ctx->FindOrCreateCompositeDevice(devices, &composite_device));
|
&composite_device));
|
||||||
*packed_handle =
|
*packed_handle =
|
||||||
new TensorHandle(std::move(handles), composite_device, dtype, shape, ctx);
|
new TensorHandle(std::move(handles), composite_device, dtype, shape, ctx);
|
||||||
(*packed_handle)->SetResourceHandleInfo(std::move(resource_handle_info));
|
(*packed_handle)->SetResourceHandleInfo(std::move(resource_handle_info));
|
||||||
@ -363,8 +364,8 @@ Status TensorHandle::CreatePackedHandle(std::vector<TensorHandle*>&& handles,
|
|||||||
tensorflow::DataType dtype = handles.at(0)->dtype;
|
tensorflow::DataType dtype = handles.at(0)->dtype;
|
||||||
tensorflow::TensorShape shape;
|
tensorflow::TensorShape shape;
|
||||||
TF_RETURN_IF_ERROR(handles.at(0)->Shape(&shape));
|
TF_RETURN_IF_ERROR(handles.at(0)->Shape(&shape));
|
||||||
return CreatePackedHandle(std::move(handles), dtype, shape, ctx,
|
return CreatePackedHandle(std::move(handles), dtype, shape,
|
||||||
packed_handle);
|
/*device_name*/ "", ctx, packed_handle);
|
||||||
}
|
}
|
||||||
|
|
||||||
TensorHandle::TensorHandle(std::vector<TensorHandle*>&& handles, Device* device,
|
TensorHandle::TensorHandle(std::vector<TensorHandle*>&& handles, Device* device,
|
||||||
|
@ -94,7 +94,7 @@ class TensorHandle : public ImmediateExecutionTensorHandle,
|
|||||||
static Status CreatePackedHandle(std::vector<TensorHandle*>&& handles,
|
static Status CreatePackedHandle(std::vector<TensorHandle*>&& handles,
|
||||||
const tensorflow::DataType dtype,
|
const tensorflow::DataType dtype,
|
||||||
const tensorflow::TensorShape& shape,
|
const tensorflow::TensorShape& shape,
|
||||||
EagerContext* ctx,
|
const string& device_name, EagerContext* ctx,
|
||||||
TensorHandle** packed_handle);
|
TensorHandle** packed_handle);
|
||||||
static Status CreatePackedHandle(std::vector<TensorHandle*>&& handles,
|
static Status CreatePackedHandle(std::vector<TensorHandle*>&& handles,
|
||||||
EagerContext* ctx,
|
EagerContext* ctx,
|
||||||
|
@ -685,7 +685,7 @@ Status EagerServiceImpl::SendPackedHandle(
|
|||||||
// Create a unshaped packed TensorHandle.
|
// Create a unshaped packed TensorHandle.
|
||||||
TF_RETURN_IF_ERROR(TensorHandle::CreatePackedHandle(
|
TF_RETURN_IF_ERROR(TensorHandle::CreatePackedHandle(
|
||||||
std::move(handles_to_pack), handles.at(0)->dtype, TensorShape(),
|
std::move(handles_to_pack), handles.at(0)->dtype, TensorShape(),
|
||||||
eager_context, &packed_handle));
|
send_packed_handle.device_name(), eager_context, &packed_handle));
|
||||||
|
|
||||||
for (auto* h : handles) {
|
for (auto* h : handles) {
|
||||||
// Unref handle since it has a ref in the packed handle now.
|
// Unref handle since it has a ref in the packed handle now.
|
||||||
|
@ -1071,6 +1071,8 @@ TEST_F(EagerServiceImplTest, SendPackedHandleTest) {
|
|||||||
const string device0 = "/job:localhost/replica:0/task:0/device:CPU:0";
|
const string device0 = "/job:localhost/replica:0/task:0/device:CPU:0";
|
||||||
const string device1 = "/job:localhost/replica:0/task:1/device:CPU:0";
|
const string device1 = "/job:localhost/replica:0/task:1/device:CPU:0";
|
||||||
const string device2 = "/job:localhost/replica:0/task:2/device:CPU:0";
|
const string device2 = "/job:localhost/replica:0/task:2/device:CPU:0";
|
||||||
|
const string composite_device =
|
||||||
|
"/job:localhost/replica:0/task:0/device:COMPOSITE:0";
|
||||||
|
|
||||||
uint64 context_id = random::New64();
|
uint64 context_id = random::New64();
|
||||||
CreateContextRequest request;
|
CreateContextRequest request;
|
||||||
@ -1125,6 +1127,8 @@ TEST_F(EagerServiceImplTest, SendPackedHandleTest) {
|
|||||||
|
|
||||||
EXPECT_EQ(packed_handle->Type(), TensorHandle::PACKED);
|
EXPECT_EQ(packed_handle->Type(), TensorHandle::PACKED);
|
||||||
EXPECT_EQ(packed_handle->NumPackedHandles(), 3);
|
EXPECT_EQ(packed_handle->NumPackedHandles(), 3);
|
||||||
|
EXPECT_EQ(absl::get<Device*>(packed_handle->device())->name(),
|
||||||
|
composite_device);
|
||||||
|
|
||||||
TensorHandle* handle0 = nullptr;
|
TensorHandle* handle0 = nullptr;
|
||||||
TF_ASSERT_OK(packed_handle->ExtractPackedHandle(0, &handle0));
|
TF_ASSERT_OK(packed_handle->ExtractPackedHandle(0, &handle0));
|
||||||
|
@ -298,6 +298,7 @@ Status SerializePackedHandle(const uint64 op_id, TensorHandle* packed_handle,
|
|||||||
const Device* target_device, EagerContext* ctx,
|
const Device* target_device, EagerContext* ctx,
|
||||||
SendPackedHandleOp* op) {
|
SendPackedHandleOp* op) {
|
||||||
op->set_op_id(op_id);
|
op->set_op_id(op_id);
|
||||||
|
op->set_device_name(VariantDeviceName(packed_handle->DeviceOrHostCPU(*ctx)));
|
||||||
for (int i = 0; i < packed_handle->NumPackedHandles(); ++i) {
|
for (int i = 0; i < packed_handle->NumPackedHandles(); ++i) {
|
||||||
TensorHandle* h = nullptr;
|
TensorHandle* h = nullptr;
|
||||||
TF_RETURN_IF_ERROR(packed_handle->ExtractPackedHandle(i, &h));
|
TF_RETURN_IF_ERROR(packed_handle->ExtractPackedHandle(i, &h));
|
||||||
|
@ -258,6 +258,8 @@ message SendPackedHandleOp {
|
|||||||
}
|
}
|
||||||
|
|
||||||
repeated Handle handles = 2;
|
repeated Handle handles = 2;
|
||||||
|
|
||||||
|
string device_name = 3;
|
||||||
}
|
}
|
||||||
|
|
||||||
////////////////////////////////////////////////////////////////////////////////
|
////////////////////////////////////////////////////////////////////////////////
|
||||||
|
Loading…
x
Reference in New Issue
Block a user