1166 lines
44 KiB
C++
1166 lines
44 KiB
C++
/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
|
|
|
|
Licensed under the Apache License, Version 2.0 (the "License");
|
|
you may not use this file except in compliance with the License.
|
|
You may obtain a copy of the License at
|
|
|
|
http://www.apache.org/licenses/LICENSE-2.0
|
|
|
|
Unless required by applicable law or agreed to in writing, software
|
|
distributed under the License is distributed on an "AS IS" BASIS,
|
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
See the License for the specific language governing permissions and
|
|
limitations under the License.
|
|
==============================================================================*/
|
|
|
|
#include "tensorflow/c/c_api_experimental.h"
|
|
|
|
#include "absl/strings/substitute.h"
|
|
#include "tensorflow/c/c_api.h"
|
|
#include "tensorflow/c/c_api_internal.h"
|
|
#include "tensorflow/c/checkpoint_reader.h"
|
|
#include "tensorflow/c/eager/c_api.h"
|
|
#include "tensorflow/c/eager/c_api_internal.h"
|
|
#include "tensorflow/compiler/jit/flags.h"
|
|
#include "tensorflow/core/common_runtime/eager/attr_builder.h"
|
|
#include "tensorflow/core/distributed_runtime/rpc/grpc_server_lib.h"
|
|
#include "tensorflow/core/framework/node_def.pb.h"
|
|
#include "tensorflow/core/framework/shape_inference.h"
|
|
#include "tensorflow/core/framework/tensor.pb.h"
|
|
#include "tensorflow/core/graph/graph.h"
|
|
#include "tensorflow/core/graph/node_builder.h"
|
|
#include "tensorflow/core/lib/strings/strcat.h"
|
|
#include "tensorflow/core/platform/init_main.h"
|
|
#include "tensorflow/core/platform/net.h"
|
|
#include "tensorflow/core/platform/platform.h"
|
|
#include "tensorflow/core/protobuf/config.pb.h"
|
|
#include "tensorflow/core/protobuf/tensorflow_server.pb.h"
|
|
|
|
using tensorflow::FunctionDef;
|
|
using tensorflow::Node;
|
|
using tensorflow::NodeBuilder;
|
|
using tensorflow::Status;
|
|
using tensorflow::errors::InvalidArgument;
|
|
|
|
namespace {
|
|
typedef std::unique_ptr<TF_Function, decltype(&TF_DeleteFunction)>
|
|
UniqueFuncPtr;
|
|
}
|
|
|
|
// struct TF_Operation { tensorflow::Node node; };
|
|
static TF_Operation* ToTF_Operation(Node* node) {
|
|
return static_cast<TF_Operation*>(static_cast<void*>(node));
|
|
}
|
|
|
|
void TF_EnableXLACompilation(TF_SessionOptions* options, unsigned char enable) {
|
|
tensorflow::ConfigProto& config = options->options.config;
|
|
auto* optimizer_options =
|
|
config.mutable_graph_options()->mutable_optimizer_options();
|
|
if (enable) {
|
|
optimizer_options->set_global_jit_level(tensorflow::OptimizerOptions::ON_1);
|
|
|
|
// These XLA flags are needed to trigger XLA properly from C (more generally
|
|
// non-Python) clients. If this API is called again with `enable` set to
|
|
// false, it is safe to keep these flag values as is.
|
|
tensorflow::MarkForCompilationPassFlags* flags =
|
|
tensorflow::GetMarkForCompilationPassFlags();
|
|
flags->tf_xla_cpu_global_jit = true;
|
|
flags->tf_xla_min_cluster_size = 1;
|
|
} else {
|
|
optimizer_options->set_global_jit_level(tensorflow::OptimizerOptions::OFF);
|
|
}
|
|
}
|
|
|
|
unsigned char TF_SetXlaEnableLazyCompilation(unsigned char enable) {
|
|
tensorflow::BuildXlaOpsPassFlags* flags =
|
|
tensorflow::GetBuildXlaOpsPassFlags();
|
|
bool original = flags->tf_xla_enable_lazy_compilation;
|
|
flags->tf_xla_enable_lazy_compilation = enable;
|
|
return original;
|
|
}
|
|
|
|
void TF_SetXlaAutoJitMode(const char* mode) {
|
|
tensorflow::SetXlaAutoJitFlagFromFlagString(mode);
|
|
}
|
|
|
|
unsigned char TF_GetXlaConstantFoldingDisabled() {
|
|
return static_cast<unsigned char>(
|
|
tensorflow::GetBuildXlaOpsPassFlags()->tf_xla_disable_constant_folding);
|
|
}
|
|
|
|
void TF_SetXlaConstantFoldingDisabled(unsigned char should_enable) {
|
|
tensorflow::GetBuildXlaOpsPassFlags()->tf_xla_disable_constant_folding =
|
|
static_cast<bool>(should_enable);
|
|
}
|
|
|
|
void TF_SetXlaMinClusterSize(int size) {
|
|
tensorflow::MarkForCompilationPassFlags* flags =
|
|
tensorflow::GetMarkForCompilationPassFlags();
|
|
flags->tf_xla_min_cluster_size = size;
|
|
}
|
|
|
|
TF_Buffer* TF_CreateConfig(unsigned char enable_xla_compilation,
|
|
unsigned char gpu_memory_allow_growth,
|
|
unsigned int num_cpu_devices) {
|
|
tensorflow::ConfigProto config;
|
|
auto* optimizer_options =
|
|
config.mutable_graph_options()->mutable_optimizer_options();
|
|
if (enable_xla_compilation) {
|
|
optimizer_options->set_global_jit_level(tensorflow::OptimizerOptions::ON_1);
|
|
|
|
// These XLA flags are needed to trigger XLA properly from C (more generally
|
|
// non-Python) clients. If this API is called again with `enable` set to
|
|
// false, it is safe to keep these flag values as is.
|
|
tensorflow::MarkForCompilationPassFlags* flags =
|
|
tensorflow::GetMarkForCompilationPassFlags();
|
|
flags->tf_xla_cpu_global_jit = true;
|
|
flags->tf_xla_min_cluster_size = 1;
|
|
} else {
|
|
optimizer_options->set_global_jit_level(tensorflow::OptimizerOptions::OFF);
|
|
}
|
|
|
|
auto* gpu_options = config.mutable_gpu_options();
|
|
gpu_options->set_allow_growth(gpu_memory_allow_growth);
|
|
|
|
(*config.mutable_device_count())["CPU"] = num_cpu_devices;
|
|
|
|
// TODO(b/113217601): This is needed for EagerContext::runner_ to use a
|
|
// threadpool, so that we avoid the possibility of running the runner_ in the
|
|
// threadpool of GPU event mgr, as that can trigger more callbacks to be
|
|
// scheduled on that same threadpool, causing a deadlock in cases where the
|
|
// caller of event_mgr->ThenExecute() blocks on the completion of the callback
|
|
// (as in the case of ConstOp kernel creation on GPU, which involves copying a
|
|
// CPU tensor to GPU).
|
|
// Setting a larger thread pool does not help with the Swift caller, as we use
|
|
// a different TFE context for each thread of execution (for running graph
|
|
// functions, and their send/recvs corountines).
|
|
config.set_inter_op_parallelism_threads(1);
|
|
|
|
TF_Buffer* ret = TF_NewBuffer();
|
|
TF_CHECK_OK(MessageToBuffer(config, ret));
|
|
return ret;
|
|
}
|
|
|
|
TF_Buffer* TF_CreateRunOptions(unsigned char enable_full_trace) {
|
|
tensorflow::RunOptions options;
|
|
if (enable_full_trace) {
|
|
options.set_trace_level(tensorflow::RunOptions::FULL_TRACE);
|
|
} else {
|
|
options.set_trace_level(tensorflow::RunOptions::NO_TRACE);
|
|
}
|
|
TF_Buffer* ret = TF_NewBuffer();
|
|
TF_CHECK_OK(MessageToBuffer(options, ret));
|
|
return ret;
|
|
}
|
|
|
|
const char* TF_GraphDebugString(TF_Graph* graph, size_t* len) {
|
|
tensorflow::mutex_lock c(graph->mu);
|
|
const auto& debug_str = graph->graph.ToGraphDefDebug().DebugString();
|
|
*len = debug_str.size();
|
|
char* ret = static_cast<char*>(malloc(*len + 1));
|
|
memcpy(ret, debug_str.c_str(), *len + 1);
|
|
return ret;
|
|
}
|
|
|
|
char* TF_FunctionDebugString(TF_Function* func, size_t* len) {
|
|
const auto& debug_str = DebugString(func->fdef);
|
|
*len = debug_str.size();
|
|
char* ret = static_cast<char*>(malloc(*len + 1));
|
|
memcpy(ret, debug_str.c_str(), *len + 1);
|
|
return ret;
|
|
}
|
|
|
|
// On success, returns a set of TF_Function instances from `text_proto` of
|
|
// GraphDef type. These functions must be deleted by calling TF_DeleteFunction.
|
|
//
|
|
// If `mutate_proto_func` is non-NULL, run it over each FunctionDef proto,
|
|
// before creating a TF_Function out of the possibly mutated proto.
|
|
static std::vector<UniqueFuncPtr> CreateFunctionsFromTextProto(
|
|
const char* text_proto,
|
|
std::function<void(FunctionDef*)>* mutate_proto_func, TF_Status* status) {
|
|
tensorflow::GraphDef gdef;
|
|
if (!tensorflow::protobuf::TextFormat::ParseFromString(text_proto, &gdef)) {
|
|
status->status = tensorflow::errors::Internal(
|
|
"Invalid text proto for GraphDef: ", text_proto);
|
|
return {};
|
|
}
|
|
const auto& fdef_lib = gdef.library();
|
|
if (fdef_lib.gradient_size() > 0) {
|
|
status->status = tensorflow::errors::Internal(
|
|
"GradientDef is not supported in reading Dataset related functions: ",
|
|
text_proto);
|
|
return {};
|
|
}
|
|
std::vector<UniqueFuncPtr> ret;
|
|
for (const FunctionDef& fdef : fdef_lib.function()) {
|
|
// Make a copy so that we can mutate it.
|
|
FunctionDef fdef_to_load = fdef;
|
|
if (mutate_proto_func) {
|
|
(*mutate_proto_func)(&fdef_to_load);
|
|
}
|
|
VLOG(1) << "Adding func to graph: " << fdef_to_load.DebugString();
|
|
std::vector<char> binary_proto_buf(fdef_to_load.ByteSizeLong());
|
|
fdef_to_load.SerializeToArray(binary_proto_buf.data(),
|
|
binary_proto_buf.size());
|
|
TF_Function* func = TF_FunctionImportFunctionDef(
|
|
binary_proto_buf.data(), binary_proto_buf.size(), status);
|
|
if (!status->status.ok()) return {};
|
|
ret.push_back(UniqueFuncPtr(func, TF_DeleteFunction));
|
|
}
|
|
return ret;
|
|
}
|
|
|
|
TF_Tensor* TF_DequeueNamedTensor(TF_Session* session, int tensor_id,
|
|
TF_Status* status) {
|
|
assert(session);
|
|
{
|
|
tensorflow::mutex_lock c(session->graph->mu);
|
|
VLOG(1) << "Dequeuing named tensor with id " << tensor_id
|
|
<< ", with input graph: "
|
|
<< session->graph->graph.ToGraphDefDebug().DebugString();
|
|
}
|
|
|
|
TF_Operation* dequeue_op = TF_GraphOperationByName(
|
|
session->graph,
|
|
tensorflow::strings::StrCat("fifo_queue_dequeue_", tensor_id).c_str());
|
|
if (dequeue_op == nullptr) {
|
|
status->status = tensorflow::errors::Internal(
|
|
"Unable to find the dequeue node in the TF graph.");
|
|
return nullptr;
|
|
}
|
|
|
|
VLOG(1) << "Running the dequeue op";
|
|
TF_Output output{dequeue_op, 0};
|
|
TF_Tensor* ret;
|
|
TF_SessionRun(session, /*run_options*/ nullptr,
|
|
// input related parameters
|
|
/*inputs*/ nullptr, /*input_values*/ nullptr, /*ninputs*/ 0,
|
|
// output related parameters
|
|
/*outputs*/ &output, /*output_values*/ &ret,
|
|
/*noutputs*/ 1,
|
|
/*targets*/ nullptr, /*ntargets*/ 0,
|
|
/*run_metadata*/ nullptr, status);
|
|
if (VLOG_IS_ON(1) && status->status.ok()) {
|
|
tensorflow::Tensor tensor;
|
|
if (tensorflow::TF_TensorToTensor(ret, &tensor).ok()) {
|
|
VLOG(1) << "Dequeued tensor content: " << tensor.DebugString();
|
|
}
|
|
}
|
|
return ret;
|
|
}
|
|
|
|
void TF_EnqueueNamedTensor(TF_Session* session, int tensor_id,
|
|
TF_Tensor* tensor, TF_Status* status) {
|
|
assert(session);
|
|
{
|
|
tensorflow::mutex_lock c(session->graph->mu);
|
|
if (VLOG_IS_ON(1)) {
|
|
VLOG(1) << "Enqueuing named tensor with id " << tensor_id
|
|
<< ", with input graph: "
|
|
<< session->graph->graph.ToGraphDefDebug().DebugString();
|
|
tensorflow::Tensor internal_tensor;
|
|
if (tensorflow::TF_TensorToTensor(tensor, &internal_tensor).ok()) {
|
|
VLOG(1) << "Enqueu'ing tensor content: "
|
|
<< internal_tensor.DebugString();
|
|
}
|
|
}
|
|
}
|
|
|
|
TF_Operation* enqueue_op = TF_GraphOperationByName(
|
|
session->graph,
|
|
tensorflow::strings::StrCat("fifo_queue_enqueue_", tensor_id).c_str());
|
|
if (enqueue_op == nullptr) {
|
|
status->status = tensorflow::errors::Internal(
|
|
"Unable to find the enqueue node in the TF graph.");
|
|
return;
|
|
}
|
|
|
|
TF_Operation* placeholder_op = TF_GraphOperationByName(
|
|
session->graph,
|
|
tensorflow::strings::StrCat("arg_tensor_enqueue_", tensor_id).c_str());
|
|
if (placeholder_op == nullptr) {
|
|
status->status = tensorflow::errors::Internal(
|
|
"Unable to find the placeholder node as input to enqueue in the TF "
|
|
"graph.");
|
|
return;
|
|
}
|
|
|
|
VLOG(1) << "Running the enqueue op";
|
|
TF_Output input{placeholder_op, 0};
|
|
TF_SessionRun(session, /*run_options*/ nullptr,
|
|
// input related parameters
|
|
/*inputs*/ &input, /*input_values*/ &tensor, /*ninputs*/ 1,
|
|
// output related parameters
|
|
/*outputs*/ nullptr, /*output_values*/ nullptr, /*noutputs*/ 0,
|
|
/*targets*/ &enqueue_op, /*ntargets*/ 1,
|
|
/*run_metadata*/ nullptr, status);
|
|
VLOG(1) << "Enqueuing is done.";
|
|
}
|
|
|
|
TF_Buffer* TFE_GetServerDef(const char* text_proto, TF_Status* status) {
|
|
tensorflow::ServerDef server_def;
|
|
if (!tensorflow::protobuf::TextFormat::ParseFromString(text_proto,
|
|
&server_def)) {
|
|
status->status = tensorflow::errors::Internal(
|
|
"Invalid text proto for ServerDef: ", text_proto);
|
|
return nullptr;
|
|
}
|
|
status->status = tensorflow::Status();
|
|
TF_Buffer* ret = TF_NewBuffer();
|
|
TF_CHECK_OK(MessageToBuffer(server_def, ret));
|
|
return ret;
|
|
}
|
|
|
|
TFE_Context* TFE_CreateContextFromSession(TF_Session* session,
|
|
TF_Status* status) {
|
|
auto* opts = TFE_NewContextOptions();
|
|
|
|
// Reduce GPU memory allocation, and set appropriate config options for TFE
|
|
// context.
|
|
auto* config = TF_CreateConfig(
|
|
/*xla*/ false, /* gpu_memory_allow_growth */ true, /* num_cpu_devices */
|
|
10);
|
|
TFE_ContextOptionsSetConfig(opts, config->data, config->length, status);
|
|
if (!status->status.ok()) {
|
|
CHECK(!config);
|
|
TFE_DeleteContextOptions(opts);
|
|
return nullptr;
|
|
}
|
|
|
|
auto* ctx = TFE_NewContextFromSession(opts, session, status);
|
|
TF_DeleteBuffer(config);
|
|
TFE_DeleteContextOptions(opts);
|
|
return ctx;
|
|
}
|
|
|
|
// TODO: retrieve the device string via TFE_ContextListDevices()
|
|
static const char DEFAULT_CPU_DEVICE[] =
|
|
"/job:localhost/replica:0/task:0/device:CPU:0";
|
|
|
|
static TFE_TensorHandle* createTFEQueue(TFE_Context* ctx, TF_DataType inputType,
|
|
int tensor_id, TF_Status* status) {
|
|
std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> queueOp(
|
|
TFE_NewOp(ctx, "FIFOQueueV2", status), TFE_DeleteOp);
|
|
TFE_OpSetDevice(queueOp.get(), DEFAULT_CPU_DEVICE, status);
|
|
if (!status->status.ok()) return nullptr;
|
|
// TODO: use NAMED_TENSOR_QUEUE_CAPACITY in S4TF compiler.
|
|
TFE_OpSetAttrInt(queueOp.get(), "capacity", 1);
|
|
TFE_OpSetAttrTypeList(queueOp.get(), "component_types", &inputType, 1);
|
|
auto shared_name = tensorflow::strings::StrCat("fifo_queue_", tensor_id);
|
|
TFE_OpSetAttrString(queueOp.get(), "shared_name", shared_name.data(),
|
|
shared_name.size());
|
|
TFE_OpSetAttrString(queueOp.get(), "container", "", 0);
|
|
|
|
// TODO: consider making this an unknown shape.
|
|
const int64_t* dims_ptr = nullptr;
|
|
int num_dims = 0;
|
|
TFE_OpSetAttrShapeList(queueOp.get(), "shapes", &dims_ptr, &num_dims,
|
|
/*num_values*/ 0, status);
|
|
if (!status->status.ok()) return nullptr;
|
|
|
|
int num_retvals = 1;
|
|
TFE_TensorHandle* queue = nullptr;
|
|
TFE_Execute(queueOp.get(), &queue, &num_retvals, status);
|
|
if (!status->status.ok()) return nullptr;
|
|
CHECK_EQ(num_retvals, 1);
|
|
|
|
return queue;
|
|
}
|
|
|
|
static void createTFEEnqueue(TFE_Context* ctx, TF_DataType inputType,
|
|
TFE_TensorHandle* queue, TFE_TensorHandle* tensor,
|
|
TF_Status* status) {
|
|
TFE_Op* op = TFE_NewOp(ctx, "QueueEnqueueV2", status);
|
|
if (!status->status.ok()) return;
|
|
std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op_deleter(op, TFE_DeleteOp);
|
|
TFE_OpSetDevice(op, DEFAULT_CPU_DEVICE, status);
|
|
if (!status->status.ok()) return;
|
|
TFE_OpAddInput(op, queue, status);
|
|
if (!status->status.ok()) return;
|
|
TFE_OpAddInput(op, tensor, status);
|
|
if (!status->status.ok()) return;
|
|
TFE_OpSetAttrTypeList(op, "Tcomponents", &inputType, 1);
|
|
TFE_OpSetAttrInt(op, "timeout_ms", -1);
|
|
|
|
int num_retvals = 0;
|
|
TFE_Execute(op, nullptr /*retvals*/, &num_retvals, status);
|
|
if (!status->status.ok()) return;
|
|
CHECK_EQ(num_retvals, 0);
|
|
}
|
|
|
|
static TFE_TensorHandle* createTFEDequeue(TFE_Context* ctx,
|
|
TF_DataType inputType,
|
|
TFE_TensorHandle* queue,
|
|
TF_Status* status) {
|
|
TFE_Op* op = TFE_NewOp(ctx, "QueueDequeueV2", status);
|
|
if (!status->status.ok()) return nullptr;
|
|
std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op_deleter(op, TFE_DeleteOp);
|
|
TFE_OpSetDevice(op, DEFAULT_CPU_DEVICE, status);
|
|
if (!status->status.ok()) return nullptr;
|
|
|
|
TFE_OpAddInput(op, queue, status);
|
|
if (!status->status.ok()) return nullptr;
|
|
TFE_OpSetAttrTypeList(op, "component_types", &inputType, 1);
|
|
TFE_OpSetAttrInt(op, "timeout_ms", -1);
|
|
TFE_TensorHandle* ret;
|
|
int num_retvals = 1;
|
|
TFE_Execute(op, &ret, &num_retvals, status);
|
|
if (!status->status.ok()) return nullptr;
|
|
CHECK_EQ(num_retvals, 1);
|
|
return ret;
|
|
}
|
|
|
|
TFE_TensorHandle* TFE_DequeueNamedTensor(TF_Session* session, int tensor_id,
|
|
TF_DataType inputType,
|
|
TF_Status* status) {
|
|
assert(session);
|
|
VLOG(1) << "Dequeuing data tensor with id " << tensor_id;
|
|
|
|
auto ctx = TFE_CreateContextFromSession(session, status);
|
|
if (!status->status.ok()) return nullptr;
|
|
std::unique_ptr<TFE_Context, decltype(&TFE_DeleteContext)> ctx_deleter(
|
|
ctx, TFE_DeleteContext);
|
|
|
|
TFE_TensorHandle* queue = createTFEQueue(ctx, inputType, tensor_id, status);
|
|
if (!status->status.ok()) return nullptr;
|
|
std::unique_ptr<TFE_TensorHandle, decltype(&TFE_DeleteTensorHandle)>
|
|
queue_deleter(queue, TFE_DeleteTensorHandle);
|
|
|
|
auto* ret = createTFEDequeue(ctx, inputType, queue, status);
|
|
return ret;
|
|
}
|
|
|
|
TFE_TensorHandle* TFE_DequeueNamedTensorFromCtx(TFE_Context* ctx, int tensor_id,
|
|
TF_DataType inputType,
|
|
TF_Status* status) {
|
|
TFE_TensorHandle* queue = createTFEQueue(ctx, inputType, tensor_id, status);
|
|
if (!status->status.ok()) return nullptr;
|
|
std::unique_ptr<TFE_TensorHandle, decltype(&TFE_DeleteTensorHandle)>
|
|
queue_deleter(queue, TFE_DeleteTensorHandle);
|
|
|
|
auto* ret = createTFEDequeue(ctx, inputType, queue, status);
|
|
|
|
return ret;
|
|
}
|
|
|
|
void TFE_EnqueueNamedTensor(TF_Session* session, int tensor_id,
|
|
TFE_TensorHandle* tensor, TF_Status* status) {
|
|
assert(session);
|
|
VLOG(1) << "Enqueuing data tensor with id " << tensor_id;
|
|
|
|
auto ctx = TFE_CreateContextFromSession(session, status);
|
|
if (!status->status.ok()) return;
|
|
std::unique_ptr<TFE_Context, decltype(&TFE_DeleteContext)> ctx_deleter(
|
|
ctx, TFE_DeleteContext);
|
|
|
|
TF_DataType inputType = TFE_TensorHandleDataType(tensor);
|
|
TFE_TensorHandle* queue = createTFEQueue(ctx, inputType, tensor_id, status);
|
|
if (!status->status.ok()) return;
|
|
std::unique_ptr<TFE_TensorHandle, decltype(&TFE_DeleteTensorHandle)>
|
|
queue_deleter(queue, TFE_DeleteTensorHandle);
|
|
|
|
createTFEEnqueue(ctx, inputType, queue, tensor, status);
|
|
}
|
|
|
|
void TFE_EnqueueNamedTensorFromCtx(TFE_Context* ctx, int tensor_id,
|
|
TFE_TensorHandle* tensor,
|
|
TF_Status* status) {
|
|
VLOG(1) << "Enqueuing data tensor with id " << tensor_id;
|
|
|
|
TF_DataType inputType = TFE_TensorHandleDataType(tensor);
|
|
TFE_TensorHandle* queue = createTFEQueue(ctx, inputType, tensor_id, status);
|
|
if (!status->status.ok()) return;
|
|
std::unique_ptr<TFE_TensorHandle, decltype(&TFE_DeleteTensorHandle)>
|
|
queue_deleter(queue, TFE_DeleteTensorHandle);
|
|
|
|
createTFEEnqueue(ctx, inputType, queue, tensor, status);
|
|
}
|
|
|
|
void TFE_EnqueueVariantTensor(TF_Session* session, int tensor_id,
|
|
TFE_TensorHandle* tensor, TF_Status* status) {
|
|
VLOG(1) << "Enqueuing variant tensor with id " << tensor_id;
|
|
|
|
auto ctx = TFE_CreateContextFromSession(session, status);
|
|
if (!status->status.ok()) return;
|
|
std::unique_ptr<TFE_Context, decltype(&TFE_DeleteContext)> ctx_deleter(
|
|
ctx, TFE_DeleteContext);
|
|
|
|
TFE_TensorHandle* queue = createTFEQueue(ctx, TF_VARIANT, tensor_id, status);
|
|
if (!status->status.ok()) return;
|
|
std::unique_ptr<TFE_TensorHandle, decltype(&TFE_DeleteTensorHandle)>
|
|
queue_deleter(queue, TFE_DeleteTensorHandle);
|
|
|
|
createTFEEnqueue(ctx, TF_VARIANT, queue, tensor, status);
|
|
}
|
|
|
|
TFE_TensorHandle* TFE_DequeueVariantTensor(TF_Session* session, int tensor_id,
|
|
TF_Status* status) {
|
|
VLOG(1) << "Dequeuing variant tensor with id " << tensor_id;
|
|
|
|
auto ctx = TFE_CreateContextFromSession(session, status);
|
|
if (!status->status.ok()) return nullptr;
|
|
std::unique_ptr<TFE_Context, decltype(&TFE_DeleteContext)> ctx_deleter(
|
|
ctx, TFE_DeleteContext);
|
|
|
|
TFE_TensorHandle* queue = createTFEQueue(ctx, TF_VARIANT, tensor_id, status);
|
|
if (!status->status.ok()) return nullptr;
|
|
std::unique_ptr<TFE_TensorHandle, decltype(&TFE_DeleteTensorHandle)>
|
|
queue_deleter(queue, TFE_DeleteTensorHandle);
|
|
|
|
return createTFEDequeue(ctx, TF_VARIANT, queue, status);
|
|
}
|
|
|
|
static void CheckOk(TF_Status* status) {
|
|
CHECK_EQ(TF_GetCode(status), TF_OK) << TF_Message(status);
|
|
}
|
|
|
|
void TFE_TensorHandlePrintDebugString(TFE_TensorHandle* handle) {
|
|
auto* status = TF_NewStatus();
|
|
if (!TFE_TensorHandleIsConcrete(handle)) {
|
|
VLOG(1) << "Symbolic tensor: " << handle;
|
|
TF_DeleteStatus(status);
|
|
return;
|
|
}
|
|
|
|
TF_Tensor* t = TFE_TensorHandleResolve(handle, status);
|
|
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
|
|
|
tensorflow::Tensor dst;
|
|
TF_CHECK_OK(TF_TensorToTensor(t, &dst));
|
|
LOG(INFO) << dst.DebugString();
|
|
|
|
TF_DeleteTensor(t);
|
|
TF_DeleteStatus(status);
|
|
}
|
|
|
|
void TFE_OpPrintDebugString(TFE_Op* op) {
|
|
VLOG(1) << "TFE_OpPrintDebugString() over " << op;
|
|
LOG(INFO) << op->operation.DebugString();
|
|
}
|
|
|
|
struct TFE_ExecuteOpNotification {
|
|
TFE_ExecuteOpNotification() : status(TF_NewStatus(), TF_DeleteStatus) {}
|
|
tensorflow::Notification n;
|
|
std::unique_ptr<tensorflow::Thread> thread;
|
|
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status;
|
|
};
|
|
|
|
TFE_ExecuteOpNotification* TFE_ExecuteOpInNewThread(TFE_Op* op,
|
|
TFE_TensorHandle** retvals,
|
|
int* num_retvals,
|
|
TF_Status* status) {
|
|
TFE_ExecuteOpNotification* n = new TFE_ExecuteOpNotification;
|
|
|
|
n->thread.reset(op->operation.EagerContext()->TFEnv()->StartThread(
|
|
tensorflow::ThreadOptions(), "ExecuteOpThread",
|
|
[op, retvals, num_retvals, n]() {
|
|
TFE_Execute(op, retvals, num_retvals, n->status.get());
|
|
n->n.Notify();
|
|
}));
|
|
|
|
return n;
|
|
}
|
|
|
|
void TFE_ExecuteOpNotificationWaitAndDelete(
|
|
TFE_ExecuteOpNotification* notification, TF_Status* status) {
|
|
if (notification == nullptr) {
|
|
status->status = tensorflow::errors::InvalidArgument(
|
|
"Passed in notification is a nullptr.");
|
|
|
|
return;
|
|
}
|
|
if (notification->thread == nullptr) {
|
|
status->status = tensorflow::errors::InvalidArgument(
|
|
"Passed in notification didn't start a thread correctly. Cleaning up "
|
|
"this notification. Please re-execute the operation to get a new "
|
|
"notification.");
|
|
|
|
delete notification;
|
|
return;
|
|
}
|
|
|
|
notification->n.WaitForNotification();
|
|
|
|
status->status = notification->status->status;
|
|
|
|
delete notification;
|
|
}
|
|
|
|
void TF_MakeInternalErrorStatus(TF_Status* status, const char* errMsg) {
|
|
status->status = tensorflow::errors::Internal(errMsg);
|
|
}
|
|
|
|
struct TF_CheckpointReader : public tensorflow::checkpoint::CheckpointReader {
|
|
using tensorflow::checkpoint::CheckpointReader::CheckpointReader;
|
|
std::vector<std::string> variable_list;
|
|
};
|
|
|
|
TF_CheckpointReader* TF_NewCheckpointReader(const char* filename,
|
|
TF_Status* status) {
|
|
TF_CheckpointReader* reader = new TF_CheckpointReader(filename, status);
|
|
if (!status->status.ok()) {
|
|
TF_DeleteCheckpointReader(reader);
|
|
return nullptr;
|
|
}
|
|
const auto& m = reader->GetVariableToDataTypeMap();
|
|
for (auto it = m.begin(); it != m.end(); ++it)
|
|
reader->variable_list.push_back(it->first);
|
|
std::sort(reader->variable_list.begin(), reader->variable_list.end());
|
|
return reader;
|
|
}
|
|
|
|
void TF_DeleteCheckpointReader(TF_CheckpointReader* reader) { delete reader; }
|
|
|
|
int TF_CheckpointReaderHasTensor(TF_CheckpointReader* reader,
|
|
const char* name) {
|
|
return reader->HasTensor(name);
|
|
}
|
|
|
|
const char* TF_CheckpointReaderGetVariable(TF_CheckpointReader* reader,
|
|
int index) {
|
|
return reader->variable_list[index].c_str();
|
|
}
|
|
|
|
int TF_CheckpointReaderSize(TF_CheckpointReader* reader) {
|
|
return reader->variable_list.size();
|
|
}
|
|
|
|
TF_DataType TF_CheckpointReaderGetVariableDataType(TF_CheckpointReader* reader,
|
|
const char* name) {
|
|
const auto& m = reader->GetVariableToDataTypeMap();
|
|
return static_cast<TF_DataType>(m.at(name));
|
|
}
|
|
|
|
TF_Tensor* TF_CheckpointReaderGetTensor(TF_CheckpointReader* reader,
|
|
const char* name, TF_Status* status) {
|
|
std::unique_ptr<tensorflow::Tensor> tensor;
|
|
reader->GetTensor(name, &tensor, status);
|
|
if (!status->status.ok()) return nullptr;
|
|
return tensorflow::TF_TensorFromTensor(*tensor, status);
|
|
}
|
|
|
|
void TF_CheckpointReaderGetVariableShape(TF_CheckpointReader* reader,
|
|
const char* name, int64_t* dims,
|
|
int num_dims, TF_Status* status) {
|
|
const auto& shape = reader->GetVariableToShapeMap().at(name);
|
|
int rank = shape.dims();
|
|
if (num_dims != rank) {
|
|
status->status = InvalidArgument("Expected rank is ", num_dims,
|
|
" but actual rank is ", rank);
|
|
return;
|
|
}
|
|
for (int i = 0; i < num_dims; i++) {
|
|
dims[i] = shape.dim_size(i);
|
|
}
|
|
}
|
|
|
|
int TF_CheckpointReaderGetVariableNumDims(TF_CheckpointReader* reader,
|
|
const char* name) {
|
|
const auto& m = reader->GetVariableToShapeMap();
|
|
return m.at(name).dims();
|
|
}
|
|
|
|
// This builder is used in the eager API to build a NodeDef.
|
|
struct TF_AttrBuilder : public tensorflow::AttrBuilder {
|
|
using tensorflow::AttrBuilder::AttrBuilder;
|
|
// The string buffers to make sure that any `attr_name` we pass into
|
|
// `builder->Set()` will outlive the subsequent
|
|
// `TF_AttrBuilderCheckCanRunOnDevice()` call(s) on the same `builder`.
|
|
std::set<std::string> attr_names;
|
|
};
|
|
|
|
TF_AttrBuilder* TF_NewAttrBuilder(const char* op_name) {
|
|
return new TF_AttrBuilder(op_name);
|
|
}
|
|
|
|
void TF_DeleteAttrBuilder(TF_AttrBuilder* builder) { delete builder; }
|
|
|
|
void TF_AttrBuilderSetType(TF_AttrBuilder* builder, const char* attr_name,
|
|
TF_DataType value) {
|
|
auto iter = builder->attr_names.insert(attr_name).first;
|
|
builder->Set(*iter, static_cast<tensorflow::DataType>(value));
|
|
}
|
|
|
|
void TF_AttrBuilderSetTypeList(TF_AttrBuilder* builder, const char* attr_name,
|
|
const TF_DataType* values, int num_values) {
|
|
auto iter = builder->attr_names.insert(attr_name).first;
|
|
builder->Set(
|
|
(*iter).c_str(),
|
|
tensorflow::gtl::ArraySlice<const tensorflow::DataType>(
|
|
reinterpret_cast<const tensorflow::DataType*>(values), num_values));
|
|
}
|
|
|
|
void TF_AttrBuilderCheckCanRunOnDevice(TF_AttrBuilder* builder,
|
|
const char* device_type,
|
|
TF_Status* status) {
|
|
status->status = tensorflow::FindKernelDef(
|
|
tensorflow::DeviceType(device_type), builder->BuildNodeDef(),
|
|
/* def = */ nullptr, /* kernel_class_name = */ nullptr);
|
|
}
|
|
|
|
const char* TF_GetNumberAttrForOpListInput(const char* op_name, int input_index,
|
|
TF_Status* status) {
|
|
const tensorflow::OpDef* op_def = nullptr;
|
|
status->status =
|
|
tensorflow::OpRegistry::Global()->LookUpOpDef(op_name, &op_def);
|
|
if (!status->status.ok()) return nullptr;
|
|
|
|
if (input_index >= op_def->input_arg_size() || input_index < 0) {
|
|
status->status = tensorflow::errors::InvalidArgument(
|
|
input_index, " out of range for ", op_name);
|
|
return nullptr;
|
|
}
|
|
|
|
const tensorflow::OpDef_ArgDef& input_arg = op_def->input_arg()[input_index];
|
|
|
|
if (input_arg.number_attr().empty()) {
|
|
status->status = tensorflow::errors::NotFound(
|
|
op_name, " does not have number_attr() defined.");
|
|
return nullptr;
|
|
}
|
|
|
|
// The returned string is owned by OpRegistry, so liveness is not a concern.
|
|
return input_arg.number_attr().c_str();
|
|
}
|
|
|
|
int TF_OpIsStateful(const char* op_type, TF_Status* status) {
|
|
const tensorflow::OpRegistrationData* op_reg_data;
|
|
status->status =
|
|
tensorflow::OpRegistry::Global()->LookUp(op_type, &op_reg_data);
|
|
if (!status->status.ok()) {
|
|
return 0;
|
|
}
|
|
return op_reg_data->op_def.is_stateful();
|
|
}
|
|
|
|
void TF_InitMain(const char* usage, int* argc, char*** argv) {
|
|
tensorflow::port::InitMain(usage, argc, argv);
|
|
}
|
|
|
|
int TF_PickUnusedPortOrDie() {
|
|
return tensorflow::internal::PickUnusedPortOrDie();
|
|
}
|
|
|
|
TFE_TensorHandle* TFE_NewTensorHandleFromScalar(TF_DataType data_type,
|
|
void* data, size_t len,
|
|
TF_Status* status) {
|
|
auto dtype = static_cast<tensorflow::DataType>(data_type);
|
|
DCHECK(tensorflow::DataTypeCanUseMemcpy(dtype));
|
|
|
|
tensorflow::Tensor tensor(dtype, tensorflow::TensorShape({}));
|
|
std::memcpy(tensorflow::TensorCApi::Buffer(tensor)->data(), data, len);
|
|
return TFE_TensorHandle::CreateLocalHandle(tensor, status);
|
|
}
|
|
|
|
namespace {
|
|
tensorflow::Status EnableCollectiveOps(const tensorflow::ServerDef& server_def,
|
|
TFE_Context* ctx) {
|
|
// We don't use the TF_RETURN_IF_ERROR macro directly since that destroys the
|
|
// server object (which currently CHECK-fails) and we miss the error, instead,
|
|
// we log the error, and then return to allow the user to see the error
|
|
// message.
|
|
#define LOG_AND_RETURN_IF_ERROR(...) \
|
|
do { \
|
|
const ::tensorflow::Status _status = (__VA_ARGS__); \
|
|
if (TF_PREDICT_FALSE(!_status.ok())) { \
|
|
LOG(ERROR) << _status.error_message(); \
|
|
return _status; \
|
|
} \
|
|
} while (0);
|
|
|
|
std::unique_ptr<tensorflow::ServerInterface> server;
|
|
LOG_AND_RETURN_IF_ERROR(tensorflow::NewServer(server_def, &server));
|
|
|
|
tensorflow::GrpcServer* grpc_server =
|
|
dynamic_cast<tensorflow::GrpcServer*>(server.get());
|
|
if (grpc_server == nullptr) {
|
|
LOG_AND_RETURN_IF_ERROR(tensorflow::errors::Internal(
|
|
"Currently, TFE_NewContext only supports tensorflow::GrpcServer."));
|
|
}
|
|
|
|
LOG_AND_RETURN_IF_ERROR(grpc_server->Start());
|
|
|
|
LOG_AND_RETURN_IF_ERROR(ctx->context->StoreCollectiveOpsServer(
|
|
std::move(server), grpc_server->worker_env()->device_mgr,
|
|
grpc_server->worker_env()->collective_executor_mgr));
|
|
|
|
return tensorflow::Status::OK();
|
|
#undef LOG_AND_RETURN_IF_ERROR
|
|
}
|
|
} // namespace
|
|
|
|
// Set server_def on the context, possibly updating it.
|
|
TF_CAPI_EXPORT extern void TFE_EnableCollectiveOps(TFE_Context* ctx,
|
|
const void* proto,
|
|
size_t proto_len,
|
|
TF_Status* status) {
|
|
tensorflow::ServerDef server_def;
|
|
if (!server_def.ParseFromArray(proto, proto_len)) {
|
|
status->status = tensorflow::errors::InvalidArgument(
|
|
"Invalid tensorflow.ServerDef protocol buffer");
|
|
return;
|
|
}
|
|
status->status = EnableCollectiveOps(server_def, ctx);
|
|
}
|
|
|
|
std::string tensorflow::getTF_OutputDebugString(TF_Output node) {
|
|
return absl::Substitute("TF_Output($0, $1)", node.oper, node.index);
|
|
}
|
|
|
|
using tensorflow::getTF_OutputDebugString;
|
|
|
|
TFE_TensorHandle* TFE_NewTensorHandleFromTFOutput(TF_Output t,
|
|
TF_DataType data_type) {
|
|
auto ret = new TFE_TensorHandle(t, data_type);
|
|
VLOG(1) << "Storing TFOutput " << getTF_OutputDebugString(t)
|
|
<< " into tensor handle " << ret << " with internal handle "
|
|
<< ret->handle;
|
|
return ret;
|
|
}
|
|
|
|
unsigned char TFE_TensorHandleIsConcrete(TFE_TensorHandle* handle) {
|
|
assert(handle->handle != nullptr);
|
|
return handle->handle->getSymbolicTensor() == nullptr;
|
|
}
|
|
|
|
TF_Output TFE_GetTFOutputFromTensorHandle(TFE_TensorHandle* handle,
|
|
TF_Status* status) {
|
|
if (TFE_TensorHandleIsConcrete(handle)) {
|
|
status->status =
|
|
tensorflow::errors::Internal("Not a symbolic tensor: ", handle);
|
|
return TF_Output{nullptr, -1};
|
|
}
|
|
|
|
auto* sym_tensor = handle->handle->getSymbolicTensor();
|
|
CHECK(sym_tensor != nullptr);
|
|
auto ret = TF_Output{sym_tensor->oper, sym_tensor->index};
|
|
VLOG(1) << "Retrieving " << getTF_OutputDebugString(ret)
|
|
<< " from tensor handle " << handle;
|
|
CHECK_GE(sym_tensor->index, 0);
|
|
return ret;
|
|
}
|
|
|
|
TFE_TraceContext* TFE_NewTraceContext(TF_Graph* graph) {
|
|
return new TFE_TraceContext(graph);
|
|
}
|
|
|
|
void TFE_DeleteTraceContext(TFE_TraceContext* trace_ctx) { delete trace_ctx; }
|
|
|
|
// If `handle` is already symbolic, return it. Otherwise map it to a new
|
|
// symbolic tensor (a PlaceHolder op) and return that.
|
|
static TF_Output getOrCreateSymbolicTensor(TFE_TraceContext* trace_ctx,
|
|
tensorflow::TensorHandle* handle,
|
|
TF_Status* status) {
|
|
VLOG(1) << "Getting symbolic tensor for input tensor handle " << handle
|
|
<< ": " << handle->DebugString();
|
|
|
|
auto* sym_tensor = handle->getSymbolicTensor();
|
|
if (sym_tensor != nullptr) {
|
|
auto ret = TF_Output{sym_tensor->oper, sym_tensor->index};
|
|
VLOG(1) << "This handle is a symbolic tensor " << sym_tensor << ": "
|
|
<< getTF_OutputDebugString(ret);
|
|
return ret;
|
|
}
|
|
|
|
auto find_it = trace_ctx->input_tensor_map.find(handle);
|
|
if (find_it != trace_ctx->input_tensor_map.end()) {
|
|
VLOG(1) << "There exists a map entry from this concrete tensor to: "
|
|
<< getTF_OutputDebugString(find_it->second);
|
|
return find_it->second;
|
|
}
|
|
|
|
auto node_name = tensorflow::strings::StrCat("additional_input_",
|
|
trace_ctx->node_counter++);
|
|
VLOG(1) << "Adding a place holder node named " << node_name;
|
|
auto* desc =
|
|
TF_NewOperation(trace_ctx->graph, "Placeholder", node_name.c_str());
|
|
TF_SetAttrType(desc, "dtype",
|
|
static_cast<TF_DataType>(handle->dtype) /*TF_FLOAT*/);
|
|
auto* result = TF_FinishOperation(desc, status);
|
|
if (!status->status.ok()) {
|
|
return TF_Output{nullptr, -1};
|
|
}
|
|
|
|
auto ret = TF_Output{result, 0};
|
|
VLOG(1) << "Creating a new map entry to map to: "
|
|
<< getTF_OutputDebugString(ret);
|
|
trace_ctx->input_tensor_map[handle] = ret;
|
|
// `handle` could be destroyed before it's read from `input_tensor_map` (say
|
|
// during a subsequent TFE_FinalizeInputTensorsFromTraceContext() call), so we
|
|
// increment its ref count to extend its life span to that of `trace_ctx`.
|
|
handle->Ref();
|
|
VLOG(1) << "Ref count for handle " << handle
|
|
<< " is 1?: " << handle->RefCountIsOne();
|
|
return ret;
|
|
}
|
|
|
|
TF_Operation* TFE_AddEagerOpToGraph(TFE_Op* op, TFE_TraceContext* trace_ctx,
|
|
TFE_TensorHandle** retvals,
|
|
int* num_retvals, TF_Status* status) {
|
|
VLOG(1) << "Calling TFE_AddEagerOpToGraph() with op " << op << ": "
|
|
<< op->operation.DebugString();
|
|
|
|
const auto& op_type = op->operation.Name();
|
|
auto op_name =
|
|
tensorflow::strings::StrCat(op_type, "_", trace_ctx->node_counter++);
|
|
std::unique_ptr<TF_OperationDescription> desc(
|
|
TF_NewOperation(trace_ctx->graph, op_type.c_str(), op_name.c_str()));
|
|
|
|
VLOG(1) << "Adding attrs.";
|
|
tensorflow::AttrValueMap attrs;
|
|
op->operation.Attrs().FillAttrValueMap(&attrs);
|
|
for (const auto& attr : attrs) {
|
|
desc->node_builder.Attr(attr.first, attr.second);
|
|
}
|
|
|
|
VLOG(1) << "Adding inputs.";
|
|
const auto& inputs = op->operation.Inputs();
|
|
size_t inputIndex = 0;
|
|
const tensorflow::OpDef& op_def = desc->node_builder.op_def();
|
|
for (const tensorflow::OpDef::ArgDef& input_arg : op_def.input_arg()) {
|
|
if (input_arg.type_list_attr().empty() && input_arg.number_attr().empty()) {
|
|
auto symbolic_input =
|
|
getOrCreateSymbolicTensor(trace_ctx, inputs[inputIndex++], status);
|
|
if (!status->status.ok()) return nullptr;
|
|
TF_AddInput(desc.get(), symbolic_input);
|
|
continue;
|
|
}
|
|
size_t list_size = 0;
|
|
if (!input_arg.type_list_attr().empty()) {
|
|
const std::string& type_list_attr = input_arg.type_list_attr();
|
|
const auto& attr_value = attrs[type_list_attr];
|
|
CHECK(attr_value.value_case() == tensorflow::AttrValue::kList)
|
|
<< "Type list attribute should be a list!";
|
|
list_size = attr_value.list().type_size();
|
|
} else {
|
|
CHECK(!input_arg.number_attr().empty());
|
|
const auto& attr_value = attrs[input_arg.number_attr()];
|
|
CHECK(attr_value.value_case() == tensorflow::AttrValue::kI)
|
|
<< "Number attribute should be int!";
|
|
if (attr_value.i() < 0) {
|
|
status->status = tensorflow::errors::Internal(
|
|
"Number attribute for length should be >=0!");
|
|
return nullptr;
|
|
}
|
|
list_size = attr_value.i();
|
|
}
|
|
std::vector<TF_Output> list_inputs(list_size);
|
|
for (TF_Output& list_input : list_inputs) {
|
|
list_input =
|
|
getOrCreateSymbolicTensor(trace_ctx, inputs[inputIndex++], status);
|
|
if (!status->status.ok()) return nullptr;
|
|
}
|
|
TF_AddInputList(desc.get(), list_inputs.data(), list_inputs.size());
|
|
}
|
|
|
|
auto* graph_op = TF_FinishOperation(desc.release(), status);
|
|
if (!status->status.ok()) return nullptr;
|
|
|
|
VLOG(1) << "Op finalized; setting return tensors.";
|
|
*num_retvals = TF_OperationNumOutputs(graph_op);
|
|
VLOG(1) << "This op has " << *num_retvals << " outputs.";
|
|
for (int i = 0; i < *num_retvals; ++i) {
|
|
auto output = TF_Output{graph_op, i};
|
|
auto dtype = TF_OperationOutputType(output);
|
|
retvals[i] = TFE_NewTensorHandleFromTFOutput(output, dtype);
|
|
}
|
|
return graph_op;
|
|
}
|
|
|
|
int TFE_FinalizeInputTensorsFromTraceContext(TFE_TraceContext* trace_ctx) {
|
|
if (trace_ctx->input_tensors == nullptr) {
|
|
trace_ctx->input_tensors =
|
|
new std::vector<std::pair<tensorflow::TensorHandle*, TF_Output>>();
|
|
trace_ctx->input_tensors->reserve(trace_ctx->input_tensor_map.size());
|
|
|
|
for (auto input : trace_ctx->input_tensor_map) {
|
|
trace_ctx->input_tensors->emplace_back(input.first, input.second);
|
|
}
|
|
}
|
|
return trace_ctx->input_tensor_map.size();
|
|
}
|
|
|
|
TF_Output TFE_GetInputGraphNodeFromTraceContext(TFE_TraceContext* trace_ctx,
|
|
unsigned int idx) {
|
|
CHECK(trace_ctx->input_tensors != nullptr);
|
|
CHECK(trace_ctx->input_tensors->size() > idx);
|
|
return trace_ctx->input_tensors->at(idx).second;
|
|
}
|
|
|
|
TFE_TensorHandle* TFE_ConsumeInputConcreteTensorFromTraceContext(
|
|
TFE_TraceContext* trace_ctx, unsigned int idx) {
|
|
CHECK(trace_ctx->input_tensors != nullptr);
|
|
CHECK(trace_ctx->input_tensors->size() > idx);
|
|
auto* handle = trace_ctx->input_tensors->at(idx).first;
|
|
VLOG(1) << "Ref count for internal handle " << handle
|
|
<< " is 1?: " << handle->RefCountIsOne();
|
|
handle->Ref();
|
|
auto* ret = new TFE_TensorHandle(handle);
|
|
VLOG(1) << "Returning a new tensor handle " << ret << ": "
|
|
<< handle->DebugString();
|
|
return ret;
|
|
}
|
|
|
|
TF_ShapeAndTypeList* TF_NewShapeAndTypeList(int num_items) {
|
|
TF_ShapeAndTypeList* result = new TF_ShapeAndTypeList;
|
|
result->num_items = num_items;
|
|
result->items = (num_items == 0) ? nullptr : new TF_ShapeAndType[num_items]();
|
|
return result;
|
|
}
|
|
|
|
void TF_ShapeAndTypeListSetShape(TF_ShapeAndTypeList* shape_list, int index,
|
|
const int64_t* dims, int num_dims) {
|
|
DCHECK(index >= 0 && index < shape_list->num_items);
|
|
TF_ShapeAndType& shape = shape_list->items[index];
|
|
DCHECK(shape.dims == nullptr) << "Shape at " << index << " is already set!";
|
|
DCHECK(num_dims >= 0) << "Number of dimensions cannot be negative!";
|
|
shape.num_dims = num_dims;
|
|
shape.dims = new int64_t[num_dims];
|
|
memcpy(shape.dims, dims, sizeof(int64_t) * num_dims);
|
|
}
|
|
|
|
void TF_ShapeAndTypeListSetUnknownShape(TF_ShapeAndTypeList* shape_list,
|
|
int index) {
|
|
DCHECK(index >= 0 && index < shape_list->num_items);
|
|
TF_ShapeAndType& shape = shape_list->items[index];
|
|
DCHECK(shape.dims == nullptr) << "Shape at " << index << " is already set!";
|
|
shape.num_dims = -1;
|
|
shape.dims = nullptr;
|
|
}
|
|
|
|
void TF_ShapeAndTypeListSetDtype(TF_ShapeAndTypeList* shape_list, int index,
|
|
TF_DataType dtype) {
|
|
DCHECK(index >= 0 && index < shape_list->num_items);
|
|
TF_ShapeAndType& shape_and_type = shape_list->items[index];
|
|
shape_and_type.dtype = dtype;
|
|
}
|
|
|
|
void TF_DeleteShapeAndTypeList(TF_ShapeAndTypeList* shape_list) {
|
|
if (shape_list == nullptr) return;
|
|
for (size_t i = 0; i < shape_list->num_items; ++i) {
|
|
delete[] shape_list->items[i].dims;
|
|
}
|
|
delete[] shape_list->items;
|
|
delete shape_list;
|
|
}
|
|
|
|
void TF_DeleteShapeAndTypeListArray(TF_ShapeAndTypeList** shape_list_array,
|
|
int num_items) {
|
|
if (shape_list_array == nullptr) return;
|
|
for (int i = 0; i < num_items; ++i) {
|
|
TF_DeleteShapeAndTypeList(shape_list_array[i]);
|
|
}
|
|
delete[] shape_list_array;
|
|
}
|
|
|
|
namespace tensorflow {
|
|
Status TF_TensorToTensor(const TF_Tensor* src, Tensor* dst);
|
|
} // namespace tensorflow
|
|
|
|
void TFE_InferShapes(TFE_Op* tfe_op, TF_ShapeAndTypeList* input_shapes,
|
|
TF_Tensor** input_tensors,
|
|
TF_ShapeAndTypeList* input_tensors_as_shapes,
|
|
TF_ShapeAndTypeList** input_resource_shapes_and_types,
|
|
TF_ShapeAndTypeList** output_shapes,
|
|
TF_ShapeAndTypeList*** output_resource_shapes_and_types,
|
|
TF_Status* status) {
|
|
using tensorflow::NodeDef;
|
|
using tensorflow::OpRegistrationData;
|
|
using tensorflow::Tensor;
|
|
using tensorflow::shape_inference::DimensionHandle;
|
|
using tensorflow::shape_inference::InferenceContext;
|
|
using tensorflow::shape_inference::ShapeAndType;
|
|
using tensorflow::shape_inference::ShapeHandle;
|
|
|
|
const int num_inputs = input_shapes->num_items;
|
|
NodeDef node_def;
|
|
node_def.set_name(tfe_op->operation.Name());
|
|
node_def.set_op(tfe_op->operation.Name());
|
|
for (int i = 0; i < num_inputs; ++i) {
|
|
node_def.add_input("dummy_input");
|
|
}
|
|
tfe_op->operation.Attrs().FillAttrValueMap(node_def.mutable_attr());
|
|
|
|
const tensorflow::OpRegistrationData* op_reg_data;
|
|
status->status =
|
|
tensorflow::OpRegistry::Global()->LookUp(node_def.op(), &op_reg_data);
|
|
if (!status->status.ok()) return;
|
|
|
|
// Initialize a input_tensor vector with `nullptr` values.
|
|
std::vector<const Tensor*> input_tensors_vector(num_inputs, nullptr);
|
|
// A vector to keep track of newly created `tf::Tensor` objects.
|
|
std::vector<Tensor> all_input_tensors;
|
|
// Update the vector with information from `input_tensors` if provided.
|
|
if (input_tensors != nullptr) {
|
|
for (int i = 0; i < num_inputs; ++i) {
|
|
if (input_tensors[i] == nullptr) continue;
|
|
all_input_tensors.emplace_back();
|
|
Tensor& input_tensor = all_input_tensors.back();
|
|
status->status = TF_TensorToTensor(input_tensors[i], &input_tensor);
|
|
if (!status->status.ok()) return;
|
|
input_tensors_vector[i] = &input_tensor;
|
|
}
|
|
}
|
|
|
|
// Create an inference context with dummy values, which will be updated later.
|
|
InferenceContext c(TF_GRAPH_DEF_VERSION, &node_def, op_reg_data->op_def,
|
|
std::vector<ShapeHandle>(num_inputs), input_tensors_vector,
|
|
{},
|
|
std::vector<std::unique_ptr<std::vector<ShapeAndType>>>());
|
|
|
|
// Set input_shapes.
|
|
for (int i = 0; i < num_inputs; ++i) {
|
|
std::vector<DimensionHandle> dims;
|
|
const TF_ShapeAndType& input_shape = input_shapes->items[i];
|
|
if (input_shape.num_dims == InferenceContext::kUnknownRank) {
|
|
c.SetInput(i, c.UnknownShape());
|
|
continue;
|
|
}
|
|
for (int j = 0; j < input_shape.num_dims; ++j) {
|
|
dims.push_back(c.MakeDim(input_shape.dims[j]));
|
|
}
|
|
c.SetInput(i, c.MakeShape(dims));
|
|
}
|
|
|
|
// TODO(bgogul): Handle input_tensors_as_shapes.
|
|
// TODO(bgogul): Handle input_resource_shapes_and_types.
|
|
|
|
status->status = c.construction_status();
|
|
if (!status->status.ok()) return;
|
|
|
|
if (op_reg_data->shape_inference_fn == nullptr) {
|
|
status->status =
|
|
InvalidArgument("No shape inference function exists for op '",
|
|
node_def.op(), "', did you forget to define it?");
|
|
return;
|
|
}
|
|
|
|
status->status = c.Run(op_reg_data->shape_inference_fn);
|
|
if (!status->status.ok()) return;
|
|
|
|
// Set output_shapes.
|
|
TF_ShapeAndTypeList* output_shapes_result =
|
|
TF_NewShapeAndTypeList(c.num_outputs());
|
|
for (int i = 0; i < c.num_outputs(); ++i) {
|
|
ShapeHandle shape_handle = c.output(i);
|
|
TF_ShapeAndType& shape = output_shapes_result->items[i];
|
|
shape.num_dims = c.Rank(shape_handle);
|
|
if (shape.num_dims == InferenceContext::kUnknownRank) {
|
|
shape.dims = nullptr;
|
|
continue;
|
|
}
|
|
shape.dims = new int64_t[shape.num_dims];
|
|
for (size_t j = 0; j < shape.num_dims; ++j) {
|
|
shape.dims[j] = c.Value(c.Dim(shape_handle, j));
|
|
}
|
|
}
|
|
if (output_shapes != nullptr) *output_shapes = output_shapes_result;
|
|
|
|
// TODO(bgogul): Set output_resource_shapes_and_types.
|
|
}
|
|
|
|
void TF_ImportGraphDefOptionsSetValidateColocationConstraints(
|
|
TF_ImportGraphDefOptions* opts, unsigned char enable) {
|
|
opts->opts.validate_colocation_constraints = enable;
|
|
}
|