Fix two issues in remote tensor handle.
* When setting TensorHandle remote shape, it currently uses context_view_id to index mirrors, but should not use it to check when setting the shape of RemoteTensorHandleData itself. * If a tensor is ready because it was previously poisoned, return the original error instead of the less useful error "SetShape is only called on non-ready handles". PiperOrigin-RevId: 307633262 Change-Id: I22436402c6beeb41731802060b59851e807627d9
This commit is contained in:
parent
203c417a20
commit
ffac1f3e0d
@ -626,10 +626,14 @@ Status TensorHandle::SetRemoteShape(const TensorShape& shape, const Device* d,
|
||||
DCHECK(IsRemote()) << "SetRemoteShape is only called on remote handles.";
|
||||
|
||||
auto& data = absl::get<RemoteTensorHandleData>(data_);
|
||||
if (data.context_view_id() != context_view_id) {
|
||||
return errors::Internal("Attempted to set remote shape for an old handle.");
|
||||
}
|
||||
|
||||
// context_view_id is currently used to validate mirrors. The shape of
|
||||
// RemoteTensorHandleData should be set without checking context_view_id.
|
||||
// The reason behind it is that for the primary copy of data, if the remote
|
||||
// worker / device is removed, the consumer should report a connection error
|
||||
// indicating the remote tensor is no longer available.
|
||||
// For mirrors, this is not the case because they colocate with the data
|
||||
// consuming op/function device, and we (for now) have to aggressively
|
||||
// invalidate those copies to avoid any false positives during cluster update.
|
||||
return data.SetShape(shape);
|
||||
}
|
||||
|
||||
|
@ -102,6 +102,52 @@ TEST_F(RemoteMgrTest, SerializeRemoteTensorHandle) {
|
||||
handle->Unref();
|
||||
}
|
||||
|
||||
TEST_F(RemoteMgrTest, InvalidateRemoteMirrorWithClusterUpdate) {
|
||||
RemoteMgr remote_mgr(false, ctx_);
|
||||
Tensor t(DT_FLOAT, TensorShape({0}));
|
||||
|
||||
TensorHandle* handle = TensorHandle::CreateLocalHandle(
|
||||
std::move(t), local_device_, local_device_, ctx_);
|
||||
const uint64 op_id = 2;
|
||||
const int output_num = 3;
|
||||
TF_ASSERT_OK(handle->AddUnshapedRemoteMirror(remote_device_, op_id,
|
||||
output_num, "", ctx_));
|
||||
EXPECT_TRUE(
|
||||
handle->HasRemoteMirror(remote_device_, ctx_->GetContextViewId()));
|
||||
|
||||
// When updating cluster, remote mirror should be invalidated.
|
||||
ctx_->IncrementContextViewId();
|
||||
EXPECT_FALSE(
|
||||
handle->HasRemoteMirror(remote_device_, ctx_->GetContextViewId()));
|
||||
// Setting remote shape should still be OK
|
||||
TF_ASSERT_OK(handle->SetRemoteShape(TensorShape({0}), remote_device_,
|
||||
ctx_->GetContextViewId()));
|
||||
handle->Unref();
|
||||
}
|
||||
|
||||
TEST_F(RemoteMgrTest, SetRemoteShapeWithClusterUpdate) {
|
||||
RemoteMgr remote_mgr(false, ctx_);
|
||||
|
||||
const uint64 op_id = 3;
|
||||
const int output_num = 1;
|
||||
TensorHandle* handle = TensorHandle::CreateUnshapedRemoteHandle(
|
||||
op_id, output_num,
|
||||
/*remote_task=*/"", DT_FLOAT, remote_device_, ctx_);
|
||||
TF_ASSERT_OK(handle->SetRemoteShape(TensorShape({0}), remote_device_,
|
||||
ctx_->GetContextViewId()));
|
||||
handle->Unref();
|
||||
|
||||
// Setting remote shape on primary (non-mirror) remote handle works after
|
||||
// cluster being updated
|
||||
handle = TensorHandle::CreateUnshapedRemoteHandle(
|
||||
op_id, output_num,
|
||||
/*remote_task=*/"", DT_FLOAT, remote_device_, ctx_);
|
||||
ctx_->IncrementContextViewId();
|
||||
TF_ASSERT_OK(handle->SetRemoteShape(TensorShape({0}), remote_device_,
|
||||
ctx_->GetContextViewId()));
|
||||
handle->Unref();
|
||||
}
|
||||
|
||||
} // namespace
|
||||
} // namespace eager
|
||||
} // namespace tensorflow
|
||||
|
@ -173,6 +173,10 @@ Status RemoteTensorHandleData::IsPoisoned() const {
|
||||
}
|
||||
|
||||
Status RemoteTensorHandleData::SetShape(const TensorShape& shape) {
|
||||
// If `is_ready_` is set previously due to poisoning, return the original
|
||||
// error that poisoned this tensor.
|
||||
TF_RETURN_IF_ERROR(IsPoisoned());
|
||||
|
||||
mutex_lock l(mu_);
|
||||
if (is_ready_) {
|
||||
return errors::Internal("SetShape is only called on non-ready handles.");
|
||||
|
Loading…
Reference in New Issue
Block a user