Added experimental C APIs based on eager, as a first step towards using eager
based runtime in Swift for Tensorflow. PiperOrigin-RevId: 211892308
This commit is contained in:
parent
d79a9d5595
commit
e583f1090f
@ -79,6 +79,18 @@ TF_Buffer* TF_CreateConfig(unsigned char enable_xla_compilation,
|
||||
auto* gpu_options = config.mutable_gpu_options();
|
||||
gpu_options->set_allow_growth(gpu_memory_allow_growth);
|
||||
|
||||
// TODO(b/113217601): This is needed for EagerContext::runner_ to use a
|
||||
// threadpool, so that we avoid the possibility of running the runner_ in the
|
||||
// threadpool of GPU event mgr, as that can trigger more callbacks to be
|
||||
// scheduled on that same threadpool, causing a deadlock in cases where the
|
||||
// caller of event_mgr->ThenExecute() blocks on the completion of the callback
|
||||
// (as in the case of ConstOp kernel creation on GPU, which involves copying a
|
||||
// CPU tensor to GPU).
|
||||
// Setting a larger thread pool does not help with the Swift caller, as we use
|
||||
// a different TFE context for each thread of execution (for running graph
|
||||
// functions, and their send/recvs corountines).
|
||||
config.set_inter_op_parallelism_threads(1);
|
||||
|
||||
TF_Buffer* ret = TF_NewBuffer();
|
||||
TF_CHECK_OK(MessageToBuffer(config, ret));
|
||||
return ret;
|
||||
@ -8494,3 +8506,201 @@ void TF_EnqueueNamedTensor(TF_Session* session, int tensor_id,
|
||||
/*run_metadata*/ nullptr, status);
|
||||
VLOG(1) << "Enqueuing is done.";
|
||||
}
|
||||
|
||||
TFE_Context* TFE_CreateContextFromSession(TF_Session* session,
|
||||
TF_Status* status) {
|
||||
auto* opts = TFE_NewContextOptions();
|
||||
|
||||
// Reduce GPU memory allocation, and set appropriate config options for TFE
|
||||
// context.
|
||||
auto* config =
|
||||
TF_CreateConfig(/*xla*/ false, /* gpu_memory_allow_growth */ true);
|
||||
TFE_ContextOptionsSetConfig(opts, config->data, config->length, status);
|
||||
if (!status->status.ok()) {
|
||||
CHECK(!config);
|
||||
TFE_DeleteContextOptions(opts);
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
auto* ctx = TFE_NewContextFromSession(opts, session, status);
|
||||
TF_DeleteBuffer(config);
|
||||
TFE_DeleteContextOptions(opts);
|
||||
return ctx;
|
||||
}
|
||||
|
||||
// TODO: retrieve the device string via TFE_ContextListDevices()
|
||||
static const char DEFAULT_CPU_DEVICE[] =
|
||||
"/job:localhost/replica:0/task:0/device:CPU:0";
|
||||
|
||||
static TFE_TensorHandle* createTFEQueue(TFE_Context* ctx, TF_DataType inputType,
|
||||
int tensor_id, TF_Status* status) {
|
||||
std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> queueOp(
|
||||
TFE_NewOp(ctx, "FIFOQueueV2", status), TFE_DeleteOp);
|
||||
TFE_OpSetDevice(queueOp.get(), DEFAULT_CPU_DEVICE, status);
|
||||
if (!status->status.ok()) return nullptr;
|
||||
// TODO: use NAMED_TENSOR_QUEUE_CAPACITY in S4TF compiler.
|
||||
TFE_OpSetAttrInt(queueOp.get(), "capacity", 1);
|
||||
TFE_OpSetAttrTypeList(queueOp.get(), "component_types", &inputType, 1);
|
||||
auto shared_name = tensorflow::strings::StrCat("fifo_queue_", tensor_id);
|
||||
TFE_OpSetAttrString(queueOp.get(), "shared_name", shared_name.data(),
|
||||
shared_name.size());
|
||||
TFE_OpSetAttrString(queueOp.get(), "container", "", 0);
|
||||
|
||||
// TODO: consider making this an unknown shape.
|
||||
const int64_t* dims_ptr = nullptr;
|
||||
int num_dims = 0;
|
||||
TFE_OpSetAttrShapeList(queueOp.get(), "shapes", &dims_ptr, &num_dims,
|
||||
/*num_values*/ 0, status);
|
||||
if (!status->status.ok()) return nullptr;
|
||||
|
||||
int num_retvals = 1;
|
||||
TFE_TensorHandle* queue = nullptr;
|
||||
TFE_Execute(queueOp.get(), &queue, &num_retvals, status);
|
||||
if (!status->status.ok()) return nullptr;
|
||||
CHECK_EQ(num_retvals, 1);
|
||||
|
||||
return queue;
|
||||
}
|
||||
|
||||
static void createTFEEnqueue(TFE_Context* ctx, TF_DataType inputType,
|
||||
TFE_TensorHandle* queue, TFE_TensorHandle* tensor,
|
||||
TF_Status* status) {
|
||||
TFE_Op* op = TFE_NewOp(ctx, "QueueEnqueueV2", status);
|
||||
if (!status->status.ok()) return;
|
||||
std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op_deleter(op, TFE_DeleteOp);
|
||||
TFE_OpSetDevice(op, DEFAULT_CPU_DEVICE, status);
|
||||
if (!status->status.ok()) return;
|
||||
TFE_OpAddInput(op, queue, status);
|
||||
if (!status->status.ok()) return;
|
||||
TFE_OpAddInput(op, tensor, status);
|
||||
if (!status->status.ok()) return;
|
||||
TFE_OpSetAttrTypeList(op, "Tcomponents", &inputType, 1);
|
||||
TFE_OpSetAttrInt(op, "timeout_ms", -1);
|
||||
|
||||
int num_retvals = 0;
|
||||
TFE_Execute(op, nullptr /*retvals*/, &num_retvals, status);
|
||||
if (!status->status.ok()) return;
|
||||
CHECK_EQ(num_retvals, 0);
|
||||
}
|
||||
|
||||
static TFE_TensorHandle* createTFEDequeue(TFE_Context* ctx,
|
||||
TF_DataType inputType,
|
||||
TFE_TensorHandle* queue,
|
||||
TF_Status* status) {
|
||||
TFE_Op* op = TFE_NewOp(ctx, "QueueDequeueV2", status);
|
||||
if (!status->status.ok()) return nullptr;
|
||||
std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op_deleter(op, TFE_DeleteOp);
|
||||
TFE_OpSetDevice(op, DEFAULT_CPU_DEVICE, status);
|
||||
if (!status->status.ok()) return nullptr;
|
||||
|
||||
TFE_OpAddInput(op, queue, status);
|
||||
if (!status->status.ok()) return nullptr;
|
||||
TFE_OpSetAttrTypeList(op, "component_types", &inputType, 1);
|
||||
TFE_OpSetAttrInt(op, "timeout_ms", -1);
|
||||
TFE_TensorHandle* ret;
|
||||
int num_retvals = 1;
|
||||
TFE_Execute(op, &ret, &num_retvals, status);
|
||||
if (!status->status.ok()) return nullptr;
|
||||
CHECK_EQ(num_retvals, 1);
|
||||
return ret;
|
||||
}
|
||||
|
||||
TFE_TensorHandle* TFE_DequeueNamedTensor(TF_Session* session, int tensor_id,
|
||||
TF_DataType inputType,
|
||||
TF_Status* status) {
|
||||
assert(session);
|
||||
VLOG(1) << "Dequeuing data tensor with id " << tensor_id;
|
||||
|
||||
auto ctx = TFE_CreateContextFromSession(session, status);
|
||||
if (!status->status.ok()) return nullptr;
|
||||
std::unique_ptr<TFE_Context, decltype(&TFE_DeleteContext)> ctx_deleter(
|
||||
ctx, TFE_DeleteContext);
|
||||
|
||||
TFE_TensorHandle* queue = createTFEQueue(ctx, inputType, tensor_id, status);
|
||||
if (!status->status.ok()) return nullptr;
|
||||
std::unique_ptr<TFE_TensorHandle, decltype(&TFE_DeleteTensorHandle)>
|
||||
queue_deleter(queue, TFE_DeleteTensorHandle);
|
||||
|
||||
auto* ret = createTFEDequeue(ctx, inputType, queue, status);
|
||||
return ret;
|
||||
}
|
||||
|
||||
TFE_TensorHandle* TFE_DequeueNamedTensorFromCtx(TFE_Context* ctx, int tensor_id,
|
||||
TF_DataType inputType,
|
||||
TF_Status* status) {
|
||||
TFE_TensorHandle* queue = createTFEQueue(ctx, inputType, tensor_id, status);
|
||||
if (!status->status.ok()) return nullptr;
|
||||
std::unique_ptr<TFE_TensorHandle, decltype(&TFE_DeleteTensorHandle)>
|
||||
queue_deleter(queue, TFE_DeleteTensorHandle);
|
||||
|
||||
auto* ret = createTFEDequeue(ctx, inputType, queue, status);
|
||||
|
||||
return ret;
|
||||
}
|
||||
|
||||
void TFE_EnqueueNamedTensor(TF_Session* session, int tensor_id,
|
||||
TFE_TensorHandle* tensor, TF_Status* status) {
|
||||
assert(session);
|
||||
VLOG(1) << "Enqueuing data tensor with id " << tensor_id;
|
||||
|
||||
auto ctx = TFE_CreateContextFromSession(session, status);
|
||||
if (!status->status.ok()) return;
|
||||
std::unique_ptr<TFE_Context, decltype(&TFE_DeleteContext)> ctx_deleter(
|
||||
ctx, TFE_DeleteContext);
|
||||
|
||||
TF_DataType inputType = TFE_TensorHandleDataType(tensor);
|
||||
TFE_TensorHandle* queue = createTFEQueue(ctx, inputType, tensor_id, status);
|
||||
if (!status->status.ok()) return;
|
||||
std::unique_ptr<TFE_TensorHandle, decltype(&TFE_DeleteTensorHandle)>
|
||||
queue_deleter(queue, TFE_DeleteTensorHandle);
|
||||
|
||||
createTFEEnqueue(ctx, inputType, queue, tensor, status);
|
||||
}
|
||||
|
||||
void TFE_EnqueueNamedTensorFromCtx(TFE_Context* ctx, int tensor_id,
|
||||
TFE_TensorHandle* tensor,
|
||||
TF_Status* status) {
|
||||
VLOG(1) << "Enqueuing data tensor with id " << tensor_id;
|
||||
|
||||
TF_DataType inputType = TFE_TensorHandleDataType(tensor);
|
||||
TFE_TensorHandle* queue = createTFEQueue(ctx, inputType, tensor_id, status);
|
||||
if (!status->status.ok()) return;
|
||||
std::unique_ptr<TFE_TensorHandle, decltype(&TFE_DeleteTensorHandle)>
|
||||
queue_deleter(queue, TFE_DeleteTensorHandle);
|
||||
|
||||
createTFEEnqueue(ctx, inputType, queue, tensor, status);
|
||||
}
|
||||
|
||||
void TFE_EnqueueVariantTensor(TF_Session* session, int tensor_id,
|
||||
TFE_TensorHandle* tensor, TF_Status* status) {
|
||||
VLOG(1) << "Enqueuing variant tensor with id " << tensor_id;
|
||||
|
||||
auto ctx = TFE_CreateContextFromSession(session, status);
|
||||
if (!status->status.ok()) return;
|
||||
std::unique_ptr<TFE_Context, decltype(&TFE_DeleteContext)> ctx_deleter(
|
||||
ctx, TFE_DeleteContext);
|
||||
|
||||
TFE_TensorHandle* queue = createTFEQueue(ctx, TF_VARIANT, tensor_id, status);
|
||||
if (!status->status.ok()) return;
|
||||
std::unique_ptr<TFE_TensorHandle, decltype(&TFE_DeleteTensorHandle)>
|
||||
queue_deleter(queue, TFE_DeleteTensorHandle);
|
||||
|
||||
createTFEEnqueue(ctx, TF_VARIANT, queue, tensor, status);
|
||||
}
|
||||
|
||||
TFE_TensorHandle* TFE_DequeueVariantTensor(TF_Session* session, int tensor_id,
|
||||
TF_Status* status) {
|
||||
VLOG(1) << "Dequeuing variant tensor with id " << tensor_id;
|
||||
|
||||
auto ctx = TFE_CreateContextFromSession(session, status);
|
||||
if (!status->status.ok()) return nullptr;
|
||||
std::unique_ptr<TFE_Context, decltype(&TFE_DeleteContext)> ctx_deleter(
|
||||
ctx, TFE_DeleteContext);
|
||||
|
||||
TFE_TensorHandle* queue = createTFEQueue(ctx, TF_VARIANT, tensor_id, status);
|
||||
if (!status->status.ok()) return nullptr;
|
||||
std::unique_ptr<TFE_TensorHandle, decltype(&TFE_DeleteTensorHandle)>
|
||||
queue_deleter(queue, TFE_DeleteTensorHandle);
|
||||
|
||||
return createTFEDequeue(ctx, TF_VARIANT, queue, status);
|
||||
}
|
||||
|
@ -132,9 +132,48 @@ TF_CAPI_EXPORT extern void TF_EnqueueNamedTensor(TF_Session* session,
|
||||
TF_Tensor* tensor,
|
||||
TF_Status* status);
|
||||
|
||||
// TODO: remove this API in favor of the next one.
|
||||
TF_CAPI_EXPORT extern TFE_Context* TFE_NewContextFromSession(
|
||||
const TFE_ContextOptions* opts, TF_Session* sess, TF_Status* status);
|
||||
|
||||
// Creates from `session` a new eager context to run a graph function or
|
||||
// sends/recvs, so that these concurrent TFE executions can share (via
|
||||
// `session` and its associated device mgr) the same set of fifo queue resource
|
||||
// ops, used for host<->TF tensor transfers. This way the sends/recvs calls and
|
||||
// graph function execution can access the same fifo queue resource handles
|
||||
// (associated with devices managed by the device manager, which can be obtained
|
||||
// from `session`).
|
||||
//
|
||||
// TODO: Remove this function once we migrate away from using session.
|
||||
TF_CAPI_EXPORT extern TFE_Context* TFE_CreateContextFromSession(
|
||||
TF_Session* session, TF_Status* status);
|
||||
|
||||
// TODO: Retire this API in favor of the next one.
|
||||
TF_CAPI_EXPORT extern TFE_TensorHandle* TFE_DequeueNamedTensor(
|
||||
TF_Session* session, int tensor_id, TF_DataType inputType,
|
||||
TF_Status* status);
|
||||
|
||||
TF_CAPI_EXPORT extern TFE_TensorHandle* TFE_DequeueNamedTensorFromCtx(
|
||||
TFE_Context* ctx, int tensor_id, TF_DataType inputType, TF_Status* status);
|
||||
|
||||
TF_CAPI_EXPORT extern void TFE_EnqueueNamedTensor(TF_Session* session,
|
||||
int tensor_id,
|
||||
TFE_TensorHandle* tensor,
|
||||
TF_Status* status);
|
||||
|
||||
TF_CAPI_EXPORT extern void TFE_EnqueueNamedTensorFromCtx(
|
||||
TFE_Context* ctx, int tensor_id, TFE_TensorHandle* tensor,
|
||||
TF_Status* status);
|
||||
|
||||
// TODO: consider folding the 2 APIs below into the ones above.
|
||||
TF_CAPI_EXPORT extern void TFE_EnqueueVariantTensor(TF_Session* session,
|
||||
int tensor_id,
|
||||
TFE_TensorHandle* tensor,
|
||||
TF_Status* status);
|
||||
|
||||
TF_CAPI_EXPORT extern TFE_TensorHandle* TFE_DequeueVariantTensor(
|
||||
TF_Session* session, int tensor_id, TF_Status* status);
|
||||
|
||||
#ifdef __cplusplus
|
||||
} /* end extern "C" */
|
||||
#endif
|
||||
|
Loading…
x
Reference in New Issue
Block a user