Temporarily revert the WaitReady change for handles with unknown devices.

PiperOrigin-RevId: 326949430
Change-Id: I7170402785dcd86e1e16dc6f8391e6586b84a2ae
This commit is contained in:
Yujing Zhang 2020-08-16 20:28:39 -07:00 committed by TensorFlower Gardener
parent ca21b3f9f4
commit 431b88f123
3 changed files with 21 additions and 19 deletions

View File

@ -30,13 +30,12 @@ TEST(CAPI, RemoteExecuteSilentCopiesAsyncFunc) {
TestRemoteExecuteSilentCopiesFunc(/*async=*/true, /*remote=*/true,
/*heavy_load_on_streaming_rpc=*/false);
}
// TODO(b/164506563): Re-enable after the fix.
TEST(CAPI, DISABLED_RemoteExecuteSilentCopiesFuncRemoteOutputs) {
TEST(CAPI, RemoteExecuteSilentCopiesFuncRemoteOutputs) {
TestRemoteExecuteSilentCopiesFunc(/*async=*/false, /*remote=*/true,
/*heavy_load_on_streaming_rpc=*/false,
/*remote_func_outputs=*/true);
}
TEST(CAPI, DISABLED_RemoteExecuteSilentCopiesAsyncFuncRemoteOutputs) {
TEST(CAPI, RemoteExecuteSilentCopiesAsyncFuncRemoteOutputs) {
TestRemoteExecuteSilentCopiesFunc(/*async=*/true, /*remote=*/true,
/*heavy_load_on_streaming_rpc=*/false,
/*remote_func_outputs=*/true);

View File

@ -539,14 +539,13 @@ Status TensorHandle::TensorValue(const Device* d, tensorflow::TensorValue* t) {
}
Status TensorHandle::WaitUnknownDevice() const {
// TODO(b/164506563): uncomment this when b/164506563 is fixed.
// if (unknown_device_) {
// TF_RETURN_IF_ERROR(absl::visit(
// [](auto& data) {
// return data.WaitReady("TensorHandle::UnknownDevice");
// },
// data_));
// }
if (unknown_device_) {
TF_RETURN_IF_ERROR(absl::visit(
[](auto& data) {
return data.WaitReady("TensorHandle::UnknownDevice");
},
data_));
}
return Status::OK();
}

View File

@ -49,13 +49,6 @@ void RemoteExecuteNode::RunAsync(StatusCallback done) {
}
VLOG(3) << "Issuing: " << rpc_description;
for (auto handle : inputs_) {
handle->Ref();
}
for (auto handle : retvals) {
handle->Ref();
}
CancellationManager* cm = cancellation_manager_;
CancellationToken token = 0;
auto call_opts = std::make_shared<CallOptions>();
@ -64,11 +57,22 @@ void RemoteExecuteNode::RunAsync(StatusCallback done) {
const bool already_cancelled = !cm->RegisterCallback(
token, [call_opts, response, done]() { call_opts->StartCancel(); });
if (already_cancelled) {
done(errors::Cancelled("RemoteExecuteNode::RunAsync"));
Status s = errors::Cancelled("RemoteExecuteNode::RunAsync");
for (size_t i = 0; i < retvals.size(); ++i) {
retvals[i]->PoisonRemote(s, device, context_view_id_);
}
done(s);
return;
}
}
for (auto handle : inputs_) {
handle->Ref();
}
for (auto handle : retvals) {
handle->Ref();
}
eager_client_->StreamingEnqueueAsync(
call_opts.get(), request_.get(), response.get(),
[inputs, retvals, call_opts, response, device,