Added a new eager C API TFE_NewContextFromSession(), where TFE_NewContext will

get an owned device mgr from the input session.

One use case is in S4TF, we run a graph session to enqueue a tensor into a fifo
queue, and then call TFE_Execute() on a dequeue op over the same queue, as a way
to transfer a tensor from TF to host (tensor tranfer in the other direction also
works).

To make this work, we need TFE_Context and the the TF_Session to use the same
ResourceMgr object (attached to a Device, which is in turn owned by DeviceMgr),
so that both can access the fifo queue resource op.

PiperOrigin-RevId: 211471075
This commit is contained in:
Mingsheng Hong 2018-09-04 09:41:42 -07:00 committed by TensorFlower Gardener
parent 8f0487189a
commit 102e0de242
6 changed files with 55 additions and 19 deletions

View File

@ -117,6 +117,7 @@ tf_cuda_library(
deps = [
":c_api",
":c_api_internal",
"//tensorflow/c/eager:c_api",
"//tensorflow/compiler/jit/legacy_flags:mark_for_compilation_pass_flags",
"//tensorflow/contrib/tpu:all_ops",
"//tensorflow/core:core_cpu",

View File

@ -20,6 +20,7 @@ limitations under the License.
#include <stdint.h>
#include "tensorflow/c/c_api.h"
#include "tensorflow/c/eager/c_api.h"
// --------------------------------------------------------------------------
// Experimental C API for TensorFlow.
@ -131,6 +132,9 @@ TF_CAPI_EXPORT extern void TF_EnqueueNamedTensor(TF_Session* session,
TF_Tensor* tensor,
TF_Status* status);
TF_CAPI_EXPORT extern TFE_Context* TFE_NewContextFromSession(
const TFE_ContextOptions* opts, TF_Session* sess, TF_Status* status);
#ifdef __cplusplus
} /* end extern "C" */
#endif

View File

@ -273,7 +273,20 @@ TFE_Context* TFE_NewContext(const TFE_ContextOptions* opts, TF_Status* status) {
new tensorflow::IntraProcessRendezvous(device_mgr.get());
return new TFE_Context(opts->session_options.options, opts->policy,
opts->async, std::move(device_mgr), r);
opts->async, device_mgr.release(),
/*device_mgr_owned*/ true, r);
}
TFE_Context* TFE_NewContextFromSession(const TFE_ContextOptions* opts,
TF_Session* sess, TF_Status* status) {
const tensorflow::DeviceMgr* device_mgr = nullptr;
status->status = sess->session->LocalDeviceManager(&device_mgr);
if (!status->status.ok()) return nullptr;
tensorflow::Rendezvous* r =
new tensorflow::IntraProcessRendezvous(device_mgr);
return new TFE_Context(opts->session_options.options, opts->policy,
opts->async, device_mgr, /*device_mgr_owned*/ false,
r);
}
void TFE_DeleteContext(TFE_Context* ctx) { delete ctx; }

View File

@ -62,15 +62,14 @@ struct TFE_ContextOptions {
};
struct TFE_Context {
explicit TFE_Context(const tensorflow::SessionOptions& opts,
TFE_ContextDevicePlacementPolicy default_policy,
bool async,
std::unique_ptr<tensorflow::DeviceMgr> device_mgr,
tensorflow::Rendezvous* rendezvous)
TFE_Context(const tensorflow::SessionOptions& opts,
TFE_ContextDevicePlacementPolicy default_policy, bool async,
const tensorflow::DeviceMgr* device_mgr, bool device_mgr_owned,
tensorflow::Rendezvous* rendezvous)
: context(opts,
static_cast<tensorflow::ContextDevicePlacementPolicy>(
default_policy),
async, std::move(device_mgr), rendezvous) {}
async, device_mgr, device_mgr_owned, rendezvous) {}
tensorflow::EagerContext context;
};

View File

@ -36,22 +36,34 @@ bool ReadBoolFromEnvVar(StringPiece env_var_name, bool default_val) {
EagerContext::EagerContext(const SessionOptions& opts,
ContextDevicePlacementPolicy default_policy,
bool async, std::unique_ptr<DeviceMgr> device_mgr,
bool async,
std::unique_ptr<const DeviceMgr> device_mgr,
Rendezvous* rendezvous)
: EagerContext(opts, default_policy, async, device_mgr.release(),
/*device_mgr_owned*/ true, rendezvous) {}
EagerContext::EagerContext(const SessionOptions& opts,
ContextDevicePlacementPolicy default_policy,
bool async, const DeviceMgr* device_mgr,
bool device_mgr_owned, Rendezvous* rendezvous)
: policy_(default_policy),
local_device_manager_(std::move(device_mgr)),
local_unowned_device_manager_(nullptr),
devices_(local_device_manager_->ListDevices()),
devices_(device_mgr->ListDevices()),
rendezvous_(rendezvous),
thread_pool_(NewThreadPoolFromSessionOptions(opts)),
pflr_(new ProcessFunctionLibraryRuntime(
local_device_manager_.get(), opts.env, TF_GRAPH_DEF_VERSION,
&func_lib_def_, {}, thread_pool_.get())),
device_mgr, opts.env, TF_GRAPH_DEF_VERSION, &func_lib_def_, {},
thread_pool_.get())),
log_device_placement_(opts.config.log_device_placement()),
num_active_steps_(0),
async_default_(async),
env_(opts.env),
use_send_tensor_rpc_(false) {
if (device_mgr_owned) {
local_device_manager_.reset(device_mgr);
local_unowned_device_manager_ = nullptr;
} else {
local_unowned_device_manager_ = device_mgr;
}
InitDeviceMapAndAsync();
if (opts.config.inter_op_parallelism_threads() > 0) {
runner_ = [this](std::function<void()> closure) {

View File

@ -65,10 +65,17 @@ enum ContextDevicePlacementPolicy {
class EagerContext {
public:
explicit EagerContext(const SessionOptions& opts,
ContextDevicePlacementPolicy default_policy, bool async,
std::unique_ptr<DeviceMgr> device_mgr,
Rendezvous* rendezvous);
// TODO: remove this constructor once we migrate all callers to the next one.
EagerContext(const SessionOptions& opts,
ContextDevicePlacementPolicy default_policy, bool async,
std::unique_ptr<const DeviceMgr> device_mgr,
Rendezvous* rendezvous);
EagerContext(const SessionOptions& opts,
ContextDevicePlacementPolicy default_policy, bool async,
const DeviceMgr* device_mgr, bool device_mgr_owned,
Rendezvous* rendezvous);
~EagerContext();
// Returns the function library runtime for the given device.
@ -207,8 +214,8 @@ class EagerContext {
thread_local_policies_ GUARDED_BY(policy_map_mu_);
// Only one of the below is set.
std::unique_ptr<DeviceMgr> local_device_manager_;
DeviceMgr* local_unowned_device_manager_;
std::unique_ptr<const DeviceMgr> local_device_manager_;
const DeviceMgr* local_unowned_device_manager_;
std::unique_ptr<DeviceMgr> remote_device_manager_;
// Devices owned by device_manager