Set executor when creating eager operation. This change can avoid executor being changed during op execution.
PiperOrigin-RevId: 260791313
This commit is contained in:
parent
5c05370ebd
commit
3746cb0a26
@ -566,7 +566,8 @@ TF_Tensor* TFE_TensorHandleResolve(TFE_TensorHandle* h, TF_Status* status) {
|
||||
const tensorflow::Tensor* t = nullptr;
|
||||
tensorflow::TensorHandle* h_cpu = nullptr;
|
||||
status->status = EagerCopyToDevice(
|
||||
handle, handle->Context(), handle->Context()->HostCPU(), false, &h_cpu);
|
||||
handle, handle->Context(), handle->Context()->Executor(),
|
||||
handle->Context()->HostCPU(), false, &h_cpu);
|
||||
if (!status->status.ok()) {
|
||||
return nullptr;
|
||||
}
|
||||
@ -918,6 +919,7 @@ TFE_TensorHandle* TFE_TensorHandleCopyToDevice(TFE_TensorHandle* h,
|
||||
return nullptr;
|
||||
}
|
||||
status->status = tensorflow::EagerCopyToDevice(h->handle, ctx->context,
|
||||
ctx->context->Executor(),
|
||||
device, false, &handle);
|
||||
if (status->status.ok()) {
|
||||
return new TFE_TensorHandle(handle);
|
||||
|
@ -31,7 +31,8 @@ class EagerOperation {
|
||||
attrs_(op),
|
||||
attr_types_(t),
|
||||
device_(nullptr),
|
||||
is_function_(is_function) {}
|
||||
is_function_(is_function),
|
||||
executor_(ctx ? ctx->Executor() : nullptr) {}
|
||||
|
||||
~EagerOperation() {
|
||||
for (tensorflow::TensorHandle* h : inputs_) {
|
||||
@ -81,6 +82,8 @@ class EagerOperation {
|
||||
cancellation_manager_ = cancellation_manager;
|
||||
}
|
||||
|
||||
EagerExecutor* Executor() { return executor_; }
|
||||
|
||||
string DebugString() const;
|
||||
|
||||
private:
|
||||
@ -94,6 +97,7 @@ class EagerOperation {
|
||||
bool use_xla_ = false;
|
||||
const bool is_function_;
|
||||
CancellationManager* cancellation_manager_ = nullptr; // Not owned.
|
||||
EagerExecutor* const executor_; // Not owned.
|
||||
};
|
||||
} // namespace tensorflow
|
||||
|
||||
|
@ -177,8 +177,9 @@ Status MaybeCopyInputToExpectedDevice(EagerOperation* op, Device* op_device,
|
||||
// trigger a copy.
|
||||
auto pre_time_nanos = Env::Default()->NowNanos();
|
||||
TensorHandle* result_handle = nullptr;
|
||||
Status status = EagerCopyToDevice(handle, ctx, expected_input_device,
|
||||
ctx->MirrorTensors(), &result_handle);
|
||||
Status status =
|
||||
EagerCopyToDevice(handle, ctx, op->Executor(), expected_input_device,
|
||||
ctx->MirrorTensors(), &result_handle);
|
||||
if (run_metadata != nullptr) {
|
||||
auto* step_stats = run_metadata->mutable_step_stats();
|
||||
MaybeInitializeStepStats(step_stats, ctx);
|
||||
@ -474,7 +475,7 @@ Status EagerLocalExecute(EagerOperation* op, TensorHandle** retvals,
|
||||
[&] { return absl::StrCat("EagerLocalExecute: ", op->Name()); },
|
||||
profiler::TraceMeLevel::kInfo);
|
||||
EagerContext* ctx = op->EagerContext();
|
||||
auto* executor = ctx->Executor();
|
||||
auto* executor = op->Executor();
|
||||
TF_RETURN_IF_ERROR(executor->status());
|
||||
Device* device = op->Device();
|
||||
|
||||
@ -508,7 +509,7 @@ Status EagerLocalExecute(EagerOperation* op, TensorHandle** retvals,
|
||||
if (input->IsRemote()) {
|
||||
TensorHandle* handle = nullptr;
|
||||
TF_RETURN_IF_ERROR(EagerCopyToDevice(
|
||||
input, ctx, device == nullptr ? ctx->HostCPU() : device,
|
||||
input, ctx, executor, device == nullptr ? ctx->HostCPU() : device,
|
||||
ctx->MirrorTensors(), &handle));
|
||||
op->UpdateInput(i, handle);
|
||||
// Unref handle since it has a ref as an input now
|
||||
@ -834,7 +835,7 @@ Status EagerRemoteExecute(EagerOperation* op, TensorHandle** retvals,
|
||||
|
||||
tensorflow::Device* op_device = op->Device();
|
||||
|
||||
auto* executor = ctx->Executor();
|
||||
auto* executor = op->Executor();
|
||||
bool is_async = executor->Async();
|
||||
VLOG(4) << "Execute remote eager op: " << op->Name()
|
||||
<< " (is async?: " << is_async << ").";
|
||||
@ -1149,9 +1150,9 @@ Status EagerKernelExecute(EagerContext* ctx,
|
||||
|
||||
namespace {
|
||||
|
||||
Status LocalEagerCopyToDevice(TensorHandle* h, EagerContext* ctx, Device* dstd,
|
||||
Status LocalEagerCopyToDevice(TensorHandle* h, EagerContext* ctx,
|
||||
EagerExecutor* executor, Device* dstd,
|
||||
TensorHandle** result) {
|
||||
auto* executor = ctx->Executor();
|
||||
TF_RETURN_IF_ERROR(executor->status());
|
||||
Device* resource_device = (h->dtype == DT_RESOURCE) ? dstd : nullptr;
|
||||
TF_RETURN_IF_ERROR(TensorHandle::CreateAsyncLocalHandle(
|
||||
@ -1173,8 +1174,9 @@ Status LocalEagerCopyToDevice(TensorHandle* h, EagerContext* ctx, Device* dstd,
|
||||
|
||||
} // namespace
|
||||
|
||||
Status EagerCopyToDevice(TensorHandle* h, EagerContext* ctx, Device* device,
|
||||
bool mirror, TensorHandle** result) {
|
||||
Status EagerCopyToDevice(TensorHandle* h, EagerContext* ctx,
|
||||
EagerExecutor* executor, Device* device, bool mirror,
|
||||
TensorHandle** result) {
|
||||
Device* send_device = h->DeviceOrHostCPU(ctx);
|
||||
|
||||
bool sender_is_local = send_device->IsLocal();
|
||||
@ -1182,7 +1184,7 @@ Status EagerCopyToDevice(TensorHandle* h, EagerContext* ctx, Device* device,
|
||||
bool recver_is_local = device->IsLocal();
|
||||
|
||||
if (sender_is_local && recver_is_local) {
|
||||
return LocalEagerCopyToDevice(h, ctx, device, result);
|
||||
return LocalEagerCopyToDevice(h, ctx, executor, device, result);
|
||||
} else {
|
||||
#if defined(IS_MOBILE_PLATFORM)
|
||||
return errors::Unimplemented(
|
||||
@ -1199,7 +1201,6 @@ Status EagerCopyToDevice(TensorHandle* h, EagerContext* ctx, Device* device,
|
||||
if (ctx->UseSendTensorRPC() && sender_is_local && !recver_is_local) {
|
||||
return EagerRemoteSendTensor(ctx, h, device, mirror, result);
|
||||
} else {
|
||||
auto* executor = ctx->Executor();
|
||||
uint64 recv_op_id = 0;
|
||||
if (recver_is_local) {
|
||||
TF_RETURN_IF_ERROR(TensorHandle::CreateAsyncLocalHandle(
|
||||
|
@ -58,8 +58,9 @@ Status EagerKernelExecute(EagerContext* ctx,
|
||||
// the mirror flag, EagerCopyToDevice will attempt to add a mirror to the
|
||||
// original handle and update *result to point to h. Since this is not
|
||||
// guaranteed, callers should always use the value in *result.
|
||||
Status EagerCopyToDevice(TensorHandle* h, EagerContext* ctx, Device* device,
|
||||
bool mirror, TensorHandle** result);
|
||||
Status EagerCopyToDevice(TensorHandle* h, EagerContext* ctx,
|
||||
EagerExecutor* executor, Device* device, bool mirror,
|
||||
TensorHandle** result);
|
||||
|
||||
} // namespace tensorflow
|
||||
|
||||
|
@ -346,8 +346,8 @@ Status EagerServiceImpl::SendTensor(const SendTensorRequest* request,
|
||||
Device* device;
|
||||
TF_RETURN_IF_ERROR(
|
||||
ctx->FindDeviceFromName(request->device_name().c_str(), &device));
|
||||
TF_RETURN_IF_ERROR(
|
||||
EagerCopyToDevice(tensor_handle, ctx, device, false, &copied_handle));
|
||||
TF_RETURN_IF_ERROR(EagerCopyToDevice(tensor_handle, ctx, ctx->Executor(),
|
||||
device, false, &copied_handle));
|
||||
tensors.push_back(copied_handle);
|
||||
tensor_handle->Unref();
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user