Temporarily revert the WaitReady change for handles with unknown devices.
PiperOrigin-RevId: 326949430 Change-Id: I7170402785dcd86e1e16dc6f8391e6586b84a2ae
This commit is contained in:
parent
ca21b3f9f4
commit
431b88f123
@ -30,13 +30,12 @@ TEST(CAPI, RemoteExecuteSilentCopiesAsyncFunc) {
|
||||
TestRemoteExecuteSilentCopiesFunc(/*async=*/true, /*remote=*/true,
|
||||
/*heavy_load_on_streaming_rpc=*/false);
|
||||
}
|
||||
// TODO(b/164506563): Re-enable after the fix.
|
||||
TEST(CAPI, DISABLED_RemoteExecuteSilentCopiesFuncRemoteOutputs) {
|
||||
TEST(CAPI, RemoteExecuteSilentCopiesFuncRemoteOutputs) {
|
||||
TestRemoteExecuteSilentCopiesFunc(/*async=*/false, /*remote=*/true,
|
||||
/*heavy_load_on_streaming_rpc=*/false,
|
||||
/*remote_func_outputs=*/true);
|
||||
}
|
||||
TEST(CAPI, DISABLED_RemoteExecuteSilentCopiesAsyncFuncRemoteOutputs) {
|
||||
TEST(CAPI, RemoteExecuteSilentCopiesAsyncFuncRemoteOutputs) {
|
||||
TestRemoteExecuteSilentCopiesFunc(/*async=*/true, /*remote=*/true,
|
||||
/*heavy_load_on_streaming_rpc=*/false,
|
||||
/*remote_func_outputs=*/true);
|
||||
|
@ -539,14 +539,13 @@ Status TensorHandle::TensorValue(const Device* d, tensorflow::TensorValue* t) {
|
||||
}
|
||||
|
||||
Status TensorHandle::WaitUnknownDevice() const {
|
||||
// TODO(b/164506563): uncomment this when b/164506563 is fixed.
|
||||
// if (unknown_device_) {
|
||||
// TF_RETURN_IF_ERROR(absl::visit(
|
||||
// [](auto& data) {
|
||||
// return data.WaitReady("TensorHandle::UnknownDevice");
|
||||
// },
|
||||
// data_));
|
||||
// }
|
||||
if (unknown_device_) {
|
||||
TF_RETURN_IF_ERROR(absl::visit(
|
||||
[](auto& data) {
|
||||
return data.WaitReady("TensorHandle::UnknownDevice");
|
||||
},
|
||||
data_));
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
|
@ -49,13 +49,6 @@ void RemoteExecuteNode::RunAsync(StatusCallback done) {
|
||||
}
|
||||
VLOG(3) << "Issuing: " << rpc_description;
|
||||
|
||||
for (auto handle : inputs_) {
|
||||
handle->Ref();
|
||||
}
|
||||
for (auto handle : retvals) {
|
||||
handle->Ref();
|
||||
}
|
||||
|
||||
CancellationManager* cm = cancellation_manager_;
|
||||
CancellationToken token = 0;
|
||||
auto call_opts = std::make_shared<CallOptions>();
|
||||
@ -64,11 +57,22 @@ void RemoteExecuteNode::RunAsync(StatusCallback done) {
|
||||
const bool already_cancelled = !cm->RegisterCallback(
|
||||
token, [call_opts, response, done]() { call_opts->StartCancel(); });
|
||||
if (already_cancelled) {
|
||||
done(errors::Cancelled("RemoteExecuteNode::RunAsync"));
|
||||
Status s = errors::Cancelled("RemoteExecuteNode::RunAsync");
|
||||
for (size_t i = 0; i < retvals.size(); ++i) {
|
||||
retvals[i]->PoisonRemote(s, device, context_view_id_);
|
||||
}
|
||||
done(s);
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
||||
for (auto handle : inputs_) {
|
||||
handle->Ref();
|
||||
}
|
||||
for (auto handle : retvals) {
|
||||
handle->Ref();
|
||||
}
|
||||
|
||||
eager_client_->StreamingEnqueueAsync(
|
||||
call_opts.get(), request_.get(), response.get(),
|
||||
[inputs, retvals, call_opts, response, device,
|
||||
|
Loading…
Reference in New Issue
Block a user