Support remote multi-device functions with outputs on any remote devices.
PiperOrigin-RevId: 326549047 Change-Id: Ic6a03936e7923360b05e723a5bd4a788ec57d06b
This commit is contained in:
parent
43288ecdda
commit
3ebcb8dadc
@ -30,18 +30,26 @@ TEST(CAPI, RemoteExecuteSilentCopiesAsyncFunc) {
|
|||||||
TestRemoteExecuteSilentCopiesFunc(/*async=*/true, /*remote=*/true,
|
TestRemoteExecuteSilentCopiesFunc(/*async=*/true, /*remote=*/true,
|
||||||
/*heavy_load_on_streaming_rpc=*/false);
|
/*heavy_load_on_streaming_rpc=*/false);
|
||||||
}
|
}
|
||||||
|
TEST(CAPI, RemoteExecuteSilentCopiesFuncRemoteOutputs) {
|
||||||
|
TestRemoteExecuteSilentCopiesFunc(/*async=*/false, /*remote=*/true,
|
||||||
|
/*heavy_load_on_streaming_rpc=*/false,
|
||||||
|
/*remote_func_outputs=*/true);
|
||||||
|
}
|
||||||
|
TEST(CAPI, RemoteExecuteSilentCopiesAsyncFuncRemoteOutputs) {
|
||||||
|
TestRemoteExecuteSilentCopiesFunc(/*async=*/true, /*remote=*/true,
|
||||||
|
/*heavy_load_on_streaming_rpc=*/false,
|
||||||
|
/*remote_func_outputs=*/true);
|
||||||
|
}
|
||||||
TEST(CAPI, RemoteExecuteSilentCopiesLocalAsyncFunc) {
|
TEST(CAPI, RemoteExecuteSilentCopiesLocalAsyncFunc) {
|
||||||
TestRemoteExecuteSilentCopiesFunc(/*async=*/true, /*remote=*/false,
|
TestRemoteExecuteSilentCopiesFunc(/*async=*/true, /*remote=*/false,
|
||||||
/*heavy_load_on_streaming_rpc=*/false);
|
/*heavy_load_on_streaming_rpc=*/false);
|
||||||
}
|
}
|
||||||
// TODO(b/162618595): Enable this test once we remove the check of remote
|
TEST(CAPI, RemoteExecuteSilentCopiesLocalFuncRemoteOutputs) {
|
||||||
// outputs in ProcessFunctionLibraryRuntime.
|
|
||||||
TEST(CAPI, DISABLED_RemoteExecuteSilentCopiesLocalFuncRemoteOutputs) {
|
|
||||||
TestRemoteExecuteSilentCopiesFunc(/*async=*/false, /*remote=*/false,
|
TestRemoteExecuteSilentCopiesFunc(/*async=*/false, /*remote=*/false,
|
||||||
/*heavy_load_on_streaming_rpc=*/false,
|
/*heavy_load_on_streaming_rpc=*/false,
|
||||||
/*remote_func_outputs=*/true);
|
/*remote_func_outputs=*/true);
|
||||||
}
|
}
|
||||||
TEST(CAPI, DISABLED_RemoteExecuteSilentCopiesLocalAsyncFuncRemoteOutputs) {
|
TEST(CAPI, RemoteExecuteSilentCopiesLocalAsyncFuncRemoteOutputs) {
|
||||||
TestRemoteExecuteSilentCopiesFunc(/*async=*/true, /*remote=*/false,
|
TestRemoteExecuteSilentCopiesFunc(/*async=*/true, /*remote=*/false,
|
||||||
/*heavy_load_on_streaming_rpc=*/false,
|
/*heavy_load_on_streaming_rpc=*/false,
|
||||||
/*remote_func_outputs=*/true);
|
/*remote_func_outputs=*/true);
|
||||||
|
@ -169,6 +169,13 @@ void TestRemoteExecuteSilentCopies(bool async, bool remote, bool func,
|
|||||||
ASSERT_TRUE(remote_arg->HasLocalMirror(nullptr));
|
ASSERT_TRUE(remote_arg->HasLocalMirror(nullptr));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if (remote_func_outputs) {
|
||||||
|
const string backing_device =
|
||||||
|
TFE_TensorHandleBackingDeviceName(retvals[0], status);
|
||||||
|
ASSERT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status);
|
||||||
|
EXPECT_EQ(backing_device, task2_name);
|
||||||
|
}
|
||||||
|
|
||||||
auto* retval_task0 = TFE_TensorHandleCopyToDevice(
|
auto* retval_task0 = TFE_TensorHandleCopyToDevice(
|
||||||
retvals[0], ctx, "/job:localhost/replica:0/task:0/device:CPU:0", status);
|
retvals[0], ctx, "/job:localhost/replica:0/task:0/device:CPU:0", status);
|
||||||
ASSERT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status);
|
ASSERT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status);
|
||||||
|
@ -191,6 +191,9 @@ Status EagerContext::RegisterFunction(AbstractFunction* f) {
|
|||||||
// eager_operation.cc we can avoid a circular dependency between them.
|
// eager_operation.cc we can avoid a circular dependency between them.
|
||||||
Status EagerOperation::Execute(absl::Span<AbstractTensorHandle*> retvals,
|
Status EagerOperation::Execute(absl::Span<AbstractTensorHandle*> retvals,
|
||||||
int* num_retvals) {
|
int* num_retvals) {
|
||||||
|
for (int i = 0; i < Inputs().size(); ++i) {
|
||||||
|
TF_RETURN_IF_ERROR(Inputs()[i]->WaitUnknownDevice());
|
||||||
|
}
|
||||||
// Run eager placement logic.
|
// Run eager placement logic.
|
||||||
VariantDevice device;
|
VariantDevice device;
|
||||||
TF_RETURN_IF_ERROR(eager::MaybePinToCustomDevice(&device, *this));
|
TF_RETURN_IF_ERROR(eager::MaybePinToCustomDevice(&device, *this));
|
||||||
|
@ -607,8 +607,14 @@ Status CreateUnshapedOutput(
|
|||||||
"Unable to find remote task corresponding to device ",
|
"Unable to find remote task corresponding to device ",
|
||||||
output_device->name());
|
output_device->name());
|
||||||
}
|
}
|
||||||
*output = TensorHandle::CreateUnshapedRemoteHandle(
|
if (ctx->RemoteMgr()->IsMaster()) {
|
||||||
op_id, output_num, remote_task, output_dtype, output_device, ctx);
|
*output = TensorHandle::CreateUnshapedRemoteHandle(
|
||||||
|
op_id, output_num, remote_task, output_dtype, output_device, ctx);
|
||||||
|
} else {
|
||||||
|
*output = TensorHandle::CreateLazyRemoteHandle(op_id, output_num,
|
||||||
|
output_dtype, output_device,
|
||||||
|
/*is_ready=*/false, ctx);
|
||||||
|
}
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
#endif // !IS_MOBILE_PLATFORM
|
#endif // !IS_MOBILE_PLATFORM
|
||||||
}
|
}
|
||||||
@ -916,14 +922,15 @@ Status EagerRemoteExecute(EagerOperation* op, TensorHandle** retvals,
|
|||||||
// execute.
|
// execute.
|
||||||
|
|
||||||
// The device_ and resource_device_ of this TensorHandle might be
|
// The device_ and resource_device_ of this TensorHandle might be
|
||||||
// incorrect. It is pretty hard to make it correct because for
|
// incorrect. For multi-device functions, we don't know the output device
|
||||||
// multi-device functions, we don't know the output device until the
|
// until the function is instantiated on a remote worker. Luckily, we don't
|
||||||
// function is instantiated. Luckily, we don't need to know the correct
|
// need to know the correct remote device here. We just need to know that it
|
||||||
// remote device here. We just need to know that it is remote. If we need
|
// is remote. If we need copy this tensor to this process or run any ops
|
||||||
// to copy this tensor to this process, the remote end will know the
|
// which take this tensor as an input, block until the correct device is
|
||||||
// correct device of this handle.
|
// set.
|
||||||
|
const bool unknown_device = op->is_function();
|
||||||
retvals[i] = TensorHandle::CreateUnshapedRemoteHandle(
|
retvals[i] = TensorHandle::CreateUnshapedRemoteHandle(
|
||||||
id, i, remote_task, output_dtypes[i], op_device, &ctx);
|
id, i, remote_task, output_dtypes[i], op_device, &ctx, unknown_device);
|
||||||
}
|
}
|
||||||
|
|
||||||
if (ctx.LazyCopyFunctionRemoteInputs()) {
|
if (ctx.LazyCopyFunctionRemoteInputs()) {
|
||||||
@ -1206,6 +1213,7 @@ Status LocalEagerCopyToDevice(TensorHandle* h, EagerContext* ctx,
|
|||||||
Status EagerCopyToDevice(TensorHandle* h, EagerContext* ctx,
|
Status EagerCopyToDevice(TensorHandle* h, EagerContext* ctx,
|
||||||
EagerExecutor* executor, Device* device, bool mirror,
|
EagerExecutor* executor, Device* device, bool mirror,
|
||||||
TensorHandle** result) {
|
TensorHandle** result) {
|
||||||
|
TF_RETURN_IF_ERROR(h->WaitUnknownDevice());
|
||||||
auto send_device = h->DeviceOrHostCPU(*ctx);
|
auto send_device = h->DeviceOrHostCPU(*ctx);
|
||||||
if (VariantDeviceIsCustom(send_device)) {
|
if (VariantDeviceIsCustom(send_device)) {
|
||||||
return errors::Unimplemented(
|
return errors::Unimplemented(
|
||||||
|
@ -94,9 +94,9 @@ TEST(ExecuteNodeTest, ExecuteNodeArgs) {
|
|||||||
TensorHandle::CreateLocalHandle(std::move(t1), device0, device0, ctx);
|
TensorHandle::CreateLocalHandle(std::move(t1), device0, device0, ctx);
|
||||||
// Create two remote TensorHandles
|
// Create two remote TensorHandles
|
||||||
TensorHandle* h2 = TensorHandle::CreateLazyRemoteHandle(
|
TensorHandle* h2 = TensorHandle::CreateLazyRemoteHandle(
|
||||||
/*op_id=*/1, /*output_num=*/0, dtype, device1, ctx);
|
/*op_id=*/1, /*output_num=*/0, dtype, device1, /*is_ready=*/true, ctx);
|
||||||
TensorHandle* h3 = TensorHandle::CreateLazyRemoteHandle(
|
TensorHandle* h3 = TensorHandle::CreateLazyRemoteHandle(
|
||||||
/*op_id=*/2, /*output_num=*/1, dtype, device1, ctx);
|
/*op_id=*/2, /*output_num=*/1, dtype, device1, /*is_ready=*/true, ctx);
|
||||||
// Create a packed TensorHandle
|
// Create a packed TensorHandle
|
||||||
TensorHandle* packed_h = nullptr;
|
TensorHandle* packed_h = nullptr;
|
||||||
TF_ASSERT_OK(TensorHandle::CreatePackedHandle({h1, h2}, ctx, &packed_h));
|
TF_ASSERT_OK(TensorHandle::CreatePackedHandle({h1, h2}, ctx, &packed_h));
|
||||||
|
@ -115,6 +115,20 @@ bool TensorHandle::PackedTensorHandleData::IsReady() const {
|
|||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
Status TensorHandle::PackedTensorHandleData::WaitReady(
|
||||||
|
const char* caller) const {
|
||||||
|
{
|
||||||
|
tf_shared_lock l(mu_);
|
||||||
|
if (!is_poisoned_.ok()) {
|
||||||
|
return is_poisoned_;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
for (auto* handle : handles_) {
|
||||||
|
TF_RETURN_IF_ERROR(handle->WaitReady(caller));
|
||||||
|
}
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
|
||||||
void TensorHandle::PackedTensorHandleData::Poison(Status status) {
|
void TensorHandle::PackedTensorHandleData::Poison(Status status) {
|
||||||
mutex_lock l(mu_);
|
mutex_lock l(mu_);
|
||||||
is_poisoned_ = status;
|
is_poisoned_ = status;
|
||||||
@ -370,14 +384,16 @@ TensorHandle::TensorHandle(std::vector<TensorHandle*>&& handles, Device* device,
|
|||||||
#if !defined(IS_MOBILE_PLATFORM)
|
#if !defined(IS_MOBILE_PLATFORM)
|
||||||
TensorHandle* TensorHandle::CreateUnshapedRemoteHandle(
|
TensorHandle* TensorHandle::CreateUnshapedRemoteHandle(
|
||||||
int64 op_id, int32 output_num, const string& remote_task,
|
int64 op_id, int32 output_num, const string& remote_task,
|
||||||
tensorflow::DataType dtype, Device* d, EagerContext* ctx) {
|
tensorflow::DataType dtype, Device* d, EagerContext* ctx,
|
||||||
return new TensorHandle(op_id, output_num, remote_task, dtype, d, ctx);
|
const bool unknown_device) {
|
||||||
|
return new TensorHandle(op_id, output_num, remote_task, dtype, d, ctx,
|
||||||
|
unknown_device);
|
||||||
}
|
}
|
||||||
|
|
||||||
TensorHandle::TensorHandle(int64 op_id, int32 output_num,
|
TensorHandle::TensorHandle(int64 op_id, int32 output_num,
|
||||||
const string& remote_task,
|
const string& remote_task,
|
||||||
tensorflow::DataType dtype, Device* d,
|
tensorflow::DataType dtype, Device* d,
|
||||||
EagerContext* ctx)
|
EagerContext* ctx, const bool unknown_device)
|
||||||
: ImmediateExecutionTensorHandle(kEager),
|
: ImmediateExecutionTensorHandle(kEager),
|
||||||
dtype(dtype),
|
dtype(dtype),
|
||||||
device_(d),
|
device_(d),
|
||||||
@ -385,6 +401,7 @@ TensorHandle::TensorHandle(int64 op_id, int32 output_num,
|
|||||||
resource_device_(dtype == DT_RESOURCE ? d : nullptr),
|
resource_device_(dtype == DT_RESOURCE ? d : nullptr),
|
||||||
resource_remote_device_incarnation_(
|
resource_remote_device_incarnation_(
|
||||||
GetRemoteDeviceIncarnation(resource_device_)),
|
GetRemoteDeviceIncarnation(resource_device_)),
|
||||||
|
unknown_device_(unknown_device),
|
||||||
ctx_(ctx),
|
ctx_(ctx),
|
||||||
data_(absl::in_place_type<RemoteTensorHandleData>, op_id, output_num,
|
data_(absl::in_place_type<RemoteTensorHandleData>, op_id, output_num,
|
||||||
remote_task, ctx) {
|
remote_task, ctx) {
|
||||||
@ -392,17 +409,15 @@ TensorHandle::TensorHandle(int64 op_id, int32 output_num,
|
|||||||
<< " device: " << VariantDeviceDebugString(device_);
|
<< " device: " << VariantDeviceDebugString(device_);
|
||||||
}
|
}
|
||||||
|
|
||||||
TensorHandle* TensorHandle::CreateLazyRemoteHandle(int64 op_id,
|
TensorHandle* TensorHandle::CreateLazyRemoteHandle(
|
||||||
int32 output_num,
|
int64 op_id, int32 output_num, tensorflow::DataType dtype, Device* d,
|
||||||
tensorflow::DataType dtype,
|
const bool is_ready, EagerContext* ctx) {
|
||||||
Device* d,
|
return new TensorHandle(op_id, output_num, dtype, d, is_ready, ctx);
|
||||||
EagerContext* ctx) {
|
|
||||||
return new TensorHandle(op_id, output_num, dtype, d, ctx);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
TensorHandle::TensorHandle(int64 op_id, int32 output_num,
|
TensorHandle::TensorHandle(int64 op_id, int32 output_num,
|
||||||
tensorflow::DataType dtype, Device* d,
|
tensorflow::DataType dtype, Device* d,
|
||||||
EagerContext* ctx)
|
const bool is_ready, EagerContext* ctx)
|
||||||
: ImmediateExecutionTensorHandle(kEager),
|
: ImmediateExecutionTensorHandle(kEager),
|
||||||
dtype(dtype),
|
dtype(dtype),
|
||||||
device_(d),
|
device_(d),
|
||||||
@ -412,7 +427,7 @@ TensorHandle::TensorHandle(int64 op_id, int32 output_num,
|
|||||||
GetRemoteDeviceIncarnation(resource_device_)),
|
GetRemoteDeviceIncarnation(resource_device_)),
|
||||||
ctx_(ctx),
|
ctx_(ctx),
|
||||||
data_(absl::in_place_type<RemoteTensorHandleData>, op_id, output_num,
|
data_(absl::in_place_type<RemoteTensorHandleData>, op_id, output_num,
|
||||||
ctx->GetContextViewId()) {
|
ctx->GetContextViewId(), is_ready) {
|
||||||
DVLOG(3) << "Creating Lazy Remote TensorHandle: " << this
|
DVLOG(3) << "Creating Lazy Remote TensorHandle: " << this
|
||||||
<< " device: " << VariantDeviceDebugString(device_);
|
<< " device: " << VariantDeviceDebugString(device_);
|
||||||
}
|
}
|
||||||
@ -431,6 +446,11 @@ bool TensorHandle::IsReady() const {
|
|||||||
return absl::visit([](auto& data) { return data.IsReady(); }, data_);
|
return absl::visit([](auto& data) { return data.IsReady(); }, data_);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
Status TensorHandle::WaitReady(const char* caller) const {
|
||||||
|
return absl::visit([caller](auto& data) { return data.WaitReady(caller); },
|
||||||
|
data_);
|
||||||
|
}
|
||||||
|
|
||||||
TensorHandle::HandleType TensorHandle::Type() const {
|
TensorHandle::HandleType TensorHandle::Type() const {
|
||||||
if (data_.index() == 0) {
|
if (data_.index() == 0) {
|
||||||
return LOCAL;
|
return LOCAL;
|
||||||
@ -518,6 +538,17 @@ Status TensorHandle::TensorValue(const Device* d, tensorflow::TensorValue* t) {
|
|||||||
return mirror.TensorValue(t);
|
return mirror.TensorValue(t);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
Status TensorHandle::WaitUnknownDevice() const {
|
||||||
|
if (unknown_device_) {
|
||||||
|
TF_RETURN_IF_ERROR(absl::visit(
|
||||||
|
[](auto& data) {
|
||||||
|
return data.WaitReady("TensorHandle::UnknownDevice");
|
||||||
|
},
|
||||||
|
data_));
|
||||||
|
}
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
|
||||||
VariantDevice TensorHandle::DeviceOrHostCPU(const EagerContext& ctx) const {
|
VariantDevice TensorHandle::DeviceOrHostCPU(const EagerContext& ctx) const {
|
||||||
if (VariantDeviceIsCustom(device_)) {
|
if (VariantDeviceIsCustom(device_)) {
|
||||||
return device_;
|
return device_;
|
||||||
@ -786,13 +817,21 @@ Status TensorHandle::AddResourceShapeMirror(const Device* d, int64 op_id,
|
|||||||
|
|
||||||
resource_shape_mirrors_.emplace(
|
resource_shape_mirrors_.emplace(
|
||||||
std::piecewise_construct, std::forward_as_tuple(d->name()),
|
std::piecewise_construct, std::forward_as_tuple(d->name()),
|
||||||
std::forward_as_tuple(op_id, output_num, ctx->GetContextViewId()));
|
std::forward_as_tuple(op_id, output_num, ctx->GetContextViewId(),
|
||||||
|
/*is_ready=*/true));
|
||||||
|
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
|
||||||
Status TensorHandle::SetRemoteShape(const TensorShape& shape, const Device* d,
|
Status TensorHandle::SetRemoteShape(const TensorShape& shape, const Device* d,
|
||||||
uint64 context_view_id) {
|
uint64 context_view_id) {
|
||||||
|
return SetRemoteShapeAndDevice(shape, d, context_view_id, /*op_device=*/"");
|
||||||
|
}
|
||||||
|
|
||||||
|
Status TensorHandle::SetRemoteShapeAndDevice(const TensorShape& shape,
|
||||||
|
const Device* d,
|
||||||
|
uint64 context_view_id,
|
||||||
|
string op_device) {
|
||||||
DVLOG(3) << "SetRemoteShape on TensorHandle: " << this << " device: " << d
|
DVLOG(3) << "SetRemoteShape on TensorHandle: " << this << " device: " << d
|
||||||
<< " " << d->name();
|
<< " " << d->name();
|
||||||
|
|
||||||
@ -830,7 +869,27 @@ Status TensorHandle::SetRemoteShape(const TensorShape& shape, const Device* d,
|
|||||||
// For mirrors, this is not the case because they colocate with the data
|
// For mirrors, this is not the case because they colocate with the data
|
||||||
// consuming op/function device, and we (for now) have to aggressively
|
// consuming op/function device, and we (for now) have to aggressively
|
||||||
// invalidate those copies to avoid any false positives during cluster update.
|
// invalidate those copies to avoid any false positives during cluster update.
|
||||||
return data.SetShape(shape);
|
if (op_device.empty()) {
|
||||||
|
return data.SetShape(shape);
|
||||||
|
} else {
|
||||||
|
if (!unknown_device_) {
|
||||||
|
return errors::Internal("Cannot reset known devices.");
|
||||||
|
}
|
||||||
|
Device* device;
|
||||||
|
TF_RETURN_IF_ERROR(ctx_->FindDeviceFromName(op_device.c_str(), &device));
|
||||||
|
device_ = device;
|
||||||
|
op_device_ = device;
|
||||||
|
resource_device_ = dtype == DT_RESOURCE ? device : nullptr;
|
||||||
|
resource_remote_device_incarnation_ =
|
||||||
|
GetRemoteDeviceIncarnation(resource_device_);
|
||||||
|
string remote_task;
|
||||||
|
if (!DeviceNameUtils::GetTaskName(device->parsed_name(), &remote_task)) {
|
||||||
|
return errors::InvalidArgument(
|
||||||
|
"Unable to find remote task corresponding to device ",
|
||||||
|
device->name());
|
||||||
|
}
|
||||||
|
return data.SetShapeAndRemoteTask(shape, remote_task);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
void TensorHandle::PoisonRemote(Status status, const Device* d,
|
void TensorHandle::PoisonRemote(Status status, const Device* d,
|
||||||
@ -1040,6 +1099,7 @@ const char* TensorHandle::DeviceName(Status* status) const {
|
|||||||
if (VariantDeviceIsCustom(device())) {
|
if (VariantDeviceIsCustom(device())) {
|
||||||
return absl::get<CustomDevice*>(device())->name().c_str();
|
return absl::get<CustomDevice*>(device())->name().c_str();
|
||||||
}
|
}
|
||||||
|
status->Update(WaitUnknownDevice());
|
||||||
tensorflow::Device* d = op_device();
|
tensorflow::Device* d = op_device();
|
||||||
return (d == nullptr) ? "/job:localhost/replica:0/task:0/device:CPU:0"
|
return (d == nullptr) ? "/job:localhost/replica:0/task:0/device:CPU:0"
|
||||||
: d->name().c_str();
|
: d->name().c_str();
|
||||||
@ -1049,6 +1109,7 @@ const char* TensorHandle::BackingDeviceName(Status* status) const {
|
|||||||
if (VariantDeviceIsCustom(device())) {
|
if (VariantDeviceIsCustom(device())) {
|
||||||
return absl::get<tensorflow::CustomDevice*>(device())->name().c_str();
|
return absl::get<tensorflow::CustomDevice*>(device())->name().c_str();
|
||||||
} else {
|
} else {
|
||||||
|
status->Update(WaitUnknownDevice());
|
||||||
tensorflow::Device* d = absl::get<tensorflow::Device*>(device());
|
tensorflow::Device* d = absl::get<tensorflow::Device*>(device());
|
||||||
return (d == nullptr) ? "/job:localhost/replica:0/task:0/device:CPU:0"
|
return (d == nullptr) ? "/job:localhost/replica:0/task:0/device:CPU:0"
|
||||||
: d->name().c_str();
|
: d->name().c_str();
|
||||||
|
@ -66,9 +66,10 @@ class TensorHandle : public ImmediateExecutionTensorHandle {
|
|||||||
|
|
||||||
#if !defined(IS_MOBILE_PLATFORM)
|
#if !defined(IS_MOBILE_PLATFORM)
|
||||||
TensorHandle(int64 op_id, int32 output_num, const string& remote_task,
|
TensorHandle(int64 op_id, int32 output_num, const string& remote_task,
|
||||||
tensorflow::DataType dtype, Device* device, EagerContext* ctx);
|
tensorflow::DataType dtype, Device* device, EagerContext* ctx,
|
||||||
|
const bool unknown_device);
|
||||||
TensorHandle(int64 op_id, int32 output_num, tensorflow::DataType dtype,
|
TensorHandle(int64 op_id, int32 output_num, tensorflow::DataType dtype,
|
||||||
Device* device, EagerContext* ctx);
|
Device* device, const bool is_ready, EagerContext* ctx);
|
||||||
#endif // IS_MOBILE_PLATFORM
|
#endif // IS_MOBILE_PLATFORM
|
||||||
|
|
||||||
public:
|
public:
|
||||||
@ -100,13 +101,21 @@ class TensorHandle : public ImmediateExecutionTensorHandle {
|
|||||||
TensorHandle** packed_handle);
|
TensorHandle** packed_handle);
|
||||||
|
|
||||||
#if !defined(IS_MOBILE_PLATFORM)
|
#if !defined(IS_MOBILE_PLATFORM)
|
||||||
static TensorHandle* CreateUnshapedRemoteHandle(int64 op_id, int32 output_num,
|
// An unshaped remote handle refers to a tensor on a remote worker. It's not
|
||||||
const string& remote_task,
|
// ready until the shape is set. It controls the lifetime of the remote
|
||||||
tensorflow::DataType dtype,
|
// tensor.
|
||||||
Device* d, EagerContext* ctx);
|
static TensorHandle* CreateUnshapedRemoteHandle(
|
||||||
|
int64 op_id, int32 output_num, const string& remote_task,
|
||||||
|
tensorflow::DataType dtype, Device* d, EagerContext* ctx,
|
||||||
|
const bool unknown_device = false);
|
||||||
|
// A lazy remote handle refers to a tensor on a remote worker. The lifetime of
|
||||||
|
// the remote tensor is controlled by the remote worker, but not by the lazy
|
||||||
|
// remote handle. Lazy handles are normally created on a default function
|
||||||
|
// device.
|
||||||
static TensorHandle* CreateLazyRemoteHandle(int64 op_id, int32 output_num,
|
static TensorHandle* CreateLazyRemoteHandle(int64 op_id, int32 output_num,
|
||||||
tensorflow::DataType dtype,
|
tensorflow::DataType dtype,
|
||||||
Device* d, EagerContext* ctx);
|
Device* d, const bool is_ready,
|
||||||
|
EagerContext* ctx);
|
||||||
#endif // IS_MOBILE_PLATFORM
|
#endif // IS_MOBILE_PLATFORM
|
||||||
|
|
||||||
void Release() override;
|
void Release() override;
|
||||||
@ -141,6 +150,10 @@ class TensorHandle : public ImmediateExecutionTensorHandle {
|
|||||||
return resource_remote_device_incarnation_;
|
return resource_remote_device_incarnation_;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// If the devices are unknown at creation time, block until the actual devices
|
||||||
|
// are set (data is ready).
|
||||||
|
Status WaitUnknownDevice() const;
|
||||||
|
|
||||||
VariantDevice DeviceOrHostCPU(const EagerContext& ctx) const;
|
VariantDevice DeviceOrHostCPU(const EagerContext& ctx) const;
|
||||||
|
|
||||||
Status Shape(tensorflow::TensorShape* shape);
|
Status Shape(tensorflow::TensorShape* shape);
|
||||||
@ -177,10 +190,15 @@ class TensorHandle : public ImmediateExecutionTensorHandle {
|
|||||||
// transitions the tensor handle from a non-ready to a ready state by
|
// transitions the tensor handle from a non-ready to a ready state by
|
||||||
// replacing the backing data abstraction to allow for the shape to be
|
// replacing the backing data abstraction to allow for the shape to be
|
||||||
// queried.
|
// queried.
|
||||||
|
// creating a TensorHandle (e.g. a remote output of a remote function).
|
||||||
// This method or Poison must be called exactly once for remote tensors that
|
// This method or Poison must be called exactly once for remote tensors that
|
||||||
// were created without a known shape.
|
// were created without a known shape.
|
||||||
Status SetRemoteShape(const TensorShape& shape, const Device* d,
|
Status SetRemoteShape(const TensorShape& shape, const Device* d,
|
||||||
uint64 context_view_id);
|
uint64 context_view_id);
|
||||||
|
// If op_device is not empty, reset the devices of a remote tensor which is
|
||||||
|
// created without known devices (e.g. function outputs).
|
||||||
|
Status SetRemoteShapeAndDevice(const TensorShape& shape, const Device* d,
|
||||||
|
uint64 context_view_id, string op_device);
|
||||||
|
|
||||||
// Poisons either this handle or a remote mirror with error `status`.
|
// Poisons either this handle or a remote mirror with error `status`.
|
||||||
// Poisoning means that the handle will become ready and methods trying
|
// Poisoning means that the handle will become ready and methods trying
|
||||||
@ -258,21 +276,27 @@ class TensorHandle : public ImmediateExecutionTensorHandle {
|
|||||||
// to either SetTensor or SetRemoteShape which replaces the underlying data
|
// to either SetTensor or SetRemoteShape which replaces the underlying data
|
||||||
// with a ready version of the tensor handle data.
|
// with a ready version of the tensor handle data.
|
||||||
bool IsReady() const;
|
bool IsReady() const;
|
||||||
|
Status WaitReady(const char* caller) const;
|
||||||
|
|
||||||
VariantDevice const device_;
|
VariantDevice device_;
|
||||||
|
|
||||||
// Device in which the op producing this tensor was executed. Equals to
|
// Device in which the op producing this tensor was executed. Equals to
|
||||||
// device_ for constant tensors.
|
// device_ for constant tensors.
|
||||||
// Can be nullptr if the op producing this tensor was a function executed
|
// Can be nullptr if the op producing this tensor was a function executed
|
||||||
// with function library runtime.
|
// with function library runtime.
|
||||||
tensorflow::Device* const op_device_;
|
tensorflow::Device* op_device_;
|
||||||
|
|
||||||
// If the tensor dtype is DT_RESOURCE, resource_device_ holds the device
|
// If the tensor dtype is DT_RESOURCE, resource_device_ holds the device
|
||||||
// backing the resource. Else resource_device_ is nullptr.
|
// backing the resource. Else resource_device_ is nullptr.
|
||||||
tensorflow::Device* const resource_device_;
|
tensorflow::Device* resource_device_;
|
||||||
// Incarnation ID of the resource device if it locates on a remote device, or
|
// Incarnation ID of the resource device if it locates on a remote device, or
|
||||||
// 0 if it locates on a local device.
|
// 0 if it locates on a local device.
|
||||||
const int64 resource_remote_device_incarnation_;
|
int64 resource_remote_device_incarnation_;
|
||||||
|
|
||||||
|
// If true, the handle refers to a remote tensor which is created without
|
||||||
|
// known devices. The actual devices are set by SetRemoteShape. The devices
|
||||||
|
// should be accessed once the handle is ready.
|
||||||
|
const bool unknown_device_ = false;
|
||||||
|
|
||||||
mutable mutex mu_;
|
mutable mutex mu_;
|
||||||
|
|
||||||
@ -323,6 +347,7 @@ class TensorHandle : public ImmediateExecutionTensorHandle {
|
|||||||
Status NumElements(int64* num_elements) const;
|
Status NumElements(int64* num_elements) const;
|
||||||
Status Unprotect();
|
Status Unprotect();
|
||||||
bool IsReady() const;
|
bool IsReady() const;
|
||||||
|
Status WaitReady(const char* caller) const;
|
||||||
void Poison(Status status);
|
void Poison(Status status);
|
||||||
string DebugString() const;
|
string DebugString() const;
|
||||||
|
|
||||||
|
@ -334,4 +334,84 @@ TEST(TensorHandle_ResourceDeviceTest, OnRemoteDevice) {
|
|||||||
ctx->Unref();
|
ctx->Unref();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
class RemoteTensorHandleTest : public ::testing::Test {
|
||||||
|
public:
|
||||||
|
RemoteTensorHandleTest() {
|
||||||
|
std::vector<std::unique_ptr<Device>> devices;
|
||||||
|
for (const char* name : device_names_) {
|
||||||
|
devices.emplace_back(CreateDevice("CPU", name));
|
||||||
|
}
|
||||||
|
device_mgr_ = new StaticDeviceMgr(std::move(devices));
|
||||||
|
|
||||||
|
context_ = new EagerContext(
|
||||||
|
SessionOptions(),
|
||||||
|
tensorflow::ContextDevicePlacementPolicy::DEVICE_PLACEMENT_SILENT,
|
||||||
|
tensorflow::ContextMirroringPolicy::MIRRORING_NONE, /* async= */ false,
|
||||||
|
/* lazy_copy_function_remote_inputs= */ false, device_mgr_,
|
||||||
|
/* device_mgr_owned= */ false, /* rendezvous= */ nullptr,
|
||||||
|
/* custom_kernel_creator= */ nullptr,
|
||||||
|
/* cluster_flr= */ nullptr);
|
||||||
|
}
|
||||||
|
|
||||||
|
~RemoteTensorHandleTest() override {
|
||||||
|
delete device_mgr_;
|
||||||
|
context_->Unref();
|
||||||
|
}
|
||||||
|
|
||||||
|
EagerContext* context() { return context_; }
|
||||||
|
|
||||||
|
std::vector<Device*> ListDevices() const {
|
||||||
|
return device_mgr_->ListDevices();
|
||||||
|
}
|
||||||
|
|
||||||
|
private:
|
||||||
|
const std::vector<const char*> device_names_ = {
|
||||||
|
"/job:worker/replica:0/task:0/device:CPU:0",
|
||||||
|
"/job:worker/replica:0/task:1/device:CPU:0",
|
||||||
|
"/job:worker/replica:0/task:2/device:CPU:0"};
|
||||||
|
|
||||||
|
StaticDeviceMgr* device_mgr_;
|
||||||
|
EagerContext* context_;
|
||||||
|
};
|
||||||
|
|
||||||
|
TEST_F(RemoteTensorHandleTest, UnknownRemoteDevice) {
|
||||||
|
std::vector<std::unique_ptr<Device>> devices;
|
||||||
|
devices.emplace_back(
|
||||||
|
CreateDevice("CPU", "/job:worker/replica:0/task:0/device:CPU:0"));
|
||||||
|
devices.emplace_back(
|
||||||
|
CreateDevice("CPU", "/job:worker/replica:0/task:1/device:CPU:0"));
|
||||||
|
devices.emplace_back(
|
||||||
|
CreateDevice("CPU", "/job:worker/replica:0/task:2/device:CPU:0"));
|
||||||
|
StaticDeviceMgr device_mgr(std::move(devices));
|
||||||
|
|
||||||
|
EagerContext* context = new EagerContext(
|
||||||
|
SessionOptions(),
|
||||||
|
tensorflow::ContextDevicePlacementPolicy::DEVICE_PLACEMENT_SILENT,
|
||||||
|
tensorflow::ContextMirroringPolicy::MIRRORING_NONE, /* async= */ false,
|
||||||
|
/* lazy_copy_function_remote_inputs= */ false, &device_mgr,
|
||||||
|
/* device_mgr_owned= */ false, /* rendezvous= */ nullptr,
|
||||||
|
/* custom_kernel_creator= */ nullptr,
|
||||||
|
/* cluster_flr= */ nullptr);
|
||||||
|
|
||||||
|
tensorflow::DataType dtype = DT_FLOAT;
|
||||||
|
TensorShape shape = {};
|
||||||
|
|
||||||
|
const string remote_task = "/job:worker/replica:0/task:1";
|
||||||
|
Device* d1 = device_mgr.ListDevices().at(1);
|
||||||
|
TensorHandle* h = TensorHandle::CreateUnshapedRemoteHandle(
|
||||||
|
/*op_id=*/0, /*output_num=*/0, remote_task, dtype, d1, context,
|
||||||
|
/*unknown_device=*/true);
|
||||||
|
EXPECT_EQ(absl::get<Device*>(h->device()), d1);
|
||||||
|
|
||||||
|
Device* d2 = device_mgr.ListDevices().at(2);
|
||||||
|
TF_ASSERT_OK(h->SetRemoteShapeAndDevice(
|
||||||
|
shape, d1, context->GetContextViewId(), d2->name()));
|
||||||
|
Status s;
|
||||||
|
EXPECT_EQ(h->BackingDeviceName(&s), d2->name());
|
||||||
|
TF_EXPECT_OK(s);
|
||||||
|
EXPECT_EQ(absl::get<Device*>(h->device()), d2);
|
||||||
|
h->Unref();
|
||||||
|
context->Unref();
|
||||||
|
}
|
||||||
|
|
||||||
} // namespace tensorflow
|
} // namespace tensorflow
|
||||||
|
@ -1006,16 +1006,13 @@ Status ProcessFunctionLibraryRuntime::GetOutputDevices(
|
|||||||
const string& target = pair.first;
|
const string& target = pair.first;
|
||||||
FunctionLibraryRuntime* target_flr = GetFLR(target);
|
FunctionLibraryRuntime* target_flr = GetFLR(target);
|
||||||
Device* target_device = nullptr;
|
Device* target_device = nullptr;
|
||||||
|
Device* host = nullptr;
|
||||||
if (target_flr == nullptr) {
|
if (target_flr == nullptr) {
|
||||||
// TODO(b/162618595): Remove this error once we support a remote
|
target_device = device_set()->FindDeviceByName(target);
|
||||||
// multi-device function with remote outputs.
|
string remote_host;
|
||||||
return errors::Unimplemented(
|
TF_RETURN_IF_ERROR(
|
||||||
"Currently, outputting tensors on remote devices is not supported."
|
DeviceNameUtils::DeviceNameToCpuDeviceName(target, &remote_host));
|
||||||
"The ",
|
host = device_set()->FindDeviceByName(remote_host);
|
||||||
comp_data.ret_indices[0],
|
|
||||||
"-th return value of the function outputs to target_device: ", target,
|
|
||||||
" Please copy the tensor to local device explicitly using "
|
|
||||||
"tf.identity and return the new Tensor instead.");
|
|
||||||
} else {
|
} else {
|
||||||
target_device = target_flr->device();
|
target_device = target_flr->device();
|
||||||
}
|
}
|
||||||
@ -1026,7 +1023,7 @@ Status ProcessFunctionLibraryRuntime::GetOutputDevices(
|
|||||||
(*output_devices)[ret_index] = target_device;
|
(*output_devices)[ret_index] = target_device;
|
||||||
} else {
|
} else {
|
||||||
(*output_devices)[ret_index] =
|
(*output_devices)[ret_index] =
|
||||||
comp_data.ret_alloc_attrs[j].on_host() ? nullptr : target_device;
|
comp_data.ret_alloc_attrs[j].on_host() ? host : target_device;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -203,7 +203,7 @@ class ProcessFunctionLibraryRuntime {
|
|||||||
|
|
||||||
const DeviceMgr* device_mgr() { return device_mgr_; }
|
const DeviceMgr* device_mgr() { return device_mgr_; }
|
||||||
|
|
||||||
const std::shared_ptr<DeviceSet> device_set() {
|
const std::shared_ptr<DeviceSet> device_set() const {
|
||||||
tf_shared_lock l(mu_);
|
tf_shared_lock l(mu_);
|
||||||
return device_set_;
|
return device_set_;
|
||||||
}
|
}
|
||||||
|
@ -156,9 +156,15 @@ Status TensorHandleShape(TensorHandle* handle, TensorShapeProto* proto) {
|
|||||||
const tensorflow::Tensor* t = nullptr;
|
const tensorflow::Tensor* t = nullptr;
|
||||||
|
|
||||||
// TODO(nareshmodi): This call makes async calls sync calls. Fix this.
|
// TODO(nareshmodi): This call makes async calls sync calls. Fix this.
|
||||||
TF_RETURN_IF_ERROR(handle->Tensor(&t));
|
if (handle->Type() == TensorHandle::LOCAL) {
|
||||||
|
TF_RETURN_IF_ERROR(handle->Tensor(&t));
|
||||||
|
|
||||||
t->shape().AsProto(proto);
|
t->shape().AsProto(proto);
|
||||||
|
} else {
|
||||||
|
TensorShape shape;
|
||||||
|
TF_RETURN_IF_ERROR(handle->Shape(&shape));
|
||||||
|
shape.AsProto(proto);
|
||||||
|
}
|
||||||
|
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
@ -166,7 +172,8 @@ Status TensorHandleShape(TensorHandle* handle, TensorShapeProto* proto) {
|
|||||||
Status AddOpRetvalsToResponse(
|
Status AddOpRetvalsToResponse(
|
||||||
EagerContext* eager_context, int op_id, int num_retvals,
|
EagerContext* eager_context, int op_id, int num_retvals,
|
||||||
TensorHandle** retvals, std::function<TensorProto*()> add_tensor_proto_fn,
|
TensorHandle** retvals, std::function<TensorProto*()> add_tensor_proto_fn,
|
||||||
std::function<TensorShapeProto*()> add_shape_proto_fn) {
|
std::function<TensorShapeProto*()> add_shape_proto_fn,
|
||||||
|
std::function<string*()> add_device_fn = nullptr) {
|
||||||
if (op_id == kInvalidRemoteOpId) {
|
if (op_id == kInvalidRemoteOpId) {
|
||||||
// Copy the output tensors back along with the response, since the op id
|
// Copy the output tensors back along with the response, since the op id
|
||||||
// is invalid which cannot be added to RemoteMgr.
|
// is invalid which cannot be added to RemoteMgr.
|
||||||
@ -175,10 +182,21 @@ Status AddOpRetvalsToResponse(
|
|||||||
retvals[i]->Unref();
|
retvals[i]->Unref();
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
eager_context->RemoteMgr()->AddOperationOutputs(
|
|
||||||
absl::MakeSpan(retvals, num_retvals), op_id);
|
|
||||||
for (int i = 0; i < num_retvals; i++) {
|
for (int i = 0; i < num_retvals; i++) {
|
||||||
TF_RETURN_IF_ERROR(TensorHandleShape(retvals[i], add_shape_proto_fn()));
|
TF_RETURN_IF_ERROR(TensorHandleShape(retvals[i], add_shape_proto_fn()));
|
||||||
|
const bool is_remote = retvals[i]->Type() == TensorHandle::REMOTE;
|
||||||
|
if (add_device_fn) {
|
||||||
|
*add_device_fn() =
|
||||||
|
is_remote ? absl::get<Device*>(
|
||||||
|
retvals[i]->DeviceOrHostCPU(*eager_context))
|
||||||
|
->name()
|
||||||
|
: "";
|
||||||
|
}
|
||||||
|
if (is_remote) {
|
||||||
|
retvals[i]->Unref();
|
||||||
|
} else {
|
||||||
|
eager_context->RemoteMgr()->AddOperationOutput(retvals[i], op_id, i);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
@ -479,6 +497,8 @@ void EagerServiceImpl::RunComponentFunction(
|
|||||||
wrapped_done(status);
|
wrapped_done(status);
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
// The output device of a component function is the component device
|
||||||
|
// which is known on the default device of it's parent function.
|
||||||
wrapped_done(AddOpRetvalsToResponse(
|
wrapped_done(AddOpRetvalsToResponse(
|
||||||
eager_context, op_id, *num_retvals, retvals->data(),
|
eager_context, op_id, *num_retvals, retvals->data(),
|
||||||
[response] { return response->add_tensor(); },
|
[response] { return response->add_tensor(); },
|
||||||
@ -510,10 +530,19 @@ Status EagerServiceImpl::ExecuteOp(CallOptions* call_opts,
|
|||||||
num_retvals),
|
num_retvals),
|
||||||
&num_retvals));
|
&num_retvals));
|
||||||
|
|
||||||
|
std::function<string*()> add_device_fn = nullptr;
|
||||||
|
// Send the output devices of a function back to let a client know where the
|
||||||
|
// outputs are. For a primitive op, an output devics is the op device which is
|
||||||
|
// known on a client.
|
||||||
|
if (op.is_function()) {
|
||||||
|
add_device_fn = [queue_response] { return queue_response->add_device(); };
|
||||||
|
}
|
||||||
|
|
||||||
return AddOpRetvalsToResponse(
|
return AddOpRetvalsToResponse(
|
||||||
eager_context, operation.id(), num_retvals, retvals.data(),
|
eager_context, operation.id(), num_retvals, retvals.data(),
|
||||||
[queue_response] { return queue_response->add_tensor(); },
|
[queue_response] { return queue_response->add_tensor(); },
|
||||||
[queue_response] { return queue_response->add_shape(); });
|
[queue_response] { return queue_response->add_shape(); },
|
||||||
|
std::move(add_device_fn));
|
||||||
}
|
}
|
||||||
|
|
||||||
Status EagerServiceImpl::Enqueue(CallOptions* call_opts,
|
Status EagerServiceImpl::Enqueue(CallOptions* call_opts,
|
||||||
|
@ -88,8 +88,14 @@ void RemoteExecuteNode::RunAsync(StatusCallback done) {
|
|||||||
}
|
}
|
||||||
for (size_t i = 0; i < retvals.size(); ++i) {
|
for (size_t i = 0; i < retvals.size(); ++i) {
|
||||||
if (status.ok()) {
|
if (status.ok()) {
|
||||||
Status s = retvals[i]->SetRemoteShape(
|
const string output_device =
|
||||||
response->queue_response(0).shape(i), device, context_view_id);
|
response->queue_response(0).device().empty()
|
||||||
|
? ""
|
||||||
|
: response->queue_response(0).device(i);
|
||||||
|
Status s = retvals[i]->SetRemoteShapeAndDevice(
|
||||||
|
response->queue_response(0).shape(i), device, context_view_id,
|
||||||
|
output_device);
|
||||||
|
|
||||||
if (!s.ok()) {
|
if (!s.ok()) {
|
||||||
LOG(ERROR) << "Ignoring an error encountered when setting "
|
LOG(ERROR) << "Ignoring an error encountered when setting "
|
||||||
"remote shape of tensor handle: "
|
"remote shape of tensor handle: "
|
||||||
|
@ -35,6 +35,13 @@ void RemoteMgr::AddOperationOutputs(
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void RemoteMgr::AddOperationOutput(tensorflow::TensorHandle* handle,
|
||||||
|
int64 operation_id, int32 output_num) {
|
||||||
|
mutex_lock l(remote_tensor_handle_mu_);
|
||||||
|
remote_tensor_handle_map_.emplace(
|
||||||
|
RemoteTensorHandleInternal(operation_id, output_num), handle);
|
||||||
|
}
|
||||||
|
|
||||||
Status RemoteMgr::GetTensorHandleImpl(
|
Status RemoteMgr::GetTensorHandleImpl(
|
||||||
const RemoteTensorHandleInternal& remote_handle,
|
const RemoteTensorHandleInternal& remote_handle,
|
||||||
tensorflow::TensorHandle** handle) {
|
tensorflow::TensorHandle** handle) {
|
||||||
@ -160,13 +167,14 @@ Status RemoteMgr::DeserializeRemoteTensorHandle(const RemoteTensorHandle& in,
|
|||||||
(*out)->Ref();
|
(*out)->Ref();
|
||||||
} else {
|
} else {
|
||||||
// Create a remote TensorHandle for remote tensors which have not been
|
// Create a remote TensorHandle for remote tensors which have not been
|
||||||
// copied to the local worker yet.
|
// copied to the local worker yet (e.g. remote function inputs).
|
||||||
const string& device_name =
|
const string& device_name =
|
||||||
in.op_device().empty() ? in.device() : in.op_device();
|
in.op_device().empty() ? in.device() : in.op_device();
|
||||||
TF_RETURN_IF_ERROR(
|
TF_RETURN_IF_ERROR(
|
||||||
parent_->FindDeviceFromName(device_name.c_str(), &device));
|
parent_->FindDeviceFromName(device_name.c_str(), &device));
|
||||||
*out = TensorHandle::CreateLazyRemoteHandle(in.op_id(), in.output_num(),
|
*out = TensorHandle::CreateLazyRemoteHandle(in.op_id(), in.output_num(),
|
||||||
in.dtype(), device, parent_);
|
in.dtype(), device,
|
||||||
|
/*is_ready=*/true, parent_);
|
||||||
std::vector<DtypeAndPartialTensorShape> dtypes_and_shapes;
|
std::vector<DtypeAndPartialTensorShape> dtypes_and_shapes;
|
||||||
if (!GetMirroredResourceShape(RemoteTensorHandleInternal(in),
|
if (!GetMirroredResourceShape(RemoteTensorHandleInternal(in),
|
||||||
&dtypes_and_shapes)
|
&dtypes_and_shapes)
|
||||||
|
@ -47,6 +47,9 @@ class RemoteMgr {
|
|||||||
const gtl::ArraySlice<tensorflow::TensorHandle*> handles,
|
const gtl::ArraySlice<tensorflow::TensorHandle*> handles,
|
||||||
int64 operation_id);
|
int64 operation_id);
|
||||||
|
|
||||||
|
void AddOperationOutput(tensorflow::TensorHandle* handles, int64 operation_id,
|
||||||
|
int32 output_num);
|
||||||
|
|
||||||
Status GetTensorHandle(const RemoteTensorHandleInternal& remote_handle,
|
Status GetTensorHandle(const RemoteTensorHandleInternal& remote_handle,
|
||||||
tensorflow::TensorHandle** handle);
|
tensorflow::TensorHandle** handle);
|
||||||
|
|
||||||
|
@ -95,7 +95,7 @@ TEST_F(RemoteMgrTest, SerializeRemoteTensorHandle) {
|
|||||||
const uint64 op_id = 3;
|
const uint64 op_id = 3;
|
||||||
const int output_num = 1;
|
const int output_num = 1;
|
||||||
TensorHandle* handle = TensorHandle::CreateLazyRemoteHandle(
|
TensorHandle* handle = TensorHandle::CreateLazyRemoteHandle(
|
||||||
op_id, output_num, DT_FLOAT, remote_device_, ctx_);
|
op_id, output_num, DT_FLOAT, remote_device_, /*is_ready=*/true, ctx_);
|
||||||
RemoteTensorHandle remote_handle;
|
RemoteTensorHandle remote_handle;
|
||||||
TF_ASSERT_OK(remote_mgr.SerializeRemoteTensorHandle(
|
TF_ASSERT_OK(remote_mgr.SerializeRemoteTensorHandle(
|
||||||
handle, /*wait_until_ready=*/true, &remote_handle, remote_device_,
|
handle, /*wait_until_ready=*/true, &remote_handle, remote_device_,
|
||||||
|
@ -85,8 +85,9 @@ void DestroyRemoteTensorHandle(EagerContext* ctx, const string& remote_task,
|
|||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
RemoteTensorHandleData::RemoteTensorHandleData(int64 op_id, int output_num,
|
RemoteTensorHandleData::RemoteTensorHandleData(int64 op_id, int output_num,
|
||||||
uint64 context_view_id)
|
uint64 context_view_id,
|
||||||
: is_ready_(true),
|
bool is_ready)
|
||||||
|
: is_ready_(is_ready),
|
||||||
op_id_(op_id),
|
op_id_(op_id),
|
||||||
output_num_(output_num),
|
output_num_(output_num),
|
||||||
context_view_id_(context_view_id),
|
context_view_id_(context_view_id),
|
||||||
@ -173,6 +174,11 @@ Status RemoteTensorHandleData::IsPoisoned() const {
|
|||||||
}
|
}
|
||||||
|
|
||||||
Status RemoteTensorHandleData::SetShape(const TensorShape& shape) {
|
Status RemoteTensorHandleData::SetShape(const TensorShape& shape) {
|
||||||
|
return SetShapeAndRemoteTask(shape, /*remote_task=*/"");
|
||||||
|
}
|
||||||
|
|
||||||
|
Status RemoteTensorHandleData::SetShapeAndRemoteTask(
|
||||||
|
const TensorShape& shape, const string& remote_task) {
|
||||||
// If `is_ready_` is set previously due to poisoning, return the original
|
// If `is_ready_` is set previously due to poisoning, return the original
|
||||||
// error that poisoned this tensor.
|
// error that poisoned this tensor.
|
||||||
TF_RETURN_IF_ERROR(IsPoisoned());
|
TF_RETURN_IF_ERROR(IsPoisoned());
|
||||||
@ -183,6 +189,9 @@ Status RemoteTensorHandleData::SetShape(const TensorShape& shape) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
shape_ = shape;
|
shape_ = shape;
|
||||||
|
if (!remote_task.empty()) {
|
||||||
|
remote_task_ = remote_task;
|
||||||
|
}
|
||||||
is_poisoned_ = Status::OK();
|
is_poisoned_ = Status::OK();
|
||||||
is_ready_ = true;
|
is_ready_ = true;
|
||||||
|
|
||||||
|
@ -26,11 +26,16 @@ namespace tensorflow {
|
|||||||
class RemoteTensorHandleData {
|
class RemoteTensorHandleData {
|
||||||
public:
|
public:
|
||||||
// Constructor for lazy remote handles. A lazy remote handle is created on
|
// Constructor for lazy remote handles. A lazy remote handle is created on
|
||||||
// a remote worker with an op_id and an output_num sent by a client. The
|
// a remote worker with an op_id and an output_num. It doesn't control the
|
||||||
// client won't serialize them until the corresponding remote tensor is ready.
|
// lifetime of a remote handle that it refers to. If it refers to a remote
|
||||||
// So the remote tensor should be ready when we create a lazy remote handle.
|
// function input, it's sent by a client which won't serialize it until
|
||||||
RemoteTensorHandleData(int64 op_id, int output_num, uint64 context_view_id);
|
// the corresponding remote tensor is ready. So the remote tensor should be
|
||||||
// Constructor for unshaped remote handles
|
// ready when we create a lazy remote handle. If it refers to a remote output,
|
||||||
|
// it's not ready until the shape is set.
|
||||||
|
RemoteTensorHandleData(int64 op_id, int output_num, uint64 context_view_id,
|
||||||
|
bool is_ready);
|
||||||
|
// Constructor for unshaped remote handles. It controls the lifetime of a
|
||||||
|
// remote handel that it refers to.
|
||||||
RemoteTensorHandleData(int64 op_id, int output_num, const string& remote_task,
|
RemoteTensorHandleData(int64 op_id, int output_num, const string& remote_task,
|
||||||
EagerContext* ctx);
|
EagerContext* ctx);
|
||||||
~RemoteTensorHandleData();
|
~RemoteTensorHandleData();
|
||||||
@ -44,7 +49,10 @@ class RemoteTensorHandleData {
|
|||||||
Status Unprotect() { return Status::OK(); }
|
Status Unprotect() { return Status::OK(); }
|
||||||
|
|
||||||
bool IsReady() const;
|
bool IsReady() const;
|
||||||
|
Status WaitReady(const char* caller) const;
|
||||||
Status SetShape(const TensorShape& shape);
|
Status SetShape(const TensorShape& shape);
|
||||||
|
Status SetShapeAndRemoteTask(const TensorShape& shape,
|
||||||
|
const string& remote_task);
|
||||||
void Poison(Status status);
|
void Poison(Status status);
|
||||||
Status IsPoisoned() const;
|
Status IsPoisoned() const;
|
||||||
|
|
||||||
@ -58,8 +66,6 @@ class RemoteTensorHandleData {
|
|||||||
uint64 context_view_id() const { return context_view_id_; }
|
uint64 context_view_id() const { return context_view_id_; }
|
||||||
|
|
||||||
private:
|
private:
|
||||||
Status WaitReady(const char* caller) const;
|
|
||||||
|
|
||||||
mutable mutex mu_;
|
mutable mutex mu_;
|
||||||
bool is_ready_ TF_GUARDED_BY(mu_);
|
bool is_ready_ TF_GUARDED_BY(mu_);
|
||||||
Status is_poisoned_ TF_GUARDED_BY(mu_);
|
Status is_poisoned_ TF_GUARDED_BY(mu_);
|
||||||
@ -68,7 +74,7 @@ class RemoteTensorHandleData {
|
|||||||
// IDs required when this class is representing a remote tensor handle.
|
// IDs required when this class is representing a remote tensor handle.
|
||||||
const int64 op_id_;
|
const int64 op_id_;
|
||||||
const int32 output_num_;
|
const int32 output_num_;
|
||||||
string remote_task_;
|
string remote_task_ TF_GUARDED_BY(mu_);
|
||||||
uint64 context_id_;
|
uint64 context_id_;
|
||||||
uint64 context_view_id_;
|
uint64 context_view_id_;
|
||||||
EagerContext* ctx_;
|
EagerContext* ctx_;
|
||||||
|
@ -77,6 +77,8 @@ message QueueResponse {
|
|||||||
// `shape` and `tensor` cannot be set in the same response.
|
// `shape` and `tensor` cannot be set in the same response.
|
||||||
// Shapes of output tensors for creating remote TensorHandles.
|
// Shapes of output tensors for creating remote TensorHandles.
|
||||||
repeated TensorShapeProto shape = 1;
|
repeated TensorShapeProto shape = 1;
|
||||||
|
// Optional. If set, represents the output devices of a function.
|
||||||
|
repeated string device = 3;
|
||||||
|
|
||||||
// Output tensors of a remote function. Set when Operation.id is invalid.
|
// Output tensors of a remote function. Set when Operation.id is invalid.
|
||||||
repeated TensorProto tensor = 2;
|
repeated TensorProto tensor = 2;
|
||||||
|
Loading…
Reference in New Issue
Block a user