From d968853cc6825c705a4443844319279c464b152e Mon Sep 17 00:00:00 2001 From: Xiao Yu Date: Fri, 15 May 2020 12:12:51 -0700 Subject: [PATCH] Skip TFE_ContextAsyncWait for tfrt. In current TF-TFRT integration, all ops are executed synchronously. We will revisit this later. PiperOrigin-RevId: 311777624 Change-Id: I3a27805dcce53ccf572f3c500d6fd0a532b286b2 --- tensorflow/c/eager/c_api.cc | 4 +--- tensorflow/c/eager/context_interface.h | 3 +++ tensorflow/core/common_runtime/eager/context.h | 2 ++ 3 files changed, 6 insertions(+), 3 deletions(-) diff --git a/tensorflow/c/eager/c_api.cc b/tensorflow/c/eager/c_api.cc index 5c01ccb82bb..f5535c80d30 100644 --- a/tensorflow/c/eager/c_api.cc +++ b/tensorflow/c/eager/c_api.cc @@ -899,9 +899,7 @@ TF_CAPI_EXPORT extern void TFE_ContextAsyncWait(TFE_Context* ctx, #if defined(IS_MOBILE_PLATFORM) status->status = tensorflow::Status::OK(); #else // !defined(IS_MOBILE_PLATFORM) - tensorflow::EagerContext* context = - tensorflow::ContextFromInterface(tensorflow::unwrap(ctx)); - status->status = context->SyncExecutors(); + status->status = tensorflow::unwrap(ctx)->AsyncWait(); #endif // !IS_MOBILE_PLATFORM } diff --git a/tensorflow/c/eager/context_interface.h b/tensorflow/c/eager/context_interface.h index d21ab45e579..76f182f4945 100644 --- a/tensorflow/c/eager/context_interface.h +++ b/tensorflow/c/eager/context_interface.h @@ -101,6 +101,9 @@ class AbstractContextInterface { // Destroy the step resource container for a training step. virtual void EndStep() = 0; + // Block until all pending nodes are finished, + virtual Status AsyncWait() = 0; + protected: virtual ~AbstractContextInterface() {} }; diff --git a/tensorflow/core/common_runtime/eager/context.h b/tensorflow/core/common_runtime/eager/context.h index d034aaf2f9c..d03a91c817a 100644 --- a/tensorflow/core/common_runtime/eager/context.h +++ b/tensorflow/core/common_runtime/eager/context.h @@ -295,6 +295,8 @@ class EagerContext : public AbstractContextInterface, public core::RefCounted { // errors, and the error message will be combined from all executors. Status SyncExecutors(); + Status AsyncWait() override { return SyncExecutors(); } + core::RefCountPtr GetCachedKernel(Fprint128 cache_key); void AddKernelToCache(Fprint128 cache_key, KernelAndDevice* kernel);