Added a new eager C API TFE_NewContextFromSession(), where TFE_NewContext will
get an owned device mgr from the input session. One use case is in S4TF, we run a graph session to enqueue a tensor into a fifo queue, and then call TFE_Execute() on a dequeue op over the same queue, as a way to transfer a tensor from TF to host (tensor tranfer in the other direction also works). To make this work, we need TFE_Context and the the TF_Session to use the same ResourceMgr object (attached to a Device, which is in turn owned by DeviceMgr), so that both can access the fifo queue resource op. PiperOrigin-RevId: 211471075
This commit is contained in:
parent
8f0487189a
commit
102e0de242
@ -117,6 +117,7 @@ tf_cuda_library(
|
||||
deps = [
|
||||
":c_api",
|
||||
":c_api_internal",
|
||||
"//tensorflow/c/eager:c_api",
|
||||
"//tensorflow/compiler/jit/legacy_flags:mark_for_compilation_pass_flags",
|
||||
"//tensorflow/contrib/tpu:all_ops",
|
||||
"//tensorflow/core:core_cpu",
|
||||
|
@ -20,6 +20,7 @@ limitations under the License.
|
||||
#include <stdint.h>
|
||||
|
||||
#include "tensorflow/c/c_api.h"
|
||||
#include "tensorflow/c/eager/c_api.h"
|
||||
|
||||
// --------------------------------------------------------------------------
|
||||
// Experimental C API for TensorFlow.
|
||||
@ -131,6 +132,9 @@ TF_CAPI_EXPORT extern void TF_EnqueueNamedTensor(TF_Session* session,
|
||||
TF_Tensor* tensor,
|
||||
TF_Status* status);
|
||||
|
||||
TF_CAPI_EXPORT extern TFE_Context* TFE_NewContextFromSession(
|
||||
const TFE_ContextOptions* opts, TF_Session* sess, TF_Status* status);
|
||||
|
||||
#ifdef __cplusplus
|
||||
} /* end extern "C" */
|
||||
#endif
|
||||
|
@ -273,7 +273,20 @@ TFE_Context* TFE_NewContext(const TFE_ContextOptions* opts, TF_Status* status) {
|
||||
new tensorflow::IntraProcessRendezvous(device_mgr.get());
|
||||
|
||||
return new TFE_Context(opts->session_options.options, opts->policy,
|
||||
opts->async, std::move(device_mgr), r);
|
||||
opts->async, device_mgr.release(),
|
||||
/*device_mgr_owned*/ true, r);
|
||||
}
|
||||
|
||||
TFE_Context* TFE_NewContextFromSession(const TFE_ContextOptions* opts,
|
||||
TF_Session* sess, TF_Status* status) {
|
||||
const tensorflow::DeviceMgr* device_mgr = nullptr;
|
||||
status->status = sess->session->LocalDeviceManager(&device_mgr);
|
||||
if (!status->status.ok()) return nullptr;
|
||||
tensorflow::Rendezvous* r =
|
||||
new tensorflow::IntraProcessRendezvous(device_mgr);
|
||||
return new TFE_Context(opts->session_options.options, opts->policy,
|
||||
opts->async, device_mgr, /*device_mgr_owned*/ false,
|
||||
r);
|
||||
}
|
||||
|
||||
void TFE_DeleteContext(TFE_Context* ctx) { delete ctx; }
|
||||
|
@ -62,15 +62,14 @@ struct TFE_ContextOptions {
|
||||
};
|
||||
|
||||
struct TFE_Context {
|
||||
explicit TFE_Context(const tensorflow::SessionOptions& opts,
|
||||
TFE_ContextDevicePlacementPolicy default_policy,
|
||||
bool async,
|
||||
std::unique_ptr<tensorflow::DeviceMgr> device_mgr,
|
||||
tensorflow::Rendezvous* rendezvous)
|
||||
TFE_Context(const tensorflow::SessionOptions& opts,
|
||||
TFE_ContextDevicePlacementPolicy default_policy, bool async,
|
||||
const tensorflow::DeviceMgr* device_mgr, bool device_mgr_owned,
|
||||
tensorflow::Rendezvous* rendezvous)
|
||||
: context(opts,
|
||||
static_cast<tensorflow::ContextDevicePlacementPolicy>(
|
||||
default_policy),
|
||||
async, std::move(device_mgr), rendezvous) {}
|
||||
async, device_mgr, device_mgr_owned, rendezvous) {}
|
||||
|
||||
tensorflow::EagerContext context;
|
||||
};
|
||||
|
@ -36,22 +36,34 @@ bool ReadBoolFromEnvVar(StringPiece env_var_name, bool default_val) {
|
||||
|
||||
EagerContext::EagerContext(const SessionOptions& opts,
|
||||
ContextDevicePlacementPolicy default_policy,
|
||||
bool async, std::unique_ptr<DeviceMgr> device_mgr,
|
||||
bool async,
|
||||
std::unique_ptr<const DeviceMgr> device_mgr,
|
||||
Rendezvous* rendezvous)
|
||||
: EagerContext(opts, default_policy, async, device_mgr.release(),
|
||||
/*device_mgr_owned*/ true, rendezvous) {}
|
||||
|
||||
EagerContext::EagerContext(const SessionOptions& opts,
|
||||
ContextDevicePlacementPolicy default_policy,
|
||||
bool async, const DeviceMgr* device_mgr,
|
||||
bool device_mgr_owned, Rendezvous* rendezvous)
|
||||
: policy_(default_policy),
|
||||
local_device_manager_(std::move(device_mgr)),
|
||||
local_unowned_device_manager_(nullptr),
|
||||
devices_(local_device_manager_->ListDevices()),
|
||||
devices_(device_mgr->ListDevices()),
|
||||
rendezvous_(rendezvous),
|
||||
thread_pool_(NewThreadPoolFromSessionOptions(opts)),
|
||||
pflr_(new ProcessFunctionLibraryRuntime(
|
||||
local_device_manager_.get(), opts.env, TF_GRAPH_DEF_VERSION,
|
||||
&func_lib_def_, {}, thread_pool_.get())),
|
||||
device_mgr, opts.env, TF_GRAPH_DEF_VERSION, &func_lib_def_, {},
|
||||
thread_pool_.get())),
|
||||
log_device_placement_(opts.config.log_device_placement()),
|
||||
num_active_steps_(0),
|
||||
async_default_(async),
|
||||
env_(opts.env),
|
||||
use_send_tensor_rpc_(false) {
|
||||
if (device_mgr_owned) {
|
||||
local_device_manager_.reset(device_mgr);
|
||||
local_unowned_device_manager_ = nullptr;
|
||||
} else {
|
||||
local_unowned_device_manager_ = device_mgr;
|
||||
}
|
||||
InitDeviceMapAndAsync();
|
||||
if (opts.config.inter_op_parallelism_threads() > 0) {
|
||||
runner_ = [this](std::function<void()> closure) {
|
||||
|
@ -65,10 +65,17 @@ enum ContextDevicePlacementPolicy {
|
||||
|
||||
class EagerContext {
|
||||
public:
|
||||
explicit EagerContext(const SessionOptions& opts,
|
||||
ContextDevicePlacementPolicy default_policy, bool async,
|
||||
std::unique_ptr<DeviceMgr> device_mgr,
|
||||
Rendezvous* rendezvous);
|
||||
// TODO: remove this constructor once we migrate all callers to the next one.
|
||||
EagerContext(const SessionOptions& opts,
|
||||
ContextDevicePlacementPolicy default_policy, bool async,
|
||||
std::unique_ptr<const DeviceMgr> device_mgr,
|
||||
Rendezvous* rendezvous);
|
||||
|
||||
EagerContext(const SessionOptions& opts,
|
||||
ContextDevicePlacementPolicy default_policy, bool async,
|
||||
const DeviceMgr* device_mgr, bool device_mgr_owned,
|
||||
Rendezvous* rendezvous);
|
||||
|
||||
~EagerContext();
|
||||
|
||||
// Returns the function library runtime for the given device.
|
||||
@ -207,8 +214,8 @@ class EagerContext {
|
||||
thread_local_policies_ GUARDED_BY(policy_map_mu_);
|
||||
|
||||
// Only one of the below is set.
|
||||
std::unique_ptr<DeviceMgr> local_device_manager_;
|
||||
DeviceMgr* local_unowned_device_manager_;
|
||||
std::unique_ptr<const DeviceMgr> local_device_manager_;
|
||||
const DeviceMgr* local_unowned_device_manager_;
|
||||
std::unique_ptr<DeviceMgr> remote_device_manager_;
|
||||
|
||||
// Devices owned by device_manager
|
||||
|
Loading…
x
Reference in New Issue
Block a user