Internal C API headers.

Change: 153717143
This commit is contained in:
Alexandre Passos 2017-04-20 08:18:28 -08:00 committed by TensorFlower Gardener
parent 7933420df6
commit 79d6be3b86
3 changed files with 135 additions and 84 deletions

View File

@ -26,6 +26,22 @@ filegroup(
visibility = ["//tensorflow:__subpackages__"],
)
tf_cuda_library(
name = "c_api_internal",
srcs = ["c_api.h"],
hdrs = ["c_api_internal.h"],
deps = select({
"//tensorflow:android": [
"//tensorflow/core:android_tensorflow_lib_lite",
],
"//conditions:default": [
"//tensorflow/core:core_cpu",
"//tensorflow/core:framework",
"//tensorflow/core:lib",
],
}),
)
tf_cuda_library(
name = "c_api",
srcs = ["c_api.cc"],
@ -34,9 +50,11 @@ tf_cuda_library(
visibility = ["//visibility:public"],
deps = select({
"//tensorflow:android": [
":c_api_internal",
"//tensorflow/core:android_tensorflow_lib_lite",
],
"//conditions:default": [
":c_api_internal",
"//tensorflow/cc/saved_model:loader",
"//tensorflow/cc:gradients",
"//tensorflow/cc:ops",

View File

@ -26,6 +26,7 @@ limitations under the License.
#include "tensorflow/cc/framework/scope_internal.h"
#include "tensorflow/cc/saved_model/loader.h"
#endif
#include "tensorflow/c/c_api_internal.h"
#include "tensorflow/core/common_runtime/shape_refiner.h"
#include "tensorflow/core/framework/log_memory.h"
#include "tensorflow/core/framework/node_def_util.h"
@ -96,9 +97,6 @@ size_t TF_DataTypeSize(TF_DataType dt) {
}
// --------------------------------------------------------------------------
struct TF_Status {
Status status;
};
TF_Status* TF_NewStatus() { return new TF_Status; }
@ -182,12 +180,6 @@ Status MessageToBuffer(const tensorflow::protobuf::Message& in,
} // namespace
struct TF_Tensor {
TF_DataType dtype;
TensorShape shape;
TensorBuffer* buffer;
};
TF_Tensor* TF_AllocateTensor(TF_DataType dtype, const int64_t* dims,
int num_dims, size_t len) {
void* data = allocate_tensor("TF_AllocateTensor", len);
@ -292,9 +284,6 @@ size_t TF_StringEncodedSize(size_t len) {
}
// --------------------------------------------------------------------------
struct TF_SessionOptions {
SessionOptions options;
};
TF_SessionOptions* TF_NewSessionOptions() { return new TF_SessionOptions; }
void TF_DeleteSessionOptions(TF_SessionOptions* opt) { delete opt; }
@ -335,9 +324,6 @@ void TF_DeleteBuffer(TF_Buffer* buffer) {
TF_Buffer TF_GetBuffer(TF_Buffer* buffer) { return *buffer; }
// --------------------------------------------------------------------------
struct TF_DeprecatedSession {
Session* session;
};
TF_DeprecatedSession* TF_NewDeprecatedSession(const TF_SessionOptions* opt,
TF_Status* status) {
@ -701,11 +687,6 @@ void TF_PRun(TF_DeprecatedSession* s, const char* handle,
c_outputs, target_oper_names, nullptr, status);
}
struct TF_Library {
void* lib_handle;
TF_Buffer op_list;
};
TF_Library* TF_LoadLibrary(const char* library_filename, TF_Status* status) {
TF_Library* lib_handle = new TF_Library;
status->status = tensorflow::LoadLibrary(
@ -742,66 +723,6 @@ TF_Buffer* TF_GetAllOpList() {
// --------------------------------------------------------------------------
// New Graph and Session API
// Structures -----------------------------------------------------------------
extern "C" {
struct TF_Graph {
TF_Graph()
: graph(OpRegistry::Global()),
refiner(graph.versions().producer(), graph.op_registry()),
num_sessions(0),
delete_requested(false),
parent(nullptr),
parent_inputs(nullptr) {}
mutex mu;
Graph graph GUARDED_BY(mu);
// Runs shape inference.
tensorflow::ShapeRefiner refiner GUARDED_BY(mu);
// Maps from name of an operation to the Node* in 'graph'.
std::unordered_map<tensorflow::string, Node*> name_map GUARDED_BY(mu);
// TF_Graph may only / must be deleted when
// num_sessions == 0 && delete_requested == true
// num_sessions incremented by TF_NewSession, and decremented by
// TF_DeleteSession.
int num_sessions GUARDED_BY(mu);
bool delete_requested GUARDED_BY(mu); // set true by TF_DeleteGraph
// Used to link graphs contained in TF_WhileParams to the parent graph that
// will eventually contain the full while loop.
TF_Graph* parent;
TF_Output* parent_inputs;
};
struct TF_OperationDescription {
TF_OperationDescription(TF_Graph* g, const char* op_type,
const char* node_name)
: node_builder(node_name, op_type, g->graph.op_registry()), graph(g) {}
NodeBuilder node_builder;
TF_Graph* graph;
std::vector<tensorflow::string> colocation_constraints;
};
struct TF_Operation {
Node node;
};
struct TF_Session {
TF_Session(Session* s, TF_Graph* g)
: session(s), graph(g), last_num_graph_nodes(0) {}
Session* session;
TF_Graph* graph;
mutex mu;
int last_num_graph_nodes;
};
} // end extern "C"
// Helper functions -----------------------------------------------------------
namespace {
@ -1691,10 +1612,6 @@ void TF_GraphToGraphDef(TF_Graph* graph, TF_Buffer* output_graph_def,
status->status = MessageToBuffer(def, output_graph_def);
}
struct TF_ImportGraphDefOptions {
tensorflow::ImportGraphDefOptions opts;
};
TF_ImportGraphDefOptions* TF_NewImportGraphDefOptions() {
return new TF_ImportGraphDefOptions;
}

View File

@ -0,0 +1,116 @@
/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/c/c_api.h"
#include <vector>
#include <unordered_map>
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/framework/tensor_shape.h"
#include "tensorflow/core/platform/mutex.h"
#include "tensorflow/core/public/session.h"
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/graph/graph.h"
#include "tensorflow/core/graph/graph_constructor.h"
#include "tensorflow/core/graph/node_builder.h"
#include "tensorflow/core/platform/types.h"
#include "tensorflow/core/common_runtime/shape_refiner.h"
// Internal structures used by the C API. These are likely to change and should
// not be depended on.
struct TF_Status {
tensorflow::Status status;
};
struct TF_Tensor {
TF_DataType dtype;
tensorflow::TensorShape shape;
tensorflow::TensorBuffer* buffer;
};
struct TF_SessionOptions {
tensorflow::SessionOptions options;
};
struct TF_DeprecatedSession {
tensorflow::Session* session;
};
struct TF_Library {
void* lib_handle;
TF_Buffer op_list;
};
struct TF_Graph {
TF_Graph()
: graph(tensorflow::OpRegistry::Global()),
refiner(graph.versions().producer(), graph.op_registry()),
num_sessions(0),
delete_requested(false),
parent(nullptr),
parent_inputs(nullptr) {}
tensorflow::mutex mu;
tensorflow::Graph graph GUARDED_BY(mu);
// Runs shape inference.
tensorflow::ShapeRefiner refiner GUARDED_BY(mu);
// Maps from name of an operation to the Node* in 'graph'.
std::unordered_map<tensorflow::string, tensorflow::Node*> name_map
GUARDED_BY(mu);
// TF_Graph may only / must be deleted when
// num_sessions == 0 && delete_requested == true
// num_sessions incremented by TF_NewSession, and decremented by
// TF_DeleteSession.
int num_sessions GUARDED_BY(mu);
bool delete_requested GUARDED_BY(mu); // set true by TF_DeleteGraph
// Used to link graphs contained in TF_WhileParams to the parent graph that
// will eventually contain the full while loop.
TF_Graph* parent;
TF_Output* parent_inputs;
};
struct TF_OperationDescription {
TF_OperationDescription(TF_Graph* g, const char* op_type,
const char* node_name)
: node_builder(node_name, op_type, g->graph.op_registry()), graph(g) {}
tensorflow::NodeBuilder node_builder;
TF_Graph* graph;
std::vector<tensorflow::string> colocation_constraints;
};
struct TF_Operation {
tensorflow::Node node;
};
struct TF_Session {
TF_Session(tensorflow::Session* s, TF_Graph* g)
: session(s), graph(g), last_num_graph_nodes(0) {}
tensorflow::Session* session;
TF_Graph* graph;
tensorflow::mutex mu;
int last_num_graph_nodes;
};
struct TF_ImportGraphDefOptions {
tensorflow::ImportGraphDefOptions opts;
};