Support remote multi-device functions with outputs on any remote devices.

PiperOrigin-RevId: 326549047
Change-Id: Ic6a03936e7923360b05e723a5bd4a788ec57d06b
This commit is contained in:
Yujing Zhang 2020-08-13 16:31:21 -07:00 committed by TensorFlower Gardener
parent 43288ecdda
commit 3ebcb8dadc
18 changed files with 323 additions and 71 deletions

View File

@ -30,18 +30,26 @@ TEST(CAPI, RemoteExecuteSilentCopiesAsyncFunc) {
TestRemoteExecuteSilentCopiesFunc(/*async=*/true, /*remote=*/true, TestRemoteExecuteSilentCopiesFunc(/*async=*/true, /*remote=*/true,
/*heavy_load_on_streaming_rpc=*/false); /*heavy_load_on_streaming_rpc=*/false);
} }
TEST(CAPI, RemoteExecuteSilentCopiesFuncRemoteOutputs) {
TestRemoteExecuteSilentCopiesFunc(/*async=*/false, /*remote=*/true,
/*heavy_load_on_streaming_rpc=*/false,
/*remote_func_outputs=*/true);
}
TEST(CAPI, RemoteExecuteSilentCopiesAsyncFuncRemoteOutputs) {
TestRemoteExecuteSilentCopiesFunc(/*async=*/true, /*remote=*/true,
/*heavy_load_on_streaming_rpc=*/false,
/*remote_func_outputs=*/true);
}
TEST(CAPI, RemoteExecuteSilentCopiesLocalAsyncFunc) { TEST(CAPI, RemoteExecuteSilentCopiesLocalAsyncFunc) {
TestRemoteExecuteSilentCopiesFunc(/*async=*/true, /*remote=*/false, TestRemoteExecuteSilentCopiesFunc(/*async=*/true, /*remote=*/false,
/*heavy_load_on_streaming_rpc=*/false); /*heavy_load_on_streaming_rpc=*/false);
} }
// TODO(b/162618595): Enable this test once we remove the check of remote TEST(CAPI, RemoteExecuteSilentCopiesLocalFuncRemoteOutputs) {
// outputs in ProcessFunctionLibraryRuntime.
TEST(CAPI, DISABLED_RemoteExecuteSilentCopiesLocalFuncRemoteOutputs) {
TestRemoteExecuteSilentCopiesFunc(/*async=*/false, /*remote=*/false, TestRemoteExecuteSilentCopiesFunc(/*async=*/false, /*remote=*/false,
/*heavy_load_on_streaming_rpc=*/false, /*heavy_load_on_streaming_rpc=*/false,
/*remote_func_outputs=*/true); /*remote_func_outputs=*/true);
} }
TEST(CAPI, DISABLED_RemoteExecuteSilentCopiesLocalAsyncFuncRemoteOutputs) { TEST(CAPI, RemoteExecuteSilentCopiesLocalAsyncFuncRemoteOutputs) {
TestRemoteExecuteSilentCopiesFunc(/*async=*/true, /*remote=*/false, TestRemoteExecuteSilentCopiesFunc(/*async=*/true, /*remote=*/false,
/*heavy_load_on_streaming_rpc=*/false, /*heavy_load_on_streaming_rpc=*/false,
/*remote_func_outputs=*/true); /*remote_func_outputs=*/true);

View File

@ -169,6 +169,13 @@ void TestRemoteExecuteSilentCopies(bool async, bool remote, bool func,
ASSERT_TRUE(remote_arg->HasLocalMirror(nullptr)); ASSERT_TRUE(remote_arg->HasLocalMirror(nullptr));
} }
if (remote_func_outputs) {
const string backing_device =
TFE_TensorHandleBackingDeviceName(retvals[0], status);
ASSERT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status);
EXPECT_EQ(backing_device, task2_name);
}
auto* retval_task0 = TFE_TensorHandleCopyToDevice( auto* retval_task0 = TFE_TensorHandleCopyToDevice(
retvals[0], ctx, "/job:localhost/replica:0/task:0/device:CPU:0", status); retvals[0], ctx, "/job:localhost/replica:0/task:0/device:CPU:0", status);
ASSERT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status); ASSERT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status);

View File

@ -191,6 +191,9 @@ Status EagerContext::RegisterFunction(AbstractFunction* f) {
// eager_operation.cc we can avoid a circular dependency between them. // eager_operation.cc we can avoid a circular dependency between them.
Status EagerOperation::Execute(absl::Span<AbstractTensorHandle*> retvals, Status EagerOperation::Execute(absl::Span<AbstractTensorHandle*> retvals,
int* num_retvals) { int* num_retvals) {
for (int i = 0; i < Inputs().size(); ++i) {
TF_RETURN_IF_ERROR(Inputs()[i]->WaitUnknownDevice());
}
// Run eager placement logic. // Run eager placement logic.
VariantDevice device; VariantDevice device;
TF_RETURN_IF_ERROR(eager::MaybePinToCustomDevice(&device, *this)); TF_RETURN_IF_ERROR(eager::MaybePinToCustomDevice(&device, *this));

View File

@ -607,8 +607,14 @@ Status CreateUnshapedOutput(
"Unable to find remote task corresponding to device ", "Unable to find remote task corresponding to device ",
output_device->name()); output_device->name());
} }
*output = TensorHandle::CreateUnshapedRemoteHandle( if (ctx->RemoteMgr()->IsMaster()) {
op_id, output_num, remote_task, output_dtype, output_device, ctx); *output = TensorHandle::CreateUnshapedRemoteHandle(
op_id, output_num, remote_task, output_dtype, output_device, ctx);
} else {
*output = TensorHandle::CreateLazyRemoteHandle(op_id, output_num,
output_dtype, output_device,
/*is_ready=*/false, ctx);
}
return Status::OK(); return Status::OK();
#endif // !IS_MOBILE_PLATFORM #endif // !IS_MOBILE_PLATFORM
} }
@ -916,14 +922,15 @@ Status EagerRemoteExecute(EagerOperation* op, TensorHandle** retvals,
// execute. // execute.
// The device_ and resource_device_ of this TensorHandle might be // The device_ and resource_device_ of this TensorHandle might be
// incorrect. It is pretty hard to make it correct because for // incorrect. For multi-device functions, we don't know the output device
// multi-device functions, we don't know the output device until the // until the function is instantiated on a remote worker. Luckily, we don't
// function is instantiated. Luckily, we don't need to know the correct // need to know the correct remote device here. We just need to know that it
// remote device here. We just need to know that it is remote. If we need // is remote. If we need copy this tensor to this process or run any ops
// to copy this tensor to this process, the remote end will know the // which take this tensor as an input, block until the correct device is
// correct device of this handle. // set.
const bool unknown_device = op->is_function();
retvals[i] = TensorHandle::CreateUnshapedRemoteHandle( retvals[i] = TensorHandle::CreateUnshapedRemoteHandle(
id, i, remote_task, output_dtypes[i], op_device, &ctx); id, i, remote_task, output_dtypes[i], op_device, &ctx, unknown_device);
} }
if (ctx.LazyCopyFunctionRemoteInputs()) { if (ctx.LazyCopyFunctionRemoteInputs()) {
@ -1206,6 +1213,7 @@ Status LocalEagerCopyToDevice(TensorHandle* h, EagerContext* ctx,
Status EagerCopyToDevice(TensorHandle* h, EagerContext* ctx, Status EagerCopyToDevice(TensorHandle* h, EagerContext* ctx,
EagerExecutor* executor, Device* device, bool mirror, EagerExecutor* executor, Device* device, bool mirror,
TensorHandle** result) { TensorHandle** result) {
TF_RETURN_IF_ERROR(h->WaitUnknownDevice());
auto send_device = h->DeviceOrHostCPU(*ctx); auto send_device = h->DeviceOrHostCPU(*ctx);
if (VariantDeviceIsCustom(send_device)) { if (VariantDeviceIsCustom(send_device)) {
return errors::Unimplemented( return errors::Unimplemented(

View File

@ -94,9 +94,9 @@ TEST(ExecuteNodeTest, ExecuteNodeArgs) {
TensorHandle::CreateLocalHandle(std::move(t1), device0, device0, ctx); TensorHandle::CreateLocalHandle(std::move(t1), device0, device0, ctx);
// Create two remote TensorHandles // Create two remote TensorHandles
TensorHandle* h2 = TensorHandle::CreateLazyRemoteHandle( TensorHandle* h2 = TensorHandle::CreateLazyRemoteHandle(
/*op_id=*/1, /*output_num=*/0, dtype, device1, ctx); /*op_id=*/1, /*output_num=*/0, dtype, device1, /*is_ready=*/true, ctx);
TensorHandle* h3 = TensorHandle::CreateLazyRemoteHandle( TensorHandle* h3 = TensorHandle::CreateLazyRemoteHandle(
/*op_id=*/2, /*output_num=*/1, dtype, device1, ctx); /*op_id=*/2, /*output_num=*/1, dtype, device1, /*is_ready=*/true, ctx);
// Create a packed TensorHandle // Create a packed TensorHandle
TensorHandle* packed_h = nullptr; TensorHandle* packed_h = nullptr;
TF_ASSERT_OK(TensorHandle::CreatePackedHandle({h1, h2}, ctx, &packed_h)); TF_ASSERT_OK(TensorHandle::CreatePackedHandle({h1, h2}, ctx, &packed_h));

View File

@ -115,6 +115,20 @@ bool TensorHandle::PackedTensorHandleData::IsReady() const {
return true; return true;
} }
Status TensorHandle::PackedTensorHandleData::WaitReady(
const char* caller) const {
{
tf_shared_lock l(mu_);
if (!is_poisoned_.ok()) {
return is_poisoned_;
}
}
for (auto* handle : handles_) {
TF_RETURN_IF_ERROR(handle->WaitReady(caller));
}
return Status::OK();
}
void TensorHandle::PackedTensorHandleData::Poison(Status status) { void TensorHandle::PackedTensorHandleData::Poison(Status status) {
mutex_lock l(mu_); mutex_lock l(mu_);
is_poisoned_ = status; is_poisoned_ = status;
@ -370,14 +384,16 @@ TensorHandle::TensorHandle(std::vector<TensorHandle*>&& handles, Device* device,
#if !defined(IS_MOBILE_PLATFORM) #if !defined(IS_MOBILE_PLATFORM)
TensorHandle* TensorHandle::CreateUnshapedRemoteHandle( TensorHandle* TensorHandle::CreateUnshapedRemoteHandle(
int64 op_id, int32 output_num, const string& remote_task, int64 op_id, int32 output_num, const string& remote_task,
tensorflow::DataType dtype, Device* d, EagerContext* ctx) { tensorflow::DataType dtype, Device* d, EagerContext* ctx,
return new TensorHandle(op_id, output_num, remote_task, dtype, d, ctx); const bool unknown_device) {
return new TensorHandle(op_id, output_num, remote_task, dtype, d, ctx,
unknown_device);
} }
TensorHandle::TensorHandle(int64 op_id, int32 output_num, TensorHandle::TensorHandle(int64 op_id, int32 output_num,
const string& remote_task, const string& remote_task,
tensorflow::DataType dtype, Device* d, tensorflow::DataType dtype, Device* d,
EagerContext* ctx) EagerContext* ctx, const bool unknown_device)
: ImmediateExecutionTensorHandle(kEager), : ImmediateExecutionTensorHandle(kEager),
dtype(dtype), dtype(dtype),
device_(d), device_(d),
@ -385,6 +401,7 @@ TensorHandle::TensorHandle(int64 op_id, int32 output_num,
resource_device_(dtype == DT_RESOURCE ? d : nullptr), resource_device_(dtype == DT_RESOURCE ? d : nullptr),
resource_remote_device_incarnation_( resource_remote_device_incarnation_(
GetRemoteDeviceIncarnation(resource_device_)), GetRemoteDeviceIncarnation(resource_device_)),
unknown_device_(unknown_device),
ctx_(ctx), ctx_(ctx),
data_(absl::in_place_type<RemoteTensorHandleData>, op_id, output_num, data_(absl::in_place_type<RemoteTensorHandleData>, op_id, output_num,
remote_task, ctx) { remote_task, ctx) {
@ -392,17 +409,15 @@ TensorHandle::TensorHandle(int64 op_id, int32 output_num,
<< " device: " << VariantDeviceDebugString(device_); << " device: " << VariantDeviceDebugString(device_);
} }
TensorHandle* TensorHandle::CreateLazyRemoteHandle(int64 op_id, TensorHandle* TensorHandle::CreateLazyRemoteHandle(
int32 output_num, int64 op_id, int32 output_num, tensorflow::DataType dtype, Device* d,
tensorflow::DataType dtype, const bool is_ready, EagerContext* ctx) {
Device* d, return new TensorHandle(op_id, output_num, dtype, d, is_ready, ctx);
EagerContext* ctx) {
return new TensorHandle(op_id, output_num, dtype, d, ctx);
} }
TensorHandle::TensorHandle(int64 op_id, int32 output_num, TensorHandle::TensorHandle(int64 op_id, int32 output_num,
tensorflow::DataType dtype, Device* d, tensorflow::DataType dtype, Device* d,
EagerContext* ctx) const bool is_ready, EagerContext* ctx)
: ImmediateExecutionTensorHandle(kEager), : ImmediateExecutionTensorHandle(kEager),
dtype(dtype), dtype(dtype),
device_(d), device_(d),
@ -412,7 +427,7 @@ TensorHandle::TensorHandle(int64 op_id, int32 output_num,
GetRemoteDeviceIncarnation(resource_device_)), GetRemoteDeviceIncarnation(resource_device_)),
ctx_(ctx), ctx_(ctx),
data_(absl::in_place_type<RemoteTensorHandleData>, op_id, output_num, data_(absl::in_place_type<RemoteTensorHandleData>, op_id, output_num,
ctx->GetContextViewId()) { ctx->GetContextViewId(), is_ready) {
DVLOG(3) << "Creating Lazy Remote TensorHandle: " << this DVLOG(3) << "Creating Lazy Remote TensorHandle: " << this
<< " device: " << VariantDeviceDebugString(device_); << " device: " << VariantDeviceDebugString(device_);
} }
@ -431,6 +446,11 @@ bool TensorHandle::IsReady() const {
return absl::visit([](auto& data) { return data.IsReady(); }, data_); return absl::visit([](auto& data) { return data.IsReady(); }, data_);
} }
Status TensorHandle::WaitReady(const char* caller) const {
return absl::visit([caller](auto& data) { return data.WaitReady(caller); },
data_);
}
TensorHandle::HandleType TensorHandle::Type() const { TensorHandle::HandleType TensorHandle::Type() const {
if (data_.index() == 0) { if (data_.index() == 0) {
return LOCAL; return LOCAL;
@ -518,6 +538,17 @@ Status TensorHandle::TensorValue(const Device* d, tensorflow::TensorValue* t) {
return mirror.TensorValue(t); return mirror.TensorValue(t);
} }
Status TensorHandle::WaitUnknownDevice() const {
if (unknown_device_) {
TF_RETURN_IF_ERROR(absl::visit(
[](auto& data) {
return data.WaitReady("TensorHandle::UnknownDevice");
},
data_));
}
return Status::OK();
}
VariantDevice TensorHandle::DeviceOrHostCPU(const EagerContext& ctx) const { VariantDevice TensorHandle::DeviceOrHostCPU(const EagerContext& ctx) const {
if (VariantDeviceIsCustom(device_)) { if (VariantDeviceIsCustom(device_)) {
return device_; return device_;
@ -786,13 +817,21 @@ Status TensorHandle::AddResourceShapeMirror(const Device* d, int64 op_id,
resource_shape_mirrors_.emplace( resource_shape_mirrors_.emplace(
std::piecewise_construct, std::forward_as_tuple(d->name()), std::piecewise_construct, std::forward_as_tuple(d->name()),
std::forward_as_tuple(op_id, output_num, ctx->GetContextViewId())); std::forward_as_tuple(op_id, output_num, ctx->GetContextViewId(),
/*is_ready=*/true));
return Status::OK(); return Status::OK();
} }
Status TensorHandle::SetRemoteShape(const TensorShape& shape, const Device* d, Status TensorHandle::SetRemoteShape(const TensorShape& shape, const Device* d,
uint64 context_view_id) { uint64 context_view_id) {
return SetRemoteShapeAndDevice(shape, d, context_view_id, /*op_device=*/"");
}
Status TensorHandle::SetRemoteShapeAndDevice(const TensorShape& shape,
const Device* d,
uint64 context_view_id,
string op_device) {
DVLOG(3) << "SetRemoteShape on TensorHandle: " << this << " device: " << d DVLOG(3) << "SetRemoteShape on TensorHandle: " << this << " device: " << d
<< " " << d->name(); << " " << d->name();
@ -830,7 +869,27 @@ Status TensorHandle::SetRemoteShape(const TensorShape& shape, const Device* d,
// For mirrors, this is not the case because they colocate with the data // For mirrors, this is not the case because they colocate with the data
// consuming op/function device, and we (for now) have to aggressively // consuming op/function device, and we (for now) have to aggressively
// invalidate those copies to avoid any false positives during cluster update. // invalidate those copies to avoid any false positives during cluster update.
return data.SetShape(shape); if (op_device.empty()) {
return data.SetShape(shape);
} else {
if (!unknown_device_) {
return errors::Internal("Cannot reset known devices.");
}
Device* device;
TF_RETURN_IF_ERROR(ctx_->FindDeviceFromName(op_device.c_str(), &device));
device_ = device;
op_device_ = device;
resource_device_ = dtype == DT_RESOURCE ? device : nullptr;
resource_remote_device_incarnation_ =
GetRemoteDeviceIncarnation(resource_device_);
string remote_task;
if (!DeviceNameUtils::GetTaskName(device->parsed_name(), &remote_task)) {
return errors::InvalidArgument(
"Unable to find remote task corresponding to device ",
device->name());
}
return data.SetShapeAndRemoteTask(shape, remote_task);
}
} }
void TensorHandle::PoisonRemote(Status status, const Device* d, void TensorHandle::PoisonRemote(Status status, const Device* d,
@ -1040,6 +1099,7 @@ const char* TensorHandle::DeviceName(Status* status) const {
if (VariantDeviceIsCustom(device())) { if (VariantDeviceIsCustom(device())) {
return absl::get<CustomDevice*>(device())->name().c_str(); return absl::get<CustomDevice*>(device())->name().c_str();
} }
status->Update(WaitUnknownDevice());
tensorflow::Device* d = op_device(); tensorflow::Device* d = op_device();
return (d == nullptr) ? "/job:localhost/replica:0/task:0/device:CPU:0" return (d == nullptr) ? "/job:localhost/replica:0/task:0/device:CPU:0"
: d->name().c_str(); : d->name().c_str();
@ -1049,6 +1109,7 @@ const char* TensorHandle::BackingDeviceName(Status* status) const {
if (VariantDeviceIsCustom(device())) { if (VariantDeviceIsCustom(device())) {
return absl::get<tensorflow::CustomDevice*>(device())->name().c_str(); return absl::get<tensorflow::CustomDevice*>(device())->name().c_str();
} else { } else {
status->Update(WaitUnknownDevice());
tensorflow::Device* d = absl::get<tensorflow::Device*>(device()); tensorflow::Device* d = absl::get<tensorflow::Device*>(device());
return (d == nullptr) ? "/job:localhost/replica:0/task:0/device:CPU:0" return (d == nullptr) ? "/job:localhost/replica:0/task:0/device:CPU:0"
: d->name().c_str(); : d->name().c_str();

View File

@ -66,9 +66,10 @@ class TensorHandle : public ImmediateExecutionTensorHandle {
#if !defined(IS_MOBILE_PLATFORM) #if !defined(IS_MOBILE_PLATFORM)
TensorHandle(int64 op_id, int32 output_num, const string& remote_task, TensorHandle(int64 op_id, int32 output_num, const string& remote_task,
tensorflow::DataType dtype, Device* device, EagerContext* ctx); tensorflow::DataType dtype, Device* device, EagerContext* ctx,
const bool unknown_device);
TensorHandle(int64 op_id, int32 output_num, tensorflow::DataType dtype, TensorHandle(int64 op_id, int32 output_num, tensorflow::DataType dtype,
Device* device, EagerContext* ctx); Device* device, const bool is_ready, EagerContext* ctx);
#endif // IS_MOBILE_PLATFORM #endif // IS_MOBILE_PLATFORM
public: public:
@ -100,13 +101,21 @@ class TensorHandle : public ImmediateExecutionTensorHandle {
TensorHandle** packed_handle); TensorHandle** packed_handle);
#if !defined(IS_MOBILE_PLATFORM) #if !defined(IS_MOBILE_PLATFORM)
static TensorHandle* CreateUnshapedRemoteHandle(int64 op_id, int32 output_num, // An unshaped remote handle refers to a tensor on a remote worker. It's not
const string& remote_task, // ready until the shape is set. It controls the lifetime of the remote
tensorflow::DataType dtype, // tensor.
Device* d, EagerContext* ctx); static TensorHandle* CreateUnshapedRemoteHandle(
int64 op_id, int32 output_num, const string& remote_task,
tensorflow::DataType dtype, Device* d, EagerContext* ctx,
const bool unknown_device = false);
// A lazy remote handle refers to a tensor on a remote worker. The lifetime of
// the remote tensor is controlled by the remote worker, but not by the lazy
// remote handle. Lazy handles are normally created on a default function
// device.
static TensorHandle* CreateLazyRemoteHandle(int64 op_id, int32 output_num, static TensorHandle* CreateLazyRemoteHandle(int64 op_id, int32 output_num,
tensorflow::DataType dtype, tensorflow::DataType dtype,
Device* d, EagerContext* ctx); Device* d, const bool is_ready,
EagerContext* ctx);
#endif // IS_MOBILE_PLATFORM #endif // IS_MOBILE_PLATFORM
void Release() override; void Release() override;
@ -141,6 +150,10 @@ class TensorHandle : public ImmediateExecutionTensorHandle {
return resource_remote_device_incarnation_; return resource_remote_device_incarnation_;
} }
// If the devices are unknown at creation time, block until the actual devices
// are set (data is ready).
Status WaitUnknownDevice() const;
VariantDevice DeviceOrHostCPU(const EagerContext& ctx) const; VariantDevice DeviceOrHostCPU(const EagerContext& ctx) const;
Status Shape(tensorflow::TensorShape* shape); Status Shape(tensorflow::TensorShape* shape);
@ -177,10 +190,15 @@ class TensorHandle : public ImmediateExecutionTensorHandle {
// transitions the tensor handle from a non-ready to a ready state by // transitions the tensor handle from a non-ready to a ready state by
// replacing the backing data abstraction to allow for the shape to be // replacing the backing data abstraction to allow for the shape to be
// queried. // queried.
// creating a TensorHandle (e.g. a remote output of a remote function).
// This method or Poison must be called exactly once for remote tensors that // This method or Poison must be called exactly once for remote tensors that
// were created without a known shape. // were created without a known shape.
Status SetRemoteShape(const TensorShape& shape, const Device* d, Status SetRemoteShape(const TensorShape& shape, const Device* d,
uint64 context_view_id); uint64 context_view_id);
// If op_device is not empty, reset the devices of a remote tensor which is
// created without known devices (e.g. function outputs).
Status SetRemoteShapeAndDevice(const TensorShape& shape, const Device* d,
uint64 context_view_id, string op_device);
// Poisons either this handle or a remote mirror with error `status`. // Poisons either this handle or a remote mirror with error `status`.
// Poisoning means that the handle will become ready and methods trying // Poisoning means that the handle will become ready and methods trying
@ -258,21 +276,27 @@ class TensorHandle : public ImmediateExecutionTensorHandle {
// to either SetTensor or SetRemoteShape which replaces the underlying data // to either SetTensor or SetRemoteShape which replaces the underlying data
// with a ready version of the tensor handle data. // with a ready version of the tensor handle data.
bool IsReady() const; bool IsReady() const;
Status WaitReady(const char* caller) const;
VariantDevice const device_; VariantDevice device_;
// Device in which the op producing this tensor was executed. Equals to // Device in which the op producing this tensor was executed. Equals to
// device_ for constant tensors. // device_ for constant tensors.
// Can be nullptr if the op producing this tensor was a function executed // Can be nullptr if the op producing this tensor was a function executed
// with function library runtime. // with function library runtime.
tensorflow::Device* const op_device_; tensorflow::Device* op_device_;
// If the tensor dtype is DT_RESOURCE, resource_device_ holds the device // If the tensor dtype is DT_RESOURCE, resource_device_ holds the device
// backing the resource. Else resource_device_ is nullptr. // backing the resource. Else resource_device_ is nullptr.
tensorflow::Device* const resource_device_; tensorflow::Device* resource_device_;
// Incarnation ID of the resource device if it locates on a remote device, or // Incarnation ID of the resource device if it locates on a remote device, or
// 0 if it locates on a local device. // 0 if it locates on a local device.
const int64 resource_remote_device_incarnation_; int64 resource_remote_device_incarnation_;
// If true, the handle refers to a remote tensor which is created without
// known devices. The actual devices are set by SetRemoteShape. The devices
// should be accessed once the handle is ready.
const bool unknown_device_ = false;
mutable mutex mu_; mutable mutex mu_;
@ -323,6 +347,7 @@ class TensorHandle : public ImmediateExecutionTensorHandle {
Status NumElements(int64* num_elements) const; Status NumElements(int64* num_elements) const;
Status Unprotect(); Status Unprotect();
bool IsReady() const; bool IsReady() const;
Status WaitReady(const char* caller) const;
void Poison(Status status); void Poison(Status status);
string DebugString() const; string DebugString() const;

View File

@ -334,4 +334,84 @@ TEST(TensorHandle_ResourceDeviceTest, OnRemoteDevice) {
ctx->Unref(); ctx->Unref();
} }
class RemoteTensorHandleTest : public ::testing::Test {
public:
RemoteTensorHandleTest() {
std::vector<std::unique_ptr<Device>> devices;
for (const char* name : device_names_) {
devices.emplace_back(CreateDevice("CPU", name));
}
device_mgr_ = new StaticDeviceMgr(std::move(devices));
context_ = new EagerContext(
SessionOptions(),
tensorflow::ContextDevicePlacementPolicy::DEVICE_PLACEMENT_SILENT,
tensorflow::ContextMirroringPolicy::MIRRORING_NONE, /* async= */ false,
/* lazy_copy_function_remote_inputs= */ false, device_mgr_,
/* device_mgr_owned= */ false, /* rendezvous= */ nullptr,
/* custom_kernel_creator= */ nullptr,
/* cluster_flr= */ nullptr);
}
~RemoteTensorHandleTest() override {
delete device_mgr_;
context_->Unref();
}
EagerContext* context() { return context_; }
std::vector<Device*> ListDevices() const {
return device_mgr_->ListDevices();
}
private:
const std::vector<const char*> device_names_ = {
"/job:worker/replica:0/task:0/device:CPU:0",
"/job:worker/replica:0/task:1/device:CPU:0",
"/job:worker/replica:0/task:2/device:CPU:0"};
StaticDeviceMgr* device_mgr_;
EagerContext* context_;
};
TEST_F(RemoteTensorHandleTest, UnknownRemoteDevice) {
std::vector<std::unique_ptr<Device>> devices;
devices.emplace_back(
CreateDevice("CPU", "/job:worker/replica:0/task:0/device:CPU:0"));
devices.emplace_back(
CreateDevice("CPU", "/job:worker/replica:0/task:1/device:CPU:0"));
devices.emplace_back(
CreateDevice("CPU", "/job:worker/replica:0/task:2/device:CPU:0"));
StaticDeviceMgr device_mgr(std::move(devices));
EagerContext* context = new EagerContext(
SessionOptions(),
tensorflow::ContextDevicePlacementPolicy::DEVICE_PLACEMENT_SILENT,
tensorflow::ContextMirroringPolicy::MIRRORING_NONE, /* async= */ false,
/* lazy_copy_function_remote_inputs= */ false, &device_mgr,
/* device_mgr_owned= */ false, /* rendezvous= */ nullptr,
/* custom_kernel_creator= */ nullptr,
/* cluster_flr= */ nullptr);
tensorflow::DataType dtype = DT_FLOAT;
TensorShape shape = {};
const string remote_task = "/job:worker/replica:0/task:1";
Device* d1 = device_mgr.ListDevices().at(1);
TensorHandle* h = TensorHandle::CreateUnshapedRemoteHandle(
/*op_id=*/0, /*output_num=*/0, remote_task, dtype, d1, context,
/*unknown_device=*/true);
EXPECT_EQ(absl::get<Device*>(h->device()), d1);
Device* d2 = device_mgr.ListDevices().at(2);
TF_ASSERT_OK(h->SetRemoteShapeAndDevice(
shape, d1, context->GetContextViewId(), d2->name()));
Status s;
EXPECT_EQ(h->BackingDeviceName(&s), d2->name());
TF_EXPECT_OK(s);
EXPECT_EQ(absl::get<Device*>(h->device()), d2);
h->Unref();
context->Unref();
}
} // namespace tensorflow } // namespace tensorflow

View File

@ -1006,16 +1006,13 @@ Status ProcessFunctionLibraryRuntime::GetOutputDevices(
const string& target = pair.first; const string& target = pair.first;
FunctionLibraryRuntime* target_flr = GetFLR(target); FunctionLibraryRuntime* target_flr = GetFLR(target);
Device* target_device = nullptr; Device* target_device = nullptr;
Device* host = nullptr;
if (target_flr == nullptr) { if (target_flr == nullptr) {
// TODO(b/162618595): Remove this error once we support a remote target_device = device_set()->FindDeviceByName(target);
// multi-device function with remote outputs. string remote_host;
return errors::Unimplemented( TF_RETURN_IF_ERROR(
"Currently, outputting tensors on remote devices is not supported." DeviceNameUtils::DeviceNameToCpuDeviceName(target, &remote_host));
"The ", host = device_set()->FindDeviceByName(remote_host);
comp_data.ret_indices[0],
"-th return value of the function outputs to target_device: ", target,
" Please copy the tensor to local device explicitly using "
"tf.identity and return the new Tensor instead.");
} else { } else {
target_device = target_flr->device(); target_device = target_flr->device();
} }
@ -1026,7 +1023,7 @@ Status ProcessFunctionLibraryRuntime::GetOutputDevices(
(*output_devices)[ret_index] = target_device; (*output_devices)[ret_index] = target_device;
} else { } else {
(*output_devices)[ret_index] = (*output_devices)[ret_index] =
comp_data.ret_alloc_attrs[j].on_host() ? nullptr : target_device; comp_data.ret_alloc_attrs[j].on_host() ? host : target_device;
} }
} }
} }

View File

@ -203,7 +203,7 @@ class ProcessFunctionLibraryRuntime {
const DeviceMgr* device_mgr() { return device_mgr_; } const DeviceMgr* device_mgr() { return device_mgr_; }
const std::shared_ptr<DeviceSet> device_set() { const std::shared_ptr<DeviceSet> device_set() const {
tf_shared_lock l(mu_); tf_shared_lock l(mu_);
return device_set_; return device_set_;
} }

View File

@ -156,9 +156,15 @@ Status TensorHandleShape(TensorHandle* handle, TensorShapeProto* proto) {
const tensorflow::Tensor* t = nullptr; const tensorflow::Tensor* t = nullptr;
// TODO(nareshmodi): This call makes async calls sync calls. Fix this. // TODO(nareshmodi): This call makes async calls sync calls. Fix this.
TF_RETURN_IF_ERROR(handle->Tensor(&t)); if (handle->Type() == TensorHandle::LOCAL) {
TF_RETURN_IF_ERROR(handle->Tensor(&t));
t->shape().AsProto(proto); t->shape().AsProto(proto);
} else {
TensorShape shape;
TF_RETURN_IF_ERROR(handle->Shape(&shape));
shape.AsProto(proto);
}
return Status::OK(); return Status::OK();
} }
@ -166,7 +172,8 @@ Status TensorHandleShape(TensorHandle* handle, TensorShapeProto* proto) {
Status AddOpRetvalsToResponse( Status AddOpRetvalsToResponse(
EagerContext* eager_context, int op_id, int num_retvals, EagerContext* eager_context, int op_id, int num_retvals,
TensorHandle** retvals, std::function<TensorProto*()> add_tensor_proto_fn, TensorHandle** retvals, std::function<TensorProto*()> add_tensor_proto_fn,
std::function<TensorShapeProto*()> add_shape_proto_fn) { std::function<TensorShapeProto*()> add_shape_proto_fn,
std::function<string*()> add_device_fn = nullptr) {
if (op_id == kInvalidRemoteOpId) { if (op_id == kInvalidRemoteOpId) {
// Copy the output tensors back along with the response, since the op id // Copy the output tensors back along with the response, since the op id
// is invalid which cannot be added to RemoteMgr. // is invalid which cannot be added to RemoteMgr.
@ -175,10 +182,21 @@ Status AddOpRetvalsToResponse(
retvals[i]->Unref(); retvals[i]->Unref();
} }
} else { } else {
eager_context->RemoteMgr()->AddOperationOutputs(
absl::MakeSpan(retvals, num_retvals), op_id);
for (int i = 0; i < num_retvals; i++) { for (int i = 0; i < num_retvals; i++) {
TF_RETURN_IF_ERROR(TensorHandleShape(retvals[i], add_shape_proto_fn())); TF_RETURN_IF_ERROR(TensorHandleShape(retvals[i], add_shape_proto_fn()));
const bool is_remote = retvals[i]->Type() == TensorHandle::REMOTE;
if (add_device_fn) {
*add_device_fn() =
is_remote ? absl::get<Device*>(
retvals[i]->DeviceOrHostCPU(*eager_context))
->name()
: "";
}
if (is_remote) {
retvals[i]->Unref();
} else {
eager_context->RemoteMgr()->AddOperationOutput(retvals[i], op_id, i);
}
} }
} }
return Status::OK(); return Status::OK();
@ -479,6 +497,8 @@ void EagerServiceImpl::RunComponentFunction(
wrapped_done(status); wrapped_done(status);
return; return;
} }
// The output device of a component function is the component device
// which is known on the default device of it's parent function.
wrapped_done(AddOpRetvalsToResponse( wrapped_done(AddOpRetvalsToResponse(
eager_context, op_id, *num_retvals, retvals->data(), eager_context, op_id, *num_retvals, retvals->data(),
[response] { return response->add_tensor(); }, [response] { return response->add_tensor(); },
@ -510,10 +530,19 @@ Status EagerServiceImpl::ExecuteOp(CallOptions* call_opts,
num_retvals), num_retvals),
&num_retvals)); &num_retvals));
std::function<string*()> add_device_fn = nullptr;
// Send the output devices of a function back to let a client know where the
// outputs are. For a primitive op, an output devics is the op device which is
// known on a client.
if (op.is_function()) {
add_device_fn = [queue_response] { return queue_response->add_device(); };
}
return AddOpRetvalsToResponse( return AddOpRetvalsToResponse(
eager_context, operation.id(), num_retvals, retvals.data(), eager_context, operation.id(), num_retvals, retvals.data(),
[queue_response] { return queue_response->add_tensor(); }, [queue_response] { return queue_response->add_tensor(); },
[queue_response] { return queue_response->add_shape(); }); [queue_response] { return queue_response->add_shape(); },
std::move(add_device_fn));
} }
Status EagerServiceImpl::Enqueue(CallOptions* call_opts, Status EagerServiceImpl::Enqueue(CallOptions* call_opts,

View File

@ -88,8 +88,14 @@ void RemoteExecuteNode::RunAsync(StatusCallback done) {
} }
for (size_t i = 0; i < retvals.size(); ++i) { for (size_t i = 0; i < retvals.size(); ++i) {
if (status.ok()) { if (status.ok()) {
Status s = retvals[i]->SetRemoteShape( const string output_device =
response->queue_response(0).shape(i), device, context_view_id); response->queue_response(0).device().empty()
? ""
: response->queue_response(0).device(i);
Status s = retvals[i]->SetRemoteShapeAndDevice(
response->queue_response(0).shape(i), device, context_view_id,
output_device);
if (!s.ok()) { if (!s.ok()) {
LOG(ERROR) << "Ignoring an error encountered when setting " LOG(ERROR) << "Ignoring an error encountered when setting "
"remote shape of tensor handle: " "remote shape of tensor handle: "

View File

@ -35,6 +35,13 @@ void RemoteMgr::AddOperationOutputs(
} }
} }
void RemoteMgr::AddOperationOutput(tensorflow::TensorHandle* handle,
int64 operation_id, int32 output_num) {
mutex_lock l(remote_tensor_handle_mu_);
remote_tensor_handle_map_.emplace(
RemoteTensorHandleInternal(operation_id, output_num), handle);
}
Status RemoteMgr::GetTensorHandleImpl( Status RemoteMgr::GetTensorHandleImpl(
const RemoteTensorHandleInternal& remote_handle, const RemoteTensorHandleInternal& remote_handle,
tensorflow::TensorHandle** handle) { tensorflow::TensorHandle** handle) {
@ -160,13 +167,14 @@ Status RemoteMgr::DeserializeRemoteTensorHandle(const RemoteTensorHandle& in,
(*out)->Ref(); (*out)->Ref();
} else { } else {
// Create a remote TensorHandle for remote tensors which have not been // Create a remote TensorHandle for remote tensors which have not been
// copied to the local worker yet. // copied to the local worker yet (e.g. remote function inputs).
const string& device_name = const string& device_name =
in.op_device().empty() ? in.device() : in.op_device(); in.op_device().empty() ? in.device() : in.op_device();
TF_RETURN_IF_ERROR( TF_RETURN_IF_ERROR(
parent_->FindDeviceFromName(device_name.c_str(), &device)); parent_->FindDeviceFromName(device_name.c_str(), &device));
*out = TensorHandle::CreateLazyRemoteHandle(in.op_id(), in.output_num(), *out = TensorHandle::CreateLazyRemoteHandle(in.op_id(), in.output_num(),
in.dtype(), device, parent_); in.dtype(), device,
/*is_ready=*/true, parent_);
std::vector<DtypeAndPartialTensorShape> dtypes_and_shapes; std::vector<DtypeAndPartialTensorShape> dtypes_and_shapes;
if (!GetMirroredResourceShape(RemoteTensorHandleInternal(in), if (!GetMirroredResourceShape(RemoteTensorHandleInternal(in),
&dtypes_and_shapes) &dtypes_and_shapes)

View File

@ -47,6 +47,9 @@ class RemoteMgr {
const gtl::ArraySlice<tensorflow::TensorHandle*> handles, const gtl::ArraySlice<tensorflow::TensorHandle*> handles,
int64 operation_id); int64 operation_id);
void AddOperationOutput(tensorflow::TensorHandle* handles, int64 operation_id,
int32 output_num);
Status GetTensorHandle(const RemoteTensorHandleInternal& remote_handle, Status GetTensorHandle(const RemoteTensorHandleInternal& remote_handle,
tensorflow::TensorHandle** handle); tensorflow::TensorHandle** handle);

View File

@ -95,7 +95,7 @@ TEST_F(RemoteMgrTest, SerializeRemoteTensorHandle) {
const uint64 op_id = 3; const uint64 op_id = 3;
const int output_num = 1; const int output_num = 1;
TensorHandle* handle = TensorHandle::CreateLazyRemoteHandle( TensorHandle* handle = TensorHandle::CreateLazyRemoteHandle(
op_id, output_num, DT_FLOAT, remote_device_, ctx_); op_id, output_num, DT_FLOAT, remote_device_, /*is_ready=*/true, ctx_);
RemoteTensorHandle remote_handle; RemoteTensorHandle remote_handle;
TF_ASSERT_OK(remote_mgr.SerializeRemoteTensorHandle( TF_ASSERT_OK(remote_mgr.SerializeRemoteTensorHandle(
handle, /*wait_until_ready=*/true, &remote_handle, remote_device_, handle, /*wait_until_ready=*/true, &remote_handle, remote_device_,

View File

@ -85,8 +85,9 @@ void DestroyRemoteTensorHandle(EagerContext* ctx, const string& remote_task,
} // namespace } // namespace
RemoteTensorHandleData::RemoteTensorHandleData(int64 op_id, int output_num, RemoteTensorHandleData::RemoteTensorHandleData(int64 op_id, int output_num,
uint64 context_view_id) uint64 context_view_id,
: is_ready_(true), bool is_ready)
: is_ready_(is_ready),
op_id_(op_id), op_id_(op_id),
output_num_(output_num), output_num_(output_num),
context_view_id_(context_view_id), context_view_id_(context_view_id),
@ -173,6 +174,11 @@ Status RemoteTensorHandleData::IsPoisoned() const {
} }
Status RemoteTensorHandleData::SetShape(const TensorShape& shape) { Status RemoteTensorHandleData::SetShape(const TensorShape& shape) {
return SetShapeAndRemoteTask(shape, /*remote_task=*/"");
}
Status RemoteTensorHandleData::SetShapeAndRemoteTask(
const TensorShape& shape, const string& remote_task) {
// If `is_ready_` is set previously due to poisoning, return the original // If `is_ready_` is set previously due to poisoning, return the original
// error that poisoned this tensor. // error that poisoned this tensor.
TF_RETURN_IF_ERROR(IsPoisoned()); TF_RETURN_IF_ERROR(IsPoisoned());
@ -183,6 +189,9 @@ Status RemoteTensorHandleData::SetShape(const TensorShape& shape) {
} }
shape_ = shape; shape_ = shape;
if (!remote_task.empty()) {
remote_task_ = remote_task;
}
is_poisoned_ = Status::OK(); is_poisoned_ = Status::OK();
is_ready_ = true; is_ready_ = true;

View File

@ -26,11 +26,16 @@ namespace tensorflow {
class RemoteTensorHandleData { class RemoteTensorHandleData {
public: public:
// Constructor for lazy remote handles. A lazy remote handle is created on // Constructor for lazy remote handles. A lazy remote handle is created on
// a remote worker with an op_id and an output_num sent by a client. The // a remote worker with an op_id and an output_num. It doesn't control the
// client won't serialize them until the corresponding remote tensor is ready. // lifetime of a remote handle that it refers to. If it refers to a remote
// So the remote tensor should be ready when we create a lazy remote handle. // function input, it's sent by a client which won't serialize it until
RemoteTensorHandleData(int64 op_id, int output_num, uint64 context_view_id); // the corresponding remote tensor is ready. So the remote tensor should be
// Constructor for unshaped remote handles // ready when we create a lazy remote handle. If it refers to a remote output,
// it's not ready until the shape is set.
RemoteTensorHandleData(int64 op_id, int output_num, uint64 context_view_id,
bool is_ready);
// Constructor for unshaped remote handles. It controls the lifetime of a
// remote handel that it refers to.
RemoteTensorHandleData(int64 op_id, int output_num, const string& remote_task, RemoteTensorHandleData(int64 op_id, int output_num, const string& remote_task,
EagerContext* ctx); EagerContext* ctx);
~RemoteTensorHandleData(); ~RemoteTensorHandleData();
@ -44,7 +49,10 @@ class RemoteTensorHandleData {
Status Unprotect() { return Status::OK(); } Status Unprotect() { return Status::OK(); }
bool IsReady() const; bool IsReady() const;
Status WaitReady(const char* caller) const;
Status SetShape(const TensorShape& shape); Status SetShape(const TensorShape& shape);
Status SetShapeAndRemoteTask(const TensorShape& shape,
const string& remote_task);
void Poison(Status status); void Poison(Status status);
Status IsPoisoned() const; Status IsPoisoned() const;
@ -58,8 +66,6 @@ class RemoteTensorHandleData {
uint64 context_view_id() const { return context_view_id_; } uint64 context_view_id() const { return context_view_id_; }
private: private:
Status WaitReady(const char* caller) const;
mutable mutex mu_; mutable mutex mu_;
bool is_ready_ TF_GUARDED_BY(mu_); bool is_ready_ TF_GUARDED_BY(mu_);
Status is_poisoned_ TF_GUARDED_BY(mu_); Status is_poisoned_ TF_GUARDED_BY(mu_);
@ -68,7 +74,7 @@ class RemoteTensorHandleData {
// IDs required when this class is representing a remote tensor handle. // IDs required when this class is representing a remote tensor handle.
const int64 op_id_; const int64 op_id_;
const int32 output_num_; const int32 output_num_;
string remote_task_; string remote_task_ TF_GUARDED_BY(mu_);
uint64 context_id_; uint64 context_id_;
uint64 context_view_id_; uint64 context_view_id_;
EagerContext* ctx_; EagerContext* ctx_;

View File

@ -77,6 +77,8 @@ message QueueResponse {
// `shape` and `tensor` cannot be set in the same response. // `shape` and `tensor` cannot be set in the same response.
// Shapes of output tensors for creating remote TensorHandles. // Shapes of output tensors for creating remote TensorHandles.
repeated TensorShapeProto shape = 1; repeated TensorShapeProto shape = 1;
// Optional. If set, represents the output devices of a function.
repeated string device = 3;
// Output tensors of a remote function. Set when Operation.id is invalid. // Output tensors of a remote function. Set when Operation.id is invalid.
repeated TensorProto tensor = 2; repeated TensorProto tensor = 2;