Set executor when creating eager operation. This change can avoid executor being changed during op execution.

PiperOrigin-RevId: 260791313
This commit is contained in:
Xiao Yu 2019-07-30 14:07:47 -07:00 committed by TensorFlower Gardener
parent 5c05370ebd
commit 3746cb0a26
5 changed files with 25 additions and 17 deletions

View File

@ -566,7 +566,8 @@ TF_Tensor* TFE_TensorHandleResolve(TFE_TensorHandle* h, TF_Status* status) {
const tensorflow::Tensor* t = nullptr;
tensorflow::TensorHandle* h_cpu = nullptr;
status->status = EagerCopyToDevice(
handle, handle->Context(), handle->Context()->HostCPU(), false, &h_cpu);
handle, handle->Context(), handle->Context()->Executor(),
handle->Context()->HostCPU(), false, &h_cpu);
if (!status->status.ok()) {
return nullptr;
}
@ -918,6 +919,7 @@ TFE_TensorHandle* TFE_TensorHandleCopyToDevice(TFE_TensorHandle* h,
return nullptr;
}
status->status = tensorflow::EagerCopyToDevice(h->handle, ctx->context,
ctx->context->Executor(),
device, false, &handle);
if (status->status.ok()) {
return new TFE_TensorHandle(handle);

View File

@ -31,7 +31,8 @@ class EagerOperation {
attrs_(op),
attr_types_(t),
device_(nullptr),
is_function_(is_function) {}
is_function_(is_function),
executor_(ctx ? ctx->Executor() : nullptr) {}
~EagerOperation() {
for (tensorflow::TensorHandle* h : inputs_) {
@ -81,6 +82,8 @@ class EagerOperation {
cancellation_manager_ = cancellation_manager;
}
EagerExecutor* Executor() { return executor_; }
string DebugString() const;
private:
@ -94,6 +97,7 @@ class EagerOperation {
bool use_xla_ = false;
const bool is_function_;
CancellationManager* cancellation_manager_ = nullptr; // Not owned.
EagerExecutor* const executor_; // Not owned.
};
} // namespace tensorflow

View File

@ -177,8 +177,9 @@ Status MaybeCopyInputToExpectedDevice(EagerOperation* op, Device* op_device,
// trigger a copy.
auto pre_time_nanos = Env::Default()->NowNanos();
TensorHandle* result_handle = nullptr;
Status status = EagerCopyToDevice(handle, ctx, expected_input_device,
ctx->MirrorTensors(), &result_handle);
Status status =
EagerCopyToDevice(handle, ctx, op->Executor(), expected_input_device,
ctx->MirrorTensors(), &result_handle);
if (run_metadata != nullptr) {
auto* step_stats = run_metadata->mutable_step_stats();
MaybeInitializeStepStats(step_stats, ctx);
@ -474,7 +475,7 @@ Status EagerLocalExecute(EagerOperation* op, TensorHandle** retvals,
[&] { return absl::StrCat("EagerLocalExecute: ", op->Name()); },
profiler::TraceMeLevel::kInfo);
EagerContext* ctx = op->EagerContext();
auto* executor = ctx->Executor();
auto* executor = op->Executor();
TF_RETURN_IF_ERROR(executor->status());
Device* device = op->Device();
@ -508,7 +509,7 @@ Status EagerLocalExecute(EagerOperation* op, TensorHandle** retvals,
if (input->IsRemote()) {
TensorHandle* handle = nullptr;
TF_RETURN_IF_ERROR(EagerCopyToDevice(
input, ctx, device == nullptr ? ctx->HostCPU() : device,
input, ctx, executor, device == nullptr ? ctx->HostCPU() : device,
ctx->MirrorTensors(), &handle));
op->UpdateInput(i, handle);
// Unref handle since it has a ref as an input now
@ -834,7 +835,7 @@ Status EagerRemoteExecute(EagerOperation* op, TensorHandle** retvals,
tensorflow::Device* op_device = op->Device();
auto* executor = ctx->Executor();
auto* executor = op->Executor();
bool is_async = executor->Async();
VLOG(4) << "Execute remote eager op: " << op->Name()
<< " (is async?: " << is_async << ").";
@ -1149,9 +1150,9 @@ Status EagerKernelExecute(EagerContext* ctx,
namespace {
Status LocalEagerCopyToDevice(TensorHandle* h, EagerContext* ctx, Device* dstd,
Status LocalEagerCopyToDevice(TensorHandle* h, EagerContext* ctx,
EagerExecutor* executor, Device* dstd,
TensorHandle** result) {
auto* executor = ctx->Executor();
TF_RETURN_IF_ERROR(executor->status());
Device* resource_device = (h->dtype == DT_RESOURCE) ? dstd : nullptr;
TF_RETURN_IF_ERROR(TensorHandle::CreateAsyncLocalHandle(
@ -1173,8 +1174,9 @@ Status LocalEagerCopyToDevice(TensorHandle* h, EagerContext* ctx, Device* dstd,
} // namespace
Status EagerCopyToDevice(TensorHandle* h, EagerContext* ctx, Device* device,
bool mirror, TensorHandle** result) {
Status EagerCopyToDevice(TensorHandle* h, EagerContext* ctx,
EagerExecutor* executor, Device* device, bool mirror,
TensorHandle** result) {
Device* send_device = h->DeviceOrHostCPU(ctx);
bool sender_is_local = send_device->IsLocal();
@ -1182,7 +1184,7 @@ Status EagerCopyToDevice(TensorHandle* h, EagerContext* ctx, Device* device,
bool recver_is_local = device->IsLocal();
if (sender_is_local && recver_is_local) {
return LocalEagerCopyToDevice(h, ctx, device, result);
return LocalEagerCopyToDevice(h, ctx, executor, device, result);
} else {
#if defined(IS_MOBILE_PLATFORM)
return errors::Unimplemented(
@ -1199,7 +1201,6 @@ Status EagerCopyToDevice(TensorHandle* h, EagerContext* ctx, Device* device,
if (ctx->UseSendTensorRPC() && sender_is_local && !recver_is_local) {
return EagerRemoteSendTensor(ctx, h, device, mirror, result);
} else {
auto* executor = ctx->Executor();
uint64 recv_op_id = 0;
if (recver_is_local) {
TF_RETURN_IF_ERROR(TensorHandle::CreateAsyncLocalHandle(

View File

@ -58,8 +58,9 @@ Status EagerKernelExecute(EagerContext* ctx,
// the mirror flag, EagerCopyToDevice will attempt to add a mirror to the
// original handle and update *result to point to h. Since this is not
// guaranteed, callers should always use the value in *result.
Status EagerCopyToDevice(TensorHandle* h, EagerContext* ctx, Device* device,
bool mirror, TensorHandle** result);
Status EagerCopyToDevice(TensorHandle* h, EagerContext* ctx,
EagerExecutor* executor, Device* device, bool mirror,
TensorHandle** result);
} // namespace tensorflow

View File

@ -346,8 +346,8 @@ Status EagerServiceImpl::SendTensor(const SendTensorRequest* request,
Device* device;
TF_RETURN_IF_ERROR(
ctx->FindDeviceFromName(request->device_name().c_str(), &device));
TF_RETURN_IF_ERROR(
EagerCopyToDevice(tensor_handle, ctx, device, false, &copied_handle));
TF_RETURN_IF_ERROR(EagerCopyToDevice(tensor_handle, ctx, ctx->Executor(),
device, false, &copied_handle));
tensors.push_back(copied_handle);
tensor_handle->Unref();
}