Fix two issues in remote tensor handle.

* When setting TensorHandle remote shape, it currently uses context_view_id to index mirrors, but should not use it to check when setting the shape of RemoteTensorHandleData itself.
* If a tensor is ready because it was previously poisoned, return the original error instead of the less useful error "SetShape is only called on non-ready handles".

PiperOrigin-RevId: 307633262
Change-Id: I22436402c6beeb41731802060b59851e807627d9
This commit is contained in:
Haoyu Zhang 2020-04-21 10:35:04 -07:00 committed by TensorFlower Gardener
parent 203c417a20
commit ffac1f3e0d
3 changed files with 58 additions and 4 deletions

View File

@ -626,10 +626,14 @@ Status TensorHandle::SetRemoteShape(const TensorShape& shape, const Device* d,
DCHECK(IsRemote()) << "SetRemoteShape is only called on remote handles.";
auto& data = absl::get<RemoteTensorHandleData>(data_);
if (data.context_view_id() != context_view_id) {
return errors::Internal("Attempted to set remote shape for an old handle.");
}
// context_view_id is currently used to validate mirrors. The shape of
// RemoteTensorHandleData should be set without checking context_view_id.
// The reason behind it is that for the primary copy of data, if the remote
// worker / device is removed, the consumer should report a connection error
// indicating the remote tensor is no longer available.
// For mirrors, this is not the case because they colocate with the data
// consuming op/function device, and we (for now) have to aggressively
// invalidate those copies to avoid any false positives during cluster update.
return data.SetShape(shape);
}

View File

@ -102,6 +102,52 @@ TEST_F(RemoteMgrTest, SerializeRemoteTensorHandle) {
handle->Unref();
}
TEST_F(RemoteMgrTest, InvalidateRemoteMirrorWithClusterUpdate) {
RemoteMgr remote_mgr(false, ctx_);
Tensor t(DT_FLOAT, TensorShape({0}));
TensorHandle* handle = TensorHandle::CreateLocalHandle(
std::move(t), local_device_, local_device_, ctx_);
const uint64 op_id = 2;
const int output_num = 3;
TF_ASSERT_OK(handle->AddUnshapedRemoteMirror(remote_device_, op_id,
output_num, "", ctx_));
EXPECT_TRUE(
handle->HasRemoteMirror(remote_device_, ctx_->GetContextViewId()));
// When updating cluster, remote mirror should be invalidated.
ctx_->IncrementContextViewId();
EXPECT_FALSE(
handle->HasRemoteMirror(remote_device_, ctx_->GetContextViewId()));
// Setting remote shape should still be OK
TF_ASSERT_OK(handle->SetRemoteShape(TensorShape({0}), remote_device_,
ctx_->GetContextViewId()));
handle->Unref();
}
TEST_F(RemoteMgrTest, SetRemoteShapeWithClusterUpdate) {
RemoteMgr remote_mgr(false, ctx_);
const uint64 op_id = 3;
const int output_num = 1;
TensorHandle* handle = TensorHandle::CreateUnshapedRemoteHandle(
op_id, output_num,
/*remote_task=*/"", DT_FLOAT, remote_device_, ctx_);
TF_ASSERT_OK(handle->SetRemoteShape(TensorShape({0}), remote_device_,
ctx_->GetContextViewId()));
handle->Unref();
// Setting remote shape on primary (non-mirror) remote handle works after
// cluster being updated
handle = TensorHandle::CreateUnshapedRemoteHandle(
op_id, output_num,
/*remote_task=*/"", DT_FLOAT, remote_device_, ctx_);
ctx_->IncrementContextViewId();
TF_ASSERT_OK(handle->SetRemoteShape(TensorShape({0}), remote_device_,
ctx_->GetContextViewId()));
handle->Unref();
}
} // namespace
} // namespace eager
} // namespace tensorflow

View File

@ -173,6 +173,10 @@ Status RemoteTensorHandleData::IsPoisoned() const {
}
Status RemoteTensorHandleData::SetShape(const TensorShape& shape) {
// If `is_ready_` is set previously due to poisoning, return the original
// error that poisoned this tensor.
TF_RETURN_IF_ERROR(IsPoisoned());
mutex_lock l(mu_);
if (is_ready_) {
return errors::Internal("SetShape is only called on non-ready handles.");